├── pyproject.toml
├── pytest.ini
├── nebula
├── __init__.py
├── evaluators
│ ├── __init__.py
│ └── bspline.py
├── topology
│ ├── __init__.py
│ ├── topology.py
│ ├── shells.py
│ ├── solids.py
│ ├── faces.py
│ ├── edges.py
│ └── wires.py
├── prim
│ ├── __init__.py
│ ├── bspline_surfaces.py
│ ├── sparse.py
│ ├── lines.py
│ ├── axes.py
│ └── bspline_curves.py
├── helpers
│ ├── types.py
│ ├── vector.py
│ ├── intersection.py
│ ├── wire.py
│ └── clipper.py
├── tools
│ ├── edge.py
│ └── solid.py
├── render
│ ├── visualization.py
│ └── tesselation.py
├── cases
│ └── airfoil.py
└── workplane.py
├── assets
├── example1_width.png
└── example1_arc_height.png
├── requirements_dev.txt
├── .pylintrc
├── .github
└── dependabot.yml
├── .vscode
├── launch.json
└── settings.json
├── .gitpod.yml
├── setup.py
├── .devcontainer
└── devcontainer.json
├── tests
├── test_intersection.py
└── test_bspline.py
├── LICENSE
├── TODO
├── README.md
└── .gitignore
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 90
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | pythonpath = . nebula
--------------------------------------------------------------------------------
/nebula/__init__.py:
--------------------------------------------------------------------------------
1 | from .workplane import Workplane
--------------------------------------------------------------------------------
/nebula/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | from .bspline import BSplineEvaluator
--------------------------------------------------------------------------------
/assets/example1_width.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenOrion/nebula/HEAD/assets/example1_width.png
--------------------------------------------------------------------------------
/assets/example1_arc_height.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenOrion/nebula/HEAD/assets/example1_arc_height.png
--------------------------------------------------------------------------------
/nebula/topology/__init__.py:
--------------------------------------------------------------------------------
1 | from .edges import Edges
2 | from .wires import Wires
3 | from .faces import Faces
4 | from .shells import Shells
5 | from .solids import Solids
--------------------------------------------------------------------------------
/nebula/prim/__init__.py:
--------------------------------------------------------------------------------
1 | from .sparse import SparseArray, SparseIndexable
2 | from .bspline_curves import BSplineCurves
3 | from .lines import Lines
4 | from .axes import Axes
5 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | plotly==5.11.0
2 | ipywidgets==7.6
3 | jupyterlab
4 | matplotlib
5 | pytest
6 |
7 | numpy
8 | ipython_genutils
9 | jupyter_cadquery
10 | jax[cpu]
11 | jax_dataclasses
12 | kaleido
--------------------------------------------------------------------------------
/nebula/topology/topology.py:
--------------------------------------------------------------------------------
1 | from typing import Protocol, Union
2 | from nebula.prim.sparse import SparseIndexable
3 |
4 | class Topology(SparseIndexable, Protocol):
5 |
6 | def add(self, topology: "Topology", reorder_index: bool = False): ...
7 |
--------------------------------------------------------------------------------
/nebula/helpers/types.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import jax.numpy as jnp
3 |
4 | Number = Union[float, jnp.ndarray]
5 | CoordLike = Union[
6 | list[tuple[Number, Number, Number]], list[tuple[Number, Number]], jnp.ndarray
7 | ]
8 | ArrayLike = Union[list[float], jnp.ndarray]
9 |
--------------------------------------------------------------------------------
/nebula/helpers/vector.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 |
3 | class VectorHelper:
4 | @staticmethod
5 | def normalize(arr: jnp.ndarray) -> jnp.ndarray:
6 | arr_min = jnp.min(arr)
7 | arr_max = jnp.max(arr)
8 | normalized = (arr - arr_min) / (arr_max - arr_min)
9 | return normalized
10 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 |
2 | [MESSAGES CONTROL]
3 | disable=import-error,
4 | invalid-name,
5 | non-ascii-name,
6 | missing-function-docstring,
7 | missing-module-docstring,
8 | missing-class-docstring,
9 | line-too-long,
10 | dangerous-default-value,
11 | line-too-long,
12 | unnecessary-lambda
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for more information:
4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 | # https://containers.dev/guide/dependabot
6 |
7 | version: 2
8 | updates:
9 | - package-ecosystem: "devcontainers"
10 | directory: "/"
11 | schedule:
12 | interval: weekly
13 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: Current File",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal",
13 | "justMyCode": true
14 | }
15 | ]
16 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.defaultInterpreterPath": "/opt/conda/bin/python",
3 | "jupyter.notebookFileRoot": "${workspaceFolder}",
4 | "python.analysis.typeCheckingMode": "basic",
5 | "python.analysis.inlayHints.variableTypes": true,
6 | "python.analysis.inlayHints.functionReturnTypes": true,
7 | "python.analysis.inlayHints.pytestParameters": true,
8 | "python.analysis.inlayHints.callArgumentNames": "partial",
9 | "python.testing.pytestArgs": [
10 | "tests"
11 | ],
12 | "python.testing.unittestEnabled": false,
13 | "python.testing.pytestEnabled": true,
14 | }
--------------------------------------------------------------------------------
/.gitpod.yml:
--------------------------------------------------------------------------------
1 | # This configuration file was automatically generated by Gitpod.
2 | # Please adjust to your needs (see https://www.gitpod.io/docs/config-gitpod-file)
3 | # and commit this file to your remote git repository to share the goodness with others.
4 | image:
5 | file: .devcontainer/Dockerfile
6 |
7 | github:
8 | prebuilds:
9 | # enable for the master/default branch (defaults to true)
10 | master: true
11 | # enable for pull requests coming from this repo (defaults to true)
12 | pullRequests: false
13 | # add a "Review in Gitpod" button as a comment to pull requests (defaults to true)
14 | addComment: false
15 |
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="nebula",
5 | version="1.0.0",
6 | description="auto differentiable CAD library in JAX",
7 | author="Afshawn Lotfi",
8 | author_email="",
9 | packages=[
10 | "nebula",
11 | "nebula.cases",
12 | "nebula.prim",
13 | "nebula.evaluators",
14 | "nebula.helpers",
15 | "nebula.render",
16 | "nebula.tools",
17 | "nebula.topology",
18 | ],
19 | install_requires=[
20 | "plotly==5.11.0",
21 | "ipywidgets==7.6",
22 | "jupyterlab",
23 | "matplotlib",
24 | "numpy",
25 | "ipython_genutils",
26 | "jupyter_cadquery",
27 | "jax[cpu]",
28 | "jax_dataclasses",
29 | "kaleido"
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python
3 | {
4 | "name": "Python 3",
5 | // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
6 | "image": "mcr.microsoft.com/devcontainers/python:1-3.12-bullseye"
7 |
8 | // Features to add to the dev container. More info: https://containers.dev/features.
9 | // "features": {},
10 |
11 | // Use 'forwardPorts' to make a list of ports inside the container available locally.
12 | // "forwardPorts": [],
13 |
14 | // Use 'postCreateCommand' to run commands after the container is created.
15 | // "postCreateCommand": "pip3 install --user -r requirements.txt",
16 |
17 | // Configure tool-specific properties.
18 | // "customizations": {},
19 |
20 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
21 | // "remoteUser": "root"
22 | }
23 |
--------------------------------------------------------------------------------
/tests/test_intersection.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from pytest import mark
3 | from nebula.helpers.intersection import Intersection
4 |
5 |
6 | @mark.parametrize(
7 | "line1, line2, expected",
8 | [
9 | (
10 | jnp.array([[50, 50], [163, 215]]),
11 | jnp.array([[50, 250], [170, 170]]),
12 | jnp.array([144.036, 187.309, 0]),
13 | ),
14 | ],
15 | )
16 | def test_intersection(line1: jnp.ndarray, line2: jnp.ndarray, expected: jnp.ndarray):
17 | intersection = Intersection.line_segment_intersection(line1, line2)
18 | assert jnp.allclose(intersection.intersected_vertices, expected, atol=1e-3)
19 |
20 |
21 | @mark.parametrize(
22 | "line1, line2",
23 | [
24 | (
25 | jnp.array([[50, 50], [205, 172]]),
26 | jnp.array([[50, 250], [170, 170]]),
27 | ),
28 | ],
29 | )
30 | def test_no_intersect(line1: jnp.ndarray, line2: jnp.ndarray):
31 | intersection = Intersection.line_segment_intersection(line1, line2)
32 | assert jnp.isnan(intersection.intersected_vertices).all()
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Open Orion, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/nebula/topology/shells.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import jax_dataclasses as jdc
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import numpy.typing as npt
6 | from nebula.topology.topology import Topology
7 | from nebula.topology.faces import Faces
8 |
9 |
10 | @jdc.pytree_dataclass
11 | class Shells(Topology):
12 | faces: Faces
13 | index: jdc.Static[npt.NDArray[np.int32]]
14 |
15 | def __repr__(self) -> str:
16 | return f"Shells(count={self.count}, faces={self.faces})"
17 |
18 | def add(self, shells: "Shells", reorder_index: bool = False):
19 | faces = self.faces.add(shells.faces)
20 | index = self.add_indices(shells.index, reorder_index)
21 | return Shells(faces, index)
22 |
23 | @staticmethod
24 | def empty():
25 | return Shells(faces=Faces.empty(), index=np.empty((0,), dtype=np.int32))
26 |
27 | @staticmethod
28 | def from_faces(faces: Faces):
29 | return Shells(faces=faces, index=np.full(len(faces.index), 0, dtype=np.int32))
30 |
31 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
32 | """Get shells from mask
33 |
34 | :param mask: Mask for shells
35 | :type mask: Union[jnp.ndarray, np.ndarray]
36 | :return: Shells
37 | :rtype: Shells
38 | """
39 | return Shells(
40 | self.faces.mask(mask),
41 | Topology.reorder_index(self.index[mask]),
42 | )
43 |
--------------------------------------------------------------------------------
/nebula/topology/solids.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import jax_dataclasses as jdc
3 |
4 | import numpy as np
5 | import numpy.typing as npt
6 | from nebula.topology.topology import Topology
7 | from nebula.topology.shells import Shells
8 | import jax.numpy as jnp
9 |
10 |
11 | @jdc.pytree_dataclass
12 | class Solids(Topology):
13 | shells: Shells
14 | index: jdc.Static[npt.NDArray[np.int32]]
15 |
16 | @staticmethod
17 | def empty():
18 | return Solids(shells=Shells.empty(), index=np.empty((0,), dtype=np.int32))
19 |
20 | @staticmethod
21 | def from_shells(shells: Shells):
22 | return Solids(shells=shells, index=np.full(len(shells.index), 0, dtype=np.int32))
23 |
24 | def add(self, solids: "Solids", reorder_index: bool = False):
25 | shells = self.shells.add(solids.shells)
26 | index = self.add_indices(solids.index, reorder_index)
27 | return Solids(shells, index)
28 |
29 | def __repr__(self) -> str:
30 | return f"Solids(count={self.count}, shells={self.shells})"
31 |
32 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
33 | """Get solids from mask
34 |
35 | :param mask: Mask for solids
36 | :type mask: Union[jnp.ndarray, np.ndarray]
37 | :return: Masked solids
38 | :rtype: Shells
39 | """
40 | return Solids(
41 | self.shells.mask(mask),
42 | Topology.reorder_index(self.index[mask]),
43 | )
44 |
--------------------------------------------------------------------------------
/nebula/prim/bspline_surfaces.py:
--------------------------------------------------------------------------------
1 | import jax_dataclasses as jdc
2 | import jax.numpy as jnp
3 | from nebula.evaluators.bspline import BSplineEvaluator
4 | from nebula.helpers.types import Number
5 | from nebula.topology.wires import Wires
6 |
7 | @jdc.pytree_dataclass
8 | class BSplineSurfaces:
9 | all_wires: Wires
10 |
11 | def evaluate(self, u: jnp.ndarray, v: jnp.ndarray, degree: Number = 1):
12 | bspline_wires = self.wires
13 | bspline_curves = bspline_wires.edges.curves
14 | num_edges = bspline_wires.get_num_edges()
15 |
16 | vertices = jnp.empty((0, 3))
17 | for i in range(bspline_wires.count):
18 | curve = bspline_curves.mask(bspline_wires.index == i)
19 | # TODO: this is only if all curves in group are the same length
20 | ctrl_pnts = curve.ctrl_pnts.val.reshape(
21 | curve.count,-1, 3
22 | )
23 | v_degree = curve.degree[0]
24 | v_knots = curve.knots.val[curve.knots.index == 0]
25 | u_degree = degree
26 | u_knots = BSplineEvaluator.generate_clamped_knots(u_degree, num_edges[i])
27 | surface_pnts = BSplineEvaluator.eval_surface(
28 | u_degree, v_degree, ctrl_pnts, u_knots, v_knots, u, v
29 | )
30 | new_vertices = surface_pnts.reshape(-1, 3)
31 | vertices = jnp.concatenate([vertices, new_vertices])
32 | return vertices
33 |
34 | def __repr__(self) -> str:
35 | return f"BSplineSurfaces(count={self.count})"
36 |
37 | @property
38 | def wires(self):
39 | return self.all_wires.mask(~self.all_wires.is_planar)
40 |
41 | @property
42 | def count(self) -> int:
43 | return self.wires.count
44 |
45 |
46 | # TODO: switch to this later
47 | # @property
48 | # def count(self) -> int:
49 | # index = SparseIndexable.reorder_index(self.all_wires.index[~self.all_wires.is_planar])
50 | # return SparseIndexable.get_count(index)
51 |
--------------------------------------------------------------------------------
/nebula/topology/faces.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Union
3 | import jax_dataclasses as jdc
4 | import numpy as np
5 | import numpy.typing as npt
6 | import jax.numpy as jnp
7 | from nebula.prim.bspline_curves import BSplineCurves
8 | from nebula.prim.bspline_surfaces import BSplineSurfaces
9 | from nebula.prim.axes import Axes
10 | from nebula.topology.topology import Topology
11 | from nebula.topology.wires import WireLike, Wires
12 |
13 |
14 | @jdc.pytree_dataclass
15 | class Faces(WireLike):
16 | wires: Wires
17 | index: jdc.Static[npt.NDArray[np.int32]]
18 |
19 | @property
20 | def bspline_surfaces(self):
21 | return BSplineSurfaces(all_wires=self.wires)
22 |
23 | @property
24 | def edges(self):
25 | return self.wires.edges
26 |
27 | def add(self, faces: "Faces", reorder_index: bool = False):
28 | wires = self.wires.add(faces.wires)
29 | index = self.add_indices(faces.index, reorder_index)
30 | return Faces(wires, index)
31 |
32 | @staticmethod
33 | def empty():
34 | return Faces(wires=Wires.empty(), index=np.empty((0,), dtype=np.int32))
35 |
36 | @staticmethod
37 | def from_wires(wires: Wires, is_exterior: bool = False):
38 | """
39 | Translates a Wires into Faces.
40 |
41 | :param wires: Wires to convert to faces
42 | :type wires: Wires
43 | :param is_exterior: transfers the index of the wires to the faces otherwise treat the wires all as one face
44 | :type is_exterior: bool, optional
45 | """
46 | if is_exterior:
47 | return Faces(wires=wires, index=wires.index)
48 | return Faces(wires=wires, index=np.full(len(wires.index), 0, dtype=np.int32))
49 |
50 | def translate(self, translation: jnp.ndarray):
51 | return Faces(wires=self.wires.translate(translation), index=self.index)
52 |
53 | def flip_winding(self):
54 | return Faces(wires=self.wires.flip_winding(), index=self.index[::-1])
55 |
56 | def project(self, new_axes: Axes):
57 | return Faces(wires=self.wires.project(new_axes), index=self.index)
58 |
59 | def __repr__(self) -> str:
60 | return f"Faces(count={self.count}, wires={self.wires})"
61 |
62 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
63 | """Get wires from mask
64 |
65 | :param mask: Mask for faces
66 | :type mask: Union[jnp.ndarray, np.ndarray]
67 | :return: Masked faces
68 | :rtype: Faces
69 | """
70 | return Faces(
71 | self.wires.mask(mask),
72 | Topology.reorder_index(self.index[mask]),
73 | )
74 |
75 |
--------------------------------------------------------------------------------
/nebula/helpers/intersection.py:
--------------------------------------------------------------------------------
1 | import jax_dataclasses as jdc
2 | import jax
3 | import jax.numpy as jnp
4 |
5 |
6 | @jdc.pytree_dataclass
7 | class IntersectionResult:
8 | intersected_vertices: jnp.ndarray
9 | is_intersection: jnp.ndarray
10 | num_intersections: jnp.ndarray
11 |
12 |
13 | class Intersection:
14 | @staticmethod
15 | @jax.jit
16 | def on_line(point: jnp.ndarray, line: jnp.ndarray):
17 | """
18 | Determines if a point is on a line.
19 |
20 | :param point: point to check (3,)
21 | :type point: jnp.ndarray
22 | :param line: start and endpoint of line (2, 3)
23 | :type line: jnp.ndarray
24 | :return: True if point is on line, False otherwise
25 | :rtype: bool
26 | """
27 | line_x1, line_y1, line_x2, line_y2 = line[:, :2].flatten()
28 |
29 | return jnp.all(
30 | jnp.array(
31 | [
32 | jnp.minimum(line_x1, line_x2) <= point[0],
33 | jnp.minimum(line_y1, line_y2) <= point[1],
34 | jnp.maximum(line_x1, line_x2) >= point[0],
35 | jnp.maximum(line_y1, line_y2) >= point[1],
36 | ]
37 | )
38 | )
39 |
40 | @staticmethod
41 | @jax.jit
42 | def line_segment_intersection(
43 | line1: jnp.ndarray, line2: jnp.ndarray, decimal_accuracy=5
44 | ):
45 | """
46 | Finds the intersection of two 3D line segment given endpoints of each line.
47 | https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection
48 |
49 | :param line1: start and endpoint of line 1 (2, 3)
50 | :type line1: jnp.ndarray
51 | :param line2: start and endpoint of line 2 (2, 3)
52 | :type line2: jnp.ndarray
53 | :return: intersection point, nan if lines are parallel
54 | :rtype: jnp.ndarray
55 | """
56 | x1, y1, x2, y2 = line1[:, :2].flatten()
57 | x3, y3, x4, y4 = line2[:, :2].flatten()
58 |
59 | nx = (x1 * y2 - y1 * x2) * (x3 - x4) - (x1 - x2) * (x3 * y4 - y3 * x4)
60 | ny = (x1 * y2 - y1 * x2) * (y3 - y4) - (y1 - y2) * (x3 * y4 - y3 * x4)
61 | denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
62 |
63 | unconstrained_intersection_vertices = jnp.array([nx / denom, ny / denom, 0])
64 | rounded_unconstrained_intersection_vertices = jnp.round(
65 | unconstrained_intersection_vertices, decimal_accuracy
66 | )
67 |
68 | intersection_on_line1 = Intersection.on_line(
69 | rounded_unconstrained_intersection_vertices,
70 | jnp.round(line1, decimal_accuracy),
71 | )
72 | intersection_on_line2 = Intersection.on_line(
73 | rounded_unconstrained_intersection_vertices,
74 | jnp.round(line2, decimal_accuracy),
75 | )
76 | is_intersection = (denom != 0) & intersection_on_line1 & intersection_on_line2
77 |
78 | nan_array = jnp.full((3,), jnp.nan)
79 |
80 | intersection_vertices = jnp.where(
81 | is_intersection, unconstrained_intersection_vertices, nan_array
82 | )
83 | return IntersectionResult(
84 | intersected_vertices=intersection_vertices,
85 | is_intersection=is_intersection,
86 | num_intersections=jnp.where(is_intersection, 1, 0),
87 | )
88 |
--------------------------------------------------------------------------------
/nebula/prim/sparse.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Protocol, Union
2 | import jax_dataclasses as jdc
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import numpy.typing as npt
6 |
7 |
8 | class SparseIndexable(Protocol):
9 | index: jdc.Static[npt.NDArray[np.int32]]
10 |
11 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]) -> "SparseIndexable": ...
12 |
13 | @property
14 | def last_index(self):
15 | return np.max(self.index).astype(np.int32) if len(self.index) else 0
16 |
17 | @property
18 | def count(self):
19 | return SparseIndexable.get_count(self.index)
20 |
21 | def add_indices(self, indices: np.ndarray, reorder_index: bool = False):
22 | if reorder_index:
23 | indices = SparseIndexable.reorder_index(indices)
24 |
25 | new_indices = indices + self.count
26 | return np.concatenate([self.index, new_indices])
27 |
28 | @staticmethod
29 | def get_count(index: np.ndarray):
30 | return np.max(index).astype(np.int32) + 1 if len(index) else 0
31 |
32 | @staticmethod
33 | def reorder_index(index: np.ndarray):
34 | return np.unique(index, return_inverse=True)[1]
35 |
36 | @staticmethod
37 | def expanded_index(index: np.ndarray, repeats: int):
38 | count = SparseIndexable.get_count(index)
39 | index_increment = (np.arange(0, repeats) * count).repeat(len(index))
40 | return np.repeat(index[None, :], repeats, axis=0).flatten() + index_increment
41 |
42 |
43 | @jdc.pytree_dataclass
44 | class SparseArray(SparseIndexable):
45 | val: jnp.ndarray
46 | index: jdc.Static[npt.NDArray[np.int32]]
47 |
48 | @staticmethod
49 | def empty(shape=(0,)):
50 | return SparseArray(jnp.empty(shape), np.empty((0,), dtype=np.int32))
51 |
52 | @staticmethod
53 | def from_array(array: jnp.ndarray):
54 | return SparseArray(array, np.full(len(array), 0, dtype=np.int32))
55 |
56 | def add(self, item: "SparseArray"):
57 | val = jnp.concatenate([self.val, item.val])
58 | index = self.add_indices(item.index)
59 | return SparseArray(val, index)
60 |
61 | def flip_winding(self, mask: Optional[jnp.ndarray] = None):
62 | if mask is None:
63 | return SparseArray(self.val[::-1], self.index[::-1])
64 |
65 | new_index = self.index.copy()
66 | new_index[mask] = np.flip(new_index[mask], axis=0)
67 |
68 | return SparseArray(self.val, new_index)
69 |
70 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]) -> "SparseArray":
71 | return SparseArray(
72 | val=self.val[mask],
73 | index=SparseIndexable.reorder_index(self.index[mask]),
74 | )
75 |
76 | # TODO: make this faster
77 | def reorder(self, index: np.ndarray):
78 | """
79 | Reorder the array according to the given index.
80 |
81 | :param index: The new order of the array.
82 | :type index: np.ndarray
83 | :return: The reordered array.
84 | :rtype: SparseArray
85 | """
86 |
87 | new_val = jnp.empty((0, *self.val.shape[1:]))
88 | new_index = np.empty((0,), dtype=np.int32)
89 | for i in index:
90 | index_mask = self.index == i
91 | new_val = jnp.concatenate([new_val, self.val[index_mask]])
92 | new_index = np.concatenate([new_index, self.index[index_mask]])
93 |
94 | return SparseArray(new_val, new_index)
95 |
96 | def __repr__(self) -> str:
97 | return f"SparseArray(count={self.count})"
98 |
--------------------------------------------------------------------------------
/nebula/prim/lines.py:
--------------------------------------------------------------------------------
1 | import jax_dataclasses as jdc
2 | import numpy as np
3 | import jax.numpy as jnp
4 | from typing import Optional, Union
5 | from nebula.prim.axes import Axes
6 | import numpy.typing as npt
7 |
8 | @jdc.pytree_dataclass
9 | class Lines:
10 | vertices: jnp.ndarray
11 | index: jdc.Static[npt.NDArray[np.int32]]
12 |
13 | def add(
14 | self,
15 | lines: "Lines",
16 | ):
17 | vertices = lines.vertices
18 | # Allocate new indices for the new edges ahead of the current indices
19 | new_line_indices = lines.index + len(self.vertices)
20 |
21 | vertices = jnp.concatenate([self.vertices, vertices], axis=0)
22 | index = np.concatenate([self.index, new_line_indices], axis=0)
23 | return Lines(vertices, index)
24 |
25 | @staticmethod
26 | def empty():
27 | return Lines(vertices=jnp.empty((0, 3)), index=np.empty((0, 2), dtype=np.int32))
28 |
29 | @staticmethod
30 | def from_segments(segments: jnp.ndarray):
31 | vertices = segments.reshape(-1, 3)
32 | index = np.arange(vertices.shape[0], dtype=np.int32).reshape(-1, 2)
33 | return Lines(vertices=vertices, index=index)
34 |
35 | def get_segments(self):
36 | return self.vertices[self.index]
37 |
38 | def clone(
39 | self,
40 | vertices: Optional[jnp.ndarray] = None,
41 | index: Optional[np.ndarray] = None,
42 | ):
43 | return Lines(
44 | vertices=vertices if vertices is not None else self.vertices,
45 | index=index if index is not None else self.index,
46 | )
47 |
48 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray):
49 | vertices = self.vertices[self.index[index]]
50 | return (1 - u) * vertices[0] + u * vertices[1]
51 |
52 | def project(self, curr_axes: Axes, new_axes: Axes):
53 | local_coords = curr_axes.to_local_coords(self.vertices)
54 |
55 | new_vertices = new_axes.to_world_coords(local_coords).reshape(-1, 3)
56 |
57 | max_index = np.max(self.index) if len(self.index) else 0
58 | index_increment = (np.arange(0, new_axes.count) * max_index).repeat(len(self.index))
59 | new_index = np.repeat(self.index[None, :], new_axes.count, axis=0).reshape(
60 | -1, 2
61 | ) + np.expand_dims(index_increment, axis=1)
62 |
63 | return Lines(vertices=new_vertices, index=new_index)
64 |
65 |
66 | def translate(self, translation: jnp.ndarray):
67 | if len(translation.shape) == 1:
68 | new_vertices = self.vertices + translation
69 | else:
70 | new_vertices = self.vertices.at[self.index].add(
71 | jnp.expand_dims(translation, axis=1)
72 | )
73 |
74 | return self.clone(vertices=new_vertices)
75 |
76 |
77 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
78 | return Lines.from_segments(self.vertices[self.index[mask]])
79 |
80 | def flip_winding(self, mask: Optional[jnp.ndarray] = None):
81 | if mask is None:
82 | return Lines(vertices=self.vertices, index=np.flip(self.index, (0, 1)))
83 | new_index = self.index.copy()
84 | new_index[mask] = np.flip(new_index[mask], (0, 1))
85 | return Lines(vertices=self.vertices, index=new_index)
86 |
87 | def reorder(self, index: np.ndarray):
88 | new_index = self.index[index]
89 | return self.clone(index=new_index)
90 |
91 | def __repr__(self) -> str:
92 | return f"Lines(count={len(self.index)})"
93 |
94 | @property
95 | def count(self) -> int:
96 | return len(self.index)
97 |
--------------------------------------------------------------------------------
/nebula/helpers/wire.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 | from nebula.helpers.types import CoordLike
5 | from nebula.prim.axes import Axes
6 |
7 |
8 | @jax.jit
9 | def is_inside(polygon_segment: jnp.ndarray, point: jnp.ndarray):
10 | x1, y1, x2, y2 = polygon_segment[:, :2].flatten()
11 | xp, yp = point[:2]
12 |
13 | return ((yp < y1) != (yp < y2)) & (xp < (x1 + ((yp - y1) / (y2 - y1)) * (x2 - x1)))
14 |
15 |
16 | class WireHelper:
17 | @staticmethod
18 | def contains_vertex(
19 | polygon_segments: jnp.ndarray,
20 | vertex: jnp.ndarray,
21 | ):
22 | """
23 | Checks if polygon contains vertex.
24 |
25 | :param polygon_segments: polygon segments (num_segments, 2, 3)
26 | :type polygon_segments: jnp.ndarray
27 | :param vertex: vertex to check (3,)
28 | :type vertex: jnp.ndarray
29 | :return: True if polygon contains vertex
30 | :rtype: bool
31 |
32 | """
33 | ray_intersection = jax.vmap(is_inside, in_axes=(0, None))(
34 | polygon_segments, vertex
35 | ).astype(jnp.int32)
36 |
37 | # if odd number of intersections, then segment is inside polygon
38 | total_intersections = jnp.sum(ray_intersection)
39 | return jnp.mod(total_intersections, 2) == 1
40 |
41 | @staticmethod
42 | def contains_segment(
43 | polygon_segments: jnp.ndarray,
44 | segment: jnp.ndarray,
45 | ):
46 | """
47 | Checks if segment is inside polygon.
48 |
49 | :param polygon_segments: polygon segments (num_segments, 2, 3)
50 | :type polygon_segments: jnp.ndarray
51 | :param segment: segment to check (2, 3)
52 | :type segment: jnp.ndarray
53 | :return: True if segment is inside polygon
54 | :rtype: bool
55 |
56 | """
57 | mean_vertex = jnp.mean(segment, axis=0)
58 | return WireHelper.contains_vertex(polygon_segments, mean_vertex)
59 |
60 | @staticmethod
61 | def fully_contains_segment(
62 | polygon_segments: jnp.ndarray,
63 | segment: jnp.ndarray,
64 | ):
65 | """
66 | Checks if segment is inside polygon.
67 |
68 | :param polygon_segments: polygon segments (num_segments, 2, 3)
69 | :type polygon_segments: jnp.ndarray
70 | :param segment: segment to check (2, 3)
71 | :type segment: jnp.ndarray
72 | :return: True if segment is inside polygon
73 | :rtype: bool
74 |
75 | """
76 | return jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))(
77 | polygon_segments, segment
78 | ).all()
79 |
80 | @staticmethod
81 | def to_3d_vertices(vertices: CoordLike):
82 | vertices = jnp.asarray(vertices)
83 | if vertices.shape[-1] == 2:
84 | return jnp.concatenate([vertices, jnp.zeros((len(vertices), 1))], axis=-1)
85 | return vertices
86 |
87 | @staticmethod
88 | @jax.jit
89 | def is_clockwise(first_segment: jnp.ndarray, last_segment: jnp.ndarray):
90 | """
91 | Finds if the segments are in clockwise order. IMPORTANT: This assumes that segments are on the XY plane
92 |
93 | """
94 | normal = WireHelper.get_normal(first_segment, last_segment)
95 | return normal.sum() < 0
96 |
97 | @staticmethod
98 | def get_normal(first_segment: jnp.ndarray, last_segment: jnp.ndarray):
99 | start_vecs = first_segment[1] - first_segment[0]
100 | end_vecs = last_segment[1] - last_segment[0]
101 | normals = jnp.cross(end_vecs, start_vecs)
102 | return normals / jnp.linalg.norm(normals, axis=-1, keepdims=True)
103 |
104 |
105 |
--------------------------------------------------------------------------------
/TODO:
--------------------------------------------------------------------------------
1 | Baseline MVP
2 |
3 | - Top Priority
4 | [X] Parafoil demo
5 |
6 | [ ] Tesselation
7 | - [ ] Add normals for bspline surface triangles (normals issue)
8 | - [X] L issue with triangulaion
9 | - [X] Correct Orientation
10 |
11 | - [ ] Sweep
12 | - [ ] Same Wire along path
13 | - [ ] Different Wire along path
14 |
15 | - [ ] Revolve
16 | - [ ] Add circle representation of spline
17 | - [ ] Sweep curve/line around that circle
18 |
19 | [ ] Tech Debt
20 | - Peformance
21 | [ ] Optimizations
22 | - [X] Use different dataclasses that work better with jit
23 | - [ ] New masking/padding approach
24 | - [ ] Better profiling for exactly what takes the most time
25 | - [ ] Project
26 | - [ ] Extrude
27 | - [ ] Clipper
28 | - [ ] Sort
29 | - [ ] Tesselation
30 | - [X] Reorder by index
31 | - [X] Improve/change unique operations
32 | - [X] Don't need to sort on all clip operations?
33 | - [ ] Jax types
34 | - [ ] Unit tests
35 | - [ ] Auto evaluation point count
36 |
37 |
38 | # Next Release
39 | [ ] Querying (next release)
40 | - [ ] Query centers by X,Y,Z
41 | - [ ] Parallel
42 | - [ ] Most/Least (+/-)
43 | - [X] Query index
44 | - [ ] Display selected entities
45 |
46 |
47 | [ ] OCC Compatibility (next release)
48 | - [ ] BSpline Surfaces
49 | - [ ] Polygons
50 | - [ ] CQ detach
51 |
52 | [X] Add tesselation to work with Jupyter Cadquery
53 | - [X] BSpline surfaces
54 | - [X] Add normals for bspline surface triangles (normals issue)
55 | - [X] Planar Surfaces
56 | - [X] Fix ordering for clipping
57 | - [X] Fix normals
58 | - [X] BSpline Curves
59 | - [X] Fix non showing planars in bspline case
60 | - [X] Trimmed Planar Surfaces
61 | - [X] Solve bug with specific case
62 | - [ ] Trimmed BSpline Surfaces (next release)
63 |
64 |
65 | [ ] Polar Array (next release)
66 |
67 | [ ] Add Edge Type
68 | - [X] Add Lines
69 | - [X] Add Bsplines
70 | - [ ] Efficient Bsplines (next release)
71 |
72 | [ ] Projection
73 | - [ ] Plane (next release)
74 | - [ ] Bspline Surface (next release)
75 |
76 | [ ] 2D Clip
77 | - [X] Line Line Intersection
78 | - [X] 2D Clip
79 | [X] 1x1 Intersections
80 | [ ] Curve intersections (next release)
81 | [ ] nx2 intersections (next release)
82 | [ ] Split into two wires (next release)
83 |
84 | [X] Auto Diff working for bspline example
85 | - [X] BSpline surface -> jacobian
86 | - [X] Plane -> BSpline surface -> jacobian
87 | - [ ] Polygon (next release)
88 |
89 | [ ] Add trims/holes
90 | - [ ] BSpline Surfaces (next release)
91 | - [X] Planar Surfaces
92 |
93 | [ ] 3D Contains (next release)
94 | (tesselate and ray test)
95 | - [ ] Planes
96 | - [ ] BSpline
97 |
98 | [] 3D Clip (next release)
99 | - [] Line Plane Intersection
100 | - [] 3D Clip
101 |
102 |
103 | [ ] Query topology and get all face points
104 | - [X] Inefficient method
105 | - [X] Eventually scatter
106 |
107 | [ ] Try to plot
108 | - [X] Fix normals
109 | - [X] Make runtime faster - might downgrade later
110 |
111 | Workplane
112 | - [X] Project to Axes
113 | - [X] Basic Workplane
114 |
115 | - [X] Extrude
116 | - [X] Plane
117 | - [X] BSpline
118 |
119 |
120 | # Tommorows next steps
121 | [X] Fix extrude indexes for normal cases
122 | - [X] Add quick extrude unit tests
123 | [X] Finally get tesselation to work properly
124 | [X] More unit test for existing edge cases
125 | [ ] Research into auto diff for polygons and in Jax in general
126 |
127 | [X] Start Bsplines
128 |
129 | [ ] Later: Account for multi line splits (not important right now)
130 | [ ] Later: Fix jax jit for interior intersection cases
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Nebula
2 |
3 | open source differentiable geometry engine
4 |
5 |
6 | # About
7 | Nebula was an attempt to create simple geometries that are auto-differentiable with respect to their design variables for design optimization and machine learning tasks.
8 |
9 | In the past the jacobians of geometry surface meshes with respect to the design variables were computed via finite differences, a task of altering each design variable by small epsilons and taking the position difference of their corresponding meshes. This process was very slow depending on the geometry and also restricted to just structured meshes because triangulation is very non-deterministic to small changes.
10 |
11 | The structure of Nebula is single-to-many for performance reasons (i.e Solids with verticies and indexes rather than a list of class Solid). This allows us to construct simulataneuosly mutliple items at the same time.
12 |
13 | The Workplane class of Nebula is modeled after CadQuery syntax for familiarity and although it has not been tested should migrate for some cases.
14 |
15 | Nebula still has some performance issues mainly because the Jax JIT is very slow for operations that aren't uniform sized. Unfortunatly, operations like clipping and the amount of topology items are not uniform in nature so the only way to run fast would be to pad to a uniform size. This was not implemented mainly because time-deficits and the fact that the goal of this project mainly was to get CAD sensitivities.
16 |
17 | Nebula has a lot of deficits in features and most likely has a lot of bugs so this project is for experimental use only. We will add more features as we need them as we currently don't have the resources to invest all our time to features that don't help our current roadmap.
18 |
19 |
20 | # Install
21 | ```
22 | pip install git+https://github.com/OpenOrion/nebula.git#egg=nebula
23 | ```
24 |
25 |
26 |
27 | ## Example
28 | See more examples in [examples](/examples) directory
29 |
30 | ```python
31 | import jax
32 | import jax.numpy as jnp
33 | from nebula import Workplane
34 | from nebula.render.visualization import show
35 | from nebula.render.visualization import Tesselator
36 |
37 | def make_mesh(height: jnp.ndarray, arc_height: jnp.ndarray):
38 | profile = (
39 | Workplane.init("XY")
40 | .lineTo(0, height)
41 | .bspline([
42 | (0.5, height+arc_height),
43 | (1.5, height+arc_height),
44 | (2, height)
45 | ], includeCurrent=True)
46 | .lineTo(2, 0)
47 | .close()
48 | .extrude(2)
49 | )
50 | return Tesselator.get_differentiable_mesh(profile)
51 |
52 | # design varibles
53 | height = jnp.array(2.0)
54 | arc_height = jnp.array(0.5)
55 |
56 | # the jacobian with respect to height and arc height i.e (0,1)
57 | mesh_jacobian = jax.jacobian(
58 | lambda height, arc_height: make_mesh(height, arc_height).vertices,
59 | (0,1)
60 | )(height, arc_height)
61 | arc_height_jac_magnitude = jnp.linalg.norm(mesh_jacobian[1], axis=-1)
62 |
63 | # construct regular surface mesh
64 | mesh = make_mesh(height, arc_height)
65 |
66 | # show the mesh with magnitude of jacobian as color
67 | show(mesh, "plot", arc_height_jac_magnitude, name="CAD Sensitivity w.r.t arc height
Nebula by Open Orion")
68 |
69 | ```
70 |
71 | 
72 |
73 |
74 | ## Current features
75 |
76 | * Basic Sketch
77 | - rect
78 | - lineTo
79 | - bsplineTo
80 | - polarLine
81 | - polyline
82 | * Extrude
83 | * Polyline
84 | * Bspline Operations
85 | - Skin
86 | - Evaluate
87 |
88 |
89 | ## Deficits
90 | All the project roadmap features are in the TODO file. Main thing to keep in mind are:
91 | * Clip operations
92 | - Can only have 1-2 intersections, this was to maintain uniform size for faster operations, if we see it fit will add a workaround later
93 |
94 |
95 | # Developement Setup
96 | ```
97 | git clone https://github.com/OpenOrion/nebula.git
98 | cd nebula
99 | pip install -r requirements_dev.txt
100 | ```
101 |
102 | # Help Wanted
103 | Please join the [Discord](https://discord.gg/H7qRauGkQ6) for project communications and collaboration.
104 |
105 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.cgns
2 | *.su2
3 | *.geo_unrolled
4 | *.msh
5 | su2/
6 | cfd/
7 | generated/
8 | simulations/
9 | paraview
10 | docs
11 | bin/CQ-editor*
12 | bin/CQ-editor/
13 | *.curaprofile
14 | ./cad
15 | *.step
16 | *.geo
17 | *.stl
18 | *.out
19 | *.mod
20 | *.pkl
21 |
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 | cover/
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 | local_settings.py
82 | db.sqlite3
83 | db.sqlite3-journal
84 |
85 | # Flask stuff:
86 | instance/
87 | .webassets-cache
88 |
89 | # Scrapy stuff:
90 | .scrapy
91 |
92 | # Sphinx documentation
93 | docs/_build/
94 |
95 | # PyBuilder
96 | .pybuilder/
97 | target/
98 |
99 | # Jupyter Notebook
100 | .ipynb_checkpoints
101 |
102 | # IPython
103 | profile_default/
104 | ipython_config.py
105 |
106 | # pyenv
107 | # For a library or package, you might want to ignore these files since the code is
108 | # intended to run in multiple environments; otherwise, check them in:
109 | # .python-version
110 |
111 | # pipenv
112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
115 | # install all needed dependencies.
116 | #Pipfile.lock
117 |
118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
119 | __pypackages__/
120 |
121 | # Celery stuff
122 | celerybeat-schedule
123 | celerybeat.pid
124 |
125 | # SageMath parsed files
126 | *.sage.py
127 |
128 | # Environments
129 | .env
130 | .venv
131 | env/
132 | venv/
133 | ENV/
134 | env.bak/
135 | venv.bak/
136 |
137 | # Spyder project settings
138 | .spyderproject
139 | .spyproject
140 |
141 | # Rope project settings
142 | .ropeproject
143 |
144 | # mkdocs documentation
145 | /site
146 |
147 | # mypy
148 | .mypy_cache/
149 | .dmypy.json
150 | dmypy.json
151 |
152 | # Pyre type checker
153 | .pyre/
154 |
155 | # pytype static type analyzer
156 | .pytype/
157 |
158 | # Cython debug symbols
159 | cython_debug/
160 |
161 | # General
162 | .DS_Store
163 | .AppleDouble
164 | .LSOverride
165 |
166 | # Icon must end with two \r
167 | Icon
168 |
169 |
170 | # Thumbnails
171 | ._*
172 |
173 | # Files that might appear in the root of a volume
174 | .DocumentRevisions-V100
175 | .fseventsd
176 | .Spotlight-V100
177 | .TemporaryItems
178 | .Trashes
179 | .VolumeIcon.icns
180 | .com.apple.timemachine.donotpresent
181 |
182 | # Directories potentially created on remote AFP share
183 | .AppleDB
184 | .AppleDesktop
185 | Network Trash Folder
186 | Temporary Items
187 | .apdisk
188 |
189 | *~
190 |
191 | # temporary files which can be created if a process still has a handle open of a deleted file
192 | .fuse_hidden*
193 |
194 | # KDE directory preferences
195 | .directory
196 |
197 | # Linux trash folder which might appear on any partition or disk
198 | .Trash-*
199 |
200 | # .nfs files are created when an open file is removed but is still being accessed
201 | .nfs*
202 |
203 | # Windows thumbnail cache files
204 | Thumbs.db
205 | Thumbs.db:encryptable
206 | ehthumbs.db
207 | ehthumbs_vista.db
208 |
209 | # Dump file
210 | *.stackdump
211 |
212 | # Folder config file
213 | [Dd]esktop.ini
214 |
215 | # Recycle Bin used on file shares
216 | $RECYCLE.BIN/
217 |
218 | # Windows Installer files
219 | *.cab
220 | *.msi
221 | *.msix
222 | *.msm
223 | *.msp
224 |
225 | # Windows shortcuts
226 | *.lnk
227 |
--------------------------------------------------------------------------------
/nebula/tools/edge.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from typing import Optional
3 | import jax.numpy as jnp
4 | from nebula.evaluators.bspline import BSplineEvaluator
5 | from nebula.prim.bspline_curves import BSplineCurves
6 | from nebula.helpers.clipper import Clipper
7 | from nebula.helpers.types import ArrayLike, Number
8 | from nebula.helpers.wire import WireHelper
9 | from nebula.prim.sparse import SparseArray
10 | from nebula.topology.edges import Edges
11 | from nebula.topology.wires import Wires
12 |
13 |
14 | class EdgeTool:
15 | @staticmethod
16 | @jax.jit
17 | def make_rect(
18 | x_len: Number, y_len: Number, origins: Optional[jnp.ndarray] = None, centered: bool = True
19 | ):
20 | """Creates a rectangle.
21 |
22 | :param x_len: length of the rectangle along x-axis
23 | :type x_len: float
24 | :param y_len: length of the rectangle along y-axis
25 | :type y_len: float
26 | :param axes: axes to create the rectangle on, defaults to XY
27 | :type axes: Axes, optional
28 | :param center: flag to create the rectangle at the center
29 | :type center: bool
30 | :return: rectangle array of size (num_axis, 4, 2, 3)
31 | :rtype: jnp.ndarray
32 | """
33 | if origins is None:
34 | origins = jnp.zeros((1,3))
35 |
36 |
37 | vertices = jnp.array(
38 | [
39 | (0.0, 0.0, 0.0),
40 | (x_len, 0.0, 0.0),
41 | (x_len, y_len, 0.0),
42 | (0.0, y_len, 0.0),
43 | (0.0, 0.0, 0.0),
44 | ]
45 | )
46 |
47 | # center the rectangle if requested
48 | vertices = jnp.where(
49 | centered, vertices - jnp.array([x_len / 2, y_len / 2, 0.0]), vertices
50 | )
51 | return EdgeTool.make_polyline(vertices + origins)
52 |
53 | @staticmethod
54 | @jax.jit
55 | def make_polyline(vertices: jnp.ndarray):
56 | segments = jnp.stack([vertices[:-1], vertices[1:]], axis=1)
57 | return Edges.from_line_segments(segments)
58 |
59 | @staticmethod
60 | @jax.jit
61 | def make_line(start_vertex: jnp.ndarray, end_vertex: jnp.ndarray):
62 | line_segments = jnp.stack([start_vertex, end_vertex])
63 | return Edges.from_line_segments(jnp.expand_dims(line_segments, axis=0))
64 |
65 | @staticmethod
66 | @jax.jit
67 | def make_polar_line(start_vertex: jnp.ndarray, distance: Number, angle: Number):
68 | end_vertex = jnp.array(
69 | [
70 | jnp.cos(jnp.radians(angle)) * distance,
71 | jnp.sin(jnp.radians(angle)) * distance,
72 | ]
73 | )
74 |
75 | return EdgeTool.make_line(start_vertex, end_vertex)
76 |
77 | @staticmethod
78 | def make_bspline_curve(
79 | ctrl_pnts: jnp.ndarray,
80 | degree: jnp.ndarray,
81 | knots: Optional[ArrayLike] = None,
82 | ):
83 | if knots is None:
84 | knots = BSplineEvaluator.generate_clamped_knots(degree, len(ctrl_pnts))
85 | knots = jnp.asarray(knots)
86 |
87 | return BSplineCurves(
88 | degree=jnp.array([degree], dtype=jnp.int32),
89 | ctrl_pnts=SparseArray.from_array(ctrl_pnts),
90 | knots=SparseArray.from_array(knots),
91 | )
92 |
93 | @staticmethod
94 | def consolidate_wires(wires: Wires, pending_edges: Edges, validate: bool):
95 | """
96 | Consolidates the pending edges with the existing wires.
97 |
98 | :param wires: The existing wires
99 | :type wires: Wires
100 | :param pending_edges: The pending edges
101 | :type pending_edges: Edges
102 | :param validate: Flag to validate the pending edges for sort order before consolidation. No validation is more performant.
103 | :type validate: bool
104 |
105 | """
106 | if validate:
107 | pending_edge_result = pending_edges.get_sorted()
108 | pending_edges = pending_edge_result.edges
109 | is_clockwise = WireHelper.is_clockwise(pending_edge_result.segments[0], pending_edge_result.segments[-1])
110 | if is_clockwise:
111 | pending_edges = pending_edges.flip_winding()
112 |
113 | clipper_wire = Wires.from_edges(pending_edges)
114 | if wires.count > 0:
115 | return Clipper.cut_polygons(wires, clipper_wire)
116 | else:
117 | return clipper_wire
118 |
--------------------------------------------------------------------------------
/tests/test_bspline.py:
--------------------------------------------------------------------------------
1 | from nebula.evaluators.bspline import BSplineEvaluator
2 | import jax.numpy as jnp
3 |
4 |
5 | def test_bspline_curve():
6 | u = jnp.linspace(0.0, 1.0, 10)
7 | actual = BSplineEvaluator.eval_curve(
8 | degree=jnp.array(2),
9 | ctrl_pnts=jnp.array([[1, 0, 0], [1, 1, 0], [0, 1, 0]]),
10 | u=u,
11 | knots=jnp.array([0, 0, 0, 1, 1, 1]),
12 | )
13 | expected = jnp.array(
14 | [
15 | [1.0, 0.0, 0.0],
16 | [0.9876543209876543, 0.20987654320987653, 0.0],
17 | [0.9506172839506173, 0.3950617283950617, 0.0],
18 | [0.8888888888888891, 0.5555555555555556, 0.0],
19 | [0.8024691358024691, 0.691358024691358, 0.0],
20 | [0.691358024691358, 0.8024691358024691, 0.0],
21 | [0.5555555555555556, 0.8888888888888888, 0.0],
22 | [0.3950617283950617, 0.9506172839506173, 0.0],
23 | [0.20987654320987661, 0.9876543209876544, 0.0],
24 | [0.0, 1.0, 0.0],
25 | ]
26 | )
27 |
28 | assert jnp.allclose(actual, expected, atol=1e-3)
29 |
30 |
31 | def test_bspline_surface():
32 | u = jnp.linspace(0.0, 1.0, 5)
33 | actual = BSplineEvaluator.eval_surface(
34 | u_degree=jnp.array(3),
35 | v_degree=jnp.array(3),
36 | ctrl_pnts=jnp.array(
37 | [
38 | [-25.0, -25.0, -10.0],
39 | [-25.0, -15.0, -5.0],
40 | [-25.0, -5.0, 0.0],
41 | [-25.0, 5.0, 0.0],
42 | [-25.0, 15.0, -5.0],
43 | [-25.0, 25.0, -10.0],
44 | [-15.0, -25.0, -8.0],
45 | [-15.0, -15.0, -4.0],
46 | [-15.0, -5.0, -4.0],
47 | [-15.0, 5.0, -4.0],
48 | [-15.0, 15.0, -4.0],
49 | [-15.0, 25.0, -8.0],
50 | [-5.0, -25.0, -5.0],
51 | [-5.0, -15.0, -3.0],
52 | [-5.0, -5.0, -8.0],
53 | [-5.0, 5.0, -8.0],
54 | [-5.0, 15.0, -3.0],
55 | [-5.0, 25.0, -5.0],
56 | [5.0, -25.0, -3.0],
57 | [5.0, -15.0, -2.0],
58 | [5.0, -5.0, -8.0],
59 | [5.0, 5.0, -8.0],
60 | [5.0, 15.0, -2.0],
61 | [5.0, 25.0, -3.0],
62 | [15.0, -25.0, -8.0],
63 | [15.0, -15.0, -4.0],
64 | [15.0, -5.0, -4.0],
65 | [15.0, 5.0, -4.0],
66 | [15.0, 15.0, -4.0],
67 | [15.0, 25.0, -8.0],
68 | [25.0, -25.0, -10.0],
69 | [25.0, -15.0, -5.0],
70 | [25.0, -5.0, 2.0],
71 | [25.0, 5.0, 2.0],
72 | [25.0, 15.0, -5.0],
73 | [25.0, 25.0, -10.0],
74 | ]
75 | ).reshape(6, 6, 3),
76 | u_knots=jnp.array([0.0, 0.0, 0.0, 0.0, 0.33, 0.66, 1.0, 1.0, 1.0, 1.0]),
77 | v_knots=jnp.array([0.0, 0.0, 0.0, 0.0, 0.33, 0.66, 1.0, 1.0, 1.0, 1.0]),
78 | u=u,
79 | v=u,
80 | )
81 | expected = jnp.array(
82 | [
83 | [
84 | [-25.0, -25.0, -10.0],
85 | [-25.000004, -9.077171, -2.397286],
86 | [-25.000002, 0.11256424, -0.3082978],
87 | [-25.000002, 9.309787, -2.4978478],
88 | [-25.0, 25.0, -10.0],
89 | ],
90 | [
91 | [-9.077171, -25.000004, -6.2806444],
92 | [-9.077172, -9.077172, -4.8789206],
93 | [-9.077172, 0.11256425, -5.9172673],
94 | [-9.077171, 9.309788, -4.8489876],
95 | [-9.077171, 25.000004, -6.2806444],
96 | ],
97 | [
98 | [0.11256424, -25.000002, -4.2381387],
99 | [0.1125643, -9.077172, -5.3740025],
100 | [0.1125643, 0.11256424, -7.434884],
101 | [0.11256424, 9.309788, -5.299428],
102 | [0.11256424, 25.000002, -4.2381387],
103 | ],
104 | [
105 | [9.309787, -25.000002, -5.5793867],
106 | [9.309788, -9.077171, -4.6443815],
107 | [9.309788, 0.11256424, -5.7848625],
108 | [9.309787, 9.309787, -4.609164],
109 | [9.309787, 25.000002, -5.5793867],
110 | ],
111 | [
112 | [25.0, -25.0, -10.0],
113 | [25.000004, -9.077171, -1.3277057],
114 | [25.000002, 0.11256424, 1.5683833],
115 | [25.000002, 9.309787, -1.4598912],
116 | [25.0, 25.0, -10.0],
117 | ],
118 | ],
119 | dtype=jnp.float32,
120 | )
121 |
122 | assert jnp.allclose(actual, expected, atol=1e-3)
123 |
--------------------------------------------------------------------------------
/nebula/tools/solid.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import numpy as np
4 | from nebula.helpers.types import Number
5 | from nebula.prim.axes import Axes
6 | from nebula.topology.faces import Faces
7 | from nebula.topology.shells import Shells
8 | from nebula.topology.solids import Solids
9 | from nebula.topology.wires import Wires
10 |
11 |
12 | class SolidTool:
13 | @staticmethod
14 | # @jax.jit
15 | def extrude_line_segments(segments: jnp.ndarray, translation: Number):
16 | """
17 | Returns the segments for the extruded edge points
18 |
19 | :param segment: edge segments to extrude
20 | :type segment: jnp.ndarray
21 | :param translation: translation to extrude by
22 | :type translation: Number
23 | :return: extruded segments (num_edges, 2, 3)
24 |
25 |
26 | """
27 | bottom_edge = segments
28 | top_edge = bottom_edge[::-1] + translation
29 | right_edge = jnp.stack([bottom_edge[-1], top_edge[0]])
30 | left_edge = jnp.stack([top_edge[-1], bottom_edge[0]])
31 |
32 | return jnp.stack([bottom_edge, right_edge, top_edge, left_edge])
33 |
34 | @staticmethod
35 | @jax.jit
36 | def extruded_planar_faces(
37 | faces: Faces, translation: Number, is_interior: jnp.ndarray
38 | ):
39 | # Handle planar extrusion faces
40 | line_segments = faces.wires.edges.lines.get_segments()
41 | plane_segments = jnp.where(
42 | is_interior[:, None, None],
43 | jnp.flip(line_segments, (0, 1)),
44 | line_segments,
45 | )
46 |
47 | extruded_plane_segments = jax.vmap(
48 | SolidTool.extrude_line_segments, in_axes=(0, None)
49 | )(
50 | plane_segments, translation
51 | ) # (num_edges, 4, 2, 3)
52 |
53 | return Faces.from_wires(
54 | Wires.from_line_segments(extruded_plane_segments),
55 | is_exterior=True,
56 | )
57 |
58 | # TODO: this is a work in progress.
59 | # @staticmethod
60 | # def sweep(wires: Wires, path: Wires):
61 | # assert path.is_single, "Only one path is supported for sweep"
62 | # faces = Faces.from_wires(wires)
63 | # start_face = faces.flip_winding()
64 |
65 |
66 | # sample = jnp.linspace(0.1, 1.0, 10)
67 |
68 | # new_axes = Axes(
69 | # normals=path.dir_vec_at(sample),
70 | # origins=path.evaluate_at(sample),
71 | # )
72 |
73 | # end_face = faces.project(new_axes)
74 | # end_face.wires.plot()
75 |
76 |
77 | # @staticmethod
78 | # def revolve(wires: Wires, angle: Number, axis: Axes):
79 | # return SolidTool.sweep(wires, Wires.make_circle(1.0, angle, axis))
80 |
81 | @staticmethod
82 | def extrude(wires: Wires, distance: Number):
83 | """Extrudes a wire by a distance on axis.
84 |
85 | E-------F
86 | /| /|
87 | / | / |
88 | A--|----B |
89 | | H----|--G
90 | | / ^ /
91 | |/ |/
92 | C--->---D
93 |
94 | :param wire: Wires to extrude
95 | :type wire: Wire
96 | :param distance: distance to extrude
97 | :type distance: float
98 | :return: new extruded Solids
99 | :rtype: Solids
100 | """
101 |
102 | faces = Faces.from_wires(wires)
103 |
104 | faces_normals = faces.get_axis_normals()
105 | translation = faces_normals * distance
106 |
107 | start_face = faces.flip_winding()
108 | end_face = faces.translate(translation)
109 |
110 | # Handle planar extrusion faces
111 | line_segments = faces.wires.edges.lines.get_segments()
112 | planar_is_interior = faces.wires.is_interior[~faces.wires.edges.is_curve]
113 | plane_segments = line_segments.at[planar_is_interior].set(
114 | jnp.flip(line_segments[planar_is_interior], (0, 1))
115 | )
116 | extruded_plane_segments = jax.jit(
117 | jax.vmap(SolidTool.extrude_line_segments, in_axes=(0, None))
118 | )(plane_segments, translation)
119 |
120 | exterior_extruded_plane_faces = Faces.from_wires(
121 | Wires.from_line_segments(extruded_plane_segments[~planar_is_interior]),
122 | is_exterior=True,
123 | )
124 | interior_extruded_plane_faces = Faces.from_wires(
125 | Wires.from_line_segments(extruded_plane_segments[planar_is_interior]),
126 | is_exterior=True,
127 | )
128 |
129 | # Handle non-planar extrusion faces
130 | start_curves = faces.wires.edges.curves
131 | end_curves = end_face.wires.edges.curves
132 | bspline_surface_faces = Faces.from_wires(Wires.skin((start_curves, end_curves)))
133 |
134 | exterior_faces = Faces.combine(
135 | (start_face, end_face, exterior_extruded_plane_faces, bspline_surface_faces)
136 | )
137 | start_face_exterior_index = start_face.index[~wires.is_interior][0]
138 | exterior_shell_index = (
139 | np.ones(len(exterior_faces.index), dtype=np.int32) * start_face_exterior_index
140 | )
141 | interior_shell_index = np.repeat(wires.index[wires.is_interior], 4)
142 | shell_index = np.concatenate([exterior_shell_index, interior_shell_index])
143 |
144 | new_faces = Faces.combine((exterior_faces, interior_extruded_plane_faces))
145 | shells = Shells(faces=new_faces, index=shell_index)
146 |
147 | return Solids.from_shells(shells)
148 |
--------------------------------------------------------------------------------
/nebula/prim/axes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import field
2 | from functools import cached_property
3 | from typing import Literal
4 | import jax_dataclasses as jdc
5 | import jax
6 | import jax.numpy as jnp
7 |
8 |
9 | @jax.jit
10 | def get_rotation_matrix(a: jnp.ndarray, b: jnp.ndarray):
11 | """Returns a rotation matrix that rotates vector a to vector b.
12 |
13 | :param a: vector a
14 | :type a: jnp.ndarray
15 | :param b: vector b
16 | :type b: jnp.ndarray
17 | :return: rotation matrix
18 | :rtype: jnp.ndarray
19 | """
20 | # Tolerance for floating point errors
21 | eps = 1.0e-10
22 |
23 | # Normalize the vectors
24 | a = a / jnp.linalg.norm(a)
25 | b = b / jnp.linalg.norm(b)
26 |
27 | # dimension of the space and identity
28 | I = jnp.identity(a.size)
29 |
30 | # Get the cross product of the two vectors
31 | v = jnp.cross(a, b)
32 |
33 | # Get the dot product of the two vectors
34 | c = jnp.dot(a, b=b)
35 |
36 | # the cross product matrix of a vector to rotate around
37 | K = jnp.outer(b, a) - jnp.outer(a, b)
38 |
39 | return jnp.where(
40 | # same direction
41 | jnp.abs(c - 1.0) < eps,
42 | I,
43 | jnp.where(
44 | # opposite direction
45 | jnp.abs(c + 1.0) < eps,
46 | -I,
47 | # Rodrigues' formula
48 | I + K + (K @ K) / (1 + c),
49 | ),
50 | )
51 |
52 |
53 | AxesString = Literal["XY", "YX", "XZ", "ZX", "YZ", "ZY"]
54 |
55 |
56 |
57 |
58 | @jdc.pytree_dataclass
59 | class Axes:
60 | normals: jnp.ndarray
61 | origins: jnp.ndarray = field(default_factory=lambda: jnp.array([[0.0, 0.0, 0.0]]))
62 |
63 | @cached_property
64 | def local_rotation_matrix(self):
65 | return jax.vmap(get_rotation_matrix, (0, None))(self.normals, XY.normals[0])
66 |
67 | @cached_property
68 | def world_rotation_matrix(self):
69 | return jax.vmap(get_rotation_matrix, (None, 0))(XY.normals[0], self.normals)
70 |
71 | @property
72 | def count(self):
73 | return self.normals.shape[0]
74 |
75 | @property
76 | def local_origins(self):
77 | return self.to_local_coords(self.origins)
78 |
79 | @property
80 | def shape(self):
81 | return self.origins.shape
82 |
83 | def __add__(self, translation: jnp.ndarray):
84 | return Axes(origins=self.origins + translation, normals=self.normals)
85 |
86 | def to_local_coords(self, world_coords: jnp.ndarray):
87 | """Returns a rotation matrix that rotates the plane to the xy-plane.
88 |
89 | :param origin: origin of the plane
90 | :type origin: djnp.ndarray
91 | :param normal: normal of the plane
92 | :type normal: jnp.ndarray
93 | :param world_coords: world coordinates to rotate
94 | :type world_coords: jnp.ndarray
95 | :return: rotated coordinates
96 | :rtype: jnp.ndarray
97 |
98 | """
99 | # Translation vector
100 | translation_vector = jnp.expand_dims(XY.origins[0] - self.origins, axis=1)
101 |
102 | # Rotate the points
103 | return jax.jit(jax.vmap(jnp.matmul, (None, 0)))(
104 | world_coords + translation_vector, self.local_rotation_matrix
105 | )
106 |
107 | def to_world_coords(self, local_coords: jnp.ndarray):
108 | """Returns a rotation matrix that rotates the plane to the xy-plane.
109 |
110 | :param origin: origin of the plane
111 | :type origin: jnp.ndarray
112 | :param normal: normal of the plane
113 | :type normal: jnp.ndarray
114 | :param local_coords: local coordinates to rotate
115 | :type local_coords: jnp.ndarray
116 | :return: rotated coordinates
117 | :rtype: jnp.ndarray
118 | """
119 | if local_coords.shape[-1] == 2:
120 | local_coords = jnp.concatenate(
121 | [local_coords, jnp.zeros((local_coords.shape[0], 1))], axis=-1
122 | )
123 |
124 | # Translation vector
125 | translation_vector = jnp.expand_dims(self.origins - XY.origins[0], axis=1)
126 |
127 | # Rotate the points
128 | return (
129 | jax.jit(jax.vmap(jnp.matmul, (None, 0)))(
130 | local_coords, self.world_rotation_matrix
131 | )
132 | + translation_vector
133 | )
134 |
135 | def __repr__(self) -> str:
136 | return f"Axes(origin={self.origins}, normal={self.normals})"
137 |
138 | @staticmethod
139 | def from_str(str: AxesString):
140 | if str == "XY":
141 | return Axes(
142 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[0.0, 0.0, 1.0]])
143 | )
144 | elif str == "YX":
145 | return Axes(
146 | origins=jnp.array([[0.0, 0.0, 0.0]]),
147 | normals=jnp.array([[0.0, 0.0, -1.0]]),
148 | )
149 | elif str == "XZ":
150 | return Axes(
151 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[0.0, 1.0, 0.0]])
152 | )
153 | elif str == "ZX":
154 | return Axes(
155 | origins=jnp.array([[0.0, 0.0, 0.0]]),
156 | normals=jnp.array([[0.0, -1.0, 0.0]]),
157 | )
158 | elif str == "YZ":
159 | return Axes(
160 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array([[1.0, 0.0, 0.0]])
161 | )
162 | elif str == "ZY":
163 | return Axes(
164 | origins=jnp.array([[0.0, 0.0, 0.0]]),
165 | normals=jnp.array([[-1.0, 0.0, 0.0]]),
166 | )
167 |
168 |
169 | XY = Axes(
170 | origins=jnp.array([[0.0, 0.0, 0.0]]), normals=jnp.array(object=[[0.0, 0.0, 1.0]])
171 | )
172 |
--------------------------------------------------------------------------------
/nebula/prim/bspline_curves.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax_dataclasses as jdc
3 | from nebula.evaluators.bspline import BSplineEvaluator
4 | from nebula.prim.axes import Axes
5 | from nebula.prim import SparseArray, SparseIndexable
6 | from typing import Optional, Sequence, Union
7 | import jax.numpy as jnp
8 | import numpy as np
9 |
10 |
11 | @jdc.pytree_dataclass
12 | class BSplineCurves:
13 | degree: jnp.ndarray
14 | ctrl_pnts: SparseArray
15 | knots: SparseArray
16 |
17 | @staticmethod
18 | def empty():
19 | return BSplineCurves(
20 | degree=jnp.empty((0,), dtype=np.int32),
21 | ctrl_pnts=SparseArray.empty((0, 3)),
22 | knots=SparseArray.empty(),
23 | )
24 |
25 |
26 | # TODO: make this more efficient
27 | def evaluate(self, u: Optional[jnp.ndarray] = None):
28 | """
29 | Evaluate the curves at the given u values.
30 |
31 | :param u: The u values to evaluate the curves at. If None then gives sements of start and end ctrl pnts
32 | :type u: jnp.ndarray
33 | :returns: The vertices of the curves.
34 | :rtype: jnp.ndarray
35 |
36 | """
37 | if not self.ctrl_pnts.count:
38 | return jnp.empty((0, 2, 3))
39 | curve_segments = []
40 | for i in range(self.ctrl_pnts.count):
41 | curve = self.mask(i)
42 | if u is not None:
43 | vertices = BSplineEvaluator.eval_curve(
44 | curve.degree, curve.ctrl_pnts.val, u, curve.knots.val
45 | )
46 |
47 | segments = jnp.stack([vertices[:-1], vertices[1:]], axis=1)
48 | else:
49 | segments = jnp.array([[curve.ctrl_pnts.val[0], curve.ctrl_pnts.val[-1]]])
50 | curve_segments.append(segments)
51 | return jnp.concatenate(curve_segments, axis=0)
52 |
53 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray):
54 | curve = self.mask(index)
55 | return BSplineEvaluator.eval_curve(
56 | curve.degree, curve.ctrl_pnts.val, u, curve.knots.val
57 | )
58 |
59 | # TODO: Add and test this later when converting planes to bspline surfaces
60 | @staticmethod
61 | @jax.jit
62 | def from_line_segment(segment: jnp.ndarray):
63 | # other segments are the connectors between top and bottom segments in quad
64 | ctrl_pnts = segment.reshape(-1, 3)
65 | line_knots = BSplineEvaluator.generate_line_knots()
66 | return BSplineCurves(
67 | jnp.array([1]),
68 | SparseArray.from_array(ctrl_pnts),
69 | SparseArray.from_array(line_knots),
70 | )
71 |
72 | def get_segments(self):
73 | last_vertices = (
74 | jnp.empty((self.count, 3)).at[self.ctrl_pnts.index].set(self.ctrl_pnts.val)
75 | )
76 |
77 | first_vertices = (
78 | jnp.empty((self.count, 3))
79 | .at[self.ctrl_pnts.index[::-1]]
80 | .set(self.ctrl_pnts.val[::-1])
81 | )
82 |
83 | return jnp.stack([first_vertices, last_vertices], axis=1)
84 |
85 |
86 | def add(
87 | self,
88 | curves: "BSplineCurves",
89 | ):
90 | return BSplineCurves(
91 | degree=jnp.concatenate([self.degree, curves.degree]),
92 | ctrl_pnts=self.ctrl_pnts.add(curves.ctrl_pnts),
93 | knots=self.knots.add(curves.knots),
94 | )
95 |
96 |
97 | @classmethod
98 | def combine(cls, all_curves: Sequence["BSplineCurves"]):
99 | """Combine curves into single
100 |
101 | :param curves: Topology to combine
102 | :type curves: Topology
103 | :return: Combined curves
104 | :rtype: Topology
105 | """
106 | new_curves = cls(None) # type: ignore
107 | for curves in all_curves:
108 | new_curves.add(curves)
109 | return new_curves
110 |
111 | def clone(
112 | self,
113 | degree: Optional[jnp.ndarray] = None,
114 | ctrl_pnts: Optional[SparseArray] = None,
115 | knots: Optional[SparseArray] = None,
116 | ):
117 | return BSplineCurves(
118 | degree=degree if degree is not None else self.degree,
119 | ctrl_pnts=(
120 | SparseArray(ctrl_pnts.val, ctrl_pnts.index)
121 | if ctrl_pnts is not None
122 | else self.ctrl_pnts
123 | ),
124 | knots=(
125 | SparseArray(knots.val, knots.index) if knots is not None else self.knots
126 | ),
127 | )
128 |
129 | # TODO: this is assuming it is already on the XY plane
130 | def project(self, curr_axes: Axes, new_axes: Axes):
131 | local_coords = curr_axes.to_local_coords(self.ctrl_pnts.val)
132 | return BSplineCurves(
133 | degree=jnp.repeat(self.degree[None, :], new_axes.count, axis=0).flatten(),
134 | ctrl_pnts=SparseArray(
135 | val=new_axes.to_world_coords(local_coords).reshape(-1, 3),
136 | index=SparseIndexable.expanded_index(self.ctrl_pnts.index, new_axes.count),
137 | ),
138 | knots=SparseArray(
139 | val=jnp.repeat(self.knots.val[None, :], new_axes.count, axis=0).flatten(),
140 | index=SparseIndexable.expanded_index(self.knots.index, new_axes.count),
141 | ),
142 | )
143 |
144 | def translate(self, translation: jnp.ndarray):
145 | return self.clone(
146 | ctrl_pnts=SparseArray(self.ctrl_pnts.val + translation, index=self.ctrl_pnts.index),
147 | )
148 |
149 | def mask(self, mask: Union[jnp.ndarray, np.ndarray, int]):
150 | index = np.arange(self.count)[mask]
151 | ctrl_pnts_mask = jnp.isin(self.ctrl_pnts.index, index)
152 | knots_mask = jnp.isin(self.knots.index, index)
153 | return BSplineCurves(
154 | degree=self.degree[mask],
155 | ctrl_pnts=self.ctrl_pnts.mask(ctrl_pnts_mask),
156 | knots=self.knots.mask(knots_mask),
157 | )
158 |
159 | def reorder(self, index: np.ndarray):
160 | return BSplineCurves(
161 | degree=self.degree[index],
162 | ctrl_pnts=self.ctrl_pnts.reorder(index),
163 | knots=self.knots.reorder(index),
164 | )
165 |
166 | def flip_winding(self, mask: Optional[jnp.ndarray] = None):
167 | return BSplineCurves(
168 | degree=self.degree[::-1],
169 | ctrl_pnts=self.ctrl_pnts.flip_winding(mask),
170 | knots=self.knots.flip_winding(mask),
171 | )
172 |
173 | def __repr__(self) -> str:
174 | return f"BSplineCurves(count={self.count})"
175 |
176 | @property
177 | def count(self):
178 | return self.ctrl_pnts.count
179 |
--------------------------------------------------------------------------------
/nebula/render/visualization.py:
--------------------------------------------------------------------------------
1 | import zmq
2 | import pickle
3 | from http.client import REQUEST_TIMEOUT
4 | import jax.numpy as jnp
5 | from typing import Literal, Optional, Union
6 | from dataclasses import dataclass, asdict
7 | from nebula.helpers.vector import VectorHelper
8 | from nebula.render.tesselation import Mesh, Tesselator
9 | from nebula.topology.solids import Solids
10 |
11 |
12 | @dataclass
13 | class BoundingBox:
14 | xmin: float
15 | xmax: float
16 | ymin: float
17 | ymax: float
18 | zmin: float
19 | zmax: float
20 |
21 |
22 | @dataclass
23 | class PartShape:
24 | vertices: list[list[float]]
25 | triangles: list[int]
26 | normals: list[list[float]]
27 | edges: list[list[list[float]]]
28 |
29 |
30 | @dataclass
31 | class Part:
32 | name: str
33 | id: str
34 | shape: PartShape
35 | type: str = "shapes"
36 | color: str = "#e8b024"
37 | renderback: bool = True
38 |
39 |
40 | @dataclass
41 | class Shapes:
42 | bb: Optional[BoundingBox]
43 | parts: list[Part]
44 | name: str = "Group"
45 | id: str = "/Group"
46 | loc: Optional[tuple[list[float], tuple[float, float, float, float]]] = None
47 |
48 |
49 | def get_part(mesh: Mesh, index: int = 0):
50 | return Part(
51 | name=f"Part_{index}",
52 | id=f"/Group/Part_{index}",
53 | shape=PartShape(
54 | vertices=mesh.vertices.tolist(),
55 | triangles=mesh.simplices.flatten().tolist(),
56 | normals=mesh.normals.tolist(),
57 | edges=mesh.edges.tolist(),
58 | ),
59 | )
60 |
61 |
62 | def get_bounding_box(vertices: jnp.ndarray):
63 | return BoundingBox(
64 | xmin=float(vertices[:, 0].min()),
65 | xmax=float(vertices[:, 0].max()),
66 | ymin=float(vertices[:, 1].min()),
67 | ymax=float(vertices[:, 1].max()),
68 | zmin=float(vertices[:, 2].min()) or -1e-07,
69 | zmax=float(vertices[:, 2].max()) or 1e-07,
70 | )
71 |
72 |
73 | DEFAULT_CONFIG = {
74 | "viewer": "",
75 | "anchor": "right",
76 | "theme": "light",
77 | "pinning": False,
78 | "angular_tolerance": 0.2,
79 | "deviation": 0.1,
80 | "edge_accuracy": None,
81 | "default_color": [232, 176, 36],
82 | "default_edge_color": "#707070",
83 | "optimal_bb": False,
84 | "render_normals": False,
85 | "render_edges": True,
86 | "render_mates": False,
87 | "parallel": False,
88 | "mate_scale": 1,
89 | "control": "trackball",
90 | "up": "Z",
91 | "axes": False,
92 | "axes0": False,
93 | "grid": [False, False, False],
94 | "ticks": 10,
95 | "ortho": True,
96 | "transparent": False,
97 | "black_edges": False,
98 | "ambient_intensity": 0.75,
99 | "direct_intensity": 0.15,
100 | "reset_camera": True,
101 | "show_parent": True,
102 | "show_bbox": False,
103 | "quaternion": None,
104 | "target": None,
105 | "zoom_speed": 1.0,
106 | "pan_speed": 1.0,
107 | "rotate_speed": 1.0,
108 | "collapse": 1,
109 | "tools": True,
110 | "timeit": False,
111 | "js_debug": False,
112 | "normal_len": 0,
113 | }
114 | ZMQ_PORT = 5555
115 |
116 |
117 | def connect(context):
118 | endpoint = f"tcp://localhost:{ZMQ_PORT}"
119 | socket = context.socket(zmq.REQ)
120 | socket.connect(endpoint)
121 | return socket
122 |
123 |
124 | def send(data):
125 | context = zmq.Context()
126 | socket = connect(context)
127 |
128 | msg = pickle.dumps(data, 4)
129 | print(" sending ... ", end="")
130 | socket.send(msg)
131 |
132 | retries_left = 3
133 | while True:
134 | if (socket.poll(REQUEST_TIMEOUT) & zmq.POLLIN) != 0:
135 | reply = socket.recv_json()
136 |
137 | if reply["result"] == "success":
138 | print("done")
139 | else:
140 | print("\n", reply["msg"])
141 | break
142 |
143 | retries_left -= 1
144 |
145 | # Socket is confused. Close and remove it.
146 | socket.setsockopt(zmq.LINGER, 0)
147 | socket.close()
148 | if retries_left == 0:
149 | break
150 |
151 | print("Reconnecting to server…")
152 | # Create new connection
153 | socket = connect(context)
154 |
155 | print("Resending ...")
156 | socket.send(msg)
157 |
158 |
159 | def show(
160 | item: Union[Solids, Mesh],
161 | type: Literal["cad", "plot"] = "cad",
162 | intensity: Optional[jnp.ndarray] = None,
163 | name: Optional[str] = None,
164 | file_name: Optional[str] = None,
165 | **kwargs,
166 | ):
167 | mesh = Tesselator.get_mesh(item) if isinstance(item, Solids) else item
168 |
169 | if type == "cad":
170 | parts: list[Part] = [get_part(mesh)]
171 |
172 | shapes = Shapes(
173 | bb=get_bounding_box(mesh.vertices),
174 | parts=parts,
175 | )
176 |
177 | states = {"/Group/Part_0": [1, 1]}
178 | data = {
179 | "data": dict(shapes=asdict(shapes), states=states),
180 | "type": "data",
181 | "config": DEFAULT_CONFIG,
182 | "count": len(parts),
183 | }
184 | send(data)
185 |
186 | else:
187 | import plotly.graph_objects as go
188 |
189 | Xe = []
190 | Ye = []
191 | Ze = []
192 | for T in mesh.vertices[mesh.simplices]:
193 | Xe.extend([T[k % 3][0] for k in range(4)] + [None])
194 | Ye.extend([T[k % 3][1] for k in range(4)] + [None])
195 | Ze.extend([T[k % 3][2] for k in range(4)] + [None])
196 |
197 | # Create a mesh object
198 | if intensity is not None:
199 | intensity = VectorHelper.normalize(intensity)
200 | plot_mesh = go.Mesh3d(
201 | x=mesh.vertices[:, 0],
202 | y=mesh.vertices[:, 1],
203 | z=mesh.vertices[:, 2],
204 | i=[face[0] for face in mesh.simplices],
205 | j=[face[1] for face in mesh.simplices],
206 | k=[face[2] for face in mesh.simplices],
207 | intensity=intensity,
208 | )
209 |
210 | lines = go.Scatter3d(
211 | x=Xe,
212 | y=Ye,
213 | z=Ze,
214 | mode="lines",
215 | line=dict(color="rgb(70,70,70)", width=1),
216 | )
217 |
218 | # Create a figure and add the mesh to it
219 | fig = go.Figure(data=[plot_mesh, lines])
220 |
221 | # Update layout to remove axes and background and make the graph larger
222 | fig.update_layout(
223 | scene=dict(
224 | xaxis=dict(visible=False),
225 | yaxis=dict(visible=False),
226 | zaxis=dict(visible=False),
227 | bgcolor="rgba(0,0,0,0)",
228 | ),
229 | width=800, # Set the width of the graph to 800 pixels
230 | height=600, # Set the height of the graph to 600 pixels
231 | title=name,
232 | )
233 | fig.update_layout(scene=dict(bgcolor="rgba(0,0,0,0)"))
234 | if file_name is not None:
235 | fig.write_image(file_name, scale=6, format="png", engine="kaleido")
236 | # Display the figure
237 | fig.show()
238 |
--------------------------------------------------------------------------------
/nebula/cases/airfoil.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from functools import cached_property
3 | from typing import Literal, Optional
4 | import plotly.graph_objects as go
5 | from nebula.evaluators.bspline import BSplineEvaluator, get_sampling
6 | import jax.numpy as jnp
7 | from nebula.tools.edge import EdgeTool
8 |
9 |
10 | def get_thickness_dist_ctrl_pnts(
11 | camber: jnp.ndarray,
12 | camber_normal: jnp.ndarray,
13 | thickness_dist: jnp.ndarray,
14 | thickness_sampling: jnp.ndarray,
15 | degree: jnp.ndarray,
16 | ):
17 | "get thickness distribution control points"
18 | camber_normal_thickness = BSplineEvaluator.eval_curve(
19 | degree, thickness_dist, thickness_sampling
20 | )
21 |
22 | return jnp.concatenate(
23 | [
24 | jnp.array([camber[0]]),
25 | camber + camber_normal * camber_normal_thickness,
26 | jnp.array([camber[-1]]),
27 | ]
28 | )
29 |
30 |
31 | @dataclass
32 | class CamberThicknessAirfoil:
33 | "parametric airfoil using B-splines"
34 |
35 | inlet_angle: jnp.ndarray
36 | "inlet angle (rad)"
37 |
38 | outlet_angle: jnp.ndarray
39 | "outlet angle (rad)"
40 |
41 | upper_thick_prop: jnp.ndarray
42 | "upper thickness proportion to chord length (length)"
43 |
44 | lower_thick_prop: jnp.ndarray
45 | "lower thickness proportion to chord length (length)"
46 |
47 | leading_prop: jnp.ndarray
48 | "leading edge tangent line proportion [0.0-1.0] (dimensionless)"
49 |
50 | trailing_prop: jnp.ndarray
51 | "trailing edge tangent line proportion [0.0-1.0] (dimensionless)"
52 |
53 | chord_length: jnp.ndarray
54 | "chord length (length)"
55 |
56 | stagger_angle: Optional[jnp.ndarray] = None
57 | "stagger angle (rad)"
58 |
59 | num_samples: int = 50
60 | "number of samples"
61 |
62 | is_cosine_sampling: bool = True
63 | "use cosine sampling"
64 |
65 | leading_ctrl_pnt: jnp.ndarray = field(
66 | default_factory=lambda: jnp.array([0.0, 0.0, 0.0])
67 | )
68 | "leading control point (length)"
69 |
70 | angle_units: Literal["rad", "deg"] = "rad"
71 | "angle units"
72 |
73 | def __post_init__(self):
74 | if self.angle_units == "deg":
75 | self.inlet_angle = jnp.radians(self.inlet_angle)
76 | self.outlet_angle = jnp.radians(self.outlet_angle)
77 |
78 | if self.upper_thick_prop is not None:
79 | self.upper_thick_dist = [
80 | self.chord_length * prop for prop in self.upper_thick_prop
81 | ]
82 |
83 | if self.lower_thick_prop is not None:
84 | self.lower_thick_dist = [
85 | self.chord_length * prop for prop in self.lower_thick_prop
86 | ]
87 |
88 | if self.stagger_angle is None:
89 | self.stagger_angle = (self.inlet_angle + self.outlet_angle) / 2
90 |
91 | self.degree = jnp.array(3)
92 | self.num_thickness_dist_pnts = len(self.upper_thick_dist) + 4
93 | self.camber_knots = BSplineEvaluator.generate_clamped_knots(self.degree, 4)
94 |
95 | self.thickness_dist_sampling = jnp.linspace(
96 | 0, 1, self.num_thickness_dist_pnts, endpoint=True
97 | )
98 | self.sampling = get_sampling(0.0, 1.0, self.num_samples, self.is_cosine_sampling)
99 | self.axial_chord_length = self.chord_length * jnp.cos(self.stagger_angle)
100 | self.height = self.chord_length * jnp.sin(self.stagger_angle)
101 |
102 | @cached_property
103 | def camber_ctrl_pnts(self):
104 | assert self.stagger_angle is not None, "stagger angle is not defined"
105 | p_le = jnp.array(self.leading_ctrl_pnt)
106 |
107 | p_te = p_le + jnp.array(
108 | [
109 | self.chord_length * jnp.cos(self.stagger_angle),
110 | self.chord_length * jnp.sin(self.stagger_angle),
111 | 0.0,
112 | ]
113 | )
114 |
115 | # leading edge tangent control point
116 | p1 = p_le + self.leading_prop * jnp.array(
117 | [
118 | self.chord_length * jnp.cos(self.inlet_angle),
119 | self.chord_length * jnp.sin(self.inlet_angle),
120 | 0.0,
121 | ]
122 | )
123 |
124 | # trailing edge tangent control point
125 | p2 = p_te - self.trailing_prop * jnp.array(
126 | [
127 | self.chord_length * jnp.cos(self.outlet_angle),
128 | self.chord_length * jnp.sin(self.outlet_angle),
129 | 0.0,
130 | ]
131 | )
132 |
133 | return jnp.vstack((p_le, p1, p2, p_te))
134 |
135 | @cached_property
136 | def top_ctrl_pnts(self):
137 | "upper side bspline"
138 | assert (
139 | self.upper_thick_dist is not None
140 | ), "upper thickness distribution is not defined"
141 | thickness_dist = jnp.vstack(self.upper_thick_dist)
142 | return get_thickness_dist_ctrl_pnts(
143 | self.camber_coords,
144 | self.camber_normal_coords,
145 | thickness_dist,
146 | self.thickness_dist_sampling,
147 | self.degree,
148 | )
149 |
150 | @cached_property
151 | def bottom_ctrl_pnts(self):
152 | "lower side bspline"
153 | assert (
154 | self.lower_thick_dist is not None
155 | ), "lower thickness distribution is not defined"
156 | thickness_dist = -jnp.vstack(self.lower_thick_dist)
157 | return get_thickness_dist_ctrl_pnts(
158 | self.camber_coords,
159 | self.camber_normal_coords,
160 | thickness_dist,
161 | self.thickness_dist_sampling,
162 | self.degree,
163 | )
164 |
165 | @cached_property
166 | def camber_coords(self):
167 | "camber line coordinates"
168 | return BSplineEvaluator.eval_curve(
169 | self.degree, self.camber_ctrl_pnts, self.thickness_dist_sampling
170 | )
171 |
172 | @cached_property
173 | def camber_normal_coords(self):
174 | "camber normal line coordinates"
175 | dy = jnp.gradient(self.camber_coords[:, 1])
176 | dx = jnp.gradient(self.camber_coords[:, 0])
177 | normal = jnp.vstack([-dy, dx, jnp.zeros(len(self.camber_coords))]).T
178 | return normal / jnp.linalg.norm(normal, axis=1)[:, None]
179 |
180 | def get_coords(self):
181 | "airfoil coordinates"
182 | top_coords = BSplineEvaluator.eval_curve(
183 | self.degree, self.top_ctrl_pnts, self.sampling
184 | )
185 | bottom_coords = BSplineEvaluator.eval_curve(
186 | self.degree, self.bottom_ctrl_pnts, self.sampling
187 | )
188 |
189 | return jnp.concatenate([top_coords[1:-1], bottom_coords[::-1]])
190 |
191 | def get_edges(self):
192 | airfoil_top_edge = EdgeTool.make_bspline_curve(self.top_ctrl_pnts, self.degree)
193 | airfoil_bottom_edge = EdgeTool.make_bspline_curve(
194 | self.bottom_ctrl_pnts, self.degree
195 | )
196 | return airfoil_top_edge, airfoil_bottom_edge
197 |
198 | def visualize(
199 | self,
200 | include_camber=True,
201 | include_camber_ctrl_pnts=False,
202 | filename: Optional[str] = None,
203 | ):
204 | fig = go.Figure(layout=go.Layout(title=go.layout.Title(text="Airfoil")))
205 | if include_camber_ctrl_pnts:
206 | fig.add_trace(
207 | go.Scatter(
208 | x=self.camber_ctrl_pnts[:, 0],
209 | y=self.camber_ctrl_pnts[:, 1],
210 | name=f"Camber Control Points",
211 | )
212 | )
213 |
214 | if include_camber:
215 | camber_coords = self.camber_coords
216 | fig.add_trace(
217 | go.Scatter(x=camber_coords[:, 0], y=camber_coords[:, 1], name=f"Camber")
218 | )
219 |
220 | coords = self.get_coords()
221 | fig.add_trace(go.Scatter(x=coords[:, 0], y=coords[:, 1], name=f"Airfoil"))
222 |
223 | fig.layout.yaxis.scaleanchor = "x" # type: ignore
224 | if filename:
225 | fig.write_image(filename, width=500, height=500)
226 | else:
227 | fig.show()
228 |
229 |
230 |
--------------------------------------------------------------------------------
/nebula/evaluators/bspline.py:
--------------------------------------------------------------------------------
1 |
2 | from nebula.helpers.types import Number
3 | import jax
4 | import jax.numpy as jnp
5 | from typing import Callable, Optional
6 |
7 |
8 | SpanFunction = Callable[[int, jnp.ndarray, int, jnp.ndarray], jnp.ndarray]
9 |
10 |
11 | def get_degree(degree: Number):
12 | degree = jnp.asarray(degree)
13 | return degree[0] if len(degree.shape) > 0 else degree
14 |
15 | def get_sampling(start: Number, end: Number, num_points: int, is_cosine_sampling: bool = False):
16 | if is_cosine_sampling:
17 | beta = jnp.linspace(0.0,jnp.pi, num_points, endpoint=True)
18 | return 0.5*(1.0-jnp.cos(beta))
19 | return jnp.linspace(start, end, num_points, endpoint=True)
20 |
21 |
22 | class BSplineEvaluator:
23 | @staticmethod
24 | def find_spans(
25 | degree: jnp.ndarray,
26 | knot_vector: jnp.ndarray,
27 | num_ctrlpts: int,
28 | knot_samples: jnp.ndarray,
29 | ):
30 | """Finds the span of a single knot over the knot vector using linear search.
31 |
32 | Alternative implementation for the Algorithm A2.1 from The NURBS Book by Piegl & Tiller.
33 |
34 | :param degree: degree, :math:`p`
35 | :type degree: jnp.ndarray, (1,)
36 | :param knot_vector: knot vector, :math:`U`
37 | :type knot_vector: torch.Tensor
38 | :param num_ctrlpts: number of control points, :math:`n + 1`
39 | :type num_ctrlpts: int
40 | :param knot: knot or parameter, :math:`u`
41 | :type knot: float
42 | :return: knot span
43 | :rtype: int
44 | """
45 | span_start = degree + 1
46 | span_offset = jnp.sum(
47 | jnp.expand_dims(knot_samples, axis=-1) > knot_vector[span_start:], axis=-1
48 | )
49 | span = jnp.clip(span_start + span_offset, a_max=num_ctrlpts)
50 | return span - 1
51 |
52 | @staticmethod
53 | def basis_functions(
54 | degree: Number,
55 | knot_vector: jnp.ndarray,
56 | span: jnp.ndarray,
57 | knot_samples: jnp.ndarray,
58 | ):
59 | """Computes the non-vanishing basis functions for a single parameter.
60 |
61 | Implementation of Algorithm A2.2 pg 70 from The NURBS Book by Piegl & Tiller.
62 | Uses recurrence to compute the basis functions, also known as Cox - de
63 | Boor recursion formula.
64 |
65 | :param degree: degree, :math:`p`
66 | :type degree: jnp.ndarray (1,)
67 | :param knot_vector: knot vector, :math:`U`
68 | :type knot_vector: list, tuple
69 | :param span: knot span, :math:`i`
70 | :type span: int
71 | :param knot: knot or parameter, :math:`u`
72 | :type knot: float
73 | :return: basis functions
74 | :rtype: list
75 | """
76 | N = jnp.ones((degree + 1, len(knot_samples)))
77 | left = jnp.expand_dims(knot_samples, axis=0) - knot_vector[span + 1 - jnp.arange(degree + 1)[:, None]]
78 | right = knot_vector[span + jnp.arange(degree + 1)[:, None]] - jnp.expand_dims(knot_samples, axis=0)
79 |
80 | def inner_body_fun(r, init_value):
81 | j, saved, new_N = init_value
82 | temp = new_N[r] / (right[r + 1] + left[j - r])
83 | next_N = new_N.at[r].set(saved + right[r + 1] * temp)
84 | saved = left[j - r] * temp
85 | return j, saved, next_N
86 |
87 | def outer_body_fun(j, N: jnp.ndarray):
88 | saved = jnp.zeros(len(knot_samples))
89 | _, saved, N = jax.lax.fori_loop(0, j, inner_body_fun, (j, saved, N))
90 | return N.at[j].set(saved)
91 |
92 | return jax.lax.fori_loop(1, degree+1, outer_body_fun, N)
93 |
94 |
95 | @staticmethod
96 | def generate_line_knots():
97 | return BSplineEvaluator.generate_clamped_knots(1, 2)
98 |
99 | @staticmethod
100 | def generate_clamped_knots(degree: Number, num_ctrlpts: Number):
101 | """Generates a clamped knot vector.
102 |
103 | :param degree: non-zero degree of the curve
104 | :type degree: int
105 | :param num_ctrlpts: non-zero number of control points
106 | :type num_ctrlpts: int
107 | :return: clamped knot vector
108 | :rtype: Array
109 | """
110 | # Number of repetitions at the start and end of the array
111 | num_repeat = degree
112 | # Number of knots in the middle
113 | num_segments = int(num_ctrlpts - (degree + 1))
114 |
115 | return jnp.concatenate(
116 | (
117 | jnp.zeros(num_repeat),
118 | jnp.linspace(0.0, 1.0, num_segments + 2),
119 | jnp.ones(num_repeat),
120 | )
121 | )
122 |
123 | @staticmethod
124 | def generate_unclamped_knots(degree: int, num_ctrlpts: int):
125 | """Generates a unclamped knot vector.
126 |
127 | :param degree: non-zero degree of the curve
128 | :type degree: int
129 | :param num_ctrlpts: non-zero number of control points
130 | :type num_ctrlpts: int
131 | :return: clamped knot vector
132 | :rtype: Array
133 | """
134 | # Should conform the rule: m = n + p + 1
135 | return jnp.linspace(0.0, 1.0, degree + num_ctrlpts + 1)
136 |
137 |
138 | @staticmethod
139 | def eval_curve_pnt(
140 | degree: jnp.ndarray, ctrl_pnts: jnp.ndarray, basis: jnp.ndarray, span: jnp.ndarray
141 | ):
142 | dim = ctrl_pnts.shape[-1]
143 | if len(ctrl_pnts) < degree + 1:
144 | raise ValueError("Invalid size of control points for the given degree.")
145 |
146 | ctrl_pnt_slice = jax.lax.dynamic_slice(
147 | ctrl_pnts, (span - degree, 0), (1 + degree, dim)
148 | )
149 | return jnp.sum(ctrl_pnt_slice * jnp.expand_dims(basis, axis=1), axis=0)
150 |
151 | @staticmethod
152 | def eval_curve(
153 | degree: jnp.ndarray, ctrl_pnts: jnp.ndarray, u: jnp.ndarray, knots: Optional[jnp.ndarray] = None
154 | ):
155 | degree = get_degree(degree)
156 | knots = (
157 | BSplineEvaluator.generate_clamped_knots(degree, len(ctrl_pnts))
158 | if knots is None
159 | else knots
160 | )
161 |
162 | span = BSplineEvaluator.find_spans(degree, knots, len(ctrl_pnts), u)
163 | basis = BSplineEvaluator.basis_functions(degree, knots, span, u)
164 |
165 | return jax.vmap(BSplineEvaluator.eval_curve_pnt, in_axes=(None, None, 1, 0))(
166 | degree, ctrl_pnts, basis, span
167 | )
168 |
169 | @staticmethod
170 | def eval_surface_pnt(
171 | u_degree: jnp.ndarray,
172 | v_degree: jnp.ndarray,
173 | ctrl_pnts: jnp.ndarray,
174 | basis_u: jnp.ndarray,
175 | basis_v: jnp.ndarray,
176 | span_u: jnp.ndarray,
177 | span_v: jnp.ndarray,
178 | ):
179 | ctrl_pnt_slice = jax.lax.dynamic_slice(
180 | ctrl_pnts,
181 | (span_u - u_degree, span_v - v_degree, 0),
182 | (1 + u_degree, 1 + v_degree, 3),
183 | )
184 |
185 | eval_prev_pnt = jnp.sum(ctrl_pnt_slice * jnp.expand_dims(basis_v, axis=1), axis=1)
186 | return jnp.sum(eval_prev_pnt * jnp.expand_dims(basis_u, axis=1), axis=0)
187 |
188 | @staticmethod
189 | def eval_surface(
190 | u_degree: Number,
191 | v_degree: Number,
192 | ctrl_pnts: jnp.ndarray,
193 | u_knots: jnp.ndarray,
194 | v_knots: jnp.ndarray,
195 | u: jnp.ndarray,
196 | v: jnp.ndarray,
197 | ):
198 | u_degree = get_degree(u_degree)
199 | v_degree = get_degree(v_degree)
200 |
201 | assert (
202 | ctrl_pnts.shape[0] >= u_degree + 1
203 | ), f"Number of curves should be at least {u_degree + 1}"
204 | assert (
205 | ctrl_pnts.shape[1] >= v_degree + 1
206 | ), f"Number of control points should be at least {v_degree + 1}"
207 |
208 | span_u = BSplineEvaluator.find_spans(
209 | u_degree, u_knots, ctrl_pnts.shape[0], u
210 | )
211 | basis_u = BSplineEvaluator.basis_functions(u_degree, u_knots, span_u, u)
212 |
213 | span_v = BSplineEvaluator.find_spans(
214 | v_degree, v_knots, ctrl_pnts.shape[1], v
215 | )
216 | basis_v = BSplineEvaluator.basis_functions(v_degree, v_knots, span_v, v)
217 |
218 | return jax.vmap(
219 | jax.vmap(
220 | BSplineEvaluator.eval_surface_pnt,
221 | in_axes=(None, None, None, None, 1, None, 0),
222 | ),
223 | in_axes=(None, None, None, 1, None, 0, None),
224 | )(u_degree, v_degree, ctrl_pnts, basis_u, basis_v, span_u, span_v)
225 |
226 |
--------------------------------------------------------------------------------
/nebula/render/tesselation.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Sequence, Union
3 | import jax
4 | import numpy as np
5 | from scipy.spatial import Delaunay
6 | import jax.numpy as jnp
7 | from nebula.evaluators.bspline import BSplineEvaluator, get_sampling
8 | from nebula.prim.bspline_curves import BSplineCurves
9 | from nebula.helpers.wire import WireHelper
10 | from nebula.topology.edges import DEFAULT_CURVE_SAMPLES
11 | from nebula.topology.faces import Faces
12 | from nebula.topology.solids import Solids
13 | from nebula.topology.wires import Wires
14 | from nebula.workplane import Workplane
15 |
16 |
17 | @dataclass
18 | class Mesh:
19 | vertices: jnp.ndarray
20 | simplices: jnp.ndarray
21 | normals: jnp.ndarray
22 | edges: jnp.ndarray
23 |
24 | @staticmethod
25 | def empty():
26 | return Mesh(
27 | vertices=jnp.empty((0, 3)),
28 | simplices=jnp.empty((0, 3), dtype=jnp.int32),
29 | normals=jnp.empty((0, 3)),
30 | edges=jnp.empty((0, 2, 3)),
31 | )
32 |
33 | @staticmethod
34 | def combine(meshes: Sequence["Mesh"]):
35 |
36 | vertices = jnp.concatenate([mesh.vertices for mesh in meshes])
37 | simplices = jnp.concatenate(
38 | [
39 | mesh.simplices + (len(meshes[i - 1].vertices) if i > 0 else 0)
40 | for i, mesh in enumerate(meshes)
41 | ]
42 | )
43 | normals = jnp.concatenate([mesh.normals for mesh in meshes])
44 | edges = jnp.concatenate([mesh.edges for mesh in meshes])
45 | return Mesh(vertices, simplices, normals, edges)
46 |
47 |
48 | class Tesselator:
49 | @staticmethod
50 | def get_mesh(solids: Solids):
51 | faces = solids.shells.faces
52 | planar_mesh = Tesselator.get_planar_mesh(faces.mask(faces.wires.is_planar))
53 | non_planar_mesh = Tesselator.get_non_planar_mesh(
54 | faces.mask(~faces.wires.is_planar)
55 | )
56 | return Mesh.combine([planar_mesh, non_planar_mesh])
57 |
58 | @staticmethod
59 | def get_differentiable_mesh(target: Union[Workplane, Solids]):
60 | solids = target.base_solids if isinstance(target, Workplane) else target
61 | faces = solids.shells.faces
62 | planar_mesh = Tesselator.get_differentiable_planar_mesh(
63 | faces.mask(faces.wires.is_planar)
64 | )
65 | non_planar_mesh = Tesselator.get_non_planar_mesh(
66 | faces.mask(~faces.wires.is_planar)
67 | )
68 | return Mesh.combine([planar_mesh, non_planar_mesh])
69 |
70 | @staticmethod
71 | def get_trimmed_simplices(
72 | vertices: jnp.ndarray,
73 | simplices: jnp.ndarray,
74 | exterior_segments: jnp.ndarray,
75 | interior_segments: jnp.ndarray,
76 | ):
77 | triangles = vertices[simplices]
78 | triangle_centers = jnp.mean(triangles, axis=1)
79 | is_exterior_triangle = jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))(
80 | exterior_segments, triangle_centers
81 | )
82 | is_interior_triangle = jax.vmap(WireHelper.contains_vertex, in_axes=(None, 0))(
83 | interior_segments, triangle_centers
84 | )
85 | return simplices[is_exterior_triangle & ~is_interior_triangle]
86 |
87 | @staticmethod
88 | def get_tri_normal(tri_vertices: Union[jnp.ndarray, np.ndarray]) -> jnp.ndarray:
89 | # Calculate two vectors along the edges of the triangle
90 | normal = jnp.cross(
91 | tri_vertices[1] - tri_vertices[0], tri_vertices[2] - tri_vertices[0]
92 | )
93 | unit_normal = normal / jnp.linalg.norm(normal, axis=0, keepdims=True)
94 | return unit_normal
95 |
96 | @staticmethod
97 | def get_differentiable_planar_mesh(faces: Faces, curve_samples: int = 50):
98 | is_quad_face = ((faces.get_num_edges() - faces.get_num_curves()) == 4).repeat(4)
99 | quad_faces = faces.mask(is_quad_face)
100 | bspline_faces = Faces.from_wires(
101 | Wires.convert_to_bspline(quad_faces.wires), is_exterior=True
102 | )
103 | bspline_face_mesh = Tesselator.get_non_planar_mesh(bspline_faces, curve_samples)
104 | return bspline_face_mesh
105 |
106 | @staticmethod
107 | def get_planar_mesh(faces: Faces, curve_samples: int = 50):
108 | vertices = jnp.empty((0, 3))
109 | simplices = jnp.empty((0, 3), dtype=jnp.int32)
110 | normals = jnp.empty((0, 3))
111 |
112 | # TODO: replace this with jax implementation of Delaunay triangulation
113 | for i in range(faces.count):
114 | face = faces.mask(faces.index == i)
115 | assert face.wires.is_planar.all() == True, "Only planar faces supported"
116 |
117 | axis = face.get_axes()
118 | axis_normal = axis.normals[0]
119 |
120 | # triangulate face
121 | segments = face.wires.edges.evaluate(curve_samples, is_cosine_sampling=True)
122 | local_segments = axis.to_local_coords(segments)[0]
123 |
124 | delaunay = Delaunay(local_segments[:, 0, 0:2])
125 |
126 | tri_local_vertices, tri_simplices = (
127 | jnp.array(delaunay.points),
128 | jnp.array(delaunay.simplices),
129 | )
130 | tri_vertices = axis.to_world_coords(tri_local_vertices)[0]
131 |
132 | # Flip triangle normals to correct direction as plane
133 | tri_normal = Tesselator.get_tri_normal(
134 | tri_vertices[tri_simplices[0]],
135 | )
136 | if not (tri_normal == axis_normal).all():
137 | tri_simplices = jnp.flip(tri_simplices, axis=1)
138 |
139 | # Trim triangles that are not interior
140 | is_interior = face.wires.get_segment_is_interior(curve_samples)
141 | trimmed_simplices = Tesselator.get_trimmed_simplices(
142 | tri_local_vertices,
143 | tri_simplices,
144 | local_segments[~is_interior],
145 | local_segments[is_interior],
146 | )
147 | new_normals = axis.normals.repeat(len(vertices), axis=0)
148 |
149 | simplices = jnp.concatenate([simplices, trimmed_simplices + len(vertices)])
150 | vertices = jnp.concatenate([vertices, tri_vertices])
151 |
152 | normals = jnp.concatenate([normals, new_normals])
153 |
154 | return Mesh(
155 | vertices,
156 | simplices,
157 | normals,
158 | faces.edges.evaluate(curve_samples),
159 | )
160 |
161 | @staticmethod
162 | @jax.jit
163 | def get_bspline_tri_index(i: jnp.ndarray, j: jnp.ndarray, num_v: jnp.ndarray):
164 | index = jnp.array(
165 | [
166 | j + (i * num_v),
167 | j + ((i + 1) * num_v),
168 | j + 1 + ((i + 1) * num_v),
169 | j + 1 + (i * num_v),
170 | ],
171 | dtype=jnp.int32,
172 | )
173 | # + curr_index
174 |
175 | return jax.vmap(
176 | lambda tri_index, i: jnp.array(
177 | # [tri_index[i + 1], tri_index[i], tri_index[0]],
178 | [tri_index[0], tri_index[i], tri_index[i + 1]],
179 | dtype=jnp.int32,
180 | ),
181 | in_axes=(None, 0),
182 | )(index, jnp.arange(1, 3))
183 |
184 | @staticmethod
185 | def get_non_planar_mesh(
186 | faces: Faces,
187 | num_u: int = 20,
188 | num_v: int = 20,
189 | curve_samples: int = DEFAULT_CURVE_SAMPLES,
190 | ):
191 | mesh = Mesh.empty()
192 | if len(faces.wires.is_planar) > 0:
193 | # TODO: get rid of for loop here
194 | for i in range(faces.count):
195 | face = faces.mask(faces.index == i)
196 | assert (
197 | faces.wires.is_planar.all() == False
198 | ), "Only non-planar faces supported"
199 | i, j = jnp.arange(num_u - 1), jnp.arange(num_v - 1)
200 |
201 | u = get_sampling(0.0, 1.0, num_u)
202 | v = get_sampling(0.0, 1.0, num_v)
203 | vertices = face.bspline_surfaces.evaluate(u, v)
204 |
205 | simplicies = jax.vmap(
206 | jax.vmap(
207 | Tesselator.get_bspline_tri_index,
208 | in_axes=(None, 0, None),
209 | ),
210 | in_axes=(0, None, None),
211 | )(i, j, jnp.array(num_v)).reshape(-1, 3)
212 |
213 | u_curve = get_sampling(0.0, 1.0, curve_samples)
214 | line_segments = faces.wires.edges.curves.evaluate(u_curve)
215 | # faces.wires.plot()
216 | # TODO: fix these normals frp, Bspline derivatives later
217 | normals = jax.vmap(Tesselator.get_tri_normal, in_axes=(0,))(
218 | vertices[simplicies]
219 | )
220 |
221 | new_mesh = Mesh(vertices, simplicies, normals, line_segments)
222 | mesh = Mesh.combine([mesh, new_mesh])
223 | return mesh
224 |
--------------------------------------------------------------------------------
/nebula/workplane.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | from typing import Optional, Union
3 | from nebula.helpers.types import ArrayLike, CoordLike, Number
4 | from nebula.helpers.wire import WireHelper
5 | from nebula.tools.solid import SolidTool
6 | from nebula.tools.edge import EdgeTool
7 | from nebula.prim.axes import Axes, AxesString
8 | from nebula.topology.edges import Edges
9 | from nebula.topology.solids import Solids
10 | from nebula.topology.wires import Wires
11 | import jax_dataclasses as jdc
12 |
13 |
14 | @jdc.pytree_dataclass
15 | class Workplane:
16 | axes: Axes
17 | "The axes of the workplane"
18 |
19 | base_solids: Solids = jdc.field(default_factory=Solids.empty)
20 | "The base solids of the workplane"
21 |
22 | pending_edges: Edges = jdc.field(default_factory=Edges.empty)
23 | "The pending edges of the workplane"
24 |
25 | pending_wires: Wires = jdc.field(default_factory=Wires.empty)
26 | "The pending wires of the workplane"
27 |
28 | trace_vertex: jnp.ndarray = jdc.field(
29 | default_factory=lambda: jnp.array([0, 0, 0], dtype=jnp.float32)
30 | )
31 | "The trace vertex of the workplane"
32 |
33 | @staticmethod
34 | def init(
35 | axes: Union[AxesString, Axes] = "XY",
36 | ):
37 | "Initialize the workplane"
38 | axes = axes if isinstance(axes, Axes) else Axes.from_str(axes)
39 | workplane = Workplane(axes)
40 | return workplane
41 |
42 | def clone(
43 | self,
44 | axes: Optional[Axes] = None,
45 | base_solids: Optional[Solids] = None,
46 | pending_edges: Optional[Edges] = None,
47 | pending_wires: Optional[Wires] = None,
48 | trace_vertex: Optional[jnp.ndarray] = None,
49 | ):
50 | "Clone the workplane"
51 | return Workplane(
52 | axes if axes is not None else self.axes,
53 | base_solids if base_solids is not None else self.base_solids,
54 | pending_edges if pending_edges is not None else self.pending_edges,
55 | pending_wires if pending_wires is not None else self.pending_wires,
56 | trace_vertex if trace_vertex is not None else self.trace_vertex,
57 | )
58 |
59 | def plot(self):
60 | self.base_solids.shells.faces.wires.plot()
61 |
62 | @property
63 | def axis_locked(self):
64 | return self.pending_wires.count > 0
65 |
66 | def consolidateWires(self, validate: bool = False):
67 | """Consolidate pending wires into a single wire
68 |
69 | :param validate: Validate the edges for order
70 |
71 | """
72 | pending_wires = EdgeTool.consolidate_wires(
73 | self.pending_wires, self.pending_edges, validate
74 | )
75 | # pending_wires = Wires.empty()
76 | pending_edges = Edges.empty()
77 | trace_vertex = jnp.array([0, 0, 0])
78 | return self.clone(
79 | pending_edges=pending_edges,
80 | pending_wires=pending_wires,
81 | trace_vertex=trace_vertex,
82 | )
83 |
84 | def moveTo(self, x: Number = 0.0, y: Number = 0.0):
85 | "Move to the specified point, without drawing"
86 | trace_vertex = jnp.array([x, y, 0])
87 | return self.clone(trace_vertex=trace_vertex)
88 |
89 | def move(self, xDist: Number = 0.0, yDist: Number = 0.0):
90 | "Move the specified distance from the current point, without drawing"
91 | trace_vertex = self.trace_vertex + jnp.array([xDist, yDist, 0])
92 | return self.clone(trace_vertex=trace_vertex)
93 |
94 | def polarLineTo(self, distance: Number, angle: Number, forConstruction: bool = False):
95 | """
96 | Make a line from the current point to the given polar coordinates
97 |
98 | Useful if it is more convenient to specify the end location rather than
99 | the distance and angle from the current point
100 |
101 | :param distance: distance of the end of the line from the origin
102 | :param angle: angle of the vector to the end of the line with the x-axis
103 | :return: the Workplane object with the current point at the end of the new line
104 | """
105 | edge = EdgeTool.make_polar_line(self.trace_vertex, distance, angle)
106 |
107 | new_workplane = self.moveTo(
108 | edge.lines.vertices[-1][0], edge.lines.vertices[-1][1]
109 | )
110 | if not forConstruction:
111 | new_pending_edges = self.pending_edges.add(edge)
112 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges)
113 | return new_workplane
114 |
115 | def box(self, length: Number, width: Number, height: Number, centered=True):
116 | "Make a box for each item on the stack"
117 | return self.rect(length, width).extrude(height)
118 |
119 | def lineTo(self, x: Number, y: Number, forConstruction=False):
120 | "Make a line from the current point to the provided point"
121 | edge = EdgeTool.make_line(self.trace_vertex, jnp.array([x, y, 0]))
122 |
123 | new_workplane = self.moveTo(x, y)
124 | if not forConstruction:
125 | new_pending_edges = self.pending_edges.add(edge)
126 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges)
127 | return new_workplane
128 |
129 | def rect(self, xLen: Number, yLen: Number, centered=True, forConstruction=False):
130 | "Make a rectangle for each item on the stack"
131 | edges = EdgeTool.make_rect(xLen, yLen, self.trace_vertex, centered=centered)
132 | if not forConstruction:
133 | new_pending_edges = self.pending_edges.add(edges)
134 | new_workplane = self.clone(pending_edges=new_pending_edges).consolidateWires(
135 | validate=False
136 | )
137 |
138 | return new_workplane
139 |
140 | # def show(self):
141 |
142 | def polyline(
143 | self,
144 | vertices: CoordLike,
145 | forConstruction: bool = False,
146 | includeCurrent: bool = False,
147 | ):
148 | "Create a polyline from a list of points"
149 |
150 | vertices = WireHelper.to_3d_vertices(vertices)
151 |
152 | if includeCurrent:
153 | vertices = jnp.concatenate(
154 | [jnp.expand_dims(self.trace_vertex, axis=0), vertices]
155 | )
156 |
157 | edges = EdgeTool.make_polyline(vertices)
158 | new_workplane = self.moveTo(vertices[-1][0], vertices[-1][1])
159 | if not forConstruction:
160 | new_pending_edges = self.pending_edges.add(edges)
161 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges)
162 | return new_workplane
163 |
164 | def bspline(
165 | self,
166 | ctrl_pnts: CoordLike,
167 | degree: Number = 3,
168 | knots: Optional[ArrayLike] = None,
169 | includeCurrent: bool = False,
170 | forConstruction: bool = False,
171 | ):
172 | ctrl_pnts = WireHelper.to_3d_vertices(ctrl_pnts)
173 | if includeCurrent:
174 | ctrl_pnts = jnp.concatenate(
175 | [jnp.expand_dims(self.trace_vertex, axis=0), ctrl_pnts]
176 | )
177 |
178 | curve = EdgeTool.make_bspline_curve(ctrl_pnts, jnp.array(degree), knots)
179 | edges = Edges.from_bspline_curves(curve)
180 | new_workplane = self.moveTo(ctrl_pnts[-1][0], ctrl_pnts[-1][1])
181 | if not forConstruction:
182 | new_pending_edges = self.pending_edges.add(edges)
183 | new_workplane = new_workplane.clone(pending_edges=new_pending_edges)
184 | return new_workplane
185 |
186 | def close(self, validate: bool = True):
187 | """End construction, and attempt to build a closed wire
188 |
189 | If the wire is already closed, nothing happens
190 |
191 | :param validate: Validate the edges for order
192 | """
193 | assert self.pending_edges.count > 0, "No segments to close"
194 |
195 | start_point = self.pending_edges.vertices[0]
196 | is_closed = jnp.allclose(self.trace_vertex, start_point, atol=1e-3)
197 | assert ~is_closed, "Wire already closed"
198 |
199 | return self.lineTo(start_point[0], start_point[1]).consolidateWires(validate)
200 |
201 | def extrude(
202 | self,
203 | until: Number,
204 | ):
205 | "Use all un-extruded wires in the parent chain to create a prismatic solid"
206 | projected_wires = self.pending_wires.project(self.axes)
207 | extruded_solids = SolidTool.extrude(projected_wires, until)
208 | base_solids = self.base_solids.add(extruded_solids)
209 | return Workplane(self.axes, base_solids)
210 |
211 | # TODO: implement this later
212 | # def sweep(self, path: Union[Wires, Edges, "Workplane"]):
213 | # "Use all un-extruded wires in the parent chain to create a swept solid"
214 | # if isinstance(path, Workplane):
215 | # path = path.pending_wires
216 | # elif isinstance(path, Edges):
217 | # path = Wires.from_edges(path)
218 |
219 | # projected_wires = self.pending_wires.project(self.axes)
220 | # swept_solids = SolidTool.sweep(projected_wires, path)
221 | # self.base_solids = self.base_solids.add(swept_solids)
222 | # self.reset()
223 |
224 | # return self
225 |
--------------------------------------------------------------------------------
/nebula/helpers/clipper.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import jax_dataclasses as jdc
3 | import jax
4 | import jax.numpy as jnp
5 | import numpy as np
6 | from nebula.helpers.intersection import Intersection
7 | from nebula.helpers.wire import WireHelper
8 | from nebula.topology.edges import Edges
9 | from nebula.topology.wires import Wires
10 |
11 |
12 | @jdc.pytree_dataclass
13 | class PolygonContainsResult:
14 | subject_in_clipper: jnp.ndarray
15 | clipper_in_subject: jnp.ndarray
16 |
17 |
18 | @jdc.pytree_dataclass
19 | class PolygonSplitResult:
20 | split_subject_segments: jnp.ndarray
21 | split_clipper_segments: jnp.ndarray
22 | subject_intersected: jnp.ndarray
23 | clipper_intersected: jnp.ndarray
24 |
25 |
26 | @jdc.pytree_dataclass
27 | class WireSplitResult:
28 | subject_segments: jnp.ndarray
29 | clipper_segments: jnp.ndarray
30 | unsplit_wires: Wires
31 |
32 |
33 | @jdc.pytree_dataclass
34 | class PolygonClipResult:
35 | clip_edges: Edges
36 | unclipped_wires: Wires
37 |
38 |
39 | class Clipper:
40 | @staticmethod
41 | def cut_polygons(
42 | subject_wires: Wires, clipper: Union[Wires, Edges], sort_edges: bool = True
43 | ):
44 | cutter_union_result = Clipper.union(
45 | subject_wires.mask(subject_wires.is_interior),
46 | clipper,
47 | )
48 |
49 | intersection_result = Clipper.intersect(
50 | subject_wires.mask(~subject_wires.is_interior),
51 | cutter_union_result.clip_edges,
52 | )
53 | clip_edges = intersection_result.clip_edges
54 | if sort_edges:
55 | clip_edges = clip_edges.get_sorted().edges
56 |
57 | intersection_wires = Wires.from_edges(clip_edges)
58 | return Wires.combine(
59 | [
60 | intersection_wires,
61 | cutter_union_result.unclipped_wires,
62 | intersection_result.unclipped_wires,
63 | ]
64 | )
65 |
66 | @staticmethod
67 | def intersect(subject_wires: Wires, clipper: Union[Wires, Edges]):
68 | # Split subject and clipper edges at intersection points
69 | split_result = Clipper.split_wires(subject_wires, clipper)
70 |
71 | if subject_wires.edges.curves.count:
72 | raise NotImplementedError("Intersecting curves is not supported yet.")
73 |
74 | new_subject_segments = (
75 | split_result.subject_segments
76 | if len(split_result.subject_segments)
77 | else subject_wires.edges.lines.get_segments()
78 | )
79 |
80 | # Get which subject and clipper edges are in the other polygon
81 | polygon_contains = Clipper.contains_polygons(
82 | new_subject_segments, split_result.clipper_segments
83 | )
84 | # TODO: check for orientation of clipper
85 | filtered_subject_segments, filtered_clipper_segments = (
86 | new_subject_segments[~polygon_contains.subject_in_clipper],
87 | split_result.clipper_segments[polygon_contains.clipper_in_subject],
88 | )
89 |
90 |
91 | if len(split_result.subject_segments):
92 | filtered_clipper_segments = jnp.flip(filtered_clipper_segments, (0,1))
93 | clip_edges = Edges.from_line_segments(
94 | jnp.concatenate(
95 | [
96 | filtered_subject_segments,
97 | filtered_clipper_segments,
98 | ]
99 | ),
100 | )
101 | return PolygonClipResult(
102 | clip_edges=clip_edges,
103 | unclipped_wires=split_result.unsplit_wires,
104 | )
105 |
106 | return PolygonClipResult(
107 | clip_edges=Edges.from_line_segments(filtered_subject_segments),
108 | unclipped_wires=Wires.from_line_segments(
109 | filtered_clipper_segments, is_interior=True
110 | ),
111 | )
112 |
113 | @staticmethod
114 | def union(subject_wires: Wires, clipper: Union[Wires, Edges]):
115 | # Split subject and clipper edges at intersection points
116 | split_result = Clipper.split_wires(subject_wires, clipper)
117 |
118 | # Get which subject and clipper edges are in the other polygon
119 | polygon_contains = Clipper.contains_polygons(
120 | split_result.subject_segments, split_result.clipper_segments
121 | )
122 |
123 | clip_edges = Edges.from_line_segments(
124 | jnp.concatenate(
125 | [
126 | split_result.subject_segments[~polygon_contains.subject_in_clipper],
127 | split_result.clipper_segments[~polygon_contains.clipper_in_subject],
128 | ]
129 | ),
130 | )
131 | return PolygonClipResult(
132 | clip_edges=clip_edges,
133 | unclipped_wires=split_result.unsplit_wires,
134 | )
135 |
136 | # TODO: make segments have same count for jit
137 | @staticmethod
138 | @jax.jit
139 | def contains_polygons(
140 | subject_segments: jnp.ndarray,
141 | clipper_segments: jnp.ndarray,
142 | ):
143 | subject_in_clipper = jax.vmap(WireHelper.contains_segment, in_axes=(None, 0))(
144 | clipper_segments, subject_segments
145 | )
146 |
147 | clipper_in_subject = jax.vmap(WireHelper.contains_segment, in_axes=(None, 0))(
148 | subject_segments, clipper_segments
149 | )
150 |
151 | return PolygonContainsResult(
152 | subject_in_clipper=subject_in_clipper, clipper_in_subject=clipper_in_subject
153 | )
154 |
155 | @staticmethod
156 | def split_polygon(subject_segments: jnp.ndarray, clipper_segments: jnp.ndarray):
157 | intersections = jax.vmap(
158 | jax.vmap(Intersection.line_segment_intersection, in_axes=(None, 0)),
159 | in_axes=(0, None),
160 | )(subject_segments, clipper_segments)
161 | # find which subject and clipper edges intersect
162 | subject_intersected = jnp.any(intersections.is_intersection, axis=(1,))
163 | clipper_intersected = jnp.any(intersections.is_intersection, axis=(0,))
164 |
165 | # split subject edges at intersection points
166 | subject_split_segments = Clipper.split_segments(
167 | subject_segments[subject_intersected],
168 | intersections.intersected_vertices[intersections.is_intersection],
169 | )
170 |
171 | # Split clipper edges at intersection points
172 | clipper_split_segments = Clipper.split_segments(
173 | clipper_segments[clipper_intersected],
174 | jnp.swapaxes(intersections.intersected_vertices, 0, 1)[
175 | jnp.swapaxes(intersections.is_intersection, 0, 1)
176 | ],
177 | )
178 |
179 | return PolygonSplitResult(
180 | split_subject_segments=subject_split_segments,
181 | split_clipper_segments=clipper_split_segments,
182 | subject_intersected=subject_intersected,
183 | clipper_intersected=clipper_intersected,
184 | )
185 |
186 | @staticmethod
187 | def split_wires(subject_wires: Wires, clipper: Union[Wires, Edges]):
188 | clipper_edges = clipper.edges if isinstance(clipper, Wires) else clipper
189 | if subject_wires.edges.curves.count or clipper_edges.curves.count:
190 | raise NotImplementedError("Intersecting curves is not supported yet.")
191 | subject_segments = subject_wires.edges.lines.get_segments()
192 | clipper_segments = clipper_edges.lines.get_segments()
193 | split_result = Clipper.split_polygon(subject_segments, clipper_segments)
194 |
195 | # Exclude all subject wires that are not affected from final subject segements
196 | is_split_wire = (
197 | jnp.zeros(subject_wires.count)
198 | .at[subject_wires.index]
199 | .add(split_result.subject_intersected.astype(jnp.int32))[subject_wires.index]
200 | .astype(bool)
201 | )
202 | unsplit_wires = subject_wires.mask(~is_split_wire)
203 | unsplit_affected_wire_segments = subject_segments[
204 | ~split_result.subject_intersected & is_split_wire
205 | ]
206 |
207 | new_subject_segments = jnp.concatenate(
208 | [unsplit_affected_wire_segments, split_result.split_subject_segments]
209 | )
210 |
211 | new_clipper_segments = jnp.concatenate(
212 | [
213 | clipper_segments[~split_result.clipper_intersected],
214 | split_result.split_clipper_segments,
215 | ]
216 | )
217 |
218 | return WireSplitResult(
219 | subject_segments=new_subject_segments,
220 | clipper_segments=new_clipper_segments,
221 | unsplit_wires=unsplit_wires,
222 | )
223 |
224 | # TODO: make segments have same count for jit
225 | @staticmethod
226 | @jax.jit
227 | def split_segments(segments: jnp.ndarray, intersection_vertices: jnp.ndarray):
228 | fun_split_edges = lambda edge_vertices, intersection_point: jnp.array(
229 | [
230 | [edge_vertices[0], intersection_point],
231 | [intersection_point, edge_vertices[1]],
232 | ]
233 | ) # (2, 3)
234 |
235 | # Split subject edges at intersection points
236 | segment_groups = jax.vmap(fun_split_edges, in_axes=(0, 0))(
237 | segments,
238 | intersection_vertices,
239 | )
240 | return (
241 | jnp.concatenate(segment_groups)
242 | if len(segment_groups) > 0
243 | else jnp.empty((0, 2, 3))
244 | )
245 |
--------------------------------------------------------------------------------
/nebula/topology/edges.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional, Union
2 | import jax_dataclasses as jdc
3 | import jax.numpy as jnp
4 | import numpy as np
5 | from nebula.evaluators.bspline import get_sampling
6 | from nebula.prim import BSplineCurves
7 | from nebula.prim import Lines
8 | import numpy.typing as npt
9 | from nebula.helpers.wire import WireHelper
10 | from nebula.prim.axes import Axes
11 | from nebula.topology.topology import Topology
12 |
13 | DEFAULT_CURVE_SAMPLES = 20
14 |
15 |
16 | @jdc.pytree_dataclass
17 | class SortedEdgeResult:
18 | edges: "Edges"
19 | index: jdc.Static[npt.NDArray[np.int32]]
20 | segments: jnp.ndarray
21 |
22 |
23 | @jdc.pytree_dataclass
24 | class Edges(Topology):
25 | lines: Lines
26 | curves: BSplineCurves
27 | is_curve: jdc.Static[npt.NDArray[np.bool_]]
28 |
29 | def add(
30 | self,
31 | edges: Union["Edges", Lines, BSplineCurves, jnp.ndarray],
32 | reorder_index: bool = False,
33 | ):
34 | """
35 | Add a new edge to the topology.
36 |
37 | :param properties: Either new vertices of the edges or new Edges to add.
38 | :type a: jnp.ndarray
39 |
40 | :returns: The indices of the new edges.
41 | :rtype: jnp.ndarray
42 | """
43 | # Add new unqiue vertices to the list of vertices
44 | if isinstance(edges, Edges):
45 | lines = self.lines.add(edges.lines)
46 | curves = self.curves.add(edges.curves)
47 | is_curve = np.concatenate([self.is_curve, edges.is_curve])
48 |
49 | else:
50 | if isinstance(edges, BSplineCurves):
51 | curves = self.curves.add(edges)
52 | lines = self.lines
53 | else:
54 | if isinstance(edges, jnp.ndarray):
55 | edges = Lines.from_segments(edges)
56 | lines = self.lines.add(edges)
57 | curves = self.curves
58 |
59 | new_indices = np.arange(edges.count, dtype=np.int32)
60 | is_curve = np.concatenate(
61 | [self.is_curve, np.full(len(new_indices), isinstance(edges, BSplineCurves))]
62 | )
63 |
64 | return Edges(
65 | lines,
66 | curves,
67 | is_curve,
68 | )
69 |
70 | @staticmethod
71 | def empty():
72 | return Edges(
73 | lines=Lines.empty(),
74 | curves=BSplineCurves.empty(),
75 | is_curve=np.empty((0,), dtype=bool),
76 | )
77 |
78 | @staticmethod
79 | def from_line_segments(segments: jnp.ndarray):
80 | """Create edges from segments
81 |
82 | :param segments: segments to create edges from
83 | :type segments: jnp.ndarray
84 | :return: Edges
85 | :rtype: Edges
86 | """
87 | lines = Lines.from_segments(segments)
88 | return Edges(
89 | lines=lines,
90 | curves=BSplineCurves.empty(),
91 | is_curve=np.zeros(lines.count, dtype=bool),
92 | )
93 |
94 | @staticmethod
95 | def from_bspline_curves(
96 | bspline_curves: BSplineCurves,
97 | ):
98 | """Create edges from bspline curves
99 |
100 | :param bspline_curves: bspline curves
101 | :type bspline_curves: BSplineCurves
102 | :return: Edges
103 | :rtype: Edges
104 | """
105 | new_edges = Edges.empty()
106 | return new_edges.add(bspline_curves)
107 |
108 | def clone(
109 | self,
110 | lines: Optional[Lines] = None,
111 | curves: Optional[BSplineCurves] = None,
112 | is_curve: Optional[np.ndarray] = None,
113 | ):
114 | return Edges(
115 | lines=lines if lines is not None else self.lines,
116 | curves=curves if curves is not None else self.curves,
117 | is_curve=is_curve if is_curve is not None else self.is_curve,
118 | )
119 |
120 | @property
121 | def vertices(self):
122 | return jnp.concatenate(
123 | [self.lines.get_segments()[:, 0], self.curves.ctrl_pnts.val]
124 | )
125 |
126 | # TODO: make this simpler and more efficient
127 | def get_sorted(self):
128 | segments = self.get_segments()
129 | sorted_segments = jnp.zeros((len(segments), 2, 3))
130 | sorting_order = np.zeros(self.index.shape[0], dtype=int)
131 |
132 | order_index = np.arange(self.count)
133 | for i in range(0, len(segments)):
134 | if i == 0:
135 | sorted_segments = sorted_segments.at[i].set(segments[0])
136 | sorting_order[i] = self.index[0]
137 | else:
138 | previous_segment_end = sorted_segments[i - 1][1]
139 | # check which start segments match the previous segment's end
140 | sort_mask = jnp.all(
141 | jnp.abs(segments[:, 0] - previous_segment_end) <= 1e-5, axis=1
142 | )
143 | sorted_segments = sorted_segments.at[i].set(segments[sort_mask][0])
144 | sorting_order[i] = order_index[sort_mask][0]
145 |
146 | sorted_is_curve = self.is_curve[sorting_order]
147 | sorted_index = self.index[sorting_order]
148 |
149 | new_edges = self.clone(
150 | lines=self.lines.reorder(sorted_index[~sorted_is_curve]),
151 | curves=self.curves.reorder(sorted_index[sorted_is_curve]),
152 | is_curve=sorted_is_curve,
153 | )
154 |
155 | return SortedEdgeResult(new_edges, sorted_index, sorted_segments)
156 | # if return_index:
157 | # return self.clone(), self.index
158 | # return self.clone()
159 |
160 | # def get_segments(self):
161 | # line_segments = self.lines.get_segments()
162 | # curve_segments = self.curves.get_segments()
163 |
164 | # total_size = line_segments.shape[0] + curve_segments.shape[0]
165 | # padded_curve_segments = jnp.pad(
166 | # curve_segments, ((0, total_size - curve_segments.shape[0]), (0, 0), (0, 0))
167 | # )
168 | # padded_line_segments = jnp.pad(
169 | # line_segments, ((0, total_size - line_segments.shape[0]), (0, 0), (0, 0))
170 | # )
171 |
172 | # return jnp.where(
173 | # self.is_curve[:, None, None], padded_curve_segments, padded_line_segments
174 | # )
175 |
176 | def get_segments(self):
177 | line_segments = self.lines.get_segments()
178 | curve_segments = self.curves.get_segments()
179 |
180 | return (
181 | jnp.empty((self.count, 2, 3))
182 | .at[~self.is_curve]
183 | .set(line_segments)
184 | .at[self.is_curve]
185 | .set(curve_segments)
186 | )
187 |
188 | @property
189 | def index(self):
190 | """
191 | Generates edge index for each component, use mask self.is_curve to mask off each index
192 | """
193 | index = np.empty(self.count, dtype=np.int32)
194 | index[~self.is_curve] = np.arange(self.lines.count)
195 | index[self.is_curve] = np.arange(self.curves.count)
196 | return index
197 |
198 | @property
199 | def count(self):
200 | return len(self.is_curve)
201 |
202 | def evaluate(
203 | self,
204 | curve_samples: Optional[int] = DEFAULT_CURVE_SAMPLES,
205 | is_cosine_sampling: bool = False,
206 | ):
207 | """
208 | Evaluate edges, not in order
209 |
210 | :param curve_samples: Number of samples to use for curves,
211 | if None then uses ctrl pnt start and end as segments, defaults to DEFAULT_CURVE_SAMPLES
212 | :type curve_samples: Optional[int], optional
213 | :return: Segments
214 | :rtype: jnp.ndarray
215 | """
216 | u = (
217 | get_sampling(0.0, 1.0, curve_samples + 1, is_cosine_sampling)
218 | if curve_samples
219 | else None
220 | )
221 | curve_segments = self.curves.evaluate(u)
222 | line_segments = self.lines.get_segments()
223 |
224 | return jnp.concatenate([line_segments, curve_segments], axis=0)
225 |
226 | def dir_vec_at(self, u: jnp.ndarray, index: jnp.ndarray, eps=1e-5):
227 | u_next = jnp.where(u >= 1.0, u, u + eps)
228 | u = jnp.where(u >= 1.0, u_next - eps, u)
229 |
230 | start_vertices = self.evaluate_at(u, index)
231 | end_vertices = self.evaluate_at(u_next, index)
232 |
233 | return end_vertices - start_vertices
234 |
235 | def evaluate_at(self, u: jnp.ndarray, index: jnp.ndarray):
236 | """
237 | Evaluate the edge at the given u values.
238 |
239 | :param u: The u values to evaluate the edge at.
240 | :type u: jnp.ndarray
241 | :param index: The index of the edge to evaluate.
242 | :type index: jnp.ndarray
243 | :return: The vertices of the edge.
244 | :rtype: jnp.ndarray
245 | """
246 | if self.is_curve[index]:
247 | return self.curves.evaluate_at(u, self.index[index])
248 | return self.lines.evaluate_at(u, self.index[index])
249 |
250 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
251 | """Get edges from mask
252 |
253 | :param mask: Mask for edges
254 | :type mask: Union[jnp.ndarray, np.ndarray]
255 | :return: Edges
256 | :rtype: Edges
257 | """
258 |
259 | line_mask = mask[~self.is_curve]
260 | curve_mask = mask[self.is_curve]
261 |
262 | return self.clone(
263 | self.lines.mask(line_mask),
264 | self.curves.mask(curve_mask),
265 | self.is_curve[mask],
266 | )
267 |
268 | def project(self, curr_axes: Axes, new_axes: Axes):
269 | new_lines = self.lines.project(curr_axes, new_axes)
270 | new_curves = self.curves.project(curr_axes, new_axes)
271 | is_curve = np.repeat(self.is_curve[None, :], new_axes.count, axis=0).flatten()
272 | return Edges(lines=new_lines, curves=new_curves, is_curve=is_curve)
273 |
274 | def translate(self, translation: jnp.ndarray):
275 | line_translation = (
276 | translation if len(translation.shape) == 1 else translation[~self.is_curve]
277 | )
278 |
279 | curve_translation = (
280 | translation if len(translation.shape) == 1 else translation[self.is_curve]
281 | )
282 |
283 | new_lines = self.lines.translate(line_translation)
284 | new_curves = self.curves.translate(curve_translation)
285 | return self.clone(new_lines, new_curves)
286 |
287 | def flip_winding(self, mask: Optional[jnp.ndarray] = None):
288 | line_mask = mask[~self.is_curve] if mask is not None else None
289 | curve_mask = mask[self.is_curve] if mask is not None else None
290 | flipped_edges = self.clone(
291 | self.lines.flip_winding(line_mask),
292 | self.curves.flip_winding(curve_mask),
293 | self.is_curve[::-1],
294 | )
295 |
296 | return flipped_edges
297 |
298 | def __repr__(self) -> str:
299 | return f"Edges(count={self.count})"
300 |
--------------------------------------------------------------------------------
/nebula/topology/wires.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Protocol, Sequence, Union
2 | import jax
3 | import jax_dataclasses as jdc
4 | import jax.numpy as jnp
5 | import numpy as np
6 | import numpy.typing as npt
7 | from nebula.helpers.types import Number
8 | from nebula.helpers.wire import WireHelper
9 | from nebula.prim.axes import Axes
10 | from nebula.topology.topology import SparseIndexable, Topology
11 | from nebula.topology.edges import DEFAULT_CURVE_SAMPLES, BSplineCurves, Edges
12 |
13 |
14 | class BoundingBox:
15 | def __init__(self, min: jnp.ndarray, max: jnp.ndarray) -> None:
16 | self.min = min
17 | self.max = max
18 |
19 | @staticmethod
20 | def from_points(points: jnp.ndarray):
21 | flattened_points = points.reshape(-1, 3)
22 | return BoundingBox(
23 | jnp.min(flattened_points, axis=0), jnp.max(flattened_points, axis=0)
24 | )
25 |
26 | def contains(self, point: jnp.ndarray):
27 | return jnp.all(point >= self.min, axis=-1) & jnp.all(point <= self.max, axis=-1)
28 |
29 | def __repr__(self) -> str:
30 | return f"BoundingBox(min={self.min}, max={self.max})"
31 |
32 |
33 | class WireLike(Topology, Protocol):
34 | edges: Edges
35 |
36 | @classmethod
37 | def combine(cls, topologies: Sequence["Topology"], reorder_index: bool = True):
38 | """Combine topologies into single
39 |
40 | :param topologies: Topology to combine
41 | :type topologies: Topology
42 | :return: Combined topologies
43 | :rtype: Topology
44 | """
45 | new_topology = cls.empty() # type: ignore
46 | for topology in topologies:
47 | new_topology = new_topology.add(topology, reorder_index)
48 |
49 | return new_topology
50 |
51 | def get_segment_index(self, curve_samples: Optional[int] = None):
52 | if curve_samples is None:
53 | curve_samples = 0
54 | # Get the wire segment index including repeats for curve segments
55 | return jnp.concatenate(
56 | [
57 | self.index[~self.edges.is_curve],
58 | self.index[self.edges.is_curve].repeat(curve_samples, axis=0),
59 | ]
60 | )
61 |
62 | def get_axis_normals(self) -> jnp.ndarray:
63 | """Compute the normals of the wire axes. Only works for planar wires."""
64 | line_segments = self.edges.get_segments()
65 | # Get the first and last edges
66 | first_edges = (
67 | jnp.zeros(shape=(self.count, 2, 3))
68 | .at[self.index[::-1]]
69 | .set(line_segments[::-1])
70 | )
71 | last_edges = jnp.zeros(shape=(self.count, 2, 3)).at[self.index].set(line_segments)
72 |
73 | return jax.vmap(WireHelper.get_normal, in_axes=(0, 0))(first_edges, last_edges)
74 |
75 | def get_axes(self):
76 | """Compute the axis of the faces."""
77 | return Axes(normals=self.get_axis_normals())
78 |
79 | def get_lengths(self, curve_samples=DEFAULT_CURVE_SAMPLES):
80 | segments = self.edges.evaluate(curve_samples)
81 | segment_index = self.get_segment_index(curve_samples)
82 |
83 | start_vertices = segments[:, 0]
84 | end_vertices = segments[:, 1]
85 |
86 | lengths = jnp.linalg.norm(end_vertices - start_vertices)
87 | return jnp.zeros(self.count).at[segment_index].add(lengths)
88 |
89 | def get_edge_index(self, u: Number):
90 | if isinstance(u, (int, float)):
91 | u = jnp.array([u])
92 | wire_lengths = self.get_lengths()
93 | length_ratios = wire_lengths / jnp.sum(wire_lengths)
94 |
95 | wire_ratios = jnp.cumsum(length_ratios)
96 | # TODO: check this for correctness
97 | # return jnp.arange(len(wire_ratios))[(wire_ratios <= u)][-1]
98 | return jnp.array(0)
99 |
100 | def dir_vec_at(self, u: Number):
101 | if isinstance(u, (int, float)):
102 | u = jnp.array([u])
103 |
104 | edge_index = self.get_edge_index(u)
105 | # TODO: add this later
106 | # edge_u = (u - length_ratios[edge_index]) / length_ratios[edge_index]
107 | return self.edges.dir_vec_at(u, edge_index)
108 |
109 | def evaluate_at(self, u: Number):
110 | if isinstance(u, (int, float)):
111 | u = jnp.array([u])
112 |
113 | edge_index = self.get_edge_index(u)
114 |
115 | # TODO: add this later
116 | # edge_u = (u - length_ratios[edge_index]) / length_ratios[edge_index]
117 | return self.edges.evaluate_at(u, edge_index)
118 |
119 | def get_centroids(self, curve_samples=DEFAULT_CURVE_SAMPLES):
120 | """Compute the centers of the faces."""
121 | segments = self.edges.evaluate(curve_samples)
122 | segment_index = self.get_segment_index(curve_samples)
123 | num_segments = jnp.expand_dims(
124 | (jnp.zeros(self.count).at[segment_index].add(values=jnp.ones(len(segments)))),
125 | axis=1,
126 | )
127 |
128 | start_vertices = segments[:, 0]
129 |
130 | # Sum vertices and divide by number of vertices
131 | vertex_sum = (
132 | jnp.zeros(shape=(self.count, 3)).at[segment_index].add(start_vertices)
133 | ) # (num_wires, 3)
134 | centers = vertex_sum / num_segments
135 | return centers
136 |
137 | def get_num_edges(self):
138 | """Compute the number of edges in each wire."""
139 | return (
140 | jnp.zeros(shape=self.count)
141 | .at[self.index]
142 | .add(values=jnp.ones(shape=self.edges.count))
143 | )
144 |
145 | def get_num_curves(self):
146 | """Compute the number of edges in each wire."""
147 | return (
148 | jnp.zeros(shape=self.count)
149 | .at[self.index]
150 | .add(values=self.edges.is_curve.astype(jnp.int32))
151 | )
152 |
153 | def get_bounding_box(self):
154 | start_vertices = self.edges.evaluate()[:, 0]
155 |
156 | max_coords = (
157 | jnp.zeros(shape=(self.count, 3)).at[self.index].max(start_vertices)
158 | ) # (num_wires, 3)
159 |
160 | min_coords = (
161 | jnp.zeros(shape=(self.count, 3)).at[self.index].min(start_vertices)
162 | ) # (num_wires, 3)
163 |
164 | return BoundingBox(min_coords, max_coords)
165 |
166 |
167 | @jdc.pytree_dataclass
168 | class Wires(WireLike):
169 | edges: Edges
170 | index: jdc.Static[npt.NDArray[np.int32]]
171 | is_interior: jdc.Static[npt.NDArray[np.bool_]]
172 | is_planar: jdc.Static[npt.NDArray[np.bool_]]
173 |
174 | def get_segment_is_interior(self, curve_samples: int):
175 | # Get the wire segment index including repeats for curve segments
176 | return jnp.concatenate(
177 | [
178 | self.is_interior[~self.edges.is_curve],
179 | self.is_interior[self.edges.is_curve].repeat(curve_samples, axis=0),
180 | ]
181 | )
182 |
183 | @property
184 | def is_single(self):
185 | return self.count == 1
186 |
187 | # TODO: make this more efficient
188 | @staticmethod
189 | def convert_to_bspline(wires: "Wires"):
190 | """
191 | :param wires: Wires to convert to bspline if they are planar and quads
192 | :type wires: Wires
193 | :return: BSpline curves
194 | :rtype: BSplineCurves
195 | """
196 | eval_wires = Wires.empty()
197 | for i in range(wires.count):
198 | if wires.is_planar[i]:
199 | curr_wire = wires.mask(wires.index == i)
200 | segments = curr_wire.edges.get_segments()
201 | assert segments.shape[0] == 4, "Only quads are supported"
202 | top_curve = BSplineCurves.from_line_segment(segments[0])
203 | bottom_curve = BSplineCurves.from_line_segment(segments[2][::-1])
204 | bspline_plane_wire = Wires.skin([top_curve, bottom_curve])
205 | else:
206 | bspline_plane_wire = wires.mask(wires.index == i)
207 | eval_wires = eval_wires.add(bspline_plane_wire)
208 | return eval_wires
209 |
210 | def mask(self, mask: Union[jnp.ndarray, np.ndarray]):
211 | """Get wires from mask
212 |
213 | :param mask: Mask for wires
214 | :type mask: Union[jnp.ndarray, np.ndarray]
215 | :return: Masked wires
216 | :rtype: Wires
217 | """
218 | wires = Wires(
219 | edges=self.edges.mask(mask),
220 | index=Topology.reorder_index(self.index[mask]),
221 | is_interior=self.is_interior[mask],
222 | is_planar=self.is_planar[mask],
223 | )
224 | assert len(wires.index) == wires.edges.count, "index and node count mismatch"
225 | return wires
226 |
227 | @staticmethod
228 | def empty():
229 | return Wires(
230 | edges=Edges.empty(),
231 | index=np.empty((0,), dtype=np.int32),
232 | is_interior=np.empty((0,), dtype=bool),
233 | is_planar=np.empty((0,), dtype=bool),
234 | )
235 |
236 | @staticmethod
237 | def skin(sections: Sequence[BSplineCurves]):
238 | """Create Wires from bspline surfaces
239 |
240 | :param bspline_surfaces: bspline surfaces
241 | :type bspline_surfaces: BSplineSurfaces
242 | :return: Wires
243 | :rtype: Wires
244 | """
245 | new_curves = BSplineCurves.empty()
246 | index = np.empty((0,), dtype=np.int32)
247 | num_section_pnts = sections[0].ctrl_pnts.count
248 | for section in sections:
249 | assert section.ctrl_pnts.count == num_section_pnts, "section size mismatch"
250 | index = np.concatenate([index, np.arange(section.count, dtype=np.int32)])
251 | new_curves = new_curves.add(section)
252 |
253 | edges = Edges.from_bspline_curves(new_curves)
254 | num_edges = edges.count
255 |
256 | return Wires(
257 | edges=edges,
258 | index=index,
259 | is_interior=np.full(num_edges, False),
260 | is_planar=np.full(num_edges, False),
261 | )
262 |
263 | @staticmethod
264 | def from_edges(edges: Edges):
265 | """Create Wires from edges
266 |
267 | :param edges: edges
268 | :type edges: Edges
269 | :return: Wires
270 | :rtype: Wires
271 | """
272 | num_edges = edges.count
273 | return Wires(
274 | edges=edges,
275 | index=np.full(num_edges, 0, dtype=np.int32),
276 | is_interior=np.full(num_edges, False),
277 | is_planar=np.full(num_edges, True),
278 | )
279 |
280 | @staticmethod
281 | def from_line_segments(
282 | segments: jnp.ndarray,
283 | is_interior: bool = False,
284 | ):
285 | """Create Wires from edge segments for each wire (num_wires, num_edges, 2, 3)
286 |
287 | :param edge_segments: edge segments for each wire (num_wires, num_edges, 2, 3)
288 | :type edge_segments: jnp.ndarray
289 | :return: Wires
290 | :rtype: Wires
291 | """
292 | if len(segments.shape) == 3:
293 | segments = jnp.expand_dims(segments, axis=0)
294 |
295 | edges = Edges.from_line_segments(segments)
296 |
297 | index = (
298 | # create new indices for each wire
299 | np.arange(0, len(segments))
300 | # repeat for each edge
301 | .repeat(segments.shape[1], axis=0)
302 | )
303 |
304 | return Wires(
305 | edges=edges,
306 | index=index,
307 | is_interior=np.full(edges.count, is_interior),
308 | is_planar=np.full(edges.count, True),
309 | )
310 |
311 | def add(self, wires: "Wires", reorder_index: bool = False):
312 | is_interior = np.concatenate([self.is_interior, wires.is_interior])
313 | is_planar = np.concatenate([self.is_planar, wires.is_planar])
314 | edges = self.edges.add(wires.edges)
315 | index = self.add_indices(wires.index, reorder_index)
316 | assert len(index) == edges.count, "index and edges count mismatch"
317 |
318 | return Wires(edges, index, is_interior, is_planar)
319 |
320 | def clone(
321 | self,
322 | edges: Optional[Edges] = None,
323 | index: Optional[np.ndarray] = None,
324 | is_interior: Optional[np.ndarray] = None,
325 | is_planar: Optional[np.ndarray] = None,
326 | ):
327 | edges = edges if edges is not None else self.edges
328 | return Wires(
329 | edges=edges,
330 | index=index if index is not None else self.index,
331 | is_interior=np.full(
332 | edges.count, is_interior if is_interior is not None else False
333 | ),
334 | is_planar=np.full(edges.count, is_planar if is_planar is not None else True),
335 | )
336 |
337 | def project(self, new_axes: Axes):
338 | """
339 | Project the wires onto the given axes from XY plane. Assumes all wires are on XY plane.
340 | :param axes: Axes to project on
341 | :type axes: Axes
342 | :return: Projected wires
343 | :rtype: Wires
344 | """
345 | curr_axes = self.get_axes()
346 | new_edges = self.edges.project(curr_axes, new_axes)
347 | index = SparseIndexable.expanded_index(self.index, new_axes.count)
348 | is_interior = np.repeat(
349 | self.is_interior[None, :], new_axes.count, axis=0
350 | ).flatten()
351 | is_planar = np.repeat(self.is_planar[None, :], new_axes.count, axis=0).flatten()
352 |
353 | return Wires(new_edges, index, is_interior, is_planar)
354 |
355 | def translate(self, translation: jnp.ndarray):
356 | """Translate the wires by a given translation.
357 |
358 | :param translation: The translation to apply.
359 | :type translation: jnp.ndarray
360 | """
361 | if len(translation.shape) == 1:
362 | new_edges = self.edges.translate(translation)
363 | elif len(translation.shape) == 2:
364 | new_edges = self.edges.translate(translation[self.index])
365 | else:
366 | raise ValueError(
367 | f"Translation must be of shape (3,) or (num_wires, 3), got {translation.shape}"
368 | )
369 | return self.clone(new_edges)
370 |
371 | def wind(self, is_clockwise: bool = False):
372 | normals = self.get_axis_normals()
373 | # checks if winding order is clockwise, if so it needs to be flipped to counter clockwise
374 | flip_mask = jnp.sum(normals, axis=1) < 0
375 | if is_clockwise:
376 | # if we want it to be clockwise, it is the inverse of the flip mask
377 | flip_mask = ~flip_mask
378 | new_flipped = self.flip_winding(flip_mask)
379 | return new_flipped
380 |
381 | def flip_winding(self, mask: Optional[jnp.ndarray] = None):
382 | """Returns the wires with fliped winding order.
383 |
384 | :param edges: Wires to flip
385 | :type wires: Wires
386 | :return: Wire with flipped winding order
387 | :rtype: jnp.ndarray
388 | """
389 | # assign to new Wires
390 | edge_flip_mask = mask[self.index] if mask is not None else None
391 | new_edges = self.edges.flip_winding(edge_flip_mask)
392 | new_wires = self.clone(
393 | new_edges, self.index[::-1], is_interior=self.is_interior[::-1]
394 | )
395 |
396 | return new_wires
397 |
398 | @property
399 | def num_interior(self):
400 | return jnp.sum(jnp.any(self.is_interior))
401 |
402 | def __repr__(self) -> str:
403 | return f"Wires(count={self.count}, edges={self.edges}, num_interior={self.num_interior})"
404 |
405 | def plot(self, is_3d=True):
406 | """Plotly plot the wires."""
407 |
408 | import plotly.graph_objects as go
409 |
410 | x = []
411 | y = []
412 | z = []
413 | segments = self.edges.evaluate()
414 | for segment in segments:
415 | x += [*segment[:, 0], None]
416 | y += [*segment[:, 1], None]
417 | z += [*segment[:, 2], None]
418 | if is_3d:
419 | fig = go.Figure(
420 | go.Scatter3d(
421 | x=x,
422 | y=y,
423 | z=z,
424 | mode="lines",
425 | line=dict(color="blue", width=2),
426 | )
427 | )
428 | fig.update_scenes(aspectmode="data")
429 |
430 | else:
431 | fig = go.Figure(
432 | go.Scatter(
433 | x=x,
434 | y=y,
435 | mode="lines",
436 | line=dict(color="blue", width=2),
437 | )
438 | )
439 | # equal scale
440 | fig.update_xaxes(scaleanchor="y", scaleratio=1)
441 | fig.update_yaxes(scaleanchor="x", scaleratio=1)
442 | fig.show()
443 |
--------------------------------------------------------------------------------