├── README.md ├── example_data ├── scene_lens_doubleslit.png └── scene_optical_fibers.png ├── images ├── optical_cavity.jpg ├── simulation_1.jpg ├── simulation_2.jpg ├── simulation_3.jpg ├── simulation_4.jpg └── source_antialiasing.png ├── requirements.txt └── wave_sim2d ├── __init__.py ├── develop_tests.py ├── examples ├── example0.py ├── example1.py ├── example2.py ├── example3.py └── example4.py ├── main.py ├── scene_objects ├── __pycache__ │ └── border_absorber.cpython-310.pyc ├── source.py ├── static_dampening.py ├── static_image_scene.py ├── static_refractive_index.py └── strain_refractive_index.py ├── wave_simulation.py └── wave_visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # 2D Wave Simulation on the GPU 2 | 3 | This repository contains a lightweight 2D wave simulator running on the GPU using CuPy library (probably requires a NVIDIA GPU). It can be used for 2D light and sound simulations. 4 | A simple visualizer shows the field and its intensity on the screen and writes a movie file for each to disks. The goal is to provide a fast, easy to use but still felxible wave simulator. 5 | 6 |
7 | Example Image 1 8 | Example Image 2 9 |
10 | 11 | ### Update 06.04.2025 12 | 13 | * Scene objects can now draw to a visualization layer (most of them do not yet, feel free to contribute) ! 14 | * Example 4 now shows a two-mirror optical cavity and how standing waves emerge. 15 | * Added new Line Sources 16 | * Added Refractive index Polygon object (StaticRefractiveIndexPolygon) 17 | * Added Refractive index Box object (StaticRefractiveIndexBox) 18 | * Fixed some issues with the examples 19 | 20 |
21 | Example 4 - Optical Cavity with Standing Waves 22 |
23 | 24 | ### Update 01.04.2024 25 | 26 | * Refactored the code to support a more flexible scene description. A simulation scene now consists of a list of objects that add their contribution to the fields. 27 | They can be combined to build complex and time dependent simulations. The refactoring also made the core simulation code even simpler. 28 | * Added a few new custom colormaps that work well for wave simulations. 29 | * Added new examples, which should make it easier to understand the usage of the program and how you can setup your own simulations: [examples](source/examples). 30 | 31 |
32 | Example Image 3 33 | Example Image 4 34 |
35 | 36 | The old image based scene description is still available as a scene object. You can continue to use the convenience of an image editing software and create simulations 37 | without much programming. 38 | 39 | ### Image Scene Decsription Usage ### 40 | 41 | When using the 'StaticImageScene' class the simulation scenes can given as an 8Bit RGB image with the following channel semantics: 42 | * Red: The Refractive index times 100 (for refractive index 1.5 you would use value 150) 43 | * Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value defines its frequency. 44 | * Blue: Absorbtion field. Larger values correspond to higher dampening of the waves, use graduated transitions to avoid reflections 45 | 46 | WARNING: Do not use anti-aliasing for the green channel ! The shades produced are interpreted as different source frequencies, which yields weird results. 47 | 48 |
49 | Example Image 5 50 |
51 | 52 | ### Recommended Installation ### 53 | 54 | 1. Install Python and PyCharm IDE 55 | 2. Clone the Project to you hard disk 56 | 3. Open the folder as a Project using PyCharm 57 | 4. If prompted to install requirements, accept (or install requirements using pip -r requirements.txt) 58 | 5. Right click on one of the examples in wave_sim2d/examples and select run 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /example_data/scene_lens_doubleslit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/example_data/scene_lens_doubleslit.png -------------------------------------------------------------------------------- /example_data/scene_optical_fibers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/example_data/scene_optical_fibers.png -------------------------------------------------------------------------------- /images/optical_cavity.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/optical_cavity.jpg -------------------------------------------------------------------------------- /images/simulation_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_1.jpg -------------------------------------------------------------------------------- /images/simulation_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_2.jpg -------------------------------------------------------------------------------- /images/simulation_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_3.jpg -------------------------------------------------------------------------------- /images/simulation_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_4.jpg -------------------------------------------------------------------------------- /images/source_antialiasing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/source_antialiasing.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | matplotlib 4 | cupy -------------------------------------------------------------------------------- /wave_sim2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/wave_sim2d/__init__.py -------------------------------------------------------------------------------- /wave_sim2d/develop_tests.py: -------------------------------------------------------------------------------- 1 | import wave_visualizer 2 | import wave_visualizer as vis 3 | import wave_simulation as sim 4 | import numpy as np 5 | import cv2 6 | import math 7 | import json 8 | from scene_objects.static_dampening import StaticDampening 9 | from scene_objects.static_refractive_index import StaticRefractiveIndex 10 | from scene_objects.static_image_scene import StaticImageScene 11 | from scene_objects.source import PointSource, ModulatorSmoothSquare, ModulatorDiscreteSignal 12 | 13 | 14 | def build_example_scene1(scene_image): 15 | """ 16 | This example uses the old image scene description. See 'StaticImageScene' for more information. 17 | """ 18 | scene_objects = [StaticImageScene(scene_image)] 19 | return scene_objects 20 | 21 | 22 | def build_example_scene2(width, height): 23 | """ 24 | In this example, a new scene is created from scratch and a few emitters are places manually. 25 | One of the emitters uses an amplitude modulation object to change brightness over time 26 | """ 27 | objects = [] 28 | 29 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening) 30 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness') 31 | objects.append(StaticDampening(np.ones((height, width)), 48)) 32 | 33 | # add a constant refractive index field 34 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5))) 35 | 36 | # add a simple point source 37 | objects.append(PointSource(200, 250, 0.19, 5)) 38 | 39 | # add a point source with an amplitude modulator 40 | amplitude_modulator = ModulatorDiscreteSignal(np.random.randint(2, size=64), 0.0006) 41 | objects.append(PointSource(200, 350, 0.19, 5, amp_modulator=amplitude_modulator)) 42 | 43 | return objects 44 | 45 | 46 | def simulate(scene_image_fn, num_iterations, 47 | simulation_steps_per_frame, write_videos, 48 | field_colormap, intensity_colormap, 49 | background_image_fn=None): 50 | # reset random number generator 51 | np.random.seed(0) 52 | 53 | # load scene image 54 | scene_image = cv2.cvtColor(cv2.imread(scene_image_fn), cv2.COLOR_BGR2RGB) 55 | 56 | background_image = None 57 | if background_image_fn is not None: 58 | background_image = cv2.imread(background_image_fn) 59 | background_image = cv2.resize(background_image, (scene_image.shape[1], scene_image.shape[0])) 60 | 61 | # create simulator and visualizer objects 62 | simulator = sim.WaveSimulator2D(scene_image.shape[1], scene_image.shape[0]) 63 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 64 | 65 | # build simulation scene 66 | simulator.scene_objects = build_example_scene2(scene_image.shape[1], scene_image.shape[0]) 67 | 68 | # create video writers 69 | if write_videos: 70 | video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'), 71 | 60, (scene_image.shape[1], scene_image.shape[0])) 72 | video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'), 73 | 60, (scene_image.shape[1], scene_image.shape[0])) 74 | 75 | # run simulation 76 | for i in range(num_iterations): 77 | simulator.update_scene() 78 | simulator.update_field() 79 | visualizer.update(simulator) 80 | 81 | if i % simulation_steps_per_frame == 0: 82 | frame_int = visualizer.render_intensity(1.0) 83 | frame_field = visualizer.render_field(1.0) 84 | 85 | if background_image is not None: 86 | frame_int = cv2.add(background_image, frame_int) 87 | frame_field = cv2.add(background_image, frame_field) 88 | 89 | # frame_int = cv2.pyrDown(frame_int) 90 | # frame_field = cv2.pyrDown(frame_field) 91 | cv2.imshow("Wave Simulation", frame_field) #cv2.resize(frame_int, dsize=(1024, 1024))) 92 | cv2.waitKey(1) 93 | 94 | if write_videos: 95 | video_writer1.write(frame_field) 96 | video_writer2.write(frame_int) 97 | 98 | if i % 128 == 0: 99 | print(f'{int((i+1)/num_iterations*100)}%') 100 | 101 | 102 | if __name__ == "__main__": 103 | print('This file contains tests for development and you may not bve able to run it without errors') 104 | print('Please take a look at the previded examples') 105 | 106 | # increase simulation_steps_per_frame to better utilize GPU 107 | # good colormaps for field: RdBu[invert=True], colormap_wave1, colormap_wave2, colormap_wave4, icefire 108 | simulate('../exxample_data/scene_lens_doubleslit.png', 109 | 20000, 110 | simulation_steps_per_frame=16, 111 | write_videos=True, 112 | field_colormap=vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05), 113 | # field_colormap=vis.get_colormap_lut('RdBu', invert=True, make_symmetric=True), 114 | intensity_colormap=vis.get_colormap_lut('afmhot', invert=False, black_level=0.0), 115 | background_image_fn=None) 116 | 117 | -------------------------------------------------------------------------------- /wave_sim2d/examples/example0.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa 4 | 5 | import cv2 6 | import wave_sim2d.wave_visualizer as vis 7 | import wave_sim2d.wave_simulation as sim 8 | from wave_sim2d.scene_objects.source import * 9 | from wave_sim2d.scene_objects.static_refractive_index import * 10 | 11 | def build_scene(): 12 | """ 13 | This example creates the simplest possible simulation using a single emitter. 14 | """ 15 | width = 512 16 | height = 512 17 | objects = [PointSource(200, 256, 0.1, 5)] 18 | # objects.append(StaticRefractiveIndexPolygon([[400, 255], [300, 200], [300, 300]], 1.5)) 19 | # objects = [LineSource((200, 265), (250, 105), 0.2, 0.5)] 20 | 21 | return objects, width, height 22 | 23 | 24 | def main(): 25 | # create colormaps 26 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05) 27 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0) 28 | 29 | # build simulation scene 30 | scene_objects, w, h = build_scene() 31 | 32 | # create simulator and visualizer objects 33 | simulator = sim.WaveSimulator2D(w, h, scene_objects) 34 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 35 | 36 | # run simulation 37 | for i in range(1000): 38 | simulator.update_scene() 39 | simulator.update_field() 40 | visualizer.update(simulator) 41 | 42 | # show field 43 | frame_field = visualizer.render_field(1.0) 44 | cv2.imshow("Wave Simulation Field", frame_field) 45 | 46 | # show intensity 47 | # frame_int = visualizer.render_intensity(1.0) 48 | # cv2.imshow("Wave Simulation Intensity", frame_int) 49 | 50 | cv2.waitKey(1) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | 56 | -------------------------------------------------------------------------------- /wave_sim2d/examples/example1.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | import wave_sim2d.wave_visualizer as vis 9 | import wave_sim2d.wave_simulation as sim 10 | from wave_sim2d.scene_objects.static_image_scene import StaticImageScene 11 | 12 | 13 | def build_scene(scene_image_path): 14 | """ 15 | This example uses the 'old' image scene description. See 'StaticImageScene' for more information. 16 | """ 17 | # load scene image 18 | scene_image = cv2.cvtColor(cv2.imread(scene_image_path), cv2.COLOR_BGR2RGB) 19 | 20 | # create the scene object list with an 'StaticImageScene' entry as the only scene object 21 | # more scene objects can be added to the list to build more complex scenes 22 | scene_objects = [StaticImageScene(scene_image, source_fequency_scale=2.0)] 23 | 24 | return scene_objects, scene_image.shape[1], scene_image.shape[0] 25 | 26 | 27 | def main(): 28 | # Set scene image path. The image encodes refractive index, dampening and emitters in its color channels 29 | # see 'static_image_scene.StaticImageScene' class for a more detailed description. 30 | # please take a look at the image to understand what is happening in the simulation 31 | scene_image_path = '../../example_data/scene_lens_doubleslit.png' 32 | 33 | # create colormaps 34 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05) 35 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0) 36 | 37 | # reset random number generator 38 | np.random.seed(0) 39 | 40 | # build simulation scene 41 | scene_objects, w, h = build_scene(scene_image_path) 42 | 43 | # create simulator and visualizer objects 44 | simulator = sim.WaveSimulator2D(w, h, scene_objects) 45 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 46 | 47 | # run simulation 48 | for i in range(2000): 49 | simulator.update_scene() 50 | simulator.update_field() 51 | visualizer.update(simulator) 52 | 53 | # visualize very N frames 54 | if (i % 4) == 0: 55 | # show field 56 | frame_field = visualizer.render_field(1.0) 57 | cv2.imshow("Wave Simulation Field", frame_field) 58 | 59 | # show intensity 60 | # frame_int = visualizer.render_intensity(1.0) 61 | # cv2.imshow("Wave Simulation Intensity", frame_int) 62 | 63 | cv2.waitKey(1) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | 69 | -------------------------------------------------------------------------------- /wave_sim2d/examples/example2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa 4 | 5 | import numpy as np 6 | import cv2 7 | import wave_sim2d.wave_visualizer as vis 8 | import wave_sim2d.wave_simulation as sim 9 | from wave_sim2d.scene_objects.static_dampening import StaticDampening 10 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex 11 | from wave_sim2d.scene_objects.source import PointSource, ModulatorSmoothSquare 12 | 13 | 14 | def build_scene(): 15 | """ 16 | In this example, a new scene is created from scratch and a few emitters are places manually. 17 | One of the emitters uses an amplitude modulation object to change brightness over time 18 | """ 19 | width = 600 20 | height = 600 21 | objects = [] 22 | 23 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening) 24 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness') 25 | objects.append(StaticDampening(np.ones((height, width)), 32)) 26 | 27 | # add a constant refractive index field 28 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5))) 29 | 30 | # add a simple point source 31 | objects.append(PointSource(200, 220, 0.2, 8)) 32 | 33 | # add a point source with an amplitude modulator 34 | amplitude_modulator = ModulatorSmoothSquare(0.025, 0.0, smoothness=0.5) 35 | objects.append(PointSource(200, 380, 0.2, 8, amp_modulator=amplitude_modulator)) 36 | 37 | return objects, width, height 38 | 39 | 40 | def main(): 41 | # create colormaps 42 | field_colormap = vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05) 43 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0) 44 | 45 | # reset random number generator 46 | np.random.seed(0) 47 | 48 | # build simulation scene 49 | scene_objects, w, h = build_scene() 50 | 51 | # create simulator and visualizer objects 52 | simulator = sim.WaveSimulator2D(w, h, scene_objects) 53 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 54 | 55 | # run simulation 56 | for i in range(2000): 57 | simulator.update_scene() 58 | simulator.update_field() 59 | visualizer.update(simulator) 60 | 61 | # visualize very N frames 62 | if (i % 2) == 0: 63 | # show field 64 | frame_field = visualizer.render_field(1.0) 65 | cv2.imshow("Wave Simulation Field", frame_field) 66 | 67 | # show intensity 68 | # frame_int = visualizer.render_intensity(1.0) 69 | # cv2.imshow("Wave Simulation Intensity", frame_int) 70 | 71 | cv2.waitKey(1) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | 77 | -------------------------------------------------------------------------------- /wave_sim2d/examples/example3.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa 4 | 5 | import numpy as np 6 | import cupy as cp 7 | import math 8 | import cv2 9 | 10 | import wave_sim2d.wave_visualizer as vis 11 | import wave_sim2d.wave_simulation as sim 12 | from wave_sim2d.scene_objects.static_dampening import StaticDampening 13 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex 14 | 15 | 16 | def gaussian_kernel(size, sigma): 17 | """ 18 | creates gaussian kernel with side length `l` and a sigma of `sig` 19 | """ 20 | ax = np.linspace(-(size - 1) / 2., (size - 1) / 2., size) 21 | gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma)) 22 | kernel = np.outer(gauss, gauss) 23 | return kernel / np.sum(kernel) 24 | 25 | 26 | class MovingCharge(sim.SceneObject): 27 | """ 28 | Implements a point source scene object. The amplitude can be optionally modulated using a modulator object. 29 | :param x: center position x. 30 | :param y: center position y. 31 | :param frequency: motion frequency 32 | :param amplitude: motion amplitude 33 | """ 34 | def __init__(self, x, y, frequency, amplitude): 35 | self.x = x 36 | self.y = y 37 | self.frequency = frequency 38 | self.amplitude = amplitude 39 | self.size = 11 40 | 41 | # create a smooth source shape 42 | self.source_array = cp.array(gaussian_kernel(self.size, self.size/3)) 43 | 44 | def render(self, field, wave_speed_field, dampening_field): 45 | # no changes to the refractive index or dampening field required for this class 46 | pass 47 | 48 | def update_field(self, field, t): 49 | fade_in = math.sin(min(t*0.1, math.pi/2)) 50 | 51 | # write the moving charge to the field 52 | x = self.x + math.sin(self.frequency * t*0.05)*200 53 | y = self.y + math.sin(self.frequency * t)*self.amplitude 54 | 55 | # copy source shape to current position into field 56 | wh = self.source_array.shape[1]//2 57 | hh = self.source_array.shape[0]//2 58 | field[y-hh:y+hh+1, x-wh:x+wh+1] += self.source_array * fade_in * 0.25 59 | 60 | 61 | def build_scene(): 62 | """ 63 | In this example, a custom scene object is implemented and used to simulate a moving field disturbance. 64 | """ 65 | width = 600 66 | height = 600 67 | objects = [] 68 | 69 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening) 70 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness') 71 | objects.append(StaticDampening(np.ones((height, width)), 64)) 72 | 73 | # add a constant refractive index field 74 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5))) 75 | 76 | # add a simple point source 77 | objects.append(MovingCharge(300, 300, 0.1, 10)) 78 | 79 | return objects, width, height 80 | 81 | 82 | def main(): 83 | # create colormaps 84 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05) 85 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0) 86 | 87 | # reset random number generator 88 | np.random.seed(0) 89 | 90 | # build simulation scene 91 | scene_objects, w, h = build_scene() 92 | 93 | # create simulator and visualizer objects 94 | simulator = sim.WaveSimulator2D(w, h, scene_objects) 95 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 96 | 97 | # run simulation 98 | for i in range(8000): 99 | simulator.update_scene() 100 | simulator.update_field() 101 | visualizer.update(simulator) 102 | 103 | # visualize very N frames 104 | if (i % 2) == 0: 105 | # show field 106 | frame_field = visualizer.render_field(1.0) 107 | cv2.imshow("Wave Simulation Field", frame_field) 108 | 109 | # show intensity 110 | # frame_int = visualizer.render_intensity(1.0) 111 | # cv2.imshow("Wave Simulation Intensity", frame_int) 112 | 113 | cv2.waitKey(1) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | -------------------------------------------------------------------------------- /wave_sim2d/examples/example4.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa 4 | 5 | import cv2 6 | import numpy as np 7 | import cupy as cp 8 | import wave_sim2d.wave_visualizer as vis 9 | import wave_sim2d.wave_simulation as sim 10 | from wave_sim2d.scene_objects.source import * 11 | from wave_sim2d.scene_objects.static_refractive_index import * 12 | from wave_sim2d.scene_objects.static_dampening import * 13 | 14 | 15 | def build_scene(): 16 | """ 17 | This example creates fabry pirot cavity and shows the standing waves 18 | """ 19 | width = 768 20 | height = 512 21 | objects = [] 22 | 23 | # Add a static dampening field without any dampening in the interior (value 1.0 means no dampening) 24 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness') 25 | objects.append(StaticDampening(np.ones((height, width)), 48)) 26 | 27 | # add nonlinear refractive index field 28 | objects.append(StaticRefractiveIndexBox((50, height//2), (50, int(height*0.8)), 0.0, 100.0)) 29 | objects.append(StaticRefractiveIndexBox((width-180, height//2), (40, int(height*0.8)), 0.0, 10.0)) 30 | 31 | # add a point source with an amplitude modulator 32 | # objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.0215, amplitude=0.5)) 33 | objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.1003, amplitude=0.3)) 34 | 35 | return objects, width, height 36 | 37 | 38 | def show_field(field, brightness_scale): 39 | gray = (cp.clip(field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8) 40 | img = gray.get() 41 | cv2.imshow("Strain Simulation Field", cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 42 | 43 | 44 | def main(): 45 | write_videos = False 46 | write_video_frame_every = 2 47 | 48 | # create colormaps 49 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05) 50 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0) 51 | 52 | # build simulation scene 53 | scene_objects, w, h = build_scene() 54 | 55 | # create simulator and visualizer objects 56 | simulator = sim.WaveSimulator2D(w, h, scene_objects) 57 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap) 58 | 59 | # optional create video writers 60 | if write_videos: 61 | video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h)) 62 | video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h)) 63 | 64 | # run simulation 65 | for i in range(100000): 66 | simulator.update_scene() 67 | simulator.update_field() 68 | 69 | visualizer.update(simulator) 70 | # show field 71 | frame_field = visualizer.render_field(1.0) 72 | cv2.imshow("Wave Simulation Field", frame_field) 73 | 74 | # show intensity 75 | frame_int = visualizer.render_intensity(1.0) 76 | # cv2.imshow("Wave Simulation Intensity", frame_int) 77 | 78 | if write_videos and (i % write_video_frame_every) == 0: 79 | video_writer1.write(frame_field) 80 | video_writer2.write(frame_int) 81 | 82 | cv2.waitKey(1) 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | 88 | -------------------------------------------------------------------------------- /wave_sim2d/main.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | print('please run one of the examples from the source/example folder...') 3 | -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/__pycache__/border_absorber.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/wave_sim2d/scene_objects/__pycache__/border_absorber.cpython-310.pyc -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/source.py: -------------------------------------------------------------------------------- 1 | from wave_sim2d.wave_simulation import SceneObject 2 | import cupy as cp 3 | import numpy as np 4 | import math 5 | 6 | 7 | class PointSource(SceneObject): 8 | """ 9 | Implements a point source scene object. The amplitude can be optionally modulated using a modulator object. 10 | :param x: source position x. 11 | :param y: source position y. 12 | :param frequency: emitting frequency. 13 | :param amplitude: emitting amplitude, not used when an amplitude modulator is given 14 | :param phase: emitter phase 15 | :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source 16 | over time. 17 | """ 18 | def __init__(self, x, y, frequency, amplitude=1.0, phase=0, amp_modulator=None): 19 | self.x = x 20 | self.y = y 21 | self.frequency = frequency 22 | self.amplitude = amplitude 23 | self.phase = phase 24 | self.amplitude_modulator = amp_modulator 25 | 26 | def set_amplitude_modulator(self, func): 27 | self.amplitude_modulator = func 28 | 29 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 30 | pass 31 | 32 | def update_field(self, field, t): 33 | if self.amplitude_modulator is not None: 34 | amplitude = self.amplitude_modulator(t) * self.amplitude 35 | else: 36 | amplitude = self.amplitude 37 | 38 | v = cp.sin(self.phase + self.frequency * t) * amplitude 39 | field[self.y, self.x] = v 40 | 41 | def render_visualization(self, image: np.ndarray): 42 | """ renders a visualization of the scene object to the image """ 43 | pass 44 | 45 | 46 | class LineSource(SceneObject): 47 | """ 48 | Implements a line source scene object. The amplitude can be optionally modulated using a modulator object. 49 | The source emits along a line defined by a start and end point. 50 | :param start: starting (x, y) coordinates of the line as a tuple. 51 | :param end: ending (x, y) coordinates of the line as a tuple. 52 | :param frequency: emitting frequency. 53 | :param amplitude: emitting amplitude, not used when an amplitude modulator is given 54 | :param phase: emitter phase 55 | :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source 56 | over time. 57 | """ 58 | def __init__(self, start, end, frequency, amplitude=1.0, phase=0, amp_modulator=None): 59 | self.start = start 60 | self.end = end 61 | self.frequency = frequency 62 | self.amplitude = amplitude 63 | self.phase = phase 64 | self.amplitude_modulator = amp_modulator 65 | 66 | def set_amplitude_modulator(self, func): 67 | self.amplitude_modulator = func 68 | 69 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 70 | pass 71 | 72 | def update_field(self, field, t): 73 | if self.amplitude_modulator is not None: 74 | amplitude = self.amplitude_modulator(t) * self.amplitude 75 | else: 76 | amplitude = self.amplitude 77 | 78 | v = cp.sin(self.phase + self.frequency * t) * amplitude 79 | 80 | # Determine the points along the line using NumPy 81 | x1, y1 = self.start 82 | x2, y2 = self.end 83 | 84 | distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2) 85 | num_points = int(distance) + 1 86 | 87 | if num_points > 0: 88 | x_coords = cp.linspace(x1, x2, num_points).round().astype(int) 89 | y_coords = cp.linspace(y1, y2, num_points).round().astype(int) 90 | 91 | # Create boolean masks for valid indices 92 | valid_x = (x_coords >= 0) & (x_coords < field.shape[1]) 93 | valid_y = (y_coords >= 0) & (y_coords < field.shape[0]) 94 | valid_indices = valid_x & valid_y 95 | 96 | # Use these valid indices to update the field directly 97 | valid_y_coords = y_coords[valid_indices] 98 | valid_x_coords = x_coords[valid_indices] 99 | field[valid_y_coords, valid_x_coords] = v 100 | 101 | def render_visualization(self, image: np.ndarray): 102 | """ renders a visualization of the scene object to the image """ 103 | pass 104 | 105 | # --- Modulators ------------------------------------------------------------------------------------------------------- 106 | 107 | 108 | class ModulatorSmoothSquare: 109 | """ 110 | A modulator that creates a smoothed square wave 111 | """ 112 | def __init__(self, frequency, phase, smoothness=0.5): 113 | self.frequency = frequency 114 | self.phase = phase 115 | self.smoothness = min(max(smoothness, 1e-4), 1.0) 116 | 117 | def __call__(self, t): 118 | s = math.pow(self.smoothness, 4.0) 119 | a = (0.5 / math.atan(1.0/s)) * math.atan(math.sin(t * self.frequency + self.phase) / s)+0.5 120 | return a 121 | 122 | 123 | class ModulatorDiscreteSignal: 124 | """ 125 | A modulator that creates a smoothed binary signal 126 | """ 127 | def __init__(self, signal_array, time_factor, transition_slope=8.0): 128 | self.signal_array = signal_array 129 | self.time_factor = time_factor 130 | self.transition_slope = transition_slope 131 | 132 | def __call__(self, t): 133 | def smooth_step(t): 134 | return t * t * (3 - 2 * t) 135 | 136 | # Wrap around the position if it's outside the array range 137 | sl = len(self.signal_array) 138 | t = math.fmod(t*self.time_factor, sl) 139 | 140 | # Find the indices of the neighboring values 141 | index_low = int(t) 142 | index_high = (index_low + 1) % sl 143 | 144 | # Calculate the interpolation factor 145 | tf = (t - index_low) 146 | tf = max(0.0, min(1.0, (tf-0.5)*self.transition_slope+0.5)) 147 | 148 | # Use smooth step to interpolate between neighboring values 149 | l = smooth_step(tf) 150 | interpolated_value = (1 - l) * self.signal_array[index_low] + l * self.signal_array[index_high] 151 | 152 | return interpolated_value 153 | -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/static_dampening.py: -------------------------------------------------------------------------------- 1 | from wave_sim2d.wave_simulation import SceneObject 2 | import cupy as cp 3 | import numpy as np 4 | 5 | 6 | class StaticDampening(SceneObject): 7 | """ 8 | Implements a static dampening field that overwrites the entire domain. 9 | Therefore, us this as base layer in your scene. 10 | """ 11 | 12 | def __init__(self, dampening_field, border_thickness): 13 | """ 14 | Creates a static dampening field object 15 | @param dampening_field: A NxM array with dampening factors (1.0 equals no dampening) of the same size as the simulation domain. 16 | @param pml_thickness: Thickness of the Perfectly Matched Layer (PML) at the boundaries to prevent reflections. 17 | """ 18 | w = dampening_field.shape[1] 19 | h = dampening_field.shape[0] 20 | self.d = cp.ones((h, w), dtype=cp.float32) 21 | self.d = cp.clip(cp.array(dampening_field), 0.0, 1.0) 22 | 23 | # apply border dampening 24 | for i in range(border_thickness): 25 | v = (i / border_thickness) ** 0.5 26 | self.d[i, i:w - i] = v 27 | self.d[-(1 + i), i:w - i] = v 28 | self.d[i:h - i, i] = v 29 | self.d[i:h - i, -(1 + i)] = v 30 | 31 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 32 | assert (dampening_field.shape == self.d.shape) 33 | 34 | # overwrite existing dampening field 35 | dampening_field[:] = self.d 36 | 37 | def update_field(self, field: cp.ndarray, t): 38 | pass 39 | 40 | def render_visualization(self, image: np.ndarray): 41 | """ renders a visualization of the scene object to the image """ 42 | pass 43 | -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/static_image_scene.py: -------------------------------------------------------------------------------- 1 | from wave_sim2d.wave_simulation import SceneObject 2 | 3 | import numpy as np 4 | import cupy as cp 5 | from wave_sim2d.scene_objects.static_dampening import StaticDampening 6 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex 7 | 8 | 9 | class StaticImageScene(SceneObject): 10 | """ 11 | Implements static scene, where the RGB channels of the input image encode the refractive index, the dampening and sources. 12 | This class allows to use an image editor to create scenes. 13 | """ 14 | def __init__(self, scene_image, source_amplitude=1.0, source_fequency_scale=1.0): 15 | """ 16 | load source from an image description 17 | The simulation scenes are given as an 8Bit RGB image with the following channel semantics: 18 | * Red: The Refractive index times 100 (for refractive index 1.5 you would use value 150) 19 | * Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value 20 | defines its frequency. WARNING: Do not use antialiasing for the green channel ! 21 | * Blue: Absorbtion field. Larger values correspond to higher dampening of the waves, 22 | use graduated transitions to avoid reflections 23 | """ 24 | # Set the opacity of source pixels to incoming waves. If the opacity is 0.0 25 | # the field will be completely overwritten by the source term 26 | # a nonzero value (e.g 0.5) allows for antialiasing of sources to work 27 | self.source_opacity = 0.9 28 | 29 | # set refractive index field 30 | self.refractive_index = StaticRefractiveIndex(scene_image[:, :, 0] / 100) 31 | 32 | # set absorber field 33 | self.dampening = StaticDampening(1.0 - scene_image[:, :, 2] / 255, border_thickness=48) 34 | 35 | # set sources, each entry describes a source with the following parameters: 36 | # (x, y, phase, amplitude, frequency) 37 | sources_pos = np.flip(np.argwhere(scene_image[:, :, 1] > 0), axis=1) 38 | phase_amplitude_freq = np.tile(np.array([0, source_amplitude, 0.3]), (sources_pos.shape[0], 1)) 39 | self.sources = np.concatenate((sources_pos, phase_amplitude_freq), axis=1) 40 | 41 | # set source frequency to channel value 42 | self.sources[:, 4] = scene_image[sources_pos[:, 1], sources_pos[:, 0], 1] / 255 * 0.5 * source_fequency_scale 43 | self.sources = cp.array(self.sources).astype(cp.float32) 44 | 45 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 46 | """ 47 | render the stat 48 | """ 49 | self.dampening.render(field, wave_speed_field, dampening_field) 50 | self.refractive_index.render(field, wave_speed_field, dampening_field) 51 | 52 | def update_field(self, field: cp.ndarray, t): 53 | # Update the sources in the simulation field based on their properties. 54 | v = cp.sin(self.sources[:, 2]+self.sources[:, 4]*t)*self.sources[:, 3] 55 | coords = self.sources[:, 0:2].astype(cp.int32) 56 | 57 | o = self.source_opacity 58 | field[coords[:, 1], coords[:, 0]] = field[coords[:, 1], coords[:, 0]]*o + v*(1.0-o) 59 | 60 | def render_visualization(self, image: np.ndarray): 61 | """ renders a visualization of the scene object to the image """ 62 | pass 63 | -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/static_refractive_index.py: -------------------------------------------------------------------------------- 1 | from wave_sim2d.wave_simulation import SceneObject 2 | import cupy as cp 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | class StaticRefractiveIndex(SceneObject): 8 | """ 9 | Implements a static refractive index field that overwrites the entire domain with a constant IOR value. 10 | Use this as base layer in your scene. 11 | """ 12 | 13 | def __init__(self, refractive_index_field): 14 | """ 15 | Creates a static refractive index field object 16 | :param refractive_index_field: The refractive index field, same size as the source. 17 | Note that values below 0.9 are clipped to prevent the simulation 18 | from becoming instable 19 | """ 20 | shape = refractive_index_field.shape 21 | self.c = cp.ones((shape[0], shape[1]), dtype=cp.float32) 22 | self.c = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0) 23 | 24 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 25 | assert (wave_speed_field.shape == self.c.shape) 26 | wave_speed_field[:] = self.c 27 | 28 | def update_field(self, field: cp.ndarray, t): 29 | pass 30 | 31 | def render_visualization(self, image: np.ndarray): 32 | """ renders a visualization of the scene object to the image """ 33 | pass 34 | 35 | 36 | class StaticRefractiveIndexPolygon(SceneObject): 37 | """ 38 | Draws a static polygon with a given refractive index into the wave_speed_field using an 39 | anti-aliased mask and indexing. Caches the pixel coordinates and mask values. 40 | """ 41 | 42 | def __init__(self, vertices, refractive_index): 43 | """ 44 | Initializes the StaticRefractiveIndexPolygon. 45 | 46 | Args: 47 | vertices (list or np.ndarray): A list or array of (x, y) coordinates defining the polygon. 48 | refractive_index (float): The refractive index of the polygon. Values are clamped to [0.9, 10.0]. 49 | """ 50 | self.vertices = np.array(vertices, dtype=np.float32) 51 | self.refractive_index = min(max(refractive_index, 0.9), 10.0) 52 | self._cached_coords = None 53 | self._cached_mask_values = None 54 | self._cached_field_shape = (0, 0) 55 | 56 | def _create_polygon_data(self, field_shape): 57 | """ 58 | Creates and caches the pixel coordinates and anti-aliased mask values for the polygon. 59 | 60 | Args: 61 | field_shape (tuple): The shape (rows, cols) of the simulation field. 62 | 63 | Returns: 64 | tuple: A tuple containing: 65 | - coords (tuple of cp.ndarray): (y_coordinates, x_coordinates) of the polygon pixels within the field. 66 | - mask_values (cp.ndarray): Corresponding anti-aliased mask values (0.0 to 1.0). 67 | """ 68 | if self._cached_coords is not None and self._cached_field_shape == field_shape: 69 | return self._cached_coords, self._cached_mask_values 70 | 71 | rows, cols = field_shape 72 | 73 | # Find the bounding box of the polygon 74 | min_x = np.min(self.vertices[:, 0]) 75 | max_x = np.max(self.vertices[:, 0]) 76 | min_y = np.min(self.vertices[:, 1]) 77 | max_y = np.max(self.vertices[:, 1]) 78 | 79 | mask_width = int(np.ceil(max_x - min_x)) + 1 80 | mask_height = int(np.ceil(max_y - min_y)) + 1 81 | offset_x = int(np.floor(min_x)) 82 | offset_y = int(np.floor(min_y)) 83 | 84 | # Create the mask 85 | mask = np.zeros((mask_height, mask_width), dtype=np.float32) 86 | translated_vertices = self.vertices - [offset_x, offset_y] 87 | translated_vertices_cv = np.round(translated_vertices).astype(np.int32) 88 | cv2.fillPoly(mask, [translated_vertices_cv], 1.0, lineType=cv2.LINE_AA) 89 | 90 | # Get coordinates and mask values of non-black pixels 91 | coords_y, coords_x = np.where(mask > 0) 92 | mask_values = mask[coords_y, coords_x] 93 | 94 | # Adjust coordinates to the position in the main field 95 | global_coords_y = coords_y + offset_y 96 | global_coords_x = coords_x + offset_x 97 | 98 | # Perform out-of-bounds check here 99 | in_bounds = (global_coords_y >= 0) & (global_coords_y < rows) & \ 100 | (global_coords_x >= 0) & (global_coords_x < cols) 101 | 102 | valid_global_y = global_coords_y[in_bounds] 103 | valid_global_x = global_coords_x[in_bounds] 104 | valid_mask_values = mask_values[in_bounds] 105 | 106 | self._cached_coords = (cp.array(valid_global_y), cp.array(valid_global_x)) 107 | self._cached_mask_values = cp.array(valid_mask_values, dtype=cp.float32) 108 | self._cached_field_shape = field_shape 109 | return self._cached_coords, self._cached_mask_values 110 | 111 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 112 | coords, mask_values = self._create_polygon_data(wave_speed_field.shape) 113 | 114 | # Use advanced indexing to update the field and perform alpha blending 115 | bg_wave_speed = wave_speed_field[coords[0], coords[1]] 116 | wave_speed_field[coords[0], coords[1]] = (bg_wave_speed * (1.0 - mask_values) + 117 | mask_values / self.refractive_index) 118 | 119 | def update_field(self, field: cp.ndarray, t): 120 | pass 121 | 122 | def render_visualization(self, image: np.ndarray): 123 | vertices = np.round(self.vertices).astype(np.int32) 124 | cv2.fillPoly(image, [vertices], (60, 60, 60), lineType=cv2.LINE_AA) 125 | 126 | 127 | class StaticRefractiveIndexBox(StaticRefractiveIndexPolygon): 128 | """ 129 | Draws a static rotated box with a given refractive index into the wave_speed_field by 130 | inheriting from StaticRefractiveIndexPolygon. 131 | """ 132 | 133 | def __init__(self, center: tuple, box_size: tuple, box_angle_rad: float, refractive_index: float): 134 | """ 135 | Initializes the StaticRefractiveIndexBox. 136 | 137 | Args: 138 | center (tuple): A tuple (center_x, center_y) representing the box's center. 139 | box_size (tuple): A tuple (width, height) representing the box's dimensions. 140 | box_angle_rad (float): The rotation angle of the box in radians (counter-clockwise). 141 | refractive_index (float): The refractive index of the box. Values are clamped to [0.9, 10.0]. 142 | """ 143 | self.center = center 144 | self.box_size = box_size 145 | self.box_angle_rad = box_angle_rad 146 | refractive_index = min(max(refractive_index, 0.9), 10.0) 147 | 148 | # Unpack center and box size 149 | center_x, center_y = self.center 150 | width, height = self.box_size 151 | 152 | # Calculate the vertices of the rotated box 153 | half_width = width / 2 154 | half_height = height / 2 155 | local_vertices = np.array([[-half_width, -half_height], 156 | [half_width, -half_height], 157 | [half_width, half_height], 158 | [-half_width, half_height]], dtype=np.float32) 159 | 160 | # Create the rotation matrix 161 | rotation_matrix = cv2.getRotationMatrix2D((0, 0), np.rad2deg(self.box_angle_rad), 1) 162 | 163 | # Rotate the local vertices 164 | rotated_vertices = cv2.transform(np.array([local_vertices]), rotation_matrix)[0] 165 | 166 | # Translate the rotated vertices to the center 167 | translated_vertices = rotated_vertices + [center_x, center_y] 168 | 169 | # Initialize the parent class (StaticRefractiveIndexPolygon) with the vertices 170 | super().__init__(translated_vertices, refractive_index) 171 | -------------------------------------------------------------------------------- /wave_sim2d/scene_objects/strain_refractive_index.py: -------------------------------------------------------------------------------- 1 | from wave_sim2d.wave_simulation import SceneObject 2 | import cupy as cp 3 | import cupyx.scipy.signal 4 | import numpy as np 5 | 6 | class StrainRefractiveIndex(SceneObject): 7 | """ 8 | Implements a dynamic refractive index field that linearly depends on the strain of the current field. 9 | The refractive index within the entire domain is overwritten 10 | """ 11 | 12 | def __init__(self, refractive_index_offset, coupling_constant): 13 | """ 14 | Creates a strain refractive index field object 15 | :param coupling_constant: coupling constant between the strain and the refractive index 16 | """ 17 | self.coupling_constant = coupling_constant 18 | self.refractive_index_offset = refractive_index_offset 19 | 20 | self.du_dx_kernel = cp.array([[-1, 0.0, 1]]) 21 | self.du_dy_kernel = cp.array([[-1], [0.0], [1]]) 22 | 23 | self.strain_field = None 24 | 25 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray): 26 | # compute strain 27 | du_dx = cupyx.scipy.signal.convolve2d(field, self.du_dx_kernel, mode='same', boundary='fill') 28 | du_dy = cupyx.scipy.signal.convolve2d(field, self.du_dy_kernel, mode='same', boundary='fill') 29 | 30 | self.strain_field = cp.sqrt(du_dx**2 + du_dy**2) 31 | 32 | # compute refractive index from strain 33 | refractive_index_field = self.refractive_index_offset + self.strain_field*self.coupling_constant 34 | 35 | # assign wave speed using refractive index from above 36 | wave_speed_field[:] = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0) 37 | 38 | def update_field(self, field: cp.ndarray, t): 39 | pass 40 | 41 | def render_visualization(self, image: np.ndarray): 42 | """ renders a visualization of the scene object to the image """ 43 | pass 44 | -------------------------------------------------------------------------------- /wave_sim2d/wave_simulation.py: -------------------------------------------------------------------------------- 1 | import cupy 2 | import numpy as np 3 | import cupy as cp 4 | import cupyx.scipy.signal 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class SceneObject(ABC): 9 | """ 10 | Interface for simulation scene objects. A scene object is anything defining or modifying the simulation scene. 11 | For example: Light sources, Absorbers or regions with specific refractive index. Scene objects can change the 12 | simulated field and draw their contribution to the wave speed field and dampening field each frame """ 13 | 14 | @abstractmethod 15 | def render(self, field: cupy.ndarray, wave_speed_field: cupy.ndarray, dampening_field: cupy.ndarray): 16 | """ renders the scene objects contribution to the wave speed field and dampening field """ 17 | pass 18 | 19 | @abstractmethod 20 | def update_field(self, field: cupy.ndarray, t): 21 | """ performs updates to the field itself, e.g. for adding sources """ 22 | pass 23 | 24 | @abstractmethod 25 | def render_visualization(self, image: np.ndarray): 26 | """ renders a visualization of the scene object to the image """ 27 | pass 28 | 29 | 30 | class WaveSimulator2D: 31 | """ 32 | Simulates the 2D wave equation 33 | The system assumes units, where the wave speed is 1.0 pixel/timestep 34 | source frequency should be adjusted accordingly 35 | """ 36 | def __init__(self, w, h, scene_objects, initial_field=None): 37 | """ 38 | Initialize the 2D wave simulator. 39 | @param w: Width of the simulation grid. 40 | @param h: Height of the simulation grid. 41 | """ 42 | self.global_dampening = 1.0 43 | self.c = cp.ones((h, w), dtype=cp.float32) # wave speed field (from refractive indices) 44 | self.d = cp.ones((h, w), dtype=cp.float32) # dampening field 45 | self.u = cp.zeros((h, w), dtype=cp.float32) # field values 46 | self.u_prev = cp.zeros((h, w), dtype=cp.float32) # field values of prev frame 47 | 48 | if initial_field is not None: 49 | assert w == initial_field.shape[1] and h == initial_field.shape[2], 'width/height of initial field invalid' 50 | self.u[:] = initial_field 51 | self.u_prev[:] = initial_field 52 | 53 | # Define Laplacian kernel 54 | self.laplacian_kernel = cp.array([[0.066, 0.184, 0.066], 55 | [0.184, -1.0, 0.184], 56 | [0.066, 0.184, 0.066]]) 57 | 58 | # self.laplacian_kernel = cp.array([[0.05, 0.2, 0.05], 59 | # [0.2, -1.0, 0.2], 60 | # [0.05, 0.2, 0.05]]) 61 | 62 | # self.laplacian_kernel = cp.array([[0.103, 0.147, 0.103], 63 | # [0.147, -1.0, 0.147], 64 | # [0.103, 0.147, 0.103]]) 65 | 66 | self.t = 0 67 | self.dt = 1.0 68 | 69 | self.scene_objects = scene_objects if scene_objects is not None else [] 70 | 71 | def reset_time(self): 72 | """ 73 | Reset the simulation time to zero. 74 | """ 75 | self.t = 0.0 76 | 77 | def update_field(self): 78 | """ 79 | Update the simulation field based on the wave equation. 80 | """ 81 | # calculate laplacian using convolution 82 | laplacian = cupyx.scipy.signal.convolve2d(self.u, self.laplacian_kernel, mode='same', boundary='fill') 83 | 84 | # update field 85 | v = (self.u - self.u_prev) * self.d * self.global_dampening 86 | r = (self.u + v + laplacian * (self.c * self.dt)**2) 87 | 88 | self.u_prev[:] = self.u 89 | self.u[:] = r 90 | 91 | self.t += self.dt 92 | 93 | def update_scene(self): 94 | # clear wave speed field and dampening field 95 | self.c.fill(1.0) 96 | self.d.fill(1.0) 97 | 98 | for obj in self.scene_objects: 99 | obj.render(self.u, self.c, self.d) 100 | 101 | for obj in self.scene_objects: 102 | obj.update_field(self.u, self.t) 103 | 104 | def get_field(self): 105 | """ 106 | Get the current state of the simulation field. 107 | @return: A 2D array representing the simulation field. 108 | """ 109 | return self.u 110 | 111 | def render_visualization(self, image=None): 112 | # clear wave speed field and dampening field 113 | if image is None: 114 | image = np.zeros((self.c.shape[0], self.c.shape[1], 3), dtype=np.uint8) 115 | 116 | for obj in self.scene_objects: 117 | obj.render_visualization(image) 118 | 119 | return image 120 | 121 | 122 | -------------------------------------------------------------------------------- /wave_sim2d/wave_visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cupy as cp 3 | import cv2 4 | import matplotlib.pyplot 5 | 6 | colormap_icefire = [[179, 224, 216], [178, 223, 216], [176, 222, 215], [175, 221, 215], [173, 219, 214], [171, 218, 214], [169, 217, 214], [167, 215, 213], [165, 214, 213], [162, 212, 212], [160, 210, 212], [157, 209, 211], [154, 207, 211], [151, 205, 210], [148, 203, 210], [146, 201, 209], [143, 199, 209], [140, 198, 208], [137, 196, 208], [134, 194, 208], [131, 192, 207], [128, 190, 207], [125, 188, 207], [122, 187, 207], [119, 185, 206], [116, 183, 206], [113, 181, 206], [110, 179, 206], [108, 177, 206], [105, 176, 205], [102, 174, 205], [99, 172, 205], [97, 170, 205], [94, 168, 205], [91, 166, 205], [89, 164, 205], [86, 162, 205], [84, 161, 205], [82, 159, 205], [79, 157, 205], [77, 155, 205], [75, 153, 206], [73, 151, 206], [71, 149, 206], [69, 147, 206], [68, 145, 206], [66, 143, 206], [65, 140, 206], [64, 138, 206], [63, 136, 206], [62, 134, 206], [61, 132, 206], [61, 130, 205], [61, 127, 205], [60, 125, 205], [60, 123, 204], [60, 121, 203], [60, 118, 203], [61, 116, 202], [61, 114, 201], [61, 112, 200], [62, 109, 198], [62, 107, 197], [63, 105, 195], [64, 103, 194], [65, 100, 192], [65, 98, 190], [66, 96, 187], [67, 94, 185], [67, 92, 183], [68, 90, 180], [68, 88, 177], [69, 86, 174], [69, 85, 171], [69, 83, 168], [70, 81, 165], [70, 79, 162], [70, 78, 158], [69, 76, 155], [69, 75, 151], [69, 73, 148], [68, 72, 144], [68, 70, 141], [67, 69, 137], [66, 67, 134], [66, 66, 130], [65, 65, 127], [64, 63, 123], [63, 62, 120], [62, 61, 116], [61, 60, 113], [60, 59, 109], [59, 57, 106], [58, 56, 103], [57, 55, 99], [55, 54, 96], [54, 53, 93], [53, 52, 90], [52, 50, 87], [51, 49, 84], [50, 48, 81], [48, 47, 78], [47, 46, 75], [46, 45, 72], [45, 44, 70], [44, 43, 67], [43, 42, 65], [42, 41, 62], [41, 40, 60], [40, 39, 57], [39, 38, 55], [38, 37, 53], [37, 37, 51], [37, 36, 49], [36, 35, 47], [35, 35, 45], [35, 34, 44], [34, 33, 42], [34, 33, 41], [33, 32, 39], [33, 32, 38], [33, 32, 37], [33, 31, 36], [33, 31, 35], [33, 31, 35], [34, 30, 34], [34, 30, 33], [34, 30, 33], [35, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]] 7 | colormap_wave1 = [[255, 255, 255], [254, 254, 253], [254, 253, 252], [253, 252, 250], [253, 250, 248], [252, 249, 246], [252, 248, 244], [251, 246, 242], [251, 245, 240], [250, 243, 237], [250, 242, 235], [249, 240, 232], [248, 238, 230], [248, 237, 227], [247, 235, 224], [247, 233, 221], [246, 231, 218], [245, 229, 215], [245, 227, 212], [244, 225, 209], [243, 223, 206], [242, 221, 203], [242, 219, 200], [241, 217, 196], [240, 215, 193], [239, 213, 190], [239, 211, 186], [238, 208, 183], [237, 206, 179], [236, 204, 176], [235, 202, 172], [234, 199, 169], [233, 197, 165], [232, 195, 162], [231, 192, 158], [230, 190, 155], [230, 188, 151], [228, 185, 148], [227, 183, 144], [226, 181, 141], [225, 178, 137], [224, 176, 134], [223, 174, 130], [222, 171, 127], [221, 169, 124], [219, 167, 120], [218, 164, 117], [217, 162, 114], [216, 160, 111], [214, 157, 108], [213, 155, 105], [212, 153, 102], [210, 151, 99], [209, 149, 96], [207, 146, 93], [206, 144, 91], [204, 142, 88], [203, 140, 85], [201, 138, 83], [199, 136, 80], [198, 134, 78], [196, 132, 76], [194, 130, 74], [193, 128, 72], [191, 127, 70], [189, 125, 69], [187, 123, 67], [185, 121, 65], [183, 119, 63], [180, 117, 62], [178, 115, 60], [176, 113, 58], [173, 111, 56], [170, 109, 55], [168, 107, 53], [165, 105, 52], [162, 102, 50], [159, 100, 49], [156, 98, 47], [154, 96, 46], [151, 94, 44], [148, 92, 43], [144, 90, 41], [141, 87, 40], [138, 85, 39], [135, 83, 37], [132, 81, 36], [129, 79, 35], [125, 76, 34], [122, 74, 33], [119, 72, 31], [115, 70, 30], [112, 68, 29], [109, 66, 28], [105, 64, 27], [102, 62, 27], [99, 60, 26], [96, 58, 25], [93, 56, 25], [89, 54, 25], [86, 52, 25], [83, 51, 25], [80, 49, 25], [77, 47, 25], [74, 45, 25], [71, 44, 25], [68, 42, 25], [65, 41, 25], [62, 39, 25], [60, 38, 25], [57, 37, 25], [54, 35, 25], [52, 34, 25], [49, 33, 25], [47, 32, 25], [45, 31, 25], [43, 30, 25], [40, 29, 25], [39, 28, 25], [37, 28, 25], [35, 27, 25], [33, 27, 25], [32, 26, 25], [30, 26, 25], [29, 25, 25], [28, 25, 25], [27, 25, 25], [26, 25, 25], [26, 25, 26], [26, 26, 27], [26, 26, 28], [26, 26, 30], [26, 27, 31], [26, 27, 33], [26, 28, 34], [26, 29, 36], [26, 30, 38], [26, 31, 40], [26, 32, 42], [26, 33, 44], [26, 34, 47], [26, 35, 49], [26, 37, 51], [26, 38, 54], [26, 40, 56], [26, 41, 59], [26, 43, 62], [26, 44, 64], [26, 46, 67], [26, 48, 70], [26, 50, 73], [27, 51, 76], [28, 53, 79], [28, 55, 82], [29, 57, 85], [30, 59, 88], [31, 61, 91], [32, 64, 94], [33, 66, 97], [35, 68, 101], [36, 70, 104], [37, 72, 107], [38, 74, 110], [40, 77, 113], [41, 79, 117], [42, 81, 120], [44, 84, 123], [45, 86, 126], [47, 88, 130], [48, 91, 133], [50, 93, 136], [51, 95, 139], [53, 98, 142], [54, 100, 145], [56, 102, 148], [58, 104, 151], [59, 107, 154], [61, 109, 157], [63, 111, 160], [64, 114, 163], [66, 116, 165], [68, 118, 168], [70, 120, 171], [71, 122, 173], [73, 125, 176], [75, 127, 178], [77, 129, 181], [78, 131, 183], [80, 133, 185], [82, 135, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]] 8 | colormap_wave2 = [[255, 255, 255], [253, 254, 254], [252, 254, 253], [250, 253, 252], [248, 253, 252], [246, 252, 251], [244, 252, 250], [242, 251, 249], [240, 251, 247], [237, 250, 246], [235, 250, 245], [232, 249, 244], [230, 248, 243], [227, 248, 242], [224, 247, 240], [221, 247, 239], [218, 246, 238], [215, 245, 236], [212, 245, 235], [209, 244, 234], [206, 243, 232], [203, 242, 231], [200, 242, 229], [196, 241, 228], [193, 240, 226], [190, 239, 225], [186, 239, 223], [183, 238, 222], [179, 237, 220], [176, 236, 218], [172, 235, 217], [169, 234, 215], [165, 233, 213], [162, 232, 212], [158, 231, 210], [155, 231, 208], [151, 230, 206], [148, 228, 205], [144, 227, 203], [141, 226, 201], [137, 225, 199], [134, 224, 198], [130, 223, 196], [127, 222, 194], [124, 221, 192], [120, 219, 190], [117, 218, 188], [114, 217, 187], [111, 216, 185], [108, 214, 183], [105, 213, 181], [102, 212, 179], [99, 210, 177], [96, 209, 176], [93, 207, 174], [91, 206, 172], [88, 204, 170], [85, 203, 168], [83, 201, 166], [80, 199, 164], [78, 198, 163], [76, 196, 161], [74, 194, 159], [72, 193, 157], [70, 191, 155], [69, 189, 153], [67, 187, 152], [65, 185, 149], [63, 183, 147], [62, 180, 145], [60, 178, 143], [59, 175, 141], [57, 173, 138], [56, 170, 136], [54, 167, 133], [53, 164, 131], [52, 161, 128], [50, 158, 125], [49, 155, 123], [48, 152, 120], [47, 149, 117], [46, 146, 115], [45, 143, 112], [43, 140, 109], [42, 136, 106], [41, 133, 103], [40, 130, 101], [39, 126, 98], [39, 123, 95], [38, 119, 92], [37, 116, 89], [36, 112, 86], [35, 109, 84], [35, 106, 81], [34, 102, 78], [33, 99, 75], [33, 95, 73], [32, 92, 70], [31, 89, 67], [31, 86, 65], [30, 82, 62], [30, 79, 60], [29, 76, 57], [29, 73, 55], [28, 70, 52], [28, 67, 50], [28, 64, 48], [27, 61, 46], [27, 58, 44], [27, 55, 42], [26, 53, 40], [26, 50, 38], [26, 48, 37], [26, 46, 35], [26, 43, 34], [26, 41, 32], [26, 39, 31], [26, 37, 30], [26, 35, 29], [26, 34, 28], [26, 32, 27], [26, 31, 26], [26, 29, 26], [26, 28, 25], [26, 27, 25], [26, 27, 25], [26, 26, 25], [26, 25, 25], [26, 25, 26], [26, 25, 26], [26, 25, 27], [26, 25, 27], [26, 25, 28], [27, 25, 30], [27, 25, 31], [27, 26, 32], [28, 27, 34], [28, 27, 35], [28, 28, 37], [29, 29, 39], [29, 30, 41], [30, 31, 43], [30, 32, 45], [31, 34, 48], [31, 35, 50], [32, 36, 53], [32, 38, 55], [33, 40, 58], [33, 41, 61], [34, 43, 64], [35, 45, 66], [36, 47, 69], [36, 49, 73], [37, 51, 76], [38, 53, 79], [39, 55, 82], [39, 57, 85], [40, 59, 89], [41, 61, 92], [42, 64, 95], [43, 66, 99], [44, 68, 102], [45, 71, 105], [46, 73, 109], [47, 76, 112], [48, 78, 116], [49, 80, 119], [50, 83, 122], [52, 86, 126], [53, 88, 129], [54, 91, 133], [55, 93, 136], [57, 96, 139], [58, 98, 143], [59, 100, 146], [60, 103, 149], [62, 105, 152], [63, 108, 155], [65, 110, 158], [66, 113, 161], [68, 115, 164], [69, 117, 167], [71, 120, 170], [72, 122, 173], [74, 124, 175], [75, 126, 178], [77, 128, 180], [79, 131, 183], [80, 133, 185], [82, 134, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]] 9 | colormap_wave3 = [[253, 203, 160], [252, 201, 158], [252, 200, 156], [252, 198, 154], [251, 196, 152], [251, 194, 149], [251, 192, 147], [250, 190, 145], [250, 189, 142], [249, 187, 139], [249, 185, 135], [248, 182, 132], [248, 179, 129], [247, 177, 125], [247, 175, 122], [247, 172, 119], [246, 168, 115], [246, 166, 112], [245, 164, 108], [245, 161, 105], [244, 158, 102], [243, 155, 98], [243, 152, 95], [242, 150, 92], [242, 147, 88], [241, 145, 86], [240, 142, 83], [240, 139, 80], [239, 137, 78], [238, 134, 75], [237, 131, 73], [237, 129, 70], [236, 126, 68], [235, 124, 67], [234, 121, 65], [233, 118, 63], [232, 115, 61], [231, 112, 61], [230, 110, 59], [229, 107, 57], [228, 104, 56], [228, 102, 55], [226, 100, 55], [224, 97, 55], [224, 94, 54], [222, 92, 54], [221, 89, 53], [219, 87, 53], [218, 84, 53], [217, 83, 53], [215, 80, 54], [213, 78, 54], [211, 76, 54], [209, 74, 55], [208, 72, 55], [206, 70, 56], [203, 68, 57], [201, 66, 57], [199, 65, 58], [198, 64, 59], [196, 62, 59], [193, 60, 60], [191, 59, 61], [188, 57, 61], [186, 56, 62], [183, 55, 62], [181, 55, 63], [179, 54, 63], [176, 53, 63], [173, 52, 64], [171, 51, 64], [168, 50, 65], [165, 49, 65], [162, 49, 65], [159, 48, 66], [157, 48, 66], [155, 47, 66], [152, 47, 66], [149, 46, 66], [146, 46, 65], [143, 46, 65], [140, 45, 65], [138, 45, 65], [135, 45, 64], [132, 44, 64], [130, 44, 63], [127, 44, 63], [124, 44, 62], [121, 43, 61], [118, 43, 60], [115, 43, 60], [112, 42, 60], [110, 42, 59], [108, 42, 58], [105, 42, 57], [102, 41, 56], [99, 41, 55], [97, 40, 54], [94, 40, 53], [91, 40, 52], [89, 39, 51], [86, 39, 50], [84, 38, 49], [82, 38, 48], [79, 37, 47], [77, 37, 46], [74, 36, 45], [72, 36, 44], [70, 35, 43], [68, 35, 42], [65, 34, 41], [64, 34, 40], [62, 33, 39], [60, 33, 38], [58, 33, 38], [56, 32, 38], [54, 32, 36], [52, 32, 35], [50, 32, 35], [48, 31, 35], [47, 31, 34], [45, 31, 34], [43, 31, 33], [42, 30, 33], [41, 30, 32], [40, 30, 32], [39, 30, 32], [38, 30, 32], [37, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]] 10 | colormap_wave4 = [[246, 230, 183], [246, 229, 182], [246, 227, 180], [246, 226, 178], [246, 224, 176], [245, 222, 173], [245, 219, 170], [244, 217, 167], [244, 214, 163], [244, 211, 160], [243, 209, 156], [243, 206, 152], [242, 203, 148], [242, 200, 144], [241, 196, 140], [241, 193, 136], [241, 190, 132], [240, 186, 128], [240, 183, 124], [239, 180, 120], [239, 176, 116], [238, 173, 112], [238, 170, 108], [237, 166, 104], [237, 163, 100], [236, 160, 97], [236, 156, 93], [236, 153, 90], [235, 150, 87], [235, 147, 84], [235, 144, 81], [234, 140, 78], [234, 137, 76], [234, 134, 74], [234, 131, 71], [233, 127, 69], [233, 124, 67], [233, 121, 65], [233, 118, 64], [232, 115, 62], [232, 112, 61], [232, 109, 60], [232, 106, 59], [232, 103, 58], [232, 101, 58], [232, 98, 57], [232, 95, 57], [231, 93, 57], [230, 90, 57], [230, 88, 57], [229, 85, 57], [227, 83, 57], [226, 81, 57], [224, 78, 57], [222, 76, 58], [220, 74, 58], [217, 72, 59], [215, 70, 59], [212, 67, 60], [210, 65, 60], [207, 63, 60], [204, 62, 61], [201, 60, 61], [199, 58, 61], [196, 56, 61], [193, 55, 61], [189, 53, 61], [186, 52, 62], [183, 50, 61], [180, 49, 61], [176, 48, 61], [173, 46, 61], [170, 45, 61], [166, 44, 61], [163, 43, 61], [159, 42, 60], [156, 40, 60], [152, 39, 60], [149, 39, 59], [146, 38, 58], [142, 37, 58], [139, 36, 57], [135, 35, 56], [132, 34, 55], [128, 34, 54], [125, 33, 53], [122, 32, 52], [118, 31, 51], [115, 31, 50], [111, 30, 49], [108, 29, 48], [105, 28, 46], [101, 28, 45], [98, 27, 44], [94, 26, 43], [91, 25, 41], [88, 24, 40], [85, 24, 38], [82, 23, 37], [78, 22, 35], [75, 21, 34], [72, 20, 32], [69, 19, 31], [66, 18, 29], [63, 18, 28], [60, 16, 26], [57, 16, 25], [55, 15, 24], [52, 14, 22], [49, 13, 21], [46, 12, 19], [44, 11, 18], [41, 11, 17], [39, 10, 16], [36, 9, 14], [34, 8, 13], [31, 8, 12], [29, 7, 11], [27, 7, 10], [25, 6, 10], [23, 5, 9], [21, 5, 8], [19, 4, 7], [17, 4, 7], [16, 4, 6], [14, 3, 6], [13, 3, 5], [13, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [13, 3, 5], [14, 3, 5], [15, 3, 6], [16, 4, 6], [18, 4, 7], [20, 4, 7], [21, 5, 8], [23, 5, 9], [26, 6, 10], [28, 7, 11], [30, 7, 12], [33, 8, 13], [35, 9, 14], [38, 9, 15], [40, 10, 17], [43, 11, 18], [46, 12, 19], [48, 13, 21], [51, 14, 22], [54, 15, 24], [57, 16, 25], [60, 17, 27], [63, 18, 28], [67, 19, 30], [70, 20, 31], [73, 20, 33], [76, 21, 34], [80, 22, 36], [83, 23, 37], [86, 24, 39], [90, 25, 40], [93, 26, 42], [96, 27, 43], [100, 27, 44], [103, 28, 46], [107, 29, 47], [110, 30, 48], [114, 30, 50], [117, 31, 51], [121, 32, 52], [124, 33, 53], [128, 34, 54], [131, 34, 55], [135, 35, 56], [138, 36, 57], [142, 37, 57], [146, 38, 58], [149, 39, 59], [153, 39, 59], [156, 40, 60], [160, 42, 60], [164, 43, 61], [167, 44, 61], [171, 45, 61], [174, 46, 61], [178, 48, 62], [181, 49, 62], [184, 51, 62], [188, 52, 62], [191, 54, 62], [194, 56, 61], [197, 57, 61], [200, 59, 61], [203, 61, 61], [206, 63, 60], [209, 65, 60], [212, 67, 59], [215, 69, 59], [217, 72, 59], [220, 74, 58], [222, 76, 58], [224, 78, 57], [226, 81, 57], [227, 83, 57], [229, 86, 57], [230, 88, 57], [230, 91, 57], [231, 94, 57], [232, 97, 57], [232, 99, 57], [232, 102, 58], [232, 105, 59], [232, 108, 60], [232, 111, 61], [232, 114, 62], [232, 117, 63], [233, 120, 65], [233, 123, 66], [233, 127, 68], [233, 130, 71], [234, 133, 73], [234, 137, 75], [234, 140, 78], [235, 143, 81], [235, 147, 84], [235, 150, 87], [236, 153, 90], [236, 157, 94], [236, 160, 97], [237, 164, 101], [237, 167, 105], [238, 170, 109], [238, 174, 113], [239, 178, 117], [239, 181, 121], [240, 184, 126], [240, 188, 130], [241, 192, 134], [241, 195, 138], [242, 198, 142], [242, 202, 147], [243, 205, 151], [243, 208, 155], [244, 211, 159], [244, 214, 162], [244, 216, 166], [245, 219, 169], [245, 221, 173], [246, 224, 175], [246, 226, 178], [246, 227, 180], [246, 229, 182], [246, 230, 183]] 11 | 12 | 13 | class WaveVisualizer: 14 | def __init__(self, field_colormap, intensity_colormap): 15 | self.field_colormap = field_colormap 16 | self.intensity_colormap = intensity_colormap 17 | self.intensity = None 18 | self.intensity_exp_average_factor = 0.98 19 | self.field = None 20 | self.visualization_image = None 21 | 22 | def update(self, wave_sim): 23 | self.field = wave_sim.get_field() 24 | 25 | if self.intensity is None: 26 | self.intensity = cp.zeros_like(self.field) 27 | 28 | t = self.intensity_exp_average_factor 29 | self.intensity = self.intensity*t + (self.field**2)*(1.0-t) 30 | self.visualization_image = wave_sim.render_visualization() 31 | 32 | def render_intensity(self, brightness_scale=1.0, exp=0.5, overlay_visualization=True): 33 | gray = (cp.clip((self.intensity**exp)*brightness_scale, 0.0, 1.0) * 254.0).astype(np.uint8) 34 | img = self.intensity_colormap[gray].get() if self.intensity_colormap is not None else gray.get() 35 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 36 | if overlay_visualization: 37 | img = cv2.add(img, self.visualization_image) 38 | return img 39 | 40 | def render_field(self, brightness_scale=1.0, overlay_visualization=True): 41 | gray = (cp.clip(self.field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8) 42 | img = self.field_colormap[gray].get() if self.field_colormap is not None else gray.get() 43 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 44 | if overlay_visualization: 45 | img = cv2.add(img, self.visualization_image) 46 | return img 47 | 48 | 49 | def get_colormap_lut(name, invert, black_level=0.0, make_symmetric=False): 50 | if name == 'icefire': color_values = np.array(colormap_icefire)/255 51 | elif name == 'colormap_wave1': color_values = np.array(colormap_wave1)/255 52 | elif name == 'colormap_wave2':color_values = np.array(colormap_wave2) / 255 53 | elif name == 'colormap_wave3': color_values = np.array(colormap_wave3) / 255 54 | elif name == 'colormap_wave4': color_values = np.array(colormap_wave4) / 255 55 | else: 56 | colormap = matplotlib.pyplot.get_cmap(name) 57 | color_values = colormap(np.linspace(0, 1, 255)) 58 | 59 | if invert: 60 | color_values = 1.0-color_values 61 | 62 | if make_symmetric: 63 | src = color_values.copy() 64 | color_values[255:126:-1, :] = src[0:255:2, :] 65 | color_values[0:128, :] = src[0:255:2, :] 66 | 67 | color_values = np.clip(color_values*(1.0-black_level)+black_level, 0, 255) 68 | 69 | return cp.asarray((color_values*255).astype(np.uint8)) --------------------------------------------------------------------------------