├── README.md
├── example_data
├── scene_lens_doubleslit.png
└── scene_optical_fibers.png
├── images
├── optical_cavity.jpg
├── simulation_1.jpg
├── simulation_2.jpg
├── simulation_3.jpg
├── simulation_4.jpg
└── source_antialiasing.png
├── requirements.txt
└── wave_sim2d
├── __init__.py
├── develop_tests.py
├── examples
├── example0.py
├── example1.py
├── example2.py
├── example3.py
└── example4.py
├── main.py
├── scene_objects
├── __pycache__
│ └── border_absorber.cpython-310.pyc
├── source.py
├── static_dampening.py
├── static_image_scene.py
├── static_refractive_index.py
└── strain_refractive_index.py
├── wave_simulation.py
└── wave_visualizer.py
/README.md:
--------------------------------------------------------------------------------
1 | # 2D Wave Simulation on the GPU
2 |
3 | This repository contains a lightweight 2D wave simulator running on the GPU using CuPy library (probably requires a NVIDIA GPU). It can be used for 2D light and sound simulations.
4 | A simple visualizer shows the field and its intensity on the screen and writes a movie file for each to disks. The goal is to provide a fast, easy to use but still felxible wave simulator.
5 |
6 |
7 |

8 |

9 |
10 |
11 | ### Update 06.04.2025
12 |
13 | * Scene objects can now draw to a visualization layer (most of them do not yet, feel free to contribute) !
14 | * Example 4 now shows a two-mirror optical cavity and how standing waves emerge.
15 | * Added new Line Sources
16 | * Added Refractive index Polygon object (StaticRefractiveIndexPolygon)
17 | * Added Refractive index Box object (StaticRefractiveIndexBox)
18 | * Fixed some issues with the examples
19 |
20 |
21 |

22 |
23 |
24 | ### Update 01.04.2024
25 |
26 | * Refactored the code to support a more flexible scene description. A simulation scene now consists of a list of objects that add their contribution to the fields.
27 | They can be combined to build complex and time dependent simulations. The refactoring also made the core simulation code even simpler.
28 | * Added a few new custom colormaps that work well for wave simulations.
29 | * Added new examples, which should make it easier to understand the usage of the program and how you can setup your own simulations: [examples](source/examples).
30 |
31 |
32 |

33 |

34 |
35 |
36 | The old image based scene description is still available as a scene object. You can continue to use the convenience of an image editing software and create simulations
37 | without much programming.
38 |
39 | ### Image Scene Decsription Usage ###
40 |
41 | When using the 'StaticImageScene' class the simulation scenes can given as an 8Bit RGB image with the following channel semantics:
42 | * Red: The Refractive index times 100 (for refractive index 1.5 you would use value 150)
43 | * Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value defines its frequency.
44 | * Blue: Absorbtion field. Larger values correspond to higher dampening of the waves, use graduated transitions to avoid reflections
45 |
46 | WARNING: Do not use anti-aliasing for the green channel ! The shades produced are interpreted as different source frequencies, which yields weird results.
47 |
48 |
49 |

50 |
51 |
52 | ### Recommended Installation ###
53 |
54 | 1. Install Python and PyCharm IDE
55 | 2. Clone the Project to you hard disk
56 | 3. Open the folder as a Project using PyCharm
57 | 4. If prompted to install requirements, accept (or install requirements using pip -r requirements.txt)
58 | 5. Right click on one of the examples in wave_sim2d/examples and select run
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/example_data/scene_lens_doubleslit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/example_data/scene_lens_doubleslit.png
--------------------------------------------------------------------------------
/example_data/scene_optical_fibers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/example_data/scene_optical_fibers.png
--------------------------------------------------------------------------------
/images/optical_cavity.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/optical_cavity.jpg
--------------------------------------------------------------------------------
/images/simulation_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_1.jpg
--------------------------------------------------------------------------------
/images/simulation_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_2.jpg
--------------------------------------------------------------------------------
/images/simulation_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_3.jpg
--------------------------------------------------------------------------------
/images/simulation_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/simulation_4.jpg
--------------------------------------------------------------------------------
/images/source_antialiasing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/images/source_antialiasing.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | matplotlib
4 | cupy
--------------------------------------------------------------------------------
/wave_sim2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/wave_sim2d/__init__.py
--------------------------------------------------------------------------------
/wave_sim2d/develop_tests.py:
--------------------------------------------------------------------------------
1 | import wave_visualizer
2 | import wave_visualizer as vis
3 | import wave_simulation as sim
4 | import numpy as np
5 | import cv2
6 | import math
7 | import json
8 | from scene_objects.static_dampening import StaticDampening
9 | from scene_objects.static_refractive_index import StaticRefractiveIndex
10 | from scene_objects.static_image_scene import StaticImageScene
11 | from scene_objects.source import PointSource, ModulatorSmoothSquare, ModulatorDiscreteSignal
12 |
13 |
14 | def build_example_scene1(scene_image):
15 | """
16 | This example uses the old image scene description. See 'StaticImageScene' for more information.
17 | """
18 | scene_objects = [StaticImageScene(scene_image)]
19 | return scene_objects
20 |
21 |
22 | def build_example_scene2(width, height):
23 | """
24 | In this example, a new scene is created from scratch and a few emitters are places manually.
25 | One of the emitters uses an amplitude modulation object to change brightness over time
26 | """
27 | objects = []
28 |
29 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)
30 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')
31 | objects.append(StaticDampening(np.ones((height, width)), 48))
32 |
33 | # add a constant refractive index field
34 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))
35 |
36 | # add a simple point source
37 | objects.append(PointSource(200, 250, 0.19, 5))
38 |
39 | # add a point source with an amplitude modulator
40 | amplitude_modulator = ModulatorDiscreteSignal(np.random.randint(2, size=64), 0.0006)
41 | objects.append(PointSource(200, 350, 0.19, 5, amp_modulator=amplitude_modulator))
42 |
43 | return objects
44 |
45 |
46 | def simulate(scene_image_fn, num_iterations,
47 | simulation_steps_per_frame, write_videos,
48 | field_colormap, intensity_colormap,
49 | background_image_fn=None):
50 | # reset random number generator
51 | np.random.seed(0)
52 |
53 | # load scene image
54 | scene_image = cv2.cvtColor(cv2.imread(scene_image_fn), cv2.COLOR_BGR2RGB)
55 |
56 | background_image = None
57 | if background_image_fn is not None:
58 | background_image = cv2.imread(background_image_fn)
59 | background_image = cv2.resize(background_image, (scene_image.shape[1], scene_image.shape[0]))
60 |
61 | # create simulator and visualizer objects
62 | simulator = sim.WaveSimulator2D(scene_image.shape[1], scene_image.shape[0])
63 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
64 |
65 | # build simulation scene
66 | simulator.scene_objects = build_example_scene2(scene_image.shape[1], scene_image.shape[0])
67 |
68 | # create video writers
69 | if write_videos:
70 | video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'),
71 | 60, (scene_image.shape[1], scene_image.shape[0]))
72 | video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'),
73 | 60, (scene_image.shape[1], scene_image.shape[0]))
74 |
75 | # run simulation
76 | for i in range(num_iterations):
77 | simulator.update_scene()
78 | simulator.update_field()
79 | visualizer.update(simulator)
80 |
81 | if i % simulation_steps_per_frame == 0:
82 | frame_int = visualizer.render_intensity(1.0)
83 | frame_field = visualizer.render_field(1.0)
84 |
85 | if background_image is not None:
86 | frame_int = cv2.add(background_image, frame_int)
87 | frame_field = cv2.add(background_image, frame_field)
88 |
89 | # frame_int = cv2.pyrDown(frame_int)
90 | # frame_field = cv2.pyrDown(frame_field)
91 | cv2.imshow("Wave Simulation", frame_field) #cv2.resize(frame_int, dsize=(1024, 1024)))
92 | cv2.waitKey(1)
93 |
94 | if write_videos:
95 | video_writer1.write(frame_field)
96 | video_writer2.write(frame_int)
97 |
98 | if i % 128 == 0:
99 | print(f'{int((i+1)/num_iterations*100)}%')
100 |
101 |
102 | if __name__ == "__main__":
103 | print('This file contains tests for development and you may not bve able to run it without errors')
104 | print('Please take a look at the previded examples')
105 |
106 | # increase simulation_steps_per_frame to better utilize GPU
107 | # good colormaps for field: RdBu[invert=True], colormap_wave1, colormap_wave2, colormap_wave4, icefire
108 | simulate('../exxample_data/scene_lens_doubleslit.png',
109 | 20000,
110 | simulation_steps_per_frame=16,
111 | write_videos=True,
112 | field_colormap=vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05),
113 | # field_colormap=vis.get_colormap_lut('RdBu', invert=True, make_symmetric=True),
114 | intensity_colormap=vis.get_colormap_lut('afmhot', invert=False, black_level=0.0),
115 | background_image_fn=None)
116 |
117 |
--------------------------------------------------------------------------------
/wave_sim2d/examples/example0.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa
4 |
5 | import cv2
6 | import wave_sim2d.wave_visualizer as vis
7 | import wave_sim2d.wave_simulation as sim
8 | from wave_sim2d.scene_objects.source import *
9 | from wave_sim2d.scene_objects.static_refractive_index import *
10 |
11 | def build_scene():
12 | """
13 | This example creates the simplest possible simulation using a single emitter.
14 | """
15 | width = 512
16 | height = 512
17 | objects = [PointSource(200, 256, 0.1, 5)]
18 | # objects.append(StaticRefractiveIndexPolygon([[400, 255], [300, 200], [300, 300]], 1.5))
19 | # objects = [LineSource((200, 265), (250, 105), 0.2, 0.5)]
20 |
21 | return objects, width, height
22 |
23 |
24 | def main():
25 | # create colormaps
26 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)
27 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)
28 |
29 | # build simulation scene
30 | scene_objects, w, h = build_scene()
31 |
32 | # create simulator and visualizer objects
33 | simulator = sim.WaveSimulator2D(w, h, scene_objects)
34 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
35 |
36 | # run simulation
37 | for i in range(1000):
38 | simulator.update_scene()
39 | simulator.update_field()
40 | visualizer.update(simulator)
41 |
42 | # show field
43 | frame_field = visualizer.render_field(1.0)
44 | cv2.imshow("Wave Simulation Field", frame_field)
45 |
46 | # show intensity
47 | # frame_int = visualizer.render_intensity(1.0)
48 | # cv2.imshow("Wave Simulation Intensity", frame_int)
49 |
50 | cv2.waitKey(1)
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
56 |
--------------------------------------------------------------------------------
/wave_sim2d/examples/example1.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa
4 |
5 | import numpy as np
6 | import cv2
7 |
8 | import wave_sim2d.wave_visualizer as vis
9 | import wave_sim2d.wave_simulation as sim
10 | from wave_sim2d.scene_objects.static_image_scene import StaticImageScene
11 |
12 |
13 | def build_scene(scene_image_path):
14 | """
15 | This example uses the 'old' image scene description. See 'StaticImageScene' for more information.
16 | """
17 | # load scene image
18 | scene_image = cv2.cvtColor(cv2.imread(scene_image_path), cv2.COLOR_BGR2RGB)
19 |
20 | # create the scene object list with an 'StaticImageScene' entry as the only scene object
21 | # more scene objects can be added to the list to build more complex scenes
22 | scene_objects = [StaticImageScene(scene_image, source_fequency_scale=2.0)]
23 |
24 | return scene_objects, scene_image.shape[1], scene_image.shape[0]
25 |
26 |
27 | def main():
28 | # Set scene image path. The image encodes refractive index, dampening and emitters in its color channels
29 | # see 'static_image_scene.StaticImageScene' class for a more detailed description.
30 | # please take a look at the image to understand what is happening in the simulation
31 | scene_image_path = '../../example_data/scene_lens_doubleslit.png'
32 |
33 | # create colormaps
34 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)
35 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)
36 |
37 | # reset random number generator
38 | np.random.seed(0)
39 |
40 | # build simulation scene
41 | scene_objects, w, h = build_scene(scene_image_path)
42 |
43 | # create simulator and visualizer objects
44 | simulator = sim.WaveSimulator2D(w, h, scene_objects)
45 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
46 |
47 | # run simulation
48 | for i in range(2000):
49 | simulator.update_scene()
50 | simulator.update_field()
51 | visualizer.update(simulator)
52 |
53 | # visualize very N frames
54 | if (i % 4) == 0:
55 | # show field
56 | frame_field = visualizer.render_field(1.0)
57 | cv2.imshow("Wave Simulation Field", frame_field)
58 |
59 | # show intensity
60 | # frame_int = visualizer.render_intensity(1.0)
61 | # cv2.imshow("Wave Simulation Intensity", frame_int)
62 |
63 | cv2.waitKey(1)
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
69 |
--------------------------------------------------------------------------------
/wave_sim2d/examples/example2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa
4 |
5 | import numpy as np
6 | import cv2
7 | import wave_sim2d.wave_visualizer as vis
8 | import wave_sim2d.wave_simulation as sim
9 | from wave_sim2d.scene_objects.static_dampening import StaticDampening
10 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex
11 | from wave_sim2d.scene_objects.source import PointSource, ModulatorSmoothSquare
12 |
13 |
14 | def build_scene():
15 | """
16 | In this example, a new scene is created from scratch and a few emitters are places manually.
17 | One of the emitters uses an amplitude modulation object to change brightness over time
18 | """
19 | width = 600
20 | height = 600
21 | objects = []
22 |
23 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)
24 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')
25 | objects.append(StaticDampening(np.ones((height, width)), 32))
26 |
27 | # add a constant refractive index field
28 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))
29 |
30 | # add a simple point source
31 | objects.append(PointSource(200, 220, 0.2, 8))
32 |
33 | # add a point source with an amplitude modulator
34 | amplitude_modulator = ModulatorSmoothSquare(0.025, 0.0, smoothness=0.5)
35 | objects.append(PointSource(200, 380, 0.2, 8, amp_modulator=amplitude_modulator))
36 |
37 | return objects, width, height
38 |
39 |
40 | def main():
41 | # create colormaps
42 | field_colormap = vis.get_colormap_lut('colormap_wave4', invert=False, black_level=-0.05)
43 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)
44 |
45 | # reset random number generator
46 | np.random.seed(0)
47 |
48 | # build simulation scene
49 | scene_objects, w, h = build_scene()
50 |
51 | # create simulator and visualizer objects
52 | simulator = sim.WaveSimulator2D(w, h, scene_objects)
53 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
54 |
55 | # run simulation
56 | for i in range(2000):
57 | simulator.update_scene()
58 | simulator.update_field()
59 | visualizer.update(simulator)
60 |
61 | # visualize very N frames
62 | if (i % 2) == 0:
63 | # show field
64 | frame_field = visualizer.render_field(1.0)
65 | cv2.imshow("Wave Simulation Field", frame_field)
66 |
67 | # show intensity
68 | # frame_int = visualizer.render_intensity(1.0)
69 | # cv2.imshow("Wave Simulation Intensity", frame_int)
70 |
71 | cv2.waitKey(1)
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
77 |
--------------------------------------------------------------------------------
/wave_sim2d/examples/example3.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa
4 |
5 | import numpy as np
6 | import cupy as cp
7 | import math
8 | import cv2
9 |
10 | import wave_sim2d.wave_visualizer as vis
11 | import wave_sim2d.wave_simulation as sim
12 | from wave_sim2d.scene_objects.static_dampening import StaticDampening
13 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex
14 |
15 |
16 | def gaussian_kernel(size, sigma):
17 | """
18 | creates gaussian kernel with side length `l` and a sigma of `sig`
19 | """
20 | ax = np.linspace(-(size - 1) / 2., (size - 1) / 2., size)
21 | gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma))
22 | kernel = np.outer(gauss, gauss)
23 | return kernel / np.sum(kernel)
24 |
25 |
26 | class MovingCharge(sim.SceneObject):
27 | """
28 | Implements a point source scene object. The amplitude can be optionally modulated using a modulator object.
29 | :param x: center position x.
30 | :param y: center position y.
31 | :param frequency: motion frequency
32 | :param amplitude: motion amplitude
33 | """
34 | def __init__(self, x, y, frequency, amplitude):
35 | self.x = x
36 | self.y = y
37 | self.frequency = frequency
38 | self.amplitude = amplitude
39 | self.size = 11
40 |
41 | # create a smooth source shape
42 | self.source_array = cp.array(gaussian_kernel(self.size, self.size/3))
43 |
44 | def render(self, field, wave_speed_field, dampening_field):
45 | # no changes to the refractive index or dampening field required for this class
46 | pass
47 |
48 | def update_field(self, field, t):
49 | fade_in = math.sin(min(t*0.1, math.pi/2))
50 |
51 | # write the moving charge to the field
52 | x = self.x + math.sin(self.frequency * t*0.05)*200
53 | y = self.y + math.sin(self.frequency * t)*self.amplitude
54 |
55 | # copy source shape to current position into field
56 | wh = self.source_array.shape[1]//2
57 | hh = self.source_array.shape[0]//2
58 | field[y-hh:y+hh+1, x-wh:x+wh+1] += self.source_array * fade_in * 0.25
59 |
60 |
61 | def build_scene():
62 | """
63 | In this example, a custom scene object is implemented and used to simulate a moving field disturbance.
64 | """
65 | width = 600
66 | height = 600
67 | objects = []
68 |
69 | # Add a static dampening field without any dampending in the interior (value 1.0 means no dampening)
70 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')
71 | objects.append(StaticDampening(np.ones((height, width)), 64))
72 |
73 | # add a constant refractive index field
74 | objects.append(StaticRefractiveIndex(np.full((height, width), 1.5)))
75 |
76 | # add a simple point source
77 | objects.append(MovingCharge(300, 300, 0.1, 10))
78 |
79 | return objects, width, height
80 |
81 |
82 | def main():
83 | # create colormaps
84 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)
85 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)
86 |
87 | # reset random number generator
88 | np.random.seed(0)
89 |
90 | # build simulation scene
91 | scene_objects, w, h = build_scene()
92 |
93 | # create simulator and visualizer objects
94 | simulator = sim.WaveSimulator2D(w, h, scene_objects)
95 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
96 |
97 | # run simulation
98 | for i in range(8000):
99 | simulator.update_scene()
100 | simulator.update_field()
101 | visualizer.update(simulator)
102 |
103 | # visualize very N frames
104 | if (i % 2) == 0:
105 | # show field
106 | frame_field = visualizer.render_field(1.0)
107 | cv2.imshow("Wave Simulation Field", frame_field)
108 |
109 | # show intensity
110 | # frame_int = visualizer.render_intensity(1.0)
111 | # cv2.imshow("Wave Simulation Intensity", frame_int)
112 |
113 | cv2.waitKey(1)
114 |
115 |
116 | if __name__ == "__main__":
117 | main()
118 |
119 |
--------------------------------------------------------------------------------
/wave_sim2d/examples/example4.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.append(os.path.join(os.path.dirname(__file__), '../')) # noqa
4 |
5 | import cv2
6 | import numpy as np
7 | import cupy as cp
8 | import wave_sim2d.wave_visualizer as vis
9 | import wave_sim2d.wave_simulation as sim
10 | from wave_sim2d.scene_objects.source import *
11 | from wave_sim2d.scene_objects.static_refractive_index import *
12 | from wave_sim2d.scene_objects.static_dampening import *
13 |
14 |
15 | def build_scene():
16 | """
17 | This example creates fabry pirot cavity and shows the standing waves
18 | """
19 | width = 768
20 | height = 512
21 | objects = []
22 |
23 | # Add a static dampening field without any dampening in the interior (value 1.0 means no dampening)
24 | # However a dampening layer at the border is added to avoid reflections (see parameter 'border thickness')
25 | objects.append(StaticDampening(np.ones((height, width)), 48))
26 |
27 | # add nonlinear refractive index field
28 | objects.append(StaticRefractiveIndexBox((50, height//2), (50, int(height*0.8)), 0.0, 100.0))
29 | objects.append(StaticRefractiveIndexBox((width-180, height//2), (40, int(height*0.8)), 0.0, 10.0))
30 |
31 | # add a point source with an amplitude modulator
32 | # objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.0215, amplitude=0.5))
33 | objects.append(LineSource((77, height//2-140), (77, height//2+140), 0.1003, amplitude=0.3))
34 |
35 | return objects, width, height
36 |
37 |
38 | def show_field(field, brightness_scale):
39 | gray = (cp.clip(field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8)
40 | img = gray.get()
41 | cv2.imshow("Strain Simulation Field", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
42 |
43 |
44 | def main():
45 | write_videos = False
46 | write_video_frame_every = 2
47 |
48 | # create colormaps
49 | field_colormap = vis.get_colormap_lut('colormap_wave1', invert=False, black_level=-0.05)
50 | intensity_colormap = vis.get_colormap_lut('afmhot', invert=False, black_level=0.0)
51 |
52 | # build simulation scene
53 | scene_objects, w, h = build_scene()
54 |
55 | # create simulator and visualizer objects
56 | simulator = sim.WaveSimulator2D(w, h, scene_objects)
57 | visualizer = vis.WaveVisualizer(field_colormap=field_colormap, intensity_colormap=intensity_colormap)
58 |
59 | # optional create video writers
60 | if write_videos:
61 | video_writer1 = cv2.VideoWriter('simulation_field.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h))
62 | video_writer2 = cv2.VideoWriter('simulation_intensity.avi', cv2.VideoWriter_fourcc(*'FFV1'), 60, (w, h))
63 |
64 | # run simulation
65 | for i in range(100000):
66 | simulator.update_scene()
67 | simulator.update_field()
68 |
69 | visualizer.update(simulator)
70 | # show field
71 | frame_field = visualizer.render_field(1.0)
72 | cv2.imshow("Wave Simulation Field", frame_field)
73 |
74 | # show intensity
75 | frame_int = visualizer.render_intensity(1.0)
76 | # cv2.imshow("Wave Simulation Intensity", frame_int)
77 |
78 | if write_videos and (i % write_video_frame_every) == 0:
79 | video_writer1.write(frame_field)
80 | video_writer2.write(frame_int)
81 |
82 | cv2.waitKey(1)
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
87 |
88 |
--------------------------------------------------------------------------------
/wave_sim2d/main.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | print('please run one of the examples from the source/example folder...')
3 |
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/__pycache__/border_absorber.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0x23/WaveSimulator2D/16de78f7dd308f93fbd681fbafcbf7d2fc90b03a/wave_sim2d/scene_objects/__pycache__/border_absorber.cpython-310.pyc
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/source.py:
--------------------------------------------------------------------------------
1 | from wave_sim2d.wave_simulation import SceneObject
2 | import cupy as cp
3 | import numpy as np
4 | import math
5 |
6 |
7 | class PointSource(SceneObject):
8 | """
9 | Implements a point source scene object. The amplitude can be optionally modulated using a modulator object.
10 | :param x: source position x.
11 | :param y: source position y.
12 | :param frequency: emitting frequency.
13 | :param amplitude: emitting amplitude, not used when an amplitude modulator is given
14 | :param phase: emitter phase
15 | :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source
16 | over time.
17 | """
18 | def __init__(self, x, y, frequency, amplitude=1.0, phase=0, amp_modulator=None):
19 | self.x = x
20 | self.y = y
21 | self.frequency = frequency
22 | self.amplitude = amplitude
23 | self.phase = phase
24 | self.amplitude_modulator = amp_modulator
25 |
26 | def set_amplitude_modulator(self, func):
27 | self.amplitude_modulator = func
28 |
29 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
30 | pass
31 |
32 | def update_field(self, field, t):
33 | if self.amplitude_modulator is not None:
34 | amplitude = self.amplitude_modulator(t) * self.amplitude
35 | else:
36 | amplitude = self.amplitude
37 |
38 | v = cp.sin(self.phase + self.frequency * t) * amplitude
39 | field[self.y, self.x] = v
40 |
41 | def render_visualization(self, image: np.ndarray):
42 | """ renders a visualization of the scene object to the image """
43 | pass
44 |
45 |
46 | class LineSource(SceneObject):
47 | """
48 | Implements a line source scene object. The amplitude can be optionally modulated using a modulator object.
49 | The source emits along a line defined by a start and end point.
50 | :param start: starting (x, y) coordinates of the line as a tuple.
51 | :param end: ending (x, y) coordinates of the line as a tuple.
52 | :param frequency: emitting frequency.
53 | :param amplitude: emitting amplitude, not used when an amplitude modulator is given
54 | :param phase: emitter phase
55 | :param amp_modulator: optional amplitude modulator. This can be used to change the amplitude of the source
56 | over time.
57 | """
58 | def __init__(self, start, end, frequency, amplitude=1.0, phase=0, amp_modulator=None):
59 | self.start = start
60 | self.end = end
61 | self.frequency = frequency
62 | self.amplitude = amplitude
63 | self.phase = phase
64 | self.amplitude_modulator = amp_modulator
65 |
66 | def set_amplitude_modulator(self, func):
67 | self.amplitude_modulator = func
68 |
69 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
70 | pass
71 |
72 | def update_field(self, field, t):
73 | if self.amplitude_modulator is not None:
74 | amplitude = self.amplitude_modulator(t) * self.amplitude
75 | else:
76 | amplitude = self.amplitude
77 |
78 | v = cp.sin(self.phase + self.frequency * t) * amplitude
79 |
80 | # Determine the points along the line using NumPy
81 | x1, y1 = self.start
82 | x2, y2 = self.end
83 |
84 | distance = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
85 | num_points = int(distance) + 1
86 |
87 | if num_points > 0:
88 | x_coords = cp.linspace(x1, x2, num_points).round().astype(int)
89 | y_coords = cp.linspace(y1, y2, num_points).round().astype(int)
90 |
91 | # Create boolean masks for valid indices
92 | valid_x = (x_coords >= 0) & (x_coords < field.shape[1])
93 | valid_y = (y_coords >= 0) & (y_coords < field.shape[0])
94 | valid_indices = valid_x & valid_y
95 |
96 | # Use these valid indices to update the field directly
97 | valid_y_coords = y_coords[valid_indices]
98 | valid_x_coords = x_coords[valid_indices]
99 | field[valid_y_coords, valid_x_coords] = v
100 |
101 | def render_visualization(self, image: np.ndarray):
102 | """ renders a visualization of the scene object to the image """
103 | pass
104 |
105 | # --- Modulators -------------------------------------------------------------------------------------------------------
106 |
107 |
108 | class ModulatorSmoothSquare:
109 | """
110 | A modulator that creates a smoothed square wave
111 | """
112 | def __init__(self, frequency, phase, smoothness=0.5):
113 | self.frequency = frequency
114 | self.phase = phase
115 | self.smoothness = min(max(smoothness, 1e-4), 1.0)
116 |
117 | def __call__(self, t):
118 | s = math.pow(self.smoothness, 4.0)
119 | a = (0.5 / math.atan(1.0/s)) * math.atan(math.sin(t * self.frequency + self.phase) / s)+0.5
120 | return a
121 |
122 |
123 | class ModulatorDiscreteSignal:
124 | """
125 | A modulator that creates a smoothed binary signal
126 | """
127 | def __init__(self, signal_array, time_factor, transition_slope=8.0):
128 | self.signal_array = signal_array
129 | self.time_factor = time_factor
130 | self.transition_slope = transition_slope
131 |
132 | def __call__(self, t):
133 | def smooth_step(t):
134 | return t * t * (3 - 2 * t)
135 |
136 | # Wrap around the position if it's outside the array range
137 | sl = len(self.signal_array)
138 | t = math.fmod(t*self.time_factor, sl)
139 |
140 | # Find the indices of the neighboring values
141 | index_low = int(t)
142 | index_high = (index_low + 1) % sl
143 |
144 | # Calculate the interpolation factor
145 | tf = (t - index_low)
146 | tf = max(0.0, min(1.0, (tf-0.5)*self.transition_slope+0.5))
147 |
148 | # Use smooth step to interpolate between neighboring values
149 | l = smooth_step(tf)
150 | interpolated_value = (1 - l) * self.signal_array[index_low] + l * self.signal_array[index_high]
151 |
152 | return interpolated_value
153 |
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/static_dampening.py:
--------------------------------------------------------------------------------
1 | from wave_sim2d.wave_simulation import SceneObject
2 | import cupy as cp
3 | import numpy as np
4 |
5 |
6 | class StaticDampening(SceneObject):
7 | """
8 | Implements a static dampening field that overwrites the entire domain.
9 | Therefore, us this as base layer in your scene.
10 | """
11 |
12 | def __init__(self, dampening_field, border_thickness):
13 | """
14 | Creates a static dampening field object
15 | @param dampening_field: A NxM array with dampening factors (1.0 equals no dampening) of the same size as the simulation domain.
16 | @param pml_thickness: Thickness of the Perfectly Matched Layer (PML) at the boundaries to prevent reflections.
17 | """
18 | w = dampening_field.shape[1]
19 | h = dampening_field.shape[0]
20 | self.d = cp.ones((h, w), dtype=cp.float32)
21 | self.d = cp.clip(cp.array(dampening_field), 0.0, 1.0)
22 |
23 | # apply border dampening
24 | for i in range(border_thickness):
25 | v = (i / border_thickness) ** 0.5
26 | self.d[i, i:w - i] = v
27 | self.d[-(1 + i), i:w - i] = v
28 | self.d[i:h - i, i] = v
29 | self.d[i:h - i, -(1 + i)] = v
30 |
31 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
32 | assert (dampening_field.shape == self.d.shape)
33 |
34 | # overwrite existing dampening field
35 | dampening_field[:] = self.d
36 |
37 | def update_field(self, field: cp.ndarray, t):
38 | pass
39 |
40 | def render_visualization(self, image: np.ndarray):
41 | """ renders a visualization of the scene object to the image """
42 | pass
43 |
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/static_image_scene.py:
--------------------------------------------------------------------------------
1 | from wave_sim2d.wave_simulation import SceneObject
2 |
3 | import numpy as np
4 | import cupy as cp
5 | from wave_sim2d.scene_objects.static_dampening import StaticDampening
6 | from wave_sim2d.scene_objects.static_refractive_index import StaticRefractiveIndex
7 |
8 |
9 | class StaticImageScene(SceneObject):
10 | """
11 | Implements static scene, where the RGB channels of the input image encode the refractive index, the dampening and sources.
12 | This class allows to use an image editor to create scenes.
13 | """
14 | def __init__(self, scene_image, source_amplitude=1.0, source_fequency_scale=1.0):
15 | """
16 | load source from an image description
17 | The simulation scenes are given as an 8Bit RGB image with the following channel semantics:
18 | * Red: The Refractive index times 100 (for refractive index 1.5 you would use value 150)
19 | * Green: Each pixel with a green value above 0 is a sinusoidal wave source. The green value
20 | defines its frequency. WARNING: Do not use antialiasing for the green channel !
21 | * Blue: Absorbtion field. Larger values correspond to higher dampening of the waves,
22 | use graduated transitions to avoid reflections
23 | """
24 | # Set the opacity of source pixels to incoming waves. If the opacity is 0.0
25 | # the field will be completely overwritten by the source term
26 | # a nonzero value (e.g 0.5) allows for antialiasing of sources to work
27 | self.source_opacity = 0.9
28 |
29 | # set refractive index field
30 | self.refractive_index = StaticRefractiveIndex(scene_image[:, :, 0] / 100)
31 |
32 | # set absorber field
33 | self.dampening = StaticDampening(1.0 - scene_image[:, :, 2] / 255, border_thickness=48)
34 |
35 | # set sources, each entry describes a source with the following parameters:
36 | # (x, y, phase, amplitude, frequency)
37 | sources_pos = np.flip(np.argwhere(scene_image[:, :, 1] > 0), axis=1)
38 | phase_amplitude_freq = np.tile(np.array([0, source_amplitude, 0.3]), (sources_pos.shape[0], 1))
39 | self.sources = np.concatenate((sources_pos, phase_amplitude_freq), axis=1)
40 |
41 | # set source frequency to channel value
42 | self.sources[:, 4] = scene_image[sources_pos[:, 1], sources_pos[:, 0], 1] / 255 * 0.5 * source_fequency_scale
43 | self.sources = cp.array(self.sources).astype(cp.float32)
44 |
45 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
46 | """
47 | render the stat
48 | """
49 | self.dampening.render(field, wave_speed_field, dampening_field)
50 | self.refractive_index.render(field, wave_speed_field, dampening_field)
51 |
52 | def update_field(self, field: cp.ndarray, t):
53 | # Update the sources in the simulation field based on their properties.
54 | v = cp.sin(self.sources[:, 2]+self.sources[:, 4]*t)*self.sources[:, 3]
55 | coords = self.sources[:, 0:2].astype(cp.int32)
56 |
57 | o = self.source_opacity
58 | field[coords[:, 1], coords[:, 0]] = field[coords[:, 1], coords[:, 0]]*o + v*(1.0-o)
59 |
60 | def render_visualization(self, image: np.ndarray):
61 | """ renders a visualization of the scene object to the image """
62 | pass
63 |
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/static_refractive_index.py:
--------------------------------------------------------------------------------
1 | from wave_sim2d.wave_simulation import SceneObject
2 | import cupy as cp
3 | import numpy as np
4 | import cv2
5 |
6 |
7 | class StaticRefractiveIndex(SceneObject):
8 | """
9 | Implements a static refractive index field that overwrites the entire domain with a constant IOR value.
10 | Use this as base layer in your scene.
11 | """
12 |
13 | def __init__(self, refractive_index_field):
14 | """
15 | Creates a static refractive index field object
16 | :param refractive_index_field: The refractive index field, same size as the source.
17 | Note that values below 0.9 are clipped to prevent the simulation
18 | from becoming instable
19 | """
20 | shape = refractive_index_field.shape
21 | self.c = cp.ones((shape[0], shape[1]), dtype=cp.float32)
22 | self.c = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0)
23 |
24 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
25 | assert (wave_speed_field.shape == self.c.shape)
26 | wave_speed_field[:] = self.c
27 |
28 | def update_field(self, field: cp.ndarray, t):
29 | pass
30 |
31 | def render_visualization(self, image: np.ndarray):
32 | """ renders a visualization of the scene object to the image """
33 | pass
34 |
35 |
36 | class StaticRefractiveIndexPolygon(SceneObject):
37 | """
38 | Draws a static polygon with a given refractive index into the wave_speed_field using an
39 | anti-aliased mask and indexing. Caches the pixel coordinates and mask values.
40 | """
41 |
42 | def __init__(self, vertices, refractive_index):
43 | """
44 | Initializes the StaticRefractiveIndexPolygon.
45 |
46 | Args:
47 | vertices (list or np.ndarray): A list or array of (x, y) coordinates defining the polygon.
48 | refractive_index (float): The refractive index of the polygon. Values are clamped to [0.9, 10.0].
49 | """
50 | self.vertices = np.array(vertices, dtype=np.float32)
51 | self.refractive_index = min(max(refractive_index, 0.9), 10.0)
52 | self._cached_coords = None
53 | self._cached_mask_values = None
54 | self._cached_field_shape = (0, 0)
55 |
56 | def _create_polygon_data(self, field_shape):
57 | """
58 | Creates and caches the pixel coordinates and anti-aliased mask values for the polygon.
59 |
60 | Args:
61 | field_shape (tuple): The shape (rows, cols) of the simulation field.
62 |
63 | Returns:
64 | tuple: A tuple containing:
65 | - coords (tuple of cp.ndarray): (y_coordinates, x_coordinates) of the polygon pixels within the field.
66 | - mask_values (cp.ndarray): Corresponding anti-aliased mask values (0.0 to 1.0).
67 | """
68 | if self._cached_coords is not None and self._cached_field_shape == field_shape:
69 | return self._cached_coords, self._cached_mask_values
70 |
71 | rows, cols = field_shape
72 |
73 | # Find the bounding box of the polygon
74 | min_x = np.min(self.vertices[:, 0])
75 | max_x = np.max(self.vertices[:, 0])
76 | min_y = np.min(self.vertices[:, 1])
77 | max_y = np.max(self.vertices[:, 1])
78 |
79 | mask_width = int(np.ceil(max_x - min_x)) + 1
80 | mask_height = int(np.ceil(max_y - min_y)) + 1
81 | offset_x = int(np.floor(min_x))
82 | offset_y = int(np.floor(min_y))
83 |
84 | # Create the mask
85 | mask = np.zeros((mask_height, mask_width), dtype=np.float32)
86 | translated_vertices = self.vertices - [offset_x, offset_y]
87 | translated_vertices_cv = np.round(translated_vertices).astype(np.int32)
88 | cv2.fillPoly(mask, [translated_vertices_cv], 1.0, lineType=cv2.LINE_AA)
89 |
90 | # Get coordinates and mask values of non-black pixels
91 | coords_y, coords_x = np.where(mask > 0)
92 | mask_values = mask[coords_y, coords_x]
93 |
94 | # Adjust coordinates to the position in the main field
95 | global_coords_y = coords_y + offset_y
96 | global_coords_x = coords_x + offset_x
97 |
98 | # Perform out-of-bounds check here
99 | in_bounds = (global_coords_y >= 0) & (global_coords_y < rows) & \
100 | (global_coords_x >= 0) & (global_coords_x < cols)
101 |
102 | valid_global_y = global_coords_y[in_bounds]
103 | valid_global_x = global_coords_x[in_bounds]
104 | valid_mask_values = mask_values[in_bounds]
105 |
106 | self._cached_coords = (cp.array(valid_global_y), cp.array(valid_global_x))
107 | self._cached_mask_values = cp.array(valid_mask_values, dtype=cp.float32)
108 | self._cached_field_shape = field_shape
109 | return self._cached_coords, self._cached_mask_values
110 |
111 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
112 | coords, mask_values = self._create_polygon_data(wave_speed_field.shape)
113 |
114 | # Use advanced indexing to update the field and perform alpha blending
115 | bg_wave_speed = wave_speed_field[coords[0], coords[1]]
116 | wave_speed_field[coords[0], coords[1]] = (bg_wave_speed * (1.0 - mask_values) +
117 | mask_values / self.refractive_index)
118 |
119 | def update_field(self, field: cp.ndarray, t):
120 | pass
121 |
122 | def render_visualization(self, image: np.ndarray):
123 | vertices = np.round(self.vertices).astype(np.int32)
124 | cv2.fillPoly(image, [vertices], (60, 60, 60), lineType=cv2.LINE_AA)
125 |
126 |
127 | class StaticRefractiveIndexBox(StaticRefractiveIndexPolygon):
128 | """
129 | Draws a static rotated box with a given refractive index into the wave_speed_field by
130 | inheriting from StaticRefractiveIndexPolygon.
131 | """
132 |
133 | def __init__(self, center: tuple, box_size: tuple, box_angle_rad: float, refractive_index: float):
134 | """
135 | Initializes the StaticRefractiveIndexBox.
136 |
137 | Args:
138 | center (tuple): A tuple (center_x, center_y) representing the box's center.
139 | box_size (tuple): A tuple (width, height) representing the box's dimensions.
140 | box_angle_rad (float): The rotation angle of the box in radians (counter-clockwise).
141 | refractive_index (float): The refractive index of the box. Values are clamped to [0.9, 10.0].
142 | """
143 | self.center = center
144 | self.box_size = box_size
145 | self.box_angle_rad = box_angle_rad
146 | refractive_index = min(max(refractive_index, 0.9), 10.0)
147 |
148 | # Unpack center and box size
149 | center_x, center_y = self.center
150 | width, height = self.box_size
151 |
152 | # Calculate the vertices of the rotated box
153 | half_width = width / 2
154 | half_height = height / 2
155 | local_vertices = np.array([[-half_width, -half_height],
156 | [half_width, -half_height],
157 | [half_width, half_height],
158 | [-half_width, half_height]], dtype=np.float32)
159 |
160 | # Create the rotation matrix
161 | rotation_matrix = cv2.getRotationMatrix2D((0, 0), np.rad2deg(self.box_angle_rad), 1)
162 |
163 | # Rotate the local vertices
164 | rotated_vertices = cv2.transform(np.array([local_vertices]), rotation_matrix)[0]
165 |
166 | # Translate the rotated vertices to the center
167 | translated_vertices = rotated_vertices + [center_x, center_y]
168 |
169 | # Initialize the parent class (StaticRefractiveIndexPolygon) with the vertices
170 | super().__init__(translated_vertices, refractive_index)
171 |
--------------------------------------------------------------------------------
/wave_sim2d/scene_objects/strain_refractive_index.py:
--------------------------------------------------------------------------------
1 | from wave_sim2d.wave_simulation import SceneObject
2 | import cupy as cp
3 | import cupyx.scipy.signal
4 | import numpy as np
5 |
6 | class StrainRefractiveIndex(SceneObject):
7 | """
8 | Implements a dynamic refractive index field that linearly depends on the strain of the current field.
9 | The refractive index within the entire domain is overwritten
10 | """
11 |
12 | def __init__(self, refractive_index_offset, coupling_constant):
13 | """
14 | Creates a strain refractive index field object
15 | :param coupling_constant: coupling constant between the strain and the refractive index
16 | """
17 | self.coupling_constant = coupling_constant
18 | self.refractive_index_offset = refractive_index_offset
19 |
20 | self.du_dx_kernel = cp.array([[-1, 0.0, 1]])
21 | self.du_dy_kernel = cp.array([[-1], [0.0], [1]])
22 |
23 | self.strain_field = None
24 |
25 | def render(self, field: cp.ndarray, wave_speed_field: cp.ndarray, dampening_field: cp.ndarray):
26 | # compute strain
27 | du_dx = cupyx.scipy.signal.convolve2d(field, self.du_dx_kernel, mode='same', boundary='fill')
28 | du_dy = cupyx.scipy.signal.convolve2d(field, self.du_dy_kernel, mode='same', boundary='fill')
29 |
30 | self.strain_field = cp.sqrt(du_dx**2 + du_dy**2)
31 |
32 | # compute refractive index from strain
33 | refractive_index_field = self.refractive_index_offset + self.strain_field*self.coupling_constant
34 |
35 | # assign wave speed using refractive index from above
36 | wave_speed_field[:] = 1.0/cp.clip(cp.array(refractive_index_field), 0.9, 10.0)
37 |
38 | def update_field(self, field: cp.ndarray, t):
39 | pass
40 |
41 | def render_visualization(self, image: np.ndarray):
42 | """ renders a visualization of the scene object to the image """
43 | pass
44 |
--------------------------------------------------------------------------------
/wave_sim2d/wave_simulation.py:
--------------------------------------------------------------------------------
1 | import cupy
2 | import numpy as np
3 | import cupy as cp
4 | import cupyx.scipy.signal
5 | from abc import ABC, abstractmethod
6 |
7 |
8 | class SceneObject(ABC):
9 | """
10 | Interface for simulation scene objects. A scene object is anything defining or modifying the simulation scene.
11 | For example: Light sources, Absorbers or regions with specific refractive index. Scene objects can change the
12 | simulated field and draw their contribution to the wave speed field and dampening field each frame """
13 |
14 | @abstractmethod
15 | def render(self, field: cupy.ndarray, wave_speed_field: cupy.ndarray, dampening_field: cupy.ndarray):
16 | """ renders the scene objects contribution to the wave speed field and dampening field """
17 | pass
18 |
19 | @abstractmethod
20 | def update_field(self, field: cupy.ndarray, t):
21 | """ performs updates to the field itself, e.g. for adding sources """
22 | pass
23 |
24 | @abstractmethod
25 | def render_visualization(self, image: np.ndarray):
26 | """ renders a visualization of the scene object to the image """
27 | pass
28 |
29 |
30 | class WaveSimulator2D:
31 | """
32 | Simulates the 2D wave equation
33 | The system assumes units, where the wave speed is 1.0 pixel/timestep
34 | source frequency should be adjusted accordingly
35 | """
36 | def __init__(self, w, h, scene_objects, initial_field=None):
37 | """
38 | Initialize the 2D wave simulator.
39 | @param w: Width of the simulation grid.
40 | @param h: Height of the simulation grid.
41 | """
42 | self.global_dampening = 1.0
43 | self.c = cp.ones((h, w), dtype=cp.float32) # wave speed field (from refractive indices)
44 | self.d = cp.ones((h, w), dtype=cp.float32) # dampening field
45 | self.u = cp.zeros((h, w), dtype=cp.float32) # field values
46 | self.u_prev = cp.zeros((h, w), dtype=cp.float32) # field values of prev frame
47 |
48 | if initial_field is not None:
49 | assert w == initial_field.shape[1] and h == initial_field.shape[2], 'width/height of initial field invalid'
50 | self.u[:] = initial_field
51 | self.u_prev[:] = initial_field
52 |
53 | # Define Laplacian kernel
54 | self.laplacian_kernel = cp.array([[0.066, 0.184, 0.066],
55 | [0.184, -1.0, 0.184],
56 | [0.066, 0.184, 0.066]])
57 |
58 | # self.laplacian_kernel = cp.array([[0.05, 0.2, 0.05],
59 | # [0.2, -1.0, 0.2],
60 | # [0.05, 0.2, 0.05]])
61 |
62 | # self.laplacian_kernel = cp.array([[0.103, 0.147, 0.103],
63 | # [0.147, -1.0, 0.147],
64 | # [0.103, 0.147, 0.103]])
65 |
66 | self.t = 0
67 | self.dt = 1.0
68 |
69 | self.scene_objects = scene_objects if scene_objects is not None else []
70 |
71 | def reset_time(self):
72 | """
73 | Reset the simulation time to zero.
74 | """
75 | self.t = 0.0
76 |
77 | def update_field(self):
78 | """
79 | Update the simulation field based on the wave equation.
80 | """
81 | # calculate laplacian using convolution
82 | laplacian = cupyx.scipy.signal.convolve2d(self.u, self.laplacian_kernel, mode='same', boundary='fill')
83 |
84 | # update field
85 | v = (self.u - self.u_prev) * self.d * self.global_dampening
86 | r = (self.u + v + laplacian * (self.c * self.dt)**2)
87 |
88 | self.u_prev[:] = self.u
89 | self.u[:] = r
90 |
91 | self.t += self.dt
92 |
93 | def update_scene(self):
94 | # clear wave speed field and dampening field
95 | self.c.fill(1.0)
96 | self.d.fill(1.0)
97 |
98 | for obj in self.scene_objects:
99 | obj.render(self.u, self.c, self.d)
100 |
101 | for obj in self.scene_objects:
102 | obj.update_field(self.u, self.t)
103 |
104 | def get_field(self):
105 | """
106 | Get the current state of the simulation field.
107 | @return: A 2D array representing the simulation field.
108 | """
109 | return self.u
110 |
111 | def render_visualization(self, image=None):
112 | # clear wave speed field and dampening field
113 | if image is None:
114 | image = np.zeros((self.c.shape[0], self.c.shape[1], 3), dtype=np.uint8)
115 |
116 | for obj in self.scene_objects:
117 | obj.render_visualization(image)
118 |
119 | return image
120 |
121 |
122 |
--------------------------------------------------------------------------------
/wave_sim2d/wave_visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cupy as cp
3 | import cv2
4 | import matplotlib.pyplot
5 |
6 | colormap_icefire = [[179, 224, 216], [178, 223, 216], [176, 222, 215], [175, 221, 215], [173, 219, 214], [171, 218, 214], [169, 217, 214], [167, 215, 213], [165, 214, 213], [162, 212, 212], [160, 210, 212], [157, 209, 211], [154, 207, 211], [151, 205, 210], [148, 203, 210], [146, 201, 209], [143, 199, 209], [140, 198, 208], [137, 196, 208], [134, 194, 208], [131, 192, 207], [128, 190, 207], [125, 188, 207], [122, 187, 207], [119, 185, 206], [116, 183, 206], [113, 181, 206], [110, 179, 206], [108, 177, 206], [105, 176, 205], [102, 174, 205], [99, 172, 205], [97, 170, 205], [94, 168, 205], [91, 166, 205], [89, 164, 205], [86, 162, 205], [84, 161, 205], [82, 159, 205], [79, 157, 205], [77, 155, 205], [75, 153, 206], [73, 151, 206], [71, 149, 206], [69, 147, 206], [68, 145, 206], [66, 143, 206], [65, 140, 206], [64, 138, 206], [63, 136, 206], [62, 134, 206], [61, 132, 206], [61, 130, 205], [61, 127, 205], [60, 125, 205], [60, 123, 204], [60, 121, 203], [60, 118, 203], [61, 116, 202], [61, 114, 201], [61, 112, 200], [62, 109, 198], [62, 107, 197], [63, 105, 195], [64, 103, 194], [65, 100, 192], [65, 98, 190], [66, 96, 187], [67, 94, 185], [67, 92, 183], [68, 90, 180], [68, 88, 177], [69, 86, 174], [69, 85, 171], [69, 83, 168], [70, 81, 165], [70, 79, 162], [70, 78, 158], [69, 76, 155], [69, 75, 151], [69, 73, 148], [68, 72, 144], [68, 70, 141], [67, 69, 137], [66, 67, 134], [66, 66, 130], [65, 65, 127], [64, 63, 123], [63, 62, 120], [62, 61, 116], [61, 60, 113], [60, 59, 109], [59, 57, 106], [58, 56, 103], [57, 55, 99], [55, 54, 96], [54, 53, 93], [53, 52, 90], [52, 50, 87], [51, 49, 84], [50, 48, 81], [48, 47, 78], [47, 46, 75], [46, 45, 72], [45, 44, 70], [44, 43, 67], [43, 42, 65], [42, 41, 62], [41, 40, 60], [40, 39, 57], [39, 38, 55], [38, 37, 53], [37, 37, 51], [37, 36, 49], [36, 35, 47], [35, 35, 45], [35, 34, 44], [34, 33, 42], [34, 33, 41], [33, 32, 39], [33, 32, 38], [33, 32, 37], [33, 31, 36], [33, 31, 35], [33, 31, 35], [34, 30, 34], [34, 30, 33], [34, 30, 33], [35, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]]
7 | colormap_wave1 = [[255, 255, 255], [254, 254, 253], [254, 253, 252], [253, 252, 250], [253, 250, 248], [252, 249, 246], [252, 248, 244], [251, 246, 242], [251, 245, 240], [250, 243, 237], [250, 242, 235], [249, 240, 232], [248, 238, 230], [248, 237, 227], [247, 235, 224], [247, 233, 221], [246, 231, 218], [245, 229, 215], [245, 227, 212], [244, 225, 209], [243, 223, 206], [242, 221, 203], [242, 219, 200], [241, 217, 196], [240, 215, 193], [239, 213, 190], [239, 211, 186], [238, 208, 183], [237, 206, 179], [236, 204, 176], [235, 202, 172], [234, 199, 169], [233, 197, 165], [232, 195, 162], [231, 192, 158], [230, 190, 155], [230, 188, 151], [228, 185, 148], [227, 183, 144], [226, 181, 141], [225, 178, 137], [224, 176, 134], [223, 174, 130], [222, 171, 127], [221, 169, 124], [219, 167, 120], [218, 164, 117], [217, 162, 114], [216, 160, 111], [214, 157, 108], [213, 155, 105], [212, 153, 102], [210, 151, 99], [209, 149, 96], [207, 146, 93], [206, 144, 91], [204, 142, 88], [203, 140, 85], [201, 138, 83], [199, 136, 80], [198, 134, 78], [196, 132, 76], [194, 130, 74], [193, 128, 72], [191, 127, 70], [189, 125, 69], [187, 123, 67], [185, 121, 65], [183, 119, 63], [180, 117, 62], [178, 115, 60], [176, 113, 58], [173, 111, 56], [170, 109, 55], [168, 107, 53], [165, 105, 52], [162, 102, 50], [159, 100, 49], [156, 98, 47], [154, 96, 46], [151, 94, 44], [148, 92, 43], [144, 90, 41], [141, 87, 40], [138, 85, 39], [135, 83, 37], [132, 81, 36], [129, 79, 35], [125, 76, 34], [122, 74, 33], [119, 72, 31], [115, 70, 30], [112, 68, 29], [109, 66, 28], [105, 64, 27], [102, 62, 27], [99, 60, 26], [96, 58, 25], [93, 56, 25], [89, 54, 25], [86, 52, 25], [83, 51, 25], [80, 49, 25], [77, 47, 25], [74, 45, 25], [71, 44, 25], [68, 42, 25], [65, 41, 25], [62, 39, 25], [60, 38, 25], [57, 37, 25], [54, 35, 25], [52, 34, 25], [49, 33, 25], [47, 32, 25], [45, 31, 25], [43, 30, 25], [40, 29, 25], [39, 28, 25], [37, 28, 25], [35, 27, 25], [33, 27, 25], [32, 26, 25], [30, 26, 25], [29, 25, 25], [28, 25, 25], [27, 25, 25], [26, 25, 25], [26, 25, 26], [26, 26, 27], [26, 26, 28], [26, 26, 30], [26, 27, 31], [26, 27, 33], [26, 28, 34], [26, 29, 36], [26, 30, 38], [26, 31, 40], [26, 32, 42], [26, 33, 44], [26, 34, 47], [26, 35, 49], [26, 37, 51], [26, 38, 54], [26, 40, 56], [26, 41, 59], [26, 43, 62], [26, 44, 64], [26, 46, 67], [26, 48, 70], [26, 50, 73], [27, 51, 76], [28, 53, 79], [28, 55, 82], [29, 57, 85], [30, 59, 88], [31, 61, 91], [32, 64, 94], [33, 66, 97], [35, 68, 101], [36, 70, 104], [37, 72, 107], [38, 74, 110], [40, 77, 113], [41, 79, 117], [42, 81, 120], [44, 84, 123], [45, 86, 126], [47, 88, 130], [48, 91, 133], [50, 93, 136], [51, 95, 139], [53, 98, 142], [54, 100, 145], [56, 102, 148], [58, 104, 151], [59, 107, 154], [61, 109, 157], [63, 111, 160], [64, 114, 163], [66, 116, 165], [68, 118, 168], [70, 120, 171], [71, 122, 173], [73, 125, 176], [75, 127, 178], [77, 129, 181], [78, 131, 183], [80, 133, 185], [82, 135, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]]
8 | colormap_wave2 = [[255, 255, 255], [253, 254, 254], [252, 254, 253], [250, 253, 252], [248, 253, 252], [246, 252, 251], [244, 252, 250], [242, 251, 249], [240, 251, 247], [237, 250, 246], [235, 250, 245], [232, 249, 244], [230, 248, 243], [227, 248, 242], [224, 247, 240], [221, 247, 239], [218, 246, 238], [215, 245, 236], [212, 245, 235], [209, 244, 234], [206, 243, 232], [203, 242, 231], [200, 242, 229], [196, 241, 228], [193, 240, 226], [190, 239, 225], [186, 239, 223], [183, 238, 222], [179, 237, 220], [176, 236, 218], [172, 235, 217], [169, 234, 215], [165, 233, 213], [162, 232, 212], [158, 231, 210], [155, 231, 208], [151, 230, 206], [148, 228, 205], [144, 227, 203], [141, 226, 201], [137, 225, 199], [134, 224, 198], [130, 223, 196], [127, 222, 194], [124, 221, 192], [120, 219, 190], [117, 218, 188], [114, 217, 187], [111, 216, 185], [108, 214, 183], [105, 213, 181], [102, 212, 179], [99, 210, 177], [96, 209, 176], [93, 207, 174], [91, 206, 172], [88, 204, 170], [85, 203, 168], [83, 201, 166], [80, 199, 164], [78, 198, 163], [76, 196, 161], [74, 194, 159], [72, 193, 157], [70, 191, 155], [69, 189, 153], [67, 187, 152], [65, 185, 149], [63, 183, 147], [62, 180, 145], [60, 178, 143], [59, 175, 141], [57, 173, 138], [56, 170, 136], [54, 167, 133], [53, 164, 131], [52, 161, 128], [50, 158, 125], [49, 155, 123], [48, 152, 120], [47, 149, 117], [46, 146, 115], [45, 143, 112], [43, 140, 109], [42, 136, 106], [41, 133, 103], [40, 130, 101], [39, 126, 98], [39, 123, 95], [38, 119, 92], [37, 116, 89], [36, 112, 86], [35, 109, 84], [35, 106, 81], [34, 102, 78], [33, 99, 75], [33, 95, 73], [32, 92, 70], [31, 89, 67], [31, 86, 65], [30, 82, 62], [30, 79, 60], [29, 76, 57], [29, 73, 55], [28, 70, 52], [28, 67, 50], [28, 64, 48], [27, 61, 46], [27, 58, 44], [27, 55, 42], [26, 53, 40], [26, 50, 38], [26, 48, 37], [26, 46, 35], [26, 43, 34], [26, 41, 32], [26, 39, 31], [26, 37, 30], [26, 35, 29], [26, 34, 28], [26, 32, 27], [26, 31, 26], [26, 29, 26], [26, 28, 25], [26, 27, 25], [26, 27, 25], [26, 26, 25], [26, 25, 25], [26, 25, 26], [26, 25, 26], [26, 25, 27], [26, 25, 27], [26, 25, 28], [27, 25, 30], [27, 25, 31], [27, 26, 32], [28, 27, 34], [28, 27, 35], [28, 28, 37], [29, 29, 39], [29, 30, 41], [30, 31, 43], [30, 32, 45], [31, 34, 48], [31, 35, 50], [32, 36, 53], [32, 38, 55], [33, 40, 58], [33, 41, 61], [34, 43, 64], [35, 45, 66], [36, 47, 69], [36, 49, 73], [37, 51, 76], [38, 53, 79], [39, 55, 82], [39, 57, 85], [40, 59, 89], [41, 61, 92], [42, 64, 95], [43, 66, 99], [44, 68, 102], [45, 71, 105], [46, 73, 109], [47, 76, 112], [48, 78, 116], [49, 80, 119], [50, 83, 122], [52, 86, 126], [53, 88, 129], [54, 91, 133], [55, 93, 136], [57, 96, 139], [58, 98, 143], [59, 100, 146], [60, 103, 149], [62, 105, 152], [63, 108, 155], [65, 110, 158], [66, 113, 161], [68, 115, 164], [69, 117, 167], [71, 120, 170], [72, 122, 173], [74, 124, 175], [75, 126, 178], [77, 128, 180], [79, 131, 183], [80, 133, 185], [82, 134, 187], [84, 136, 189], [86, 138, 191], [87, 140, 193], [89, 142, 194], [91, 144, 196], [93, 146, 198], [96, 147, 199], [98, 149, 201], [100, 151, 203], [103, 153, 204], [105, 155, 206], [108, 157, 207], [110, 160, 209], [113, 162, 210], [116, 164, 212], [118, 166, 213], [121, 168, 214], [124, 170, 216], [127, 172, 217], [130, 174, 218], [133, 176, 219], [136, 178, 221], [139, 180, 222], [142, 183, 223], [145, 185, 224], [148, 187, 225], [152, 189, 226], [155, 191, 227], [158, 193, 228], [161, 195, 230], [164, 197, 231], [168, 200, 231], [171, 202, 232], [174, 204, 233], [177, 206, 234], [180, 208, 235], [183, 210, 236], [187, 212, 237], [190, 214, 238], [193, 216, 239], [196, 218, 239], [199, 220, 240], [202, 222, 241], [205, 223, 242], [208, 225, 242], [211, 227, 243], [214, 229, 244], [217, 231, 245], [219, 232, 245], [222, 234, 246], [225, 236, 247], [227, 237, 247], [230, 239, 248], [232, 240, 248], [234, 242, 249], [237, 243, 250], [239, 245, 250], [241, 246, 251], [243, 247, 251], [245, 249, 252], [247, 250, 252], [249, 251, 253], [251, 252, 253], [252, 253, 254], [254, 254, 254]]
9 | colormap_wave3 = [[253, 203, 160], [252, 201, 158], [252, 200, 156], [252, 198, 154], [251, 196, 152], [251, 194, 149], [251, 192, 147], [250, 190, 145], [250, 189, 142], [249, 187, 139], [249, 185, 135], [248, 182, 132], [248, 179, 129], [247, 177, 125], [247, 175, 122], [247, 172, 119], [246, 168, 115], [246, 166, 112], [245, 164, 108], [245, 161, 105], [244, 158, 102], [243, 155, 98], [243, 152, 95], [242, 150, 92], [242, 147, 88], [241, 145, 86], [240, 142, 83], [240, 139, 80], [239, 137, 78], [238, 134, 75], [237, 131, 73], [237, 129, 70], [236, 126, 68], [235, 124, 67], [234, 121, 65], [233, 118, 63], [232, 115, 61], [231, 112, 61], [230, 110, 59], [229, 107, 57], [228, 104, 56], [228, 102, 55], [226, 100, 55], [224, 97, 55], [224, 94, 54], [222, 92, 54], [221, 89, 53], [219, 87, 53], [218, 84, 53], [217, 83, 53], [215, 80, 54], [213, 78, 54], [211, 76, 54], [209, 74, 55], [208, 72, 55], [206, 70, 56], [203, 68, 57], [201, 66, 57], [199, 65, 58], [198, 64, 59], [196, 62, 59], [193, 60, 60], [191, 59, 61], [188, 57, 61], [186, 56, 62], [183, 55, 62], [181, 55, 63], [179, 54, 63], [176, 53, 63], [173, 52, 64], [171, 51, 64], [168, 50, 65], [165, 49, 65], [162, 49, 65], [159, 48, 66], [157, 48, 66], [155, 47, 66], [152, 47, 66], [149, 46, 66], [146, 46, 65], [143, 46, 65], [140, 45, 65], [138, 45, 65], [135, 45, 64], [132, 44, 64], [130, 44, 63], [127, 44, 63], [124, 44, 62], [121, 43, 61], [118, 43, 60], [115, 43, 60], [112, 42, 60], [110, 42, 59], [108, 42, 58], [105, 42, 57], [102, 41, 56], [99, 41, 55], [97, 40, 54], [94, 40, 53], [91, 40, 52], [89, 39, 51], [86, 39, 50], [84, 38, 49], [82, 38, 48], [79, 37, 47], [77, 37, 46], [74, 36, 45], [72, 36, 44], [70, 35, 43], [68, 35, 42], [65, 34, 41], [64, 34, 40], [62, 33, 39], [60, 33, 38], [58, 33, 38], [56, 32, 38], [54, 32, 36], [52, 32, 35], [50, 32, 35], [48, 31, 35], [47, 31, 34], [45, 31, 34], [43, 31, 33], [42, 30, 33], [41, 30, 32], [40, 30, 32], [39, 30, 32], [38, 30, 32], [37, 30, 32], [36, 30, 32], [36, 30, 32], [37, 30, 32], [38, 30, 32], [39, 30, 32], [40, 30, 32], [41, 30, 32], [42, 30, 33], [44, 31, 33], [46, 31, 34], [47, 31, 34], [49, 31, 35], [51, 32, 35], [53, 32, 36], [55, 32, 37], [57, 33, 38], [59, 33, 38], [61, 33, 39], [63, 34, 40], [65, 34, 41], [67, 35, 42], [70, 35, 43], [72, 36, 44], [74, 36, 45], [77, 37, 46], [79, 37, 47], [82, 38, 48], [84, 38, 49], [87, 39, 50], [90, 39, 51], [92, 40, 52], [95, 40, 53], [98, 40, 54], [100, 41, 55], [103, 41, 56], [106, 42, 57], [109, 42, 58], [111, 42, 59], [114, 43, 60], [117, 43, 60], [120, 43, 61], [123, 44, 62], [126, 44, 63], [129, 44, 63], [131, 44, 64], [134, 45, 64], [137, 45, 65], [140, 45, 65], [143, 46, 65], [146, 46, 65], [149, 46, 66], [152, 47, 66], [155, 47, 66], [158, 48, 66], [160, 48, 66], [163, 49, 65], [166, 49, 65], [169, 50, 65], [172, 51, 64], [174, 52, 64], [177, 53, 63], [180, 54, 63], [182, 55, 62], [185, 56, 62], [187, 57, 61], [190, 58, 61], [192, 60, 60], [195, 61, 59], [197, 63, 59], [199, 65, 58], [201, 66, 57], [203, 68, 57], [206, 70, 56], [208, 72, 55], [209, 74, 55], [211, 76, 54], [213, 78, 54], [215, 81, 54], [217, 83, 53], [218, 85, 53], [220, 88, 53], [221, 90, 53], [223, 93, 54], [224, 95, 54], [225, 98, 55], [227, 101, 55], [228, 103, 56], [229, 106, 57], [230, 109, 58], [231, 111, 60], [232, 114, 61], [233, 117, 62], [234, 120, 64], [235, 123, 66], [236, 125, 68], [237, 128, 70], [237, 131, 73], [238, 134, 75], [239, 137, 78], [240, 139, 80], [240, 142, 83], [241, 145, 86], [242, 148, 89], [242, 151, 93], [243, 153, 96], [243, 156, 99], [244, 159, 103], [245, 162, 106], [245, 165, 110], [246, 167, 113], [246, 170, 117], [247, 173, 120], [247, 176, 124], [248, 178, 127], [248, 181, 131], [249, 184, 134], [249, 186, 138], [250, 188, 141], [250, 190, 144], [251, 192, 147], [251, 194, 149], [251, 196, 152], [252, 198, 154], [252, 200, 156], [252, 201, 158], [253, 203, 160]]
10 | colormap_wave4 = [[246, 230, 183], [246, 229, 182], [246, 227, 180], [246, 226, 178], [246, 224, 176], [245, 222, 173], [245, 219, 170], [244, 217, 167], [244, 214, 163], [244, 211, 160], [243, 209, 156], [243, 206, 152], [242, 203, 148], [242, 200, 144], [241, 196, 140], [241, 193, 136], [241, 190, 132], [240, 186, 128], [240, 183, 124], [239, 180, 120], [239, 176, 116], [238, 173, 112], [238, 170, 108], [237, 166, 104], [237, 163, 100], [236, 160, 97], [236, 156, 93], [236, 153, 90], [235, 150, 87], [235, 147, 84], [235, 144, 81], [234, 140, 78], [234, 137, 76], [234, 134, 74], [234, 131, 71], [233, 127, 69], [233, 124, 67], [233, 121, 65], [233, 118, 64], [232, 115, 62], [232, 112, 61], [232, 109, 60], [232, 106, 59], [232, 103, 58], [232, 101, 58], [232, 98, 57], [232, 95, 57], [231, 93, 57], [230, 90, 57], [230, 88, 57], [229, 85, 57], [227, 83, 57], [226, 81, 57], [224, 78, 57], [222, 76, 58], [220, 74, 58], [217, 72, 59], [215, 70, 59], [212, 67, 60], [210, 65, 60], [207, 63, 60], [204, 62, 61], [201, 60, 61], [199, 58, 61], [196, 56, 61], [193, 55, 61], [189, 53, 61], [186, 52, 62], [183, 50, 61], [180, 49, 61], [176, 48, 61], [173, 46, 61], [170, 45, 61], [166, 44, 61], [163, 43, 61], [159, 42, 60], [156, 40, 60], [152, 39, 60], [149, 39, 59], [146, 38, 58], [142, 37, 58], [139, 36, 57], [135, 35, 56], [132, 34, 55], [128, 34, 54], [125, 33, 53], [122, 32, 52], [118, 31, 51], [115, 31, 50], [111, 30, 49], [108, 29, 48], [105, 28, 46], [101, 28, 45], [98, 27, 44], [94, 26, 43], [91, 25, 41], [88, 24, 40], [85, 24, 38], [82, 23, 37], [78, 22, 35], [75, 21, 34], [72, 20, 32], [69, 19, 31], [66, 18, 29], [63, 18, 28], [60, 16, 26], [57, 16, 25], [55, 15, 24], [52, 14, 22], [49, 13, 21], [46, 12, 19], [44, 11, 18], [41, 11, 17], [39, 10, 16], [36, 9, 14], [34, 8, 13], [31, 8, 12], [29, 7, 11], [27, 7, 10], [25, 6, 10], [23, 5, 9], [21, 5, 8], [19, 4, 7], [17, 4, 7], [16, 4, 6], [14, 3, 6], [13, 3, 5], [13, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [12, 3, 5], [13, 3, 5], [14, 3, 5], [15, 3, 6], [16, 4, 6], [18, 4, 7], [20, 4, 7], [21, 5, 8], [23, 5, 9], [26, 6, 10], [28, 7, 11], [30, 7, 12], [33, 8, 13], [35, 9, 14], [38, 9, 15], [40, 10, 17], [43, 11, 18], [46, 12, 19], [48, 13, 21], [51, 14, 22], [54, 15, 24], [57, 16, 25], [60, 17, 27], [63, 18, 28], [67, 19, 30], [70, 20, 31], [73, 20, 33], [76, 21, 34], [80, 22, 36], [83, 23, 37], [86, 24, 39], [90, 25, 40], [93, 26, 42], [96, 27, 43], [100, 27, 44], [103, 28, 46], [107, 29, 47], [110, 30, 48], [114, 30, 50], [117, 31, 51], [121, 32, 52], [124, 33, 53], [128, 34, 54], [131, 34, 55], [135, 35, 56], [138, 36, 57], [142, 37, 57], [146, 38, 58], [149, 39, 59], [153, 39, 59], [156, 40, 60], [160, 42, 60], [164, 43, 61], [167, 44, 61], [171, 45, 61], [174, 46, 61], [178, 48, 62], [181, 49, 62], [184, 51, 62], [188, 52, 62], [191, 54, 62], [194, 56, 61], [197, 57, 61], [200, 59, 61], [203, 61, 61], [206, 63, 60], [209, 65, 60], [212, 67, 59], [215, 69, 59], [217, 72, 59], [220, 74, 58], [222, 76, 58], [224, 78, 57], [226, 81, 57], [227, 83, 57], [229, 86, 57], [230, 88, 57], [230, 91, 57], [231, 94, 57], [232, 97, 57], [232, 99, 57], [232, 102, 58], [232, 105, 59], [232, 108, 60], [232, 111, 61], [232, 114, 62], [232, 117, 63], [233, 120, 65], [233, 123, 66], [233, 127, 68], [233, 130, 71], [234, 133, 73], [234, 137, 75], [234, 140, 78], [235, 143, 81], [235, 147, 84], [235, 150, 87], [236, 153, 90], [236, 157, 94], [236, 160, 97], [237, 164, 101], [237, 167, 105], [238, 170, 109], [238, 174, 113], [239, 178, 117], [239, 181, 121], [240, 184, 126], [240, 188, 130], [241, 192, 134], [241, 195, 138], [242, 198, 142], [242, 202, 147], [243, 205, 151], [243, 208, 155], [244, 211, 159], [244, 214, 162], [244, 216, 166], [245, 219, 169], [245, 221, 173], [246, 224, 175], [246, 226, 178], [246, 227, 180], [246, 229, 182], [246, 230, 183]]
11 |
12 |
13 | class WaveVisualizer:
14 | def __init__(self, field_colormap, intensity_colormap):
15 | self.field_colormap = field_colormap
16 | self.intensity_colormap = intensity_colormap
17 | self.intensity = None
18 | self.intensity_exp_average_factor = 0.98
19 | self.field = None
20 | self.visualization_image = None
21 |
22 | def update(self, wave_sim):
23 | self.field = wave_sim.get_field()
24 |
25 | if self.intensity is None:
26 | self.intensity = cp.zeros_like(self.field)
27 |
28 | t = self.intensity_exp_average_factor
29 | self.intensity = self.intensity*t + (self.field**2)*(1.0-t)
30 | self.visualization_image = wave_sim.render_visualization()
31 |
32 | def render_intensity(self, brightness_scale=1.0, exp=0.5, overlay_visualization=True):
33 | gray = (cp.clip((self.intensity**exp)*brightness_scale, 0.0, 1.0) * 254.0).astype(np.uint8)
34 | img = self.intensity_colormap[gray].get() if self.intensity_colormap is not None else gray.get()
35 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
36 | if overlay_visualization:
37 | img = cv2.add(img, self.visualization_image)
38 | return img
39 |
40 | def render_field(self, brightness_scale=1.0, overlay_visualization=True):
41 | gray = (cp.clip(self.field*brightness_scale, -1.0, 1.0) * 127 + 127).astype(np.uint8)
42 | img = self.field_colormap[gray].get() if self.field_colormap is not None else gray.get()
43 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44 | if overlay_visualization:
45 | img = cv2.add(img, self.visualization_image)
46 | return img
47 |
48 |
49 | def get_colormap_lut(name, invert, black_level=0.0, make_symmetric=False):
50 | if name == 'icefire': color_values = np.array(colormap_icefire)/255
51 | elif name == 'colormap_wave1': color_values = np.array(colormap_wave1)/255
52 | elif name == 'colormap_wave2':color_values = np.array(colormap_wave2) / 255
53 | elif name == 'colormap_wave3': color_values = np.array(colormap_wave3) / 255
54 | elif name == 'colormap_wave4': color_values = np.array(colormap_wave4) / 255
55 | else:
56 | colormap = matplotlib.pyplot.get_cmap(name)
57 | color_values = colormap(np.linspace(0, 1, 255))
58 |
59 | if invert:
60 | color_values = 1.0-color_values
61 |
62 | if make_symmetric:
63 | src = color_values.copy()
64 | color_values[255:126:-1, :] = src[0:255:2, :]
65 | color_values[0:128, :] = src[0:255:2, :]
66 |
67 | color_values = np.clip(color_values*(1.0-black_level)+black_level, 0, 255)
68 |
69 | return cp.asarray((color_values*255).astype(np.uint8))
--------------------------------------------------------------------------------