├── examples ├── 08_optimization.py ├── 10_pattern_creation.py ├── 07_gradient_accumulation.py ├── 09_point_pattern_optimization.py ├── 11_domain_specific_pattern_optim.py ├── 01_hello_world.py ├── 03_parent_child.py ├── 04_material_randomization.py ├── 05_light_randomization.py ├── 06_animation.py ├── 02_general_transformations.py ├── 06_sampling.py └── vocalfold_scene.py ├── fireflies ├── utils │ ├── __init__.py │ ├── transforms.py │ ├── torch_grads.py │ ├── warnings.py │ ├── intersections.py │ ├── io.py │ ├── math.py │ └── laser_estimation.py ├── graphics │ ├── __init__.py │ ├── depth.py │ └── rasterization.py ├── emitter │ ├── __init__.py │ └── base.py ├── __init__.py ├── material │ ├── __init__.py │ └── base.py ├── projection │ ├── __init__.py │ ├── camera.py │ └── laser.py ├── entity │ ├── __init__.py │ ├── shape.py │ ├── curve.py │ ├── flame.py │ ├── mesh.py │ └── base.py ├── postprocessing │ ├── __init__.py │ ├── base.py │ ├── postprocessor.py │ ├── white_noise.py │ ├── gauss_blur.py │ └── apply_silhouette.py ├── sampling │ ├── __init__.py │ ├── uniform.py │ ├── gaussian_distribution.py │ ├── uniform_integer.py │ ├── uniform_scalar_to_vec3.py │ ├── animation.py │ ├── base.py │ ├── noise_texture_lerp.py │ └── poisson.py └── scene.py ├── assets ├── eval.gif ├── logo.png ├── train.gif └── teaser.png ├── setup.py ├── LICENSE ├── requirements.txt ├── .gitignore ├── README.md └── main.py /examples/08_optimization.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fireflies/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/10_pattern_creation.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fireflies/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fireflies/utils/transforms.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/07_gradient_accumulation.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/09_point_pattern_optimization.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/11_domain_specific_pattern_optim.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fireflies/emitter/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Light 2 | -------------------------------------------------------------------------------- /fireflies/__init__.py: -------------------------------------------------------------------------------- 1 | from fireflies.scene import Scene 2 | -------------------------------------------------------------------------------- /fireflies/material/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Material 2 | -------------------------------------------------------------------------------- /assets/eval.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/eval.gif -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/logo.png -------------------------------------------------------------------------------- /assets/train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/train.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Henningson/Fireflies/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /fireflies/projection/__init__.py: -------------------------------------------------------------------------------- 1 | from .camera import Camera 2 | from .laser import Laser 3 | -------------------------------------------------------------------------------- /fireflies/entity/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Transformable 2 | from .curve import Curve 3 | from .mesh import Mesh 4 | -------------------------------------------------------------------------------- /fireflies/utils/torch_grads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | 5 | def retain_grads(non_leaf_tensor: List[torch.tensor]) -> None: 6 | for tensor in non_leaf_tensor: 7 | tensor.retain_grad() 8 | -------------------------------------------------------------------------------- /fireflies/postprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BasePostProcessingFunction 2 | from .white_noise import WhiteNoise 3 | from .postprocessor import PostProcessor 4 | from .apply_silhouette import ApplySilhouette 5 | from .gauss_blur import GaussianBlur 6 | -------------------------------------------------------------------------------- /fireflies/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Sampler 2 | from .gaussian_distribution import GaussianSampler 3 | from .uniform import UniformSampler 4 | from .uniform_integer import UniformIntegerSampler 5 | from .uniform_scalar_to_vec3 import UniformScalarToVec3Sampler 6 | from .animation import AnimationSampler 7 | from .noise_texture_lerp import NoiseTextureLerpSampler 8 | -------------------------------------------------------------------------------- /fireflies/emitter/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | import fireflies.utils.math 5 | import fireflies.entity 6 | 7 | 8 | class Light(fireflies.entity.Transformable): 9 | def __init__( 10 | self, 11 | name: str, 12 | device: torch.cuda.device = torch.device("cuda"), 13 | ): 14 | super(Light, self).__init__(name, device) 15 | -------------------------------------------------------------------------------- /fireflies/postprocessing/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import fireflies.sampling 4 | 5 | 6 | class BasePostProcessingFunction: 7 | def __init__(self, probability: float): 8 | self._probability = probability 9 | 10 | def apply(self, image: np.array) -> np.array: 11 | if random.uniform(0, 1) < self._probability: 12 | return self.post_process(image) 13 | 14 | return image 15 | 16 | @NotImplementedError 17 | def post_process(self, image: np.array) -> np.array: 18 | return None 19 | -------------------------------------------------------------------------------- /fireflies/postprocessing/postprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .base import BasePostProcessingFunction 3 | 4 | from typing import List 5 | 6 | 7 | class PostProcessor: 8 | def __init__( 9 | self, 10 | post_process_funcs: List[BasePostProcessingFunction], 11 | ): 12 | self._post_process_functs = post_process_funcs 13 | 14 | def post_process(self, image: np.array) -> np.array: 15 | image_copy = image.copy() 16 | for func in self._post_process_functs: 17 | image_copy = func.apply(image_copy) 18 | 19 | return image_copy 20 | -------------------------------------------------------------------------------- /fireflies/sampling/uniform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.sampling.base as base 3 | import fireflies.utils.math 4 | 5 | 6 | class UniformSampler(base.Sampler): 7 | def __init__( 8 | self, 9 | min: torch.tensor, 10 | max: torch.tensor, 11 | eval_step_size: float = 0.01, 12 | device: torch.cuda.device = torch.device("cuda"), 13 | ) -> None: 14 | super(UniformSampler, self).__init__(min, max, eval_step_size, device) 15 | 16 | def sample_train(self) -> torch.tensor: 17 | return fireflies.utils.math.randomBetweenTensors( 18 | self._min_range, self._max_range 19 | ) 20 | -------------------------------------------------------------------------------- /fireflies/postprocessing/white_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import fireflies.postprocessing.base as base 3 | 4 | 5 | class WhiteNoise(base.BasePostProcessingFunction): 6 | def __init__( 7 | self, 8 | mean: float, 9 | std: float, 10 | probability: float, 11 | ): 12 | super(WhiteNoise, self).__init__(probability) 13 | self._mean = mean 14 | self._std = std 15 | 16 | def post_process(self, image: np.array) -> np.array: 17 | image += np.random.normal( 18 | np.ones_like(image) * self._mean, np.ones_like(image) * self._std 19 | ) 20 | return np.clip(image, 0, 1) 21 | -------------------------------------------------------------------------------- /fireflies/sampling/gaussian_distribution.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.sampling.base as base 3 | 4 | 5 | class GaussianSampler(base.Sampler): 6 | def __init__( 7 | self, 8 | min: torch.tensor, 9 | max: torch.tensor, 10 | mean: torch.tensor, 11 | std: torch.tensor, 12 | eval_step_size: float = 0.01, 13 | device: torch.cuda.device = torch.device("cuda"), 14 | ) -> None: 15 | super(GaussianSampler, self).__init__(min, max, eval_step_size, device) 16 | self._mean = mean 17 | self._std = std 18 | 19 | def sample_train(self) -> torch.tensor: 20 | return torch.normal(self._mean, self._std) 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from os import path 3 | 4 | 5 | this_directory = path.abspath(path.dirname(__file__)) 6 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 7 | long_description = f.read() 8 | 9 | with open("requirements.txt") as f: 10 | install_requires = f.read().strip().split("\n") 11 | 12 | setup( 13 | name="Fireflies", 14 | version="1.0", 15 | description="A module for randomizing mitsuba scenes and their parameters originally created for Structured Light Endoscopy.", 16 | author="Jann-Ole Henningson", 17 | author_email="jann-ole.henningson@fau.de", 18 | url="https://github.com/Henningson/Fireflies", 19 | packages=["fireflies"], 20 | install_requires=install_requires, 21 | packages=find_packages(), 22 | ) 23 | -------------------------------------------------------------------------------- /fireflies/postprocessing/gauss_blur.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import fireflies.postprocessing.base as base 3 | import kornia 4 | import torch 5 | 6 | 7 | class GaussianBlur(base.BasePostProcessingFunction): 8 | def __init__( 9 | self, 10 | kernel_size: tuple[int, int], 11 | sigma: tuple[float, float], 12 | probability: float, 13 | ): 14 | super(GaussianBlur, self).__init__(probability) 15 | self._kernel_size = kernel_size 16 | self._sigma = sigma 17 | 18 | def post_process(self, image: np.array) -> np.array: 19 | image = ( 20 | kornia.filters.gaussian_blur2d( 21 | torch.tensor(image).unsqueeze(0).unsqueeze(0), 22 | self._kernel_size, 23 | self._sigma, 24 | ) 25 | .squeeze() 26 | .numpy() 27 | ) 28 | return image 29 | -------------------------------------------------------------------------------- /fireflies/entity/shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List 4 | 5 | from mesh import Mesh 6 | import fireflies.utils.math 7 | 8 | 9 | @NotImplementedError 10 | class ShapeModel(Mesh): 11 | def __init__( 12 | self, 13 | name: str, 14 | vertex_data: torch.tensor, 15 | face_data: torch.tensor, 16 | device: torch.cuda.device = torch.device("cuda"), 17 | ): 18 | super(Mesh, self).__init__(name, vertex_data, face_data, device) 19 | self._device = device 20 | self._name = name 21 | 22 | def load_animation(self): 23 | return None 24 | 25 | def get_model_params(self) -> dict: 26 | return self._model_params 27 | 28 | def set_model_params(self, dict: dict) -> None: 29 | assert NotImplementedError 30 | 31 | def get_vertices(self): 32 | assert NotImplementedError 33 | -------------------------------------------------------------------------------- /examples/01_hello_world.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | 5 | mi.set_variant("cuda_ad_rgb") 6 | 7 | import torch 8 | import fireflies 9 | 10 | 11 | def render_to_opencv(render): 12 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 13 | return (render * 255).astype(np.uint8) 14 | 15 | 16 | if __name__ == "__main__": 17 | path = "scenes/hello_world/hello_world.xml" 18 | 19 | mitsuba_scene = mi.load_file(path) 20 | mitsuba_params = mi.traverse(mitsuba_scene) 21 | fireflies_scene = fireflies.Scene(mitsuba_params) 22 | 23 | fireflies_scene.mesh_at(0).rotate_z(-np.pi, np.pi) 24 | 25 | fireflies_scene.train() 26 | for i in range(100): 27 | fireflies_scene.randomize() 28 | 29 | render = mi.render(mitsuba_scene, spp=10) 30 | 31 | cv2.imshow("a", render_to_opencv(render)) 32 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render)) 33 | cv2.waitKey(10) 34 | -------------------------------------------------------------------------------- /fireflies/sampling/uniform_integer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.sampling.base as base 3 | import random 4 | 5 | 6 | class UniformIntegerSampler(base.Sampler): 7 | def __init__( 8 | self, 9 | min_integer: int, 10 | max_integer: int, 11 | eval_step_size: int = 1, 12 | device: torch.cuda.device = torch.device("cuda"), 13 | ) -> None: 14 | """ 15 | Will generate samples from the integer interval given by [min_integer, ..., max_integer) similar to how range() is defined in python. 16 | """ 17 | super(UniformIntegerSampler, self).__init__(min, max, eval_step_size, device) 18 | self._current_step = 0 19 | 20 | def sample_eval(self) -> int: 21 | sample = self._current_step 22 | self._current_step += self._eval_step_size 23 | 24 | if self._current_step >= self._max_range: 25 | self._current_step = self._min_range 26 | 27 | return sample 28 | 29 | def sample_train(self) -> int: 30 | return random.randint(0, self._max_range - 1) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Henningson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/03_parent_child.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | import mitsuba as mi 6 | 7 | mi.set_variant("cuda_ad_rgb") 8 | 9 | import fireflies 10 | 11 | 12 | def render_to_opencv(render): 13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 14 | return (render * 255).astype(np.uint8) 15 | 16 | 17 | if __name__ == "__main__": 18 | path = "scenes/parent_child/parent_child.xml" 19 | 20 | mi_scene = mi.load_file(path) 21 | mi_params = mi.traverse(mi_scene) 22 | ff_scene = fireflies.Scene(mi_params) 23 | 24 | cone = ff_scene.mesh("mesh-Cone") 25 | sphere = ff_scene.mesh("mesh-Sphere") 26 | 27 | # Add sphere as the cones parent 28 | cone.setParent(sphere) 29 | 30 | # Also let the cone be randomizable, since it wouldn't be randomized if this is not set. 31 | cone.set_randomizable(True) 32 | 33 | # Rotate everything around the z-axis. 34 | sphere.rotate_z(-np.pi, np.pi) 35 | 36 | ff_scene.eval() 37 | for i in range(100): 38 | ff_scene.randomize() 39 | 40 | render = mi.render(mi_scene, spp=10) 41 | 42 | cv2.imshow("a", render_to_opencv(render)) 43 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render)) 44 | cv2.waitKey(10) 45 | -------------------------------------------------------------------------------- /fireflies/postprocessing/apply_silhouette.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import fireflies.postprocessing.base as base 4 | import cv2 5 | import random 6 | import kornia 7 | import torch 8 | 9 | 10 | class ApplySilhouette(base.BasePostProcessingFunction): 11 | def __init__( 12 | self, 13 | probability: float = 2.0, 14 | ): 15 | super(ApplySilhouette, self).__init__(probability) 16 | 17 | def post_process(self, image: np.array) -> np.array: 18 | silhouette_image = np.zeros_like(image) 19 | spawning_rect_x = [100, 200] 20 | spawning_rect_y = [200, 300] 21 | radius_interval = [170, 230] 22 | 23 | cc_x = random.randint(spawning_rect_x[0], spawning_rect_x[1]) 24 | cc_y = random.randint(spawning_rect_y[0], spawning_rect_y[1]) 25 | radius = random.randint(radius_interval[0], radius_interval[1]) 26 | silhouette_image = cv2.circle( 27 | silhouette_image, (cc_x, cc_y), radius, color=1, thickness=-1 28 | ) 29 | 30 | silhouette_image = ( 31 | kornia.filters.gaussian_blur2d( 32 | torch.tensor(silhouette_image).unsqueeze(0).unsqueeze(0), 33 | (11, 11), 34 | (5, 5), 35 | ) 36 | .squeeze() 37 | .numpy() 38 | ) 39 | 40 | return image * silhouette_image 41 | -------------------------------------------------------------------------------- /examples/04_material_randomization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | import mitsuba as mi 6 | 7 | mi.set_variant("cuda_ad_rgb") 8 | 9 | import fireflies 10 | 11 | 12 | def render_to_opencv(render): 13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 14 | return (render * 255).astype(np.uint8) 15 | 16 | 17 | if __name__ == "__main__": 18 | path = "examples/scenes/hello_world/hello_world.xml" 19 | 20 | mi_scene = mi.load_file(path) 21 | mi_params = mi.traverse(mi_scene) 22 | ff_scene = fireflies.Scene(mi_params) 23 | 24 | # Lets randomize the color of the cube. 25 | min_color = torch.tensor([0.2, 0.3, 0.2], device=ff_scene._device) 26 | max_color = torch.tensor([0.8, 1.0, 0.8], device=ff_scene._device) 27 | 28 | material = ff_scene.material("mat-Material") 29 | material.add_vec3_key("brdf_0.base_color.value", min_color, max_color) 30 | 31 | # Keys for randomizable attributes can be accessed using: 32 | # material.vec3_attributes().keys() 33 | # material.float_attributes().keys() 34 | 35 | ff_scene.train() 36 | for i in range(100): 37 | ff_scene.randomize() 38 | 39 | render = mi.render(mi_scene, spp=10) 40 | 41 | cv2.imshow("a", render_to_opencv(render)) 42 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render)) 43 | cv2.waitKey(10) 44 | -------------------------------------------------------------------------------- /fireflies/sampling/uniform_scalar_to_vec3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.sampling.base as base 3 | import fireflies.utils.math 4 | 5 | 6 | class UniformScalarToVec3Sampler(base.Sampler): 7 | def __init__( 8 | self, 9 | min: torch.tensor, 10 | max: torch.tensor, 11 | eval_step_size: float = 0.01, 12 | device: torch.cuda.device = torch.device("cuda"), 13 | ) -> None: 14 | super(UniformScalarToVec3Sampler, self).__init__( 15 | min, max, eval_step_size, device 16 | ) 17 | 18 | def sample_train(self) -> torch.tensor: 19 | scalar_value = fireflies.utils.math.randomBetweenTensors( 20 | self._min_range, self._max_range 21 | ) 22 | return torch.tensor( 23 | [scalar_value, scalar_value, scalar_value], device=self._device 24 | ) 25 | 26 | def sample_eval(self) -> torch.tensor: 27 | if (self._min_range == self._max_range).all(): 28 | return torch.tensor( 29 | [self._min_range, self._min_range, self._min_range], device=self._device 30 | ) 31 | 32 | sample = self._current_step 33 | self._current_step += self._eval_step_size 34 | 35 | if (self._current_step > self._max_range).any(): 36 | self._current_step = self._min_range 37 | 38 | return torch.tensor([sample, sample, sample], device=self._device) 39 | -------------------------------------------------------------------------------- /examples/05_light_randomization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | import mitsuba as mi 6 | 7 | mi.set_variant("cuda_ad_rgb") 8 | 9 | import fireflies 10 | 11 | 12 | def render_to_opencv(render): 13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 14 | return (render * 255).astype(np.uint8) 15 | 16 | 17 | if __name__ == "__main__": 18 | path = "examples/scenes/parent_child/parent_child.xml" 19 | 20 | mi_scene = mi.load_file(path) 21 | mi_params = mi.traverse(mi_scene) 22 | ff_scene = fireflies.Scene(mi_params) 23 | 24 | cone = ff_scene.mesh("mesh-Cone") 25 | sphere = ff_scene.mesh("mesh-Sphere") 26 | light = ff_scene.light("emit-Light") 27 | 28 | # Add sphere as the cones parent 29 | # Also let the cone be randomizable, since it wouldn't be randomized if this is not set. 30 | cone.setParent(sphere) 31 | cone.set_randomizable(True) 32 | 33 | min_intensity = torch.tensor([150, 0, 0], device=ff_scene._device) 34 | max_intensity = torch.tensor([150, 150, 150], device=ff_scene._device) 35 | light.add_vec3_key("intensity.value", min_intensity, max_intensity) 36 | 37 | # Rotate everything around the z-axis. 38 | sphere.rotate_z(-np.pi, np.pi) 39 | 40 | ff_scene.eval() 41 | for i in range(100): 42 | ff_scene.randomize() 43 | 44 | render = mi.render(mi_scene, spp=10) 45 | 46 | cv2.imshow("a", render_to_opencv(render)) 47 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render)) 48 | cv2.waitKey(10) 49 | -------------------------------------------------------------------------------- /examples/06_animation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | 5 | mi.set_variant("cuda_ad_rgb") 6 | 7 | import torch 8 | import fireflies 9 | import fireflies.sampling 10 | 11 | 12 | def render_to_opencv(render): 13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 14 | return (render * 255).astype(np.uint8) 15 | 16 | 17 | # Let's define an animation function that will visualize a sine wave. 18 | # You can define completely arbitrary functions. 19 | def animation_function(vertices: torch.tensor, time: float) -> torch.tensor: 20 | # Let's not change the incoming vertices in place. 21 | vertices_clone = vertices.clone() 22 | 23 | # Change z coordinate of plane via the sin(x_coordinate + time) 24 | wave_direction = 0 25 | 26 | vertices_clone[:, 1] = ( 27 | vertices_clone[:, 1] 28 | + torch.sin(vertices_clone[:, 2] * 10.0 + time * 20.0) / 10.0 29 | ) 30 | 31 | return vertices_clone 32 | 33 | 34 | if __name__ == "__main__": 35 | path = "examples/scenes/animation/animation.xml" 36 | 37 | mitsuba_scene = mi.load_file(path) 38 | mitsuba_params = mi.traverse(mitsuba_scene) 39 | fireflies_scene = fireflies.Scene(mitsuba_params) 40 | 41 | mesh = fireflies_scene.mesh("mesh-Animation") 42 | mesh.add_animation_func( 43 | animation_function, 44 | torch.tensor([0.0]).to(fireflies_scene.device()), 45 | torch.tensor([2 * np.pi]).to(fireflies_scene.device()), 46 | ) 47 | 48 | fireflies_scene.eval() 49 | for i in range(1000): 50 | fireflies_scene.randomize() 51 | render = mi.render(mitsuba_scene, spp=12) 52 | cv2.imshow("Animation Example", render_to_opencv(render)) 53 | cv2.waitKey(25) -------------------------------------------------------------------------------- /examples/02_general_transformations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | 5 | mi.set_variant("cuda_ad_rgb") 6 | 7 | import torch 8 | import fireflies 9 | 10 | 11 | def render_to_opencv(render): 12 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 13 | return (render * 255).astype(np.uint8) 14 | 15 | 16 | if __name__ == "__main__": 17 | path = "scenes/hello_world/hello_world.xml" 18 | 19 | mi_scene = mi.load_file(path) 20 | mi_params = mi.traverse(mi_scene) 21 | ff_scene = fireflies.Scene(mi_params) 22 | 23 | mesh = ff_scene.mesh_at(0) 24 | 25 | # Rotations 26 | mesh.rotate_x(-0.5, 0.5) 27 | mesh.rotate_y(-0.5, 0.5) 28 | mesh.rotate_z(-0.5, 0.5) 29 | mesh.rotate( 30 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device), 31 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device), 32 | ) 33 | 34 | # Translations 35 | mesh.translate_x(-0.5, 0.5) 36 | mesh.translate_y(-0.5, 0.5) 37 | mesh.translate_z(-0.5, 0.5) 38 | mesh.translate( 39 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device), 40 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device), 41 | ) 42 | 43 | # Scale 44 | mesh.scale_x(-0.5, 0.5) 45 | mesh.scale_y(-0.5, 0.5) 46 | mesh.scale_z(-0.5, 0.5) 47 | mesh.scale( 48 | torch.tensor([-0.5, -0.5, -0.5], device=ff_scene._device), 49 | torch.tensor([0.5, 0.5, 0.5], device=ff_scene._device), 50 | ) 51 | 52 | # There's more in later examples :) 53 | 54 | ff_scene.train() 55 | for i in range(100): 56 | ff_scene.randomize() 57 | 58 | render = mi.render(mi_scene, spp=10) 59 | 60 | cv2.imshow("a", render_to_opencv(render)) 61 | cv2.imwrite(f"im/{i:05d}.png", render_to_opencv(render)) 62 | cv2.waitKey(10) 63 | -------------------------------------------------------------------------------- /fireflies/sampling/animation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.sampling.base as base 3 | import random 4 | 5 | 6 | class AnimationSampler(base.Sampler): 7 | def __init__( 8 | self, 9 | min_integer_train: int, 10 | max_integer_train: int, 11 | min_integer_eval: int, 12 | max_integer_eval: int, 13 | eval_step_size: int = 1, 14 | device: torch.cuda.device = torch.device("cuda"), 15 | ) -> None: 16 | """ 17 | Will generate samples from the integer interval given by [min_integer, ..., max_integer) similar to how range() is defined in python. 18 | Assuming we have a set of train and eval objs that were loaded in the Mesh class using load_train_objs() and load_eval_objs(). 19 | """ 20 | super(AnimationSampler, self).__init__(min_integer_train, max_integer_train, eval_step_size, device) 21 | self._min_integer_train = min_integer_train 22 | self._max_integer_train = max_integer_train 23 | self._min_integer_eval = min_integer_eval 24 | self._max_integer_eval = max_integer_eval 25 | self._current_step = min_integer_eval 26 | 27 | def sample_eval(self) -> int: 28 | sample = self._current_step 29 | self._current_step += self._eval_step_size 30 | 31 | if self._current_step > self._max_integer_eval: 32 | self._current_step = self._min_integer_eval 33 | 34 | return sample 35 | 36 | def sample_train(self) -> int: 37 | return random.randint(self._min_integer_train, self._max_integer_train - 1) 38 | 39 | def set_train_interval(self, min_integer_train: int, max_integer_train: int) -> None: 40 | self._min_integer_train = min_integer_train 41 | self._max_integer_train = max_integer_train 42 | 43 | def set_eval_interval(self, min_integer_eval: int, max_integer_eval: int) -> None: 44 | self._min_integer_eval = min_integer_eval 45 | self._max_integer_eval = max_integer_eval -------------------------------------------------------------------------------- /examples/06_sampling.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | 5 | mi.set_variant("cuda_ad_rgb") 6 | 7 | import torch 8 | import fireflies 9 | import fireflies.sampling 10 | 11 | 12 | def render_to_opencv(render): 13 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 14 | return (render * 255).astype(np.uint8) 15 | 16 | 17 | # Let's define an animation function that will visualize a sine wave. 18 | # You can define completely arbitrary functions. 19 | def animation_function(vertices: torch.tensor, time: float) -> torch.tensor: 20 | # Let's not change the incoming vertices in place. 21 | vertices_clone = vertices.clone() 22 | 23 | # Change z coordinate of plane via the sin(x_coordinate + time) 24 | wave_direction = 0 25 | 26 | vertices_clone[:, 1] = ( 27 | vertices_clone[:, 1] 28 | + torch.sin(vertices_clone[:, 2] * 10.0 + time * 20.0) / 10.0 29 | ) 30 | 31 | return vertices_clone 32 | 33 | 34 | if __name__ == "__main__": 35 | path = "examples/scenes/animation/animation.xml" 36 | 37 | mitsuba_scene = mi.load_file(path) 38 | mitsuba_params = mi.traverse(mitsuba_scene) 39 | ff_scene = fireflies.Scene(mitsuba_params) 40 | 41 | mesh = ff_scene.mesh("mesh-Animation") 42 | mesh.add_animation_func( 43 | animation_function, 44 | torch.tensor([0.0]).to(ff_scene.device()), 45 | torch.tensor([2 * np.pi]).to(ff_scene.device()), 46 | ) 47 | 48 | normal_distribution_sampler = fireflies.sampling.GaussianSampler( 49 | min=torch.ones(3, device=ff_scene.device())*0.5, 50 | max=torch.ones(3, device=ff_scene.device())*1.5, 51 | mean=torch.ones(3, device=ff_scene.device())*1.0, 52 | std=torch.ones(3, device=ff_scene.device())*0.5, 53 | eval_step_size=0.01, 54 | ) 55 | mesh.set_scale_sampler(normal_distribution_sampler) 56 | 57 | ff_scene.train() 58 | for i in range(1000): 59 | ff_scene.randomize() 60 | render = mi.render(mitsuba_scene, spp=12) 61 | cv2.imwrite("a.png", render_to_opencv(render)) 62 | cv2.waitKey(25) -------------------------------------------------------------------------------- /fireflies/projection/camera.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fireflies.utils.math 3 | import fireflies.utils.transforms 4 | 5 | import fireflies.entity.base 6 | 7 | 8 | class Camera: 9 | id = 0 10 | MITSUBA_KEYS = { 11 | "fov": "x_fov", 12 | "f": "x_fov", 13 | "to_world": "to_world", 14 | "world": "to_world", 15 | } 16 | 17 | def __init__( 18 | self, 19 | transform: fireflies.entity.base.Transformable, 20 | perspective: torch.tensor, 21 | fov: float, 22 | near_clip: float = 0.01, 23 | far_clip: float = 1000.0, 24 | device: torch.cuda.device = torch.device("cuda"), 25 | ): 26 | self.device = device 27 | 28 | self._transformable = transform 29 | self._perspective = perspective 30 | self._near_clip = near_clip 31 | self._far_clip = far_clip 32 | self._fov = fov 33 | 34 | self._key = self.generate_mitsuba_key() 35 | Camera.id += 1 36 | 37 | def full_key(self, key: str): 38 | return self._key + "." + Camera.MITSUBA_KEYS[key] 39 | 40 | def key(self) -> str: 41 | return self._key 42 | 43 | def near_clip(self) -> float: 44 | return self._near_clip 45 | 46 | def generate_mitsuba_key(self) -> str: 47 | if Camera.id == 0: 48 | return "PerspectiveCamera" 49 | 50 | return "PerspectiveCamera_{0}".format(id) 51 | 52 | def far_clip(self) -> float: 53 | return self._far_clip 54 | 55 | def fov(self) -> torch.tensor: 56 | return self._fov 57 | 58 | def origin(self) -> torch.tensor: 59 | return self._transformable.origin() 60 | 61 | def world(self) -> torch.tensor: 62 | return self._transformable.world() 63 | 64 | def randomize(self) -> None: 65 | self._transformable.randomize() 66 | 67 | def pointsToNDC(self, points) -> torch.tensor: 68 | view_space_points = fireflies.utils.transforms.transform_points( 69 | points, self.world().inverse() 70 | ) 71 | ndc_points = fireflies.utils.transforms.transform_points( 72 | view_space_points, self._perspective 73 | ) 74 | return ndc_points 75 | -------------------------------------------------------------------------------- /fireflies/sampling/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler: 5 | def __init__( 6 | self, 7 | min: torch.tensor, 8 | max: torch.tensor, 9 | eval_step_size: float = 0.01, 10 | device: torch.cuda.device = torch.device("cuda"), 11 | ) -> None: 12 | self._device = device 13 | self._min_range = ( 14 | min.clone() 15 | if type(min) is torch.Tensor 16 | else torch.tensor([min], device=device) 17 | ) 18 | self._max_range = ( 19 | max.clone() 20 | if type(max) is torch.Tensor 21 | else torch.tensor([max], device=device) 22 | ) 23 | self._train = True 24 | 25 | self._eval_step_size = eval_step_size 26 | self._current_step = ( 27 | self._min_range.clone() 28 | if type(min) is torch.Tensor 29 | else torch.tensor([min], device=device) 30 | ) 31 | 32 | def set_sample_interval(self, min: torch.tensor, max: torch.tensor) -> None: 33 | self._min_range = min.clone() 34 | self._max_range = max.clone() 35 | 36 | def get_min(self) -> torch.tensor: 37 | return self._min_range 38 | 39 | def get_max(self) -> torch.tensor: 40 | return self._max_range 41 | 42 | def set_sample_max(self, max: torch.tensor) -> None: 43 | self._max_range = max.clone() 44 | 45 | def set_sample_min(self, min: torch.tensor) -> None: 46 | self._min_range = min.clone() 47 | 48 | def train(self) -> None: 49 | self._train = True 50 | 51 | def eval(self) -> None: 52 | self._train = False 53 | 54 | def sample(self) -> torch.tensor: 55 | if self._train: 56 | return self.sample_train() 57 | else: 58 | return self.sample_eval() 59 | 60 | @NotImplementedError 61 | def sample_train(self) -> torch.tensor: 62 | return None 63 | 64 | def sample_eval(self) -> torch.tensor: 65 | if (self._min_range == self._max_range).all(): 66 | return self._min_range 67 | 68 | sample = self._current_step 69 | self._current_step += self._eval_step_size 70 | 71 | if (self._current_step > self._max_range).any(): 72 | self._current_step = self._min_range 73 | 74 | return sample 75 | -------------------------------------------------------------------------------- /fireflies/utils/warnings.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import functools 3 | 4 | 5 | def RotationAssignmentWarning(func): 6 | @functools.wraps(func) 7 | def new_func(*args, **kwargs): 8 | warnings.simplefilter("always", Warning) # turn off filter 9 | warnings.warn( 10 | "This object should generally not have a transformation assignment via {}.".format( 11 | func.__name__ 12 | ), 13 | category=Warning, 14 | stacklevel=2, 15 | ) 16 | warnings.simplefilter("default", Warning) # reset filter 17 | return func(*args, **kwargs) 18 | 19 | return new_func 20 | 21 | 22 | def RelativeAssignmentWarning(func): 23 | @functools.wraps(func) 24 | def new_func(*args, **kwargs): 25 | warnings.simplefilter("always", Warning) # turn off filter 26 | warnings.warn( 27 | "This object should generally not have a parent/child assignment via {}.".format( 28 | func.__name__ 29 | ), 30 | category=Warning, 31 | stacklevel=2, 32 | ) 33 | warnings.simplefilter("default", Warning) # reset filter 34 | return func(*args, **kwargs) 35 | 36 | return new_func 37 | 38 | 39 | def TranslationAssignmentWarning(func): 40 | @functools.wraps(func) 41 | def new_func(*args, **kwargs): 42 | warnings.simplefilter("always", Warning) # turn off filter 43 | warnings.warn( 44 | "This object should generally not have a translation assignment via {}.".format( 45 | func.__name__ 46 | ), 47 | category=Warning, 48 | stacklevel=2, 49 | ) 50 | warnings.simplefilter("default", Warning) # reset filter 51 | return new_func(*args, **kwargs) 52 | 53 | 54 | def WorldAssignmentWarning(func): 55 | @functools.wraps(func) 56 | def new_func(*args, **kwargs): 57 | warnings.simplefilter("always", Warning) # turn off filter 58 | warnings.warn( 59 | "This object should generally not have a to-world matrix via {}.".format( 60 | func.__name__ 61 | ), 62 | category=Warning, 63 | stacklevel=2, 64 | ) 65 | warnings.simplefilter("default", Warning) # reset filter 66 | return new_func(*args, **kwargs) 67 | -------------------------------------------------------------------------------- /fireflies/utils/intersections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Batchwise ray-plane intersection 5 | def rayPlane(laserOrigin, laserDirection, planeOrigin, planeNormal): 6 | denom = torch.sum(planeNormal * laserDirection, axis=1) 7 | 8 | denom = torch.where(torch.abs(denom) < 0.000001, denom / denom, denom) 9 | t = torch.sum((planeOrigin - laserOrigin) * planeNormal, axis=1) / denom 10 | 11 | return t[:, None] 12 | 13 | 14 | # Batchwise Sphere-Sphere intersection: 15 | # Inputs: 16 | # - a_coords: Tensor of size NxD where N is the number of circles, and D the number of dimensions 17 | # - a_radii: radii of the spheres in a_xy 18 | # - b_coords: Tensor of size NxD where N is the number of circles, and D the number of dimensions 19 | # - b_radii: radii of the spheres in a_xy 20 | # Outputs: 21 | # - Tensor of size Nx1, with boolean values being 22 | # - TRUE when a hit occured 23 | # - FALSE otherwise 24 | 25 | 26 | def sphereSphere(a_coords, a_radius, b_coords, b_radius): 27 | dist_ab = a_coords - b_coords 28 | squared_dist = dist_ab.pow(2).sum(dim=1, keepdim=True) 29 | 30 | sum_radii = a_radius + b_radius 31 | squared_radii = sum_radii.pow(2) 32 | 33 | return squared_dist <= squared_radii 34 | 35 | 36 | if __name__ == "__main__": 37 | import numpy as np 38 | import cv2 39 | 40 | # Here we create random circles and intersect them 41 | # They're rendered green, if they do not intersect 42 | # And red, if they intersect 43 | for _ in range(1000): 44 | a_coords = torch.rand(1, 2) 45 | a_radius = torch.rand(1, 1) / 2.0 46 | b_coords = torch.rand(1, 2) 47 | b_radius = torch.rand(1, 1) / 2.0 48 | 49 | intersections = sphereSphere(a_coords, a_radius, b_coords, b_radius) 50 | 51 | image = np.zeros((512, 512, 3), dtype=np.uint8) 52 | for i in range(a_coords.shape[0]): 53 | cv2.circle( 54 | image, 55 | (a_coords[i].detach().cpu().numpy() * 512).astype(np.int), 56 | (a_radius[i, 0].detach().cpu().numpy() * 512).astype(np.int), 57 | color=(0, 0, 255) if intersections[i, 0] else (0, 255, 0), 58 | ) 59 | cv2.circle( 60 | image, 61 | (b_coords[i].detach().cpu().numpy() * 512).astype(np.int), 62 | (b_radius[i, 0].detach().cpu().numpy() * 512).astype(np.int), 63 | color=(0, 0, 255) if intersections[i, 0] else (0, 255, 0), 64 | ) 65 | 66 | cv2.imshow("Test", image) 67 | cv2.waitKey(0) 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.12.3 2 | certifi==2024.2.2 3 | charset-normalizer==3.3.2 4 | chumpy==0.70 5 | cmake==3.26.4 6 | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work 7 | contourpy==1.1.0 8 | cycler==0.11.0 9 | dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work 10 | drjit==0.4.4 11 | filelock==3.12.2 12 | fonttools==4.40.0 13 | freetype-py==2.4.0 14 | fsspec==2024.2.0 15 | fvcore @ file:///home/conda/feedstock_root/build_artifacts/fvcore_1671623667463/work 16 | geomdl==5.3.1 17 | idna==3.6 18 | imageio==2.31.1 19 | iopath @ file:///home/conda/feedstock_root/build_artifacts/iopath_1636568816351/work 20 | Jinja2==3.1.2 21 | kiwisolver==1.4.4 22 | kornia==0.7.1 23 | lazy_loader==0.3 24 | lightning-utilities==0.10.1 25 | lit==16.0.6 26 | lxml==5.1.0 27 | MarkupSafe==2.1.3 28 | matplotlib==3.7.2 29 | mitsuba==3.5.0 30 | mpmath==1.3.0 31 | networkx==3.1 32 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1651020388975/work 33 | nvidia-cublas-cu11==11.11.3.6 34 | nvidia-cuda-cupti-cu11==11.8.87 35 | nvidia-cuda-nvrtc-cu11==11.8.89 36 | nvidia-cuda-runtime-cu11==11.8.89 37 | nvidia-cudnn-cu11==8.7.0.84 38 | nvidia-cufft-cu11==10.9.0.58 39 | nvidia-curand-cu11==10.3.0.86 40 | nvidia-cusolver-cu11==11.4.1.48 41 | nvidia-cusparse-cu11==11.7.5.86 42 | nvidia-nccl-cu11==2.20.5 43 | nvidia-nvtx-cu11==11.8.86 44 | opencv-python==4.9.0.80 45 | packaging==23.1 46 | Pillow==10.0.0 47 | plotly==5.19.0 48 | portalocker @ file:///home/conda/feedstock_root/build_artifacts/portalocker_1695662047585/work 49 | pyglet==2.0.12 50 | PyOpenGL==3.1.0 51 | pyparsing==3.0.9 52 | pyrender==0.1.45 53 | python-dateutil==2.8.2 54 | pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@f34104cf6ebefacd7b7e07955ee7aaa823e616ac 55 | PyWavefront==1.3.3 56 | PyYAML==6.0 57 | rdfpy==1.0.0 58 | requests==2.31.0 59 | scikit-image==0.22.0 60 | scipy==1.11.0 61 | six==1.16.0 62 | smplx==0.1.28 63 | soupsieve==2.5 64 | sympy==1.12 65 | tabulate @ file:///home/conda/feedstock_root/build_artifacts/tabulate_1665138452165/work 66 | tenacity==8.2.3 67 | termcolor @ file:///home/conda/feedstock_root/build_artifacts/termcolor_1704357939450/work 68 | tifffile==2024.2.12 69 | torch==2.3.0+cu118 70 | torchaudio==2.3.0+cu118 71 | torchmetrics==1.3.1 72 | torchvision==0.18.0+cu118 73 | tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1707598593068/work 74 | trimesh==4.1.7 75 | triton==2.3.0 76 | typing_extensions==4.9.0 77 | urllib3==2.2.1 78 | yacs @ file:///home/conda/feedstock_root/build_artifacts/yacs_1645705974477/work 79 | -------------------------------------------------------------------------------- /fireflies/entity/curve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | from geomdl import NURBS 5 | from typing import List 6 | 7 | import fireflies.entity.base as base 8 | import fireflies.utils.math 9 | 10 | 11 | class Curve(base.Transformable): 12 | count = 0.0 13 | 14 | def fromObj(path): 15 | # TODO: Implement me 16 | pass 17 | 18 | def __init__( 19 | self, 20 | name: str, 21 | curve: NURBS.Curve, 22 | device: torch.cuda.device = torch.device("cuda"), 23 | ): 24 | super(Curve, self).__init__(self, name, device) 25 | 26 | self._curve = curve 27 | self.curve_epsilon = 0.05 28 | 29 | self.curve_delta = self.curve_epsilon 30 | 31 | self._interp_steps = 1000 32 | self._interp_delta = 1.0 / self._interp_steps 33 | 34 | self.eval_interval_start = 0.05 35 | 36 | def train(self) -> None: 37 | self._train = True 38 | self._continuous = False 39 | 40 | def eval(self) -> None: 41 | self._train = False 42 | self._continuous = True 43 | self._curve_delta = self.eval_interval_start 44 | 45 | def setContinuous(self, continuous: bool) -> None: 46 | self._continuous = continuous 47 | 48 | def sample_rotation(self) -> torch.tensor: 49 | t = self.curve_delta 50 | t_new = self.curve_delta + 0.001 51 | 52 | t_new = torch.tensor(self._curve.evaluate_single(t_new), device=self._device) 53 | t = torch.tensor(self._curve.evaluate_single(t), device=self._device) 54 | 55 | curve_direction = t_new - t 56 | curve_direction[0] *= -1.0 57 | curve_direction[2] *= -1.0 58 | 59 | # curve_normal = torch.tensor(self._curve.normal(t), device=self._device) 60 | # curve_direction /= torch.linalg.norm(curve_direction) 61 | # curve_normal /= torch.linalg.norm(curve_normal) 62 | 63 | # camera_up_vector = torch.tensor([0, 0, 1], device=self._device) 64 | 65 | camera_direction = torch.tensor([0.0, 1.0, 0.0], device=self._device) 66 | return fireflies.utils.transforms.toMat4x4( 67 | fireflies.utils.math.rotation_matrix_from_vectors( 68 | camera_direction, curve_direction 69 | ) 70 | ) 71 | 72 | def sample_translation(self) -> torch.tensor: 73 | translationMatrix = torch.eye(4, device=self._device) 74 | translation = self._curve.evaluate_single(self.curve_delta) 75 | 76 | translationMatrix[0, 3] = translation[0] 77 | translationMatrix[1, 3] = translation[1] 78 | translationMatrix[2, 3] = translation[2] 79 | 80 | return translationMatrix 81 | 82 | def randomize(self) -> None: 83 | if self._train: 84 | self.curve_delta = random.uniform( 85 | 0 + self.curve_epsilon, self.eval_interval_start 86 | ) 87 | else: 88 | self.curve_delta += self._interp_delta 89 | 90 | if self.curve_delta > 1.0 - self.curve_epsilon: 91 | self.curve_delta = self.eval_interval_start 92 | 93 | self._randomized_world = ( 94 | self.sample_translation() @ self.sample_rotation() @ self._world 95 | ) 96 | -------------------------------------------------------------------------------- /fireflies/material/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | import fireflies.utils.math 5 | import fireflies.entity 6 | from fireflies.utils.warnings import ( 7 | RotationAssignmentWarning, 8 | RelativeAssignmentWarning, 9 | TranslationAssignmentWarning, 10 | WorldAssignmentWarning, 11 | ) 12 | 13 | 14 | class Material(fireflies.entity.Transformable): 15 | def __init__( 16 | self, 17 | name: str, 18 | device: torch.cuda.device = torch.device("cuda"), 19 | ): 20 | super(Material, self).__init__(name, device) 21 | 22 | def randomize(self) -> None: 23 | for key, sampler in self._float_attributes.items(): 24 | self._randomized_float_attributes[key] = sampler.sample() 25 | 26 | for key, sampler in self._vec3_attributes.items(): 27 | self._randomized_vec3_attributes[key] = sampler.sample() 28 | 29 | @WorldAssignmentWarning 30 | def set_world(self, _origin: torch.tensor) -> None: 31 | super(Material, self).set_world(_origin) 32 | 33 | @RelativeAssignmentWarning 34 | def setParent(self, parent) -> None: 35 | super(Material, self).setParent(parent) 36 | 37 | @RelativeAssignmentWarning 38 | def setChild(self, child) -> None: 39 | super(Material, self).setChild(child) 40 | 41 | @RotationAssignmentWarning 42 | def rotate_x(self, min_rot: float, max_rot: float) -> None: 43 | super(Material, self).rotate_x(min_rot, max_rot) 44 | 45 | @RotationAssignmentWarning 46 | def rotate_y(self, min_rot: float, max_rot: float) -> None: 47 | super(Material, self).rotate_y(min_rot, max_rot) 48 | 49 | @RotationAssignmentWarning 50 | def rotate_z(self, min_rot: float, max_rot: float) -> None: 51 | super(Material, self).rotate_z(min_rot, max_rot) 52 | 53 | @RotationAssignmentWarning 54 | def rotate(self, min: torch.tensor, max: torch.tensor) -> None: 55 | super(Material, self).rotate(min, max) 56 | 57 | @TranslationAssignmentWarning 58 | def translate_x(self, min_translation: float, max_translation: float) -> None: 59 | super(Material, self).translate_x(min_translation, max_translation) 60 | 61 | @TranslationAssignmentWarning 62 | def translate_y(self, min_translation: float, max_translation: float) -> None: 63 | super(Material, self).translate_y(min_translation, max_translation) 64 | 65 | @TranslationAssignmentWarning 66 | def translate_z(self, min_translation: float, max_translation: float) -> None: 67 | super(Material, self).translate_z(min_translation, max_translation) 68 | 69 | @TranslationAssignmentWarning 70 | def translate(self, min: torch.tensor, max: torch.tensor) -> None: 71 | super(Material, self).translate(min, max) 72 | 73 | @RotationAssignmentWarning 74 | def sample_rotation(self) -> torch.tensor: 75 | return super(Material, self).sample_rotation() 76 | 77 | @TranslationAssignmentWarning 78 | def sample_translation(self) -> torch.tensor: 79 | return super(Material, self).sample_translation() 80 | 81 | @RelativeAssignmentWarning 82 | def relative(self) -> None: 83 | return super(Material, self).relative() 84 | 85 | @WorldAssignmentWarning 86 | def world(self) -> torch.tensor: 87 | return super(Material, self).world() 88 | 89 | @WorldAssignmentWarning 90 | def nonRandomizedWorld(self) -> torch.tensor: 91 | return super(Material, self).nonRandomizedWorld() 92 | -------------------------------------------------------------------------------- /fireflies/sampling/noise_texture_lerp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | from typing import List 5 | import fireflies.sampling.base as base 6 | 7 | 8 | def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): 9 | delta = (res[0] / shape[0], res[1] / shape[1]) 10 | d = (shape[0] // res[0], shape[1] // res[1]) 11 | 12 | grid = ( 13 | torch.stack( 14 | torch.meshgrid( 15 | torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]) 16 | ), 17 | dim=-1, 18 | ) 19 | % 1 20 | ) 21 | angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) 22 | gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) 23 | 24 | tile_grads = ( 25 | lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]] 26 | .repeat_interleave(d[0], 0) 27 | .repeat_interleave(d[1], 1) 28 | ) 29 | dot = lambda grad, shift: ( 30 | torch.stack( 31 | ( 32 | grid[: shape[0], : shape[1], 0] + shift[0], 33 | grid[: shape[0], : shape[1], 1] + shift[1], 34 | ), 35 | dim=-1, 36 | ) 37 | * grad[: shape[0], : shape[1]] 38 | ).sum(dim=-1) 39 | 40 | n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) 41 | n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) 42 | n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) 43 | n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) 44 | t = fade(grid[: shape[0], : shape[1]]) 45 | return math.sqrt(2) * torch.lerp( 46 | torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1] 47 | ) 48 | 49 | 50 | def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5): 51 | noise = torch.zeros(shape) 52 | frequency = 1 53 | amplitude = 1 54 | for _ in range(octaves): 55 | noise += amplitude * rand_perlin_2d( 56 | shape, (frequency * res[0], frequency * res[1]) 57 | ) 58 | frequency *= 2 59 | amplitude *= persistence 60 | return noise 61 | 62 | 63 | class NoiseTextureLerpSampler(base.Sampler): 64 | def __init__( 65 | self, 66 | color_a: torch.tensor, 67 | color_b: torch.tensor, 68 | texture_shape: List[int], 69 | eval_step_size: float = 0.01, 70 | device: torch.cuda.device = torch.device("cuda"), 71 | ) -> None: 72 | super(NoiseTextureLerpSampler, self).__init__( 73 | torch.tensor([0.0], device=device), 74 | torch.tensor([1.0], device=device), 75 | eval_step_size, 76 | device, 77 | ) 78 | self._color_a = color_a 79 | self._color_b = color_b 80 | self._texture_shape = texture_shape 81 | 82 | def sample_train(self) -> torch.tensor: 83 | i = 2 ** random.randint(1, 6) 84 | octaves = random.randint(1, 4) 85 | persistence = random.uniform(0.1, 2.0) 86 | tex = rand_perlin_2d_octaves( 87 | self._texture_shape, res=(i, i), octaves=octaves, persistence=persistence 88 | ).to(self._device) 89 | tex = (tex - tex.min()) / (tex.max() - tex.min()) 90 | 91 | col_a = torch.ones_like(tex).unsqueeze(0).repeat( 92 | 3, 1, 1 93 | ) * self._color_a.unsqueeze(-1).unsqueeze(-1) 94 | col_b = torch.ones_like(tex).unsqueeze(0).repeat( 95 | 3, 1, 1 96 | ) * self._color_b.unsqueeze(-1).unsqueeze(-1) 97 | tex = tex.unsqueeze(0).repeat(3, 1, 1) 98 | return torch.lerp(col_a, col_b, tex) 99 | 100 | # To lazy to implement it right now. 101 | def sample_eval(self) -> torch.tensor: 102 | return self.sample_train() 103 | -------------------------------------------------------------------------------- /fireflies/utils/io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import math 4 | 5 | from geomdl import NURBS 6 | from pathlib import Path 7 | 8 | 9 | def read_config_yaml(file_path: str) -> dict: 10 | return yaml.safe_load(Path(file_path).read_text()) 11 | 12 | 13 | # From: https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/renderer/cameras.html#FoVPerspectiveCameras.compute_projection_matrix 14 | def build_projection_matrix( 15 | fov: float, 16 | near_clip: float, 17 | far_clip: float, 18 | device: torch.cuda.device = torch.device("cuda"), 19 | ) -> torch.tensor: 20 | """ 21 | Compute the calibration matrix K of shape (N, 4, 4) 22 | 23 | Args: 24 | znear: near clipping plane of the view frustrum. 25 | zfar: far clipping plane of the view frustrum. 26 | fov: field of view angle of the camera. 27 | aspect_ratio: aspect ratio of the image pixels. 28 | 1.0 indicates square pixels. 29 | degrees: bool, set to True if fov is specified in degrees. 30 | 31 | Returns: 32 | torch.FloatTensor of the calibration matrix with shape (N, 4, 4) 33 | """ 34 | K = torch.zeros((4, 4), dtype=torch.float32, device=device) 35 | fov = (math.pi / 180) * fov 36 | 37 | if not torch.is_tensor(fov): 38 | fov = torch.tensor(fov, device=device) 39 | 40 | tanHalfFov = torch.tan((fov / 2.0)) 41 | max_y = tanHalfFov * near_clip 42 | min_y = -max_y 43 | max_x = max_y * 1.0 44 | min_x = -max_x 45 | 46 | # NOTE: In OpenGL the projection matrix changes the handedness of the 47 | # coordinate frame. i.e the NDC space positive z direction is the 48 | # camera space negative z direction. This is because the sign of the z 49 | # in the projection matrix is set to -1.0. 50 | # In pytorch3d we maintain a right handed coordinate system throughout 51 | # so the so the z sign is 1.0. 52 | z_sign = -1.0 53 | 54 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 55 | K[0, 0] = 2.0 * near_clip / (max_x - min_x) 56 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 57 | K[1, 1] = 2.0 * near_clip / (max_y - min_y) 58 | K[0, 2] = (max_x + min_x) / (max_x - min_x) 59 | K[1, 2] = (max_y + min_y) / (max_y - min_y) 60 | K[3, 2] = z_sign 61 | 62 | # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point 63 | # is at the near clipping plane and z = 1 when the point is at the far 64 | # clipping plane. 65 | K[2, 2] = z_sign * far_clip / (far_clip - near_clip) 66 | K[2, 3] = -(far_clip * near_clip) / (far_clip - near_clip) 67 | 68 | return K 69 | 70 | 71 | # def tensorToFloatImage(tensor: torch.tensor) -> np.array: 72 | # return 73 | 74 | 75 | def importBlenderNurbsObj(path): 76 | obj_file = open(path, "r") 77 | lines = obj_file.readlines() 78 | 79 | control_points = [] 80 | deg = None 81 | knotvector = None 82 | 83 | for line in lines: 84 | token = "v " 85 | if line.startswith("v "): 86 | line = line.replace(token, "") 87 | values = line.split(" ") 88 | values = [float(value) for value in values] 89 | control_points.append(values) 90 | continue 91 | 92 | token = "deg " 93 | if line.startswith(token): 94 | line = line.replace(token, "") 95 | deg = int(line) 96 | continue 97 | 98 | token = "parm u " 99 | if line.startswith(token): 100 | line = line.replace(token, "") 101 | values = line.split(" ") 102 | knotvector = [float(value) for value in values] 103 | continue 104 | 105 | spline = NURBS.Curve() 106 | spline.degree = deg 107 | spline.ctrlpts = control_points 108 | spline.knotvector = knotvector 109 | 110 | return spline 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | 164 | 165 | 166 | # Stuff 167 | scenes* 168 | *.png 169 | *.mp4 170 | *.yaml 171 | *.obj 172 | *.jpg 173 | *.gif 174 | Old* 175 | .vscode* -------------------------------------------------------------------------------- /examples/vocalfold_scene.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | import kornia 5 | 6 | mi.set_variant("cuda_ad_rgb") 7 | 8 | import torch 9 | import fireflies 10 | import fireflies.sampling 11 | import fireflies.projection.laser 12 | 13 | 14 | def render_to_opencv(render): 15 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 16 | return (render * 255).astype(np.uint8) 17 | 18 | 19 | if __name__ == "__main__": 20 | path = "examples/scenes/vocalfold/vocalfold.xml" 21 | 22 | mitsuba_scene = mi.load_file(path) 23 | mitsuba_params = mi.traverse(mitsuba_scene) 24 | ff_scene = fireflies.Scene(mitsuba_params) 25 | 26 | projector_sensor = mitsuba_scene.sensors()[1] 27 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"] 28 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"] 29 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"] 30 | 31 | K_PROJECTOR = mi.perspective_projection( 32 | projector_sensor.film().size(), 33 | projector_sensor.film().crop_size(), 34 | projector_sensor.film().crop_offset(), 35 | x_fov, 36 | near_clip, 37 | far_clip, 38 | ).matrix.torch()[0] 39 | 40 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays( 41 | # 0.0275, 18, 18, device=ff_scene.device() 42 | # ) 43 | laser_rays = fireflies.projection.Laser.generate_blue_noise_rays( 44 | 500, 500, 18 * 18, K_PROJECTOR, device=ff_scene.device() 45 | ) 46 | 47 | laser = fireflies.projection.Laser( 48 | ff_scene._projector, 49 | laser_rays, 50 | K_PROJECTOR, 51 | x_fov, 52 | near_clip, 53 | far_clip, 54 | device=ff_scene.device(), 55 | ) 56 | texture = laser.generateTexture( 57 | 10.0, torch.tensor([500, 500], device=ff_scene.device()) 58 | ) 59 | texture = texture.sum(dim=0) 60 | 61 | texture = kornia.filters.gaussian_blur2d( 62 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3) 63 | ).squeeze() 64 | texture = torch.stack( 65 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)] 66 | ) 67 | texture = torch.movedim(texture, 0, -1) 68 | 69 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy()) 70 | 71 | vocalfold_mesh = ff_scene.mesh("mesh-VocalFold") 72 | larynx_mesh = ff_scene.mesh("mesh-Larynx") 73 | larynx_mesh.scale_x(0.8, 1.2) 74 | larynx_mesh.rotate_y(-0.1, 0.1) 75 | 76 | vocalfold_mesh.scale_x(0.5, 2.0) 77 | vocalfold_mesh.rotate_y(-0.25, 0.25) 78 | 79 | material = ff_scene.material("mat-Default OBJ") 80 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler( 81 | 1.0, 20.0, device=ff_scene.device() 82 | ) 83 | 84 | light = ff_scene.light("emit-Spot") 85 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler) 86 | 87 | material.add_vec3_key( 88 | "brdf_0.base_color.value", 89 | torch.tensor([0.8, 0.14, 0.34], device=ff_scene.device()), 90 | torch.tensor([0.85, 0.5, 0.44], device=ff_scene.device()), 91 | ) 92 | material.add_float_key("brdf_0.specular", 0.0, 0.75) 93 | 94 | # ff_scene._camera.rotate_y(-0.25, 0.25) 95 | # ff_scene._projector.setParent(ff_scene._camera) 96 | # ff_scene._projector._randomizable = True 97 | # ff_scene._projector.rotate_x(3.141, 3.141) 98 | 99 | ff_scene.train() 100 | for i in range(1000): 101 | ff_scene.randomize() 102 | render = mi.render(mitsuba_scene, spp=100) 103 | render = render_to_opencv(render) 104 | 105 | if i % 2 == 0: 106 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY).astype(int) 107 | noise = np.random.normal(np.zeros_like(render), np.ones_like(render) * 0.05) 108 | noise *= 255 109 | noise = noise.astype(np.int) 110 | render += noise 111 | render[render > 255] = 255 112 | render[render < 0] = 0 113 | render = render.astype(np.uint8) 114 | 115 | cv2.imwrite("vf_renderings/{0:05d}.png".format(i), render) 116 | # cv2.imshow("A", render) 117 | # cv2.waitKey(0) 118 | -------------------------------------------------------------------------------- /fireflies/entity/flame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import flame_pytorch.flame as flame 4 | from typing import List 5 | 6 | import fireflies.entity.base as base 7 | import shape 8 | 9 | import fireflies.utils.math 10 | 11 | 12 | class FlameShapeModel(shape.ShapeModel): 13 | def __init__( 14 | self, 15 | name: str, 16 | vertex_data: List[float], 17 | config: dict, 18 | device: torch.cuda.device = torch.device("cuda"), 19 | base_path: str = None, 20 | sequential_animation: bool = False, 21 | ): 22 | 23 | self._device = device 24 | self._name = name 25 | 26 | self.setTranslationBoundaries(config["translation"]) 27 | self.setRotationBoundaries(config["rotation"]) 28 | self.setWorld(config["to_world"]) 29 | self._world = self._world @ fireflies.utils.transforms.toMat4x4( 30 | fireflies.utils.math.getXTransform(np.pi * 0.5, self._device) 31 | ) 32 | self._randomized_world = self._world.clone() 33 | 34 | self._randomizable = bool(config["randomizable"]) 35 | self._relative = bool(config["is_relative"]) 36 | 37 | self._parent_name = config["parent_name"] if self._relative else None 38 | # Is loaded in a second step 39 | self._parent = None 40 | self._child = None 41 | 42 | self.setVertices(vertex_data) 43 | self.setScaleBoundaries(config["scale"]) 44 | self._animated = bool(config["animated"]) 45 | self._sequential_animation = sequential_animation 46 | 47 | self._animation_index = 0 48 | 49 | flame_config = Namespace( 50 | **{ 51 | "batch_size": 1, 52 | "dynamic_landmark_embedding_path": "./Objects/flame_pytorch/model/flame_dynamic_embedding.npy", 53 | "expression_params": 50, 54 | "flame_model_path": "./Objects/flame_pytorch/model/generic_model.pkl", 55 | "num_worker": 4, 56 | "optimize_eyeballpose": True, 57 | "optimize_neckpose": True, 58 | "pose_params": 6, 59 | "ring_loss_weight": 1.0, 60 | "ring_margin": 0.5, 61 | "shape_params": 100, 62 | "static_landmark_embedding_path": "./Objects/flame_pytorch/model/flame_static_embedding.pkl", 63 | "use_3D_translation": True, 64 | "use_face_contour": True, 65 | } 66 | ) 67 | 68 | self.setVertices(vertex_data) 69 | self.setScaleBoundaries(config["scale"]) 70 | self._animated = True 71 | self._stddev_range = config["stddev_range"] 72 | self._shape_layer = flame.FLAME(flame_config).to(self._device) 73 | self._faces = self._shape_layer.faces 74 | self._pose_params = torch.zeros(1, 6, device=self._device) 75 | self._expression_params = torch.zeros(1, 50, device=self._device) 76 | self._shape_params = ( 77 | (torch.rand(1, 100, device=self._device) - 0.5) * 2.0 * self._stddev_range 78 | ) 79 | self._shape_params[:, 20:] = 0.0 80 | 81 | self._shape_params *= 0.0 82 | self._invert = False 83 | 84 | def train(self) -> None: 85 | base.Transformable.train(self) 86 | 87 | def eval(self) -> None: 88 | base.Transformable.eval(self) 89 | 90 | def loadAnimation(self): 91 | return None 92 | 93 | def modelParams(self) -> dict: 94 | return self._shape_params 95 | 96 | def shapeParams(self) -> torch.tensor: 97 | return self._shape_params 98 | 99 | def expressionParams(self) -> torch.tensor: 100 | return self._expression_params 101 | 102 | def poseParams(self) -> torch.tensor: 103 | return self._pose_params 104 | 105 | def randomize(self) -> None: 106 | if self._shape_params[0, 0] > 2.0: 107 | self._invert = True 108 | self._shape_params = self._shape_params + (0.05 if not self._invert else -0.05) 109 | self._shape_params[:, 20:] = 0.0 110 | 111 | self._randomized_world = ( 112 | self.sampleTranslation() @ self.sampleRotation() @ self.sampleScale() 113 | ) 114 | 115 | def getVertexData(self): 116 | if not self._animated: 117 | return self._vertices, self._shape_layer.faces 118 | 119 | vertices, _ = self._shape_layer( 120 | self._shape_params, self._expression_params, self._pose_params 121 | ) 122 | vertices = vertices[0] 123 | 124 | vertices = fireflies.utils.transforms.transform_points( 125 | vertices, 126 | self.world() 127 | @ fireflies.utils.transforms.toMat4x4( 128 | fireflies.utils.math.getXTransform(np.pi * 0.5, self._device) 129 | ), 130 | ) 131 | 132 | return vertices, self._shape_layer.faces 133 | -------------------------------------------------------------------------------- /fireflies/sampling/poisson.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Implementation of the fast Poisson Disk Sampling algorithm of 5 | Bridson (2007) adapted to support spatially varying sampling radii. 6 | 7 | Adrian Bittner, 2021 8 | Published under MIT license. 9 | """ 10 | 11 | 12 | def getGridCoordinates(coords): 13 | return np.floor(coords).astype("int") 14 | 15 | 16 | def bridson(radius, k=30, radiusType="default"): 17 | """ 18 | Implementation of the Poisson Disk Sampling algorithm. 19 | 20 | :param radius: 2d array specifying the minimum sampling radius for each spatial position in the sampling box. The 21 | size of the sampling box is given by the size of the radius array. 22 | :param k: Number of iterations to find a new particle in an annulus between radius r and 2r from a sample particle. 23 | :param radiusType: Method to determine the distance to newly spawned particles. 'default' follows the algorithm of 24 | Bridson (2007) and generates particles uniformly in the annulus between radius r and 2r. 25 | 'normDist' instead creates new particles at distances drawn from a normal distribution centered 26 | around 1.5r with a dispersion of 0.2r. 27 | :return: nParticle: Number of particles in the sampling. 28 | particleCoordinates: 2d array containing the coordinates of the created particles. 29 | """ 30 | # Set-up background grid 31 | gridHeight, gridWidth = radius.shape 32 | grid = np.zeros((gridHeight, gridWidth)) 33 | 34 | # Pick initial (active) point 35 | coords = (np.random.random() * gridHeight, np.random.random() * gridWidth) 36 | idx = getGridCoordinates(coords) 37 | nParticle = 1 38 | grid[idx[0], idx[1]] = nParticle 39 | 40 | # Initialise active queue 41 | queue = [ 42 | coords 43 | ] # Appending to list is much quicker than to numpy array, if you do it very often 44 | particleCoordinates = [ 45 | coords 46 | ] # List containing the exact positions of the final particles 47 | 48 | # Continue iteration while there is still points in active list 49 | while queue: 50 | 51 | # Pick random element in active queue 52 | idx = np.random.randint(len(queue)) 53 | activeCoords = queue[idx] 54 | activeGridCoords = getGridCoordinates(activeCoords) 55 | 56 | success = False 57 | for _ in range(k): 58 | 59 | if radiusType == "default": 60 | # Pick radius for new sample particle ranging between 1 and 2 times the local radius 61 | newRadius = radius[activeGridCoords[0], activeGridCoords[1]] * ( 62 | np.random.random() + 1 63 | ) 64 | elif radiusType == "normDist": 65 | # Pick radius for new sample particle from a normal distribution around 1.5 times the local radius 66 | newRadius = radius[ 67 | activeGridCoords[0], activeGridCoords[1] 68 | ] * np.random.normal(1.5, 0.2) 69 | 70 | # Pick the angle to the sample particle and determine its coordinates 71 | angle = 2 * np.pi * np.random.random() 72 | newCoords = np.zeros(2) 73 | newCoords[0] = activeCoords[0] + newRadius * np.sin(angle) 74 | newCoords[1] = activeCoords[1] + newRadius * np.cos(angle) 75 | 76 | # Prevent that the new particle is outside of the grid 77 | if not (0 <= newCoords[1] <= gridWidth and 0 <= newCoords[0] <= gridHeight): 78 | continue 79 | 80 | # Check that particle is not too close to other particle 81 | newGridCoords = getGridCoordinates((newCoords[0], newCoords[1])) 82 | 83 | radiusThere = np.ceil(radius[newGridCoords[0], newGridCoords[1]]) 84 | 85 | gridRangeX = ( 86 | np.max([newGridCoords[1] - radiusThere, 0]).astype("int"), 87 | np.min([newGridCoords[1] + radiusThere + 1, gridWidth]).astype("int"), 88 | ) 89 | gridRangeY = ( 90 | np.max([newGridCoords[0] - radiusThere, 0]).astype("int"), 91 | np.min([newGridCoords[0] + radiusThere + 1, gridHeight]).astype("int"), 92 | ) 93 | 94 | searchGrid = grid[ 95 | slice(gridRangeY[0], gridRangeY[1]), slice(gridRangeX[0], gridRangeX[1]) 96 | ] 97 | conflicts = np.where(searchGrid > 0) 98 | 99 | if len(conflicts[0]) == 0 and len(conflicts[1]) == 0: 100 | # No conflicts detected. Create a new particle at this position! 101 | queue.append(newCoords) 102 | particleCoordinates.append(newCoords) 103 | nParticle += 1 104 | grid[newGridCoords[0], newGridCoords[1]] = nParticle 105 | success = True 106 | 107 | else: 108 | # There is a conflict. Do NOT create a new particle at this position! 109 | continue 110 | 111 | if success == False: 112 | # No new particle could be associated to the currently active particle. 113 | # Remove current particle from the active queue! 114 | del queue[idx] 115 | 116 | return (nParticle, np.array(particleCoordinates)) 117 | 118 | 119 | if __name__ == "__main__": 120 | import matplotlib.pyplot as plt 121 | 122 | width = 512 123 | height = 512 124 | 125 | radius = 10 126 | max_radius = 3 * radius 127 | 128 | import cv2 129 | 130 | weight_matrix = ( 131 | np.ones([height, width], np.float32) * max_radius 132 | ) # Should be between 0 and 1 133 | cv2.circle(weight_matrix, np.array(weight_matrix.shape) // 2, 50, radius, -1) 134 | 135 | # cv2.imshow("Weight Matrix", weight_matrix) 136 | # cv2.waitKey(0) 137 | 138 | npoints, points = bridson(weight_matrix) 139 | # points = np.array(poisson_disc_samples(width=width, height=height, r=radius)) 140 | 141 | plt.scatter(points[:, 0], points[:, 1]) 142 | plt.xlim(0, width) 143 | plt.ylim(0, height) 144 | plt.show() 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Fireflies](https://github.com/Henningson/Fireflies/assets/27073509/36254690-b42a-4604-849f-ebfa4ffa69c6) 2 | 3 | **Fireflies** is a wrapper for the Mitsuba Renderer and allows for rapid prototyping and generation of physically-based renderings and simulation data in a differentiable manner. 4 | It can be used for example, to easily generate highly realistic medical imaging data for medical machine learning tasks or (its intended use) test the reconstruction capabilities of Structured Light projection systems in simulated environments. 5 | I originally created it to research if the task of finding an optimal point-based laser pattern for structured light laryngoscopy can be reformulated as a gradient-based optimization problem. 6 | 7 | This code accompanies the paper **Fireflies: Photorealistic Simulation and Optimization of Structured Light Endoscopy** accepted at **SASHIMI 2024**. 🎊 8 | 9 | 10 | # Main features 11 | - **Easy torch-like and pythonic scene randomization description.** This library is made to be easily usable for everyone who regularly uses python and pytorch. We implement train() and eval() functionality from the get go. 12 | - **Integratable into online deep-learning and machine learning tasks** due to the differentiability of the mitsuba renderer w.r.t. the scene parameters. 13 | - **Simple animation description**. Have a look into the examples. 14 | - **Single Shot Structured Light specific**. You can easily test different projection pattern and reconstruction algorithms on randomized scenes, giving a good estimation of the quality and viability of patterns/systems/algorithms. 15 | 16 | # Installation 17 | Make sure to create a conda environment first. 18 | I tested fireflies on Python 3.10, it should however work with every Python version that is also supported by Mitsuba and Pytorch. 19 | I'm working on adding Fireflies to PyPi in the future. 20 | First install the necessary dependencies: 21 | ``` 22 | pip install pywavefront geomdl 23 | pip install torch 24 | pip install mitsuba 25 | ``` 26 | To run the examples, you also need OpenCV: 27 | ``` 28 | pip install opencv-python 29 | ``` 30 | Finally, you can install Fireflies via: 31 | ``` 32 | git clone https://github.com/Henningson/Fireflies.git 33 | cd Fireflies 34 | pip install . 35 | ``` 36 | 37 | ![Datasets](https://github.com/Henningson/Fireflies/assets/27073509/9c617876-356a-420d-8632-cf4c286d6778) 38 | # Usage 39 | ``` 40 | import mitsuba as mi 41 | import fireflies as ff 42 | 43 | mi_scene = mi.scene(path) 44 | mi_params = mi.traverse(mi_scene) 45 | ff_scene = ff.scene(mi_params) 46 | 47 | mesh = ff_scene.mesh_at(0) 48 | mesh.rotate_z(-3.141, 3.141) 49 | 50 | ff_scene.eval() 51 | #ff_scene.train() generates uniformly sampled results on the right 52 | for i in range(0, 20): 53 | ff_scene.randomize() 54 | mi.render(mi_scene) 55 | ``` 56 | 57 |

58 | 59 | 60 |

61 | 62 | # Examples 63 | A bunch of different examples can be found in the examples folder. 64 | Papercode can be found in the **Paper** branch. 65 | 66 | # Building and loading your own scene 67 | You can easily generate a scene using Blender. 68 | To export a scene in Mitsubas required .xml format, you first need to install the Mitsuba Blender Add-On. 69 | You can then export it under the File -> Export Tab. 70 | Make sure to tick the ✅ export ids Checkbox, as fireflies infers the object type by checking for name qualifiers with specific keys, e.g.: "mesh", "brdf", etc. 71 | 72 | # Render Gallery 73 |

74 | 75 | 76 | 77 | 78 |

79 | These are some renderings that were created during my work on the aforementioned paper. 80 | From left to right: Reconstructed in-vivo colon, the flame shapemodel, phonating human vocal folds with a point-based structured light pattern. 81 | 82 | ## More Discussion about the Paper 83 | Can be found in the **README** of the **paper** branch. 84 | 85 | ## Why did you call this Fireflies? 86 | Because optimizing a point-based laser pattern looks like fireflies that jet around. :) 87 |

88 | 89 |

90 | 91 | 92 | ## Acknowledgements 93 | A big thank you to Wenzel Jakob and team for their wonderful work on the Mitsuba renderer. 94 | You should definitely check out their work: Mitsuba Homepage, Mitsuba Github. 95 | 96 | Furthermore, this work was supported by Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under grant STA662/6-1, Project-ID 448240908 and (partly) funded by the DFG – SFB 1483 – Project-ID 442419336, EmpkinS. 97 | 98 | 99 |

100 | 101 | 102 |

103 | 104 | ## Citation 105 | Please cite this, if this work helps you with your research: 106 | ``` 107 | @InProceedings{10.1007/978-3-031-73281-2_10, 108 | author="Henningson, Jann-Ole and Veltrup, Reinhard and Semmler, Marion and D{\"o}llinger, Michael and Stamminger, Marc", 109 | title="Fireflies: Photorealistic Simulation and Optimization of Structured Light Endoscopy", 110 | booktitle="Simulation and Synthesis in Medical Imaging", 111 | year="2025", 112 | publisher="Springer Nature Switzerland", 113 | address="Cham", 114 | pages="102--112", 115 | isbn="978-3-031-73281-2" 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /fireflies/graphics/depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mitsuba as mi 3 | import drjit as dr 4 | import fireflies.utils.laser_estimation 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | @dr.wrap_ad(source="torch", target="drjit") 10 | def from_laser(scene, params, laser): 11 | sensor = scene.sensors()[0] 12 | film = sensor.film() 13 | # TODO: Add device 14 | size = torch.tensor(film.size(), device="cuda") 15 | 16 | hit_points = fireflies.utils.laser_estimation.cast_laser(scene, laser=laser) 17 | ndc_coords = fireflies.utils.laser_estimation.project_to_camera_space( 18 | params, hit_points 19 | ) 20 | pixel_coords = ndc_coords * 0.5 + 0.5 21 | pixel_coords = pixel_coords[0, :, 0:2] 22 | pixel_coords = torch.floor(pixel_coords * size).int() 23 | 24 | mask = torch.zeros(size.tolist(), device=size.device) 25 | mask[pixel_coords[:, 0], pixel_coords[:, 1]] = 1.0 26 | 27 | depth_map = from_camera_non_wrapped(scene, spp=1) 28 | depth_map = depth_map.reshape(film.size()[0], film.size()[1]) 29 | 30 | return depth_map * mask 31 | 32 | 33 | @dr.wrap_ad(source="torch", target="drjit") 34 | def cast_laser_id(scene, origin, direction): 35 | origin_point = mi.Point3f( 36 | origin[:, 0].array, origin[:, 1].array, origin[:, 2].array 37 | ) 38 | rays_vector = mi.Vector3f( 39 | direction[:, 0].array, direction[:, 1].array, direction[:, 2].array 40 | ) 41 | surface_interaction = scene.ray_intersect(mi.Ray3f(origin_point, rays_vector)) 42 | shape_pointer = mi.Int( 43 | dr.reinterpret_array_v(mi.UInt, surface_interaction.shape) 44 | ).torch() 45 | shape_pointer -= shape_pointer.min() 46 | return shape_pointer 47 | 48 | 49 | def from_camera_non_wrapped(scene, spp=64): 50 | sensor = scene.sensors()[0] 51 | film = sensor.film() 52 | sampler = sensor.sampler() 53 | film_size = film.crop_size() 54 | total_samples = dr.prod(film_size) * spp 55 | 56 | if sampler.wavefront_size() != total_samples: 57 | sampler.seed(0, total_samples) 58 | 59 | # Enumerate discrete sample & pixel indices, and uniformly sample 60 | # positions within each pixel. 61 | pos = dr.arange(mi.UInt32, total_samples) 62 | 63 | pos //= spp 64 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 65 | pos = mi.Vector2f( 66 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0])) 67 | ) 68 | 69 | # pos += sampler.next_2d() 70 | 71 | # Sample rays starting from the camera sensor 72 | rays, weights = sensor.sample_ray( 73 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0 74 | ) 75 | 76 | # Intersect rays with the scene geometry 77 | surface_interaction = scene.ray_intersect(rays) 78 | 79 | # Given intersection, compute the final pixel values as the depth t 80 | # of the sampled surface interaction 81 | result = surface_interaction.t 82 | 83 | # Set to zero if no intersection was found 84 | result[~surface_interaction.is_valid()] = 0 85 | 86 | return result 87 | 88 | 89 | def get_segmentation_from_camera(scene, spp=1): 90 | sensor = scene.sensors()[0] 91 | film = sensor.film() 92 | sampler = sensor.sampler() 93 | film_size = film.crop_size() 94 | total_samples = dr.prod(film_size) * spp 95 | 96 | if sampler.wavefront_size() != total_samples: 97 | sampler.seed(0, total_samples) 98 | 99 | # Enumerate discrete sample & pixel indices, and uniformly sample 100 | # positions within each pixel. 101 | pos = dr.arange(mi.UInt32, total_samples) 102 | 103 | pos //= spp 104 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 105 | pos = mi.Vector2f( 106 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0])) 107 | ) 108 | 109 | # Sample rays starting from the camera sensor 110 | rays, weights = sensor.sample_ray( 111 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0 112 | ) 113 | 114 | # Intersect rays with the scene geometry 115 | surface_interaction = scene.ray_intersect(rays) 116 | 117 | # Watch out, hacky stuff going on! 118 | # Solution from: https://github.com/mitsuba-renderer/mitsuba3/discussions/882 119 | shape_pointer = mi.Int( 120 | dr.reinterpret_array_v(mi.UInt, surface_interaction.shape) 121 | ).torch() 122 | shape_pointer -= shape_pointer.min() 123 | shape_pointer = shape_pointer.max() - shape_pointer 124 | 125 | return shape_pointer.reshape(film_size[1], film_size[0]) 126 | 127 | 128 | @dr.wrap_ad(source="drjit", target="torch") 129 | def from_camera(scene, spp=64): 130 | sensor = scene.sensors()[0] 131 | film = sensor.film() 132 | sampler = sensor.sampler() 133 | film_size = film.crop_size() 134 | total_samples = dr.prod(film_size) * spp 135 | 136 | if sampler.wavefront_size() != total_samples: 137 | sampler.seed(0, total_samples) 138 | 139 | # Enumerate discrete sample & pixel indices, and uniformly sample 140 | # positions within each pixel. 141 | pos = dr.arange(mi.UInt32, total_samples) 142 | 143 | pos //= spp 144 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 145 | pos = mi.Vector2f( 146 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0])) 147 | ) 148 | 149 | pos += sampler.next_2d() 150 | 151 | # Sample rays starting from the camera sensor 152 | rays, weights = sensor.sample_ray( 153 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0 154 | ) 155 | 156 | # Intersect rays with the scene geometry 157 | surface_interaction = scene.ray_intersect(rays) 158 | 159 | # Given intersection, compute the final pixel values as the depth t 160 | # of the sampled surface interaction 161 | result = surface_interaction.t 162 | 163 | # Set to zero if no intersection was found 164 | result[~surface_interaction.is_valid()] = 0 165 | 166 | return result 167 | 168 | 169 | def random_depth_maps( 170 | firefly_scene, mi_scene, num_maps: int = 100, spp: int = 1 171 | ) -> torch.tensor: 172 | stacked_depth_maps = [] 173 | im_size = mi_scene.sensors()[0].film().size() 174 | 175 | for i in tqdm(range(num_maps)): 176 | firefly_scene.randomize() 177 | 178 | depth_map = from_camera_non_wrapped(mi_scene, spp=1) 179 | 180 | # vis_depth = torch.log(depth_map.torch().reshape(im_size[1], im_size[0]).clone()) 181 | # vis_depth = utils.normalize(vis_depth) 182 | # vis_depth = ((1 - vis_depth.detach().cpu().numpy()) * 255).astype(np.uint8) 183 | # colored = cv2.applyColorMap(vis_depth, cv2.COLORMAP_INFERNO) 184 | # cv2.imshow("Depth", colored) 185 | # cv2.waitKey(1) 186 | 187 | depth_map = depth_map.torch().reshape(im_size[1], im_size[0], spp).mean(dim=-1) 188 | stacked_depth_maps.append(depth_map) 189 | 190 | return torch.stack(stacked_depth_maps) 191 | -------------------------------------------------------------------------------- /fireflies/entity/mesh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pywavefront 4 | 5 | import fireflies.entity.base as base 6 | import fireflies.utils.math 7 | import fireflies.sampling 8 | 9 | 10 | class Mesh(base.Transformable): 11 | def __init__( 12 | self, 13 | name: str, 14 | vertex_data: torch.tensor, 15 | device: torch.cuda.device = torch.device("cuda"), 16 | ): 17 | super(Mesh, self).__init__(name, device) 18 | 19 | self._vertices = vertex_data.to(self._device) 20 | self._vertices_animation = None 21 | 22 | ones = torch.ones(3, device=self._device) 23 | self._scale_sampler = fireflies.sampling.UniformSampler( 24 | ones.clone(), ones.clone() 25 | ) 26 | 27 | self._animated = False 28 | 29 | self._anim_data_train = None 30 | self._anim_data_eval = None 31 | self._animation_func = None 32 | self._animation_sampler = None 33 | 34 | def set_scale_sampler(self, sampler: fireflies.sampling.Sampler) -> None: 35 | self._scale_sampler = sampler 36 | 37 | def scale_x(self, min_scale: float, max_scale: float) -> None: 38 | self._randomizable = True 39 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 0) 40 | 41 | def scale_y(self, min_scale: float, max_scale: float) -> None: 42 | self._randomizable = True 43 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 1) 44 | 45 | def scale_z(self, min_scale: float, max_scale: float) -> None: 46 | self._randomizable = True 47 | self.update_index_from_sampler(self._scale_sampler, min_scale, max_scale, 2) 48 | 49 | def scale(self, min: torch.tensor, max: torch.tensor) -> None: 50 | self._randomizable = True 51 | self._scale_sampler.set_sample_interval( 52 | min.to(self._device), max.to(self._device) 53 | ) 54 | 55 | def set_scale_sampler(self, sampler: fireflies.sampling.Sampler) -> None: 56 | self._scale_sampler = sampler 57 | 58 | def animated(self) -> bool: 59 | return self._animated 60 | 61 | def add_animation(self, animation_data: torch.tensor) -> None: 62 | self._animation_vertices = animation_data.to(self._device) 63 | self._animated = True 64 | self._randomizable = True 65 | 66 | def add_animation_func(self, func, min_range, max_range) -> None: 67 | self._animation_func = func 68 | self._animation_sampler = fireflies.sampling.UniformSampler( 69 | min_range, max_range, device=self._device 70 | ) 71 | self._animated = True 72 | self._randomizable = True 73 | 74 | def add_train_animation_from_obj( 75 | self, path: str, min: int = None, max: int = None 76 | ) -> None: 77 | self._anim_data_train = self.load_animation(path) 78 | 79 | if self._animation_sampler: 80 | self._animation_sampler.set_train_interval( 81 | 0 if min is None else 0, 82 | self._anim_data_train.shape[0] if max is None else max, 83 | ) 84 | return 85 | 86 | self._animation_sampler = fireflies.sampling.AnimationSampler(0, 1, 0, 1) 87 | self._animation_sampler.set_train_interval( 88 | 0 if min is None else 0, 89 | self._anim_data_train.shape[0] if max is None else max, 90 | ) 91 | self._animated = True 92 | 93 | def add_eval_animation_from_obj( 94 | self, path: str, min: int = None, max: int = None 95 | ) -> None: 96 | self._anim_data_eval = self.load_animation(path) 97 | 98 | if self._animation_sampler: 99 | self._animation_sampler.set_eval_interval( 100 | 0 if min is None else 0, 101 | self._anim_data_eval.shape[0] if max is None else max, 102 | ) 103 | return 104 | 105 | self._animation_sampler = fireflies.sampling.AnimationSampler(0, 1, 0, 1) 106 | self._animation_sampler.set_eval_interval( 107 | 0 if min is None else 0, 108 | self._anim_data_eval.shape[0] if max is None else max, 109 | ) 110 | 111 | def train(self) -> None: 112 | super(Mesh, self).train() 113 | self._scale_sampler.train() 114 | 115 | if self._animation_sampler: 116 | self._animation_sampler.train() 117 | 118 | def eval(self) -> None: 119 | super(Mesh, self).eval() 120 | self._scale_sampler.eval() 121 | 122 | if self._animation_sampler: 123 | self._animation_sampler.eval() 124 | 125 | def set_faces(self, faces: torch.tensor) -> None: 126 | self._faces = faces.to(self._device) 127 | 128 | def set_vertices(self, vertices: torch.tensor) -> None: 129 | self._vertices = vertices.to(self._device) 130 | 131 | def sample_scale(self) -> torch.tensor: 132 | scale_matrix = torch.eye(4, device=self._device) 133 | 134 | random_scale = self._scale_sampler.sample() 135 | 136 | scale_matrix[0, 0] = random_scale[0] 137 | scale_matrix[1, 1] = random_scale[1] 138 | scale_matrix[2, 2] = random_scale[2] 139 | return scale_matrix 140 | 141 | def randomize(self) -> None: 142 | if not self.randomizable(): 143 | return 144 | 145 | self._randomized_world = ( 146 | (self.sample_translation() + self._centroid_mat) 147 | @ self.sample_rotation() 148 | @ self.sample_scale() 149 | @ self._world 150 | ) 151 | 152 | def faces(self) -> torch.tensor: 153 | return self._faces 154 | 155 | def get_vertices(self) -> torch.tensor: 156 | return self._vertices 157 | 158 | def get_randomized_vertices(self) -> torch.tensor: 159 | # Sample Animations 160 | temp_vertex = self.sample_animation() if self._animated else self._vertices 161 | 162 | # Transform by world transform 163 | temp_vertex = fireflies.utils.math.transform_points(temp_vertex, self.world()) 164 | 165 | return temp_vertex 166 | 167 | def load_animation(self, path: str) -> torch.tensor: 168 | animation_data = [] 169 | for file in sorted(os.listdir(path)): 170 | if file.endswith(".obj"): 171 | obj_path = os.path.join(path, file) 172 | 173 | obj = pywavefront.Wavefront(obj_path, collect_faces=True) 174 | 175 | vertices = torch.tensor(obj.vertices, device=self._device).reshape( 176 | -1, 3 177 | ) 178 | 179 | animation_data.append(vertices) 180 | 181 | return torch.stack(animation_data) 182 | 183 | def sample_animation(self): 184 | if not self._animated: 185 | return self._vertices 186 | 187 | # Can either be an integer or a float, depending if we loaded meshes or defined an animation function 188 | time_sample = self._animation_sampler.sample() 189 | if self._animation_func is not None: 190 | return self._animation_func(self._vertices, time_sample) 191 | elif self._anim_data_train is not None and self._anim_data_eval is not None: 192 | return ( 193 | self._anim_data_train[time_sample] 194 | if self._train 195 | else self._anim_data_eval[time_sample] 196 | ) 197 | 198 | return None 199 | -------------------------------------------------------------------------------- /fireflies/utils/math.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | def uniformBetweenValues(a: float, b: float) -> float: 9 | return random.uniform(a, b) 10 | 11 | 12 | def getZTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 13 | return getYawTransform(alpha, _device) 14 | 15 | 16 | def getYTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 17 | return getPitchTransform(alpha, _device) 18 | 19 | 20 | def getXTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 21 | return getRollTransform(alpha, _device) 22 | 23 | 24 | def getYawTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 25 | rotZ = torch.tensor( 26 | [ 27 | [math.cos(alpha), -math.sin(alpha), 0], 28 | [math.sin(alpha), math.cos(alpha), 0], 29 | [0, 0, 1], 30 | ], 31 | device=_device, 32 | ) 33 | 34 | return rotZ 35 | 36 | 37 | def getPitchTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 38 | rotY = torch.tensor( 39 | [ 40 | [math.cos(alpha), 0, math.sin(alpha)], 41 | [0, 1, 0], 42 | [-math.sin(alpha), 0, math.cos(alpha)], 43 | ], 44 | device=_device, 45 | ) 46 | 47 | return rotY 48 | 49 | 50 | def getRollTransform(alpha: float, _device: torch.cuda.device) -> torch.tensor: 51 | rotX = torch.tensor( 52 | [ 53 | [1, 0, 0], 54 | [0, math.cos(alpha), -math.sin(alpha)], 55 | [0, math.sin(alpha), math.cos(alpha)], 56 | ], 57 | device=_device, 58 | ) 59 | 60 | return rotX 61 | 62 | 63 | def vector_dot(A: torch.tensor, B: torch.tensor) -> torch.tensor: 64 | return torch.sum(A * B, dim=-1) 65 | 66 | 67 | def rotation_matrix_from_vectors(v1, v2): 68 | """ 69 | Calculates the rotation matrix that transforms vector v1 to v2. 70 | 71 | Args: 72 | - v1 (torch.Tensor): The source 3D vector (3x1). 73 | - v2 (torch.Tensor): The target 3D vector (3x1). 74 | 75 | Returns: 76 | - torch.Tensor: The 3x3 rotation matrix. 77 | """ 78 | v1 = F.normalize(v1, dim=0) 79 | v2 = F.normalize(v2, dim=0) 80 | 81 | # Compute the cross product and dot product 82 | cross_product = torch.cross(v1, v2) 83 | dot_product = torch.dot(v1, v2) 84 | 85 | # Skew-symmetric matrix for cross product 86 | skew_sym_matrix = torch.tensor( 87 | [ 88 | [0, -cross_product[2], cross_product[1]], 89 | [cross_product[2], 0, -cross_product[0]], 90 | [-cross_product[1], cross_product[0], 0], 91 | ], 92 | dtype=torch.float32, 93 | device=v1.device, 94 | ) 95 | 96 | # Rotation matrix using Rodrigues' formula 97 | rotation_matrix = ( 98 | torch.eye(3, device=v1.device) 99 | + skew_sym_matrix 100 | + torch.mm(skew_sym_matrix, skew_sym_matrix) 101 | * (1 - dot_product) 102 | / torch.norm(cross_product) ** 2 103 | ) 104 | 105 | return rotation_matrix 106 | 107 | 108 | def rotation_matrix_from_vectors_with_fixed_up( 109 | v1, v2, up_vector=torch.tensor([0.0, 0.0, 1.0]) 110 | ): 111 | """ 112 | Calculates the rotation matrix that transforms vector v1 to v2, while keeping an "up" direction fixed. 113 | 114 | Args: 115 | - v1 (torch.Tensor): The source 3D vector (3x1). 116 | - v2 (torch.Tensor): The target 3D vector (3x1). 117 | - up_vector (torch.Tensor): The fixed "up" direction (3x1). Default is [0, 0, 1]. 118 | 119 | Returns: 120 | - torch.Tensor: The 3x3 rotation matrix. 121 | """ 122 | v1 = F.normalize(v1, dim=0) 123 | v2 = F.normalize(v2, dim=0) 124 | up_vector = F.normalize(up_vector, dim=0) 125 | 126 | # Compute the cross product and dot product 127 | cross_product = torch.cross(v1, v2) 128 | dot_product = torch.dot(v1, v2) 129 | 130 | # Skew-symmetric matrix for cross product 131 | skew_sym_matrix = torch.tensor( 132 | [ 133 | [0, -cross_product[2], cross_product[1]], 134 | [cross_product[2], 0, -cross_product[0]], 135 | [-cross_product[1], cross_product[0], 0], 136 | ], 137 | dtype=torch.float32, 138 | device=v1.device, 139 | ) 140 | 141 | # Rotation matrix using Rodrigues' formula 142 | rotation_matrix = ( 143 | torch.eye(3, device=v1.device) 144 | + skew_sym_matrix 145 | + torch.mm(skew_sym_matrix, skew_sym_matrix) 146 | * (1 - dot_product) 147 | / torch.norm(cross_product) ** 2 148 | ) 149 | 150 | # Ensure the "up" direction is fixed 151 | rotated_up_vector = torch.mv(rotation_matrix, up_vector) 152 | correction_axis = torch.cross(rotated_up_vector, up_vector) 153 | correction_angle = torch.acos(torch.dot(rotated_up_vector, up_vector)) 154 | 155 | # Apply the correction to the rotation matrix 156 | correction_matrix = F.normalize(skew_sym_matrix, dim=0) * correction_angle 157 | corrected_rotation_matrix = torch.eye(3, device=v1.device) + correction_matrix 158 | 159 | return corrected_rotation_matrix 160 | 161 | 162 | def singleRandomBetweenTensors(a: torch.tensor, b: torch.tensor) -> torch.tensor: 163 | assert a.size() == b.size() 164 | assert a.device == b.device 165 | 166 | rands = random.uniform(0, 1) 167 | return rands * (b - a) + b 168 | 169 | 170 | def randomBetweenTensors(a: torch.tensor, b: torch.tensor) -> torch.tensor: 171 | assert a.size() == b.size() 172 | assert a.device == b.device 173 | 174 | rands = torch.rand(a.shape, device=a.device) 175 | return rands * (b - a) + a 176 | 177 | 178 | def normalize(tensor: torch.tensor) -> torch.tensor: 179 | tensor = tensor - tensor.amin() 180 | tensor = tensor / tensor.amax() 181 | return tensor 182 | 183 | 184 | def normalize_channelwise( 185 | tensor: torch.tensor, 186 | dim: int = -1, 187 | device: torch.cuda.device = torch.device("cuda"), 188 | ) -> torch.tensor: 189 | indices = torch.arange(0, len(tensor.shape), device=device) 190 | mask = torch.ones(indices.shape, dtype=torch.bool, device=device) 191 | mask[dim] = False 192 | indices = indices[mask].tolist() 193 | 194 | tensor = tensor - tensor.amin(indices) 195 | tensor = tensor / tensor.amax(indices) 196 | return tensor 197 | 198 | 199 | def convert_points_to_homogeneous(points: torch.tensor) -> torch.tensor: 200 | return torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=1.0) 201 | 202 | 203 | def toMat4x4(mat: torch.tensor, addOne: bool = True) -> torch.tensor: 204 | mat4x4 = torch.nn.functional.pad(mat, pad=(0, 1, 0, 1), mode="constant", value=0.0) 205 | 206 | if addOne: 207 | mat4x4[3, 3] = 1.0 208 | 209 | return mat4x4 210 | 211 | 212 | def convert_points_from_homogeneous(points: torch.tensor) -> torch.tensor: 213 | return points[..., :-1] / points[..., -1:] 214 | 215 | 216 | def convert_points_to_nonhomogeneous(points: torch.tensor) -> torch.tensor: 217 | return torch.nn.functional.pad(points, pad=(0, 1), mode="constant", value=0.0) 218 | 219 | 220 | def transform_points(points: torch.tensor, transform: torch.tensor) -> torch.tensor: 221 | points_1_h = convert_points_to_homogeneous(points) 222 | 223 | points_0_h = torch.matmul(transform.unsqueeze(0), points_1_h.unsqueeze(-1)) 224 | 225 | points_0_h = points_0_h.squeeze(dim=-1) 226 | 227 | points_0 = convert_points_from_homogeneous(points_0_h) 228 | return points_0 229 | 230 | 231 | def transform_directions(points: torch.tensor, transform: torch.tensor) -> torch.tensor: 232 | points = convert_points_to_nonhomogeneous(points) 233 | points = transform.unsqueeze(0) @ points.unsqueeze(-1) 234 | points = torch.squeeze(points, axis=-1) 235 | return points[..., :-1] 236 | -------------------------------------------------------------------------------- /fireflies/entity/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List 3 | 4 | import fireflies.utils.math 5 | import fireflies.sampling 6 | 7 | 8 | class Transformable: 9 | def __init__( 10 | self, 11 | name: str, 12 | device: torch.cuda.device = torch.device("cuda"), 13 | ): 14 | 15 | self._device: torch.cuda.device = device 16 | self._name: str = name 17 | 18 | self._randomizable: bool = False 19 | 20 | self._parent = None 21 | self._child = None 22 | 23 | self._train = True 24 | 25 | self._float_attributes = {} 26 | self._randomized_float_attributes = {} 27 | 28 | self._vec3_attributes = {} 29 | self._randomized_vec3_attributes = {} 30 | 31 | zeros = torch.zeros(3, device=self._device) 32 | self._rotation_sampler = fireflies.sampling.UniformSampler( 33 | zeros.clone(), zeros.clone() 34 | ) 35 | 36 | self._translation_sampler = fireflies.sampling.UniformSampler( 37 | zeros.clone(), zeros.clone() 38 | ) 39 | 40 | self._world = torch.eye(4, device=self._device) 41 | self._randomized_world = torch.eye(4, device=self._device) 42 | 43 | self._centroid_mat = torch.zeros((4, 4), device=self._device) 44 | 45 | self._eval_delta = 0.01 46 | self._num_updates = 0 47 | 48 | def randomizable(self) -> bool: 49 | return self._randomizable 50 | 51 | def set_centroid(self, centroid: torch.tensor) -> None: 52 | self._centroid_mat[0, 3] = centroid.squeeze()[0] 53 | self._centroid_mat[1, 3] = centroid.squeeze()[1] 54 | self._centroid_mat[2, 3] = centroid.squeeze()[2] 55 | 56 | def set_randomizable(self, randomizable: bool) -> None: 57 | self._randomizable = randomizable 58 | 59 | def get_randomized_vec3_attributes(self) -> dict: 60 | return self._randomized_vec3_attributes 61 | 62 | def get_randomized_float_attributes(self) -> dict: 63 | return self._randomized_float_attributes 64 | 65 | def vec3_attributes(self) -> dict: 66 | return self._vec3_attributes 67 | 68 | def float_attributes(self) -> dict: 69 | return self._float_attributes 70 | 71 | def add_float_sampler(self, key: str, sampler: fireflies.sampling.Sampler) -> None: 72 | self._randomizable = True 73 | self._float_attributes[key] = sampler 74 | 75 | def add_float_key(self, key: str, min: float, max: float) -> None: 76 | """Transforms float key into a Uniform Sampler""" 77 | self._randomizable = True 78 | self._float_attributes[key] = fireflies.sampling.UniformSampler( 79 | min, max, device=self._device 80 | ) 81 | 82 | def add_vec3_key(self, key: str, min: torch.tensor, max: torch.tensor) -> None: 83 | """Transforms vec3 into Uniform Sampler""" 84 | self._randomizable = True 85 | self._vec3_attributes[key] = fireflies.sampling.UniformSampler( 86 | min, max, device=self._device 87 | ) 88 | 89 | def add_vec3_sampler(self, key: str, sampler: fireflies.sampling.Sampler) -> None: 90 | self._randomizable = True 91 | self._vec3_attributes[key] = sampler 92 | 93 | def parent(self): 94 | return self._parent 95 | 96 | def child(self): 97 | return self._child 98 | 99 | def name(self): 100 | return self._name 101 | 102 | def train(self) -> None: 103 | self._train = True 104 | self._translation_sampler.train() 105 | self._rotation_sampler.train() 106 | 107 | for sampler in self._float_attributes.values(): 108 | sampler.train() 109 | 110 | for sampler in self._vec3_attributes.values(): 111 | sampler.train() 112 | 113 | def eval(self) -> None: 114 | self._train = False 115 | self._translation_sampler.eval() 116 | self._rotation_sampler.eval() 117 | 118 | for sampler in self._float_attributes.values(): 119 | sampler.eval() 120 | 121 | for sampler in self._vec3_attributes.values(): 122 | sampler.eval() 123 | 124 | def set_world(self, _origin: torch.tensor) -> None: 125 | self._world = _origin 126 | self._randomized_world = self._world.clone() 127 | 128 | def setParent(self, parent) -> None: 129 | self._parent = parent 130 | parent.setChild(self) 131 | 132 | def setChild(self, child) -> None: 133 | self._child = child 134 | 135 | def set_rotation_sampler(self, sampler: fireflies.sampling.Sampler) -> None: 136 | self._rotation_sampler = sampler 137 | 138 | def set_translation_sampler(self, sampler: fireflies.sampling.Sampler) -> None: 139 | self._translation_sampler = sampler 140 | 141 | def update_index_from_sampler(self, sampler, min, max, index) -> None: 142 | sampler_min = sampler.get_min() 143 | sampler_max = sampler.get_max() 144 | 145 | sampler_min[index] = min 146 | sampler_max[index] = max 147 | 148 | def rotate_x(self, min_rot: float, max_rot: float) -> None: 149 | """Convenience function for Uniform Sampler""" 150 | self._randomizable = True 151 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 0) 152 | 153 | def rotate_y(self, min_rot: float, max_rot: float) -> None: 154 | """Convenience function for Uniform Sampler""" 155 | self._randomizable = True 156 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 1) 157 | 158 | def rotate_z(self, min_rot: float, max_rot: float) -> None: 159 | """Convenience function for Uniform Sampler""" 160 | self._randomizable = True 161 | self.update_index_from_sampler(self._rotation_sampler, min_rot, max_rot, 2) 162 | 163 | def rotate(self, min: torch.tensor, max: torch.tensor) -> None: 164 | """Convenience function for Uniform Sampler""" 165 | self._randomizable = True 166 | self._rotation_sampler.set_sample_interval( 167 | min.to(self._device), max.to(self._device) 168 | ) 169 | 170 | def translate_x(self, min_translation: float, max_translation: float) -> None: 171 | self._randomizable = True 172 | self.update_index_from_sampler( 173 | self._translation_sampler, min_translation, max_translation, 0 174 | ) 175 | 176 | def translate_y(self, min_translation: float, max_translation: float) -> None: 177 | self._randomizable = True 178 | self.update_index_from_sampler( 179 | self._translation_sampler, min_translation, max_translation, 1 180 | ) 181 | 182 | def translate_z(self, min_translation: float, max_translation: float) -> None: 183 | self._randomizable = True 184 | self.update_index_from_sampler( 185 | self._translation_sampler, min_translation, max_translation, 2 186 | ) 187 | 188 | def translate(self, min: torch.tensor, max: torch.tensor) -> None: 189 | self._randomizable = True 190 | self._translation_sampler.set_sample_interval( 191 | min.to(self._device), max.to(self._device) 192 | ) 193 | 194 | def sample_rotation(self) -> torch.tensor: 195 | self._sampled_rotation = self._rotation_sampler.sample() 196 | 197 | zMat = fireflies.utils.math.getPitchTransform( 198 | self._sampled_rotation[2], self._device 199 | ) 200 | yMat = fireflies.utils.math.getYawTransform( 201 | self._sampled_rotation[1], self._device 202 | ) 203 | xMat = fireflies.utils.math.getRollTransform( 204 | self._sampled_rotation[0], self._device 205 | ) 206 | 207 | return fireflies.utils.math.toMat4x4(zMat @ yMat @ xMat) 208 | 209 | def sample_translation(self) -> torch.tensor: 210 | translation = torch.eye(4, device=self._device) 211 | 212 | self._random_translation = self._translation_sampler.sample() 213 | 214 | translation[0, 3] = self._random_translation[0] 215 | translation[1, 3] = self._random_translation[1] 216 | translation[2, 3] = self._random_translation[2] 217 | self._last_translation = translation 218 | return translation 219 | 220 | def randomize(self) -> None: 221 | if not self.randomizable(): 222 | return 223 | 224 | self._randomized_world = ( 225 | (self.sample_translation() + self._centroid_mat) 226 | @ self.sample_rotation() 227 | @ self._world 228 | ) 229 | 230 | for key, sampler in self._float_attributes.items(): 231 | self._randomized_float_attributes[key] = sampler.sample() 232 | 233 | for key, sampler in self._vec3_attributes.items(): 234 | self._randomized_vec3_attributes[key] = sampler.sample() 235 | 236 | def relative(self) -> None: 237 | return self._parent is not None 238 | 239 | def world(self) -> torch.tensor: 240 | # If no parent exists, just return the current translation 241 | if self._parent is None: 242 | return self._randomized_world.clone() 243 | 244 | return self._parent.world() @ self._randomized_world 245 | 246 | def nonRandomizedWorld(self) -> torch.tensor: 247 | if self._parent is None: 248 | return self._world 249 | 250 | return self._parent.nonRandomizedWorld() @ self._world 251 | -------------------------------------------------------------------------------- /fireflies/projection/laser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | import fireflies.graphics.rasterization 6 | import fireflies.utils.math 7 | import fireflies.sampling.poisson 8 | import fireflies.entity 9 | 10 | import fireflies.projection as projection 11 | import yaml 12 | 13 | from typing import List 14 | 15 | 16 | class Laser(projection.Camera): 17 | # Static Convenience Function 18 | @staticmethod 19 | def generate_uniform_rays( 20 | intra_ray_angle: float, 21 | num_beams_x: int, 22 | num_beams_y: int, 23 | device: torch.cuda.device = torch.device("cuda"), 24 | ) -> torch.tensor: 25 | laserRays = torch.zeros((num_beams_y * num_beams_x, 3), device=device) 26 | 27 | for x in range(num_beams_x): 28 | for y in range(num_beams_y): 29 | laserRays[x * num_beams_x + y, :] = torch.tensor( 30 | [ 31 | math.tan((x - (num_beams_x - 1) / 2) * intra_ray_angle), 32 | math.tan((y - (num_beams_y - 1) / 2) * intra_ray_angle), 33 | -1.0, 34 | ] 35 | ) 36 | 37 | return laserRays / torch.linalg.norm(laserRays, dim=-1, keepdims=True) 38 | 39 | @staticmethod 40 | def generate_uniform_rays_by_count( 41 | num_beams_x: int, 42 | num_beams_y: int, 43 | intrinsic_matrix: torch.tensor, 44 | device: torch.cuda.device = torch.device("cuda"), 45 | ) -> torch.tensor: 46 | laserRays = torch.zeros((num_beams_y * num_beams_x, 3), device=device) 47 | 48 | x_steps = torch.arange((1 / num_beams_x) / 2, 1, 1 / num_beams_x) 49 | y_steps = torch.arange((1 / num_beams_y) / 2, 1, 1 / num_beams_y) 50 | 51 | xy = torch.stack(torch.meshgrid(x_steps, y_steps)) 52 | xy = xy.movedim(0, -1).reshape(-1, 2) 53 | 54 | # Set Z to 1 55 | laserRays[:, 0:2] = xy 56 | laserRays[:, 2] = -1.0 57 | 58 | # Project to world 59 | rays = fireflies.utils.transforms.transform_points( 60 | laserRays, intrinsic_matrix.inverse() 61 | ) 62 | 63 | # Normalize 64 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True) 65 | rays[:, 2] *= -1.0 66 | return rays 67 | 68 | @staticmethod 69 | def generate_random_rays( 70 | num_beams: int, 71 | intrinsic_matrix: torch.tensor, 72 | device: torch.cuda.device = torch.device("cuda"), 73 | ) -> torch.tensor: 74 | 75 | # Random points and move into NDC 76 | spawned_points = ( 77 | torch.ones([num_beams, 3], device=device) * 0.5 78 | + (torch.rand([num_beams, 3], device=device) - 0.5) / 10.0 79 | ) 80 | 81 | # Set Z to 1 82 | spawned_points[:, 2] = -1.0 83 | 84 | # Project to world 85 | rays = fireflies.utils.transforms.transform_points( 86 | spawned_points, intrinsic_matrix.inverse() 87 | ) 88 | 89 | # Normalize 90 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True) 91 | rays[:, 2] *= -1.0 92 | return rays 93 | 94 | @staticmethod 95 | def generate_blue_noise_rays( 96 | image_size_x: int, 97 | image_size_y: int, 98 | num_beams: int, 99 | intrinsic_matrix: torch.tensor, 100 | device: torch.cuda.device = torch.device("cuda"), 101 | ) -> torch.tensor: 102 | 103 | # We want to know the radius of the poisson disk so that we get roughly N beams 104 | # 105 | # So we say N < (X*Y) / PI*r*r <=> sqrt(X*Y / PI*N) ~ r 106 | # 107 | 108 | poisson_radius = math.sqrt( 109 | (image_size_x * image_size_y) / (math.pi * num_beams) 110 | ) 111 | poisson_radius += poisson_radius / 4.0 112 | im = np.ones([image_size_x, image_size_y]) * poisson_radius 113 | num_samples, poisson_samples = fireflies.sampling.poisson.bridson(im) 114 | # print(len(poisson_samples)) 115 | poisson_samples = torch.tensor(poisson_samples, device=device) 116 | 117 | # Remove random points from poisson samples such that num_beams is correct again. 118 | # indices = torch.linspace(0, poisson_samples.shape[0] - 1, poisson_samples.shape[0], device=device) 119 | # indices = torch.multinomial(indices, num_beams, replacement=False).long() 120 | # poisson_samples = poisson_samples[indices] 121 | 122 | # From image space to 0 1 123 | poisson_samples /= torch.tensor([image_size_x, image_size_y], device=device) 124 | 125 | # Create empty tensor for copying 126 | temp = torch.ones([poisson_samples.shape[0], 3], device=device) * -1.0 127 | 128 | # Copy to temp and add 1 for z coordinate 129 | temp[:, 0:2] = poisson_samples 130 | # temp[:, 0:2] = poisson_samples 131 | 132 | # Project to world 133 | rays = fireflies.utils.math.transform_points(temp, intrinsic_matrix.inverse()) 134 | 135 | # rays = fireflies.utils.transforms.transform_points( 136 | # rays, 137 | # fireflies.utils.transforms.toMat4x4( 138 | # utils_math.getZTransform(0.5 * np.pi, intrinsic_matrix.device) 139 | # ), 140 | # ) 141 | 142 | # Normalize 143 | rays = rays / torch.linalg.norm(rays, dim=-1, keepdims=True) 144 | rays[:, 2] *= -1.0 145 | return rays 146 | 147 | def __init__( 148 | self, 149 | transformable: fireflies.entity.base.Transformable, 150 | ray_directions, 151 | perspective: torch.tensor, 152 | max_fov: float, 153 | near_clip: float = 0.01, 154 | far_clip: float = 1000.0, 155 | device: torch.cuda.device = torch.device("cuda"), 156 | ): 157 | super(Laser, self).__init__( 158 | transformable, perspective, max_fov, near_clip, far_clip, device 159 | ) 160 | self._rays = ray_directions.to(self.device) 161 | self.device = device 162 | 163 | def rays(self) -> torch.tensor: 164 | return fireflies.utils.math.transform_directions( 165 | self._rays, 166 | self._fireflies.transformable.transformable.Transformable.world(), 167 | ) 168 | 169 | def origin(self) -> torch.tensor: 170 | return self._fireflies.transformable.transformable.Transformable.world() 171 | 172 | def originPerRay(self) -> torch.tensor: 173 | return ( 174 | self._fireflies.transformable.transformable.Transformable.world()[0:3, 3] 175 | .unsqueeze(0) 176 | .repeat(self._rays.shape[0], 1) 177 | ) 178 | 179 | def near_clip(self) -> float: 180 | return self._near_clip 181 | 182 | def far_clip(self) -> float: 183 | return self._far_clip 184 | 185 | def initRandomRays(self): 186 | # Spawn random points in [-1.0, 1.0] 187 | spawned_points = torch.rand(self._rays.shape, device=self.device) * 2.0 - 1.0 188 | 189 | # Set Z to 1 190 | spawned_points[:, 2] = 1.0 191 | 192 | # Project to world 193 | rand_rays = self.projectNDCPointsToWorld(spawned_points) 194 | self._rays = self.normalize(rand_rays) 195 | 196 | def initPoissonDiskSamples(self, width, height, radius): 197 | return None 198 | 199 | def clamp_to_fov(self, clamp_val: float = 0.95, epsilon: float = 0.0001) -> None: 200 | # TODO: Check, if laser beam falls out of fov. If it does, clamp it back. 201 | # If randomize is set, spawn a new random laser inside NDC. 202 | # Else, clamp it to the edge. 203 | ndc_coords = self.projectRaysToNDC() 204 | ndc_coords[:, 0:2] = torch.clamp(ndc_coords[:, 0:2], 1 - clamp_val, clamp_val) 205 | clamped_rays = self.projectNDCPointsToWorld(ndc_coords) 206 | self._rays[:] = self.normalize(clamped_rays) 207 | 208 | def randomize_laser_out_of_bounds(self) -> None: 209 | # TODO: Check, if laser beam falls out of fov. If it does, spawn a new randomly in NDC in (-1, 1). 210 | new_rays = self._rays.clone() 211 | 212 | # No need to transform as rays are in laser space anyway 213 | ndc_coords = fireflies.utils.transforms.transform_points( 214 | new_rays, self._perspective 215 | ) 216 | xy_coords = ndc_coords[:, 0:2] 217 | out_of_bounds_indices = ((xy_coords >= 1.0) | (xy_coords <= 0.0)).any(dim=1) 218 | 219 | out_of_bounds_points = ndc_coords[out_of_bounds_indices] 220 | 221 | if out_of_bounds_points.nelement() == 0: 222 | return 0 223 | 224 | new_ray_point = torch.rand(out_of_bounds_points.shape, device=self.device) 225 | new_ray_point[:, 2] = -1.0 226 | 227 | clamped_rays = self.projectNDCPointsToWorld(new_ray_point) 228 | new_rays[out_of_bounds_indices] = clamped_rays 229 | new_rays = self.normalize(new_rays) 230 | 231 | self._rays[:] = new_rays 232 | 233 | def randomize_camera_out_of_bounds(self, ndc_coords) -> None: 234 | new_rays = self._rays.clone() 235 | xy_coords = ndc_coords[:, 0:2] 236 | out_of_bounds_indices = ((xy_coords >= 1.0) | (xy_coords <= -1.0)).any(dim=1) 237 | out_of_bounds_points = ndc_coords[out_of_bounds_indices] 238 | 239 | if out_of_bounds_points.nelement() == 0: 240 | return 0 241 | 242 | new_ray_point = torch.rand(out_of_bounds_points.shape, device=self.device) 243 | new_ray_point[:, 2] = -1.0 244 | 245 | clamped_rays = self.projectNDCPointsToWorld(new_ray_point) 246 | new_rays[out_of_bounds_indices] = clamped_rays 247 | new_rays = self.normalize(new_rays) 248 | 249 | self._rays[:] = new_rays 250 | 251 | def normalize(self, tensor: torch.tensor) -> torch.tensor: 252 | return tensor / torch.linalg.norm(tensor, dim=-1, keepdims=True) 253 | 254 | def normalize_rays(self) -> None: 255 | self._rays[:] = self.normalize(self._rays) 256 | 257 | def setToWorld(self, to_world: torch.tensor) -> None: 258 | self._to_world = ( 259 | self._fireflies.transformable.transformable.Transformable.setWorld(to_world) 260 | ) 261 | 262 | def projectRaysToNDC(self) -> torch.tensor: 263 | # rays_in_world = fireflies.utils.transforms_torch.transform_directions(self._rays, self._to_world) 264 | FLIP_Y = torch.tensor( 265 | [ 266 | [1.0, 0.0, 0.0, 0.0], 267 | [0.0, -1.0, 0.0, 0.0], 268 | [0.0, 0.0, 1.0, 0.0], 269 | [0.0, 0.0, 0.0, 1.0], 270 | ], 271 | device=self.device, 272 | ) 273 | return fireflies.utils.math.transform_points( 274 | self._rays, self._perspective @ FLIP_Y 275 | ) 276 | 277 | def projectNDCPointsToWorld(self, points: torch.tensor) -> torch.tensor: 278 | FLIP_Y = torch.tensor( 279 | [ 280 | [1.0, 0.0, 0.0, 0.0], 281 | [0.0, -1.0, 0.0, 0.0], 282 | [0.0, 0.0, 1.0, 0.0], 283 | [0.0, 0.0, 0.0, 1.0], 284 | ], 285 | device=self.device, 286 | ) 287 | 288 | return fireflies.utils.math.transform_points( 289 | points, (self._perspective @ FLIP_Y).inverse() 290 | ) 291 | 292 | def generateTexture(self, sigma: float, texture_size: List[int]) -> torch.tensor: 293 | points = self.projectRaysToNDC()[:, 0:2] 294 | return fireflies.graphics.rasterization.rasterize_points( 295 | points.cpu(), sigma, texture_size.cpu(), device="cpu" 296 | ) 297 | 298 | def render_epipolar_lines( 299 | self, sigma: float, texture_size: torch.tensor 300 | ) -> torch.tensor: 301 | epipolar_min = self.originPerRay() + self._near_clip * self.rays() 302 | epipolar_max = self.originPerRay() + self._far_clip * self.rays() 303 | 304 | CAMERA_TO_WORLD = self._fireflies.entity.Transformable.world() 305 | WORLD_TO_CAMERA = CAMERA_TO_WORLD.inverse() 306 | 307 | epipolar_max = fireflies.utils.transforms.transform_points( 308 | epipolar_max, WORLD_TO_CAMERA 309 | ) 310 | epipolar_max = fireflies.utils.transforms.transform_points( 311 | epipolar_max, self._perspective 312 | )[:, 0:2] 313 | 314 | epipolar_min = fireflies.utils.transforms.transform_points( 315 | epipolar_min, WORLD_TO_CAMERA 316 | ) 317 | epipolar_min = fireflies.utils.transforms.transform_points( 318 | epipolar_min, self._perspective 319 | )[:, 0:2] 320 | 321 | lines = torch.stack([epipolar_min, epipolar_max], dim=1) 322 | 323 | return fireflies.graphics.rasterization.rasterize_lines( 324 | lines, sigma, texture_size 325 | ) 326 | 327 | def save(self, filepath: str): 328 | save_dict = { 329 | "rays": self._rays.detach().cpu().numpy().tolist(), 330 | "fov": self._fov, 331 | "near_clip": self._near_clip, 332 | "far_clip": self._far_clip, 333 | } 334 | 335 | with open(filepath, "w") as file: 336 | yaml.dump(save_dict, file) 337 | -------------------------------------------------------------------------------- /fireflies/utils/laser_estimation.py: -------------------------------------------------------------------------------- 1 | import mitsuba as mi 2 | import cv2 3 | import numpy as np 4 | import torch 5 | 6 | from scipy.spatial import ConvexHull 7 | import numpy as np 8 | 9 | import fireflies.sampling.poisson 10 | 11 | import fireflies.entity 12 | import fireflies.projection 13 | 14 | import fireflies.graphics.depth 15 | import fireflies.graphics.rasterization 16 | 17 | import fireflies.utils.math 18 | import fireflies.utils.transforms 19 | import fireflies.utils.intersections 20 | 21 | 22 | import math 23 | 24 | 25 | def probability_distribution_from_depth_maps( 26 | depth_maps: np.array, uniform_weight: float = 0.0 27 | ) -> np.array: 28 | 29 | variance_map = depth_maps.std(axis=0) 30 | variance_map += uniform_weight 31 | 32 | return variance_map 33 | 34 | 35 | def points_from_probability_distribution( 36 | prob_distribution: torch.tensor, num_samples: int 37 | ) -> torch.tensor: 38 | 39 | p = prob_distribution.flatten() 40 | chosen_points = p.multinomial(num_samples, replacement=False) 41 | 42 | return chosen_points 43 | 44 | 45 | def get_camera_direction(sensor, device: torch.cuda.device) -> torch.tensor: 46 | film = sensor.film() 47 | sampler = sensor.sampler() 48 | film_size = film.size() 49 | total_samples = 1 50 | 51 | if sampler.wavefront_size() != total_samples: 52 | sampler.seed(0, total_samples) 53 | 54 | # Enumerate discrete sample & pixel indices, and uniformly sample 55 | # positions within each pixel. 56 | # pos = mi.UInt32(points.split(split_size=1)) 57 | 58 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 59 | pos = mi.Vector2f(mi.Float(0.5), mi.Float(0.5)) 60 | 61 | # pos += sampler.next_2d() 62 | 63 | # Sample rays starting from the camera sensor 64 | rays, weights = sensor.sample_ray( 65 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0 66 | ) 67 | 68 | return rays.o.torch(), rays.d.torch() 69 | 70 | 71 | def get_camera_frustum(sensor, device: torch.cuda.device) -> torch.tensor: 72 | # film = sensor.film() 73 | sampler = sensor.sampler() 74 | # film_size = film.size() 75 | # total_samples = 4 76 | 77 | # if sampler.wavefront_size() != total_samples: 78 | # sampler.seed(0, total_samples) 79 | 80 | # Enumerate discrete sample & pixel indices, and uniformly sample 81 | # positions within each pixel. 82 | # pos = mi.UInt32(points.split(split_size=1)) 83 | 84 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 85 | pos = mi.Vector2f(mi.Float([0.0, 1.0, 0.0, 1.0]), mi.Float([0.0, 0.0, 1.0, 1.0])) 86 | 87 | # pos += sampler.next_2d() 88 | 89 | # Sample rays starting from the camera sensor 90 | rays, weights = sensor.sample_ray( 91 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0 92 | ) 93 | 94 | ray_origins = rays.o.torch() 95 | ray_directions = rays.d.torch() 96 | 97 | # x_transform = transforms.toMat4x4(utils_math.getXTransform(np.pi*0.5, ray_origins.device)) 98 | # ray_origins = transforms.transform_points(ray_origins, x_transform) 99 | # ray_directions = transforms.transform_directions(ray_directions, x_transform) 100 | 101 | return ray_origins, ray_directions 102 | 103 | 104 | def getRayFromSensor(sensor, ray_coordinate_in_ndc): 105 | sampler = sensor.sampler() 106 | 107 | # scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 108 | pos = mi.Vector2f( 109 | mi.Float([ray_coordinate_in_ndc[0]]), mi.Float([ray_coordinate_in_ndc[1]]) 110 | ) 111 | 112 | # Sample rays starting from the camera sensor 113 | rays, weights = sensor.sample_ray( 114 | time=0, sample1=sampler.next_1d(), sample2=pos, sample3=0 115 | ) 116 | 117 | return rays.o.torch(), rays.d.torch() 118 | 119 | 120 | def create_rays(sensor, points) -> torch.tensor: 121 | film = sensor.film() 122 | sampler = sensor.sampler() 123 | film_size = film.size() 124 | total_samples = points.shape[0] 125 | 126 | if sampler.wavefront_size() != total_samples: 127 | sampler.seed(0, total_samples) 128 | 129 | # Enumerate discrete sample & pixel indices, and uniformly sample 130 | # positions within each pixel. 131 | pos = mi.UInt32(points.split(split_size=1)) 132 | 133 | scale = mi.Vector2f(1.0 / film_size[0], 1.0 / film_size[1]) 134 | pos = mi.Vector2f( 135 | mi.Float(pos % int(film_size[0])), mi.Float(pos // int(film_size[0])) 136 | ) 137 | 138 | # pos += sampler.next_2d() 139 | 140 | # Sample rays starting from the camera sensor 141 | rays, weights = sensor.sample_ray( 142 | time=0, sample1=sampler.next_1d(), sample2=pos * scale, sample3=0 143 | ) 144 | 145 | return rays.o.torch(), rays.d.torch() 146 | 147 | 148 | def laser_from_ndc_points( 149 | sensor, laser_origin, depth_maps, chosen_points, device: torch.cuda.device("cuda") 150 | ) -> torch.tensor: 151 | ray_origins, ray_directions = create_rays(sensor, chosen_points) 152 | 153 | # Get camera origin and direction 154 | camera_origin, camera_direction = get_camera_direction(sensor, device) 155 | 156 | camera_origin = sensor.world_transform().translation().torch() 157 | 158 | camera_direction = camera_direction / torch.linalg.norm( 159 | camera_direction, dim=-1, keepdims=True 160 | ) 161 | 162 | # Build plane from depth map 163 | plane_origin = camera_origin + camera_direction * depth_maps.mean() 164 | plane_normal = -camera_direction 165 | 166 | # Compute intersections inbetween mean plane and randomly sampled rays 167 | intersection_distances = fireflies.utils.intersections.rayPlane( 168 | ray_origins, ray_directions, plane_origin, plane_normal 169 | ) 170 | world_points = ray_origins + ray_directions * intersection_distances 171 | 172 | laser_dir = world_points - laser_origin 173 | laser_dir = laser_dir / torch.linalg.norm(laser_dir, dim=-1, keepdims=True) 174 | return laser_dir 175 | 176 | 177 | def draw_lines(ax, rayOrigin, rayDirection, ray_length=1.0, color="g"): 178 | for i in range(rayDirection.shape[0]): 179 | ax.plot( 180 | [rayOrigin[i, 0], rayOrigin[i, 0] + ray_length * rayDirection[i, 0]], 181 | [rayOrigin[i, 1], rayOrigin[i, 1] + ray_length * rayDirection[i, 1]], 182 | [rayOrigin[i, 2], rayOrigin[i, 2] + ray_length * rayDirection[i, 2]], 183 | color=color, 184 | ) 185 | 186 | 187 | def generate_epipolar_constraints(scene, params, device): 188 | 189 | camera_sensor = scene.sensors()[0] 190 | 191 | projector_sensor = scene.sensors()[1] 192 | proj_xwidth, proj_ywidth = projector_sensor.film().size() 193 | 194 | ray_origins, ray_directions = get_camera_frustum(projector_sensor, device) 195 | camera_origins, camera_directions = get_camera_frustum(camera_sensor, device) 196 | 197 | near_clip = params["PerspectiveCamera_1.near_clip"] 198 | far_clip = params["PerspectiveCamera_1.far_clip"] 199 | # steps = 1 200 | # delta = (far_clip - near_clip / steps) 201 | 202 | projection_points = ray_origins + far_clip * ray_directions 203 | epipolar_points = projection_points 204 | 205 | # K = utils.build_projection_matrix(params['PerspectiveCamera.x_fov'], params['PerspectiveCamera.near_clip'], params['PerspectiveCamera.far_clip']) 206 | K = mi.perspective_projection( 207 | camera_sensor.film().size(), 208 | camera_sensor.film().crop_size(), 209 | camera_sensor.film().crop_offset(), 210 | params["PerspectiveCamera.x_fov"], 211 | params["PerspectiveCamera.near_clip"], 212 | params["PerspectiveCamera.far_clip"], 213 | ).matrix.torch()[0] 214 | CAMERA_WORLD = params["PerspectiveCamera.to_world"].matrix.torch()[0] 215 | # CAMERA_WORLD[0:3, 0:3] = CAMERA_WORLD[0:3, 0:3] @ utils_math.getYTransform(np.pi, CAMERA_WORLD.device) 216 | 217 | # mi.perspective_transformation(scene.sensors()[0].film.size()) 218 | 219 | epipolar_points = transforms.transform_points( 220 | epipolar_points, CAMERA_WORLD.inverse() 221 | ) 222 | epipolar_points = transforms.transform_points(epipolar_points, K)[:, 0:2] 223 | 224 | # Is in [0 -> 1] 225 | epi_points_np = epipolar_points.detach().cpu().numpy() 226 | 227 | hull = ConvexHull(epi_points_np) 228 | line_segments = epipolar_points[hull.vertices] 229 | 230 | # We could also calculate the fundamental matrix 231 | # and use this to estimate epipolar lines here 232 | # However, we 233 | # Find closest point between min and max 234 | # Replace this point by the epipolar minimum 235 | # This gives us the convex hull of the epipolar constraints 236 | # In clockwise order 237 | 238 | camera_size = np.array(camera_sensor.film().crop_size()) 239 | # camera_size = camera_size[[1, 0]] # swap image size to Y,X 240 | 241 | epi_points_np = line_segments.cpu().numpy() 242 | # epi_points_np = epi_points_np[:, [1, 0]] 243 | epi_points_np *= camera_size 244 | 245 | image = np.zeros(camera_size[[1, 0]], dtype=np.uint8) 246 | image = cv2.fillPoly(image, [epi_points_np.astype(int)], color=1) 247 | # cv2.imshow("Epipolar Image", image * 255) 248 | # cv2.waitKey(0) 249 | 250 | return torch.from_numpy(image).to(device) 251 | 252 | 253 | def initialize_laser( 254 | mitsuba_scene, mitsuba_params, firefly_scene, config, mode, device 255 | ): 256 | projector_sensor = mitsuba_scene.sensors()[1] 257 | 258 | near_clip = mitsuba_scene.sensors()[1].near_clip() 259 | far_clip = mitsuba_scene.sensors()[1].far_clip() 260 | laser_fov = float(mitsuba_params["PerspectiveCamera_1.x_fov"][0]) 261 | near_clip = mitsuba_scene.sensors()[1].near_clip() 262 | 263 | radians = math.pi / 180.0 264 | 265 | image_size = torch.tensor(mitsuba_scene.sensors()[1].film().size(), device=device) 266 | LASER_K = mi.perspective_projection( 267 | projector_sensor.film().size(), 268 | projector_sensor.film().crop_size(), 269 | projector_sensor.film().crop_offset(), 270 | laser_fov, 271 | near_clip, 272 | far_clip, 273 | ).matrix.torch()[0] 274 | n_beams = config.n_beams 275 | 276 | local_laser_dir = None 277 | if mode == "RANDOM": 278 | local_laser_dir = fireflies.projection.Laser.generate_random_rays( 279 | num_beams=n_beams, intrinsic_matrix=LASER_K, device=device 280 | ) 281 | elif mode == "POISSON": 282 | local_laser_dir = fireflies.projection.Laser.generate_blue_noise_rays( 283 | image_size_x=image_size[0], 284 | image_size_y=image_size[1], 285 | num_beams=n_beams, 286 | intrinsic_matrix=LASER_K, 287 | device=device, 288 | ) 289 | elif mode == "GRID": 290 | grid_width = int(math.sqrt(config.n_beams)) 291 | local_laser_dir = fireflies.projection.Laser.generate_uniform_rays_by_count( 292 | num_beams_x=grid_width, 293 | num_beams_y=grid_width, 294 | intrinsic_matrix=LASER_K, 295 | device=device, 296 | ) 297 | elif mode == "SMARTY": 298 | # Doesnt work, IDK why 299 | constraint_map = generate_epipolar_constraints( 300 | mitsuba_scene, mitsuba_params, device 301 | ) 302 | 303 | # Generate random depth maps by uniformly sampling from scene parameter ranges 304 | # print(config.n_depthmaps) 305 | depth_maps = fireflies.graphics.depth.random_depth_maps( 306 | firefly_scene, mitsuba_scene, num_maps=config.n_depthmaps 307 | ) 308 | 309 | # Given depth maps, generate probability distribution 310 | variance_map = probability_distribution_from_depth_maps( 311 | depth_maps, config.variational_epsilon 312 | ) 313 | variance_map = fireflies.utils.math.normalize(variance_map) 314 | vm = (variance_map.cpu().numpy() * 255).astype(np.uint8) 315 | vm = cv2.applyColorMap(vm, cv2.COLORMAP_INFERNO) 316 | 317 | # cv2.imshow("Variance Map", vm) 318 | # cv2.waitKey(0) 319 | 320 | # Final multiplication and normalization 321 | final_sampling_map = variance_map # * constraint_map 322 | final_sampling_map /= final_sampling_map.sum() 323 | 324 | # Gotta flip this in y direction, since apparently I can't program 325 | # final_sampling_map = torch.fliplr(final_sampling_map) 326 | # final_sampling_map = torch.flip(final_sampling_map, (0,)) 327 | 328 | # sample points for laser rays 329 | 330 | min_radius = config.smarty_min_radius 331 | max_radius = config.smarty_max_radius 332 | normalized_sampling = 1 - fireflies.utils.math.normalize(final_sampling_map) 333 | normalized_sampling = ( 334 | min_radius + (max_radius - min_radius) * normalized_sampling 335 | ) 336 | n_points, points = fireflies.sampling.poisson.bridson( 337 | normalized_sampling.detach().cpu().numpy(), 50 338 | ) 339 | points = torch.from_numpy(points).to(device).floor().int() 340 | chosen_points = points[:, 0] * final_sampling_map.shape[1] + points[:, 1] 341 | 342 | # chosen_points = points_from_probability_distribution(final_sampling_map, config.n_beams) 343 | 344 | vm = variance_map.cpu().numpy() 345 | cp = chosen_points.cpu().numpy() 346 | cm = constraint_map.cpu().numpy() 347 | """ 348 | if config.save_images: 349 | vm = (vm * 255).astype(np.uint8) 350 | vm = cv2.applyColorMap(vm, cv2.COLORMAP_INFERNO) 351 | vm.reshape(-1, 3)[cp, :] = ~vm.reshape(-1, 3)[cp, :] 352 | cv2.imwrite("sampling_map.png", vm) 353 | cm = cm * 255 354 | cv2.imwrite("constraint_map.png", cm) 355 | """ 356 | 357 | laser_world = firefly_scene.projector.world() 358 | laser_origin = laser_world[0:3, 3] 359 | # Sample directions of laser beams from variance map 360 | laser_dir = laser_from_ndc_points( 361 | mitsuba_scene.sensors()[0], 362 | laser_origin, 363 | depth_maps, 364 | chosen_points, 365 | device=device, 366 | ) 367 | 368 | # Apply inverse rotation of the projector, such that we get a normalized direction 369 | # The laser direction up until now is in world coordinates! 370 | local_laser_dir = transforms.transform_directions( 371 | laser_dir, laser_world.inverse() 372 | ) 373 | 374 | # local_laser_dir = transforms.transform_points( 375 | # local_laser_dir, 376 | # transforms.toMat4x4(utils_math.getZTransform(0.5 * np.pi, device)), 377 | # ) 378 | 379 | # local_laser_dir = transforms.transform_points( 380 | # local_laser_dir, 381 | # transforms.toMat4x4(utils_math.getYTransform(-0.5 * np.pi, device)), 382 | # ) 383 | 384 | return fireflies.projection.Laser( 385 | firefly_scene.projector, 386 | local_laser_dir, 387 | LASER_K, 388 | laser_fov, 389 | near_clip, 390 | far_clip, 391 | ) 392 | -------------------------------------------------------------------------------- /fireflies/scene.py: -------------------------------------------------------------------------------- 1 | import mitsuba as mi 2 | 3 | import torch 4 | import fireflies.entity 5 | import fireflies.utils 6 | import fireflies.emitter 7 | import fireflies.material 8 | 9 | from typing import List 10 | 11 | 12 | class Scene: 13 | MESH_KEYS = ["mesh", "ply"] 14 | CAM_KEYS = ["camera", "perspective", "perspectivecamera"] 15 | PROJ_KEYS = ["projector"] 16 | MAT_KEYS = ["mat", "bsdf"] 17 | LIGHT_KEYS = ["light", "spot"] 18 | TEX_KEYS = ["tex"] 19 | 20 | def __init__( 21 | self, 22 | mitsuba_params, 23 | device: torch.cuda.device = torch.device("cuda"), 24 | ): 25 | # Here, only objects are saved, that have a "randomizable"-tag inside the yaml file. 26 | self._meshes = [] 27 | self._projector = None 28 | self._camera = None 29 | self._lights = [] 30 | self._curves = [] 31 | self._materials = [] 32 | 33 | self._transformables = [] 34 | 35 | self._device = device 36 | 37 | self._mitsuba_params = mitsuba_params 38 | 39 | self.init_from_params(self._mitsuba_params) 40 | 41 | def device(self) -> torch.cuda.device: 42 | return self._device 43 | 44 | def mesh_at(self, index: int) -> fireflies.entity.Transformable: 45 | return self._meshes[index] 46 | 47 | def meshes(self) -> fireflies.entity.Transformable: 48 | return self._meshes 49 | 50 | def get_mesh(self, name: str) -> fireflies.entity.Transformable: 51 | for mesh in self._meshes: 52 | if mesh.name() == name: 53 | return mesh 54 | 55 | return None 56 | 57 | def mesh(self, name: str) -> fireflies.entity.Transformable: 58 | return self.get_mesh(name) 59 | 60 | def light_at(self, index: int) -> fireflies.entity.Transformable: 61 | return self._lights[index] 62 | 63 | def lights(self) -> fireflies.entity.Transformable: 64 | return self._lights 65 | 66 | def get_light(self, name: str) -> fireflies.entity.Transformable: 67 | for light in self._lights: 68 | if light.name() == name: 69 | return light 70 | 71 | return None 72 | 73 | def light(self, name: str) -> fireflies.entity.Transformable: 74 | return self.get_light(name) 75 | 76 | def material_at(self, index: int) -> fireflies.entity.Transformable: 77 | return self._materials[index] 78 | 79 | def materials(self) -> fireflies.entity.Transformable: 80 | return self._materials 81 | 82 | def get_material(self, name: str) -> fireflies.entity.Transformable: 83 | for mesh in self._materials: 84 | if mesh.name() == name: 85 | return mesh 86 | 87 | return None 88 | 89 | def material(self, name: str) -> fireflies.entity.Transformable: 90 | return self.get_material(name) 91 | 92 | def init_from_params(self, mitsuba_params) -> None: 93 | # Get all scene keys 94 | param_keys = [key.split(".")[0] for key in mitsuba_params.keys()] 95 | 96 | # Remove multiples 97 | param_keys = set(param_keys) 98 | param_keys = sorted(param_keys) 99 | 100 | for key in param_keys: 101 | # Check if its a mesh 102 | if any(MESH_KEY.lower() in key.lower() for MESH_KEY in self.MESH_KEYS): 103 | self.load_mesh(key) 104 | continue 105 | elif any(CAMERA_KEY.lower() in key.lower() for CAMERA_KEY in self.CAM_KEYS): 106 | self.load_camera(key) 107 | continue 108 | elif any(PROJ_KEY.lower() in key.lower() for PROJ_KEY in self.PROJ_KEYS): 109 | self.load_projector(key) 110 | continue 111 | elif any(LIGHT_KEY.lower() in key.lower() for LIGHT_KEY in self.LIGHT_KEYS): 112 | self.load_light(key) 113 | continue 114 | elif any(MAT_KEY.lower() in key.lower() for MAT_KEY in self.MAT_KEYS): 115 | self.load_material(key) 116 | continue 117 | 118 | def load_mesh(self, base_key: str): 119 | # Gotta compute the centroid here, as mitsuba does not have a world transform for meshes 120 | vertices = torch.tensor( 121 | self._mitsuba_params[base_key + ".vertex_positions"], device=self._device 122 | ).reshape(-1, 3) 123 | centroid = vertices.sum(dim=0, keepdim=True) / vertices.shape[0] 124 | 125 | aligned_vertices = vertices - centroid 126 | 127 | transformable_mesh = fireflies.entity.Mesh( 128 | base_key, aligned_vertices, self._device 129 | ) 130 | transformable_mesh.set_centroid(centroid) 131 | 132 | self._meshes.append(transformable_mesh) 133 | 134 | def load_camera(self, base_key: str) -> None: 135 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch() 136 | to_world = to_world.squeeze().to(self._device) 137 | transformable_camera = fireflies.entity.Transformable(base_key, self._device) 138 | transformable_camera.set_world(to_world) 139 | transformable_camera.set_randomizable(False) 140 | self._camera = transformable_camera 141 | 142 | def load_projector(self, base_key: str) -> None: 143 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch() 144 | to_world = to_world.squeeze().to(self._device) 145 | transformable_projector = fireflies.entity.Transformable(base_key, self._device) 146 | transformable_projector.set_world(to_world) 147 | transformable_projector.set_randomizable(False) 148 | self._projector = transformable_projector 149 | 150 | def load_light(self, base_key: str) -> None: 151 | new_light = fireflies.emitter.Light(base_key, device=self._device) 152 | 153 | if base_key + ".to_world" in self._mitsuba_params.keys(): 154 | to_world = self._mitsuba_params[base_key + ".to_world"].matrix.torch() 155 | to_world = to_world.squeeze().to(self._device) 156 | new_light.set_world(to_world) 157 | 158 | light_keys = [] 159 | for key in self._mitsuba_params.keys(): 160 | if base_key in key: 161 | light_keys.append(key) 162 | 163 | for key in light_keys: 164 | key_without_base = ".".join(key.split(".")[1:]) 165 | value = self._mitsuba_params[key] 166 | 167 | if type(value) == mi.Transform4f: 168 | continue 169 | 170 | if isinstance(value, mi.Float) or isinstance(value, float): 171 | new_light.add_float_key(key_without_base, value, value) 172 | elif len(value) == 3: 173 | value = value.torch().squeeze() 174 | new_light.add_vec3_key(key_without_base, value, value) 175 | 176 | new_light.set_randomizable(False) 177 | self._lights.append(new_light) 178 | 179 | def load_material(self, base_key: str) -> None: 180 | new_material = fireflies.material.Material(base_key, device=self._device) 181 | 182 | material_keys = [] 183 | for key in self._mitsuba_params.keys(): 184 | if base_key in key: 185 | material_keys.append(key) 186 | 187 | for key in material_keys: 188 | key_without_base = ".".join(key.split(".")[1:]) 189 | value = self._mitsuba_params[key] 190 | 191 | if type(value) == mi.Transform4f or isinstance(value, mi.ScalarTransform3f): 192 | continue 193 | 194 | if isinstance(value, mi.Float) or isinstance(value, float): 195 | new_material.add_float_key(key_without_base, value, value) 196 | elif len(value) == 3: 197 | value = value.torch().squeeze() 198 | new_material.add_vec3_key(key_without_base, value, value) 199 | 200 | new_material.set_randomizable(False) 201 | self._materials.append(new_material) 202 | 203 | def train(self) -> None: 204 | # We first randomize all of our objects 205 | for mesh in self._meshes: 206 | mesh.train() 207 | 208 | for light in self._lights: 209 | light.train() 210 | 211 | for material in self._materials: 212 | material.train() 213 | 214 | if self._camera is not None: 215 | self._camera.train() 216 | 217 | if self._projector is not None: 218 | self._projector.train() 219 | 220 | def eval(self) -> None: 221 | # We first randomize all of our objects 222 | for mesh in self._meshes: 223 | mesh.eval() 224 | 225 | for light in self._lights: 226 | light.eval() 227 | 228 | for material in self._materials: 229 | material.eval() 230 | 231 | if self._camera is not None: 232 | self._camera.eval() 233 | 234 | if self._projector is not None: 235 | self._projector.eval() 236 | 237 | def load_curve(self, path: str, name: str = "Curve") -> None: 238 | curve = fireflies.utils.importBlenderNurbsObj(path) 239 | transformable_curve = fireflies.entity.Curve(name, curve, self._device) 240 | 241 | self.curves.append(transformable_curve) 242 | 243 | def update_meshes(self) -> None: 244 | for mesh in self._meshes: 245 | if not mesh.randomizable(): 246 | continue 247 | 248 | vertex_data = mesh.get_randomized_vertices() 249 | self._mitsuba_params[mesh.name() + ".vertex_positions"] = mi.Float32( 250 | vertex_data.flatten() 251 | ) 252 | 253 | def update_camera(self) -> None: 254 | if not self._camera.randomizable(): 255 | return 256 | 257 | self._mitsuba_params[self._camera.name() + ".to_world"] = mi.Transform4f( 258 | self._camera.world().tolist() 259 | ) 260 | 261 | float_dict = self._camera.get_randomized_float_attributes() 262 | vec3_dict = self._camera.get_randomized_vec3_attributes() 263 | 264 | for key, value in float_dict.items(): 265 | joined_key = self._camera.name() + "." + key 266 | temp_type = type(self._mitsuba_params[joined_key]) 267 | self._mitsuba_params[joined_key] = temp_type(value.item()) 268 | 269 | for key, value in vec3_dict.items(): 270 | joined_key = self._camera.name() + "." + key 271 | temp_type = type(self._mitsuba_params[joined_key]) 272 | self._mitsuba_params[self._camera.name() + "." + key] = temp_type( 273 | value.tolist() 274 | ) 275 | 276 | def update_projector(self) -> None: 277 | if not self._projector.randomizable(): 278 | return 279 | 280 | self._mitsuba_params[self._projector.name() + ".to_world"] = mi.Transform4f( 281 | self._projector.world().tolist() 282 | ) 283 | 284 | float_dict = self._projector.get_randomized_float_attributes() 285 | vec3_dict = self._projector.get_randomized_vec3_attributes() 286 | 287 | for key, value in float_dict.items(): 288 | joined_key = self._projector.name() + "." + key 289 | temp_type = type(self._mitsuba_params[joined_key]) 290 | self._mitsuba_params[joined_key] = temp_type(value.item()) 291 | 292 | for key, value in vec3_dict.items(): 293 | joined_key = self._projector.name() + "." + key 294 | temp_type = type(self._mitsuba_params[joined_key]) 295 | self._mitsuba_params[self._projector.name() + "." + key] = temp_type( 296 | value.tolist() 297 | ) 298 | 299 | def update_lights(self) -> None: 300 | for light in self._lights: 301 | if not light.randomizable(): 302 | continue 303 | 304 | if light.name() + ".to_world" in self._mitsuba_params.keys(): 305 | self._mitsuba_params[light.name() + ".to_world"] = mi.Transform4f( 306 | light.world().tolist() 307 | ) 308 | 309 | float_dict = light.get_randomized_float_attributes() 310 | vec3_dict = light.get_randomized_vec3_attributes() 311 | 312 | for key, value in float_dict.items(): 313 | joined_key = light.name() + "." + key 314 | temp_type = type(self._mitsuba_params[joined_key]) 315 | self._mitsuba_params[joined_key] = temp_type(value.item()) 316 | 317 | for key, value in vec3_dict.items(): 318 | joined_key = light.name() + "." + key 319 | temp_type = type(self._mitsuba_params[joined_key]) 320 | self._mitsuba_params[light.name() + "." + key] = temp_type( 321 | value.tolist() 322 | ) 323 | 324 | def update_materials(self) -> None: 325 | for material in self._materials: 326 | if not material.randomizable(): 327 | continue 328 | 329 | float_dict = material.get_randomized_float_attributes() 330 | vec3_dict = material.get_randomized_vec3_attributes() 331 | 332 | for key, value in float_dict.items(): 333 | joined_key = material.name() + "." + key 334 | temp_type = type(self._mitsuba_params[joined_key]) 335 | self._mitsuba_params[joined_key] = temp_type(value.item()) 336 | 337 | for key, value in vec3_dict.items(): 338 | joined_key = material.name() + "." + key 339 | temp_type = type(self._mitsuba_params[joined_key]) 340 | self._mitsuba_params[material.name() + "." + key] = temp_type( 341 | value.tolist() 342 | ) 343 | 344 | def randomize_list(self, entity_list: List[fireflies.entity.Transformable]) -> None: 345 | # First find parent objects, i.e. child is none 346 | parent_objects = [] 347 | for entity in entity_list: 348 | if entity.parent() is None: 349 | parent_objects.append(entity) 350 | 351 | # Now iterate through every parent object and iteratively call each child randomization function 352 | for entity in parent_objects: 353 | entity.randomize() 354 | 355 | iterator_child = entity.child() 356 | while iterator_child is not None: 357 | iterator_child.randomize() 358 | iterator_child = iterator_child.child() 359 | 360 | def randomize(self) -> None: 361 | # We first randomize all of our objects 362 | self.randomize_list(self._meshes) 363 | self.randomize_list(self._lights) 364 | self.randomize_list(self._materials) 365 | 366 | if self._camera is not None: 367 | self._camera.randomize() 368 | 369 | if self._projector is not None: 370 | self._projector.randomize() 371 | 372 | # And then copy the updates to the mitsuba parameters 373 | self.update_meshes() 374 | 375 | if self._camera is not None: 376 | self.update_camera() 377 | 378 | if self._projector is not None: 379 | self.update_projector() 380 | self.update_lights() 381 | self.update_materials() 382 | 383 | # We finally update the mitsuba scene graph itself 384 | self._mitsuba_params.update() 385 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mitsuba as mi 3 | import numpy as np 4 | import kornia 5 | 6 | mi.set_variant("cuda_ad_rgb") 7 | 8 | import torch 9 | import fireflies 10 | import fireflies.sampling 11 | import fireflies.projection.laser 12 | import fireflies.postprocessing 13 | import fireflies.utils.math 14 | import fireflies.graphics.depth 15 | import os 16 | 17 | from tqdm import tqdm 18 | 19 | 20 | def render_to_numpy(render): 21 | render = torch.clamp(render.torch(), 0, 1)[:, :, [2, 1, 0]].cpu().numpy() 22 | return render 23 | 24 | 25 | if __name__ == "__main__": 26 | path = "examples/scenes/realistic_vf/vocalfold.xml" 27 | dataset_path = "../LearningFromFireflies/fireflies_dataset_v4/" 28 | 29 | mitsuba_scene = mi.load_file(path, parallel=False) 30 | mitsuba_params = mi.traverse(mitsuba_scene) 31 | ff_scene = fireflies.Scene(mitsuba_params) 32 | ff_scene._camera._name = "Camera" 33 | 34 | projector_sensor = mitsuba_scene.sensors()[1] 35 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"] 36 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"] 37 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"] 38 | 39 | K_PROJECTOR = mi.perspective_projection( 40 | projector_sensor.film().size(), 41 | projector_sensor.film().crop_size(), 42 | projector_sensor.film().crop_offset(), 43 | x_fov, 44 | near_clip, 45 | far_clip, 46 | ).matrix.torch()[0] 47 | 48 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays( 49 | # 0.0275, 18, 18, device=ff_scene.device() 50 | # ) 51 | laser_rays = fireflies.projection.Laser.generate_uniform_rays( 52 | 0.0275, 18, 18, device=ff_scene.device() 53 | ) 54 | 55 | laser = fireflies.projection.Laser( 56 | ff_scene._projector, 57 | laser_rays, 58 | K_PROJECTOR, 59 | x_fov, 60 | near_clip, 61 | far_clip, 62 | device=ff_scene.device(), 63 | ) 64 | texture = laser.generateTexture( 65 | 10.0, torch.tensor([500, 500], device=ff_scene.device()) 66 | ) 67 | texture = texture.sum(dim=0) 68 | 69 | texture = kornia.filters.gaussian_blur2d( 70 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3) 71 | ).squeeze() 72 | texture = torch.stack( 73 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)] 74 | ) 75 | texture = torch.movedim(texture, 0, -1) 76 | 77 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy()) 78 | 79 | vocalfold_mesh = ff_scene.mesh("mesh-Vocalfold") 80 | vocalfold_mesh.scale_x(1.0, 3.0) 81 | vocalfold_mesh.scale_z(1.0, 3.0) 82 | vocalfold_mesh.rotate_y(-0.2, 0.2) 83 | vocalfold_mesh.translate_y(-0.05, -0.05) 84 | vocalfold_mesh.add_train_animation_from_obj("examples/scenes/vocalfold_new/train/") 85 | vocalfold_mesh.add_eval_animation_from_obj("examples/scenes/vocalfold_new/test/") 86 | 87 | a = vocalfold_mesh._anim_data_train 88 | scale_mat = torch.eye(4, device=ff_scene.device()) * 0.054 89 | scale_mat[3, 3] = 1.0 90 | for i in range(a.shape[0]): 91 | a[i] = fireflies.utils.math.transform_points(a[i], scale_mat) 92 | 93 | larynx_mesh = ff_scene.mesh("mesh-Larynx") 94 | larynx_mesh.scale_x(1.0, 4.0) 95 | larynx_mesh.scale_z(1.0, 2.0) 96 | 97 | # Randomization of Mucosa material 98 | material = ff_scene.material("mat-Mucosa") 99 | material.add_float_key("brdf_0.clearcoat.value", 0.0, 1.0) 100 | material.add_float_key("brdf_0.clearcoat_gloss.value", 0.0, 1.0) 101 | material.add_float_key("brdf_0.metallic.value", 0.0, 0.5) 102 | material.add_float_key("brdf_0.specular", 0.0, 1.0) 103 | material.add_float_key("brdf_0.roughness.value", 0.0, 1.0) 104 | material.add_float_key("brdf_0.anisotropic.value", 0.0, 1.0) 105 | material.add_float_key("brdf_0.sheen.value", 0.0, 0.5) 106 | material.add_float_key("brdf_0.spec_trans.value", 0.0, 0.4) 107 | material.add_float_key("brdf_0.flatness.value", 0.0, 1.0) 108 | 109 | # Camera Randomization 110 | ff_scene._camera.translate_x(-0.15, 0.15) 111 | ff_scene._camera.translate_y(-0.15, 0.0) 112 | ff_scene._camera.translate_z(-0.15, 0.15) 113 | ff_scene._camera.rotate_x(-0.2, 0.2) 114 | ff_scene._camera.rotate_y(-0.5, 0.5) 115 | ff_scene._camera.rotate_z((-np.pi / 2.0) - 0.5, (-np.pi / 2.0) + 0.5) 116 | ff_scene._camera.add_float_key("x_fov", 70.0, 130.0) 117 | 118 | # Light Randomization 119 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler( 120 | 0.1, 10.0, device=ff_scene.device() 121 | ) 122 | light = ff_scene.light("emit-Spot") 123 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler) 124 | 125 | texture = ( 126 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"] 127 | .torch() 128 | .moveaxis(-1, 0) 129 | .shape 130 | ) 131 | 132 | lerp_sampler = fireflies.sampling.NoiseTextureLerpSampler( 133 | color_a=torch.tensor([0.0, 0.0, 0.0], device=ff_scene.device()), 134 | color_b=torch.tensor([1.0, 1.0, 1.0], device=ff_scene.device()), 135 | texture_shape=(1024, 1024), 136 | ) 137 | 138 | post_process_funcs = [ 139 | fireflies.postprocessing.GaussianBlur((3, 3), (5, 5), 0.5), 140 | fireflies.postprocessing.ApplySilhouette(), 141 | fireflies.postprocessing.WhiteNoise(0.0, 0.05, 0.5), 142 | ] 143 | post_processor = fireflies.postprocessing.PostProcessor(post_process_funcs) 144 | spp_sampler = fireflies.sampling.AnimationSampler(1, 100, 1, 100) 145 | ff_scene.train() 146 | count = 0 147 | while count != 10000: 148 | lerp_sampler._color_a = torch.rand((3), device=ff_scene.device()) 149 | lerp_sampler._color_b = torch.rand((3), device=ff_scene.device()) 150 | mucosa_texture = lerp_sampler.sample() 151 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"] = mi.TensorXf( 152 | mucosa_texture.moveaxis(0, -1).cpu().numpy() 153 | ) 154 | 155 | ff_scene.randomize() 156 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample()) 157 | render = render_to_numpy(render) 158 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY) 159 | render = post_processor.post_process(render) 160 | 161 | segmentation = ( 162 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene) 163 | .cpu() 164 | .numpy() 165 | .astype(np.float32) 166 | ) 167 | 168 | if segmentation.max() == 0: 169 | continue 170 | 171 | seg_test = segmentation.copy() 172 | seg_test[seg_test == 1] = 0 173 | seg_test[seg_test == 2] = 1 174 | seg_test = 1 - seg_test 175 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( 176 | seg_test.astype(np.uint8) 177 | ) 178 | 179 | if n_labels > 3: 180 | continue 181 | 182 | segmentation_map = segmentation / segmentation.max() 183 | final = cv2.hconcat([render, segmentation_map]) 184 | 185 | cv2.imwrite( 186 | os.path.join(dataset_path, "train/images/{0:05d}.png".format(count)), 187 | (render * 255).astype(np.uint8), 188 | ) 189 | cv2.imwrite( 190 | os.path.join(dataset_path, "train/segmentation/{0:05d}.png".format(count)), 191 | segmentation.astype(np.uint8), 192 | ) 193 | count += 1 194 | 195 | count = 0 196 | while count != 500: 197 | lerp_sampler._color_a = torch.rand((3), device=ff_scene.device()) 198 | lerp_sampler._color_b = torch.rand((3), device=ff_scene.device()) 199 | mucosa_texture = lerp_sampler.sample() 200 | mitsuba_params["mat-Mucosa.brdf_0.base_color.data"] = mi.TensorXf( 201 | mucosa_texture.moveaxis(0, -1).cpu().numpy() 202 | ) 203 | 204 | ff_scene.randomize() 205 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample()) 206 | render = render_to_numpy(render) 207 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY) 208 | render = post_processor.post_process(render) 209 | 210 | segmentation = ( 211 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene) 212 | .cpu() 213 | .numpy() 214 | .astype(np.float32) 215 | ) 216 | 217 | if segmentation.max() == 0: 218 | continue 219 | 220 | seg_test = segmentation.copy() 221 | seg_test[seg_test == 1] = 0 222 | seg_test[seg_test == 2] = 1 223 | seg_test = 1 - seg_test 224 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( 225 | seg_test.astype(np.uint8) 226 | ) 227 | 228 | if n_labels > 3: 229 | continue 230 | 231 | segmentation_map = segmentation / segmentation.max() 232 | final = cv2.hconcat([render, segmentation_map]) 233 | 234 | cv2.imwrite( 235 | os.path.join(dataset_path, "eval/images/{0:05d}.png".format(count)), 236 | (render * 255).astype(np.uint8), 237 | ) 238 | cv2.imwrite( 239 | os.path.join(dataset_path, "eval/segmentation/{0:05d}.png".format(count)), 240 | segmentation.astype(np.uint8), 241 | ) 242 | count += 1 243 | 244 | ''' 245 | if __name__ == "__main__": 246 | path = "examples/scenes/vocalfold_new/vocalfold.xml" 247 | 248 | mitsuba_scene = mi.load_file(path, parallel=False) 249 | mitsuba_params = mi.traverse(mitsuba_scene) 250 | ff_scene = fireflies.Scene(mitsuba_params) 251 | ff_scene._camera._name = "Camera" 252 | 253 | projector_sensor = mitsuba_scene.sensors()[1] 254 | x_fov = mitsuba_params["PerspectiveCamera_1.x_fov"] 255 | near_clip = mitsuba_params["PerspectiveCamera_1.near_clip"] 256 | far_clip = mitsuba_params["PerspectiveCamera_1.far_clip"] 257 | 258 | dataset_path = "../LearningFromFireflies/fireflies_dataset_v3/" 259 | 260 | K_PROJECTOR = mi.perspective_projection( 261 | projector_sensor.film().size(), 262 | projector_sensor.film().crop_size(), 263 | projector_sensor.film().crop_offset(), 264 | x_fov, 265 | near_clip, 266 | far_clip, 267 | ).matrix.torch()[0] 268 | 269 | # laser_rays = fireflies.projection.Laser.generate_uniform_rays( 270 | # 0.0275, 18, 18, device=ff_scene.device() 271 | # ) 272 | laser_rays = fireflies.projection.Laser.generate_uniform_rays( 273 | 0.0275, 18, 18, device=ff_scene.device() 274 | ) 275 | 276 | laser = fireflies.projection.Laser( 277 | ff_scene._projector, 278 | laser_rays, 279 | K_PROJECTOR, 280 | x_fov, 281 | near_clip, 282 | far_clip, 283 | device=ff_scene.device(), 284 | ) 285 | texture = laser.generateTexture( 286 | 10.0, torch.tensor([500, 500], device=ff_scene.device()) 287 | ) 288 | texture = texture.sum(dim=0) 289 | 290 | texture = kornia.filters.gaussian_blur2d( 291 | texture.unsqueeze(0).unsqueeze(0), (5, 5), (3, 3) 292 | ).squeeze() 293 | texture = torch.stack( 294 | [torch.zeros_like(texture), texture, torch.zeros_like(texture)] 295 | ) 296 | texture = torch.movedim(texture, 0, -1) 297 | 298 | mitsuba_params["tex.data"] = mi.TensorXf(texture.cpu().numpy()) 299 | 300 | vocalfold_mesh = ff_scene.mesh("mesh-VocalFold") 301 | vocalfold_mesh.scale_x(0.75, 3.0) 302 | vocalfold_mesh.scale_y(1.1, 2.0) 303 | vocalfold_mesh.rotate_y(-0.2, 0.2) 304 | vocalfold_mesh.add_train_animation_from_obj("examples/scenes/vocalfold_new/train/") 305 | vocalfold_mesh.add_eval_animation_from_obj("examples/scenes/vocalfold_new/test/") 306 | 307 | a = vocalfold_mesh._anim_data_train 308 | scale_mat = torch.eye(4, device=ff_scene.device()) * 0.05 309 | scale_mat[3, 3] = 1.0 310 | rot_mat = fireflies.utils.math.toMat4x4( 311 | fireflies.utils.math.getXTransform(np.pi / 2.0, ff_scene.device()) 312 | ) 313 | for i in range(a.shape[0]): 314 | a[i] = fireflies.utils.math.transform_points(a[i], scale_mat) 315 | a[i] = fireflies.utils.math.transform_points(a[i], rot_mat) 316 | 317 | larynx_mesh = ff_scene.mesh("mesh-Larynx") 318 | larynx_mesh.scale_x(0.3, 1.2) 319 | larynx_mesh.rotate_y(-0.5, 0.5) 320 | # larynx_mesh.scale_z(1.0, 2.5) 321 | 322 | material = ff_scene.material("mat-Default OBJ") 323 | material.add_vec3_key( 324 | "brdf_0.base_color.value", 325 | torch.tensor([0.3, 0.3, 0.33], device=ff_scene.device()), 326 | torch.tensor([0.85, 0.85, 0.85], device=ff_scene.device()), 327 | ) 328 | 329 | for key in material.float_attributes(): 330 | if "sampling_rate" in key: 331 | continue 332 | 333 | material.add_float_key(key, 0.01, 0.99) 334 | 335 | scalar_to_vec3_sampler = fireflies.sampling.UniformScalarToVec3Sampler( 336 | 1.0, 80.0, device=ff_scene.device() 337 | ) 338 | light = ff_scene.light("emit-Spot") 339 | light.add_vec3_sampler("intensity.value", scalar_to_vec3_sampler) 340 | 341 | post_process_funcs = [ 342 | fireflies.postprocessing.GaussianBlur((3, 3), (5, 5), 0.5), 343 | fireflies.postprocessing.ApplySilhouette(), 344 | fireflies.postprocessing.WhiteNoise(0.0, 0.05, 0.5), 345 | ] 346 | post_processor = fireflies.postprocessing.PostProcessor(post_process_funcs) 347 | 348 | ff_scene._camera.translate_x(-0.5, 0.5) 349 | ff_scene._camera.translate_y(-0.5, 0.5) 350 | ff_scene._camera.translate_z(-0.5, 0.5) 351 | ff_scene._camera.add_float_key("x_fov", 20.0, 50.0) 352 | ff_scene.train() 353 | 354 | spp_sampler = fireflies.sampling.AnimationSampler(1, 100, 1, 100) 355 | spp_sampler.train() 356 | """ 357 | count = 0 358 | while count != 10000: 359 | ff_scene.randomize() 360 | render = mi.render(mitsuba_scene, spp=spp_sampler.sample()) 361 | render = render_to_numpy(render) 362 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY) 363 | 364 | render = post_processor.post_process(render) 365 | 366 | segmentation = ( 367 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene) 368 | .cpu() 369 | .numpy() 370 | .astype(np.float32) 371 | ) 372 | 373 | if segmentation.max() == 0: 374 | continue 375 | 376 | seg_test = segmentation.copy() 377 | seg_test[seg_test == 2] = 1 378 | seg_test = 1 - seg_test 379 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( 380 | seg_test.astype(np.uint8) 381 | ) 382 | 383 | if n_labels > 3: 384 | continue 385 | 386 | segmentation_map = segmentation / segmentation.max() 387 | final = cv2.hconcat([render, segmentation_map]) 388 | 389 | cv2.imwrite( 390 | os.path.join(dataset_path, "train/images/{0:05d}.png".format(count)), 391 | (render * 255).astype(np.uint8), 392 | ) 393 | cv2.imwrite( 394 | os.path.join(dataset_path, "train/segmentation/{0:05d}.png".format(count)), 395 | segmentation.astype(np.uint8), 396 | ) 397 | count += 1 398 | """ 399 | 400 | count = 0 401 | while count != 500: 402 | ff_scene.randomize() 403 | render = mi.render(mitsuba_scene, spp=100) 404 | render = render_to_numpy(render) 405 | render = cv2.cvtColor(render, cv2.COLOR_RGB2GRAY) 406 | 407 | render = post_processor.post_process(render) 408 | 409 | segmentation = ( 410 | fireflies.graphics.depth.get_segmentation_from_camera(mitsuba_scene) 411 | .cpu() 412 | .numpy() 413 | .astype(np.float32) 414 | ) 415 | 416 | if segmentation.max() == 0: 417 | continue 418 | 419 | seg_test = segmentation.copy() 420 | seg_test[seg_test == 2] = 1 421 | seg_test = 1 - seg_test 422 | n_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( 423 | seg_test.astype(np.uint8) 424 | ) 425 | 426 | if n_labels > 3: 427 | continue 428 | 429 | segmentation_map = segmentation / segmentation.max() 430 | final = cv2.hconcat([render, segmentation_map]) 431 | 432 | cv2.imwrite( 433 | os.path.join(dataset_path, "eval/images/{0:05d}.png".format(count)), 434 | (render * 255).astype(np.uint8), 435 | ) 436 | cv2.imwrite( 437 | os.path.join(dataset_path, "eval/segmentation/{0:05d}.png".format(count)), 438 | segmentation.astype(np.uint8), 439 | ) 440 | count += 1 441 | ''' 442 | -------------------------------------------------------------------------------- /fireflies/graphics/rasterization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | 6 | # We assume points to be in camera space [0, 1] 7 | def rasterize_points( 8 | points: torch.tensor, 9 | sigma: float, 10 | texture_size: torch.tensor, 11 | device: torch.cuda.device = torch.device("cuda"), 12 | ) -> torch.tensor: 13 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 14 | tex = tex[None, ...] 15 | tex = tex.repeat((points.shape[0], 1, 1, 1)) 16 | 17 | # Somewhere between [0, texture_size] but in float 18 | points = points.clone() * texture_size 19 | 20 | # Generate x, y indices 21 | x, y = torch.meshgrid( 22 | torch.arange(0, texture_size[1], device=device), 23 | torch.arange(0, texture_size[0], device=device), 24 | indexing="ij", 25 | ) 26 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1)) 27 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1)) 28 | 29 | y_dist = y - points[:, 0:1].unsqueeze(-1) 30 | x_dist = x - points[:, 1:2].unsqueeze(-1) 31 | 32 | point_distances = ( 33 | y_dist * y_dist + x_dist * x_dist 34 | ) # / (texture_size * texture_size).sum().sqrt() 35 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2)) 36 | 37 | return point_distances 38 | 39 | 40 | def rasterize_points_in_non_ndc( 41 | points: torch.tensor, 42 | sigma: float, 43 | texture_size: torch.tensor, 44 | device: torch.cuda.device = torch.device("cuda"), 45 | ) -> torch.tensor: 46 | x, y = torch.meshgrid( 47 | torch.arange(0, texture_size[1], device=device), 48 | torch.arange(0, texture_size[0], device=device), 49 | indexing="ij", 50 | ) 51 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1)) 52 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1)) 53 | 54 | y_dist = y - points[:, 0:1].unsqueeze(-1) 55 | x_dist = x - points[:, 1:2].unsqueeze(-1) 56 | 57 | point_distances = ( 58 | y_dist * y_dist + x_dist * x_dist 59 | ) # / (texture_size * texture_size).sum().sqrt() 60 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2)) 61 | 62 | return point_distances 63 | 64 | 65 | # We assume points to be in NDC [-1, 1] 66 | def rasterize_depth( 67 | points: torch.tensor, 68 | depth_vals: torch.tensor, 69 | sigma: float, 70 | texture_size: torch.tensor, 71 | device: torch.cuda.device = torch.device("cuda"), 72 | ) -> torch.tensor: 73 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 74 | tex = tex[None, ...] 75 | tex = tex.repeat((points.shape[0], 1, 1, 1)) 76 | 77 | # Somewhere between [0, texture_size] but in float 78 | points = points.clone() * texture_size 79 | 80 | # Generate x, y indices 81 | x, y = torch.meshgrid( 82 | torch.arange(0, texture_size[1], device=device), 83 | torch.arange(0, texture_size[0], device=device), 84 | indexing="ij", 85 | ) 86 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1)) 87 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1)) 88 | 89 | y_dist = y - points[:, 0:1].unsqueeze(-1) 90 | x_dist = x - points[:, 1:2].unsqueeze(-1) 91 | 92 | point_distances = ( 93 | y_dist * y_dist + x_dist * x_dist 94 | ) # / (texture_size * texture_size).sum().sqrt() 95 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2)) 96 | 97 | # normalize 98 | point_distances = ( 99 | point_distances 100 | / point_distances.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] 101 | ) 102 | 103 | # scale by depth in range [0, 1] 104 | return point_distances * depth_vals.unsqueeze(-1) 105 | 106 | 107 | def rasterize_lines( 108 | lines: torch.tensor, 109 | sigma: float, 110 | texture_size: torch.tensor, 111 | device: torch.cuda.device = torch.device("cuda"), 112 | ) -> torch.tensor: 113 | # lines are in NDC [-1, 1] 114 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 115 | tex = tex[None, ...] 116 | tex = tex.repeat((lines.shape[0], 1, 1, 1)) 117 | 118 | lines_start = lines[:, 0, :] 119 | lines_end = lines[:, 1, :] 120 | 121 | # Somewhere between [0, texture_size] but in float 122 | lines_start *= texture_size 123 | lines_end *= texture_size 124 | 125 | lines_start = lines_start.permute(1, 0).unsqueeze(-1).unsqueeze(-1) 126 | lines_end = lines_end.permute(1, 0).unsqueeze(-1).unsqueeze(-1) 127 | 128 | y, x = torch.meshgrid( 129 | torch.arange(0, texture_size[1], device=device), 130 | torch.arange(0, texture_size[0], device=device), 131 | indexing="ij", 132 | ) 133 | y = y.unsqueeze(0).repeat((lines.shape[0], 1, 1)) 134 | x = x.unsqueeze(0).repeat((lines.shape[0], 1, 1)) 135 | xy = torch.stack([x, y]) 136 | 137 | # See: https://github.com/jonhare/DifferentiableSketching/blob/main/dsketch/raster/disttrans.py 138 | # If you found this, you should definitely give them a star. That's beautiful code they wrote there. 139 | 140 | pa = xy - lines_start 141 | pb = xy - lines_end 142 | m = lines_end - lines_start 143 | 144 | t0 = (pa * m).sum(dim=0) / ((m * m).sum(dim=0) + torch.finfo().eps) 145 | patm = xy - (lines_start + t0.unsqueeze(0) * m) 146 | 147 | distance_smaller_zero = (t0 <= 0) * (pa * pa).sum(dim=0) 148 | distance_inbetween = (t0 > 0) * (t0 < 1) * (patm * patm).sum(dim=0) 149 | distance_greater_one = (t0 >= 1) * (pb * pb).sum(dim=0) 150 | 151 | distances = distance_smaller_zero + distance_inbetween + distance_greater_one 152 | # distances = distances.sqrt() 153 | return torch.exp(-(distances * distances) / (sigma * sigma)) 154 | 155 | 156 | def softor(texture: torch.tensor, dim=0, keepdim: bool = False) -> torch.tensor: 157 | return 1 - torch.prod(1 - texture, dim=dim, keepdim=keepdim) 158 | 159 | 160 | def sum(texture: torch.tensor, dim=0, keepdim: bool = False) -> torch.tensor: 161 | return torch.sum(texture, dim=dim, keepdim=keepdim) 162 | 163 | 164 | def baked_sum( 165 | points: torch.tensor, 166 | sigma: torch.tensor, 167 | texture_size: torch.tensor, 168 | num_std: int = 4, 169 | device: torch.cuda.device = torch.device("cuda"), 170 | ) -> torch.tensor: 171 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 172 | 173 | # Somewhere between [0, texture_size] but in float 174 | points = points.clone() * texture_size 175 | 176 | for i in range(points.shape[0]): 177 | point = points[i] 178 | 179 | # We use 3*sigma^2 here, to include most of the gaussian 180 | footprint = math.floor(sigma.sqrt().item()) * num_std 181 | footprint = footprint + 1 if footprint % 2 == 0 else footprint 182 | half_footprint = int((footprint - 1) / 2) 183 | 184 | point_in_middle_of_footprint = point - point.floor() + half_footprint 185 | 186 | footprint_origin_in_original_image = (point - half_footprint).floor() # [Y, X] 187 | 188 | y, x = torch.meshgrid( 189 | torch.arange(0, footprint, device=device), 190 | torch.arange(0, footprint, device=device), 191 | indexing="ij", 192 | ) 193 | 194 | y_dist = y - point_in_middle_of_footprint[0:1] 195 | x_dist = x - point_in_middle_of_footprint[1:2] 196 | dist = y_dist * y_dist + x_dist * x_dist 197 | dist = torch.exp(-torch.pow(dist / sigma, 2)) 198 | 199 | wo = (point.floor() - half_footprint).int() 200 | re = [footprint, footprint] 201 | rs = [0, 0] 202 | 203 | # tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]] = tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]].clone() + dist[rs[0]:re[0], rs[1]:re[1]] 204 | 205 | # This is probably the worst code I've ever written. 206 | # It is an out of bounds check for rectangles such that we can copy the correct parts from our tensors. 207 | # There's certainly better and more precise ways to solve this, but for know it works. 208 | 209 | rect_start = torch.tensor([0, 0], dtype=torch.int32, device=device) 210 | rect_end = torch.tensor( 211 | [footprint, footprint], dtype=torch.int32, device=device 212 | ) 213 | 214 | if footprint_origin_in_original_image[0] < 0: 215 | rect_start[0] = footprint_origin_in_original_image.abs()[0] 216 | footprint_origin_in_original_image[0] = 0 217 | 218 | if footprint_origin_in_original_image[1] < 0: 219 | rect_start[1] = footprint_origin_in_original_image.abs()[1] 220 | footprint_origin_in_original_image[1] = 0 221 | 222 | if footprint_origin_in_original_image[0] + footprint >= texture_size[0]: 223 | rect_end[0] = texture_size[0] - footprint_origin_in_original_image[0] 224 | 225 | if footprint_origin_in_original_image[1] + footprint >= texture_size[1]: 226 | rect_end[1] = texture_size[1] - footprint_origin_in_original_image[1] 227 | 228 | wo = footprint_origin_in_original_image.int() 229 | rs = rect_start.int() 230 | re = rect_end.int() 231 | 232 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]] = ( 233 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]].clone() 234 | + dist[rs[0] : re[0], rs[1] : re[1]] 235 | ) 236 | 237 | return tex.T 238 | 239 | 240 | def baked_sum_2( 241 | points: torch.tensor, 242 | sigma: torch.tensor, 243 | texture_size: torch.tensor, 244 | num_std: int = 4, 245 | device: torch.cuda.device = torch.device("cuda"), 246 | ) -> torch.tensor: 247 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 248 | # Somewhere between [0, texture_size] but in float 249 | points = points.clone() * texture_size 250 | 251 | # We use 3*sigma^2 here, to include most of the gaussian 252 | footprint = math.floor(sigma.sqrt().item()) * num_std 253 | footprint = footprint + 1 if footprint % 2 == 0 else footprint 254 | half_footprint = int((footprint - 1) / 2) 255 | 256 | point_in_middle_of_footprint = points - points.floor() + half_footprint 257 | 258 | footprint_origin_in_original_image = (points - half_footprint).floor() # [Y, X] 259 | 260 | y, x = torch.meshgrid( 261 | torch.arange(0, footprint, device=device), 262 | torch.arange(0, footprint, device=device), 263 | indexing="ij", 264 | ) 265 | 266 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1)) 267 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1)) 268 | 269 | y_dist = y - point_in_middle_of_footprint[:, 0:1].unsqueeze(-1) 270 | x_dist = x - point_in_middle_of_footprint[:, 1:2].unsqueeze(-1) 271 | 272 | dist = y_dist * y_dist + x_dist * x_dist 273 | dist = torch.exp(-torch.pow(dist / sigma, 2)) 274 | 275 | wo = footprint_origin_in_original_image.int() 276 | rect_start = torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) 277 | rect_end = ( 278 | torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) + footprint 279 | ) 280 | 281 | if (wo[:, 0] < 0).any(): 282 | rect_start[:, 0] = torch.where(wo[:, 0] < 0, wo.abs()[:, 0], rect_start[:, 0]) 283 | wo[:, 0] = torch.where(wo[:, 0] < 0, 0, wo[:, 0]) 284 | 285 | if (wo[:, 1] < 0).any(): 286 | rect_start[:, 1] = torch.where(wo[:, 1] < 1, wo.abs()[:, 1], rect_start[:, 1]) 287 | wo[:, 1] = torch.where(wo[:, 1] < 0, 0, wo[:, 1]) 288 | 289 | if (wo[:, 0] + footprint >= texture_size[0]).any(): 290 | rect_end[:, 0] = torch.where( 291 | wo[:, 0] + footprint >= texture_size[0], 292 | texture_size[0] - wo[:, 0], 293 | rect_end[:, 0], 294 | ) 295 | 296 | if (wo[:, 1] + footprint >= texture_size[1]).any(): 297 | rect_end[:, 1] = torch.where( 298 | wo[:, 1] + footprint >= texture_size[1], 299 | texture_size[1] - wo[:, 1], 300 | rect_end[:, 1], 301 | ) 302 | 303 | re = rect_end.clone() 304 | rs = rect_start.clone() 305 | 306 | for i in range(points.shape[0]): 307 | tex[ 308 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0], 309 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1], 310 | ] = ( 311 | tex[ 312 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0], 313 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1], 314 | ].clone() 315 | + dist[i, rs[i, 0] : re[i, 0], rs[i, 1] : re[i, 1]] 316 | ) 317 | 318 | return tex 319 | 320 | 321 | def baked_softor( 322 | points: torch.tensor, 323 | sigma: torch.tensor, 324 | texture_size: torch.tensor, 325 | num_std: int = 5, 326 | device: torch.cuda.device = torch.device("cuda"), 327 | ) -> torch.tensor: 328 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device) 329 | # Somewhere between [0, texture_size] but in float 330 | points = points.clone() * texture_size 331 | 332 | for i in range(points.shape[0]): 333 | point = points[i] 334 | 335 | # We use 3*sigma^2 here, to include most of the gaussian 336 | footprint = math.floor(sigma.sqrt().item()) * num_std 337 | footprint = footprint + 1 if footprint % 2 == 0 else footprint 338 | half_footprint = int((footprint - 1) / 2) 339 | 340 | point_in_middle_of_footprint = point - point.floor() + half_footprint 341 | 342 | footprint_origin_in_original_image = (point - half_footprint).floor() # [Y, X] 343 | 344 | y, x = torch.meshgrid( 345 | torch.arange(0, footprint, device=device), 346 | torch.arange(0, footprint, device=device), 347 | indexing="ij", 348 | ) 349 | 350 | y_dist = y - point_in_middle_of_footprint[0:1] 351 | x_dist = x - point_in_middle_of_footprint[1:2] 352 | dist = y_dist * y_dist + x_dist * x_dist 353 | dist = torch.exp(-torch.pow(dist / sigma, 2)) 354 | 355 | wo = (point.floor() - half_footprint).int() 356 | re = [footprint, footprint] 357 | rs = [0, 0] 358 | 359 | # tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]] = tex[wo[0]:wo[0]+re[0]-rs[0], wo[1]:wo[1]+re[1]-rs[1]].clone() + dist[rs[0]:re[0], rs[1]:re[1]] 360 | 361 | # This is probably the worst code I've ever written. 362 | # It is an out of bounds check for rectangles such that we can copy the correct parts from our tensors. 363 | # There's certainly better and more precise ways to solve this, but for know it works. 364 | 365 | rect_start = torch.tensor([0, 0], dtype=torch.int32, device=device) 366 | rect_end = torch.tensor( 367 | [footprint, footprint], dtype=torch.int32, device=device 368 | ) 369 | 370 | if footprint_origin_in_original_image[0] < 0: 371 | rect_start[0] = footprint_origin_in_original_image.abs()[0] 372 | footprint_origin_in_original_image[0] = 0 373 | 374 | if footprint_origin_in_original_image[1] < 0: 375 | rect_start[1] = footprint_origin_in_original_image.abs()[1] 376 | footprint_origin_in_original_image[1] = 0 377 | 378 | if footprint_origin_in_original_image[0] + footprint >= texture_size[0]: 379 | rect_end[0] = texture_size[0] - footprint_origin_in_original_image[0] 380 | 381 | if footprint_origin_in_original_image[1] + footprint >= texture_size[1]: 382 | rect_end[1] = texture_size[1] - footprint_origin_in_original_image[1] 383 | 384 | wo = footprint_origin_in_original_image.int() 385 | rs = rect_start.int() 386 | re = rect_end.int() 387 | 388 | tex[wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1]] = tex[ 389 | wo[0] : wo[0] + re[0] - rs[0], wo[1] : wo[1] + re[1] - rs[1] 390 | ].clone() * (1 - dist[rs[0] : re[0], rs[1] : re[1]]) 391 | 392 | return (1 - tex).T 393 | 394 | 395 | def baked_softor_2( 396 | points: torch.tensor, 397 | sigma: torch.tensor, 398 | texture_size: torch.tensor, 399 | num_std: int = 5, 400 | device: torch.cuda.device = torch.device("cuda"), 401 | ) -> torch.tensor: 402 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device) 403 | # Somewhere between [0, texture_size] but in float 404 | points = points.clone() * texture_size 405 | 406 | # We use 3*sigma^2 here, to include most of the gaussian 407 | footprint = math.floor(sigma.sqrt().item()) * num_std 408 | footprint = footprint + 1 if footprint % 2 == 0 else footprint 409 | half_footprint = int((footprint - 1) / 2) 410 | 411 | point_in_middle_of_footprint = points - points.floor() + half_footprint 412 | 413 | footprint_origin_in_original_image = (points - half_footprint).floor() # [Y, X] 414 | 415 | y, x = torch.meshgrid( 416 | torch.arange(0, footprint, device=device), 417 | torch.arange(0, footprint, device=device), 418 | indexing="ij", 419 | ) 420 | 421 | y = y.unsqueeze(0).repeat((points.shape[0], 1, 1)) 422 | x = x.unsqueeze(0).repeat((points.shape[0], 1, 1)) 423 | 424 | y_dist = y - point_in_middle_of_footprint[:, 0:1].unsqueeze(-1) 425 | x_dist = x - point_in_middle_of_footprint[:, 1:2].unsqueeze(-1) 426 | 427 | dist = y_dist * y_dist + x_dist * x_dist 428 | dist = torch.exp(-torch.pow(dist / sigma, 2)) 429 | 430 | wo = footprint_origin_in_original_image.int() 431 | rect_start = torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) 432 | rect_end = ( 433 | torch.zeros(points.shape[0], 2, dtype=torch.int32, device=device) + footprint 434 | ) 435 | 436 | if (wo[:, 0] < 0).any(): 437 | rect_start[:, 0] = torch.where(wo[:, 0] < 0, wo.abs()[:, 0], rect_start[:, 0]) 438 | wo[:, 0] = torch.where(wo[:, 0] < 0, 0, wo[:, 0]) 439 | 440 | if (wo[:, 1] < 0).any(): 441 | rect_start[:, 1] = torch.where(wo[:, 1] < 1, wo.abs()[:, 1], rect_start[:, 1]) 442 | wo[:, 1] = torch.where(wo[:, 1] < 0, 0, wo[:, 1]) 443 | 444 | if (wo[:, 0] + footprint >= texture_size[0]).any(): 445 | rect_end[:, 0] = torch.where( 446 | wo[:, 0] + footprint >= texture_size[0], 447 | texture_size[0] - wo[:, 0], 448 | rect_end[:, 0], 449 | ) 450 | 451 | if (wo[:, 1] + footprint >= texture_size[1]).any(): 452 | rect_end[:, 1] = torch.where( 453 | wo[:, 1] + footprint >= texture_size[1], 454 | texture_size[1] - wo[:, 1], 455 | rect_end[:, 1], 456 | ) 457 | 458 | re = rect_end.clone() 459 | rs = rect_start.clone() 460 | 461 | for i in range(points.shape[0]): 462 | tex[ 463 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0], 464 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1], 465 | ] = tex[ 466 | wo[i, 0] : wo[i, 0] + re[i, 0] - rs[i, 0], 467 | wo[i, 1] : wo[i, 1] + re[i, 1] - rs[i, 1], 468 | ].clone() * ( 469 | 1 - dist[i, rs[i, 0] : re[i, 0], rs[i, 1] : re[i, 1]] 470 | ) 471 | 472 | return (1 - tex).T 473 | 474 | 475 | def rasterize_points_baked_softor( 476 | points: torch.tensor, 477 | sigma: float, 478 | texture_size: torch.tensor, 479 | device: torch.cuda.device = torch.device("cuda"), 480 | ) -> torch.tensor: 481 | tex = torch.ones(texture_size.tolist(), dtype=torch.float32, device=device) 482 | 483 | # Somewhere between [0, texture_size] but in float 484 | points = points.clone() * texture_size 485 | 486 | for i in range(points.shape[0]): 487 | # Generate x, y indices 488 | x, y = torch.meshgrid( 489 | torch.arange(0, texture_size[1], device=device), 490 | torch.arange(0, texture_size[0], device=device), 491 | indexing="ij", 492 | ) 493 | 494 | y_dist = y - points[i, 0:1].unsqueeze(-1) 495 | x_dist = x - points[i, 1:2].unsqueeze(-1) 496 | 497 | point_distances = ( 498 | y_dist * y_dist + x_dist * x_dist 499 | ) # / (texture_size * texture_size).sum().sqrt() 500 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2)) 501 | tex = tex.clone() * (1 - point_distances) 502 | 503 | return 1 - tex 504 | 505 | 506 | # We assume points to be in camera space [0, 1] 507 | def rasterize_points_baked_sum( 508 | points: torch.tensor, 509 | sigma: float, 510 | texture_size: torch.tensor, 511 | device: torch.cuda.device = torch.device("cuda"), 512 | ) -> torch.tensor: 513 | tex = torch.zeros(texture_size.tolist(), dtype=torch.float32, device=device) 514 | 515 | # Somewhere between [0, texture_size] but in float 516 | points = points.clone() * texture_size 517 | 518 | for i in range(points.shape[0]): 519 | # Generate x, y indices 520 | x, y = torch.meshgrid( 521 | torch.arange(0, texture_size[1], device=device), 522 | torch.arange(0, texture_size[0], device=device), 523 | indexing="ij", 524 | ) 525 | 526 | y_dist = y - points[i, 0:1].unsqueeze(-1) 527 | x_dist = x - points[i, 1:2].unsqueeze(-1) 528 | 529 | point_distances = ( 530 | y_dist * y_dist + x_dist * x_dist 531 | ) # / (texture_size * texture_size).sum().sqrt() 532 | point_distances = torch.exp(-torch.pow(point_distances / sigma, 2)) 533 | tex += point_distances 534 | 535 | return tex 536 | 537 | 538 | def subsampled_point_raster(ndc_points, num_subsamples, sigma, sensor_size): 539 | subsampled_rastered_depth = [] 540 | for i in range(num_subsamples): 541 | rastered_depth = rasterize_depth( 542 | ndc_points[:, 0:2], ndc_points[:, 2:3], sigma, sensor_size // 2**i 543 | ) 544 | rastered_depth = softor(rastered_depth, keepdim=True) 545 | # rastered_depth = (rastered_depth - rastered_depth.min()) / ( 546 | # rastered_depth.max() - rastered_depth.min() 547 | # ) 548 | subsampled_rastered_depth.append(rastered_depth) 549 | return subsampled_rastered_depth 550 | 551 | 552 | def get_mpl_colormap(cmap): 553 | import matplotlib.pyplot as plt 554 | 555 | # Initialize the matplotlib color map 556 | sm = plt.cm.ScalarMappable(cmap=cmap) 557 | 558 | # Obtain linear color range 559 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] 560 | 561 | return color_range.reshape(256, 1, 3) 562 | 563 | 564 | def test_point_reg(reduce_overlap: bool = True): 565 | import cv2 566 | import numpy as np 567 | from tqdm import tqdm 568 | import imageio 569 | import matplotlib.colors 570 | import timeit 571 | import matplotlib.pyplot as plt 572 | 573 | device = "cuda" if torch.cuda.is_available() else "cpu" 574 | 575 | points = torch.rand([500, 2], device=device) 576 | points.requires_grad = True 577 | sigma = torch.tensor([15.0], device=device) ** 2 578 | texture_size = torch.tensor([512, 512], device=device) 579 | loss_func = torch.nn.L1Loss() 580 | 581 | opt_steps = 200 582 | 583 | optim = torch.optim.Adam([{"params": points, "lr": 0.001}]) 584 | 585 | images = [] 586 | for i in tqdm(range(opt_steps)): 587 | optim.zero_grad() 588 | 589 | summed = baked_sum_2(points, sigma, texture_size) 590 | softored = baked_softor_2(points, sigma, texture_size) 591 | 592 | # rasterized_points = rasterize_points(points, sigma, texture_size) 593 | # softored = softor(rasterized_points) 594 | # summed = rasterized_points.sum(dim=0) 595 | 596 | loss = ( 597 | loss_func(softored, summed) 598 | if reduce_overlap 599 | else -loss_func(softored, summed) 600 | ) 601 | print(loss.item()) 602 | loss.backward() 603 | optim.step() 604 | 605 | with torch.no_grad(): 606 | points[points >= 1.0] = 0.999 607 | points[points <= 0.0] = 0.001 608 | 609 | # Apply custom colormap 610 | colors = [(0.0, 0.1921, 0.4156), (0, 0.69, 0.314)] # R -> G -> B 611 | 612 | # fig = plt.figure(frameon=False) 613 | # fig.set_size_inches(10, 10) 614 | # ax = plt.Axes(fig, [0., 0., 1., 1.]) 615 | # ax.set_axis_off() 616 | # ax.set_aspect(aspect='equal') 617 | # ax.set_facecolor(colors[0]) 618 | # fig.add_axes(ax) 619 | 620 | # ax.scatter(points.detach().cpu().numpy()[:, 0], points.detach().cpu().numpy()[:, 1], s=60.0*10, color=colors[0]) 621 | # fig.canvas.draw() 622 | # img_plot = np.array(fig.canvas.renderer.buffer_rgba()) 623 | # np_points = img_plot 624 | # images.append(np_points) 625 | bla = cv2.applyColorMap( 626 | (softored.detach().cpu().numpy() * 255).astype(np.uint8), 627 | cv2.COLORMAP_VIRIDIS, 628 | ) 629 | cv2.imshow("Show", bla) 630 | cv2.waitKey(1) 631 | # if i == 0 or i == opt_steps - 1: 632 | # fig.savefig("assets/point_reduced_overlap{0}.eps".format(i) if reduce_overlap else "assets/point_increased_overlap{0}.eps".format(i), 633 | # facecolor=ax.get_facecolor(), 634 | # edgecolor='none', 635 | # bbox_inches = 'tight', 636 | # pad_inches=0) 637 | plt.close() 638 | 639 | # cv2.imshow("Optim Lines", np_points) 640 | # cv2.waitKey(1) 641 | # lines.requires_grad = True 642 | # imageio.v3.imwrite("assets/point_regularization.mp4", np.stack(images, axis=0), fps=25) 643 | 644 | 645 | def test_line_reg(): 646 | import cv2 647 | import numpy as np 648 | from tqdm import tqdm 649 | import imageio 650 | import matplotlib.colors 651 | import matplotlib.pyplot as plt 652 | 653 | device = "cuda" if torch.cuda.is_available() else "cpu" 654 | 655 | # We define line as: 656 | # P0: x + -t*d 657 | # P1: x + t*d 658 | # Where d is the direction, x the location vector and t half its length. 659 | # We want to optimize the location vector x of all lines, such that they do not overlap. 660 | 661 | num_lines = 50 662 | sigma = 10.0 663 | opt_steps = 250 664 | 665 | t = torch.tensor([0.5], device=device) 666 | direction = torch.rand([2], device=device).unsqueeze(0) 667 | direction = direction / direction.norm() 668 | 669 | location_vector = ( 670 | (torch.rand([num_lines, 2], device=device) - 0.5) * 2.0 / 10.0 671 | ) # Every line should be roughly in the middle of our frame 672 | location_vector.requires_grad = True 673 | 674 | sigma = torch.tensor([sigma], device=device) 675 | texture_size = torch.tensor([512, 512], device=device) 676 | loss_func = torch.nn.L1Loss() 677 | 678 | optim = torch.optim.Adam([{"params": location_vector, "lr": 0.005}]) 679 | 680 | images = [] 681 | for i in tqdm(range(opt_steps)): 682 | optim.zero_grad() 683 | 684 | p0 = location_vector + t * direction 685 | p1 = location_vector - t * direction 686 | lines = torch.concat([p0.unsqueeze(-1), p1.unsqueeze(-1)], dim=-1).transpose( 687 | 1, 2 688 | ) 689 | 690 | rasterized_lines = rasterize_lines(lines, sigma, texture_size) 691 | 692 | softored = softor(rasterized_lines) 693 | summed = rasterized_lines.sum(dim=0) 694 | 695 | loss = loss_func(softored, summed) 696 | loss.backward() 697 | optim.step() 698 | 699 | with torch.no_grad(): 700 | location_vector[p0 > 1.0] -= 0.01 701 | location_vector[p0 < -1.0] += 0.01 702 | location_vector[p1 > 1.0] -= 0.01 703 | location_vector[p1 < -1.0] += 0.01 704 | 705 | colors = [(0.0, 0.1921, 0.4156), (0, 0.69, 0.314)] # R -> G -> B 706 | fig = plt.figure(frameon=False) 707 | fig.set_size_inches(10, 10) 708 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) 709 | ax.set_xlim([-1, 1]) 710 | ax.set_ylim([-1, 1]) 711 | ax.set_axis_off() 712 | ax.set_aspect(aspect="equal") 713 | fig.add_axes(ax) 714 | 715 | lines_copy = lines.transpose(1, 2).detach().cpu().numpy() 716 | for j in range(lines_copy.shape[0]): 717 | ax.plot( 718 | lines_copy[j, 0, :], 719 | lines_copy[j, 1, :], 720 | c=colors[0], 721 | linewidth=9.5, 722 | solid_capstyle="round", 723 | ) # c=colors[0], linewidth=60) 724 | 725 | fig.canvas.draw() 726 | img_plot = np.array(fig.canvas.renderer.buffer_rgba()) 727 | np_points = img_plot 728 | images.append(np_points) 729 | 730 | if i == 0 or i == opt_steps - 1: 731 | fig.savefig( 732 | "assets/line_reduced_overlap{0}.eps".format(i), 733 | facecolor=ax.get_facecolor(), 734 | edgecolor="none", 735 | bbox_inches="tight", 736 | pad_inches=0, 737 | ) 738 | plt.close() 739 | 740 | imageio.v3.imwrite( 741 | "assets/line_regularization.mp4", np.stack(images, axis=0), fps=25 742 | ) 743 | # optimize("line_regularization.gif") 744 | 745 | 746 | def main(): 747 | import matplotlib.pyplot as plt 748 | 749 | device = "cuda" if torch.cuda.is_available() else "cpu" 750 | 751 | points_a = torch.rand(10, 2, device=device) # (Y, X) 752 | 753 | # points_a = torch.tensor([[0.565, 0.5555]], device=device) 754 | # points_b = torch.tensor([[0.51,0.51]], device=device) 755 | 756 | texture_size = torch.tensor([100, 100], device=device) # (Y, X) 757 | 758 | sigma = torch.tensor([10.0], device=device) ** 2 759 | sum = baked_sum(points_a, sigma, texture_size, device=device) 760 | # sum_og = baked_sum(points_b, sigma, texture_size, device=device) 761 | sum_og = rasterize_points(points_a, sigma, texture_size, device=device).sum(dim=0) 762 | # sum_og = rasterize_points(points_b, sigma, texture_size, device=device).sum(dim=0) 763 | 764 | fig, (ax1, ax2, ax3) = plt.subplots(1, 3) 765 | ax1.imshow(sum_og.detach().cpu().numpy()) 766 | ax2.imshow(sum.detach().cpu().numpy()) 767 | ax3.imshow((sum_og - sum).detach().cpu().numpy()) 768 | ax1.set_title("OG") 769 | ax2.set_title("NEW") 770 | ax3.set_title("DIFF") 771 | fig.show() 772 | plt.show() 773 | 774 | 775 | def time_it(): 776 | import timeit, functools 777 | 778 | device = "cuda" if torch.cuda.is_available() else "cpu" 779 | 780 | points = torch.rand([500, 2], device=device) 781 | points.requires_grad = True 782 | sigma = torch.tensor([10.0], device=device) ** 2 783 | texture_size = torch.tensor([512, 512], device=device) 784 | repeats = 50 785 | 786 | og_sum = timeit.Timer( 787 | lambda: rasterize_points(points, sigma, texture_size, device).sum(dim=0) 788 | ) 789 | print(og_sum.timeit(repeats)) 790 | 791 | t_baked_sum = timeit.Timer( 792 | lambda: baked_sum(points, sigma, texture_size, 4, device) 793 | ) 794 | print(t_baked_sum.timeit(repeats)) 795 | 796 | t_baked_sum2 = timeit.Timer( 797 | lambda: baked_sum_2(points, sigma, texture_size, 4, device) 798 | ) 799 | print(t_baked_sum2.timeit(repeats)) 800 | 801 | og_softor = timeit.Timer( 802 | lambda: softor(rasterize_points(points, sigma, texture_size, device)) 803 | ) 804 | print(og_softor.timeit(repeats)) 805 | 806 | t_baked_softor1 = timeit.Timer( 807 | lambda: baked_softor(points, sigma, texture_size, 4, device) 808 | ) 809 | print(t_baked_softor1.timeit(repeats)) 810 | 811 | t_baked_softor2 = timeit.Timer( 812 | lambda: baked_softor_2(points, sigma, texture_size, 4, device) 813 | ) 814 | print(t_baked_softor2.timeit(repeats)) 815 | 816 | 817 | if __name__ == "__main__": 818 | # time_it() 819 | test_point_reg(reduce_overlap=True) 820 | # test_point_reg(reduce_overlap=False) 821 | # test_line_reg() 822 | # main() 823 | --------------------------------------------------------------------------------