├── examples
├── 08_optimization.py
├── 10_pattern_creation.py
├── 07_gradient_accumulation.py
├── 09_point_pattern_optimization.py
├── 11_domain_specific_pattern_optim.py
├── 01_hello_world.py
├── 03_parent_child.py
├── 04_material_randomization.py
├── 05_light_randomization.py
├── 06_animation.py
├── 02_general_transformations.py
├── 06_sampling.py
└── vocalfold_scene.py
├── fireflies
├── utils
│ ├── __init__.py
│ ├── transforms.py
│ ├── torch_grads.py
│ ├── warnings.py
│ ├── intersections.py
│ ├── io.py
│ ├── math.py
│ └── laser_estimation.py
├── graphics
│ ├── __init__.py
│ ├── depth.py
│ └── rasterization.py
├── emitter
│ ├── __init__.py
│ └── base.py
├── __init__.py
├── material
│ ├── __init__.py
│ └── base.py
├── projection
│ ├── __init__.py
│ ├── camera.py
│ └── laser.py
├── entity
│ ├── __init__.py
│ ├── shape.py
│ ├── curve.py
│ ├── flame.py
│ ├── mesh.py
│ └── base.py
├── postprocessing
│ ├── __init__.py
│ ├── base.py
│ ├── postprocessor.py
│ ├── white_noise.py
│ ├── gauss_blur.py
│ └── apply_silhouette.py
├── sampling
│ ├── __init__.py
│ ├── uniform.py
│ ├── gaussian_distribution.py
│ ├── uniform_integer.py
│ ├── uniform_scalar_to_vec3.py
│ ├── animation.py
│ ├── base.py
│ ├── noise_texture_lerp.py
│ └── poisson.py
└── scene.py
├── assets
├── eval.gif
├── logo.png
├── train.gif
└── teaser.png
├── setup.py
├── LICENSE
├── requirements.txt
├── .gitignore
├── README.md
└── main.py
/examples/08_optimization.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fireflies/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/10_pattern_creation.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fireflies/graphics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fireflies/utils/transforms.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/07_gradient_accumulation.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/09_point_pattern_optimization.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/examples/11_domain_specific_pattern_optim.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fireflies/emitter/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Light
2 |
--------------------------------------------------------------------------------
/fireflies/__init__.py:
--------------------------------------------------------------------------------
1 | from fireflies.scene import Scene
2 |
--------------------------------------------------------------------------------
/fireflies/material/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Material
2 |
--------------------------------------------------------------------------------
/assets/eval.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/eval.gif
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/assets/train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/train.gif
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/fireflies/projection/__init__.py:
--------------------------------------------------------------------------------
1 | from .camera import Camera
2 | from .laser import Laser
3 |
--------------------------------------------------------------------------------
/fireflies/entity/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Transformable
2 | from .curve import Curve
3 | from .mesh import Mesh
4 |
--------------------------------------------------------------------------------
/fireflies/utils/torch_grads.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 |
4 |
5 | def retain_grads(non_leaf_tensor: List[torch.tensor]) -> None:
6 | for tensor in non_leaf_tensor:
7 | tensor.retain_grad()
8 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BasePostProcessingFunction
2 | from .white_noise import WhiteNoise
3 | from .postprocessor import PostProcessor
4 | from .apply_silhouette import ApplySilhouette
5 | from .gauss_blur import GaussianBlur
6 |
--------------------------------------------------------------------------------
/fireflies/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import Sampler
2 | from .gaussian_distribution import GaussianSampler
3 | from .uniform import UniformSampler
4 | from .uniform_integer import UniformIntegerSampler
5 | from .uniform_scalar_to_vec3 import UniformScalarToVec3Sampler
6 | from .animation import AnimationSampler
7 | from .noise_texture_lerp import NoiseTextureLerpSampler
8 |
--------------------------------------------------------------------------------
/fireflies/emitter/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 |
4 | import fireflies.utils.math
5 | import fireflies.entity
6 |
7 |
8 | class Light(fireflies.entity.Transformable):
9 | def __init__(
10 | self,
11 | name: str,
12 | device: torch.cuda.device = torch.device("cuda"),
13 | ):
14 | super(Light, self).__init__(name, device)
15 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import fireflies.sampling
4 |
5 |
6 | class BasePostProcessingFunction:
7 | def __init__(self, probability: float):
8 | self._probability = probability
9 |
10 | def apply(self, image: np.array) -> np.array:
11 | if random.uniform(0, 1) < self._probability:
12 | return self.post_process(image)
13 |
14 | return image
15 |
16 | @NotImplementedError
17 | def post_process(self, image: np.array) -> np.array:
18 | return None
19 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/postprocessor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .base import BasePostProcessingFunction
3 |
4 | from typing import List
5 |
6 |
7 | class PostProcessor:
8 | def __init__(
9 | self,
10 | post_process_funcs: List[BasePostProcessingFunction],
11 | ):
12 | self._post_process_functs = post_process_funcs
13 |
14 | def post_process(self, image: np.array) -> np.array:
15 | image_copy = image.copy()
16 | for func in self._post_process_functs:
17 | image_copy = func.apply(image_copy)
18 |
19 | return image_copy
20 |
--------------------------------------------------------------------------------
/fireflies/sampling/uniform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.sampling.base as base
3 | import fireflies.utils.math
4 |
5 |
6 | class UniformSampler(base.Sampler):
7 | def __init__(
8 | self,
9 | min: torch.tensor,
10 | max: torch.tensor,
11 | eval_step_size: float = 0.01,
12 | device: torch.cuda.device = torch.device("cuda"),
13 | ) -> None:
14 | super(UniformSampler, self).__init__(min, max, eval_step_size, device)
15 |
16 | def sample_train(self) -> torch.tensor:
17 | return fireflies.utils.math.randomBetweenTensors(
18 | self._min_range, self._max_range
19 | )
20 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/white_noise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import fireflies.postprocessing.base as base
3 |
4 |
5 | class WhiteNoise(base.BasePostProcessingFunction):
6 | def __init__(
7 | self,
8 | mean: float,
9 | std: float,
10 | probability: float,
11 | ):
12 | super(WhiteNoise, self).__init__(probability)
13 | self._mean = mean
14 | self._std = std
15 |
16 | def post_process(self, image: np.array) -> np.array:
17 | image += np.random.normal(
18 | np.ones_like(image) * self._mean, np.ones_like(image) * self._std
19 | )
20 | return np.clip(image, 0, 1)
21 |
--------------------------------------------------------------------------------
/fireflies/sampling/gaussian_distribution.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.sampling.base as base
3 |
4 |
5 | class GaussianSampler(base.Sampler):
6 | def __init__(
7 | self,
8 | min: torch.tensor,
9 | max: torch.tensor,
10 | mean: torch.tensor,
11 | std: torch.tensor,
12 | eval_step_size: float = 0.01,
13 | device: torch.cuda.device = torch.device("cuda"),
14 | ) -> None:
15 | super(GaussianSampler, self).__init__(min, max, eval_step_size, device)
16 | self._mean = mean
17 | self._std = std
18 |
19 | def sample_train(self) -> torch.tensor:
20 | return torch.normal(self._mean, self._std)
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from os import path
3 |
4 |
5 | this_directory = path.abspath(path.dirname(__file__))
6 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
7 | long_description = f.read()
8 |
9 | with open("requirements.txt") as f:
10 | install_requires = f.read().strip().split("\n")
11 |
12 | setup(
13 | name="Fireflies",
14 | version="1.0",
15 | description="A module for randomizing mitsuba scenes and their parameters originally created for Structured Light Endoscopy.",
16 | author="Jann-Ole Henningson",
17 | author_email="jann-ole.henningson@fau.de",
18 | url="https://github.com/Henningson/Fireflies",
19 | packages=["fireflies"],
20 | install_requires=install_requires,
21 | packages=find_packages(),
22 | )
23 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/gauss_blur.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import fireflies.postprocessing.base as base
3 | import kornia
4 | import torch
5 |
6 |
7 | class GaussianBlur(base.BasePostProcessingFunction):
8 | def __init__(
9 | self,
10 | kernel_size: tuple[int, int],
11 | sigma: tuple[float, float],
12 | probability: float,
13 | ):
14 | super(GaussianBlur, self).__init__(probability)
15 | self._kernel_size = kernel_size
16 | self._sigma = sigma
17 |
18 | def post_process(self, image: np.array) -> np.array:
19 | image = (
20 | kornia.filters.gaussian_blur2d(
21 | torch.tensor(image).unsqueeze(0).unsqueeze(0),
22 | self._kernel_size,
23 | self._sigma,
24 | )
25 | .squeeze()
26 | .numpy()
27 | )
28 | return image
29 |
--------------------------------------------------------------------------------
/fireflies/entity/shape.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from typing import List
4 |
5 | from mesh import Mesh
6 | import fireflies.utils.math
7 |
8 |
9 | @NotImplementedError
10 | class ShapeModel(Mesh):
11 | def __init__(
12 | self,
13 | name: str,
14 | vertex_data: torch.tensor,
15 | face_data: torch.tensor,
16 | device: torch.cuda.device = torch.device("cuda"),
17 | ):
18 | super(Mesh, self).__init__(name, vertex_data, face_data, device)
19 | self._device = device
20 | self._name = name
21 |
22 | def load_animation(self):
23 | return None
24 |
25 | def get_model_params(self) -> dict:
26 | return self._model_params
27 |
28 | def set_model_params(self, dict: dict) -> None:
29 | assert NotImplementedError
30 |
31 | def get_vertices(self):
32 | assert NotImplementedError
33 |
--------------------------------------------------------------------------------
/examples/01_hello_world.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 |
5 | mi.set_variant("cuda_ad_rgb")
6 |
7 | import torch
8 | import fireflies
9 |
10 |
11 | def render_to_opencv(render):
12 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
13 | return (render * 255).astype(np.uint8)
14 |
15 |
16 | if __name__ == "__main__":
17 | path = "scenes/hello_world/hello_world.xml"
18 |
19 | mitsuba_scene = mi.load_file(path)
20 | mitsuba_params = mi.traverse(mitsuba_scene)
21 | fireflies_scene = fireflies.Scene(mitsuba_params)
22 |
23 | fireflies_scene.mesh_at(0).rotate_z(-np.pi, np.pi)
24 |
25 | fireflies_scene.train()
26 | for i in range(100):
27 | fireflies_scene.randomize()
28 |
29 | render = mi.render(mitsuba_scene, spp=10)
30 |
31 | cv2.imshow("a", render_to_opencv(render))
32 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render))
33 | cv2.waitKey(10)
34 |
--------------------------------------------------------------------------------
/fireflies/sampling/uniform_integer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.sampling.base as base
3 | import random
4 |
5 |
6 | class UniformIntegerSampler(base.Sampler):
7 | def __init__(
8 | self,
9 | min_integer: int,
10 | max_integer: int,
11 | eval_step_size: int = 1,
12 | device: torch.cuda.device = torch.device("cuda"),
13 | ) -> None:
14 | """
15 | Will generate samples from the integer interval given by [min_integer, ..., max_integer) similar to how range() is defined in python.
16 | """
17 | super(UniformIntegerSampler, self).__init__(min, max, eval_step_size, device)
18 | self._current_step = 0
19 |
20 | def sample_eval(self) -> int:
21 | sample = self._current_step
22 | self._current_step += self._eval_step_size
23 |
24 | if self._current_step >= self._max_range:
25 | self._current_step = self._min_range
26 |
27 | return sample
28 |
29 | def sample_train(self) -> int:
30 | return random.randint(0, self._max_range - 1)
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Henningson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/examples/03_parent_child.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | import mitsuba as mi
6 |
7 | mi.set_variant("cuda_ad_rgb")
8 |
9 | import fireflies
10 |
11 |
12 | def render_to_opencv(render):
13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
14 | return (render * 255).astype(np.uint8)
15 |
16 |
17 | if __name__ == "__main__":
18 | path = "scenes/parent_child/parent_child.xml"
19 |
20 | mi_scene = mi.load_file(path)
21 | mi_params = mi.traverse(mi_scene)
22 | ff_scene = fireflies.Scene(mi_params)
23 |
24 | cone = ff_scene.mesh("mesh-Cone")
25 | sphere = ff_scene.mesh("mesh-Sphere")
26 |
27 | # Add sphere as the cones parent
28 | cone.setParent(sphere)
29 |
30 | # Also let the cone be randomizable, since it wouldn't be randomized if this is not set.
31 | cone.set_randomizable(True)
32 |
33 | # Rotate everything around the z-axis.
34 | sphere.rotate_z(-np.pi, np.pi)
35 |
36 | ff_scene.eval()
37 | for i in range(100):
38 | ff_scene.randomize()
39 |
40 | render = mi.render(mi_scene, spp=10)
41 |
42 | cv2.imshow("a", render_to_opencv(render))
43 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render))
44 | cv2.waitKey(10)
45 |
--------------------------------------------------------------------------------
/fireflies/postprocessing/apply_silhouette.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import fireflies.postprocessing.base as base
4 | import cv2
5 | import random
6 | import kornia
7 | import torch
8 |
9 |
10 | class ApplySilhouette(base.BasePostProcessingFunction):
11 | def __init__(
12 | self,
13 | probability: float = 2.0,
14 | ):
15 | super(ApplySilhouette, self).__init__(probability)
16 |
17 | def post_process(self, image: np.array) -> np.array:
18 | silhouette_image = np.zeros_like(image)
19 | spawning_rect_x = [100, 200]
20 | spawning_rect_y = [200, 300]
21 | radius_interval = [170, 230]
22 |
23 | cc_x = random.randint(spawning_rect_x[0], spawning_rect_x[1])
24 | cc_y = random.randint(spawning_rect_y[0], spawning_rect_y[1])
25 | radius = random.randint(radius_interval[0], radius_interval[1])
26 | silhouette_image = cv2.circle(
27 | silhouette_image, (cc_x, cc_y), radius, color=1, thickness=-1
28 | )
29 |
30 | silhouette_image = (
31 | kornia.filters.gaussian_blur2d(
32 | torch.tensor(silhouette_image).unsqueeze(0).unsqueeze(0),
33 | (11, 11),
34 | (5, 5),
35 | )
36 | .squeeze()
37 | .numpy()
38 | )
39 |
40 | return image * silhouette_image
41 |
--------------------------------------------------------------------------------
/examples/04_material_randomization.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | import mitsuba as mi
6 |
7 | mi.set_variant("cuda_ad_rgb")
8 |
9 | import fireflies
10 |
11 |
12 | def render_to_opencv(render):
13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
14 | return (render * 255).astype(np.uint8)
15 |
16 |
17 | if __name__ == "__main__":
18 | path = "examples/scenes/hello_world/hello_world.xml"
19 |
20 | mi_scene = mi.load_file(path)
21 | mi_params = mi.traverse(mi_scene)
22 | ff_scene = fireflies.Scene(mi_params)
23 |
24 | # Lets randomize the color of the cube.
25 | min_color = torch.tensor([0.2, 0.3, 0.2], device=ff_scene._device)
26 | max_color = torch.tensor([0.8, 1.0, 0.8], device=ff_scene._device)
27 |
28 | material = ff_scene.material("mat-Material")
29 | material.add_vec3_key("brdf_0.base_color.value", min_color, max_color)
30 |
31 | # Keys for randomizable attributes can be accessed using:
32 | # material.vec3_attributes().keys()
33 | # material.float_attributes().keys()
34 |
35 | ff_scene.train()
36 | for i in range(100):
37 | ff_scene.randomize()
38 |
39 | render = mi.render(mi_scene, spp=10)
40 |
41 | cv2.imshow("a", render_to_opencv(render))
42 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render))
43 | cv2.waitKey(10)
44 |
--------------------------------------------------------------------------------
/fireflies/sampling/uniform_scalar_to_vec3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.sampling.base as base
3 | import fireflies.utils.math
4 |
5 |
6 | class UniformScalarToVec3Sampler(base.Sampler):
7 | def __init__(
8 | self,
9 | min: torch.tensor,
10 | max: torch.tensor,
11 | eval_step_size: float = 0.01,
12 | device: torch.cuda.device = torch.device("cuda"),
13 | ) -> None:
14 | super(UniformScalarToVec3Sampler, self).__init__(
15 | min, max, eval_step_size, device
16 | )
17 |
18 | def sample_train(self) -> torch.tensor:
19 | scalar_value = fireflies.utils.math.randomBetweenTensors(
20 | self._min_range, self._max_range
21 | )
22 | return torch.tensor(
23 | [scalar_value, scalar_value, scalar_value], device=self._device
24 | )
25 |
26 | def sample_eval(self) -> torch.tensor:
27 | if (self._min_range == self._max_range).all():
28 | return torch.tensor(
29 | [self._min_range, self._min_range, self._min_range], device=self._device
30 | )
31 |
32 | sample = self._current_step
33 | self._current_step += self._eval_step_size
34 |
35 | if (self._current_step > self._max_range).any():
36 | self._current_step = self._min_range
37 |
38 | return torch.tensor([sample, sample, sample], device=self._device)
39 |
--------------------------------------------------------------------------------
/examples/05_light_randomization.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | import mitsuba as mi
6 |
7 | mi.set_variant("cuda_ad_rgb")
8 |
9 | import fireflies
10 |
11 |
12 | def render_to_opencv(render):
13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
14 | return (render * 255).astype(np.uint8)
15 |
16 |
17 | if __name__ == "__main__":
18 | path = "examples/scenes/parent_child/parent_child.xml"
19 |
20 | mi_scene = mi.load_file(path)
21 | mi_params = mi.traverse(mi_scene)
22 | ff_scene = fireflies.Scene(mi_params)
23 |
24 | cone = ff_scene.mesh("mesh-Cone")
25 | sphere = ff_scene.mesh("mesh-Sphere")
26 | light = ff_scene.light("emit-Light")
27 |
28 | # Add sphere as the cones parent
29 | # Also let the cone be randomizable, since it wouldn't be randomized if this is not set.
30 | cone.setParent(sphere)
31 | cone.set_randomizable(True)
32 |
33 | min_intensity = torch.tensor([150, 0, 0], device=ff_scene._device)
34 | max_intensity = torch.tensor([150, 150, 150], device=ff_scene._device)
35 | light.add_vec3_key("intensity.value", min_intensity, max_intensity)
36 |
37 | # Rotate everything around the z-axis.
38 | sphere.rotate_z(-np.pi, np.pi)
39 |
40 | ff_scene.eval()
41 | for i in range(100):
42 | ff_scene.randomize()
43 |
44 | render = mi.render(mi_scene, spp=10)
45 |
46 | cv2.imshow("a", render_to_opencv(render))
47 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render))
48 | cv2.waitKey(10)
49 |
--------------------------------------------------------------------------------
/examples/06_animation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 |
5 | mi.set_variant("cuda_ad_rgb")
6 |
7 | import torch
8 | import fireflies
9 | import fireflies.sampling
10 |
11 |
12 | def render_to_opencv(render):
13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
14 | return (render * 255).astype(np.uint8)
15 |
16 |
17 | # Let's define an animation function that will visualize a sine wave.
18 | # You can define completely arbitrary functions.
19 | def animation_function(vertices: torch.tensor, time: float) -> torch.tensor:
20 | # Let's not change the incoming vertices in place.
21 | vertices_clone = vertices.clone()
22 |
23 | # Change z coordinate of plane via the sin(x_coordinate + time)
24 | wave_direction = 0
25 |
26 | vertices_clone[:, 1] = (
27 | vertices_clone[:, 1]
28 | + torch.sin(vertices_clone[:, 2] * 10.0 + time * 20.0) / 10.0
29 | )
30 |
31 | return vertices_clone
32 |
33 |
34 | if __name__ == "__main__":
35 | path = "examples/scenes/animation/animation.xml"
36 |
37 | mitsuba_scene = mi.load_file(path)
38 | mitsuba_params = mi.traverse(mitsuba_scene)
39 | fireflies_scene = fireflies.Scene(mitsuba_params)
40 |
41 | mesh = fireflies_scene.mesh("mesh-Animation")
42 | mesh.add_animation_func(
43 | animation_function,
44 | torch.tensor([0.0]).to(fireflies_scene.device()),
45 | torch.tensor([2 * np.pi]).to(fireflies_scene.device()),
46 | )
47 |
48 | fireflies_scene.eval()
49 | for i in range(1000):
50 | fireflies_scene.randomize()
51 | render = mi.render(mitsuba_scene, spp=12)
52 | cv2.imshow("Animation Example", render_to_opencv(render))
53 | cv2.waitKey(25)
--------------------------------------------------------------------------------
/examples/02_general_transformations.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 |
5 | mi.set_variant("cuda_ad_rgb")
6 |
7 | import torch
8 | import fireflies
9 |
10 |
11 | def render_to_opencv(render):
12 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
13 | return (render * 255).astype(np.uint8)
14 |
15 |
16 | if __name__ == "__main__":
17 | path = "scenes/hello_world/hello_world.xml"
18 |
19 | mi_scene = mi.load_file(path)
20 | mi_params = mi.traverse(mi_scene)
21 | ff_scene = fireflies.Scene(mi_params)
22 |
23 | mesh = ff_scene.mesh_at(0)
24 |
25 | # Rotations
26 | mesh.rotate_x(-0.5, 0.5)
27 | mesh.rotate_y(-0.5, 0.5)
28 | mesh.rotate_z(-0.5, 0.5)
29 | mesh.rotate(
30 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device),
31 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device),
32 | )
33 |
34 | # Translations
35 | mesh.translate_x(-0.5, 0.5)
36 | mesh.translate_y(-0.5, 0.5)
37 | mesh.translate_z(-0.5, 0.5)
38 | mesh.translate(
39 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device),
40 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device),
41 | )
42 |
43 | # Scale
44 | mesh.scale_x(-0.5, 0.5)
45 | mesh.scale_y(-0.5, 0.5)
46 | mesh.scale_z(-0.5, 0.5)
47 | mesh.scale(
48 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device),
49 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device),
50 | )
51 |
52 | # There's more in later examples :)
53 |
54 | ff_scene.train()
55 | for i in range(100):
56 | ff_scene.randomize()
57 |
58 | render = mi.render(mi_scene, spp=10)
59 |
60 | cv2.imshow("a", render_to_opencv(render))
61 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render))
62 | cv2.waitKey(10)
63 |
--------------------------------------------------------------------------------
/fireflies/sampling/animation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.sampling.base as base
3 | import random
4 |
5 |
6 | class AnimationSampler(base.Sampler):
7 | def __init__(
8 | self,
9 | min_integer_train: int,
10 | max_integer_train: int,
11 | min_integer_eval: int,
12 | max_integer_eval: int,
13 | eval_step_size: int = 1,
14 | device: torch.cuda.device = torch.device("cuda"),
15 | ) -> None:
16 | """
17 | Will generate samples from the integer interval given by [min_integer, ..., max_integer) similar to how range() is defined in python.
18 | Assuming we have a set of train and eval objs that were loaded in the Mesh class using load_train_objs() and load_eval_objs().
19 | """
20 | super(AnimationSampler, self).__init__(min_integer_train, max_integer_train, eval_step_size, device)
21 | self._min_integer_train = min_integer_train
22 | self._max_integer_train = max_integer_train
23 | self._min_integer_eval = min_integer_eval
24 | self._max_integer_eval = max_integer_eval
25 | self._current_step = min_integer_eval
26 |
27 | def sample_eval(self) -> int:
28 | sample = self._current_step
29 | self._current_step += self._eval_step_size
30 |
31 | if self._current_step > self._max_integer_eval:
32 | self._current_step = self._min_integer_eval
33 |
34 | return sample
35 |
36 | def sample_train(self) -> int:
37 | return random.randint(self._min_integer_train, self._max_integer_train - 1)
38 |
39 | def set_train_interval(self, min_integer_train: int, max_integer_train: int) -> None:
40 | self._min_integer_train = min_integer_train
41 | self._max_integer_train = max_integer_train
42 |
43 | def set_eval_interval(self, min_integer_eval: int, max_integer_eval: int) -> None:
44 | self._min_integer_eval = min_integer_eval
45 | self._max_integer_eval = max_integer_eval
--------------------------------------------------------------------------------
/examples/06_sampling.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 |
5 | mi.set_variant("cuda_ad_rgb")
6 |
7 | import torch
8 | import fireflies
9 | import fireflies.sampling
10 |
11 |
12 | def render_to_opencv(render):
13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
14 | return (render * 255).astype(np.uint8)
15 |
16 |
17 | # Let's define an animation function that will visualize a sine wave.
18 | # You can define completely arbitrary functions.
19 | def animation_function(vertices: torch.tensor, time: float) -> torch.tensor:
20 | # Let's not change the incoming vertices in place.
21 | vertices_clone = vertices.clone()
22 |
23 | # Change z coordinate of plane via the sin(x_coordinate + time)
24 | wave_direction = 0
25 |
26 | vertices_clone[:, 1] = (
27 | vertices_clone[:, 1]
28 | + torch.sin(vertices_clone[:, 2] * 10.0 + time * 20.0) / 10.0
29 | )
30 |
31 | return vertices_clone
32 |
33 |
34 | if __name__ == "__main__":
35 | path = "examples/scenes/animation/animation.xml"
36 |
37 | mitsuba_scene = mi.load_file(path)
38 | mitsuba_params = mi.traverse(mitsuba_scene)
39 | ff_scene = fireflies.Scene(mitsuba_params)
40 |
41 | mesh = ff_scene.mesh("mesh-Animation")
42 | mesh.add_animation_func(
43 | animation_function,
44 | torch.tensor([0.0]).to(ff_scene.device()),
45 | torch.tensor([2 * np.pi]).to(ff_scene.device()),
46 | )
47 |
48 | normal_distribution_sampler = fireflies.sampling.GaussianSampler(
49 | min=torch.ones(3, device=ff_scene.device())*0.5,
50 | max=torch.ones(3, device=ff_scene.device())*1.5,
51 | mean=torch.ones(3, device=ff_scene.device())*1.0,
52 | std=torch.ones(3, device=ff_scene.device())*0.5,
53 | eval_step_size=0.01,
54 | )
55 | mesh.set_scale_sampler(normal_distribution_sampler)
56 |
57 | ff_scene.train()
58 | for i in range(1000):
59 | ff_scene.randomize()
60 | render = mi.render(mitsuba_scene, spp=12)
61 | cv2.imwrite("a.png", render_to_opencv(render))
62 | cv2.waitKey(25)
--------------------------------------------------------------------------------
/fireflies/projection/camera.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fireflies.utils.math
3 | import fireflies.utils.transforms
4 |
5 | import fireflies.entity.base
6 |
7 |
8 | class Camera:
9 | id = 0
10 | MITSUBA_KEYS = {
11 | "fov": "x_fov",
12 | "f": "x_fov",
13 | "to_world": "to_world",
14 | "world": "to_world",
15 | }
16 |
17 | def __init__(
18 | self,
19 | transform: fireflies.entity.base.Transformable,
20 | perspective: torch.tensor,
21 | fov: float,
22 | near_clip: float = 0.01,
23 | far_clip: float = 1000.0,
24 | device: torch.cuda.device = torch.device("cuda"),
25 | ):
26 | self.device = device
27 |
28 | self._transformable = transform
29 | self._perspective = perspective
30 | self._near_clip = near_clip
31 | self._far_clip = far_clip
32 | self._fov = fov
33 |
34 | self._key = self.generate_mitsuba_key()
35 | Camera.id += 1
36 |
37 | def full_key(self, key: str):
38 | return self._key + "." + Camera.MITSUBA_KEYS[key]
39 |
40 | def key(self) -> str:
41 | return self._key
42 |
43 | def near_clip(self) -> float:
44 | return self._near_clip
45 |
46 | def generate_mitsuba_key(self) -> str:
47 | if Camera.id == 0:
48 | return "PerspectiveCamera"
49 |
50 | return "PerspectiveCamera_{0}".format(id)
51 |
52 | def far_clip(self) -> float:
53 | return self._far_clip
54 |
55 | def fov(self) -> torch.tensor:
56 | return self._fov
57 |
58 | def origin(self) -> torch.tensor:
59 | return self._transformable.origin()
60 |
61 | def world(self) -> torch.tensor:
62 | return self._transformable.world()
63 |
64 | def randomize(self) -> None:
65 | self._transformable.randomize()
66 |
67 | def pointsToNDC(self, points) -> torch.tensor:
68 | view_space_points = fireflies.utils.transforms.transform_points(
69 | points, self.world().inverse()
70 | )
71 | ndc_points = fireflies.utils.transforms.transform_points(
72 | view_space_points, self._perspective
73 | )
74 | return ndc_points
75 |
--------------------------------------------------------------------------------
/fireflies/sampling/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler:
5 | def __init__(
6 | self,
7 | min: torch.tensor,
8 | max: torch.tensor,
9 | eval_step_size: float = 0.01,
10 | device: torch.cuda.device = torch.device("cuda"),
11 | ) -> None:
12 | self._device = device
13 | self._min_range = (
14 | min.clone()
15 | if type(min) is torch.Tensor
16 | else torch.tensor([min], device=device)
17 | )
18 | self._max_range = (
19 | max.clone()
20 | if type(max) is torch.Tensor
21 | else torch.tensor([max], device=device)
22 | )
23 | self._train = True
24 |
25 | self._eval_step_size = eval_step_size
26 | self._current_step = (
27 | self._min_range.clone()
28 | if type(min) is torch.Tensor
29 | else torch.tensor([min], device=device)
30 | )
31 |
32 | def set_sample_interval(self, min: torch.tensor, max: torch.tensor) -> None:
33 | self._min_range = min.clone()
34 | self._max_range = max.clone()
35 |
36 | def get_min(self) -> torch.tensor:
37 | return self._min_range
38 |
39 | def get_max(self) -> torch.tensor:
40 | return self._max_range
41 |
42 | def set_sample_max(self, max: torch.tensor) -> None:
43 | self._max_range = max.clone()
44 |
45 | def set_sample_min(self, min: torch.tensor) -> None:
46 | self._min_range = min.clone()
47 |
48 | def train(self) -> None:
49 | self._train = True
50 |
51 | def eval(self) -> None:
52 | self._train = False
53 |
54 | def sample(self) -> torch.tensor:
55 | if self._train:
56 | return self.sample_train()
57 | else:
58 | return self.sample_eval()
59 |
60 | @NotImplementedError
61 | def sample_train(self) -> torch.tensor:
62 | return None
63 |
64 | def sample_eval(self) -> torch.tensor:
65 | if (self._min_range == self._max_range).all():
66 | return self._min_range
67 |
68 | sample = self._current_step
69 | self._current_step += self._eval_step_size
70 |
71 | if (self._current_step > self._max_range).any():
72 | self._current_step = self._min_range
73 |
74 | return sample
75 |
--------------------------------------------------------------------------------
/fireflies/utils/warnings.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import functools
3 |
4 |
5 | def RotationAssignmentWarning(func):
6 | @functools.wraps(func)
7 | def new_func(*args, **kwargs):
8 | warnings.simplefilter("always", Warning) # turn off filter
9 | warnings.warn(
10 | "This object should generally not have a transformation assignment via {}.".format(
11 | func.__name__
12 | ),
13 | category=Warning,
14 | stacklevel=2,
15 | )
16 | warnings.simplefilter("default", Warning) # reset filter
17 | return func(*args, **kwargs)
18 |
19 | return new_func
20 |
21 |
22 | def RelativeAssignmentWarning(func):
23 | @functools.wraps(func)
24 | def new_func(*args, **kwargs):
25 | warnings.simplefilter("always", Warning) # turn off filter
26 | warnings.warn(
27 | "This object should generally not have a parent/child assignment via {}.".format(
28 | func.__name__
29 | ),
30 | category=Warning,
31 | stacklevel=2,
32 | )
33 | warnings.simplefilter("default", Warning) # reset filter
34 | return func(*args, **kwargs)
35 |
36 | return new_func
37 |
38 |
39 | def TranslationAssignmentWarning(func):
40 | @functools.wraps(func)
41 | def new_func(*args, **kwargs):
42 | warnings.simplefilter("always", Warning) # turn off filter
43 | warnings.warn(
44 | "This object should generally not have a translation assignment via {}.".format(
45 | func.__name__
46 | ),
47 | category=Warning,
48 | stacklevel=2,
49 | )
50 | warnings.simplefilter("default", Warning) # reset filter
51 | return new_func(*args, **kwargs)
52 |
53 |
54 | def WorldAssignmentWarning(func):
55 | @functools.wraps(func)
56 | def new_func(*args, **kwargs):
57 | warnings.simplefilter("always", Warning) # turn off filter
58 | warnings.warn(
59 | "This object should generally not have a to-world matrix via {}.".format(
60 | func.__name__
61 | ),
62 | category=Warning,
63 | stacklevel=2,
64 | )
65 | warnings.simplefilter("default", Warning) # reset filter
66 | return new_func(*args, **kwargs)
67 |
--------------------------------------------------------------------------------
/fireflies/utils/intersections.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | # Batchwise ray-plane intersection
5 | def rayPlane(laserOrigin, laserDirection, planeOrigin, planeNormal):
6 | denom = torch.sum(planeNormal * laserDirection, axis=1)
7 |
8 | denom = torch.where(torch.abs(denom) < 0.000001, denom / denom, denom)
9 | t = torch.sum((planeOrigin - laserOrigin) * planeNormal, axis=1) / denom
10 |
11 | return t[:, None]
12 |
13 |
14 | # Batchwise Sphere-Sphere intersection:
15 | # Inputs:
16 | # - a_coords: Tensor of size NxD where N is the number of circles, and D the number of dimensions
17 | # - a_radii: radii of the spheres in a_xy
18 | # - b_coords: Tensor of size NxD where N is the number of circles, and D the number of dimensions
19 | # - b_radii: radii of the spheres in a_xy
20 | # Outputs:
21 | # - Tensor of size Nx1, with boolean values being
22 | # - TRUE when a hit occured
23 | # - FALSE otherwise
24 |
25 |
26 | def sphereSphere(a_coords, a_radius, b_coords, b_radius):
27 | dist_ab = a_coords - b_coords
28 | squared_dist = dist_ab.pow(2).sum(dim=1, keepdim=True)
29 |
30 | sum_radii = a_radius + b_radius
31 | squared_radii = sum_radii.pow(2)
32 |
33 | return squared_dist <= squared_radii
34 |
35 |
36 | if __name__ == "__main__":
37 | import numpy as np
38 | import cv2
39 |
40 | # Here we create random circles and intersect them
41 | # They're rendered green, if they do not intersect
42 | # And red, if they intersect
43 | for _ in range(1000):
44 | a_coords = torch.rand(1, 2)
45 | a_radius = torch.rand(1, 1) / 2.0
46 | b_coords = torch.rand(1, 2)
47 | b_radius = torch.rand(1, 1) / 2.0
48 |
49 | intersections = sphereSphere(a_coords, a_radius, b_coords, b_radius)
50 |
51 | image = np.zeros((512, 512, 3), dtype=np.uint8)
52 | for i in range(a_coords.shape[0]):
53 | cv2.circle(
54 | image,
55 | (a_coords[i].detach().cpu().numpy() * 512).astype(np.int),
56 | (a_radius[i, 0].detach().cpu().numpy() * 512).astype(np.int),
57 | color=(0, 0, 255) if intersections[i, 0] else (0, 255, 0),
58 | )
59 | cv2.circle(
60 | image,
61 | (b_coords[i].detach().cpu().numpy() * 512).astype(np.int),
62 | (b_radius[i, 0].detach().cpu().numpy() * 512).astype(np.int),
63 | color=(0, 0, 255) if intersections[i, 0] else (0, 255, 0),
64 | )
65 |
66 | cv2.imshow("Test", image)
67 | cv2.waitKey(0)
68 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | beautifulsoup4==4.12.3
2 | certifi==2024.2.2
3 | charset-normalizer==3.3.2
4 | chumpy==0.70
5 | cmake==3.26.4
6 | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
7 | contourpy==1.1.0
8 | cycler==0.11.0
9 | dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
10 | drjit==0.4.4
11 | filelock==3.12.2
12 | fonttools==4.40.0
13 | freetype-py==2.4.0
14 | fsspec==2024.2.0
15 | fvcore @ file:///home/conda/feedstock_root/build_artifacts/fvcore_1671623667463/work
16 | geomdl==5.3.1
17 | idna==3.6
18 | imageio==2.31.1
19 | iopath @ file:///home/conda/feedstock_root/build_artifacts/iopath_1636568816351/work
20 | Jinja2==3.1.2
21 | kiwisolver==1.4.4
22 | kornia==0.7.1
23 | lazy_loader==0.3
24 | lightning-utilities==0.10.1
25 | lit==16.0.6
26 | lxml==5.1.0
27 | MarkupSafe==2.1.3
28 | matplotlib==3.7.2
29 | mitsuba==3.5.0
30 | mpmath==1.3.0
31 | networkx==3.1
32 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1651020388975/work
33 | nvidia-cublas-cu11==11.11.3.6
34 | nvidia-cuda-cupti-cu11==11.8.87
35 | nvidia-cuda-nvrtc-cu11==11.8.89
36 | nvidia-cuda-runtime-cu11==11.8.89
37 | nvidia-cudnn-cu11==8.7.0.84
38 | nvidia-cufft-cu11==10.9.0.58
39 | nvidia-curand-cu11==10.3.0.86
40 | nvidia-cusolver-cu11==11.4.1.48
41 | nvidia-cusparse-cu11==11.7.5.86
42 | nvidia-nccl-cu11==2.20.5
43 | nvidia-nvtx-cu11==11.8.86
44 | opencv-python==4.9.0.80
45 | packaging==23.1
46 | Pillow==10.0.0
47 | plotly==5.19.0
48 | portalocker @ file:///home/conda/feedstock_root/build_artifacts/portalocker_1695662047585/work
49 | pyglet==2.0.12
50 | PyOpenGL==3.1.0
51 | pyparsing==3.0.9
52 | pyrender==0.1.45
53 | python-dateutil==2.8.2
54 | pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@f34104cf6ebefacd7b7e07955ee7aaa823e616ac
55 | PyWavefront==1.3.3
56 | PyYAML==6.0
57 | rdfpy==1.0.0
58 | requests==2.31.0
59 | scikit-image==0.22.0
60 | scipy==1.11.0
61 | six==1.16.0
62 | smplx==0.1.28
63 | soupsieve==2.5
64 | sympy==1.12
65 | tabulate @ file:///home/conda/feedstock_root/build_artifacts/tabulate_1665138452165/work
66 | tenacity==8.2.3
67 | termcolor @ file:///home/conda/feedstock_root/build_artifacts/termcolor_1704357939450/work
68 | tifffile==2024.2.12
69 | torch==2.3.0+cu118
70 | torchaudio==2.3.0+cu118
71 | torchmetrics==1.3.1
72 | torchvision==0.18.0+cu118
73 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1707598593068/work
74 | trimesh==4.1.7
75 | triton==2.3.0
76 | typing_extensions==4.9.0
77 | urllib3==2.2.1
78 | yacs @ file:///home/conda/feedstock_root/build_artifacts/yacs_1645705974477/work
79 |
--------------------------------------------------------------------------------
/fireflies/entity/curve.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 | from geomdl import NURBS
5 | from typing import List
6 |
7 | import fireflies.entity.base as base
8 | import fireflies.utils.math
9 |
10 |
11 | class Curve(base.Transformable):
12 | count = 0.0
13 |
14 | def fromObj(path):
15 | # TODO: Implement me
16 | pass
17 |
18 | def __init__(
19 | self,
20 | name: str,
21 | curve: NURBS.Curve,
22 | device: torch.cuda.device = torch.device("cuda"),
23 | ):
24 | super(Curve, self).__init__(self, name, device)
25 |
26 | self._curve = curve
27 | self.curve_epsilon = 0.05
28 |
29 | self.curve_delta = self.curve_epsilon
30 |
31 | self._interp_steps = 1000
32 | self._interp_delta = 1.0 / self._interp_steps
33 |
34 | self.eval_interval_start = 0.05
35 |
36 | def train(self) -> None:
37 | self._train = True
38 | self._continuous = False
39 |
40 | def eval(self) -> None:
41 | self._train = False
42 | self._continuous = True
43 | self._curve_delta = self.eval_interval_start
44 |
45 | def setContinuous(self, continuous: bool) -> None:
46 | self._continuous = continuous
47 |
48 | def sample_rotation(self) -> torch.tensor:
49 | t = self.curve_delta
50 | t_new = self.curve_delta + 0.001
51 |
52 | t_new = torch.tensor(self._curve.evaluate_single(t_new), device=self._device)
53 | t = torch.tensor(self._curve.evaluate_single(t), device=self._device)
54 |
55 | curve_direction = t_new - t
56 | curve_direction[0] *= -1.0
57 | curve_direction[2] *= -1.0
58 |
59 | # curve_normal = torch.tensor(self._curve.normal(t), device=self._device)
60 | # curve_direction /= torch.linalg.norm(curve_direction)
61 | # curve_normal /= torch.linalg.norm(curve_normal)
62 |
63 | # camera_up_vector = torch.tensor([0, 0, 1], device=self._device)
64 |
65 | camera_direction = torch.tensor([0.0, 1.0, 0.0], device=self._device)
66 | return fireflies.utils.transforms.toMat4x4(
67 | fireflies.utils.math.rotation_matrix_from_vectors(
68 | camera_direction, curve_direction
69 | )
70 | )
71 |
72 | def sample_translation(self) -> torch.tensor:
73 | translationMatrix = torch.eye(4, device=self._device)
74 | translation = self._curve.evaluate_single(self.curve_delta)
75 |
76 | translationMatrix[0, 3] = translation[0]
77 | translationMatrix[1, 3] = translation[1]
78 | translationMatrix[2, 3] = translation[2]
79 |
80 | return translationMatrix
81 |
82 | def randomize(self) -> None:
83 | if self._train:
84 | self.curve_delta = random.uniform(
85 | 0 + self.curve_epsilon, self.eval_interval_start
86 | )
87 | else:
88 | self.curve_delta += self._interp_delta
89 |
90 | if self.curve_delta > 1.0 - self.curve_epsilon:
91 | self.curve_delta = self.eval_interval_start
92 |
93 | self._randomized_world = (
94 | self.sample_translation() @ self.sample_rotation() @ self._world
95 | )
96 |
--------------------------------------------------------------------------------
/fireflies/material/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 |
4 | import fireflies.utils.math
5 | import fireflies.entity
6 | from fireflies.utils.warnings import (
7 | RotationAssignmentWarning,
8 | RelativeAssignmentWarning,
9 | TranslationAssignmentWarning,
10 | WorldAssignmentWarning,
11 | )
12 |
13 |
14 | class Material(fireflies.entity.Transformable):
15 | def __init__(
16 | self,
17 | name: str,
18 | device: torch.cuda.device = torch.device("cuda"),
19 | ):
20 | super(Material, self).__init__(name, device)
21 |
22 | def randomize(self) -> None:
23 | for key, sampler in self._float_attributes.items():
24 | self._randomized_float_attributes[key] = sampler.sample()
25 |
26 | for key, sampler in self._vec3_attributes.items():
27 | self._randomized_vec3_attributes[key] = sampler.sample()
28 |
29 | @WorldAssignmentWarning
30 | def set_world(self, _origin: torch.tensor) -> None:
31 | super(Material, self).set_world(_origin)
32 |
33 | @RelativeAssignmentWarning
34 | def setParent(self, parent) -> None:
35 | super(Material, self).setParent(parent)
36 |
37 | @RelativeAssignmentWarning
38 | def setChild(self, child) -> None:
39 | super(Material, self).setChild(child)
40 |
41 | @RotationAssignmentWarning
42 | def rotate_x(self, min_rot: float, max_rot: float) -> None:
43 | super(Material, self).rotate_x(min_rot, max_rot)
44 |
45 | @RotationAssignmentWarning
46 | def rotate_y(self, min_rot: float, max_rot: float) -> None:
47 | super(Material, self).rotate_y(min_rot, max_rot)
48 |
49 | @RotationAssignmentWarning
50 | def rotate_z(self, min_rot: float, max_rot: float) -> None:
51 | super(Material, self).rotate_z(min_rot, max_rot)
52 |
53 | @RotationAssignmentWarning
54 | def rotate(self, min: torch.tensor, max: torch.tensor) -> None:
55 | super(Material, self).rotate(min, max)
56 |
57 | @TranslationAssignmentWarning
58 | def translate_x(self, min_translation: float, max_translation: float) -> None:
59 | super(Material, self).translate_x(min_translation, max_translation)
60 |
61 | @TranslationAssignmentWarning
62 | def translate_y(self, min_translation: float, max_translation: float) -> None:
63 | super(Material, self).translate_y(min_translation, max_translation)
64 |
65 | @TranslationAssignmentWarning
66 | def translate_z(self, min_translation: float, max_translation: float) -> None:
67 | super(Material, self).translate_z(min_translation, max_translation)
68 |
69 | @TranslationAssignmentWarning
70 | def translate(self, min: torch.tensor, max: torch.tensor) -> None:
71 | super(Material, self).translate(min, max)
72 |
73 | @RotationAssignmentWarning
74 | def sample_rotation(self) -> torch.tensor:
75 | return super(Material, self).sample_rotation()
76 |
77 | @TranslationAssignmentWarning
78 | def sample_translation(self) -> torch.tensor:
79 | return super(Material, self).sample_translation()
80 |
81 | @RelativeAssignmentWarning
82 | def relative(self) -> None:
83 | return super(Material, self).relative()
84 |
85 | @WorldAssignmentWarning
86 | def world(self) -> torch.tensor:
87 | return super(Material, self).world()
88 |
89 | @WorldAssignmentWarning
90 | def nonRandomizedWorld(self) -> torch.tensor:
91 | return super(Material, self).nonRandomizedWorld()
92 |
--------------------------------------------------------------------------------
/fireflies/sampling/noise_texture_lerp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import random
4 | from typing import List
5 | import fireflies.sampling.base as base
6 |
7 |
8 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
9 | delta = (res[0] / shape[0], res[1] / shape[1])
10 | d = (shape[0] // res[0], shape[1] // res[1])
11 |
12 | grid = (
13 | torch.stack(
14 | torch.meshgrid(
15 | torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
16 | ),
17 | dim=-1,
18 | )
19 | % 1
20 | )
21 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
22 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
23 |
24 | tile_grads = (
25 | lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
26 | .repeat_interleave(d[0], 0)
27 | .repeat_interleave(d[1], 1)
28 | )
29 | dot = lambda grad, shift: (
30 | torch.stack(
31 | (
32 | grid[: shape[0], : shape[1], 0] + shift[0],
33 | grid[: shape[0], : shape[1], 1] + shift[1],
34 | ),
35 | dim=-1,
36 | )
37 | * grad[: shape[0], : shape[1]]
38 | ).sum(dim=-1)
39 |
40 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
41 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
42 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
43 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
44 | t = fade(grid[: shape[0], : shape[1]])
45 | return math.sqrt(2) * torch.lerp(
46 | torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
47 | )
48 |
49 |
50 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
51 | noise = torch.zeros(shape)
52 | frequency = 1
53 | amplitude = 1
54 | for _ in range(octaves):
55 | noise += amplitude * rand_perlin_2d(
56 | shape, (frequency * res[0], frequency * res[1])
57 | )
58 | frequency *= 2
59 | amplitude *= persistence
60 | return noise
61 |
62 |
63 | class NoiseTextureLerpSampler(base.Sampler):
64 | def __init__(
65 | self,
66 | color_a: torch.tensor,
67 | color_b: torch.tensor,
68 | texture_shape: List[int],
69 | eval_step_size: float = 0.01,
70 | device: torch.cuda.device = torch.device("cuda"),
71 | ) -> None:
72 | super(NoiseTextureLerpSampler, self).__init__(
73 | torch.tensor([0.0], device=device),
74 | torch.tensor([1.0], device=device),
75 | eval_step_size,
76 | device,
77 | )
78 | self._color_a = color_a
79 | self._color_b = color_b
80 | self._texture_shape = texture_shape
81 |
82 | def sample_train(self) -> torch.tensor:
83 | i = 2 ** random.randint(1, 6)
84 | octaves = random.randint(1, 4)
85 | persistence = random.uniform(0.1, 2.0)
86 | tex = rand_perlin_2d_octaves(
87 | self._texture_shape, res=(i, i), octaves=octaves, persistence=persistence
88 | ).to(self._device)
89 | tex = (tex - tex.min()) / (tex.max() - tex.min())
90 |
91 | col_a = torch.ones_like(tex).unsqueeze(0).repeat(
92 | 3, 1, 1
93 | ) * self._color_a.unsqueeze(-1).unsqueeze(-1)
94 | col_b = torch.ones_like(tex).unsqueeze(0).repeat(
95 | 3, 1, 1
96 | ) * self._color_b.unsqueeze(-1).unsqueeze(-1)
97 | tex = tex.unsqueeze(0).repeat(3, 1, 1)
98 | return torch.lerp(col_a, col_b, tex)
99 |
100 | # To lazy to implement it right now.
101 | def sample_eval(self) -> torch.tensor:
102 | return self.sample_train()
103 |
--------------------------------------------------------------------------------
/fireflies/utils/io.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import math
4 |
5 | from geomdl import NURBS
6 | from pathlib import Path
7 |
8 |
9 | def read_config_yaml(file_path: str) -> dict:
10 | return yaml.safe_load(Path(file_path).read_text())
11 |
12 |
13 | # From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/renderer/cameras.html#FoVPerspectiveCameras.compute_projection_matrix
14 | def build_projection_matrix(
15 | fov: float,
16 | near_clip: float,
17 | far_clip: float,
18 | device: torch.cuda.device = torch.device("cuda"),
19 | ) -> torch.tensor:
20 | """
21 | Compute the calibration matrix K of shape (N, 4, 4)
22 |
23 | Args:
24 | znear: near clipping plane of the view frustrum.
25 | zfar: far clipping plane of the view frustrum.
26 | fov: field of view angle of the camera.
27 | aspect_ratio: aspect ratio of the image pixels.
28 | 1.0 indicates square pixels.
29 | degrees: bool, set to True if fov is specified in degrees.
30 |
31 | Returns:
32 | torch.FloatTensor of the calibration matrix with shape (N, 4, 4)
33 | """
34 | K = torch.zeros((4, 4), dtype=torch.float32, device=device)
35 | fov = (math.pi / 180) * fov
36 |
37 | if not torch.is_tensor(fov):
38 | fov = torch.tensor(fov, device=device)
39 |
40 | tanHalfFov = torch.tan((fov / 2.0))
41 | max_y = tanHalfFov * near_clip
42 | min_y = -max_y
43 | max_x = max_y * 1.0
44 | min_x = -max_x
45 |
46 | # NOTE: In OpenGL the projection matrix changes the handedness of the
47 | # coordinate frame. i.e the NDC space positive z direction is the
48 | # camera space negative z direction. This is because the sign of the z
49 | # in the projection matrix is set to -1.0.
50 | # In pytorch3d we maintain a right handed coordinate system throughout
51 | # so the so the z sign is 1.0.
52 | z_sign = -1.0
53 |
54 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
55 | K[0, 0] = 2.0 * near_clip / (max_x - min_x)
56 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
57 | K[1, 1] = 2.0 * near_clip / (max_y - min_y)
58 | K[0, 2] = (max_x + min_x) / (max_x - min_x)
59 | K[1, 2] = (max_y + min_y) / (max_y - min_y)
60 | K[3, 2] = z_sign
61 |
62 | # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
63 | # is at the near clipping plane and z = 1 when the point is at the far
64 | # clipping plane.
65 | K[2, 2] = z_sign * far_clip / (far_clip - near_clip)
66 | K[2, 3] = -(far_clip * near_clip) / (far_clip - near_clip)
67 |
68 | return K
69 |
70 |
71 | # def tensorToFloatImage(tensor: torch.tensor) -> np.array:
72 | # return
73 |
74 |
75 | def importBlenderNurbsObj(path):
76 | obj_file = open(path, "r")
77 | lines = obj_file.readlines()
78 |
79 | control_points = []
80 | deg = None
81 | knotvector = None
82 |
83 | for line in lines:
84 | token = "v "
85 | if line.startswith("v "):
86 | line = line.replace(token, "")
87 | values = line.split(" ")
88 | values = [float(value) for value in values]
89 | control_points.append(values)
90 | continue
91 |
92 | token = "deg "
93 | if line.startswith(token):
94 | line = line.replace(token, "")
95 | deg = int(line)
96 | continue
97 |
98 | token = "parm u "
99 | if line.startswith(token):
100 | line = line.replace(token, "")
101 | values = line.split(" ")
102 | knotvector = [float(value) for value in values]
103 | continue
104 |
105 | spline = NURBS.Curve()
106 | spline.degree = deg
107 | spline.ctrlpts = control_points
108 | spline.knotvector = knotvector
109 |
110 | return spline
111 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 |
164 |
165 |
166 | # Stuff
167 | scenes*
168 | *.png
169 | *.mp4
170 | *.yaml
171 | *.obj
172 | *.jpg
173 | *.gif
174 | Old*
175 | .vscode*
--------------------------------------------------------------------------------
/examples/vocalfold_scene.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 | import kornia
5 |
6 | mi.set_variant("cuda_ad_rgb")
7 |
8 | import torch
9 | import fireflies
10 | import fireflies.sampling
11 | import fireflies.projection.laser
12 |
13 |
14 | def render_to_opencv(render):
15 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
16 | return (render * 255).astype(np.uint8)
17 |
18 |
19 | if __name__ == "__main__":
20 | path = "examples/scenes/vocalfold/vocalfold.xml"
21 |
22 | mitsuba_scene = mi.load_file(path)
23 | mitsuba_params = mi.traverse(mitsuba_scene)
24 | ff_scene = fireflies.Scene(mitsuba_params)
25 |
26 | projector_sensor = mitsuba_scene.sensors()[1]
27 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"]
28 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"]
29 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"]
30 |
31 | K_PROJECTOR = mi.perspective_projection(
32 | projector_sensor.film().size(),
33 | projector_sensor.film().crop_size(),
34 | projector_sensor.film().crop_offset(),
35 | x_fov,
36 | near_clip,
37 | far_clip,
38 | ).matrix.torch()[0]
39 |
40 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays(
41 | # 0.0275, 18, 18, device=ff_scene.device()
42 | # )
43 | laser_rays = fireflies.projection.Laser.generate_blue_noise_rays(
44 | 500, 500, 18 * 18, K_PROJECTOR, device=ff_scene.device()
45 | )
46 |
47 | laser = fireflies.projection.Laser(
48 | ff_scene._projector,
49 | laser_rays,
50 | K_PROJECTOR,
51 | x_fov,
52 | near_clip,
53 | far_clip,
54 | device=ff_scene.device(),
55 | )
56 | texture = laser.generateTexture(
57 | 10.0, torch.tensor([500, 500], device=ff_scene.device())
58 | )
59 | texture = texture.sum(dim=0)
60 |
61 | texture = kornia.filters.gaussian_blur2d(
62 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3)
63 | ).squeeze()
64 | texture = torch.stack(
65 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)]
66 | )
67 | texture = torch.movedim(texture, 0, -1)
68 |
69 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy())
70 |
71 | vocalfold_mesh = ff_scene.mesh("mesh-VocalFold")
72 | larynx_mesh = ff_scene.mesh("mesh-Larynx")
73 | larynx_mesh.scale_x(0.8, 1.2)
74 | larynx_mesh.rotate_y(-0.1, 0.1)
75 |
76 | vocalfold_mesh.scale_x(0.5, 2.0)
77 | vocalfold_mesh.rotate_y(-0.25, 0.25)
78 |
79 | material = ff_scene.material("mat-Default OBJ")
80 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler(
81 | 1.0, 20.0, device=ff_scene.device()
82 | )
83 |
84 | light = ff_scene.light("emit-Spot")
85 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler)
86 |
87 | material.add_vec3_key(
88 | "brdf_0.base_color.value",
89 | torch.tensor([0.8, 0.14, 0.34], device=ff_scene.device()),
90 | torch.tensor([0.85, 0.5, 0.44], device=ff_scene.device()),
91 | )
92 | material.add_float_key("brdf_0.specular", 0.0, 0.75)
93 |
94 | # ff_scene._camera.rotate_y(-0.25, 0.25)
95 | # ff_scene._projector.setParent(ff_scene._camera)
96 | # ff_scene._projector._randomizable = True
97 | # ff_scene._projector.rotate_x(3.141, 3.141)
98 |
99 | ff_scene.train()
100 | for i in range(1000):
101 | ff_scene.randomize()
102 | render = mi.render(mitsuba_scene, spp=100)
103 | render = render_to_opencv(render)
104 |
105 | if i % 2 == 0:
106 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY).astype(int)
107 | noise = np.random.normal(np.zeros_like(render), np.ones_like(render) * 0.05)
108 | noise *= 255
109 | noise = noise.astype(np.int)
110 | render += noise
111 | render[render > 255] = 255
112 | render[render < 0] = 0
113 | render = render.astype(np.uint8)
114 |
115 | cv2.imwrite("vf_renderings/{0:05d}.png".format(i), render)
116 | # cv2.imshow("A", render)
117 | # cv2.waitKey(0)
118 |
--------------------------------------------------------------------------------
/fireflies/entity/flame.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import flame_pytorch.flame as flame
4 | from typing import List
5 |
6 | import fireflies.entity.base as base
7 | import shape
8 |
9 | import fireflies.utils.math
10 |
11 |
12 | class FlameShapeModel(shape.ShapeModel):
13 | def __init__(
14 | self,
15 | name: str,
16 | vertex_data: List[float],
17 | config: dict,
18 | device: torch.cuda.device = torch.device("cuda"),
19 | base_path: str = None,
20 | sequential_animation: bool = False,
21 | ):
22 |
23 | self._device = device
24 | self._name = name
25 |
26 | self.setTranslationBoundaries(config["translation"])
27 | self.setRotationBoundaries(config["rotation"])
28 | self.setWorld(config["to_world"])
29 | self._world = self._world @ fireflies.utils.transforms.toMat4x4(
30 | fireflies.utils.math.getXTransform(np.pi * 0.5, self._device)
31 | )
32 | self._randomized_world = self._world.clone()
33 |
34 | self._randomizable = bool(config["randomizable"])
35 | self._relative = bool(config["is_relative"])
36 |
37 | self._parent_name = config["parent_name"] if self._relative else None
38 | # Is loaded in a second step
39 | self._parent = None
40 | self._child = None
41 |
42 | self.setVertices(vertex_data)
43 | self.setScaleBoundaries(config["scale"])
44 | self._animated = bool(config["animated"])
45 | self._sequential_animation = sequential_animation
46 |
47 | self._animation_index = 0
48 |
49 | flame_config = Namespace(
50 | **{
51 | "batch_size": 1,
52 | "dynamic_landmark_embedding_path": "./Objects/flame_pytorch/model/flame_dynamic_embedding.npy",
53 | "expression_params": 50,
54 | "flame_model_path": "./Objects/flame_pytorch/model/generic_model.pkl",
55 | "num_worker": 4,
56 | "optimize_eyeballpose": True,
57 | "optimize_neckpose": True,
58 | "pose_params": 6,
59 | "ring_loss_weight": 1.0,
60 | "ring_margin": 0.5,
61 | "shape_params": 100,
62 | "static_landmark_embedding_path": "./Objects/flame_pytorch/model/flame_static_embedding.pkl",
63 | "use_3D_translation": True,
64 | "use_face_contour": True,
65 | }
66 | )
67 |
68 | self.setVertices(vertex_data)
69 | self.setScaleBoundaries(config["scale"])
70 | self._animated = True
71 | self._stddev_range = config["stddev_range"]
72 | self._shape_layer = flame.FLAME(flame_config).to(self._device)
73 | self._faces = self._shape_layer.faces
74 | self._pose_params = torch.zeros(1, 6, device=self._device)
75 | self._expression_params = torch.zeros(1, 50, device=self._device)
76 | self._shape_params = (
77 | (torch.rand(1, 100, device=self._device) - 0.5) * 2.0 * self._stddev_range
78 | )
79 | self._shape_params[:, 20:] = 0.0
80 |
81 | self._shape_params *= 0.0
82 | self._invert = False
83 |
84 | def train(self) -> None:
85 | base.Transformable.train(self)
86 |
87 | def eval(self) -> None:
88 | base.Transformable.eval(self)
89 |
90 | def loadAnimation(self):
91 | return None
92 |
93 | def modelParams(self) -> dict:
94 | return self._shape_params
95 |
96 | def shapeParams(self) -> torch.tensor:
97 | return self._shape_params
98 |
99 | def expressionParams(self) -> torch.tensor:
100 | return self._expression_params
101 |
102 | def poseParams(self) -> torch.tensor:
103 | return self._pose_params
104 |
105 | def randomize(self) -> None:
106 | if self._shape_params[0, 0] > 2.0:
107 | self._invert = True
108 | self._shape_params = self._shape_params + (0.05 if not self._invert else -0.05)
109 | self._shape_params[:, 20:] = 0.0
110 |
111 | self._randomized_world = (
112 | self.sampleTranslation() @ self.sampleRotation() @ self.sampleScale()
113 | )
114 |
115 | def getVertexData(self):
116 | if not self._animated:
117 | return self._vertices, self._shape_layer.faces
118 |
119 | vertices, _ = self._shape_layer(
120 | self._shape_params, self._expression_params, self._pose_params
121 | )
122 | vertices = vertices[0]
123 |
124 | vertices = fireflies.utils.transforms.transform_points(
125 | vertices,
126 | self.world()
127 | @ fireflies.utils.transforms.toMat4x4(
128 | fireflies.utils.math.getXTransform(np.pi * 0.5, self._device)
129 | ),
130 | )
131 |
132 | return vertices, self._shape_layer.faces
133 |
--------------------------------------------------------------------------------
/fireflies/sampling/poisson.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """
4 | Implementation of the fast Poisson Disk Sampling algorithm of
5 | Bridson (2007) adapted to support spatially varying sampling radii.
6 |
7 | Adrian Bittner, 2021
8 | Published under MIT license.
9 | """
10 |
11 |
12 | def getGridCoordinates(coords):
13 | return np.floor(coords).astype("int")
14 |
15 |
16 | def bridson(radius, k=30, radiusType="default"):
17 | """
18 | Implementation of the Poisson Disk Sampling algorithm.
19 |
20 | :param radius: 2d array specifying the minimum sampling radius for each spatial position in the sampling box. The
21 | size of the sampling box is given by the size of the radius array.
22 | :param k: Number of iterations to find a new particle in an annulus between radius r and 2r from a sample particle.
23 | :param radiusType: Method to determine the distance to newly spawned particles. 'default' follows the algorithm of
24 | Bridson (2007) and generates particles uniformly in the annulus between radius r and 2r.
25 | 'normDist' instead creates new particles at distances drawn from a normal distribution centered
26 | around 1.5r with a dispersion of 0.2r.
27 | :return: nParticle: Number of particles in the sampling.
28 | particleCoordinates: 2d array containing the coordinates of the created particles.
29 | """
30 | # Set-up background grid
31 | gridHeight, gridWidth = radius.shape
32 | grid = np.zeros((gridHeight, gridWidth))
33 |
34 | # Pick initial (active) point
35 | coords = (np.random.random() * gridHeight, np.random.random() * gridWidth)
36 | idx = getGridCoordinates(coords)
37 | nParticle = 1
38 | grid[idx[0], idx[1]] = nParticle
39 |
40 | # Initialise active queue
41 | queue = [
42 | coords
43 | ] # Appending to list is much quicker than to numpy array, if you do it very often
44 | particleCoordinates = [
45 | coords
46 | ] # List containing the exact positions of the final particles
47 |
48 | # Continue iteration while there is still points in active list
49 | while queue:
50 |
51 | # Pick random element in active queue
52 | idx = np.random.randint(len(queue))
53 | activeCoords = queue[idx]
54 | activeGridCoords = getGridCoordinates(activeCoords)
55 |
56 | success = False
57 | for _ in range(k):
58 |
59 | if radiusType == "default":
60 | # Pick radius for new sample particle ranging between 1 and 2 times the local radius
61 | newRadius = radius[activeGridCoords[0], activeGridCoords[1]] * (
62 | np.random.random() + 1
63 | )
64 | elif radiusType == "normDist":
65 | # Pick radius for new sample particle from a normal distribution around 1.5 times the local radius
66 | newRadius = radius[
67 | activeGridCoords[0], activeGridCoords[1]
68 | ] * np.random.normal(1.5, 0.2)
69 |
70 | # Pick the angle to the sample particle and determine its coordinates
71 | angle = 2 * np.pi * np.random.random()
72 | newCoords = np.zeros(2)
73 | newCoords[0] = activeCoords[0] + newRadius * np.sin(angle)
74 | newCoords[1] = activeCoords[1] + newRadius * np.cos(angle)
75 |
76 | # Prevent that the new particle is outside of the grid
77 | if not (0 <= newCoords[1] <= gridWidth and 0 <= newCoords[0] <= gridHeight):
78 | continue
79 |
80 | # Check that particle is not too close to other particle
81 | newGridCoords = getGridCoordinates((newCoords[0], newCoords[1]))
82 |
83 | radiusThere = np.ceil(radius[newGridCoords[0], newGridCoords[1]])
84 |
85 | gridRangeX = (
86 | np.max([newGridCoords[1] - radiusThere, 0]).astype("int"),
87 | np.min([newGridCoords[1] + radiusThere + 1, gridWidth]).astype("int"),
88 | )
89 | gridRangeY = (
90 | np.max([newGridCoords[0] - radiusThere, 0]).astype("int"),
91 | np.min([newGridCoords[0] + radiusThere + 1, gridHeight]).astype("int"),
92 | )
93 |
94 | searchGrid = grid[
95 | slice(gridRangeY[0], gridRangeY[1]), slice(gridRangeX[0], gridRangeX[1])
96 | ]
97 | conflicts = np.where(searchGrid > 0)
98 |
99 | if len(conflicts[0]) == 0 and len(conflicts[1]) == 0:
100 | # No conflicts detected. Create a new particle at this position!
101 | queue.append(newCoords)
102 | particleCoordinates.append(newCoords)
103 | nParticle += 1
104 | grid[newGridCoords[0], newGridCoords[1]] = nParticle
105 | success = True
106 |
107 | else:
108 | # There is a conflict. Do NOT create a new particle at this position!
109 | continue
110 |
111 | if success == False:
112 | # No new particle could be associated to the currently active particle.
113 | # Remove current particle from the active queue!
114 | del queue[idx]
115 |
116 | return (nParticle, np.array(particleCoordinates))
117 |
118 |
119 | if __name__ == "__main__":
120 | import matplotlib.pyplot as plt
121 |
122 | width = 512
123 | height = 512
124 |
125 | radius = 10
126 | max_radius = 3 * radius
127 |
128 | import cv2
129 |
130 | weight_matrix = (
131 | np.ones([height, width], np.float32) * max_radius
132 | ) # Should be between 0 and 1
133 | cv2.circle(weight_matrix, np.array(weight_matrix.shape) // 2, 50, radius, -1)
134 |
135 | # cv2.imshow("Weight Matrix", weight_matrix)
136 | # cv2.waitKey(0)
137 |
138 | npoints, points = bridson(weight_matrix)
139 | # points = np.array(poisson_disc_samples(width=width, height=height, r=radius))
140 |
141 | plt.scatter(points[:, 0], points[:, 1])
142 | plt.xlim(0, width)
143 | plt.ylim(0, height)
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | **Fireflies** is a wrapper for the Mitsuba Renderer and allows for rapid prototyping and generation of physically-based renderings and simulation data in a differentiable manner.
4 | It can be used for example, to easily generate highly realistic medical imaging data for medical machine learning tasks or (its intended use) test the reconstruction capabilities of Structured Light projection systems in simulated environments.
5 | I originally created it to research if the task of finding an optimal point-based laser pattern for structured light laryngoscopy can be reformulated as a gradient-based optimization problem.
6 |
7 | This code accompanies the paper **Fireflies: Photorealistic Simulation and Optimization of Structured Light Endoscopy** accepted at **SASHIMI 2024**. 🎊
8 |
9 |
10 | # Main features
11 | - **Easy torch-like and pythonic scene randomization description.** This library is made to be easily usable for everyone who regularly uses python and pytorch. We implement train() and eval() functionality from the get go.
12 | - **Integratable into online deep-learning and machine learning tasks** due to the differentiability of the mitsuba renderer w.r.t. the scene parameters.
13 | - **Simple animation description**. Have a look into the examples.
14 | - **Single Shot Structured Light specific**. You can easily test different projection pattern and reconstruction algorithms on randomized scenes, giving a good estimation of the quality and viability of patterns/systems/algorithms.
15 |
16 | # Installation
17 | Make sure to create a conda environment first.
18 | I tested fireflies on Python 3.10, it should however work with every Python version that is also supported by Mitsuba and Pytorch.
19 | I'm working on adding Fireflies to PyPi in the future.
20 | First install the necessary dependencies:
21 | ```
22 | pip install pywavefront geomdl
23 | pip install torch
24 | pip install mitsuba
25 | ```
26 | To run the examples, you also need OpenCV:
27 | ```
28 | pip install opencv-python
29 | ```
30 | Finally, you can install Fireflies via:
31 | ```
32 | git clone https://github.com/Henningson/Fireflies.git
33 | cd Fireflies
34 | pip install .
35 | ```
36 |
37 | 
38 | # Usage
39 | ```
40 | import mitsuba as mi
41 | import fireflies as ff
42 |
43 | mi_scene = mi.scene(path)
44 | mi_params = mi.traverse(mi_scene)
45 | ff_scene = ff.scene(mi_params)
46 |
47 | mesh = ff_scene.mesh_at(0)
48 | mesh.rotate_z(-3.141, 3.141)
49 |
50 | ff_scene.eval()
51 | #ff_scene.train() generates uniformly sampled results on the right
52 | for i in range(0, 20):
53 | ff_scene.randomize()
54 | mi.render(mi_scene)
55 | ```
56 |
57 |
58 |
59 |
60 |
61 |
62 | # Examples
63 | A bunch of different examples can be found in the examples folder.
64 | Papercode can be found in the **Paper** branch.
65 |
66 | # Building and loading your own scene
67 | You can easily generate a scene using Blender.
68 | To export a scene in Mitsubas required .xml format, you first need to install the Mitsuba Blender Add-On.
69 | You can then export it under the File -> Export Tab.
70 | Make sure to tick the ✅ export ids Checkbox, as fireflies infers the object type by checking for name qualifiers with specific keys, e.g.: "mesh", "brdf", etc.
71 |
72 | # Render Gallery
73 |
74 |
75 |
76 |
77 |
78 |
79 | These are some renderings that were created during my work on the aforementioned paper.
80 | From left to right: Reconstructed in-vivo colon, the flame shapemodel, phonating human vocal folds with a point-based structured light pattern.
81 |
82 | ## More Discussion about the Paper
83 | Can be found in the **README** of the **paper** branch.
84 |
85 | ## Why did you call this Fireflies?
86 | Because optimizing a point-based laser pattern looks like fireflies that jet around. :)
87 |
88 |
89 |
90 |
91 |
92 | ## Acknowledgements
93 | A big thank you to Wenzel Jakob and team for their wonderful work on the Mitsuba renderer.
94 | You should definitely check out their work: Mitsuba Homepage, Mitsuba Github.
95 |
96 | Furthermore, this work was supported by Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under grant STA662/6-1, Project-ID 448240908 and (partly) funded by the DFG – SFB 1483 – Project-ID 442419336, EmpkinS.
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | ## Citation
105 | Please cite this, if this work helps you with your research:
106 | ```
107 | @InProceedings{10.1007/978-3-031-73281-2_10,
108 | author="Henningson, Jann-Ole and Veltrup, Reinhard and Semmler, Marion and D{\"o}llinger, Michael and Stamminger, Marc",
109 | title="Fireflies: Photorealistic Simulation and Optimization of Structured Light Endoscopy",
110 | booktitle="Simulation and Synthesis in Medical Imaging",
111 | year="2025",
112 | publisher="Springer Nature Switzerland",
113 | address="Cham",
114 | pages="102--112",
115 | isbn="978-3-031-73281-2"
116 | }
117 | ```
118 |
--------------------------------------------------------------------------------
/fireflies/graphics/depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import mitsuba as mi
3 | import drjit as dr
4 | import fireflies.utils.laser_estimation
5 |
6 | from tqdm import tqdm
7 |
8 |
9 | @dr.wrap_ad(source="torch", target="drjit")
10 | def from_laser(scene, params, laser):
11 | sensor = scene.sensors()[0]
12 | film = sensor.film()
13 | # TODO: Add device
14 | size = torch.tensor(film.size(), device="cuda")
15 |
16 | hit_points = fireflies.utils.laser_estimation.cast_laser(scene, laser=laser)
17 | ndc_coords = fireflies.utils.laser_estimation.project_to_camera_space(
18 | params, hit_points
19 | )
20 | pixel_coords = ndc_coords * 0.5 + 0.5
21 | pixel_coords = pixel_coords[0, :, 0:2]
22 | pixel_coords = torch.floor(pixel_coords * size).int()
23 |
24 | mask = torch.zeros(size.tolist(), device=size.device)
25 | mask[pixel_coords[:, 0], pixel_coords[:, 1]] = 1.0
26 |
27 | depth_map = from_camera_non_wrapped(scene, spp=1)
28 | depth_map = depth_map.reshape(film.size()[0], film.size()[1])
29 |
30 | return depth_map * mask
31 |
32 |
33 | @dr.wrap_ad(source="torch", target="drjit")
34 | def cast_laser_id(scene, origin, direction):
35 | origin_point = mi.Point3f(
36 | origin[:, 0].array, origin[:, 1].array, origin[:, 2].array
37 | )
38 | rays_vector = mi.Vector3f(
39 | direction[:, 0].array, direction[:, 1].array, direction[:, 2].array
40 | )
41 | surface_interaction = scene.ray_intersect(mi.Ray3f(origin_point, rays_vector))
42 | shape_pointer = mi.Int(
43 | dr.reinterpret_array_v(mi.UInt, surface_interaction.shape)
44 | ).torch()
45 | shape_pointer -= shape_pointer.min()
46 | return shape_pointer
47 |
48 |
49 | def from_camera_non_wrapped(scene, spp=64):
50 | sensor = scene.sensors()[0]
51 | film = sensor.film()
52 | sampler = sensor.sampler()
53 | film_size = film.crop_size()
54 | total_samples = dr.prod(film_size) * spp
55 |
56 | if sampler.wavefront_size() != total_samples:
57 | sampler.seed(0, total_samples)
58 |
59 | # Enumerate discrete sample & pixel indices, and uniformly sample
60 | # positions within each pixel.
61 | pos = dr.arange(mi.UInt32, total_samples)
62 |
63 | pos //= spp
64 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
65 | pos = mi.Vector2f(
66 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0]))
67 | )
68 |
69 | # pos += sampler.next_2d()
70 |
71 | # Sample rays starting from the camera sensor
72 | rays, weights = sensor.sample_ray(
73 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0
74 | )
75 |
76 | # Intersect rays with the scene geometry
77 | surface_interaction = scene.ray_intersect(rays)
78 |
79 | # Given intersection, compute the final pixel values as the depth t
80 | # of the sampled surface interaction
81 | result = surface_interaction.t
82 |
83 | # Set to zero if no intersection was found
84 | result[~surface_interaction.is_valid()] = 0
85 |
86 | return result
87 |
88 |
89 | def get_segmentation_from_camera(scene, spp=1):
90 | sensor = scene.sensors()[0]
91 | film = sensor.film()
92 | sampler = sensor.sampler()
93 | film_size = film.crop_size()
94 | total_samples = dr.prod(film_size) * spp
95 |
96 | if sampler.wavefront_size() != total_samples:
97 | sampler.seed(0, total_samples)
98 |
99 | # Enumerate discrete sample & pixel indices, and uniformly sample
100 | # positions within each pixel.
101 | pos = dr.arange(mi.UInt32, total_samples)
102 |
103 | pos //= spp
104 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
105 | pos = mi.Vector2f(
106 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0]))
107 | )
108 |
109 | # Sample rays starting from the camera sensor
110 | rays, weights = sensor.sample_ray(
111 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0
112 | )
113 |
114 | # Intersect rays with the scene geometry
115 | surface_interaction = scene.ray_intersect(rays)
116 |
117 | # Watch out, hacky stuff going on!
118 | # Solution from: https://github.com/mitsuba-renderer/mitsuba3/discussions/882
119 | shape_pointer = mi.Int(
120 | dr.reinterpret_array_v(mi.UInt, surface_interaction.shape)
121 | ).torch()
122 | shape_pointer -= shape_pointer.min()
123 | shape_pointer = shape_pointer.max() - shape_pointer
124 |
125 | return shape_pointer.reshape(film_size[1], film_size[0])
126 |
127 |
128 | @dr.wrap_ad(source="drjit", target="torch")
129 | def from_camera(scene, spp=64):
130 | sensor = scene.sensors()[0]
131 | film = sensor.film()
132 | sampler = sensor.sampler()
133 | film_size = film.crop_size()
134 | total_samples = dr.prod(film_size) * spp
135 |
136 | if sampler.wavefront_size() != total_samples:
137 | sampler.seed(0, total_samples)
138 |
139 | # Enumerate discrete sample & pixel indices, and uniformly sample
140 | # positions within each pixel.
141 | pos = dr.arange(mi.UInt32, total_samples)
142 |
143 | pos //= spp
144 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
145 | pos = mi.Vector2f(
146 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0]))
147 | )
148 |
149 | pos += sampler.next_2d()
150 |
151 | # Sample rays starting from the camera sensor
152 | rays, weights = sensor.sample_ray(
153 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0
154 | )
155 |
156 | # Intersect rays with the scene geometry
157 | surface_interaction = scene.ray_intersect(rays)
158 |
159 | # Given intersection, compute the final pixel values as the depth t
160 | # of the sampled surface interaction
161 | result = surface_interaction.t
162 |
163 | # Set to zero if no intersection was found
164 | result[~surface_interaction.is_valid()] = 0
165 |
166 | return result
167 |
168 |
169 | def random_depth_maps(
170 | firefly_scene, mi_scene, num_maps: int = 100, spp: int = 1
171 | ) -> torch.tensor:
172 | stacked_depth_maps = []
173 | im_size = mi_scene.sensors()[0].film().size()
174 |
175 | for i in tqdm(range(num_maps)):
176 | firefly_scene.randomize()
177 |
178 | depth_map = from_camera_non_wrapped(mi_scene, spp=1)
179 |
180 | # vis_depth = torch.log(depth_map.torch().reshape(im_size[1], im_size[0]).clone())
181 | # vis_depth = utils.normalize(vis_depth)
182 | # vis_depth = ((1 - vis_depth.detach().cpu().numpy()) * 255).astype(np.uint8)
183 | # colored = cv2.applyColorMap(vis_depth, cv2.COLORMAP_INFERNO)
184 | # cv2.imshow("Depth", colored)
185 | # cv2.waitKey(1)
186 |
187 | depth_map = depth_map.torch().reshape(im_size[1], im_size[0], spp).mean(dim=-1)
188 | stacked_depth_maps.append(depth_map)
189 |
190 | return torch.stack(stacked_depth_maps)
191 |
--------------------------------------------------------------------------------
/fireflies/entity/mesh.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pywavefront
4 |
5 | import fireflies.entity.base as base
6 | import fireflies.utils.math
7 | import fireflies.sampling
8 |
9 |
10 | class Mesh(base.Transformable):
11 | def __init__(
12 | self,
13 | name: str,
14 | vertex_data: torch.tensor,
15 | device: torch.cuda.device = torch.device("cuda"),
16 | ):
17 | super(Mesh, self).__init__(name, device)
18 |
19 | self._vertices = vertex_data.to(self._device)
20 | self._vertices_animation = None
21 |
22 | ones = torch.ones(3, device=self._device)
23 | self._scale_sampler = fireflies.sampling.UniformSampler(
24 | ones.clone(), ones.clone()
25 | )
26 |
27 | self._animated = False
28 |
29 | self._anim_data_train = None
30 | self._anim_data_eval = None
31 | self._animation_func = None
32 | self._animation_sampler = None
33 |
34 | def set_scale_sampler(self, sampler: fireflies.sampling.Sampler) -> None:
35 | self._scale_sampler = sampler
36 |
37 | def scale_x(self, min_scale: float, max_scale: float) -> None:
38 | self._randomizable = True
39 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 0)
40 |
41 | def scale_y(self, min_scale: float, max_scale: float) -> None:
42 | self._randomizable = True
43 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 1)
44 |
45 | def scale_z(self, min_scale: float, max_scale: float) -> None:
46 | self._randomizable = True
47 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 2)
48 |
49 | def scale(self, min: torch.tensor, max: torch.tensor) -> None:
50 | self._randomizable = True
51 | self._scale_sampler.set_sample_interval(
52 | min.to(self._device), max.to(self._device)
53 | )
54 |
55 | def set_scale_sampler(self, sampler: fireflies.sampling.Sampler) -> None:
56 | self._scale_sampler = sampler
57 |
58 | def animated(self) -> bool:
59 | return self._animated
60 |
61 | def add_animation(self, animation_data: torch.tensor) -> None:
62 | self._animation_vertices = animation_data.to(self._device)
63 | self._animated = True
64 | self._randomizable = True
65 |
66 | def add_animation_func(self, func, min_range, max_range) -> None:
67 | self._animation_func = func
68 | self._animation_sampler = fireflies.sampling.UniformSampler(
69 | min_range, max_range, device=self._device
70 | )
71 | self._animated = True
72 | self._randomizable = True
73 |
74 | def add_train_animation_from_obj(
75 | self, path: str, min: int = None, max: int = None
76 | ) -> None:
77 | self._anim_data_train = self.load_animation(path)
78 |
79 | if self._animation_sampler:
80 | self._animation_sampler.set_train_interval(
81 | 0 if min is None else 0,
82 | self._anim_data_train.shape[0] if max is None else max,
83 | )
84 | return
85 |
86 | self._animation_sampler = fireflies.sampling.AnimationSampler(0, 1, 0, 1)
87 | self._animation_sampler.set_train_interval(
88 | 0 if min is None else 0,
89 | self._anim_data_train.shape[0] if max is None else max,
90 | )
91 | self._animated = True
92 |
93 | def add_eval_animation_from_obj(
94 | self, path: str, min: int = None, max: int = None
95 | ) -> None:
96 | self._anim_data_eval = self.load_animation(path)
97 |
98 | if self._animation_sampler:
99 | self._animation_sampler.set_eval_interval(
100 | 0 if min is None else 0,
101 | self._anim_data_eval.shape[0] if max is None else max,
102 | )
103 | return
104 |
105 | self._animation_sampler = fireflies.sampling.AnimationSampler(0, 1, 0, 1)
106 | self._animation_sampler.set_eval_interval(
107 | 0 if min is None else 0,
108 | self._anim_data_eval.shape[0] if max is None else max,
109 | )
110 |
111 | def train(self) -> None:
112 | super(Mesh, self).train()
113 | self._scale_sampler.train()
114 |
115 | if self._animation_sampler:
116 | self._animation_sampler.train()
117 |
118 | def eval(self) -> None:
119 | super(Mesh, self).eval()
120 | self._scale_sampler.eval()
121 |
122 | if self._animation_sampler:
123 | self._animation_sampler.eval()
124 |
125 | def set_faces(self, faces: torch.tensor) -> None:
126 | self._faces = faces.to(self._device)
127 |
128 | def set_vertices(self, vertices: torch.tensor) -> None:
129 | self._vertices = vertices.to(self._device)
130 |
131 | def sample_scale(self) -> torch.tensor:
132 | scale_matrix = torch.eye(4, device=self._device)
133 |
134 | random_scale = self._scale_sampler.sample()
135 |
136 | scale_matrix[0, 0] = random_scale[0]
137 | scale_matrix[1, 1] = random_scale[1]
138 | scale_matrix[2, 2] = random_scale[2]
139 | return scale_matrix
140 |
141 | def randomize(self) -> None:
142 | if not self.randomizable():
143 | return
144 |
145 | self._randomized_world = (
146 | (self.sample_translation() + self._centroid_mat)
147 | @ self.sample_rotation()
148 | @ self.sample_scale()
149 | @ self._world
150 | )
151 |
152 | def faces(self) -> torch.tensor:
153 | return self._faces
154 |
155 | def get_vertices(self) -> torch.tensor:
156 | return self._vertices
157 |
158 | def get_randomized_vertices(self) -> torch.tensor:
159 | # Sample Animations
160 | temp_vertex = self.sample_animation() if self._animated else self._vertices
161 |
162 | # Transform by world transform
163 | temp_vertex = fireflies.utils.math.transform_points(temp_vertex, self.world())
164 |
165 | return temp_vertex
166 |
167 | def load_animation(self, path: str) -> torch.tensor:
168 | animation_data = []
169 | for file in sorted(os.listdir(path)):
170 | if file.endswith(".obj"):
171 | obj_path = os.path.join(path, file)
172 |
173 | obj = pywavefront.Wavefront(obj_path, collect_faces=True)
174 |
175 | vertices = torch.tensor(obj.vertices, device=self._device).reshape(
176 | -1, 3
177 | )
178 |
179 | animation_data.append(vertices)
180 |
181 | return torch.stack(animation_data)
182 |
183 | def sample_animation(self):
184 | if not self._animated:
185 | return self._vertices
186 |
187 | # Can either be an integer or a float, depending if we loaded meshes or defined an animation function
188 | time_sample = self._animation_sampler.sample()
189 | if self._animation_func is not None:
190 | return self._animation_func(self._vertices, time_sample)
191 | elif self._anim_data_train is not None and self._anim_data_eval is not None:
192 | return (
193 | self._anim_data_train[time_sample]
194 | if self._train
195 | else self._anim_data_eval[time_sample]
196 | )
197 |
198 | return None
199 |
--------------------------------------------------------------------------------
/fireflies/utils/math.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | def uniformBetweenValues(a: float, b: float) -> float:
9 | return random.uniform(a, b)
10 |
11 |
12 | def getZTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
13 | return getYawTransform(alpha, _device)
14 |
15 |
16 | def getYTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
17 | return getPitchTransform(alpha, _device)
18 |
19 |
20 | def getXTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
21 | return getRollTransform(alpha, _device)
22 |
23 |
24 | def getYawTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
25 | rotZ = torch.tensor(
26 | [
27 | [math.cos(alpha), -math.sin(alpha), 0],
28 | [math.sin(alpha), math.cos(alpha), 0],
29 | [0, 0, 1],
30 | ],
31 | device=_device,
32 | )
33 |
34 | return rotZ
35 |
36 |
37 | def getPitchTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
38 | rotY = torch.tensor(
39 | [
40 | [math.cos(alpha), 0, math.sin(alpha)],
41 | [0, 1, 0],
42 | [-math.sin(alpha), 0, math.cos(alpha)],
43 | ],
44 | device=_device,
45 | )
46 |
47 | return rotY
48 |
49 |
50 | def getRollTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor:
51 | rotX = torch.tensor(
52 | [
53 | [1, 0, 0],
54 | [0, math.cos(alpha), -math.sin(alpha)],
55 | [0, math.sin(alpha), math.cos(alpha)],
56 | ],
57 | device=_device,
58 | )
59 |
60 | return rotX
61 |
62 |
63 | def vector_dot(A: torch.tensor, B: torch.tensor) -> torch.tensor:
64 | return torch.sum(A * B, dim=-1)
65 |
66 |
67 | def rotation_matrix_from_vectors(v1, v2):
68 | """
69 | Calculates the rotation matrix that transforms vector v1 to v2.
70 |
71 | Args:
72 | - v1 (torch.Tensor): The source 3D vector (3x1).
73 | - v2 (torch.Tensor): The target 3D vector (3x1).
74 |
75 | Returns:
76 | - torch.Tensor: The 3x3 rotation matrix.
77 | """
78 | v1 = F.normalize(v1, dim=0)
79 | v2 = F.normalize(v2, dim=0)
80 |
81 | # Compute the cross product and dot product
82 | cross_product = torch.cross(v1, v2)
83 | dot_product = torch.dot(v1, v2)
84 |
85 | # Skew-symmetric matrix for cross product
86 | skew_sym_matrix = torch.tensor(
87 | [
88 | [0, -cross_product[2], cross_product[1]],
89 | [cross_product[2], 0, -cross_product[0]],
90 | [-cross_product[1], cross_product[0], 0],
91 | ],
92 | dtype=torch.float32,
93 | device=v1.device,
94 | )
95 |
96 | # Rotation matrix using Rodrigues' formula
97 | rotation_matrix = (
98 | torch.eye(3, device=v1.device)
99 | + skew_sym_matrix
100 | + torch.mm(skew_sym_matrix, skew_sym_matrix)
101 | * (1 - dot_product)
102 | / torch.norm(cross_product) ** 2
103 | )
104 |
105 | return rotation_matrix
106 |
107 |
108 | def rotation_matrix_from_vectors_with_fixed_up(
109 | v1, v2, up_vector=torch.tensor([0.0, 0.0, 1.0])
110 | ):
111 | """
112 | Calculates the rotation matrix that transforms vector v1 to v2, while keeping an "up" direction fixed.
113 |
114 | Args:
115 | - v1 (torch.Tensor): The source 3D vector (3x1).
116 | - v2 (torch.Tensor): The target 3D vector (3x1).
117 | - up_vector (torch.Tensor): The fixed "up" direction (3x1). Default is [0, 0, 1].
118 |
119 | Returns:
120 | - torch.Tensor: The 3x3 rotation matrix.
121 | """
122 | v1 = F.normalize(v1, dim=0)
123 | v2 = F.normalize(v2, dim=0)
124 | up_vector = F.normalize(up_vector, dim=0)
125 |
126 | # Compute the cross product and dot product
127 | cross_product = torch.cross(v1, v2)
128 | dot_product = torch.dot(v1, v2)
129 |
130 | # Skew-symmetric matrix for cross product
131 | skew_sym_matrix = torch.tensor(
132 | [
133 | [0, -cross_product[2], cross_product[1]],
134 | [cross_product[2], 0, -cross_product[0]],
135 | [-cross_product[1], cross_product[0], 0],
136 | ],
137 | dtype=torch.float32,
138 | device=v1.device,
139 | )
140 |
141 | # Rotation matrix using Rodrigues' formula
142 | rotation_matrix = (
143 | torch.eye(3, device=v1.device)
144 | + skew_sym_matrix
145 | + torch.mm(skew_sym_matrix, skew_sym_matrix)
146 | * (1 - dot_product)
147 | / torch.norm(cross_product) ** 2
148 | )
149 |
150 | # Ensure the "up" direction is fixed
151 | rotated_up_vector = torch.mv(rotation_matrix, up_vector)
152 | correction_axis = torch.cross(rotated_up_vector, up_vector)
153 | correction_angle = torch.acos(torch.dot(rotated_up_vector, up_vector))
154 |
155 | # Apply the correction to the rotation matrix
156 | correction_matrix = F.normalize(skew_sym_matrix, dim=0) * correction_angle
157 | corrected_rotation_matrix = torch.eye(3, device=v1.device) + correction_matrix
158 |
159 | return corrected_rotation_matrix
160 |
161 |
162 | def singleRandomBetweenTensors(a: torch.tensor, b: torch.tensor) -> torch.tensor:
163 | assert a.size() == b.size()
164 | assert a.device == b.device
165 |
166 | rands = random.uniform(0, 1)
167 | return rands * (b - a) + b
168 |
169 |
170 | def randomBetweenTensors(a: torch.tensor, b: torch.tensor) -> torch.tensor:
171 | assert a.size() == b.size()
172 | assert a.device == b.device
173 |
174 | rands = torch.rand(a.shape, device=a.device)
175 | return rands * (b - a) + a
176 |
177 |
178 | def normalize(tensor: torch.tensor) -> torch.tensor:
179 | tensor = tensor - tensor.amin()
180 | tensor = tensor / tensor.amax()
181 | return tensor
182 |
183 |
184 | def normalize_channelwise(
185 | tensor: torch.tensor,
186 | dim: int = -1,
187 | device: torch.cuda.device = torch.device("cuda"),
188 | ) -> torch.tensor:
189 | indices = torch.arange(0, len(tensor.shape), device=device)
190 | mask = torch.ones(indices.shape, dtype=torch.bool, device=device)
191 | mask[dim] = False
192 | indices = indices[mask].tolist()
193 |
194 | tensor = tensor - tensor.amin(indices)
195 | tensor = tensor / tensor.amax(indices)
196 | return tensor
197 |
198 |
199 | def convert_points_to_homogeneous(points: torch.tensor) -> torch.tensor:
200 | return torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0)
201 |
202 |
203 | def toMat4x4(mat: torch.tensor, addOne: bool = True) -> torch.tensor:
204 | mat4x4 = torch.nn.functional.pad(mat, pad=(0, 1, 0, 1), mode="constant", value=0.0)
205 |
206 | if addOne:
207 | mat4x4[3, 3] = 1.0
208 |
209 | return mat4x4
210 |
211 |
212 | def convert_points_from_homogeneous(points: torch.tensor) -> torch.tensor:
213 | return points[..., :-1] / points[..., -1:]
214 |
215 |
216 | def convert_points_to_nonhomogeneous(points: torch.tensor) -> torch.tensor:
217 | return torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=0.0)
218 |
219 |
220 | def transform_points(points: torch.tensor, transform: torch.tensor) -> torch.tensor:
221 | points_1_h = convert_points_to_homogeneous(points)
222 |
223 | points_0_h = torch.matmul(transform.unsqueeze(0), points_1_h.unsqueeze(-1))
224 |
225 | points_0_h = points_0_h.squeeze(dim=-1)
226 |
227 | points_0 = convert_points_from_homogeneous(points_0_h)
228 | return points_0
229 |
230 |
231 | def transform_directions(points: torch.tensor, transform: torch.tensor) -> torch.tensor:
232 | points = convert_points_to_nonhomogeneous(points)
233 | points = transform.unsqueeze(0) @ points.unsqueeze(-1)
234 | points = torch.squeeze(points, axis=-1)
235 | return points[..., :-1]
236 |
--------------------------------------------------------------------------------
/fireflies/entity/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List
3 |
4 | import fireflies.utils.math
5 | import fireflies.sampling
6 |
7 |
8 | class Transformable:
9 | def __init__(
10 | self,
11 | name: str,
12 | device: torch.cuda.device = torch.device("cuda"),
13 | ):
14 |
15 | self._device: torch.cuda.device = device
16 | self._name: str = name
17 |
18 | self._randomizable: bool = False
19 |
20 | self._parent = None
21 | self._child = None
22 |
23 | self._train = True
24 |
25 | self._float_attributes = {}
26 | self._randomized_float_attributes = {}
27 |
28 | self._vec3_attributes = {}
29 | self._randomized_vec3_attributes = {}
30 |
31 | zeros = torch.zeros(3, device=self._device)
32 | self._rotation_sampler = fireflies.sampling.UniformSampler(
33 | zeros.clone(), zeros.clone()
34 | )
35 |
36 | self._translation_sampler = fireflies.sampling.UniformSampler(
37 | zeros.clone(), zeros.clone()
38 | )
39 |
40 | self._world = torch.eye(4, device=self._device)
41 | self._randomized_world = torch.eye(4, device=self._device)
42 |
43 | self._centroid_mat = torch.zeros((4, 4), device=self._device)
44 |
45 | self._eval_delta = 0.01
46 | self._num_updates = 0
47 |
48 | def randomizable(self) -> bool:
49 | return self._randomizable
50 |
51 | def set_centroid(self, centroid: torch.tensor) -> None:
52 | self._centroid_mat[0, 3] = centroid.squeeze()[0]
53 | self._centroid_mat[1, 3] = centroid.squeeze()[1]
54 | self._centroid_mat[2, 3] = centroid.squeeze()[2]
55 |
56 | def set_randomizable(self, randomizable: bool) -> None:
57 | self._randomizable = randomizable
58 |
59 | def get_randomized_vec3_attributes(self) -> dict:
60 | return self._randomized_vec3_attributes
61 |
62 | def get_randomized_float_attributes(self) -> dict:
63 | return self._randomized_float_attributes
64 |
65 | def vec3_attributes(self) -> dict:
66 | return self._vec3_attributes
67 |
68 | def float_attributes(self) -> dict:
69 | return self._float_attributes
70 |
71 | def add_float_sampler(self, key: str, sampler: fireflies.sampling.Sampler) -> None:
72 | self._randomizable = True
73 | self._float_attributes[key] = sampler
74 |
75 | def add_float_key(self, key: str, min: float, max: float) -> None:
76 | """Transforms float key into a Uniform Sampler"""
77 | self._randomizable = True
78 | self._float_attributes[key] = fireflies.sampling.UniformSampler(
79 | min, max, device=self._device
80 | )
81 |
82 | def add_vec3_key(self, key: str, min: torch.tensor, max: torch.tensor) -> None:
83 | """Transforms vec3 into Uniform Sampler"""
84 | self._randomizable = True
85 | self._vec3_attributes[key] = fireflies.sampling.UniformSampler(
86 | min, max, device=self._device
87 | )
88 |
89 | def add_vec3_sampler(self, key: str, sampler: fireflies.sampling.Sampler) -> None:
90 | self._randomizable = True
91 | self._vec3_attributes[key] = sampler
92 |
93 | def parent(self):
94 | return self._parent
95 |
96 | def child(self):
97 | return self._child
98 |
99 | def name(self):
100 | return self._name
101 |
102 | def train(self) -> None:
103 | self._train = True
104 | self._translation_sampler.train()
105 | self._rotation_sampler.train()
106 |
107 | for sampler in self._float_attributes.values():
108 | sampler.train()
109 |
110 | for sampler in self._vec3_attributes.values():
111 | sampler.train()
112 |
113 | def eval(self) -> None:
114 | self._train = False
115 | self._translation_sampler.eval()
116 | self._rotation_sampler.eval()
117 |
118 | for sampler in self._float_attributes.values():
119 | sampler.eval()
120 |
121 | for sampler in self._vec3_attributes.values():
122 | sampler.eval()
123 |
124 | def set_world(self, _origin: torch.tensor) -> None:
125 | self._world = _origin
126 | self._randomized_world = self._world.clone()
127 |
128 | def setParent(self, parent) -> None:
129 | self._parent = parent
130 | parent.setChild(self)
131 |
132 | def setChild(self, child) -> None:
133 | self._child = child
134 |
135 | def set_rotation_sampler(self, sampler: fireflies.sampling.Sampler) -> None:
136 | self._rotation_sampler = sampler
137 |
138 | def set_translation_sampler(self, sampler: fireflies.sampling.Sampler) -> None:
139 | self._translation_sampler = sampler
140 |
141 | def update_index_from_sampler(self, sampler, min, max, index) -> None:
142 | sampler_min = sampler.get_min()
143 | sampler_max = sampler.get_max()
144 |
145 | sampler_min[index] = min
146 | sampler_max[index] = max
147 |
148 | def rotate_x(self, min_rot: float, max_rot: float) -> None:
149 | """Convenience function for Uniform Sampler"""
150 | self._randomizable = True
151 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 0)
152 |
153 | def rotate_y(self, min_rot: float, max_rot: float) -> None:
154 | """Convenience function for Uniform Sampler"""
155 | self._randomizable = True
156 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 1)
157 |
158 | def rotate_z(self, min_rot: float, max_rot: float) -> None:
159 | """Convenience function for Uniform Sampler"""
160 | self._randomizable = True
161 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 2)
162 |
163 | def rotate(self, min: torch.tensor, max: torch.tensor) -> None:
164 | """Convenience function for Uniform Sampler"""
165 | self._randomizable = True
166 | self._rotation_sampler.set_sample_interval(
167 | min.to(self._device), max.to(self._device)
168 | )
169 |
170 | def translate_x(self, min_translation: float, max_translation: float) -> None:
171 | self._randomizable = True
172 | self.update_index_from_sampler(
173 | self._translation_sampler, min_translation, max_translation, 0
174 | )
175 |
176 | def translate_y(self, min_translation: float, max_translation: float) -> None:
177 | self._randomizable = True
178 | self.update_index_from_sampler(
179 | self._translation_sampler, min_translation, max_translation, 1
180 | )
181 |
182 | def translate_z(self, min_translation: float, max_translation: float) -> None:
183 | self._randomizable = True
184 | self.update_index_from_sampler(
185 | self._translation_sampler, min_translation, max_translation, 2
186 | )
187 |
188 | def translate(self, min: torch.tensor, max: torch.tensor) -> None:
189 | self._randomizable = True
190 | self._translation_sampler.set_sample_interval(
191 | min.to(self._device), max.to(self._device)
192 | )
193 |
194 | def sample_rotation(self) -> torch.tensor:
195 | self._sampled_rotation = self._rotation_sampler.sample()
196 |
197 | zMat = fireflies.utils.math.getPitchTransform(
198 | self._sampled_rotation[2], self._device
199 | )
200 | yMat = fireflies.utils.math.getYawTransform(
201 | self._sampled_rotation[1], self._device
202 | )
203 | xMat = fireflies.utils.math.getRollTransform(
204 | self._sampled_rotation[0], self._device
205 | )
206 |
207 | return fireflies.utils.math.toMat4x4(zMat @ yMat @ xMat)
208 |
209 | def sample_translation(self) -> torch.tensor:
210 | translation = torch.eye(4, device=self._device)
211 |
212 | self._random_translation = self._translation_sampler.sample()
213 |
214 | translation[0, 3] = self._random_translation[0]
215 | translation[1, 3] = self._random_translation[1]
216 | translation[2, 3] = self._random_translation[2]
217 | self._last_translation = translation
218 | return translation
219 |
220 | def randomize(self) -> None:
221 | if not self.randomizable():
222 | return
223 |
224 | self._randomized_world = (
225 | (self.sample_translation() + self._centroid_mat)
226 | @ self.sample_rotation()
227 | @ self._world
228 | )
229 |
230 | for key, sampler in self._float_attributes.items():
231 | self._randomized_float_attributes[key] = sampler.sample()
232 |
233 | for key, sampler in self._vec3_attributes.items():
234 | self._randomized_vec3_attributes[key] = sampler.sample()
235 |
236 | def relative(self) -> None:
237 | return self._parent is not None
238 |
239 | def world(self) -> torch.tensor:
240 | # If no parent exists, just return the current translation
241 | if self._parent is None:
242 | return self._randomized_world.clone()
243 |
244 | return self._parent.world() @ self._randomized_world
245 |
246 | def nonRandomizedWorld(self) -> torch.tensor:
247 | if self._parent is None:
248 | return self._world
249 |
250 | return self._parent.nonRandomizedWorld() @ self._world
251 |
--------------------------------------------------------------------------------
/fireflies/projection/laser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 |
5 | import fireflies.graphics.rasterization
6 | import fireflies.utils.math
7 | import fireflies.sampling.poisson
8 | import fireflies.entity
9 |
10 | import fireflies.projection as projection
11 | import yaml
12 |
13 | from typing import List
14 |
15 |
16 | class Laser(projection.Camera):
17 | # Static Convenience Function
18 | @staticmethod
19 | def generate_uniform_rays(
20 | intra_ray_angle: float,
21 | num_beams_x: int,
22 | num_beams_y: int,
23 | device: torch.cuda.device = torch.device("cuda"),
24 | ) -> torch.tensor:
25 | laserRays = torch.zeros((num_beams_y * num_beams_x, 3), device=device)
26 |
27 | for x in range(num_beams_x):
28 | for y in range(num_beams_y):
29 | laserRays[x * num_beams_x + y, :] = torch.tensor(
30 | [
31 | math.tan((x - (num_beams_x - 1) / 2) * intra_ray_angle),
32 | math.tan((y - (num_beams_y - 1) / 2) * intra_ray_angle),
33 | -1.0,
34 | ]
35 | )
36 |
37 | return laserRays / torch.linalg.norm(laserRays, dim=-1, keepdims=True)
38 |
39 | @staticmethod
40 | def generate_uniform_rays_by_count(
41 | num_beams_x: int,
42 | num_beams_y: int,
43 | intrinsic_matrix: torch.tensor,
44 | device: torch.cuda.device = torch.device("cuda"),
45 | ) -> torch.tensor:
46 | laserRays = torch.zeros((num_beams_y * num_beams_x, 3), device=device)
47 |
48 | x_steps = torch.arange((1 / num_beams_x) / 2, 1, 1 / num_beams_x)
49 | y_steps = torch.arange((1 / num_beams_y) / 2, 1, 1 / num_beams_y)
50 |
51 | xy = torch.stack(torch.meshgrid(x_steps, y_steps))
52 | xy = xy.movedim(0, -1).reshape(-1, 2)
53 |
54 | # Set Z to 1
55 | laserRays[:, 0:2] = xy
56 | laserRays[:, 2] = -1.0
57 |
58 | # Project to world
59 | rays = fireflies.utils.transforms.transform_points(
60 | laserRays, intrinsic_matrix.inverse()
61 | )
62 |
63 | # Normalize
64 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True)
65 | rays[:, 2] *= -1.0
66 | return rays
67 |
68 | @staticmethod
69 | def generate_random_rays(
70 | num_beams: int,
71 | intrinsic_matrix: torch.tensor,
72 | device: torch.cuda.device = torch.device("cuda"),
73 | ) -> torch.tensor:
74 |
75 | # Random points and move into NDC
76 | spawned_points = (
77 | torch.ones([num_beams, 3], device=device) * 0.5
78 | + (torch.rand([num_beams, 3], device=device) - 0.5) / 10.0
79 | )
80 |
81 | # Set Z to 1
82 | spawned_points[:, 2] = -1.0
83 |
84 | # Project to world
85 | rays = fireflies.utils.transforms.transform_points(
86 | spawned_points, intrinsic_matrix.inverse()
87 | )
88 |
89 | # Normalize
90 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True)
91 | rays[:, 2] *= -1.0
92 | return rays
93 |
94 | @staticmethod
95 | def generate_blue_noise_rays(
96 | image_size_x: int,
97 | image_size_y: int,
98 | num_beams: int,
99 | intrinsic_matrix: torch.tensor,
100 | device: torch.cuda.device = torch.device("cuda"),
101 | ) -> torch.tensor:
102 |
103 | # We want to know the radius of the poisson disk so that we get roughly N beams
104 | #
105 | # So we say N < (X*Y) / PI*r*r <=> sqrt(X*Y / PI*N) ~ r
106 | #
107 |
108 | poisson_radius = math.sqrt(
109 | (image_size_x * image_size_y) / (math.pi * num_beams)
110 | )
111 | poisson_radius += poisson_radius / 4.0
112 | im = np.ones([image_size_x, image_size_y]) * poisson_radius
113 | num_samples, poisson_samples = fireflies.sampling.poisson.bridson(im)
114 | # print(len(poisson_samples))
115 | poisson_samples = torch.tensor(poisson_samples, device=device)
116 |
117 | # Remove random points from poisson samples such that num_beams is correct again.
118 | # indices = torch.linspace(0, poisson_samples.shape[0] - 1, poisson_samples.shape[0], device=device)
119 | # indices = torch.multinomial(indices, num_beams, replacement=False).long()
120 | # poisson_samples = poisson_samples[indices]
121 |
122 | # From image space to 0 1
123 | poisson_samples /= torch.tensor([image_size_x, image_size_y], device=device)
124 |
125 | # Create empty tensor for copying
126 | temp = torch.ones([poisson_samples.shape[0], 3], device=device) * -1.0
127 |
128 | # Copy to temp and add 1 for z coordinate
129 | temp[:, 0:2] = poisson_samples
130 | # temp[:, 0:2] = poisson_samples
131 |
132 | # Project to world
133 | rays = fireflies.utils.math.transform_points(temp, intrinsic_matrix.inverse())
134 |
135 | # rays = fireflies.utils.transforms.transform_points(
136 | # rays,
137 | # fireflies.utils.transforms.toMat4x4(
138 | # utils_math.getZTransform(0.5 * np.pi, intrinsic_matrix.device)
139 | # ),
140 | # )
141 |
142 | # Normalize
143 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True)
144 | rays[:, 2] *= -1.0
145 | return rays
146 |
147 | def __init__(
148 | self,
149 | transformable: fireflies.entity.base.Transformable,
150 | ray_directions,
151 | perspective: torch.tensor,
152 | max_fov: float,
153 | near_clip: float = 0.01,
154 | far_clip: float = 1000.0,
155 | device: torch.cuda.device = torch.device("cuda"),
156 | ):
157 | super(Laser, self).__init__(
158 | transformable, perspective, max_fov, near_clip, far_clip, device
159 | )
160 | self._rays = ray_directions.to(self.device)
161 | self.device = device
162 |
163 | def rays(self) -> torch.tensor:
164 | return fireflies.utils.math.transform_directions(
165 | self._rays,
166 | self._fireflies.transformable.transformable.Transformable.world(),
167 | )
168 |
169 | def origin(self) -> torch.tensor:
170 | return self._fireflies.transformable.transformable.Transformable.world()
171 |
172 | def originPerRay(self) -> torch.tensor:
173 | return (
174 | self._fireflies.transformable.transformable.Transformable.world()[0:3, 3]
175 | .unsqueeze(0)
176 | .repeat(self._rays.shape[0], 1)
177 | )
178 |
179 | def near_clip(self) -> float:
180 | return self._near_clip
181 |
182 | def far_clip(self) -> float:
183 | return self._far_clip
184 |
185 | def initRandomRays(self):
186 | # Spawn random points in [-1.0, 1.0]
187 | spawned_points = torch.rand(self._rays.shape, device=self.device) * 2.0 - 1.0
188 |
189 | # Set Z to 1
190 | spawned_points[:, 2] = 1.0
191 |
192 | # Project to world
193 | rand_rays = self.projectNDCPointsToWorld(spawned_points)
194 | self._rays = self.normalize(rand_rays)
195 |
196 | def initPoissonDiskSamples(self, width, height, radius):
197 | return None
198 |
199 | def clamp_to_fov(self, clamp_val: float = 0.95, epsilon: float = 0.0001) -> None:
200 | # TODO: Check, if laser beam falls out of fov. If it does, clamp it back.
201 | # If randomize is set, spawn a new random laser inside NDC.
202 | # Else, clamp it to the edge.
203 | ndc_coords = self.projectRaysToNDC()
204 | ndc_coords[:, 0:2] = torch.clamp(ndc_coords[:, 0:2], 1 - clamp_val, clamp_val)
205 | clamped_rays = self.projectNDCPointsToWorld(ndc_coords)
206 | self._rays[:] = self.normalize(clamped_rays)
207 |
208 | def randomize_laser_out_of_bounds(self) -> None:
209 | # TODO: Check, if laser beam falls out of fov. If it does, spawn a new randomly in NDC in (-1, 1).
210 | new_rays = self._rays.clone()
211 |
212 | # No need to transform as rays are in laser space anyway
213 | ndc_coords = fireflies.utils.transforms.transform_points(
214 | new_rays, self._perspective
215 | )
216 | xy_coords = ndc_coords[:, 0:2]
217 | out_of_bounds_indices = ((xy_coords >= 1.0) | (xy_coords <= 0.0)).any(dim=1)
218 |
219 | out_of_bounds_points = ndc_coords[out_of_bounds_indices]
220 |
221 | if out_of_bounds_points.nelement() == 0:
222 | return 0
223 |
224 | new_ray_point = torch.rand(out_of_bounds_points.shape, device=self.device)
225 | new_ray_point[:, 2] = -1.0
226 |
227 | clamped_rays = self.projectNDCPointsToWorld(new_ray_point)
228 | new_rays[out_of_bounds_indices] = clamped_rays
229 | new_rays = self.normalize(new_rays)
230 |
231 | self._rays[:] = new_rays
232 |
233 | def randomize_camera_out_of_bounds(self, ndc_coords) -> None:
234 | new_rays = self._rays.clone()
235 | xy_coords = ndc_coords[:, 0:2]
236 | out_of_bounds_indices = ((xy_coords >= 1.0) | (xy_coords <= -1.0)).any(dim=1)
237 | out_of_bounds_points = ndc_coords[out_of_bounds_indices]
238 |
239 | if out_of_bounds_points.nelement() == 0:
240 | return 0
241 |
242 | new_ray_point = torch.rand(out_of_bounds_points.shape, device=self.device)
243 | new_ray_point[:, 2] = -1.0
244 |
245 | clamped_rays = self.projectNDCPointsToWorld(new_ray_point)
246 | new_rays[out_of_bounds_indices] = clamped_rays
247 | new_rays = self.normalize(new_rays)
248 |
249 | self._rays[:] = new_rays
250 |
251 | def normalize(self, tensor: torch.tensor) -> torch.tensor:
252 | return tensor / torch.linalg.norm(tensor, dim=-1, keepdims=True)
253 |
254 | def normalize_rays(self) -> None:
255 | self._rays[:] = self.normalize(self._rays)
256 |
257 | def setToWorld(self, to_world: torch.tensor) -> None:
258 | self._to_world = (
259 | self._fireflies.transformable.transformable.Transformable.setWorld(to_world)
260 | )
261 |
262 | def projectRaysToNDC(self) -> torch.tensor:
263 | # rays_in_world = fireflies.utils.transforms_torch.transform_directions(self._rays, self._to_world)
264 | FLIP_Y = torch.tensor(
265 | [
266 | [1.0, 0.0, 0.0, 0.0],
267 | [0.0, -1.0, 0.0, 0.0],
268 | [0.0, 0.0, 1.0, 0.0],
269 | [0.0, 0.0, 0.0, 1.0],
270 | ],
271 | device=self.device,
272 | )
273 | return fireflies.utils.math.transform_points(
274 | self._rays, self._perspective @ FLIP_Y
275 | )
276 |
277 | def projectNDCPointsToWorld(self, points: torch.tensor) -> torch.tensor:
278 | FLIP_Y = torch.tensor(
279 | [
280 | [1.0, 0.0, 0.0, 0.0],
281 | [0.0, -1.0, 0.0, 0.0],
282 | [0.0, 0.0, 1.0, 0.0],
283 | [0.0, 0.0, 0.0, 1.0],
284 | ],
285 | device=self.device,
286 | )
287 |
288 | return fireflies.utils.math.transform_points(
289 | points, (self._perspective @ FLIP_Y).inverse()
290 | )
291 |
292 | def generateTexture(self, sigma: float, texture_size: List[int]) -> torch.tensor:
293 | points = self.projectRaysToNDC()[:, 0:2]
294 | return fireflies.graphics.rasterization.rasterize_points(
295 | points.cpu(), sigma, texture_size.cpu(), device="cpu"
296 | )
297 |
298 | def render_epipolar_lines(
299 | self, sigma: float, texture_size: torch.tensor
300 | ) -> torch.tensor:
301 | epipolar_min = self.originPerRay() + self._near_clip * self.rays()
302 | epipolar_max = self.originPerRay() + self._far_clip * self.rays()
303 |
304 | CAMERA_TO_WORLD = self._fireflies.entity.Transformable.world()
305 | WORLD_TO_CAMERA = CAMERA_TO_WORLD.inverse()
306 |
307 | epipolar_max = fireflies.utils.transforms.transform_points(
308 | epipolar_max, WORLD_TO_CAMERA
309 | )
310 | epipolar_max = fireflies.utils.transforms.transform_points(
311 | epipolar_max, self._perspective
312 | )[:, 0:2]
313 |
314 | epipolar_min = fireflies.utils.transforms.transform_points(
315 | epipolar_min, WORLD_TO_CAMERA
316 | )
317 | epipolar_min = fireflies.utils.transforms.transform_points(
318 | epipolar_min, self._perspective
319 | )[:, 0:2]
320 |
321 | lines = torch.stack([epipolar_min, epipolar_max], dim=1)
322 |
323 | return fireflies.graphics.rasterization.rasterize_lines(
324 | lines, sigma, texture_size
325 | )
326 |
327 | def save(self, filepath: str):
328 | save_dict = {
329 | "rays": self._rays.detach().cpu().numpy().tolist(),
330 | "fov": self._fov,
331 | "near_clip": self._near_clip,
332 | "far_clip": self._far_clip,
333 | }
334 |
335 | with open(filepath, "w") as file:
336 | yaml.dump(save_dict, file)
337 |
--------------------------------------------------------------------------------
/fireflies/utils/laser_estimation.py:
--------------------------------------------------------------------------------
1 | import mitsuba as mi
2 | import cv2
3 | import numpy as np
4 | import torch
5 |
6 | from scipy.spatial import ConvexHull
7 | import numpy as np
8 |
9 | import fireflies.sampling.poisson
10 |
11 | import fireflies.entity
12 | import fireflies.projection
13 |
14 | import fireflies.graphics.depth
15 | import fireflies.graphics.rasterization
16 |
17 | import fireflies.utils.math
18 | import fireflies.utils.transforms
19 | import fireflies.utils.intersections
20 |
21 |
22 | import math
23 |
24 |
25 | def probability_distribution_from_depth_maps(
26 | depth_maps: np.array, uniform_weight: float = 0.0
27 | ) -> np.array:
28 |
29 | variance_map = depth_maps.std(axis=0)
30 | variance_map += uniform_weight
31 |
32 | return variance_map
33 |
34 |
35 | def points_from_probability_distribution(
36 | prob_distribution: torch.tensor, num_samples: int
37 | ) -> torch.tensor:
38 |
39 | p = prob_distribution.flatten()
40 | chosen_points = p.multinomial(num_samples, replacement=False)
41 |
42 | return chosen_points
43 |
44 |
45 | def get_camera_direction(sensor, device: torch.cuda.device) -> torch.tensor:
46 | film = sensor.film()
47 | sampler = sensor.sampler()
48 | film_size = film.size()
49 | total_samples = 1
50 |
51 | if sampler.wavefront_size() != total_samples:
52 | sampler.seed(0, total_samples)
53 |
54 | # Enumerate discrete sample & pixel indices, and uniformly sample
55 | # positions within each pixel.
56 | # pos = mi.UInt32(points.split(split_size=1))
57 |
58 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
59 | pos = mi.Vector2f(mi.Float(0.5), mi.Float(0.5))
60 |
61 | # pos += sampler.next_2d()
62 |
63 | # Sample rays starting from the camera sensor
64 | rays, weights = sensor.sample_ray(
65 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0
66 | )
67 |
68 | return rays.o.torch(), rays.d.torch()
69 |
70 |
71 | def get_camera_frustum(sensor, device: torch.cuda.device) -> torch.tensor:
72 | # film = sensor.film()
73 | sampler = sensor.sampler()
74 | # film_size = film.size()
75 | # total_samples = 4
76 |
77 | # if sampler.wavefront_size() != total_samples:
78 | # sampler.seed(0, total_samples)
79 |
80 | # Enumerate discrete sample & pixel indices, and uniformly sample
81 | # positions within each pixel.
82 | # pos = mi.UInt32(points.split(split_size=1))
83 |
84 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
85 | pos = mi.Vector2f(mi.Float([0.0, 1.0, 0.0, 1.0]), mi.Float([0.0, 0.0, 1.0, 1.0]))
86 |
87 | # pos += sampler.next_2d()
88 |
89 | # Sample rays starting from the camera sensor
90 | rays, weights = sensor.sample_ray(
91 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0
92 | )
93 |
94 | ray_origins = rays.o.torch()
95 | ray_directions = rays.d.torch()
96 |
97 | # x_transform = transforms.toMat4x4(utils_math.getXTransform(np.pi*0.5, ray_origins.device))
98 | # ray_origins = transforms.transform_points(ray_origins, x_transform)
99 | # ray_directions = transforms.transform_directions(ray_directions, x_transform)
100 |
101 | return ray_origins, ray_directions
102 |
103 |
104 | def getRayFromSensor(sensor, ray_coordinate_in_ndc):
105 | sampler = sensor.sampler()
106 |
107 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
108 | pos = mi.Vector2f(
109 | mi.Float([ray_coordinate_in_ndc[0]]), mi.Float([ray_coordinate_in_ndc[1]])
110 | )
111 |
112 | # Sample rays starting from the camera sensor
113 | rays, weights = sensor.sample_ray(
114 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0
115 | )
116 |
117 | return rays.o.torch(), rays.d.torch()
118 |
119 |
120 | def create_rays(sensor, points) -> torch.tensor:
121 | film = sensor.film()
122 | sampler = sensor.sampler()
123 | film_size = film.size()
124 | total_samples = points.shape[0]
125 |
126 | if sampler.wavefront_size() != total_samples:
127 | sampler.seed(0, total_samples)
128 |
129 | # Enumerate discrete sample & pixel indices, and uniformly sample
130 | # positions within each pixel.
131 | pos = mi.UInt32(points.split(split_size=1))
132 |
133 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1])
134 | pos = mi.Vector2f(
135 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0]))
136 | )
137 |
138 | # pos += sampler.next_2d()
139 |
140 | # Sample rays starting from the camera sensor
141 | rays, weights = sensor.sample_ray(
142 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0
143 | )
144 |
145 | return rays.o.torch(), rays.d.torch()
146 |
147 |
148 | def laser_from_ndc_points(
149 | sensor, laser_origin, depth_maps, chosen_points, device: torch.cuda.device("cuda")
150 | ) -> torch.tensor:
151 | ray_origins, ray_directions = create_rays(sensor, chosen_points)
152 |
153 | # Get camera origin and direction
154 | camera_origin, camera_direction = get_camera_direction(sensor, device)
155 |
156 | camera_origin = sensor.world_transform().translation().torch()
157 |
158 | camera_direction = camera_direction / torch.linalg.norm(
159 | camera_direction, dim=-1, keepdims=True
160 | )
161 |
162 | # Build plane from depth map
163 | plane_origin = camera_origin + camera_direction * depth_maps.mean()
164 | plane_normal = -camera_direction
165 |
166 | # Compute intersections inbetween mean plane and randomly sampled rays
167 | intersection_distances = fireflies.utils.intersections.rayPlane(
168 | ray_origins, ray_directions, plane_origin, plane_normal
169 | )
170 | world_points = ray_origins + ray_directions * intersection_distances
171 |
172 | laser_dir = world_points - laser_origin
173 | laser_dir = laser_dir / torch.linalg.norm(laser_dir, dim=-1, keepdims=True)
174 | return laser_dir
175 |
176 |
177 | def draw_lines(ax, rayOrigin, rayDirection, ray_length=1.0, color="g"):
178 | for i in range(rayDirection.shape[0]):
179 | ax.plot(
180 | [rayOrigin[i, 0], rayOrigin[i, 0] + ray_length * rayDirection[i, 0]],
181 | [rayOrigin[i, 1], rayOrigin[i, 1] + ray_length * rayDirection[i, 1]],
182 | [rayOrigin[i, 2], rayOrigin[i, 2] + ray_length * rayDirection[i, 2]],
183 | color=color,
184 | )
185 |
186 |
187 | def generate_epipolar_constraints(scene, params, device):
188 |
189 | camera_sensor = scene.sensors()[0]
190 |
191 | projector_sensor = scene.sensors()[1]
192 | proj_xwidth, proj_ywidth = projector_sensor.film().size()
193 |
194 | ray_origins, ray_directions = get_camera_frustum(projector_sensor, device)
195 | camera_origins, camera_directions = get_camera_frustum(camera_sensor, device)
196 |
197 | near_clip = params["PerspectiveCamera_1.near_clip"]
198 | far_clip = params["PerspectiveCamera_1.far_clip"]
199 | # steps = 1
200 | # delta = (far_clip - near_clip / steps)
201 |
202 | projection_points = ray_origins + far_clip * ray_directions
203 | epipolar_points = projection_points
204 |
205 | # K = utils.build_projection_matrix(params['PerspectiveCamera.x_fov'], params['PerspectiveCamera.near_clip'], params['PerspectiveCamera.far_clip'])
206 | K = mi.perspective_projection(
207 | camera_sensor.film().size(),
208 | camera_sensor.film().crop_size(),
209 | camera_sensor.film().crop_offset(),
210 | params["PerspectiveCamera.x_fov"],
211 | params["PerspectiveCamera.near_clip"],
212 | params["PerspectiveCamera.far_clip"],
213 | ).matrix.torch()[0]
214 | CAMERA_WORLD = params["PerspectiveCamera.to_world"].matrix.torch()[0]
215 | # CAMERA_WORLD[0:3, 0:3] = CAMERA_WORLD[0:3, 0:3] @ utils_math.getYTransform(np.pi, CAMERA_WORLD.device)
216 |
217 | # mi.perspective_transformation(scene.sensors()[0].film.size())
218 |
219 | epipolar_points = transforms.transform_points(
220 | epipolar_points, CAMERA_WORLD.inverse()
221 | )
222 | epipolar_points = transforms.transform_points(epipolar_points, K)[:, 0:2]
223 |
224 | # Is in [0 -> 1]
225 | epi_points_np = epipolar_points.detach().cpu().numpy()
226 |
227 | hull = ConvexHull(epi_points_np)
228 | line_segments = epipolar_points[hull.vertices]
229 |
230 | # We could also calculate the fundamental matrix
231 | # and use this to estimate epipolar lines here
232 | # However, we
233 | # Find closest point between min and max
234 | # Replace this point by the epipolar minimum
235 | # This gives us the convex hull of the epipolar constraints
236 | # In clockwise order
237 |
238 | camera_size = np.array(camera_sensor.film().crop_size())
239 | # camera_size = camera_size[[1, 0]] # swap image size to Y,X
240 |
241 | epi_points_np = line_segments.cpu().numpy()
242 | # epi_points_np = epi_points_np[:, [1, 0]]
243 | epi_points_np *= camera_size
244 |
245 | image = np.zeros(camera_size[[1, 0]], dtype=np.uint8)
246 | image = cv2.fillPoly(image, [epi_points_np.astype(int)], color=1)
247 | # cv2.imshow("Epipolar Image", image * 255)
248 | # cv2.waitKey(0)
249 |
250 | return torch.from_numpy(image).to(device)
251 |
252 |
253 | def initialize_laser(
254 | mitsuba_scene, mitsuba_params, firefly_scene, config, mode, device
255 | ):
256 | projector_sensor = mitsuba_scene.sensors()[1]
257 |
258 | near_clip = mitsuba_scene.sensors()[1].near_clip()
259 | far_clip = mitsuba_scene.sensors()[1].far_clip()
260 | laser_fov = float(mitsuba_params["PerspectiveCamera_1.x_fov"][0])
261 | near_clip = mitsuba_scene.sensors()[1].near_clip()
262 |
263 | radians = math.pi / 180.0
264 |
265 | image_size = torch.tensor(mitsuba_scene.sensors()[1].film().size(), device=device)
266 | LASER_K = mi.perspective_projection(
267 | projector_sensor.film().size(),
268 | projector_sensor.film().crop_size(),
269 | projector_sensor.film().crop_offset(),
270 | laser_fov,
271 | near_clip,
272 | far_clip,
273 | ).matrix.torch()[0]
274 | n_beams = config.n_beams
275 |
276 | local_laser_dir = None
277 | if mode == "RANDOM":
278 | local_laser_dir = fireflies.projection.Laser.generate_random_rays(
279 | num_beams=n_beams, intrinsic_matrix=LASER_K, device=device
280 | )
281 | elif mode == "POISSON":
282 | local_laser_dir = fireflies.projection.Laser.generate_blue_noise_rays(
283 | image_size_x=image_size[0],
284 | image_size_y=image_size[1],
285 | num_beams=n_beams,
286 | intrinsic_matrix=LASER_K,
287 | device=device,
288 | )
289 | elif mode == "GRID":
290 | grid_width = int(math.sqrt(config.n_beams))
291 | local_laser_dir = fireflies.projection.Laser.generate_uniform_rays_by_count(
292 | num_beams_x=grid_width,
293 | num_beams_y=grid_width,
294 | intrinsic_matrix=LASER_K,
295 | device=device,
296 | )
297 | elif mode == "SMARTY":
298 | # Doesnt work, IDK why
299 | constraint_map = generate_epipolar_constraints(
300 | mitsuba_scene, mitsuba_params, device
301 | )
302 |
303 | # Generate random depth maps by uniformly sampling from scene parameter ranges
304 | # print(config.n_depthmaps)
305 | depth_maps = fireflies.graphics.depth.random_depth_maps(
306 | firefly_scene, mitsuba_scene, num_maps=config.n_depthmaps
307 | )
308 |
309 | # Given depth maps, generate probability distribution
310 | variance_map = probability_distribution_from_depth_maps(
311 | depth_maps, config.variational_epsilon
312 | )
313 | variance_map = fireflies.utils.math.normalize(variance_map)
314 | vm = (variance_map.cpu().numpy() * 255).astype(np.uint8)
315 | vm = cv2.applyColorMap(vm, cv2.COLORMAP_INFERNO)
316 |
317 | # cv2.imshow("Variance Map", vm)
318 | # cv2.waitKey(0)
319 |
320 | # Final multiplication and normalization
321 | final_sampling_map = variance_map # * constraint_map
322 | final_sampling_map /= final_sampling_map.sum()
323 |
324 | # Gotta flip this in y direction, since apparently I can't program
325 | # final_sampling_map = torch.fliplr(final_sampling_map)
326 | # final_sampling_map = torch.flip(final_sampling_map, (0,))
327 |
328 | # sample points for laser rays
329 |
330 | min_radius = config.smarty_min_radius
331 | max_radius = config.smarty_max_radius
332 | normalized_sampling = 1 - fireflies.utils.math.normalize(final_sampling_map)
333 | normalized_sampling = (
334 | min_radius + (max_radius - min_radius) * normalized_sampling
335 | )
336 | n_points, points = fireflies.sampling.poisson.bridson(
337 | normalized_sampling.detach().cpu().numpy(), 50
338 | )
339 | points = torch.from_numpy(points).to(device).floor().int()
340 | chosen_points = points[:, 0] * final_sampling_map.shape[1] + points[:, 1]
341 |
342 | # chosen_points = points_from_probability_distribution(final_sampling_map, config.n_beams)
343 |
344 | vm = variance_map.cpu().numpy()
345 | cp = chosen_points.cpu().numpy()
346 | cm = constraint_map.cpu().numpy()
347 | """
348 | if config.save_images:
349 | vm = (vm * 255).astype(np.uint8)
350 | vm = cv2.applyColorMap(vm, cv2.COLORMAP_INFERNO)
351 | vm.reshape(-1, 3)[cp, :] = ~vm.reshape(-1, 3)[cp, :]
352 | cv2.imwrite("sampling_map.png", vm)
353 | cm = cm * 255
354 | cv2.imwrite("constraint_map.png", cm)
355 | """
356 |
357 | laser_world = firefly_scene.projector.world()
358 | laser_origin = laser_world[0:3, 3]
359 | # Sample directions of laser beams from variance map
360 | laser_dir = laser_from_ndc_points(
361 | mitsuba_scene.sensors()[0],
362 | laser_origin,
363 | depth_maps,
364 | chosen_points,
365 | device=device,
366 | )
367 |
368 | # Apply inverse rotation of the projector, such that we get a normalized direction
369 | # The laser direction up until now is in world coordinates!
370 | local_laser_dir = transforms.transform_directions(
371 | laser_dir, laser_world.inverse()
372 | )
373 |
374 | # local_laser_dir = transforms.transform_points(
375 | # local_laser_dir,
376 | # transforms.toMat4x4(utils_math.getZTransform(0.5 * np.pi, device)),
377 | # )
378 |
379 | # local_laser_dir = transforms.transform_points(
380 | # local_laser_dir,
381 | # transforms.toMat4x4(utils_math.getYTransform(-0.5 * np.pi, device)),
382 | # )
383 |
384 | return fireflies.projection.Laser(
385 | firefly_scene.projector,
386 | local_laser_dir,
387 | LASER_K,
388 | laser_fov,
389 | near_clip,
390 | far_clip,
391 | )
392 |
--------------------------------------------------------------------------------
/fireflies/scene.py:
--------------------------------------------------------------------------------
1 | import mitsuba as mi
2 |
3 | import torch
4 | import fireflies.entity
5 | import fireflies.utils
6 | import fireflies.emitter
7 | import fireflies.material
8 |
9 | from typing import List
10 |
11 |
12 | class Scene:
13 | MESH_KEYS = ["mesh", "ply"]
14 | CAM_KEYS = ["camera", "perspective", "perspectivecamera"]
15 | PROJ_KEYS = ["projector"]
16 | MAT_KEYS = ["mat", "bsdf"]
17 | LIGHT_KEYS = ["light", "spot"]
18 | TEX_KEYS = ["tex"]
19 |
20 | def __init__(
21 | self,
22 | mitsuba_params,
23 | device: torch.cuda.device = torch.device("cuda"),
24 | ):
25 | # Here, only objects are saved, that have a "randomizable"-tag inside the yaml file.
26 | self._meshes = []
27 | self._projector = None
28 | self._camera = None
29 | self._lights = []
30 | self._curves = []
31 | self._materials = []
32 |
33 | self._transformables = []
34 |
35 | self._device = device
36 |
37 | self._mitsuba_params = mitsuba_params
38 |
39 | self.init_from_params(self._mitsuba_params)
40 |
41 | def device(self) -> torch.cuda.device:
42 | return self._device
43 |
44 | def mesh_at(self, index: int) -> fireflies.entity.Transformable:
45 | return self._meshes[index]
46 |
47 | def meshes(self) -> fireflies.entity.Transformable:
48 | return self._meshes
49 |
50 | def get_mesh(self, name: str) -> fireflies.entity.Transformable:
51 | for mesh in self._meshes:
52 | if mesh.name() == name:
53 | return mesh
54 |
55 | return None
56 |
57 | def mesh(self, name: str) -> fireflies.entity.Transformable:
58 | return self.get_mesh(name)
59 |
60 | def light_at(self, index: int) -> fireflies.entity.Transformable:
61 | return self._lights[index]
62 |
63 | def lights(self) -> fireflies.entity.Transformable:
64 | return self._lights
65 |
66 | def get_light(self, name: str) -> fireflies.entity.Transformable:
67 | for light in self._lights:
68 | if light.name() == name:
69 | return light
70 |
71 | return None
72 |
73 | def light(self, name: str) -> fireflies.entity.Transformable:
74 | return self.get_light(name)
75 |
76 | def material_at(self, index: int) -> fireflies.entity.Transformable:
77 | return self._materials[index]
78 |
79 | def materials(self) -> fireflies.entity.Transformable:
80 | return self._materials
81 |
82 | def get_material(self, name: str) -> fireflies.entity.Transformable:
83 | for mesh in self._materials:
84 | if mesh.name() == name:
85 | return mesh
86 |
87 | return None
88 |
89 | def material(self, name: str) -> fireflies.entity.Transformable:
90 | return self.get_material(name)
91 |
92 | def init_from_params(self, mitsuba_params) -> None:
93 | # Get all scene keys
94 | param_keys = [key.split(".")[0] for key in mitsuba_params.keys()]
95 |
96 | # Remove multiples
97 | param_keys = set(param_keys)
98 | param_keys = sorted(param_keys)
99 |
100 | for key in param_keys:
101 | # Check if its a mesh
102 | if any(MESH_KEY.lower() in key.lower() for MESH_KEY in self.MESH_KEYS):
103 | self.load_mesh(key)
104 | continue
105 | elif any(CAMERA_KEY.lower() in key.lower() for CAMERA_KEY in self.CAM_KEYS):
106 | self.load_camera(key)
107 | continue
108 | elif any(PROJ_KEY.lower() in key.lower() for PROJ_KEY in self.PROJ_KEYS):
109 | self.load_projector(key)
110 | continue
111 | elif any(LIGHT_KEY.lower() in key.lower() for LIGHT_KEY in self.LIGHT_KEYS):
112 | self.load_light(key)
113 | continue
114 | elif any(MAT_KEY.lower() in key.lower() for MAT_KEY in self.MAT_KEYS):
115 | self.load_material(key)
116 | continue
117 |
118 | def load_mesh(self, base_key: str):
119 | # Gotta compute the centroid here, as mitsuba does not have a world transform for meshes
120 | vertices = torch.tensor(
121 | self._mitsuba_params[base_key + ".vertex_positions"], device=self._device
122 | ).reshape(-1, 3)
123 | centroid = vertices.sum(dim=0, keepdim=True) / vertices.shape[0]
124 |
125 | aligned_vertices = vertices - centroid
126 |
127 | transformable_mesh = fireflies.entity.Mesh(
128 | base_key, aligned_vertices, self._device
129 | )
130 | transformable_mesh.set_centroid(centroid)
131 |
132 | self._meshes.append(transformable_mesh)
133 |
134 | def load_camera(self, base_key: str) -> None:
135 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch()
136 | to_world = to_world.squeeze().to(self._device)
137 | transformable_camera = fireflies.entity.Transformable(base_key, self._device)
138 | transformable_camera.set_world(to_world)
139 | transformable_camera.set_randomizable(False)
140 | self._camera = transformable_camera
141 |
142 | def load_projector(self, base_key: str) -> None:
143 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch()
144 | to_world = to_world.squeeze().to(self._device)
145 | transformable_projector = fireflies.entity.Transformable(base_key, self._device)
146 | transformable_projector.set_world(to_world)
147 | transformable_projector.set_randomizable(False)
148 | self._projector = transformable_projector
149 |
150 | def load_light(self, base_key: str) -> None:
151 | new_light = fireflies.emitter.Light(base_key, device=self._device)
152 |
153 | if base_key + ".to_world" in self._mitsuba_params.keys():
154 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch()
155 | to_world = to_world.squeeze().to(self._device)
156 | new_light.set_world(to_world)
157 |
158 | light_keys = []
159 | for key in self._mitsuba_params.keys():
160 | if base_key in key:
161 | light_keys.append(key)
162 |
163 | for key in light_keys:
164 | key_without_base = ".".join(key.split(".")[1:])
165 | value = self._mitsuba_params[key]
166 |
167 | if type(value) == mi.Transform4f:
168 | continue
169 |
170 | if isinstance(value, mi.Float) or isinstance(value, float):
171 | new_light.add_float_key(key_without_base, value, value)
172 | elif len(value) == 3:
173 | value = value.torch().squeeze()
174 | new_light.add_vec3_key(key_without_base, value, value)
175 |
176 | new_light.set_randomizable(False)
177 | self._lights.append(new_light)
178 |
179 | def load_material(self, base_key: str) -> None:
180 | new_material = fireflies.material.Material(base_key, device=self._device)
181 |
182 | material_keys = []
183 | for key in self._mitsuba_params.keys():
184 | if base_key in key:
185 | material_keys.append(key)
186 |
187 | for key in material_keys:
188 | key_without_base = ".".join(key.split(".")[1:])
189 | value = self._mitsuba_params[key]
190 |
191 | if type(value) == mi.Transform4f or isinstance(value, mi.ScalarTransform3f):
192 | continue
193 |
194 | if isinstance(value, mi.Float) or isinstance(value, float):
195 | new_material.add_float_key(key_without_base, value, value)
196 | elif len(value) == 3:
197 | value = value.torch().squeeze()
198 | new_material.add_vec3_key(key_without_base, value, value)
199 |
200 | new_material.set_randomizable(False)
201 | self._materials.append(new_material)
202 |
203 | def train(self) -> None:
204 | # We first randomize all of our objects
205 | for mesh in self._meshes:
206 | mesh.train()
207 |
208 | for light in self._lights:
209 | light.train()
210 |
211 | for material in self._materials:
212 | material.train()
213 |
214 | if self._camera is not None:
215 | self._camera.train()
216 |
217 | if self._projector is not None:
218 | self._projector.train()
219 |
220 | def eval(self) -> None:
221 | # We first randomize all of our objects
222 | for mesh in self._meshes:
223 | mesh.eval()
224 |
225 | for light in self._lights:
226 | light.eval()
227 |
228 | for material in self._materials:
229 | material.eval()
230 |
231 | if self._camera is not None:
232 | self._camera.eval()
233 |
234 | if self._projector is not None:
235 | self._projector.eval()
236 |
237 | def load_curve(self, path: str, name: str = "Curve") -> None:
238 | curve = fireflies.utils.importBlenderNurbsObj(path)
239 | transformable_curve = fireflies.entity.Curve(name, curve, self._device)
240 |
241 | self.curves.append(transformable_curve)
242 |
243 | def update_meshes(self) -> None:
244 | for mesh in self._meshes:
245 | if not mesh.randomizable():
246 | continue
247 |
248 | vertex_data = mesh.get_randomized_vertices()
249 | self._mitsuba_params[mesh.name() + ".vertex_positions"] = mi.Float32(
250 | vertex_data.flatten()
251 | )
252 |
253 | def update_camera(self) -> None:
254 | if not self._camera.randomizable():
255 | return
256 |
257 | self._mitsuba_params[self._camera.name() + ".to_world"] = mi.Transform4f(
258 | self._camera.world().tolist()
259 | )
260 |
261 | float_dict = self._camera.get_randomized_float_attributes()
262 | vec3_dict = self._camera.get_randomized_vec3_attributes()
263 |
264 | for key, value in float_dict.items():
265 | joined_key = self._camera.name() + "." + key
266 | temp_type = type(self._mitsuba_params[joined_key])
267 | self._mitsuba_params[joined_key] = temp_type(value.item())
268 |
269 | for key, value in vec3_dict.items():
270 | joined_key = self._camera.name() + "." + key
271 | temp_type = type(self._mitsuba_params[joined_key])
272 | self._mitsuba_params[self._camera.name() + "." + key] = temp_type(
273 | value.tolist()
274 | )
275 |
276 | def update_projector(self) -> None:
277 | if not self._projector.randomizable():
278 | return
279 |
280 | self._mitsuba_params[self._projector.name() + ".to_world"] = mi.Transform4f(
281 | self._projector.world().tolist()
282 | )
283 |
284 | float_dict = self._projector.get_randomized_float_attributes()
285 | vec3_dict = self._projector.get_randomized_vec3_attributes()
286 |
287 | for key, value in float_dict.items():
288 | joined_key = self._projector.name() + "." + key
289 | temp_type = type(self._mitsuba_params[joined_key])
290 | self._mitsuba_params[joined_key] = temp_type(value.item())
291 |
292 | for key, value in vec3_dict.items():
293 | joined_key = self._projector.name() + "." + key
294 | temp_type = type(self._mitsuba_params[joined_key])
295 | self._mitsuba_params[self._projector.name() + "." + key] = temp_type(
296 | value.tolist()
297 | )
298 |
299 | def update_lights(self) -> None:
300 | for light in self._lights:
301 | if not light.randomizable():
302 | continue
303 |
304 | if light.name() + ".to_world" in self._mitsuba_params.keys():
305 | self._mitsuba_params[light.name() + ".to_world"] = mi.Transform4f(
306 | light.world().tolist()
307 | )
308 |
309 | float_dict = light.get_randomized_float_attributes()
310 | vec3_dict = light.get_randomized_vec3_attributes()
311 |
312 | for key, value in float_dict.items():
313 | joined_key = light.name() + "." + key
314 | temp_type = type(self._mitsuba_params[joined_key])
315 | self._mitsuba_params[joined_key] = temp_type(value.item())
316 |
317 | for key, value in vec3_dict.items():
318 | joined_key = light.name() + "." + key
319 | temp_type = type(self._mitsuba_params[joined_key])
320 | self._mitsuba_params[light.name() + "." + key] = temp_type(
321 | value.tolist()
322 | )
323 |
324 | def update_materials(self) -> None:
325 | for material in self._materials:
326 | if not material.randomizable():
327 | continue
328 |
329 | float_dict = material.get_randomized_float_attributes()
330 | vec3_dict = material.get_randomized_vec3_attributes()
331 |
332 | for key, value in float_dict.items():
333 | joined_key = material.name() + "." + key
334 | temp_type = type(self._mitsuba_params[joined_key])
335 | self._mitsuba_params[joined_key] = temp_type(value.item())
336 |
337 | for key, value in vec3_dict.items():
338 | joined_key = material.name() + "." + key
339 | temp_type = type(self._mitsuba_params[joined_key])
340 | self._mitsuba_params[material.name() + "." + key] = temp_type(
341 | value.tolist()
342 | )
343 |
344 | def randomize_list(self, entity_list: List[fireflies.entity.Transformable]) -> None:
345 | # First find parent objects, i.e. child is none
346 | parent_objects = []
347 | for entity in entity_list:
348 | if entity.parent() is None:
349 | parent_objects.append(entity)
350 |
351 | # Now iterate through every parent object and iteratively call each child randomization function
352 | for entity in parent_objects:
353 | entity.randomize()
354 |
355 | iterator_child = entity.child()
356 | while iterator_child is not None:
357 | iterator_child.randomize()
358 | iterator_child = iterator_child.child()
359 |
360 | def randomize(self) -> None:
361 | # We first randomize all of our objects
362 | self.randomize_list(self._meshes)
363 | self.randomize_list(self._lights)
364 | self.randomize_list(self._materials)
365 |
366 | if self._camera is not None:
367 | self._camera.randomize()
368 |
369 | if self._projector is not None:
370 | self._projector.randomize()
371 |
372 | # And then copy the updates to the mitsuba parameters
373 | self.update_meshes()
374 |
375 | if self._camera is not None:
376 | self.update_camera()
377 |
378 | if self._projector is not None:
379 | self.update_projector()
380 | self.update_lights()
381 | self.update_materials()
382 |
383 | # We finally update the mitsuba scene graph itself
384 | self._mitsuba_params.update()
385 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import mitsuba as mi
3 | import numpy as np
4 | import kornia
5 |
6 | mi.set_variant("cuda_ad_rgb")
7 |
8 | import torch
9 | import fireflies
10 | import fireflies.sampling
11 | import fireflies.projection.laser
12 | import fireflies.postprocessing
13 | import fireflies.utils.math
14 | import fireflies.graphics.depth
15 | import os
16 |
17 | from tqdm import tqdm
18 |
19 |
20 | def render_to_numpy(render):
21 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy()
22 | return render
23 |
24 |
25 | if __name__ == "__main__":
26 | path = "examples/scenes/realistic_vf/vocalfold.xml"
27 | dataset_path = "../LearningFromFireflies/fireflies_dataset_v4/"
28 |
29 | mitsuba_scene = mi.load_file(path, parallel=False)
30 | mitsuba_params = mi.traverse(mitsuba_scene)
31 | ff_scene = fireflies.Scene(mitsuba_params)
32 | ff_scene._camera._name = "Camera"
33 |
34 | projector_sensor = mitsuba_scene.sensors()[1]
35 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"]
36 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"]
37 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"]
38 |
39 | K_PROJECTOR = mi.perspective_projection(
40 | projector_sensor.film().size(),
41 | projector_sensor.film().crop_size(),
42 | projector_sensor.film().crop_offset(),
43 | x_fov,
44 | near_clip,
45 | far_clip,
46 | ).matrix.torch()[0]
47 |
48 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays(
49 | # 0.0275, 18, 18, device=ff_scene.device()
50 | # )
51 | laser_rays = fireflies.projection.Laser.generate_uniform_rays(
52 | 0.0275, 18, 18, device=ff_scene.device()
53 | )
54 |
55 | laser = fireflies.projection.Laser(
56 | ff_scene._projector,
57 | laser_rays,
58 | K_PROJECTOR,
59 | x_fov,
60 | near_clip,
61 | far_clip,
62 | device=ff_scene.device(),
63 | )
64 | texture = laser.generateTexture(
65 | 10.0, torch.tensor([500, 500], device=ff_scene.device())
66 | )
67 | texture = texture.sum(dim=0)
68 |
69 | texture = kornia.filters.gaussian_blur2d(
70 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3)
71 | ).squeeze()
72 | texture = torch.stack(
73 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)]
74 | )
75 | texture = torch.movedim(texture, 0, -1)
76 |
77 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy())
78 |
79 | vocalfold_mesh = ff_scene.mesh("mesh-Vocalfold")
80 | vocalfold_mesh.scale_x(1.0, 3.0)
81 | vocalfold_mesh.scale_z(1.0, 3.0)
82 | vocalfold_mesh.rotate_y(-0.2, 0.2)
83 | vocalfold_mesh.translate_y(-0.05, -0.05)
84 | vocalfold_mesh.add_train_animation_from_obj("examples/scenes/vocalfold_new/train/")
85 | vocalfold_mesh.add_eval_animation_from_obj("examples/scenes/vocalfold_new/test/")
86 |
87 | a = vocalfold_mesh._anim_data_train
88 | scale_mat = torch.eye(4, device=ff_scene.device()) * 0.054
89 | scale_mat[3, 3] = 1.0
90 | for i in range(a.shape[0]):
91 | a[i] = fireflies.utils.math.transform_points(a[i], scale_mat)
92 |
93 | larynx_mesh = ff_scene.mesh("mesh-Larynx")
94 | larynx_mesh.scale_x(1.0, 4.0)
95 | larynx_mesh.scale_z(1.0, 2.0)
96 |
97 | # Randomization of Mucosa material
98 | material = ff_scene.material("mat-Mucosa")
99 | material.add_float_key("brdf_0.clearcoat.value", 0.0, 1.0)
100 | material.add_float_key("brdf_0.clearcoat_gloss.value", 0.0, 1.0)
101 | material.add_float_key("brdf_0.metallic.value", 0.0, 0.5)
102 | material.add_float_key("brdf_0.specular", 0.0, 1.0)
103 | material.add_float_key("brdf_0.roughness.value", 0.0, 1.0)
104 | material.add_float_key("brdf_0.anisotropic.value", 0.0, 1.0)
105 | material.add_float_key("brdf_0.sheen.value", 0.0, 0.5)
106 | material.add_float_key("brdf_0.spec_trans.value", 0.0, 0.4)
107 | material.add_float_key("brdf_0.flatness.value", 0.0, 1.0)
108 |
109 | # Camera Randomization
110 | ff_scene._camera.translate_x(-0.15, 0.15)
111 | ff_scene._camera.translate_y(-0.15, 0.0)
112 | ff_scene._camera.translate_z(-0.15, 0.15)
113 | ff_scene._camera.rotate_x(-0.2, 0.2)
114 | ff_scene._camera.rotate_y(-0.5, 0.5)
115 | ff_scene._camera.rotate_z((-np.pi / 2.0) - 0.5, (-np.pi / 2.0) + 0.5)
116 | ff_scene._camera.add_float_key("x_fov", 70.0, 130.0)
117 |
118 | # Light Randomization
119 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler(
120 | 0.1, 10.0, device=ff_scene.device()
121 | )
122 | light = ff_scene.light("emit-Spot")
123 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler)
124 |
125 | texture = (
126 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"]
127 | .torch()
128 | .moveaxis(-1, 0)
129 | .shape
130 | )
131 |
132 | lerp_sampler = fireflies.sampling.NoiseTextureLerpSampler(
133 | color_a=torch.tensor([0.0, 0.0, 0.0], device=ff_scene.device()),
134 | color_b=torch.tensor([1.0, 1.0, 1.0], device=ff_scene.device()),
135 | texture_shape=(1024, 1024),
136 | )
137 |
138 | post_process_funcs = [
139 | fireflies.postprocessing.GaussianBlur((3, 3), (5, 5), 0.5),
140 | fireflies.postprocessing.ApplySilhouette(),
141 | fireflies.postprocessing.WhiteNoise(0.0, 0.05, 0.5),
142 | ]
143 | post_processor = fireflies.postprocessing.PostProcessor(post_process_funcs)
144 | spp_sampler = fireflies.sampling.AnimationSampler(1, 100, 1, 100)
145 | ff_scene.train()
146 | count = 0
147 | while count != 10000:
148 | lerp_sampler._color_a = torch.rand((3), device=ff_scene.device())
149 | lerp_sampler._color_b = torch.rand((3), device=ff_scene.device())
150 | mucosa_texture = lerp_sampler.sample()
151 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"] = mi.TensorXf(
152 | mucosa_texture.moveaxis(0, -1).cpu().numpy()
153 | )
154 |
155 | ff_scene.randomize()
156 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample())
157 | render = render_to_numpy(render)
158 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY)
159 | render = post_processor.post_process(render)
160 |
161 | segmentation = (
162 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene)
163 | .cpu()
164 | .numpy()
165 | .astype(np.float32)
166 | )
167 |
168 | if segmentation.max() == 0:
169 | continue
170 |
171 | seg_test = segmentation.copy()
172 | seg_test[seg_test == 1] = 0
173 | seg_test[seg_test == 2] = 1
174 | seg_test = 1 - seg_test
175 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
176 | seg_test.astype(np.uint8)
177 | )
178 |
179 | if n_labels > 3:
180 | continue
181 |
182 | segmentation_map = segmentation / segmentation.max()
183 | final = cv2.hconcat([render, segmentation_map])
184 |
185 | cv2.imwrite(
186 | os.path.join(dataset_path, "train/images/{0:05d}.png".format(count)),
187 | (render * 255).astype(np.uint8),
188 | )
189 | cv2.imwrite(
190 | os.path.join(dataset_path, "train/segmentation/{0:05d}.png".format(count)),
191 | segmentation.astype(np.uint8),
192 | )
193 | count += 1
194 |
195 | count = 0
196 | while count != 500:
197 | lerp_sampler._color_a = torch.rand((3), device=ff_scene.device())
198 | lerp_sampler._color_b = torch.rand((3), device=ff_scene.device())
199 | mucosa_texture = lerp_sampler.sample()
200 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"] = mi.TensorXf(
201 | mucosa_texture.moveaxis(0, -1).cpu().numpy()
202 | )
203 |
204 | ff_scene.randomize()
205 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample())
206 | render = render_to_numpy(render)
207 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY)
208 | render = post_processor.post_process(render)
209 |
210 | segmentation = (
211 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene)
212 | .cpu()
213 | .numpy()
214 | .astype(np.float32)
215 | )
216 |
217 | if segmentation.max() == 0:
218 | continue
219 |
220 | seg_test = segmentation.copy()
221 | seg_test[seg_test == 1] = 0
222 | seg_test[seg_test == 2] = 1
223 | seg_test = 1 - seg_test
224 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
225 | seg_test.astype(np.uint8)
226 | )
227 |
228 | if n_labels > 3:
229 | continue
230 |
231 | segmentation_map = segmentation / segmentation.max()
232 | final = cv2.hconcat([render, segmentation_map])
233 |
234 | cv2.imwrite(
235 | os.path.join(dataset_path, "eval/images/{0:05d}.png".format(count)),
236 | (render * 255).astype(np.uint8),
237 | )
238 | cv2.imwrite(
239 | os.path.join(dataset_path, "eval/segmentation/{0:05d}.png".format(count)),
240 | segmentation.astype(np.uint8),
241 | )
242 | count += 1
243 |
244 | '''
245 | if __name__ == "__main__":
246 | path = "examples/scenes/vocalfold_new/vocalfold.xml"
247 |
248 | mitsuba_scene = mi.load_file(path, parallel=False)
249 | mitsuba_params = mi.traverse(mitsuba_scene)
250 | ff_scene = fireflies.Scene(mitsuba_params)
251 | ff_scene._camera._name = "Camera"
252 |
253 | projector_sensor = mitsuba_scene.sensors()[1]
254 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"]
255 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"]
256 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"]
257 |
258 | dataset_path = "../LearningFromFireflies/fireflies_dataset_v3/"
259 |
260 | K_PROJECTOR = mi.perspective_projection(
261 | projector_sensor.film().size(),
262 | projector_sensor.film().crop_size(),
263 | projector_sensor.film().crop_offset(),
264 | x_fov,
265 | near_clip,
266 | far_clip,
267 | ).matrix.torch()[0]
268 |
269 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays(
270 | # 0.0275, 18, 18, device=ff_scene.device()
271 | # )
272 | laser_rays = fireflies.projection.Laser.generate_uniform_rays(
273 | 0.0275, 18, 18, device=ff_scene.device()
274 | )
275 |
276 | laser = fireflies.projection.Laser(
277 | ff_scene._projector,
278 | laser_rays,
279 | K_PROJECTOR,
280 | x_fov,
281 | near_clip,
282 | far_clip,
283 | device=ff_scene.device(),
284 | )
285 | texture = laser.generateTexture(
286 | 10.0, torch.tensor([500, 500], device=ff_scene.device())
287 | )
288 | texture = texture.sum(dim=0)
289 |
290 | texture = kornia.filters.gaussian_blur2d(
291 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3)
292 | ).squeeze()
293 | texture = torch.stack(
294 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)]
295 | )
296 | texture = torch.movedim(texture, 0, -1)
297 |
298 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy())
299 |
300 | vocalfold_mesh = ff_scene.mesh("mesh-VocalFold")
301 | vocalfold_mesh.scale_x(0.75, 3.0)
302 | vocalfold_mesh.scale_y(1.1, 2.0)
303 | vocalfold_mesh.rotate_y(-0.2, 0.2)
304 | vocalfold_mesh.add_train_animation_from_obj("examples/scenes/vocalfold_new/train/")
305 | vocalfold_mesh.add_eval_animation_from_obj("examples/scenes/vocalfold_new/test/")
306 |
307 | a = vocalfold_mesh._anim_data_train
308 | scale_mat = torch.eye(4, device=ff_scene.device()) * 0.05
309 | scale_mat[3, 3] = 1.0
310 | rot_mat = fireflies.utils.math.toMat4x4(
311 | fireflies.utils.math.getXTransform(np.pi / 2.0, ff_scene.device())
312 | )
313 | for i in range(a.shape[0]):
314 | a[i] = fireflies.utils.math.transform_points(a[i], scale_mat)
315 | a[i] = fireflies.utils.math.transform_points(a[i], rot_mat)
316 |
317 | larynx_mesh = ff_scene.mesh("mesh-Larynx")
318 | larynx_mesh.scale_x(0.3, 1.2)
319 | larynx_mesh.rotate_y(-0.5, 0.5)
320 | # larynx_mesh.scale_z(1.0, 2.5)
321 |
322 | material = ff_scene.material("mat-Default OBJ")
323 | material.add_vec3_key(
324 | "brdf_0.base_color.value",
325 | torch.tensor([0.3, 0.3, 0.33], device=ff_scene.device()),
326 | torch.tensor([0.85, 0.85, 0.85], device=ff_scene.device()),
327 | )
328 |
329 | for key in material.float_attributes():
330 | if "sampling_rate" in key:
331 | continue
332 |
333 | material.add_float_key(key, 0.01, 0.99)
334 |
335 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler(
336 | 1.0, 80.0, device=ff_scene.device()
337 | )
338 | light = ff_scene.light("emit-Spot")
339 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler)
340 |
341 | post_process_funcs = [
342 | fireflies.postprocessing.GaussianBlur((3, 3), (5, 5), 0.5),
343 | fireflies.postprocessing.ApplySilhouette(),
344 | fireflies.postprocessing.WhiteNoise(0.0, 0.05, 0.5),
345 | ]
346 | post_processor = fireflies.postprocessing.PostProcessor(post_process_funcs)
347 |
348 | ff_scene._camera.translate_x(-0.5, 0.5)
349 | ff_scene._camera.translate_y(-0.5, 0.5)
350 | ff_scene._camera.translate_z(-0.5, 0.5)
351 | ff_scene._camera.add_float_key("x_fov", 20.0, 50.0)
352 | ff_scene.train()
353 |
354 | spp_sampler = fireflies.sampling.AnimationSampler(1, 100, 1, 100)
355 | spp_sampler.train()
356 | """
357 | count = 0
358 | while count != 10000:
359 | ff_scene.randomize()
360 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample())
361 | render = render_to_numpy(render)
362 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY)
363 |
364 | render = post_processor.post_process(render)
365 |
366 | segmentation = (
367 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene)
368 | .cpu()
369 | .numpy()
370 | .astype(np.float32)
371 | )
372 |
373 | if segmentation.max() == 0:
374 | continue
375 |
376 | seg_test = segmentation.copy()
377 | seg_test[seg_test == 2] = 1
378 | seg_test = 1 - seg_test
379 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
380 | seg_test.astype(np.uint8)
381 | )
382 |
383 | if n_labels > 3:
384 | continue
385 |
386 | segmentation_map = segmentation / segmentation.max()
387 | final = cv2.hconcat([render, segmentation_map])
388 |
389 | cv2.imwrite(
390 | os.path.join(dataset_path, "train/images/{0:05d}.png".format(count)),
391 | (render * 255).astype(np.uint8),
392 | )
393 | cv2.imwrite(
394 | os.path.join(dataset_path, "train/segmentation/{0:05d}.png".format(count)),
395 | segmentation.astype(np.uint8),
396 | )
397 | count += 1
398 | """
399 |
400 | count = 0
401 | while count != 500:
402 | ff_scene.randomize()
403 | render = mi.render(mitsuba_scene, spp=100)
404 | render = render_to_numpy(render)
405 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY)
406 |
407 | render = post_processor.post_process(render)
408 |
409 | segmentation = (
410 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene)
411 | .cpu()
412 | .numpy()
413 | .astype(np.float32)
414 | )
415 |
416 | if segmentation.max() == 0:
417 | continue
418 |
419 | seg_test = segmentation.copy()
420 | seg_test[seg_test == 2] = 1
421 | seg_test = 1 - seg_test
422 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
423 | seg_test.astype(np.uint8)
424 | )
425 |
426 | if n_labels > 3:
427 | continue
428 |
429 | segmentation_map = segmentation / segmentation.max()
430 | final = cv2.hconcat([render, segmentation_map])
431 |
432 | cv2.imwrite(
433 | os.path.join(dataset_path, "eval/images/{0:05d}.png".format(count)),
434 | (render * 255).astype(np.uint8),
435 | )
436 | cv2.imwrite(
437 | os.path.join(dataset_path, "eval/segmentation/{0:05d}.png".format(count)),
438 | segmentation.astype(np.uint8),
439 | )
440 | count += 1
441 | '''
442 |
--------------------------------------------------------------------------------
/fireflies/graphics/rasterization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import math
4 |
5 |
6 | # We assume points to be in camera space [0, 1]
7 | def rasterize_points(
8 | points: torch.tensor,
9 | sigma: float,
10 | texture_size: torch.tensor,
11 | device: torch.cuda.device = torch.device("cuda"),
12 | ) -> torch.tensor:
13 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
14 | tex = tex[None, ...]
15 | tex = tex.repeat((points.shape[0], 1, 1, 1))
16 |
17 | # Somewhere between [0, texture_size] but in float
18 | points = points.clone() * texture_size
19 |
20 | # Generate x, y indices
21 | x, y = torch.meshgrid(
22 | torch.arange(0, texture_size[1], device=device),
23 | torch.arange(0, texture_size[0], device=device),
24 | indexing="ij",
25 | )
26 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1))
27 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1))
28 |
29 | y_dist = y - points[:, 0:1].unsqueeze(-1)
30 | x_dist = x - points[:, 1:2].unsqueeze(-1)
31 |
32 | point_distances = (
33 | y_dist * y_dist + x_dist * x_dist
34 | ) # / (texture_size * texture_size).sum().sqrt()
35 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2))
36 |
37 | return point_distances
38 |
39 |
40 | def rasterize_points_in_non_ndc(
41 | points: torch.tensor,
42 | sigma: float,
43 | texture_size: torch.tensor,
44 | device: torch.cuda.device = torch.device("cuda"),
45 | ) -> torch.tensor:
46 | x, y = torch.meshgrid(
47 | torch.arange(0, texture_size[1], device=device),
48 | torch.arange(0, texture_size[0], device=device),
49 | indexing="ij",
50 | )
51 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1))
52 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1))
53 |
54 | y_dist = y - points[:, 0:1].unsqueeze(-1)
55 | x_dist = x - points[:, 1:2].unsqueeze(-1)
56 |
57 | point_distances = (
58 | y_dist * y_dist + x_dist * x_dist
59 | ) # / (texture_size * texture_size).sum().sqrt()
60 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2))
61 |
62 | return point_distances
63 |
64 |
65 | # We assume points to be in NDC [-1, 1]
66 | def rasterize_depth(
67 | points: torch.tensor,
68 | depth_vals: torch.tensor,
69 | sigma: float,
70 | texture_size: torch.tensor,
71 | device: torch.cuda.device = torch.device("cuda"),
72 | ) -> torch.tensor:
73 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
74 | tex = tex[None, ...]
75 | tex = tex.repeat((points.shape[0], 1, 1, 1))
76 |
77 | # Somewhere between [0, texture_size] but in float
78 | points = points.clone() * texture_size
79 |
80 | # Generate x, y indices
81 | x, y = torch.meshgrid(
82 | torch.arange(0, texture_size[1], device=device),
83 | torch.arange(0, texture_size[0], device=device),
84 | indexing="ij",
85 | )
86 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1))
87 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1))
88 |
89 | y_dist = y - points[:, 0:1].unsqueeze(-1)
90 | x_dist = x - points[:, 1:2].unsqueeze(-1)
91 |
92 | point_distances = (
93 | y_dist * y_dist + x_dist * x_dist
94 | ) # / (texture_size * texture_size).sum().sqrt()
95 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2))
96 |
97 | # normalize
98 | point_distances = (
99 | point_distances
100 | / point_distances.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
101 | )
102 |
103 | # scale by depth in range [0, 1]
104 | return point_distances * depth_vals.unsqueeze(-1)
105 |
106 |
107 | def rasterize_lines(
108 | lines: torch.tensor,
109 | sigma: float,
110 | texture_size: torch.tensor,
111 | device: torch.cuda.device = torch.device("cuda"),
112 | ) -> torch.tensor:
113 | # lines are in NDC [-1, 1]
114 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
115 | tex = tex[None, ...]
116 | tex = tex.repeat((lines.shape[0], 1, 1, 1))
117 |
118 | lines_start = lines[:, 0, :]
119 | lines_end = lines[:, 1, :]
120 |
121 | # Somewhere between [0, texture_size] but in float
122 | lines_start *= texture_size
123 | lines_end *= texture_size
124 |
125 | lines_start = lines_start.permute(1, 0).unsqueeze(-1).unsqueeze(-1)
126 | lines_end = lines_end.permute(1, 0).unsqueeze(-1).unsqueeze(-1)
127 |
128 | y, x = torch.meshgrid(
129 | torch.arange(0, texture_size[1], device=device),
130 | torch.arange(0, texture_size[0], device=device),
131 | indexing="ij",
132 | )
133 | y = y.unsqueeze(0).repeat((lines.shape[0], 1, 1))
134 | x = x.unsqueeze(0).repeat((lines.shape[0], 1, 1))
135 | xy = torch.stack([x, y])
136 |
137 | # See: https://github.com/jonhare/DifferentiableSketching/blob/main/dsketch/raster/disttrans.py
138 | # If you found this, you should definitely give them a star. That's beautiful code they wrote there.
139 |
140 | pa = xy - lines_start
141 | pb = xy - lines_end
142 | m = lines_end - lines_start
143 |
144 | t0 = (pa * m).sum(dim=0) / ((m * m).sum(dim=0) + torch.finfo().eps)
145 | patm = xy - (lines_start + t0.unsqueeze(0) * m)
146 |
147 | distance_smaller_zero = (t0 <= 0) * (pa * pa).sum(dim=0)
148 | distance_inbetween = (t0 > 0) * (t0 < 1) * (patm * patm).sum(dim=0)
149 | distance_greater_one = (t0 >= 1) * (pb * pb).sum(dim=0)
150 |
151 | distances = distance_smaller_zero + distance_inbetween + distance_greater_one
152 | # distances = distances.sqrt()
153 | return torch.exp(-(distances * distances) / (sigma * sigma))
154 |
155 |
156 | def softor(texture: torch.tensor, dim=0, keepdim: bool = False) -> torch.tensor:
157 | return 1 - torch.prod(1 - texture, dim=dim, keepdim=keepdim)
158 |
159 |
160 | def sum(texture: torch.tensor, dim=0, keepdim: bool = False) -> torch.tensor:
161 | return torch.sum(texture, dim=dim, keepdim=keepdim)
162 |
163 |
164 | def baked_sum(
165 | points: torch.tensor,
166 | sigma: torch.tensor,
167 | texture_size: torch.tensor,
168 | num_std: int = 4,
169 | device: torch.cuda.device = torch.device("cuda"),
170 | ) -> torch.tensor:
171 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
172 |
173 | # Somewhere between [0, texture_size] but in float
174 | points = points.clone() * texture_size
175 |
176 | for i in range(points.shape[0]):
177 | point = points[i]
178 |
179 | # We use 3*sigma^2 here, to include most of the gaussian
180 | footprint = math.floor(sigma.sqrt().item()) * num_std
181 | footprint = footprint + 1 if footprint % 2 == 0 else footprint
182 | half_footprint = int((footprint - 1) / 2)
183 |
184 | point_in_middle_of_footprint = point - point.floor() + half_footprint
185 |
186 | footprint_origin_in_original_image = (point - half_footprint).floor() # [Y, X]
187 |
188 | y, x = torch.meshgrid(
189 | torch.arange(0, footprint, device=device),
190 | torch.arange(0, footprint, device=device),
191 | indexing="ij",
192 | )
193 |
194 | y_dist = y - point_in_middle_of_footprint[0:1]
195 | x_dist = x - point_in_middle_of_footprint[1:2]
196 | dist = y_dist * y_dist + x_dist * x_dist
197 | dist = torch.exp(-torch.pow(dist / sigma, 2))
198 |
199 | wo = (point.floor() - half_footprint).int()
200 | re = [footprint, footprint]
201 | rs = [0, 0]
202 |
203 | # tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]] = tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]].clone() + dist[rs[0]:re[0], rs[1]:re[1]]
204 |
205 | # This is probably the worst code I've ever written.
206 | # It is an out of bounds check for rectangles such that we can copy the correct parts from our tensors.
207 | # There's certainly better and more precise ways to solve this, but for know it works.
208 |
209 | rect_start = torch.tensor([0, 0], dtype=torch.int32, device=device)
210 | rect_end = torch.tensor(
211 | [footprint, footprint], dtype=torch.int32, device=device
212 | )
213 |
214 | if footprint_origin_in_original_image[0] < 0:
215 | rect_start[0] = footprint_origin_in_original_image.abs()[0]
216 | footprint_origin_in_original_image[0] = 0
217 |
218 | if footprint_origin_in_original_image[1] < 0:
219 | rect_start[1] = footprint_origin_in_original_image.abs()[1]
220 | footprint_origin_in_original_image[1] = 0
221 |
222 | if footprint_origin_in_original_image[0] + footprint >= texture_size[0]:
223 | rect_end[0] = texture_size[0] - footprint_origin_in_original_image[0]
224 |
225 | if footprint_origin_in_original_image[1] + footprint >= texture_size[1]:
226 | rect_end[1] = texture_size[1] - footprint_origin_in_original_image[1]
227 |
228 | wo = footprint_origin_in_original_image.int()
229 | rs = rect_start.int()
230 | re = rect_end.int()
231 |
232 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]] = (
233 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]].clone()
234 | + dist[rs[0] : re[0], rs[1] : re[1]]
235 | )
236 |
237 | return tex.T
238 |
239 |
240 | def baked_sum_2(
241 | points: torch.tensor,
242 | sigma: torch.tensor,
243 | texture_size: torch.tensor,
244 | num_std: int = 4,
245 | device: torch.cuda.device = torch.device("cuda"),
246 | ) -> torch.tensor:
247 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
248 | # Somewhere between [0, texture_size] but in float
249 | points = points.clone() * texture_size
250 |
251 | # We use 3*sigma^2 here, to include most of the gaussian
252 | footprint = math.floor(sigma.sqrt().item()) * num_std
253 | footprint = footprint + 1 if footprint % 2 == 0 else footprint
254 | half_footprint = int((footprint - 1) / 2)
255 |
256 | point_in_middle_of_footprint = points - points.floor() + half_footprint
257 |
258 | footprint_origin_in_original_image = (points - half_footprint).floor() # [Y, X]
259 |
260 | y, x = torch.meshgrid(
261 | torch.arange(0, footprint, device=device),
262 | torch.arange(0, footprint, device=device),
263 | indexing="ij",
264 | )
265 |
266 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1))
267 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1))
268 |
269 | y_dist = y - point_in_middle_of_footprint[:, 0:1].unsqueeze(-1)
270 | x_dist = x - point_in_middle_of_footprint[:, 1:2].unsqueeze(-1)
271 |
272 | dist = y_dist * y_dist + x_dist * x_dist
273 | dist = torch.exp(-torch.pow(dist / sigma, 2))
274 |
275 | wo = footprint_origin_in_original_image.int()
276 | rect_start = torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device)
277 | rect_end = (
278 | torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) + footprint
279 | )
280 |
281 | if (wo[:, 0] < 0).any():
282 | rect_start[:, 0] = torch.where(wo[:, 0] < 0, wo.abs()[:, 0], rect_start[:, 0])
283 | wo[:, 0] = torch.where(wo[:, 0] < 0, 0, wo[:, 0])
284 |
285 | if (wo[:, 1] < 0).any():
286 | rect_start[:, 1] = torch.where(wo[:, 1] < 1, wo.abs()[:, 1], rect_start[:, 1])
287 | wo[:, 1] = torch.where(wo[:, 1] < 0, 0, wo[:, 1])
288 |
289 | if (wo[:, 0] + footprint >= texture_size[0]).any():
290 | rect_end[:, 0] = torch.where(
291 | wo[:, 0] + footprint >= texture_size[0],
292 | texture_size[0] - wo[:, 0],
293 | rect_end[:, 0],
294 | )
295 |
296 | if (wo[:, 1] + footprint >= texture_size[1]).any():
297 | rect_end[:, 1] = torch.where(
298 | wo[:, 1] + footprint >= texture_size[1],
299 | texture_size[1] - wo[:, 1],
300 | rect_end[:, 1],
301 | )
302 |
303 | re = rect_end.clone()
304 | rs = rect_start.clone()
305 |
306 | for i in range(points.shape[0]):
307 | tex[
308 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0],
309 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1],
310 | ] = (
311 | tex[
312 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0],
313 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1],
314 | ].clone()
315 | + dist[i, rs[i, 0] : re[i, 0], rs[i, 1] : re[i, 1]]
316 | )
317 |
318 | return tex
319 |
320 |
321 | def baked_softor(
322 | points: torch.tensor,
323 | sigma: torch.tensor,
324 | texture_size: torch.tensor,
325 | num_std: int = 5,
326 | device: torch.cuda.device = torch.device("cuda"),
327 | ) -> torch.tensor:
328 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device)
329 | # Somewhere between [0, texture_size] but in float
330 | points = points.clone() * texture_size
331 |
332 | for i in range(points.shape[0]):
333 | point = points[i]
334 |
335 | # We use 3*sigma^2 here, to include most of the gaussian
336 | footprint = math.floor(sigma.sqrt().item()) * num_std
337 | footprint = footprint + 1 if footprint % 2 == 0 else footprint
338 | half_footprint = int((footprint - 1) / 2)
339 |
340 | point_in_middle_of_footprint = point - point.floor() + half_footprint
341 |
342 | footprint_origin_in_original_image = (point - half_footprint).floor() # [Y, X]
343 |
344 | y, x = torch.meshgrid(
345 | torch.arange(0, footprint, device=device),
346 | torch.arange(0, footprint, device=device),
347 | indexing="ij",
348 | )
349 |
350 | y_dist = y - point_in_middle_of_footprint[0:1]
351 | x_dist = x - point_in_middle_of_footprint[1:2]
352 | dist = y_dist * y_dist + x_dist * x_dist
353 | dist = torch.exp(-torch.pow(dist / sigma, 2))
354 |
355 | wo = (point.floor() - half_footprint).int()
356 | re = [footprint, footprint]
357 | rs = [0, 0]
358 |
359 | # tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]] = tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]].clone() + dist[rs[0]:re[0], rs[1]:re[1]]
360 |
361 | # This is probably the worst code I've ever written.
362 | # It is an out of bounds check for rectangles such that we can copy the correct parts from our tensors.
363 | # There's certainly better and more precise ways to solve this, but for know it works.
364 |
365 | rect_start = torch.tensor([0, 0], dtype=torch.int32, device=device)
366 | rect_end = torch.tensor(
367 | [footprint, footprint], dtype=torch.int32, device=device
368 | )
369 |
370 | if footprint_origin_in_original_image[0] < 0:
371 | rect_start[0] = footprint_origin_in_original_image.abs()[0]
372 | footprint_origin_in_original_image[0] = 0
373 |
374 | if footprint_origin_in_original_image[1] < 0:
375 | rect_start[1] = footprint_origin_in_original_image.abs()[1]
376 | footprint_origin_in_original_image[1] = 0
377 |
378 | if footprint_origin_in_original_image[0] + footprint >= texture_size[0]:
379 | rect_end[0] = texture_size[0] - footprint_origin_in_original_image[0]
380 |
381 | if footprint_origin_in_original_image[1] + footprint >= texture_size[1]:
382 | rect_end[1] = texture_size[1] - footprint_origin_in_original_image[1]
383 |
384 | wo = footprint_origin_in_original_image.int()
385 | rs = rect_start.int()
386 | re = rect_end.int()
387 |
388 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]] = tex[
389 | wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]
390 | ].clone() * (1 - dist[rs[0] : re[0], rs[1] : re[1]])
391 |
392 | return (1 - tex).T
393 |
394 |
395 | def baked_softor_2(
396 | points: torch.tensor,
397 | sigma: torch.tensor,
398 | texture_size: torch.tensor,
399 | num_std: int = 5,
400 | device: torch.cuda.device = torch.device("cuda"),
401 | ) -> torch.tensor:
402 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device)
403 | # Somewhere between [0, texture_size] but in float
404 | points = points.clone() * texture_size
405 |
406 | # We use 3*sigma^2 here, to include most of the gaussian
407 | footprint = math.floor(sigma.sqrt().item()) * num_std
408 | footprint = footprint + 1 if footprint % 2 == 0 else footprint
409 | half_footprint = int((footprint - 1) / 2)
410 |
411 | point_in_middle_of_footprint = points - points.floor() + half_footprint
412 |
413 | footprint_origin_in_original_image = (points - half_footprint).floor() # [Y, X]
414 |
415 | y, x = torch.meshgrid(
416 | torch.arange(0, footprint, device=device),
417 | torch.arange(0, footprint, device=device),
418 | indexing="ij",
419 | )
420 |
421 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1))
422 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1))
423 |
424 | y_dist = y - point_in_middle_of_footprint[:, 0:1].unsqueeze(-1)
425 | x_dist = x - point_in_middle_of_footprint[:, 1:2].unsqueeze(-1)
426 |
427 | dist = y_dist * y_dist + x_dist * x_dist
428 | dist = torch.exp(-torch.pow(dist / sigma, 2))
429 |
430 | wo = footprint_origin_in_original_image.int()
431 | rect_start = torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device)
432 | rect_end = (
433 | torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) + footprint
434 | )
435 |
436 | if (wo[:, 0] < 0).any():
437 | rect_start[:, 0] = torch.where(wo[:, 0] < 0, wo.abs()[:, 0], rect_start[:, 0])
438 | wo[:, 0] = torch.where(wo[:, 0] < 0, 0, wo[:, 0])
439 |
440 | if (wo[:, 1] < 0).any():
441 | rect_start[:, 1] = torch.where(wo[:, 1] < 1, wo.abs()[:, 1], rect_start[:, 1])
442 | wo[:, 1] = torch.where(wo[:, 1] < 0, 0, wo[:, 1])
443 |
444 | if (wo[:, 0] + footprint >= texture_size[0]).any():
445 | rect_end[:, 0] = torch.where(
446 | wo[:, 0] + footprint >= texture_size[0],
447 | texture_size[0] - wo[:, 0],
448 | rect_end[:, 0],
449 | )
450 |
451 | if (wo[:, 1] + footprint >= texture_size[1]).any():
452 | rect_end[:, 1] = torch.where(
453 | wo[:, 1] + footprint >= texture_size[1],
454 | texture_size[1] - wo[:, 1],
455 | rect_end[:, 1],
456 | )
457 |
458 | re = rect_end.clone()
459 | rs = rect_start.clone()
460 |
461 | for i in range(points.shape[0]):
462 | tex[
463 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0],
464 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1],
465 | ] = tex[
466 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0],
467 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1],
468 | ].clone() * (
469 | 1 - dist[i, rs[i, 0] : re[i, 0], rs[i, 1] : re[i, 1]]
470 | )
471 |
472 | return (1 - tex).T
473 |
474 |
475 | def rasterize_points_baked_softor(
476 | points: torch.tensor,
477 | sigma: float,
478 | texture_size: torch.tensor,
479 | device: torch.cuda.device = torch.device("cuda"),
480 | ) -> torch.tensor:
481 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device)
482 |
483 | # Somewhere between [0, texture_size] but in float
484 | points = points.clone() * texture_size
485 |
486 | for i in range(points.shape[0]):
487 | # Generate x, y indices
488 | x, y = torch.meshgrid(
489 | torch.arange(0, texture_size[1], device=device),
490 | torch.arange(0, texture_size[0], device=device),
491 | indexing="ij",
492 | )
493 |
494 | y_dist = y - points[i, 0:1].unsqueeze(-1)
495 | x_dist = x - points[i, 1:2].unsqueeze(-1)
496 |
497 | point_distances = (
498 | y_dist * y_dist + x_dist * x_dist
499 | ) # / (texture_size * texture_size).sum().sqrt()
500 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2))
501 | tex = tex.clone() * (1 - point_distances)
502 |
503 | return 1 - tex
504 |
505 |
506 | # We assume points to be in camera space [0, 1]
507 | def rasterize_points_baked_sum(
508 | points: torch.tensor,
509 | sigma: float,
510 | texture_size: torch.tensor,
511 | device: torch.cuda.device = torch.device("cuda"),
512 | ) -> torch.tensor:
513 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device)
514 |
515 | # Somewhere between [0, texture_size] but in float
516 | points = points.clone() * texture_size
517 |
518 | for i in range(points.shape[0]):
519 | # Generate x, y indices
520 | x, y = torch.meshgrid(
521 | torch.arange(0, texture_size[1], device=device),
522 | torch.arange(0, texture_size[0], device=device),
523 | indexing="ij",
524 | )
525 |
526 | y_dist = y - points[i, 0:1].unsqueeze(-1)
527 | x_dist = x - points[i, 1:2].unsqueeze(-1)
528 |
529 | point_distances = (
530 | y_dist * y_dist + x_dist * x_dist
531 | ) # / (texture_size * texture_size).sum().sqrt()
532 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2))
533 | tex += point_distances
534 |
535 | return tex
536 |
537 |
538 | def subsampled_point_raster(ndc_points, num_subsamples, sigma, sensor_size):
539 | subsampled_rastered_depth = []
540 | for i in range(num_subsamples):
541 | rastered_depth = rasterize_depth(
542 | ndc_points[:, 0:2], ndc_points[:, 2:3], sigma, sensor_size // 2**i
543 | )
544 | rastered_depth = softor(rastered_depth, keepdim=True)
545 | # rastered_depth = (rastered_depth - rastered_depth.min()) / (
546 | # rastered_depth.max() - rastered_depth.min()
547 | # )
548 | subsampled_rastered_depth.append(rastered_depth)
549 | return subsampled_rastered_depth
550 |
551 |
552 | def get_mpl_colormap(cmap):
553 | import matplotlib.pyplot as plt
554 |
555 | # Initialize the matplotlib color map
556 | sm = plt.cm.ScalarMappable(cmap=cmap)
557 |
558 | # Obtain linear color range
559 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1]
560 |
561 | return color_range.reshape(256, 1, 3)
562 |
563 |
564 | def test_point_reg(reduce_overlap: bool = True):
565 | import cv2
566 | import numpy as np
567 | from tqdm import tqdm
568 | import imageio
569 | import matplotlib.colors
570 | import timeit
571 | import matplotlib.pyplot as plt
572 |
573 | device = "cuda" if torch.cuda.is_available() else "cpu"
574 |
575 | points = torch.rand([500, 2], device=device)
576 | points.requires_grad = True
577 | sigma = torch.tensor([15.0], device=device) ** 2
578 | texture_size = torch.tensor([512, 512], device=device)
579 | loss_func = torch.nn.L1Loss()
580 |
581 | opt_steps = 200
582 |
583 | optim = torch.optim.Adam([{"params": points, "lr": 0.001}])
584 |
585 | images = []
586 | for i in tqdm(range(opt_steps)):
587 | optim.zero_grad()
588 |
589 | summed = baked_sum_2(points, sigma, texture_size)
590 | softored = baked_softor_2(points, sigma, texture_size)
591 |
592 | # rasterized_points = rasterize_points(points, sigma, texture_size)
593 | # softored = softor(rasterized_points)
594 | # summed = rasterized_points.sum(dim=0)
595 |
596 | loss = (
597 | loss_func(softored, summed)
598 | if reduce_overlap
599 | else -loss_func(softored, summed)
600 | )
601 | print(loss.item())
602 | loss.backward()
603 | optim.step()
604 |
605 | with torch.no_grad():
606 | points[points >= 1.0] = 0.999
607 | points[points <= 0.0] = 0.001
608 |
609 | # Apply custom colormap
610 | colors = [(0.0, 0.1921, 0.4156), (0, 0.69, 0.314)] # R -> G -> B
611 |
612 | # fig = plt.figure(frameon=False)
613 | # fig.set_size_inches(10, 10)
614 | # ax = plt.Axes(fig, [0., 0., 1., 1.])
615 | # ax.set_axis_off()
616 | # ax.set_aspect(aspect='equal')
617 | # ax.set_facecolor(colors[0])
618 | # fig.add_axes(ax)
619 |
620 | # ax.scatter(points.detach().cpu().numpy()[:, 0], points.detach().cpu().numpy()[:, 1], s=60.0*10, color=colors[0])
621 | # fig.canvas.draw()
622 | # img_plot = np.array(fig.canvas.renderer.buffer_rgba())
623 | # np_points = img_plot
624 | # images.append(np_points)
625 | bla = cv2.applyColorMap(
626 | (softored.detach().cpu().numpy() * 255).astype(np.uint8),
627 | cv2.COLORMAP_VIRIDIS,
628 | )
629 | cv2.imshow("Show", bla)
630 | cv2.waitKey(1)
631 | # if i == 0 or i == opt_steps - 1:
632 | # fig.savefig("assets/point_reduced_overlap{0}.eps".format(i) if reduce_overlap else "assets/point_increased_overlap{0}.eps".format(i),
633 | # facecolor=ax.get_facecolor(),
634 | # edgecolor='none',
635 | # bbox_inches = 'tight',
636 | # pad_inches=0)
637 | plt.close()
638 |
639 | # cv2.imshow("Optim Lines", np_points)
640 | # cv2.waitKey(1)
641 | # lines.requires_grad = True
642 | # imageio.v3.imwrite("assets/point_regularization.mp4", np.stack(images, axis=0), fps=25)
643 |
644 |
645 | def test_line_reg():
646 | import cv2
647 | import numpy as np
648 | from tqdm import tqdm
649 | import imageio
650 | import matplotlib.colors
651 | import matplotlib.pyplot as plt
652 |
653 | device = "cuda" if torch.cuda.is_available() else "cpu"
654 |
655 | # We define line as:
656 | # P0: x + -t*d
657 | # P1: x + t*d
658 | # Where d is the direction, x the location vector and t half its length.
659 | # We want to optimize the location vector x of all lines, such that they do not overlap.
660 |
661 | num_lines = 50
662 | sigma = 10.0
663 | opt_steps = 250
664 |
665 | t = torch.tensor([0.5], device=device)
666 | direction = torch.rand([2], device=device).unsqueeze(0)
667 | direction = direction / direction.norm()
668 |
669 | location_vector = (
670 | (torch.rand([num_lines, 2], device=device) - 0.5) * 2.0 / 10.0
671 | ) # Every line should be roughly in the middle of our frame
672 | location_vector.requires_grad = True
673 |
674 | sigma = torch.tensor([sigma], device=device)
675 | texture_size = torch.tensor([512, 512], device=device)
676 | loss_func = torch.nn.L1Loss()
677 |
678 | optim = torch.optim.Adam([{"params": location_vector, "lr": 0.005}])
679 |
680 | images = []
681 | for i in tqdm(range(opt_steps)):
682 | optim.zero_grad()
683 |
684 | p0 = location_vector + t * direction
685 | p1 = location_vector - t * direction
686 | lines = torch.concat([p0.unsqueeze(-1), p1.unsqueeze(-1)], dim=-1).transpose(
687 | 1, 2
688 | )
689 |
690 | rasterized_lines = rasterize_lines(lines, sigma, texture_size)
691 |
692 | softored = softor(rasterized_lines)
693 | summed = rasterized_lines.sum(dim=0)
694 |
695 | loss = loss_func(softored, summed)
696 | loss.backward()
697 | optim.step()
698 |
699 | with torch.no_grad():
700 | location_vector[p0 > 1.0] -= 0.01
701 | location_vector[p0 < -1.0] += 0.01
702 | location_vector[p1 > 1.0] -= 0.01
703 | location_vector[p1 < -1.0] += 0.01
704 |
705 | colors = [(0.0, 0.1921, 0.4156), (0, 0.69, 0.314)] # R -> G -> B
706 | fig = plt.figure(frameon=False)
707 | fig.set_size_inches(10, 10)
708 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
709 | ax.set_xlim([-1, 1])
710 | ax.set_ylim([-1, 1])
711 | ax.set_axis_off()
712 | ax.set_aspect(aspect="equal")
713 | fig.add_axes(ax)
714 |
715 | lines_copy = lines.transpose(1, 2).detach().cpu().numpy()
716 | for j in range(lines_copy.shape[0]):
717 | ax.plot(
718 | lines_copy[j, 0, :],
719 | lines_copy[j, 1, :],
720 | c=colors[0],
721 | linewidth=9.5,
722 | solid_capstyle="round",
723 | ) # c=colors[0], linewidth=60)
724 |
725 | fig.canvas.draw()
726 | img_plot = np.array(fig.canvas.renderer.buffer_rgba())
727 | np_points = img_plot
728 | images.append(np_points)
729 |
730 | if i == 0 or i == opt_steps - 1:
731 | fig.savefig(
732 | "assets/line_reduced_overlap{0}.eps".format(i),
733 | facecolor=ax.get_facecolor(),
734 | edgecolor="none",
735 | bbox_inches="tight",
736 | pad_inches=0,
737 | )
738 | plt.close()
739 |
740 | imageio.v3.imwrite(
741 | "assets/line_regularization.mp4", np.stack(images, axis=0), fps=25
742 | )
743 | # optimize("line_regularization.gif")
744 |
745 |
746 | def main():
747 | import matplotlib.pyplot as plt
748 |
749 | device = "cuda" if torch.cuda.is_available() else "cpu"
750 |
751 | points_a = torch.rand(10, 2, device=device) # (Y, X)
752 |
753 | # points_a = torch.tensor([[0.565, 0.5555]], device=device)
754 | # points_b = torch.tensor([[0.51,0.51]], device=device)
755 |
756 | texture_size = torch.tensor([100, 100], device=device) # (Y, X)
757 |
758 | sigma = torch.tensor([10.0], device=device) ** 2
759 | sum = baked_sum(points_a, sigma, texture_size, device=device)
760 | # sum_og = baked_sum(points_b, sigma, texture_size, device=device)
761 | sum_og = rasterize_points(points_a, sigma, texture_size, device=device).sum(dim=0)
762 | # sum_og = rasterize_points(points_b, sigma, texture_size, device=device).sum(dim=0)
763 |
764 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
765 | ax1.imshow(sum_og.detach().cpu().numpy())
766 | ax2.imshow(sum.detach().cpu().numpy())
767 | ax3.imshow((sum_og - sum).detach().cpu().numpy())
768 | ax1.set_title("OG")
769 | ax2.set_title("NEW")
770 | ax3.set_title("DIFF")
771 | fig.show()
772 | plt.show()
773 |
774 |
775 | def time_it():
776 | import timeit, functools
777 |
778 | device = "cuda" if torch.cuda.is_available() else "cpu"
779 |
780 | points = torch.rand([500, 2], device=device)
781 | points.requires_grad = True
782 | sigma = torch.tensor([10.0], device=device) ** 2
783 | texture_size = torch.tensor([512, 512], device=device)
784 | repeats = 50
785 |
786 | og_sum = timeit.Timer(
787 | lambda: rasterize_points(points, sigma, texture_size, device).sum(dim=0)
788 | )
789 | print(og_sum.timeit(repeats))
790 |
791 | t_baked_sum = timeit.Timer(
792 | lambda: baked_sum(points, sigma, texture_size, 4, device)
793 | )
794 | print(t_baked_sum.timeit(repeats))
795 |
796 | t_baked_sum2 = timeit.Timer(
797 | lambda: baked_sum_2(points, sigma, texture_size, 4, device)
798 | )
799 | print(t_baked_sum2.timeit(repeats))
800 |
801 | og_softor = timeit.Timer(
802 | lambda: softor(rasterize_points(points, sigma, texture_size, device))
803 | )
804 | print(og_softor.timeit(repeats))
805 |
806 | t_baked_softor1 = timeit.Timer(
807 | lambda: baked_softor(points, sigma, texture_size, 4, device)
808 | )
809 | print(t_baked_softor1.timeit(repeats))
810 |
811 | t_baked_softor2 = timeit.Timer(
812 | lambda: baked_softor_2(points, sigma, texture_size, 4, device)
813 | )
814 | print(t_baked_softor2.timeit(repeats))
815 |
816 |
817 | if __name__ == "__main__":
818 | # time_it()
819 | test_point_reg(reduce_overlap=True)
820 | # test_point_reg(reduce_overlap=False)
821 | # test_line_reg()
822 | # main()
823 |
--------------------------------------------------------------------------------