├── src └── topsy │ ├── __main__.py │ ├── canvas │ ├── qt │ │ ├── icons │ │ │ ├── rgb.png │ │ │ ├── stop.png │ │ │ ├── camera.png │ │ │ ├── linked.png │ │ │ ├── movie.png │ │ │ ├── record.png │ │ │ ├── unlinked.png │ │ │ ├── load_script.png │ │ │ └── save_script.png │ │ ├── lineedit.py │ │ ├── recording.py │ │ ├── __init__.py │ │ └── colormap.py │ ├── offscreen.py │ ├── __init__.py │ └── jupyter.py │ ├── drawreason.py │ ├── performance.py │ ├── text.py │ ├── shaders │ ├── overlay.wgsl │ ├── smooth.wgsl │ ├── line.wgsl │ ├── surface.wgsl │ ├── colormap.wgsl │ └── sph.wgsl │ ├── simcube.py │ ├── config.py │ ├── colorbar.py │ ├── periodic_sph.py │ ├── util.py │ ├── cell_layout.py │ ├── recorder │ ├── interpolator.py │ └── __init__.py │ ├── view_synchronizer.py │ ├── particle_buffers.py │ ├── colormap │ └── __init__.py │ ├── split_buffers.py │ ├── line.py │ ├── scalebar.py │ └── __init__.py ├── tests ├── run_tests_in_docker.sh ├── test_arg_parse.py ├── test_cell_layout.py ├── test_render_mode.py ├── test_split_buffers.py ├── test_interpolation.py ├── test_jupyter_interface.py ├── test_synchronizer.py ├── test_smooth.py ├── test_scalebar.py ├── test_colormap.py └── test_progression.py ├── Dockerfile ├── .github └── workflows │ ├── publish.yaml │ └── build-test.yaml ├── pyproject.toml ├── LICENSE └── README.md /src/topsy/__main__.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from . import main 3 | main() 4 | -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/rgb.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/stop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/stop.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/camera.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/camera.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/linked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/linked.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/movie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/movie.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/record.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/record.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/unlinked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/unlinked.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/load_script.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/load_script.png -------------------------------------------------------------------------------- /src/topsy/canvas/qt/icons/save_script.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pynbody/topsy/HEAD/src/topsy/canvas/qt/icons/save_script.png -------------------------------------------------------------------------------- /src/topsy/canvas/offscreen.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from rendercanvas.offscreen import OffscreenRenderCanvas, loop 4 | 5 | from . import VisualizerCanvasBase 6 | 7 | 8 | class VisualizerCanvas(VisualizerCanvasBase, OffscreenRenderCanvas): 9 | 10 | @classmethod 11 | def call_later(cls, delay, fn, *args): 12 | loop.call_later(delay, fn, *args) 13 | 14 | -------------------------------------------------------------------------------- /tests/run_tests_in_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # It may be useful to run the tests in docker to track down issues using a standardised environment 4 | # This script will build the docker image, run the tests and copy the output to the local machine 5 | 6 | rm -rf docker_test_output 7 | docker build .. -t topsy 8 | docker run --name running-tests topsy -c 'pytest' 9 | docker cp running-tests:/app/tests/output ./docker_test_output 10 | docker rm running-tests 11 | 12 | -------------------------------------------------------------------------------- /src/topsy/canvas/qt/lineedit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from PySide6 import QtWidgets, QtCore 4 | 5 | 6 | class MyLineEdit(QtWidgets.QLineEdit): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._timer = QtCore.QTimer() 10 | self._timer.setSingleShot(True) 11 | self._timer.timeout.connect(self.selectAll) 12 | 13 | def focusInEvent(self, event): 14 | super().focusInEvent(event) 15 | self._timer.start(0) 16 | -------------------------------------------------------------------------------- /src/topsy/drawreason.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class DrawReason(enum.Enum): 4 | """Enum to specify the reason for a draw (which may affect detailed behaviour)""" 5 | INITIAL_UPDATE = 1 # render from scratch 6 | CHANGE = 2 # a change has occurred, possibly from the UI 7 | REFINE = 3 # render the SPH at full resolution, within an interactive context 8 | PRESENTATION_CHANGE = 4 # i.e. don't rerender SPH 9 | EXPORT = 5 # full rendering, always at full resolution 10 | -------------------------------------------------------------------------------- /src/topsy/performance.py: -------------------------------------------------------------------------------- 1 | """Tools for measuring performance""" 2 | 3 | try: 4 | from os_signpost import Signposter 5 | signposter = Signposter("com.pynbody.topsy", Signposter.Category.PointsOfInterest) 6 | except ImportError: 7 | # Most of the time we won't have this module, so make a dummy 8 | class DummySignposter: 9 | def __init__(self): 10 | pass 11 | 12 | def begin_interval(self, *args, **kwargs): 13 | pass 14 | 15 | def emit_event(self, *args, **kwargs): 16 | pass 17 | 18 | def use_interval(*args, **kwds): 19 | pass 20 | 21 | signposter = DummySignposter() -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:24.04 2 | 3 | 4 | RUN apt update && apt install -y python3 python3-pip libgl1-mesa-dri libxcb-xfixes0-dev \ 5 | libglib2.0-0 mesa-vulkan-drivers libegl1 libgl1 libglx-mesa0 libxkbcommon0 libdbus-1-3 6 | 7 | # manually install some dependencies to speed up 8 | RUN pip3 install --break-system-packages numpy pynbody matplotlib pillow tqdm opencv-python PySide6 pytest pytest-ipywidgets 9 | RUN playwright install 10 | RUN playwright install-deps 11 | 12 | COPY src /app/src 13 | COPY tests /app/tests 14 | COPY pyproject.toml /app/ 15 | COPY README.md /app/ 16 | 17 | WORKDIR /app 18 | 19 | RUN pip3 install --break-system-packages .[test] 20 | 21 | ENTRYPOINT ["/bin/bash"] -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.12' 21 | 22 | - name: Install flit 23 | run: pip install flit 24 | 25 | - name: Build package 26 | run: flit build 27 | 28 | - name: Publish to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /tests/test_arg_parse.py: -------------------------------------------------------------------------------- 1 | 2 | import topsy 3 | 4 | def test_simple_arg_parse(): 5 | args = topsy.parse_args(["test://1000","-q","test-quantity"]) 6 | assert len(args)==1 7 | args = args[0] 8 | assert args.filename == "test://1000" 9 | assert args.quantity == "test-quantity" 10 | assert args.resolution == topsy.config.DEFAULT_RESOLUTION 11 | assert args.colormap == topsy.config.DEFAULT_COLORMAP 12 | 13 | def test_multi_arg_parse(): 14 | args = topsy.parse_args(["file1","-q","test-quantity","-p","dm","+","file2","-q","test-quantity2"]) 15 | assert len(args)==2 16 | 17 | a = args[0] 18 | assert a.filename == "file1" 19 | assert a.quantity == "test-quantity" 20 | assert a.particle == "dm" 21 | 22 | a = args[1] 23 | assert a.filename == "file2" 24 | assert a.quantity == "test-quantity2" 25 | assert a.particle == "dm" 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/build-test.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Test 2 | 3 | on: [push, pull_request] 4 | 5 | defaults: 6 | run: 7 | shell: bash 8 | 9 | jobs: 10 | 11 | build: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: ["3.11", "3.12", "3.13"] 17 | runs-on: ${{ matrix.os }} 18 | 19 | steps: 20 | - name: Install llvmpipe and lavapipe for offscreen canvas 21 | if: matrix.os == 'ubuntu-latest' 22 | run: | 23 | sudo apt-get update -y -qq 24 | sudo apt install -y libegl1-mesa-dev libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers 25 | - name: Install Python 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - uses: actions/checkout@v2 30 | - name: Install 31 | run: | 32 | pip install .[test] 33 | playwright install 34 | sudo $(which playwright) install-deps 35 | - name: Run all tests 36 | working-directory: tests 37 | run: python -m pytest 38 | - uses: actions/upload-artifact@v4 39 | if: always() 40 | with: 41 | name: Outputs from tests on Python ${{ matrix.python-version }} 42 | path: tests/output/ 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.8,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "topsy" 7 | authors = [{name = "Andrew Pontzen", email = "a.pontzen@ucl.ac.uk"}] 8 | dynamic = ["version", "description"] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Science/Research", 12 | "License :: OSI Approved :: BSD License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Topic :: Scientific/Engineering :: Visualization", 16 | ] 17 | readme = 'README.md' 18 | requires-python = ">=3.11" 19 | 20 | dependencies = [ 21 | "numpy >=1.16.0", 22 | "pynbody >=2.1.1", 23 | "matplotlib >=3.6.0", 24 | "pillow >=9.5.0", # 9.5.0 needed for Image.frombytes accepting memoryview 25 | "wgpu >= 0.22, <0.23", 26 | "jupyter_rfb >=0.4.1", 27 | "tqdm >=4.62.0", 28 | "opencv-python >=4.8.0", 29 | "PySide6 >= 6.5.0", # windowing 30 | "superqt>=0.7.3", # rangeslider 31 | "rendercanvas>=2.0.3", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | test = [ 36 | "pytest >=6.2.0", 37 | "pytest-ipywidgets", 38 | "tifffile", 39 | ] 40 | 41 | [project.scripts] 42 | topsy = "topsy:main" 43 | 44 | [project.urls] 45 | Home = "https://github.com/pynbody/topsy" 46 | -------------------------------------------------------------------------------- /tests/test_cell_layout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from topsy import cell_layout 3 | 4 | def test_randomization(): 5 | offsets = np.array([0, 10, 30]) 6 | lengths = np.array([10, 20, 20]) 7 | centres = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0,0.0,0.0]]) 8 | cl = cell_layout.CellLayout(centres, offsets, lengths) 9 | 10 | order = cl.randomize_within_cells() 11 | assert (order[:10]<10).all() 12 | assert ((order[10:30]<30) & (order[10:30]>=10)).all() 13 | assert ((order[30:50]<50) & (order[30:50]>=30)).all() 14 | 15 | assert (order != np.arange(50)).any() 16 | 17 | def test_from_positions(): 18 | Npart = 10000 19 | Nside = 10 20 | Ncells_to_test = 100 21 | min_pos = -1.0 22 | max_pos = 1.0 23 | 24 | np.random.seed(1337) 25 | 26 | pos = np.random.uniform(min_pos, max_pos, (Npart, 3)) 27 | 28 | cl, order = cell_layout.CellLayout.from_positions(pos, min_pos, max_pos, Nside) 29 | 30 | pos = pos[order] 31 | 32 | for test_cell in np.random.randint(0, Nside**3, Ncells_to_test): 33 | 34 | # get the cell slice 35 | cell_slice = cl.cell_slice(test_cell) 36 | 37 | # get the positions of the particles in the cell 38 | test_pos = pos[cell_slice] 39 | 40 | # get the cell centre 41 | cell_centre = cl._centres[test_cell] 42 | 43 | cell_size = (max_pos - min_pos) / Nside 44 | 45 | # check that all the particles are within the cell 46 | assert ((test_pos > (cell_centre - 0.5*cell_size)) & (test_pos < (cell_centre + 0.5*cell_size))).all() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Andrew Pontzen 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /src/topsy/text.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import matplotlib 4 | from matplotlib.figure import Figure 5 | import matplotlib.pyplot as plt 6 | 7 | from .overlay import Overlay 8 | 9 | 10 | class TextOverlay(Overlay): 11 | def __init__(self, visualizer, text, clipspace_origin, logical_pixels_height, *, dpi=200, **kwargs): 12 | self.text = text 13 | self.dpi = dpi 14 | self.clipspace_origin = clipspace_origin 15 | self.pixelspace_height = logical_pixels_height 16 | self.kwargs = kwargs 17 | 18 | super().__init__(visualizer) 19 | 20 | def get_clipspace_coordinates(self, width, height): 21 | im = self.get_contents() 22 | x,y = self.clipspace_origin 23 | height = self.pixelspace_height*self._visualizer.canvas.pixel_ratio/height 24 | width = self.pixelspace_height*self._visualizer.canvas.pixel_ratio*im.shape[1]/im.shape[0]/width 25 | return x, y, width, height 26 | 27 | def render_contents(self): 28 | return self.text_to_rgba(self.text, dpi=self.dpi, **self.kwargs) 29 | 30 | @staticmethod 31 | def text_to_rgba(s, *, dpi, **kwargs): 32 | """Render text to RGBA image. 33 | 34 | Based on 35 | https://matplotlib.org/stable/gallery/text_labels_and_annotations/mathtext_asarray.html""" 36 | fig = Figure(facecolor="none") 37 | fig.text(0, 0, s, **kwargs) 38 | with BytesIO() as buf: 39 | fig.savefig(buf, dpi=dpi, format="png", bbox_inches="tight", 40 | pad_inches=0) 41 | buf.seek(0) 42 | rgba = plt.imread(buf) 43 | return rgba -------------------------------------------------------------------------------- /src/topsy/shaders/overlay.wgsl: -------------------------------------------------------------------------------- 1 | struct OverlayParams { 2 | clipspace_origin: vec2, 3 | clipspace_extent: vec2, 4 | texturespace_origin: vec2, 5 | texturespace_extent: vec2, 6 | }; 7 | 8 | @group(0) @binding(0) 9 | var overlay_params: OverlayParams; 10 | 11 | struct VertexOutput { 12 | @builtin(position) pos: vec4, 13 | @location(0) texcoord: vec2, 14 | @location(1) weight: f32, 15 | }; 16 | 17 | 18 | @vertex 19 | fn vertex_main(@builtin(vertex_index) vertexIndex : u32, 20 | @location(0) instanceOffset : vec2, 21 | @location(1) instanceWeight : f32) -> VertexOutput { 22 | var offsets = array, 4>( 23 | vec2(0.0, 0.0), 24 | vec2(0.0, 1.0), 25 | vec2(1.0, 0.0), 26 | vec2(1.0, 1.0) 27 | ); 28 | 29 | var output: VertexOutput; 30 | 31 | var vertexOffset = offsets[vertexIndex]; 32 | vertexOffset.y = 1.0 - vertexOffset.y; 33 | vertexOffset *= overlay_params.clipspace_extent; 34 | 35 | output.pos = vec4(overlay_params.clipspace_origin + instanceOffset + vertexOffset, 0.0, 1.0); 36 | output.texcoord = overlay_params.texturespace_origin + offsets[vertexIndex]*overlay_params.texturespace_extent; 37 | output.weight = instanceWeight; 38 | return output; 39 | } 40 | 41 | @group(0) @binding(1) 42 | var image_texture: texture_2d; 43 | 44 | @group(0) @binding(2) 45 | var image_sampler: sampler; 46 | 47 | 48 | @fragment 49 | fn fragment_main(input: VertexOutput) -> @location(0) vec4 { 50 | return textureSample(image_texture, image_sampler, input.texcoord) * input.weight; 51 | } 52 | -------------------------------------------------------------------------------- /src/topsy/simcube.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .line import Line 4 | import numpy as np 5 | 6 | class SimCube(Line): 7 | def __init__(self, visualizer, color, width): 8 | size = visualizer.data_loader.get_periodicity_scale() or 1.0 9 | line_starts_ends = [[0,0,0], [0,0,1], 10 | [0,0,0], [0,1,0], 11 | [0,0,0], [1,0,0], 12 | [1,1,1], [1,1,0], 13 | [1,1,1], [1,0,1], 14 | [1,1,1], [0,1,1], 15 | [0,1,0], [0,1,1], 16 | [0,1,0], [1,1,0], 17 | [1,0,1], [1,0,0], 18 | [1,0,1], [0,0,1], 19 | [1,0,0], [1,1,0], 20 | [0,1,1], [0,0,1] 21 | ] 22 | 23 | line_starts_ends = np.array(line_starts_ends, dtype=np.float32) 24 | line_starts_ends -= 0.5 25 | 26 | line_starts_ends *= size 27 | 28 | line_starts_ends = np.concatenate([line_starts_ends, np.ones((line_starts_ends.shape[0], 1))], axis=1) 29 | 30 | self._line_starts = np.ascontiguousarray(line_starts_ends[::2,:], dtype=np.float32) 31 | self._line_ends = np.ascontiguousarray(line_starts_ends[1::2,:], dtype=np.float32) 32 | 33 | 34 | super().__init__(visualizer, None, color, width) 35 | 36 | def encode_render_pass(self, command_encoder: wgpu.GPUCommandEncoder, 37 | target_texture_view: wgpu.GPUTextureView): 38 | self._params["transform"] = self._visualizer._sph.last_transform_params["transform"] @ self._visualizer.sph_clipspace_to_screen_clipspace_matrix() 39 | super().encode_render_pass(command_encoder, target_texture_view) 40 | 41 | 42 | -------------------------------------------------------------------------------- /src/topsy/shaders/smooth.wgsl: -------------------------------------------------------------------------------- 1 | struct SmoothingParams { 2 | spatial_sigma: f32, 3 | range_sigma: f32, 4 | kernel_size: i32, 5 | padding: i32, // For 16-byte alignment 6 | } 7 | 8 | @group(0) @binding(0) var input_depth: texture_2d; 9 | @group(0) @binding(1) var output_depth: texture_storage_2d; 10 | @group(0) @binding(2) var params: SmoothingParams; 11 | 12 | @compute @workgroup_size(8, 8) 13 | fn bilateral_filter_main(@builtin(global_invocation_id) global_id: vec3) { 14 | let coord = vec2(global_id.xy); 15 | let center_sample = textureLoad(input_depth, coord, 0); 16 | let center_depth = center_sample.g; 17 | 18 | var weighted_sum = 0.0; 19 | var weight_sum = 0.0; 20 | 21 | let half_kernel = params.kernel_size / 2; 22 | 23 | let tex_size = textureDimensions(input_depth, 0); 24 | 25 | // Sample neighborhood 26 | for (var dy = -half_kernel; dy <= half_kernel; dy++) { 27 | for (var dx = -half_kernel; dx <= half_kernel; dx++) { 28 | let sample_coord = clamp(coord + vec2(dx, dy), vec2(0, 0), vec2(tex_size) - vec2(1, 1)); 29 | let sample_depth = textureLoad(input_depth, sample_coord, 0).g; 30 | 31 | // Spatial weight (Gaussian based on distance) 32 | let spatial_dist = sqrt(f32(dx*dx + dy*dy)); 33 | let w_spatial = exp(-(spatial_dist * spatial_dist) / (2.0 * params.spatial_sigma * params.spatial_sigma)); 34 | 35 | // Range weight (Gaussian based on depth difference) 36 | let depth_diff = abs(sample_depth - center_depth); 37 | let w_range = exp(-(depth_diff * depth_diff) / (2.0 * params.range_sigma * params.range_sigma)); 38 | 39 | let total_weight = w_spatial * w_range; 40 | weighted_sum += sample_depth * total_weight; 41 | weight_sum += total_weight; 42 | } 43 | } 44 | 45 | 46 | let filtered_depth = weighted_sum / weight_sum; 47 | textureStore(output_depth, coord, vec4(center_sample.r, filtered_depth, 0.0, 1.0)); 48 | } 49 | -------------------------------------------------------------------------------- /src/topsy/config.py: -------------------------------------------------------------------------------- 1 | DEFAULT_RESOLUTION = 1024 2 | DEFAULT_COLORMAP = 'twilight_shifted' 3 | 4 | DEFAULT_SCALE = 200.0 # viewport width in kpc 5 | 6 | TARGET_FPS = 30 # will use downsampling to achieve this 7 | INITIAL_PARTICLES_TO_RENDER = 1e5 # number of particles to render at first 8 | STATUS_LINE_UPDATE_INTERVAL = 0.2 # seconds 9 | STATUS_LINE_UPDATE_INTERVAL_RAPID = 0.05 # when time-critical information is being displayed 10 | 11 | GLIDE_TIME = 0.3 # seconds after double click to reach destination 12 | 13 | COLORBAR_ASPECT_RATIO = 0.15 14 | COLORMAP_NUM_SAMPLES = 1000 15 | 16 | TEST_DATA_NUM_PARTICLES_DEFAULT = int(1e6) 17 | 18 | MAX_PARTICLES_PER_BUFFER = 2**27 19 | # arbitrary number, but small enough that GPU memory fragmentation not a huge issue hopefully, while 20 | # large enough to not cause too much overhead 21 | 22 | MAX_PARTICLES_PER_EXPORT_RENDERCALL = 2 ** 25 23 | # again, a arbitraryish number, determined by experimentation. Profiling finds that calling render on a large number of particles 24 | # causes a lot of overhead - perhaps due to pipeline stalls within the GPU. This becomes particularly 25 | # noticeable on EXPORT rendering large simulations. Splitting into a sequence of smaller calls. 26 | 27 | DEFAULT_CELLS_NSIDE = 16 28 | # To provide a quick way to remove unneeded verticies, the simulation is ordered into cells 29 | # by default, use this number of cells on a side (so DEFAULT_CELLS_NSIDE^3 in total). 30 | # High numbers enable more precise geometric selections, but at a general performance penalty 31 | # due to the complexity of the GPU buffer organization. 32 | 33 | CELL_LAYOUT_FRACTIONAL_PADDING = 1e-5 34 | # Fractional padding to add to the overall cell cube, beyond the max/min particle positions 35 | 36 | JUPYTER_UI_LAG = 0.05 37 | # time over which to spread jupyter UI updates, notably for sliders where updating the range and value 38 | # simultaneously seems to lead to problems 39 | 40 | # special name for projected density in UI 41 | PROJECTED_DENSITY_NAME = "Projected density" 42 | 43 | # Maximum number of pixels to use when smoothing z maps for surfaces 44 | MAX_SURFACE_SMOOTH_PIXELS = 100 -------------------------------------------------------------------------------- /src/topsy/shaders/line.wgsl: -------------------------------------------------------------------------------- 1 | struct LineRenderParams { 2 | transform: mat4x4, 3 | color: vec4, 4 | vp_size_pix: vec2, 5 | linewidth_pix: f32, 6 | }; 7 | 8 | @group(0) @binding(0) 9 | var render_params: LineRenderParams; 10 | 11 | struct VertexOutput { 12 | @builtin(position) pos: vec4, 13 | @location(0) color: vec4, 14 | }; 15 | 16 | 17 | @vertex 18 | fn vertex_main(@location(0) instanceStart : vec4, 19 | @location(1) instanceEnd : vec4, 20 | @builtin(vertex_index) vertexIndex : u32) -> VertexOutput { 21 | 22 | var output: VertexOutput; 23 | 24 | var instanceStartTransformed : vec2 = (render_params.transform * instanceStart).xy * render_params.vp_size_pix; 25 | var instanceEndTransformed : vec2 = (render_params.transform * instanceEnd).xy * render_params.vp_size_pix; 26 | var normalizedOffsetVector : vec2 = normalize(instanceEndTransformed - instanceStartTransformed); 27 | var normalToLine : vec2 = vec2(-normalizedOffsetVector.y, normalizedOffsetVector.x); 28 | 29 | switch vertexIndex { 30 | case 0u { 31 | output.pos = vec4(instanceStartTransformed - normalToLine * render_params.linewidth_pix * 0.5, 0.0, 1.0); 32 | } 33 | case 1u: { 34 | output.pos = vec4(instanceStartTransformed + normalToLine * render_params.linewidth_pix * 0.5, 0.0, 1.0); 35 | } 36 | case 2u: { 37 | output.pos = vec4(instanceEndTransformed - normalToLine * render_params.linewidth_pix * 0.5, 0.0, 1.0); 38 | } 39 | case 3u: { 40 | output.pos = vec4(instanceEndTransformed + normalToLine * render_params.linewidth_pix * 0.5, 0.0, 1.0); 41 | } 42 | default: { 43 | output.pos = vec4(0.0); 44 | } 45 | } 46 | 47 | output.pos/=vec4(render_params.vp_size_pix, 1.0, 1.0); 48 | 49 | output.color = render_params.color; 50 | 51 | return output; 52 | } 53 | 54 | 55 | 56 | @fragment 57 | fn fragment_main(input: VertexOutput) -> @location(0) vec4 { 58 | return input.color; 59 | } 60 | -------------------------------------------------------------------------------- /tests/test_render_mode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import topsy 4 | 5 | from topsy.canvas import offscreen 6 | 7 | def test_render_mode_switching(): 8 | """Test that render mode can be switched on-the-fly without errors""" 9 | vis = topsy.test(1000, render_resolution=200, canvas_class=offscreen.VisualizerCanvas, 10 | render_mode='univariate') 11 | vis.scale = 20.0 12 | 13 | mode_sequence = 'univariate', 'bivariate', 'rgb', 'rgb-hdr', 'surface' 14 | 15 | for mode in mode_sequence: 16 | vis.render_mode = mode 17 | _check_vis_output_matches_mode(vis, mode) 18 | 19 | 20 | def test_render_mode_invalid(): 21 | """Test that invalid render modes are properly rejected""" 22 | vis = topsy.test(100, render_resolution=50, canvas_class=offscreen.VisualizerCanvas) 23 | 24 | # Test setting invalid render mode 25 | with pytest.raises(ValueError, match="Invalid render_mode 'invalid'"): 26 | vis.render_mode = 'invalid' 27 | 28 | # Verify the original render mode is unchanged 29 | assert vis.render_mode == 'univariate' 30 | 31 | def test_render_mode_reinitialization(): 32 | """Test that render mode can be set during initialization""" 33 | modes_to_test = ['univariate', 'bivariate', 'rgb', 'rgb-hdr', 'surface'] 34 | 35 | for mode in modes_to_test: 36 | vis = topsy.test(100, render_resolution=50, canvas_class=offscreen.VisualizerCanvas, 37 | render_mode=mode) 38 | assert vis.render_mode == mode 39 | 40 | _check_vis_output_matches_mode(vis, mode) 41 | 42 | class RestrictedModeOffscreenCanvas(offscreen.VisualizerCanvas): 43 | """A custom canvas that prevents hdr rendering""" 44 | def _rc_get_present_methods(self): 45 | return { 46 | "bitmap": { 47 | "formats": ["rgba-u8"], 48 | } 49 | } 50 | 51 | def test_render_mode_fail(): 52 | """Tests that if a particular render mode fails, the original render mode is restored""" 53 | vis = topsy.test(100, render_resolution=50, canvas_class=RestrictedModeOffscreenCanvas, 54 | render_mode='univariate') 55 | 56 | original_mode = vis.render_mode 57 | 58 | # Attempt to set an invalid render mode 59 | with pytest.raises(ValueError): 60 | vis.render_mode = 'rgb-hdr' # valid, but we are forcing a failure (as will happen in Jupyter currently) 61 | 62 | # Verify that the original render mode is still intact 63 | assert vis.render_mode == original_mode 64 | 65 | def _check_vis_output_matches_mode(vis, mode): 66 | result = vis.get_sph_image() 67 | result_presentation = vis.get_sph_presentation_image() 68 | 69 | # check type based on mode: 70 | if mode.endswith('hdr'): 71 | assert result_presentation.dtype == np.float16 72 | else: 73 | assert result_presentation.dtype == np.uint8 74 | 75 | res = vis._render_resolution 76 | 77 | assert result_presentation.shape == (res, res, 4) 78 | 79 | if mode in ['rgb', 'rgb-hdr']: 80 | assert result.shape == (res, res, 3) 81 | elif mode in ['bivariate', 'surface']: 82 | assert result.shape == (res, res, 2) 83 | else: # univariate 84 | assert result.shape == (res, res) 85 | -------------------------------------------------------------------------------- /tests/test_split_buffers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import wgpu 4 | 5 | from topsy import split_buffers 6 | 7 | @pytest.fixture 8 | def device(): 9 | adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance") 10 | return adapter.request_device_sync() 11 | 12 | @pytest.fixture 13 | def sb(max_per_buffer=15, num_particles=50): 14 | return split_buffers.SplitBuffers(num_particles, max_per_buffer) 15 | 16 | def test_global_to_split(sb): 17 | 18 | assert sb.global_to_split(0, 10) == ([0], [0], [10]) 19 | assert sb.global_to_split(0, 20) == ([0, 1], [0, 0], [15, 5]) 20 | assert sb.global_to_split(0, 45) == ([0, 1, 2], [0, 0, 0], [15, 15, 15]) 21 | assert sb.global_to_split(15, 10) == ([1], [0], [10]) 22 | assert sb.global_to_split(14, 2) == ([0, 1], [14, 0], [1, 1]) 23 | assert sb.global_to_split(20, 20) == ([1, 2], [5, 0], [10, 10]) 24 | assert sb.global_to_split(49, 1) == ([3], [4], [1]) 25 | assert sb.global_to_split(0, 50) == ([0, 1, 2, 3], [0, 0, 0, 0], [15, 15, 15, 5]) 26 | 27 | with pytest.raises(ValueError): 28 | # too long: 29 | sb.global_to_split(0, 100) 30 | 31 | with pytest.raises(ValueError): 32 | sb.global_to_split(49,2) 33 | 34 | def test_global_to_split_monotonic(sb): 35 | def generate_test_case(): 36 | starts = np.random.randint(0, 50, size=5) 37 | starts = np.sort(starts) 38 | lengths = np.diff(starts) 39 | starts = starts[:-1] 40 | lengths = np.random.randint(lengths+1) 41 | 42 | # remove any cases i with lengths[i] = 0 from the starts and lengths lists: 43 | mask = lengths!=0 44 | starts = starts[mask] 45 | lengths = lengths[mask] 46 | 47 | return (starts, lengths) 48 | 49 | np.random.seed(1337) 50 | for i in range(100): 51 | starts, lengths = generate_test_case() 52 | 53 | # faster version to be used in code: 54 | results = sb.global_to_split_monotonic(starts, lengths) 55 | 56 | # slow version to act as check: 57 | results_slow = [([], []) for _ in range(sb._num_buffers)] 58 | 59 | for global_s, global_l in zip(starts, lengths): 60 | local_bufs, local_starts, local_lengths = sb.global_to_split(global_s, global_l) 61 | for lb, ls, ll in zip(local_bufs, local_starts, local_lengths): 62 | results_slow[lb][0].append(ls) 63 | results_slow[lb][1].append(ll) 64 | 65 | assert results == results_slow 66 | 67 | 68 | def test_create_buffers(device, sb): 69 | # Create a buffer with 4 buffers of size 15 and one of size 5 70 | buffers = sb.create_buffers(device, 4, wgpu.BufferUsage.UNIFORM) 71 | assert len(buffers) == 4 72 | assert all([buf.size == 15*4 for buf in buffers[:-1]]) 73 | assert buffers[-1].size == 5*4 74 | 75 | def test_write_buffers(device, sb): 76 | buffers = sb.create_buffers(device, 4, wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST) 77 | 78 | data = np.arange(50, dtype=np.float32) 79 | 80 | with pytest.raises(ValueError): 81 | # wrong number of buffers 82 | sb.write_buffers(device, buffers[:-1], data) 83 | 84 | with pytest.raises(ValueError): 85 | # wrong number of particles 86 | sb.write_buffers(device, buffers, data[:-1]) 87 | 88 | # should succeed 89 | sb.write_buffers(device, buffers, data) -------------------------------------------------------------------------------- /src/topsy/colorbar.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib, matplotlib.backends.backend_agg 3 | import matplotlib.figure as figure 4 | import matplotlib.colors as colors 5 | import numpy as np 6 | import wgpu 7 | 8 | from .overlay import Overlay 9 | 10 | 11 | class ColorbarOverlay(Overlay): 12 | def __init__(self, visualizer, vmin, vmax, colormap, label, *, dpi_logical=72, **kwargs): 13 | self.dpi_logical = dpi_logical 14 | self.kwargs = kwargs 15 | self._aspect_ratio = 0.2 16 | 17 | params = visualizer.colormap.get_parameters() 18 | 19 | self._vmin = params['vmin'] 20 | self._vmax = params['vmax'] 21 | self._colormap = params['colormap_name'] 22 | 23 | self.label = label 24 | self._last_width = None 25 | self._last_height = None 26 | 27 | super().__init__(visualizer) 28 | 29 | def get_clipspace_coordinates(self, pixel_width, pixel_height): 30 | im = self.get_contents() 31 | height = 2.0 32 | width = 2.0*pixel_height*im.shape[1]/im.shape[0]/pixel_width 33 | x,y = 1.0-width,-1.0 34 | if self._last_width!=pixel_width or self._last_height!=pixel_height: 35 | # contents is the wrong size 36 | self.update() 37 | self._last_width = pixel_width 38 | self._last_height = pixel_height 39 | return x, y, width, height 40 | 41 | def encode_render_pass(self, command_encoder: wgpu.GPUCommandEncoder, target_texture_view: wgpu.GPUTextureView, 42 | clear=False): 43 | 44 | self._ensure_texture_is_current() 45 | super().encode_render_pass(command_encoder, target_texture_view, clear) 46 | 47 | def _ensure_texture_is_current(self): 48 | params = self._visualizer.colormap.get_parameters() 49 | changed = (self._vmin != params['vmin'] or 50 | self._vmax != params['vmax'] or 51 | self._colormap != params['colormap_name']) 52 | if changed: 53 | self._vmin = params['vmin'] 54 | self._vmax = params['vmax'] 55 | self._colormap = params['colormap_name'] 56 | self.update() 57 | 58 | def render_contents(self): 59 | dpi_physical = self.dpi_logical*self._visualizer.canvas.pixel_ratio 60 | 61 | fig = figure.Figure(figsize=(self._visualizer.canvas.height_physical * self._aspect_ratio/dpi_physical, 62 | self._visualizer.canvas.height_physical/dpi_physical), 63 | dpi=dpi_physical, 64 | facecolor=(1.0, 1.0, 1.0, 0.5)) 65 | 66 | canvas = matplotlib.backends.backend_agg.FigureCanvasAgg(fig) 67 | 68 | cmap = matplotlib.colormaps[self._colormap] 69 | cNorm = colors.Normalize(vmin=self._vmin, vmax=self._vmax) 70 | cb1 = matplotlib.colorbar.ColorbarBase(fig.add_axes([0.05, 0.05, 0.3, 0.9]), 71 | cmap=cmap, norm=cNorm, orientation='vertical') 72 | cb1.set_label(self.label) 73 | 74 | fig.canvas.draw() 75 | width,height = fig.canvas.get_width_height(physical=True) 76 | 77 | result: np.ndarray = np.frombuffer(fig.canvas.buffer_rgba(),dtype=np.uint8).reshape((height,width,4)).transpose((1,0,2)) 78 | result = result.swapaxes(0,1).astype(np.float32)/256 79 | 80 | return result 81 | -------------------------------------------------------------------------------- /tests/test_interpolation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from topsy.recorder import interpolator 4 | 5 | def test_step_interpolator(): 6 | timestream = [(0.4, None), (1.0, "Hello"), (2.0, "World")] 7 | interp = interpolator.StepInterpolator(timestream) 8 | assert interp(0.0) is interp.no_value 9 | assert interp(0.5) is None 10 | assert interp(1.0) == "Hello" 11 | assert interp(1.5) is interp.no_value 12 | assert interp(1.9) is interp.no_value 13 | assert interp(2.1) == "World" 14 | 15 | def test_linear_interpolator(): 16 | timestream = [(0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (4.0, 0.0)] 17 | interp = interpolator.LinearInterpolator(timestream) 18 | assert interp(0.0) == 0.0 19 | assert interp(0.5) == 0.5 20 | assert interp(1.0) == 1.0 21 | assert interp(1.5) == 2.5 22 | assert interp(2.0) == 4.0 23 | assert interp(3.0) == 2.0 24 | assert interp(4.0) == 0.0 25 | assert interp(4.01) is interp.no_value 26 | 27 | def test_smoothed_step_interpolation_with_none(): 28 | timestream = [(0.0, 0.0), (1.0, None), (2.0, 4.0), (4.0, 0.0)] 29 | interp = interpolator.SmoothedStepInterpolator(timestream) 30 | assert interp(0.0) == 0.0 31 | assert interp(0.5) is interp.no_value 32 | assert interp(1.99) is interp.no_value 33 | assert interp(2.01) == 4.0 34 | assert interp(3.0) is interp.no_value 35 | assert interp(4.0) == 4.0 36 | assert interp(4.125) == 2.0 37 | assert interp(4.25) == 0.0 38 | assert interp(4.5) is interp.no_value 39 | 40 | def test_rotation_interpolator(): 41 | timestream = [(0.0, np.eye(3)), (1.0, np.array([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]))] 42 | interp = interpolator.RotationInterpolator(timestream) 43 | assert np.allclose(interp(0.0), np.eye(3)) 44 | midway = interp(0.5) 45 | assert np.allclose(midway @ midway.T, np.eye(3)) 46 | assert midway[0,0]<1.0 and midway[0,0]>0.0 47 | assert midway[0,1]<1.0 and midway[0,1]>0.0 48 | assert np.allclose(interp(1.0), np.array([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])) 49 | 50 | def test_smoothed_linear_interpolator(): 51 | timestream = [(0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (4.0, 0.0)] 52 | interp = interpolator.SmoothedLinearInterpolator(timestream, smoothing=0.5) 53 | assert np.allclose(interp(0.0), 0.18728488638447055) 54 | assert np.allclose(interp(0.5), 0.5833226799824336) 55 | assert np.allclose(interp(1.0), 1.3206482380039408) 56 | assert np.allclose(interp(1.5), 2.3157963036465143) 57 | assert np.allclose(interp(3.9), 0.5695905129936958) 58 | assert np.allclose(interp(4.0), 0.4616432412285651) 59 | assert interp(4.1) is interp.no_value 60 | # check smoothness 61 | assert abs(np.diff(np.diff([interp(x) for x in np.arange(0.0,4.0,0.05)]))).max()<0.02 62 | 63 | def test_smoothed_rotation_interpolator(): 64 | timestream = [(0.0, np.eye(3)), (1.0, np.array([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]))] 65 | interp = interpolator.SmoothedRotationInterpolator(timestream, smoothing=0.5) 66 | for x in np.arange(0.0,1.0,0.1): 67 | assert np.allclose(interp(x) @ interp(x).T, np.eye(3)) 68 | 69 | def test_smoothed_step_interpolator(): 70 | timestream = [(0.0, 0.0), (1.0, 5.0), (2.0, 0.0)] 71 | interp = interpolator.SmoothedStepInterpolator(timestream, smoothing=0.5) 72 | assert interp(0.1) == 0.0 73 | assert interp(0.5) is interp.no_value 74 | assert interp(1.0) == 0.0 75 | assert interp(1.125) == 1.25 76 | assert interp(1.25) == 2.5 77 | assert interp(1.5) == 5.0 78 | assert interp(1.75) is interp.no_value 79 | assert interp(1.99) is interp.no_value 80 | assert interp(2.0) == 5.0 81 | assert interp(2.25) == 2.5 82 | -------------------------------------------------------------------------------- /src/topsy/periodic_sph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from topsy.drawreason import DrawReason 4 | 5 | from . import sph 6 | from . import overlay 7 | 8 | from typing import TYPE_CHECKING 9 | 10 | if TYPE_CHECKING: 11 | from .visualizer import Visualizer 12 | 13 | import numpy as np 14 | import wgpu 15 | 16 | class PeriodicSPHAccumulationOverlay(overlay.Overlay): 17 | _blending = { 18 | "src_factor": wgpu.BlendFactor.one, 19 | "dst_factor": wgpu.BlendFactor.one, 20 | "operation": wgpu.BlendOperation.add 21 | } 22 | 23 | def __init__(self, visualizer: Visualizer, source_texture: wgpu.GPUTexture): 24 | self._texture = source_texture 25 | self.num_repetitions = 0 26 | self.panel_scale = 1.0 27 | self.rotation_matrix = np.eye(3) 28 | super().__init__(visualizer, source_texture.format) 29 | 30 | def _setup_texture(self): 31 | pass 32 | 33 | def get_clipspace_coordinates(self, width, height) -> tuple[float, float, float, float]: 34 | return -1.0, -1.0, 2.0, 2.0 35 | 36 | def get_instance_offsets_and_weights(self): 37 | offsets = [] 38 | weights = [] 39 | 40 | for xoff in range(-self.num_repetitions, self.num_repetitions + 1): 41 | for yoff in range(-self.num_repetitions, self.num_repetitions + 1): 42 | for zoff in range(-self.num_repetitions, self.num_repetitions + 1): 43 | offset = self.rotation_matrix @ np.array([xoff,yoff,zoff], dtype=np.float32) 44 | if abs(offset[2]) < 1.0: 45 | offsets.append(offset[:2]) 46 | z = abs(offset[2]) 47 | # the weight be 1 for 00.5: 49 | weight = 1.0 - 2.0*(z-0.5) 50 | else: 51 | weight = 1.0 52 | weights.append(weight) 53 | 54 | return np.array(offsets, dtype=np.float32) * self.panel_scale, np.array(weights, dtype=np.float32) 55 | 56 | def render_contents(self) -> np.ndarray: 57 | # must be implemented, but should never be called because texture is provided externally 58 | raise RuntimeError("SPHAccumulationOverlay.render_contents() should never be called") 59 | 60 | class PeriodicSPH(sph.SPH): 61 | def __init__(self, visualizer, render_size): 62 | super().__init__(visualizer, render_size, wrapping=True) 63 | self._periodic_texture = visualizer.device.create_texture( 64 | size=self._render_texture.size, 65 | format=self._render_texture.format, 66 | usage=self._render_texture.usage, 67 | label=f"proxy_sph" 68 | ) 69 | self._accumulator = PeriodicSPHAccumulationOverlay(visualizer, self._render_texture) 70 | 71 | def get_output_texture(self) -> wgpu.Texture: 72 | return self._periodic_texture 73 | 74 | def render(self, draw_reason=DrawReason.CHANGE): 75 | if draw_reason == DrawReason.PRESENTATION_CHANGE: 76 | return 77 | 78 | super().render(draw_reason) 79 | 80 | command_encoder = self._visualizer.device.create_command_encoder(label="PeriodicSPH") 81 | self._accumulator.num_repetitions = 2 82 | self._accumulator.rotation_matrix = self.rotation_matrix 83 | self._accumulator.panel_scale = self._visualizer.periodicity_scale / self._visualizer.scale 84 | 85 | self._accumulator.encode_render_pass(command_encoder, self._periodic_texture.create_view(), True) 86 | 87 | encoded_render_pass = command_encoder.finish() 88 | self._device.queue.submit([encoded_render_pass]) 89 | 90 | -------------------------------------------------------------------------------- /src/topsy/shaders/surface.wgsl: -------------------------------------------------------------------------------- 1 | // Vertex and fragment shaders for surface rendering 2 | 3 | struct VertexOutput { 4 | @builtin(position) pos: vec4, 5 | @location(0) texCoord: vec2, 6 | } 7 | 8 | // Fragment shader 9 | @group(0) @binding(0) var colorTexture: texture_2d; 10 | @group(0) @binding(1) var textureSampler: sampler; 11 | 12 | struct Uniforms { 13 | depthScale: f32, // Scale factor for depth values 14 | lightDirection: vec3, // Direction to light source 15 | lightColor: vec3, // Light color 16 | ambientColor: vec3, // Ambient light color 17 | texelSize: vec2, // 1.0 / texture dimensions 18 | windowAspectRatio: f32, // Window aspect ratio for proper scaling 19 | vmin: f32, 20 | vmax: f32 21 | } 22 | 23 | @group(0) @binding(2) var uniforms: Uniforms; 24 | 25 | @group(0) @binding(3) 26 | var colormap_texture: texture_1d; 27 | 28 | 29 | fn sampleDepth(coord: vec2) -> f32 { 30 | let samp = textureSample(colorTexture, textureSampler, coord); 31 | return samp.g * uniforms.depthScale; // Using alpha channel for depth 32 | } 33 | 34 | // Compute surface normal using finite differences 35 | fn computeNormal(texCoord: vec2) -> vec3 { 36 | let texelSize = uniforms.texelSize; 37 | 38 | // Sample depth at neighboring pixels 39 | 40 | let depthLeft = sampleDepth(texCoord + vec2(-texelSize.x, 0.0)); 41 | let depthRight = sampleDepth(texCoord + vec2(texelSize.x, 0.0)); 42 | let depthUp = sampleDepth(texCoord + vec2(0.0, -texelSize.y)); 43 | let depthDown = sampleDepth(texCoord + vec2(0.0, texelSize.y)); 44 | 45 | // Compute gradients using central differences 46 | let dX = (depthRight - depthLeft) * 0.5; 47 | let dY = (depthDown - depthUp) * 0.5; 48 | 49 | // Construct normal vector 50 | // The normal points "outward" from the surface 51 | let normal = normalize(vec3(-dX, -dY, texelSize.x)); 52 | 53 | return normal; 54 | } 55 | 56 | 57 | fn computeLighting(texCoord: vec2, materialColor: vec3) -> vec3 { 58 | let normal = computeNormal(texCoord); 59 | let depthCenter = sampleDepth(texCoord); 60 | let lightDir = uniforms.lightDirection; 61 | let NdotL = max(dot(normal, lightDir), 0.0); 62 | 63 | let diffuse = uniforms.lightColor * NdotL * materialColor; 64 | let ambient = uniforms.ambientColor * materialColor; 65 | 66 | return (diffuse + ambient)*clamp(depthCenter, 0.0, 0.5)*2.0; 67 | } 68 | 69 | @vertex 70 | fn vertex_main(@builtin(vertex_index) vertexIndex : u32) -> VertexOutput { 71 | var pos = array, 4>( 72 | vec2(-1.0, -1.0), 73 | vec2(-1.0, 1.0), 74 | vec2(1.0, -1.0), 75 | vec2(1.0, 1.0) 76 | ); 77 | 78 | // Apply aspect ratio scaling 79 | if (uniforms.windowAspectRatio > 1.0) { 80 | for (var i = 0u; i < 4u; i = i + 1u) { 81 | pos[i].y = pos[i].y * uniforms.windowAspectRatio; 82 | } 83 | } else { 84 | for (var i = 0u; i < 4u; i = i + 1u) { 85 | pos[i].x = pos[i].x / uniforms.windowAspectRatio; 86 | } 87 | } 88 | 89 | var texc = array, 4>( 90 | vec2(0.0, 1.0), 91 | vec2(0.0, 0.0), 92 | vec2(1.0, 1.0), 93 | vec2(1.0, 0.0) 94 | ); 95 | 96 | var output: VertexOutput; 97 | output.pos = vec4(pos[vertexIndex], 0.0, 1.0); 98 | output.texCoord = texc[vertexIndex]; 99 | 100 | return output; 101 | } 102 | 103 | fn log10(value: f32) -> f32 { 104 | return log(value)/2.30258509; 105 | } 106 | 107 | @fragment 108 | fn fs_main(input: VertexOutput) -> @location(0) vec4 { 109 | #ifdef MATERIAL_COLORMAP 110 | var value = textureSample(colorTexture, textureSampler, input.texCoord).r; 111 | // NB above could be optimized by combining it with the g sample taken for the depth elsewhere 112 | 113 | #ifdef MATERIAL_LOG 114 | value = log10(value); 115 | #endif 116 | value = clamp((value - uniforms.vmin) / (uniforms.vmax - uniforms.vmin), 0.0, 1.0); 117 | let materialColor = textureSample(colormap_texture, textureSampler, value).rgb; 118 | #else 119 | let materialColor = vec3(1.0, 1.0, 1.0); 120 | #endif 121 | let lighting = computeLighting(input.texCoord, materialColor); 122 | return vec4(lighting, 1.0); 123 | } 124 | -------------------------------------------------------------------------------- /tests/test_jupyter_interface.py: -------------------------------------------------------------------------------- 1 | import time 2 | import ipywidgets as widgets 3 | from playwright.sync_api import Page 4 | import pytest 5 | 6 | from IPython.display import display 7 | 8 | 9 | from typing import Callable 10 | 11 | import topsy, topsy.canvas.jupyter 12 | 13 | 14 | def poll_until_true(assertion: Callable, timeout=2, iteration_delay=0.01): 15 | start = time.time() 16 | while time.time() - start < timeout: 17 | if assertion(): 18 | return True 19 | time.sleep(iteration_delay) 20 | return False 21 | 22 | @pytest.fixture 23 | def jupyter_vis(solara_test): 24 | vis = topsy.test(100, canvas_class = topsy.canvas.jupyter.VisualizerCanvas) 25 | display(vis) 26 | return vis 27 | 28 | @pytest.fixture 29 | def jupyter_vis_surface(solara_test): 30 | vis = topsy.test(100, canvas_class = topsy.canvas.jupyter.VisualizerCanvas, 31 | render_mode='surface') 32 | display(vis) 33 | return vis 34 | 35 | 36 | def test_colormap_name_select(jupyter_vis, page_session: Page): 37 | vis = jupyter_vis 38 | 39 | assert vis.colormap.get_parameter('colormap_name') == "twilight_shifted" 40 | 41 | sel = page_session.locator("select:has-text('twilight_shifted')") 42 | sel.wait_for() 43 | sel.select_option("twilight") 44 | 45 | assert poll_until_true(lambda: vis.colormap.get_parameter('colormap_name') == "twilight") 46 | 47 | assert vis.quantity_name == None 48 | 49 | def test_quantity_name_select(jupyter_vis, page_session: Page): 50 | sel = page_session.locator(f"select:has-text('{topsy.config.PROJECTED_DENSITY_NAME}')") 51 | cb = page_session.locator("input[type='checkbox']") 52 | 53 | sel.wait_for() 54 | cb.wait_for() 55 | 56 | assert cb.is_checked() 57 | sel.select_option("test-quantity") 58 | 59 | 60 | assert poll_until_true(lambda: jupyter_vis.quantity_name == "test-quantity") 61 | 62 | # check that log quantity is no longer selected 63 | assert poll_until_true(lambda: not cb.is_checked()) 64 | 65 | def test_alter_range(jupyter_vis, page_session: Page): 66 | vis = jupyter_vis 67 | 68 | min_slider = page_session.locator("div.noUi-handle-lower") 69 | max_slider = page_session.locator("div.noUi-handle-upper") 70 | 71 | min_slider.wait_for() 72 | max_slider.wait_for() 73 | 74 | vmin_orig = vis.colormap.get_parameter('vmin') 75 | vmax_orig = vis.colormap.get_parameter('vmax') 76 | 77 | # Use keyboard to move the slider instead of drag_to to avoid pointer event interception issues 78 | min_slider.focus() 79 | page_session.keyboard.press("ArrowLeft") 80 | page_session.keyboard.press("ArrowLeft") 81 | max_slider.focus() 82 | page_session.keyboard.press("ArrowRight") 83 | page_session.keyboard.press("ArrowRight") 84 | 85 | assert poll_until_true(lambda: vis.colormap.get_parameter('vmin') < vmin_orig) 86 | assert poll_until_true(lambda: vis.colormap.get_parameter('vmax') > vmax_orig) 87 | 88 | def test_rgb_map(solara_test, page_session: Page): 89 | vis = topsy.test(100, canvas_class = topsy.canvas.jupyter.VisualizerCanvas, 90 | render_mode='rgb') 91 | display(vis) 92 | 93 | # at the moment we just check this actually gives the alternative panel 94 | sel = page_session.locator("text=gamma") 95 | sel.wait_for() 96 | 97 | 98 | def test_quantity_bar_adapting(jupyter_vis_surface, page_session: Page): 99 | 100 | sel = page_session.locator("select:has-text('Projected density')") 101 | sel.wait_for() 102 | 103 | # Check that no vmin/vmax slider exists initially 104 | assert page_session.locator("div.noUi-handle-upper").count() == 0 105 | 106 | sel.select_option("test-quantity") 107 | 108 | # Wait for vmin/vmax slider to appear. NB there's other sliders, just not range sliders, so here 109 | # we look for the 'upper' handle (the 'lower' handles exist in single-value sliders) 110 | assert poll_until_true(lambda: page_session.locator("div.noUi-handle-upper").count() > 0) 111 | 112 | # Change quantity back 113 | sel = page_session.locator("select:has-text('test-quantity')") 114 | sel.wait_for() 115 | sel.select_option("Projected density") 116 | 117 | # Wait for vmin/vmax sliders to disappear 118 | assert poll_until_true(lambda: page_session.locator("div.noUi-handle-upper").count() == 0) -------------------------------------------------------------------------------- /src/topsy/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import wgpu 4 | import re 5 | import os 6 | 7 | def load_shader(name): 8 | from importlib import resources 9 | with open(resources.files("topsy.shaders") / name, "r") as f: 10 | return f.read() 11 | 12 | def preprocess_shader(shader_code, active_flags): 13 | """A hacky preprocessor for WGSL shaders. 14 | 15 | Any line in shader_code containing [[FLAG]] will be removed if FLAG is not in active_flags. 16 | Otherwise, the string [[FLAG]] will be removed, leaving just valid syntax. 17 | 18 | This is needed because we can't use 19 | const values in the shader yet, so we need to use something like #ifdefs instead. 20 | In final version of webgpu doesn't look like this will be needed""" 21 | for f in active_flags: 22 | shader_code = re.sub(f"^.*\[\[{f}]](.*)$", r"\1", shader_code, flags=re.MULTILINE) 23 | shader_code = re.sub(r"^.*\[\[[A-Z_]+]].*$", "", shader_code, flags=re.MULTILINE) 24 | 25 | # process #ifdef / #else / #endif 26 | lines = shader_code.splitlines() 27 | output_lines: list[str] = [] 28 | include_stack: list[tuple[bool, bool]] = [] 29 | current_include = True 30 | 31 | for line in lines: 32 | stripped = line.lstrip() 33 | if stripped.startswith('#ifdef'): 34 | _, flag = stripped.split(None, 1) 35 | should_include = current_include and (flag in active_flags) 36 | include_stack.append((current_include, should_include)) 37 | current_include = should_include 38 | 39 | elif stripped.startswith('#else'): 40 | parent_include, prev_include = include_stack[-1] 41 | new_include = parent_include and not prev_include 42 | include_stack[-1] = (parent_include, new_include) 43 | current_include = new_include 44 | 45 | elif stripped.startswith('#endif'): 46 | parent_include, _ = include_stack.pop() 47 | current_include = parent_include 48 | 49 | elif current_include: 50 | output_lines.append(line) 51 | 52 | result = "\n".join(output_lines) 53 | 54 | return result 55 | 56 | def is_inside_ipython(): 57 | try: 58 | __IPYTHON__ 59 | return True 60 | except NameError: 61 | return False 62 | 63 | def is_inside_jupyter_notebook(): 64 | return "JPY_SESSION_NAME" in os.environ 65 | 66 | def is_ipython_running_qt_event_loop(): 67 | if not is_inside_ipython(): 68 | return False 69 | import IPython.lib.guisupport 70 | return IPython.lib.guisupport.is_event_loop_running_qt4() 71 | 72 | def determine_backend(): 73 | if is_inside_ipython(): 74 | pass 75 | 76 | class TimeGpuOperation: 77 | """Context manager for timing GPU operations""" 78 | def __init__(self, device, n_frames_smooth=10): 79 | self.device = device 80 | self.n_frames_smooth = n_frames_smooth 81 | self._recent_times = [] 82 | self._current_frame_duration = 0.0 83 | 84 | def __enter__(self): 85 | self.device.queue.on_submitted_work_done_sync() 86 | self.__block_start = time.time() 87 | return self 88 | 89 | def __exit__(self, *args): 90 | # Now, we want to measure how much time the render has taken so that we can adapt 91 | # for the next frame if needed. However, the GPU is asynchronous. In the long term 92 | # there should be facilities like callbacks or querysets to help with this, but 93 | # right now these don't seem to be implemented. So we need to make something block 94 | # until the current queue is complete. The hack here is to do a trivial read 95 | # operation 96 | 97 | self.device.queue.on_submitted_work_done_sync() 98 | block_end = time.time() 99 | self._current_frame_duration += block_end - self.__block_start 100 | 101 | def end_frame(self): 102 | self.last_duration = self._current_frame_duration 103 | self._current_frame_duration = 0.0 104 | 105 | self._recent_times.append(self.last_duration) 106 | if len(self._recent_times) > self.n_frames_smooth: 107 | self._recent_times.pop(0) 108 | 109 | def total_time_in_frame(self): 110 | """Return the time elapsed in GPU operations since the last call to end_frame""" 111 | return self._current_frame_duration 112 | 113 | @property 114 | def running_mean_duration(self): 115 | return np.mean(self._recent_times) 116 | -------------------------------------------------------------------------------- /src/topsy/shaders/colormap.wgsl: -------------------------------------------------------------------------------- 1 | struct ColormapParams { 2 | vmin: f32, 3 | vmax: f32, 4 | density_vmin: f32, // used only in bivariate case 5 | density_vmax: f32, // used only in bivariate case 6 | window_aspect_ratio: f32, 7 | gamma: f32 8 | }; 9 | 10 | struct VertexOutput { 11 | @builtin(position) pos: vec4, 12 | @location(0) texcoord: vec2, 13 | } 14 | 15 | struct FragmentOutput { 16 | @location(0) color: vec4 17 | } 18 | 19 | 20 | @group(0) @binding(0) 21 | var image_texture: texture_2d; 22 | 23 | @group(0) @binding(1) 24 | var image_sampler: sampler; 25 | 26 | #ifdef BIVARIATE 27 | @group(0) @binding(2) 28 | var colormap_texture: texture_2d; 29 | #else 30 | @group(0) @binding(2) 31 | var colormap_texture: texture_1d; 32 | #endif 33 | 34 | @group(0) @binding(3) 35 | var colormap_sampler: sampler; 36 | 37 | @group(0) @binding(4) 38 | var colormap_params: ColormapParams; 39 | 40 | 41 | @vertex 42 | fn vertex_main(@builtin(vertex_index) vertexIndex : u32) -> VertexOutput { 43 | var pos = array, 4>( 44 | vec2(-1.0, -1.0), 45 | vec2(-1.0, 1.0), 46 | vec2(1.0, -1.0), 47 | vec2(1.0, 1.0) 48 | ); 49 | 50 | if(colormap_params.window_aspect_ratio>1.0) { 51 | for(var i = 0u; i<4u; i=i+1u) { 52 | pos[i].y = pos[i].y*colormap_params.window_aspect_ratio; 53 | } 54 | } else { 55 | for(var i = 0u; i<4u; i=i+1u) { 56 | pos[i].x = pos[i].x/colormap_params.window_aspect_ratio; 57 | } 58 | } 59 | 60 | var texc = array, 4>( 61 | vec2(0.0, 1.0), 62 | vec2(0.0, 0.0), 63 | vec2(1.0, 1.0), 64 | vec2(1.0, 0.0) 65 | ); 66 | 67 | var output: VertexOutput; 68 | 69 | output.pos = vec4(pos[vertexIndex], 0.0, 1.0); 70 | output.texcoord = texc[vertexIndex]; 71 | 72 | return output; 73 | } 74 | 75 | fn log10(value: f32) -> f32 { 76 | return log(value)/2.30258509; 77 | } 78 | 79 | @fragment 80 | fn fragment_main(input: VertexOutput) -> FragmentOutput { 81 | var output: FragmentOutput; 82 | 83 | var values = textureSample(image_texture, image_sampler, input.texcoord); 84 | 85 | var value : f32; 86 | 87 | // Note the following lines are selected by python before compile time 88 | // ultimately, this should be possible within wgsl itself by using a const, but this doesn't 89 | // seem to be supported at present 90 | 91 | #ifdef BIVARIATE 92 | var value_2d : vec2; 93 | value_2d.x = log10(values.r); 94 | value_2d.x -= colormap_params.density_vmin; 95 | value_2d.x /= (colormap_params.density_vmax - colormap_params.density_vmin); 96 | 97 | #ifdef WEIGHTED_MEAN 98 | value_2d.y = values.g / values.r; 99 | #else 100 | value_2d.y = values.r; 101 | #endif 102 | 103 | #ifdef LOG_SCALE 104 | value_2d.y = log10(value_2d.y); 105 | #endif 106 | 107 | value_2d.y -= colormap_params.vmin; 108 | value_2d.y /= (colormap_params.vmax - colormap_params.vmin); 109 | 110 | value_2d = clamp(value_2d, vec2(0.0, 0.0), vec2(1.0, 1.0)); 111 | output.color = textureSample(colormap_texture, colormap_sampler, value_2d); 112 | 113 | #else // not BIVARIATE 114 | 115 | #ifdef WEIGHTED_MEAN 116 | value = values.g/values.r; 117 | #else 118 | value = values.r; 119 | #endif 120 | 121 | #ifdef LOG_SCALE 122 | value = log10(value); 123 | #endif 124 | 125 | value = clamp((value-colormap_params.vmin)/(colormap_params.vmax-colormap_params.vmin), 0.0, 1.0); 126 | output.color = textureSample(colormap_texture, colormap_sampler, value); 127 | #endif // not BIVARIATE 128 | return output; 129 | } 130 | 131 | fn gamma_map(value: f32, vmin: f32, vmax: f32, gamma: f32) -> f32 { 132 | return pow(max((value - vmin)/(vmax - vmin), 0.0), gamma); 133 | } 134 | 135 | @fragment 136 | fn fragment_main_tri(input: VertexOutput) -> FragmentOutput { 137 | var output: FragmentOutput; 138 | var value_r: f32; 139 | var value_g: f32; 140 | var value_b: f32; 141 | 142 | value_r = textureSample(image_texture, image_sampler, input.texcoord).r; 143 | value_g = textureSample(image_texture, image_sampler, input.texcoord).g; 144 | value_b = textureSample(image_texture, image_sampler, input.texcoord).b; 145 | 146 | #ifdef LOG_SCALE 147 | value_r = log10(value_r); 148 | value_g = log10(value_g); 149 | value_b = log10(value_b); 150 | #endif 151 | 152 | value_r = gamma_map(value_r, colormap_params.vmin, colormap_params.vmax, colormap_params.gamma); 153 | value_g = gamma_map(value_g, colormap_params.vmin, colormap_params.vmax, colormap_params.gamma); 154 | value_b = gamma_map(value_b, colormap_params.vmin, colormap_params.vmax, colormap_params.gamma); 155 | 156 | output.color = vec4(value_r, value_g, value_b, 1.0); 157 | 158 | return output; 159 | } -------------------------------------------------------------------------------- /tests/test_synchronizer.py: -------------------------------------------------------------------------------- 1 | from topsy.view_synchronizer import ViewSynchronizer 2 | 3 | class DummyTarget: 4 | def __init__(self): 5 | self.reset_update_count() 6 | 7 | def __setattr__(self, __name, __value): 8 | super().__setattr__("_num_updates", self._num_updates + 1) 9 | super().__setattr__(__name, __value) 10 | 11 | def reset_update_count(self): 12 | super().__setattr__("_num_updates", 0) 13 | 14 | @property 15 | def update_count(self): 16 | return self._num_updates 17 | 18 | class SubObject: 19 | def __init__(self): 20 | self.value = 0 21 | 22 | def test_synchronizer(): 23 | """Test the Synchronizer class to ensure it correctly synchronizes attributes between objects.""" 24 | source = DummyTarget() 25 | target1 = DummyTarget() 26 | target2 = DummyTarget() 27 | 28 | synchronizer = ViewSynchronizer(['attr1', 'attr2']) 29 | synchronizer.add_view(source) 30 | synchronizer.add_view(target1) 31 | synchronizer.add_view(target2) 32 | 33 | # Set initial values 34 | source.attr1 = 10 35 | source.attr2 = 20 36 | 37 | source.reset_update_count() 38 | 39 | synchronizer.perpetuate_update(source) 40 | 41 | assert source.update_count == 0 42 | assert target1.attr1 == 10 43 | assert target1.attr2 == 20 44 | assert target2.attr1 == 10 45 | assert target2.attr2 == 20 46 | 47 | def test_synchronizer_subobjects(): 48 | """Test the Synchronizer class with nested objects.""" 49 | 50 | 51 | source = DummyTarget() 52 | source.sub = SubObject() 53 | target1 = DummyTarget() 54 | target1.sub = SubObject() 55 | target2 = DummyTarget() 56 | target2.sub = SubObject() 57 | 58 | synchronizer = ViewSynchronizer(['sub.value']) 59 | synchronizer.add_view(source) 60 | synchronizer.add_view(target1) 61 | synchronizer.add_view(target2) 62 | 63 | # Set initial values 64 | source.sub.value = 42 65 | 66 | source.reset_update_count() 67 | 68 | synchronizer.perpetuate_update(source) 69 | 70 | assert source.update_count == 0 71 | assert target1.sub.value == 42 72 | assert target2.sub.value == 42 73 | 74 | assert target1.update_count > 0 75 | assert target2.update_count > 0 76 | 77 | target1.reset_update_count() 78 | target2.reset_update_count() 79 | 80 | synchronizer.perpetuate_update(target1) 81 | synchronizer.perpetuate_update(target2) 82 | 83 | # neither of the above should have generated any ops either on targets or source, because 84 | # they are expected to be "acknowledging receipt" of the update 85 | assert source.update_count == 0 86 | assert target1.update_count == 0 87 | assert target2.update_count == 0 88 | 89 | def test_synchronizer_custom_setter(): 90 | """Test the Synchronizer class with a custom setter.""" 91 | class CustomTarget: 92 | def __init__(self): 93 | self.result_dict = {} 94 | 95 | def update_value(self, name, value): 96 | self.result_dict[name] = value 97 | 98 | def get_value(self, name): 99 | return self.result_dict.get(name, None) 100 | 101 | def __getitem__(self, item): 102 | return self.get_value(item) 103 | 104 | def __setitem__(self, key, value): 105 | self.update_value(key, value) 106 | 107 | 108 | 109 | 110 | source = DummyTarget() 111 | target1 = CustomTarget() 112 | 113 | source.sub = SubObject() 114 | source.sub.value = 42 115 | source.value = 1 116 | 117 | synchronizer = ViewSynchronizer(['value', 'sub.value']) 118 | synchronizer.add_view(source) 119 | synchronizer.add_view(target1, setter = CustomTarget.update_value, getter = CustomTarget.get_value) 120 | 121 | source.reset_update_count() 122 | 123 | synchronizer.perpetuate_update(source) 124 | 125 | assert source.update_count == 0 126 | assert not hasattr(target1, 'value') # Should not have a direct attribute 'value' 127 | 128 | assert target1['value'] == 1 129 | assert target1['sub.value'] == 42 130 | 131 | synchronizer.perpetuate_update(target1) 132 | # ^ this is actually just "acknowledging receipt". Otherwise the next update is ignored (part of the infinite 133 | # loop protection) 134 | 135 | target1['value'] = 2 136 | target1['sub.value'] = 84 137 | 138 | synchronizer.perpetuate_update(target1) 139 | 140 | assert source.value == 2 141 | assert source.sub.value == 84 142 | 143 | 144 | def test_synchronize_with_dict(): 145 | source = DummyTarget() 146 | source.data = {'key1': 1, 'key2': 2} 147 | target1 = DummyTarget() 148 | target1.data = {'key1': 0, 'key2': 0} 149 | 150 | synchronizer = ViewSynchronizer(['data[key1]', 'data[key2]']) 151 | synchronizer.add_view(source) 152 | synchronizer.add_view(target1) 153 | 154 | assert target1.data['key1'] == 0 155 | assert target1.data['key2'] == 0 156 | synchronizer.perpetuate_update(source) 157 | 158 | assert target1.data['key1'] == 1 159 | assert target1.data['key2'] == 2 160 | synchronizer.perpetuate_update(target1) 161 | 162 | assert source.data['key1'] == 1 163 | assert source.data['key2'] == 2 164 | -------------------------------------------------------------------------------- /src/topsy/cell_layout.py: -------------------------------------------------------------------------------- 1 | """Classes to keep track of the cellular layout of a simulation""" 2 | 3 | import numpy as np 4 | import pynbody 5 | 6 | from pynbody.filt import geometry_selection 7 | 8 | class CellLayout: 9 | """Class to keep track of segmentation of a simulation into cells""" 10 | def __init__(self, centres: np.ndarray, offsets: np.ndarray, lengths: np.ndarray): 11 | self._centres = np.ascontiguousarray(centres) 12 | self._offsets = offsets 13 | self._lengths = lengths 14 | self._num_particles = lengths.sum() 15 | self._cell_size = np.linalg.norm(self._centres[1]-self._centres[0]) 16 | 17 | def randomize_within_cells(self): 18 | """Get a reordering of the particles which randomizes the order within cells, but leaves the cell structure""" 19 | total_len = self._lengths.sum() 20 | reordering = np.empty(total_len, dtype=np.uintp) 21 | for offset, length in zip(self._offsets, self._lengths): 22 | # randomize the order of the particles within this cell 23 | reordering[offset:offset+length] = np.random.permutation(length) + offset 24 | return reordering 25 | 26 | def cells_in_sphere(self, centre: tuple[float, float, float], radius: float) -> np.ndarray: 27 | """Get the indices of the cells that are within a sphere of given centre and radius""" 28 | expand_radius = self._cell_size*np.sqrt(3.0) 29 | offsets = self._centres - centre 30 | selection = np.linalg.norm(offsets, axis=1) < (radius + expand_radius) 31 | return np.where(selection)[0] 32 | 33 | def cell_index_from_offset(self, offset: int) -> int: 34 | """Get the cell index from the offset of a particle""" 35 | 36 | cell_index = np.searchsorted(self._offsets, offset, side='right') - 1 37 | if cell_index < 0 or cell_index >= len(self._lengths): 38 | raise ValueError("Offset is out of bounds") 39 | return cell_index 40 | 41 | def cell_slice(self, cell_index: int) -> slice: 42 | """Get the indices of the particles in a given cell""" 43 | start = self._offsets[cell_index] 44 | end = start + self._lengths[cell_index] 45 | return slice(start, end) 46 | 47 | def get_num_cells(self): 48 | """Get the total number of cells""" 49 | return len(self._lengths) 50 | 51 | def get_num_particles(self): 52 | """Get the total number of particles""" 53 | return self._num_particles 54 | 55 | def get_cell_length(self, cell_index: int | np.ndarray[int]) -> int | np.ndarray[int]: 56 | """Get the length of a given cell""" 57 | return self._lengths[cell_index] 58 | 59 | def get_cell_offset(self, cell_index: int) -> int: 60 | """Get the offset of a given cell""" 61 | return self._offsets[cell_index] 62 | 63 | @classmethod 64 | def from_positions(cls, particle_positions: np.ndarray, box_min: float, box_max: float, nside: int): 65 | """Create a CellLayout object from the positions of the particles with arbitrary ordering 66 | 67 | Parameters 68 | ---------- 69 | 70 | particle_positions: array of the positions of the particles (Nx3) 71 | box_min: minimum coordinate of the box (for all 3 dimensions) 72 | box_max: maximum coordinate of the box (for all 3 dimensions) 73 | nside: number of cells in each of the 3 dimensions, e.g. nside=10 implies 10^3 cells in total 74 | 75 | Returns 76 | ------- 77 | 78 | cell_layout, particle_ordering: 79 | cell_layout: CellLayout object 80 | particle_ordering: array of the ordering of the particles to put them into the cells 81 | """ 82 | 83 | if particle_positions.min()=box_max: 84 | raise ValueError("Particle positions are outside the box") 85 | 86 | # get the cell size 87 | cell_size = (box_max - box_min) / nside 88 | 89 | # get the centre of the first cell 90 | cell_cen0 = box_min + cell_size / 2 91 | 92 | # get the cell centres 93 | centres = np.mgrid[cell_cen0:box_max:cell_size, 94 | cell_cen0:box_max:cell_size, 95 | cell_cen0:box_max:cell_size].reshape(3, -1).T 96 | 97 | # figure out the cell x,y,z indices of each particle 98 | pos_indices = np.floor((particle_positions - box_min) / cell_size).astype(np.intp) 99 | 100 | if pos_indices.min() < 0 or pos_indices.max() >= nside: 101 | raise ValueError("Particle positions are too close to edge of box; expand box size") 102 | 103 | # get the cell index of each particle 104 | cell_indices = pos_indices[:, 2] + nside * (pos_indices[:, 1] + nside * pos_indices[:, 0]) 105 | # sort the particles by cell index 106 | ordering = np.argsort(cell_indices) 107 | 108 | # figure out the segmentation 109 | lengths = np.bincount(cell_indices, minlength=nside ** 3) 110 | assert len(lengths) == len(centres), "Logic error within from_positions" 111 | 112 | offsets = np.cumsum(lengths) - lengths 113 | 114 | return cls(centres, offsets, lengths), ordering -------------------------------------------------------------------------------- /src/topsy/shaders/sph.wgsl: -------------------------------------------------------------------------------- 1 | struct TransformParams { 2 | transform: mat4x4, 3 | scale_factor: f32, 4 | clipspace_size_min: f32, 5 | clipspace_size_max: f32, 6 | boxsize_by_2_clipspace: f32, 7 | density_cut: f32 8 | }; 9 | 10 | struct VertexInput { 11 | @location(0) pos: vec4, // NB w is used for the smoothing length 12 | @location(1) quantities: vec3, 13 | @builtin(vertex_index) vertexIndex: u32, 14 | @builtin(instance_index) instanceIndex: u32 15 | } 16 | 17 | struct VertexOutput { 18 | @builtin(position) pos: vec4, 19 | @location(0) texcoord: vec2, 20 | @location(1) intensities: vec3 21 | } 22 | 23 | @group(0) @binding(0) 24 | var trans_params: TransformParams; 25 | 26 | @group(0) @binding(1) 27 | var kernel_texture: texture_2d; 28 | 29 | @group(0) @binding(2) 30 | var kernel_sampler: sampler; 31 | 32 | 33 | 34 | // triangle position offsets for making a square of 2 units side length 35 | const posOffset = array, 6>( 36 | vec2(-1.0, -1.0), 37 | vec2(-1.0, 1.0), 38 | vec2(1.0, 1.0), 39 | vec2(1.0, -1.0), 40 | vec2(-1.0, -1.0), 41 | vec2(1.0, 1.0) 42 | ); 43 | 44 | // corresponding texture coordinates 45 | const texCoords = array, 6>( 46 | vec2(0.0, 0.0), 47 | vec2(0.0, 1.0), 48 | vec2(1.0, 1.0), 49 | vec2(1.0, 0.0), 50 | vec2(0.0, 0.0), 51 | vec2(1.0, 1.0) 52 | ); 53 | 54 | fn vertex_calculate_positions(input: VertexInput) -> VertexOutput { 55 | var output: VertexOutput; 56 | 57 | // factor 2: going out to 2h. 58 | var clipspace_size = trans_params.scale_factor*input.pos.w*2.0; 59 | 60 | output.pos = input.pos; 61 | output.pos.w = 1.0; 62 | output.pos = (trans_params.transform * output.pos); 63 | output.pos += vec4(clipspace_size*posOffset[input.vertexIndex],0.0,0.0); 64 | output.texcoord = texCoords[input.vertexIndex]; 65 | return output; 66 | } 67 | 68 | @vertex 69 | fn vertex_rgb(input: VertexInput) -> VertexOutput { 70 | var output: VertexOutput = vertex_calculate_positions(input); 71 | output.intensities = input.quantities/(input.pos.w * input.pos.w); 72 | return output; 73 | } 74 | 75 | @vertex 76 | fn vertex_weighting(input: VertexInput) -> VertexOutput { 77 | var output: VertexOutput = vertex_calculate_positions(input); 78 | 79 | output.intensities.x = input.quantities.x/(input.pos.w * input.pos.w); 80 | output.intensities.y = input.quantities.y; 81 | 82 | return output; 83 | } 84 | 85 | @vertex 86 | fn vertex_depth(input: VertexInput) -> VertexOutput { 87 | var output: VertexOutput = vertex_calculate_positions(input); 88 | output.intensities.x = input.quantities.x/(input.pos.w * input.pos.w); 89 | output.intensities.y = output.pos.z; 90 | return output; 91 | } 92 | 93 | @vertex 94 | fn vertex_depth_with_cut(input: VertexInput) -> VertexOutput { 95 | // This could be made more efficient by a compute shader passing once through the buffer 96 | // which would only need to be updated when the user changed the density threshold 97 | var result: VertexOutput; 98 | 99 | var rho: f32 = input.quantities.x / pow(input.pos.w,3.0f); 100 | 101 | if(rho > trans_params.density_cut) { 102 | result = vertex_calculate_positions(input); 103 | result.intensities.x = input.quantities.y; // quantity value 104 | result.intensities.y = result.pos.z; // depth value 105 | 106 | // "z" component of "intensities" is actually the depth scale of the sphere to be 107 | // rendered on this tile 108 | // 109 | // Factors: Sphere extends to 2*h, but that's already baked into the kernel image 110 | // input.pos.w*trans_params.scale_factor gives the extent of h in (x,y) clip space, 111 | // but note the z direction is squsiehd into (0,1) while (x,y) are in (-1,1) 112 | // so there is a factor of 0.5 in the z direction. 113 | result.intensities.z = input.pos.w * trans_params.scale_factor*0.5; 114 | } else { 115 | // put somewhere out of the clip space: 116 | result.pos.x = 100; 117 | result.pos.y = 100; 118 | result.pos.w = 100; 119 | } 120 | 121 | return result; 122 | } 123 | 124 | 125 | struct FragmentOutputWeighting { 126 | @location(0) output: vec2 127 | } 128 | 129 | struct FragmentOutputRGB { 130 | @location(0) output: vec4 131 | } 132 | 133 | struct FragmentOutputRaw { 134 | @location(0) output: vec2, 135 | @builtin(frag_depth) depth: f32, 136 | } 137 | 138 | @fragment 139 | fn fragment_weighting(input: VertexOutput) -> FragmentOutputWeighting { 140 | var value = textureSample(kernel_texture, kernel_sampler, input.texcoord).r; 141 | 142 | value *= input.intensities.x; 143 | var output = FragmentOutputWeighting(vec2(value, value*input.intensities.y)); 144 | 145 | return output; 146 | } 147 | 148 | @fragment 149 | fn fragment_raw(input: VertexOutput) -> FragmentOutputRaw { 150 | var value = textureSample(kernel_texture, kernel_sampler, input.texcoord).r; 151 | var depth: f32 = input.intensities.y + input.intensities.z*value; 152 | 153 | if (value<0.0) { 154 | discard; 155 | } 156 | 157 | return FragmentOutputRaw(vec2(input.intensities.x, depth), depth); 158 | } 159 | 160 | @fragment 161 | fn fragment_rgb(input: VertexOutput) -> FragmentOutputRGB { 162 | var value = textureSample(kernel_texture, kernel_sampler, input.texcoord).r; 163 | var output = FragmentOutputRGB(vec4(input.intensities * value, 1.0)); 164 | return output; 165 | } 166 | -------------------------------------------------------------------------------- /src/topsy/recorder/interpolator.py: -------------------------------------------------------------------------------- 1 | """Interpolation between frames for the motion recorder""" 2 | 3 | import math 4 | import numpy as np 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | class Interpolator(ABC): 9 | """ABC for interpolating a timestream. 10 | 11 | The timestream is a list of (time, value) pairs, where time is a float and value is any type.""" 12 | 13 | no_value = object() 14 | 15 | def __init__(self, timestream): 16 | self._timestream = timestream 17 | 18 | @abstractmethod 19 | def __call__(self, t): 20 | pass 21 | 22 | 23 | class LinearInterpolator(Interpolator): 24 | """Returns the linearly interpolated value, or None if no value is available""" 25 | 26 | def __call__(self, t): 27 | stream = self._timestream 28 | for i, (t_ev, val_ev) in enumerate(stream): 29 | if t_ev >= t: 30 | if i == 0: 31 | return val_ev 32 | else: 33 | t0, val0 = stream[i - 1] 34 | assert t0 < t 35 | return val0 + (val_ev - val0) * (t - t0) / (t_ev - t0) 36 | return self.no_value 37 | 38 | class SmoothedInterpolatorMixin: 39 | def __init__(self, timestream, smoothing=0.25, fps=30): 40 | """Create a linear interpolator with gaussian smoothing over the specified period 41 | 42 | Args: 43 | timestream: the timestream to interpolate 44 | smoothing: the standard deviation of the gaussian smoothing kernel, in seconds 45 | fps: the number of samples per second in the smoothed timestream (doesn't have to match video fps) 46 | """ 47 | super().__init__(timestream) 48 | tmax = timestream[-1][0] 49 | self._smoothing = smoothing 50 | 51 | interpolated_timestream = [] 52 | for i in range(math.floor(tmax*fps)): 53 | interpolated_timestream.append(super().__call__(i/fps)) 54 | 55 | 56 | kernel = np.exp(-np.arange(-3*smoothing*fps, 3*smoothing*fps)**2/(2*smoothing**2*fps**2)) 57 | kernel/=kernel.sum() 58 | interpolated_timestream = np.concatenate( 59 | ([interpolated_timestream[0]]*(len(kernel)//2), 60 | interpolated_timestream, 61 | [interpolated_timestream[-1]]*(len(kernel)//2)) 62 | ) 63 | 64 | if len(interpolated_timestream.shape)==1: 65 | smoothed_timestream = np.convolve(interpolated_timestream, kernel, mode='valid') 66 | else: 67 | smoothed_timestream = None 68 | for index in np.ndindex(interpolated_timestream.shape[1:]): 69 | index_c = (slice(None),)+index # py3.11+ supports [:, *index] but not py3.10- 70 | result = np.convolve(interpolated_timestream[index_c], kernel, mode='valid') 71 | if smoothed_timestream is None: 72 | smoothed_timestream = np.empty((len(result),)+interpolated_timestream.shape[1:]) 73 | smoothed_timestream[index_c] = result 74 | 75 | self._timestream = [(i/fps, val) for i, val in enumerate(smoothed_timestream)] 76 | 77 | 78 | class SmoothedLinearInterpolator(SmoothedInterpolatorMixin, LinearInterpolator): 79 | pass 80 | 81 | class RotationInterpolator(LinearInterpolator): 82 | """Returns an interpolated rotation matrix""" 83 | 84 | def __call__(self, t): 85 | matr = super().__call__(t) 86 | if matr is self.no_value: 87 | return matr 88 | 89 | # orthogonalise matr: 90 | u, s, vh = np.linalg.svd(matr) 91 | return u @ vh 92 | 93 | class SmoothedRotationInterpolator(SmoothedInterpolatorMixin, RotationInterpolator): 94 | pass 95 | 96 | 97 | class StepInterpolator(Interpolator): 98 | """Only returns a value when it has changed. Assumes it is being acccessed sequentially""" 99 | 100 | def __init__(self, timestream): 101 | super().__init__(timestream) 102 | self._last_value = self.no_value 103 | self._last_t = None 104 | 105 | def __call__(self, t): 106 | if self._last_t is not None and t=self._transition_end: 133 | tv = self._target_value 134 | self._start_value = None 135 | self._target_value = None 136 | self._transition_start = None 137 | self._transition_end = None 138 | return tv 139 | else: 140 | return self._start_value + (self._target_value-self._start_value)*(t-self._transition_start)/(self._transition_end-self._transition_start) 141 | else: 142 | last_value = self._last_value 143 | new_value = super().__call__(t) 144 | if new_value is self.no_value or new_value is None or new_value == last_value: 145 | return self.no_value 146 | elif last_value is self.no_value or last_value is None: 147 | return new_value 148 | else: 149 | self._start_value = last_value 150 | self._target_value = new_value 151 | self._transition_start = t 152 | self._transition_end = t + self._smoothing 153 | return last_value 154 | 155 | -------------------------------------------------------------------------------- /src/topsy/canvas/qt/recording.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import time 4 | 5 | from PySide6 import QtWidgets, QtCore 6 | 7 | from ...recorder import VisualizationRecorder 8 | 9 | from typing import TYPE_CHECKING 10 | 11 | if TYPE_CHECKING: 12 | from ...visualizer import Visualizer 13 | 14 | 15 | class RecordingSettingsDialog(QtWidgets.QDialog): 16 | def __init__(self, *args): 17 | super().__init__(*args) 18 | self.setWindowTitle("Recording settings") 19 | self._layout = QtWidgets.QVBoxLayout() 20 | self.setLayout(self._layout) 21 | 22 | # checkbox for smoothing: 23 | self._smooth_checkbox = QtWidgets.QCheckBox("Smooth timestream camera movements") 24 | self._smooth_checkbox.setChecked(True) 25 | self._layout.addWidget(self._smooth_checkbox) 26 | 27 | # leave some space: 28 | self._layout.addSpacing(10) 29 | 30 | # checkbox for including vmin/vmax: 31 | self._vmin_vmax_checkbox = QtWidgets.QCheckBox("Set vmin/vmax from timestream") 32 | self._vmin_vmax_checkbox.setChecked(True) 33 | self._layout.addWidget(self._vmin_vmax_checkbox) 34 | 35 | # checkbox for changing quantity: 36 | self._quantity_checkbox = QtWidgets.QCheckBox("Set quantity from timestream") 37 | self._quantity_checkbox.setChecked(True) 38 | self._layout.addWidget(self._quantity_checkbox) 39 | 40 | self._layout.addSpacing(10) 41 | 42 | # checkbox for showing colorbar: 43 | self._colorbar_checkbox = QtWidgets.QCheckBox("Show colorbar") 44 | self._colorbar_checkbox.setChecked(True) 45 | self._layout.addWidget(self._colorbar_checkbox) 46 | 47 | # checkbox for showing scalebar: 48 | self._scalebar_checkbox = QtWidgets.QCheckBox("Show scalebar") 49 | self._scalebar_checkbox.setChecked(True) 50 | self._layout.addWidget(self._scalebar_checkbox) 51 | 52 | self._layout.addSpacing(10) 53 | 54 | 55 | # select resolution from dropdown, with options half HD, HD, 4K 56 | self._resolution_dropdown = QtWidgets.QComboBox() 57 | self._resolution_dropdown.addItems(["Half HD (960x540)", "HD (1920x1080)", "4K (3840x2160)"]) 58 | self._resolution_dropdown.setCurrentIndex(1) 59 | 60 | # select fps from dropdown, with options 24, 30, 60 61 | self._fps_dropdown = QtWidgets.QComboBox() 62 | self._fps_dropdown.addItems(["24 fps", "30 fps", "60 fps"]) 63 | self._fps_dropdown.setCurrentIndex(1) 64 | 65 | # put resolution/fps next to each other horizontally: 66 | self._resolution_fps_layout = QtWidgets.QHBoxLayout() 67 | self._resolution_fps_layout.addWidget(self._resolution_dropdown) 68 | self._resolution_fps_layout.addWidget(self._fps_dropdown) 69 | self._layout.addLayout(self._resolution_fps_layout) 70 | 71 | self._layout.addSpacing(10) 72 | 73 | # cancel and save.. buttons: 74 | self._cancel_save_layout = QtWidgets.QHBoxLayout() 75 | self._cancel_button = QtWidgets.QPushButton("Cancel") 76 | self._cancel_button.clicked.connect(self.reject) 77 | self._save_button = QtWidgets.QPushButton("Save") 78 | # save button should be default: 79 | self._save_button.setDefault(True) 80 | self._save_button.clicked.connect(self.accept) 81 | self._cancel_save_layout.addWidget(self._cancel_button) 82 | self._cancel_save_layout.addWidget(self._save_button) 83 | self._layout.addLayout(self._cancel_save_layout) 84 | 85 | # show as a sheet on macos: 86 | #self.setWindowModality(QtCore.Qt.WindowModality.WindowModal) 87 | self.setWindowFlags(QtCore.Qt.WindowType.Sheet) 88 | 89 | @property 90 | def fps(self): 91 | return float(self._fps_dropdown.currentText().split()[0]) 92 | 93 | @property 94 | def resolution(self): 95 | import re 96 | # use regexp 97 | # e.g. the string 'blah (123x456)' should map to tuple (123,456) 98 | match = re.match(r".*\((\d+)x(\d+)\)", self._resolution_dropdown.currentText()) 99 | return int(match.group(1)), int(match.group(2)) 100 | 101 | @property 102 | def smooth(self): 103 | return self._smooth_checkbox.isChecked() 104 | 105 | @property 106 | def set_vmin_vmax(self): 107 | return self._vmin_vmax_checkbox.isChecked() 108 | 109 | @property 110 | def set_quantity(self): 111 | return self._quantity_checkbox.isChecked() 112 | 113 | @property 114 | def show_colorbar(self): 115 | return self._colorbar_checkbox.isChecked() 116 | 117 | @property 118 | def show_scalebar(self): 119 | return self._scalebar_checkbox.isChecked() 120 | 121 | 122 | class VisualizationRecorderWithQtProgressbar(VisualizationRecorder): 123 | 124 | def __init__(self, visualizer: Visualizer, parent_widget: QtWidgets.QWidget): 125 | super().__init__(visualizer) 126 | self._parent_widget = parent_widget 127 | 128 | def _progress_iterator(self, ntot): 129 | progress_bar = QtWidgets.QProgressDialog("Rendering to mp4...", "Stop", 0, ntot, self._parent_widget) 130 | progress_bar.setWindowModality(QtCore.Qt.WindowModality.WindowModal) 131 | progress_bar.forceShow() 132 | 133 | last_update = 0 134 | 135 | loop = QtCore.QEventLoop() 136 | 137 | try: 138 | for i in range(ntot): 139 | # updating the progress bar triggers a render in the main window, which 140 | # in turn is quite slow (because it can trigger software rendering 141 | # of resizable elements like the colorbar). So only update every half second or so. 142 | if time.time() - last_update > 0.5: 143 | last_update = time.time() 144 | progress_bar.setValue(i) 145 | 146 | with self._visualizer.prevent_sph_rendering(): 147 | loop.processEvents(QtCore.QEventLoop.ProcessEventsFlag.AllEvents) 148 | 149 | if progress_bar.wasCanceled(): 150 | break 151 | yield i 152 | 153 | finally: 154 | progress_bar.close() 155 | -------------------------------------------------------------------------------- /src/topsy/view_synchronizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import weakref 4 | 5 | from .drawreason import DrawReason 6 | 7 | from typing import TYPE_CHECKING, Optional, Callable 8 | if TYPE_CHECKING: 9 | from .visualizer import Visualizer 10 | 11 | 12 | class ViewSynchronizer: 13 | """A class to manage reflecting rotation/scale/offset changes from one view to another. 14 | 15 | Best used by creating two visualizers, visualizer1 and visualizer2, and then calling 16 | visualizer1.synchronize_with(visualizer2) 17 | 18 | The two visualizers will then be kept in sync 19 | """ 20 | def __init__(self, synchronize=['rotation_matrix', 'scale', 'position_offset']): 21 | self._views : list[weakref.ReferenceType[Visualizer]] = [] 22 | self._requires_update : list[weakref.ReferenceType[Visualizer]] = [] 23 | self._synchronize = synchronize 24 | self._setters = {} 25 | self._getters = {} 26 | 27 | 28 | @staticmethod 29 | def _default_getter(source, var): 30 | """Default getter for the variables to synchronize""" 31 | # If the variable is a dot-separated path, we return the nested attributes 32 | path = var.split('.') 33 | value = source 34 | for p in path: 35 | if '[' in p: 36 | # parse array-like access, e.g. name[key] should get attr name and use key to index it 37 | attr, key = p.split('[') 38 | if key.endswith(']'): 39 | key = key[:-1] 40 | value = getattr(value, attr)[key] 41 | else: 42 | value = getattr(value, p) 43 | return value 44 | 45 | @staticmethod 46 | def _default_setter(source, var, value): 47 | """Default setter for the variables to synchronize""" 48 | # If the variable is a dot-separated path, we set the nested attributes 49 | path = var.split('.') 50 | target = source 51 | for p in path[:-1]: 52 | target = getattr(target, p) 53 | p = path[-1] 54 | if '[' in p: 55 | # parse array-like access, e.g. name[key] should get attr name and use key to index it 56 | attr, key = p.split('[') 57 | if key.endswith(']'): 58 | key = key[:-1] 59 | target = getattr(target, attr) 60 | target[key] = value 61 | else: 62 | # normal attribute assignment 63 | setattr(target, path[-1], value) 64 | 65 | def perpetuate_update(self, source): 66 | """Called when a view has been updated and the update needs to be perpetuated to other views. 67 | 68 | Note that there is built-in protection against infinite loops. If view A calls this method 69 | which causes view B to update, if view B calls this method nothing will be issued back to view A.""" 70 | sources_needing_update = [view_weakref() for view_weakref in self._requires_update] 71 | if source in sources_needing_update: 72 | # OK the update has happened! Great, but don't broadcast it again 73 | del self._requires_update[sources_needing_update.index(source)] 74 | return 75 | 76 | getter = self._getters[id(source)] 77 | 78 | for view_weakref in self._views: 79 | view = view_weakref() 80 | setter = self._setters[id(view)] 81 | if (view is not source and view is not None) and (view_weakref not in self._requires_update): 82 | self._requires_update.append(view_weakref) 83 | for var in self._synchronize: 84 | setter(view, var, getter(source, var)) 85 | 86 | def update_completed(self, view: Visualizer): 87 | """Called when a view knows it will not be attempting to perpetuate an update it received 88 | 89 | See note about infinite loops in perpetuate_update. This method is used when a view will 90 | not be acting on the update it received, so it must be removed from the exclusion list 91 | that perpetuate_update maintains.""" 92 | sources_needing_update = [view_weakref() for view_weakref in self._requires_update] 93 | if view in sources_needing_update: 94 | del self._requires_update[sources_needing_update.index(view)] 95 | 96 | def add_view(self, view: Visualizer, setter: Optional[Callable] = None, getter: Optional[Callable] = None): 97 | self._views.append(weakref.ref(view)) 98 | view._view_synchronizer = self 99 | self._setters[id(view)] = setter or self._default_setter 100 | self._getters[id(view)] = getter or self._default_getter 101 | 102 | def remove_view(self, view: Visualizer): 103 | self._views.remove(weakref.ref(view)) 104 | del view._view_synchronizer 105 | del self._setters[id(view)] 106 | del self._getters[id(view)] 107 | 108 | class SynchronizationMixin: 109 | """Mixin class for Visualizer to allow it to synchronize with other views""" 110 | def draw(self, reason, render_texture_view=None): 111 | super().draw(reason, render_texture_view) 112 | if hasattr(self, "_view_synchronizer") and reason not in (DrawReason.REFINE, DrawReason.PRESENTATION_CHANGE): 113 | self._view_synchronizer.perpetuate_update(self) 114 | 115 | def synchronize_with(self, other: Visualizer): 116 | """Start synchronizing this visualizer with another""" 117 | if hasattr(self, "_view_synchronizer") and hasattr(other, "_view_synchronizer"): 118 | raise RuntimeError("Both these visualizers are already synchronizing with others") 119 | 120 | if hasattr(self, "_view_synchronizer"): 121 | self._view_synchronizer.add_view(other) 122 | elif hasattr(other, "_view_synchronizer"): 123 | other._view_synchronizer.add_view(self) 124 | else: 125 | vs = ViewSynchronizer() 126 | vs.add_view(self) 127 | vs.add_view(other) 128 | 129 | def stop_synchronizing(self): 130 | """Stop synchronizing this visualizer with any other""" 131 | if hasattr(self, "_view_synchronizer"): 132 | self._view_synchronizer.remove_view(self) 133 | 134 | def is_synchronizing(self): 135 | return hasattr(self, "_view_synchronizer") 136 | -------------------------------------------------------------------------------- /src/topsy/particle_buffers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import wgpu 4 | 5 | from . import loader, split_buffers 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | _UNSET = object() 10 | 11 | class ParticleBuffers: 12 | def __init__(self, loader: loader.AbstractDataLoader, device: wgpu.GPUDevice, max_draw_calls_per_buffer: int): 13 | self.buffers = {} 14 | self._split_buffers = split_buffers.SplitBuffers(len(loader)) 15 | self._device = device 16 | self._loader = loader 17 | 18 | self.quantity_name = None 19 | self._mass_and_quantity_buffers = None 20 | self._quantity_buffer_is_for_name = _UNSET # can't use None here because None is valid (means 'density render') 21 | self._current_vertex_buffers = [] 22 | 23 | self._create_indirect_draw_buffers(max_draw_calls_per_buffer) 24 | 25 | self._last_bufnum = -1 26 | 27 | def _create_indirect_draw_buffers(self, max_draw_calls_per_buffer: int): 28 | self._indirect_buffers = [] # for indirect draw calls, one needed per physical buffer 29 | self._indirect_count_buffers = [] 30 | self._indirect_buffers_npy = [] 31 | self._indirect_count_buffers_npy = [] 32 | self._max_draw_calls_per_buffer = max_draw_calls_per_buffer 33 | 34 | for i in range(self._split_buffers.num_buffers): 35 | self._indirect_buffers.append( 36 | self._device.create_buffer(size=max_draw_calls_per_buffer * 4 * np.dtype(np.uint32).itemsize, 37 | usage = wgpu.BufferUsage.INDIRECT | wgpu.BufferUsage.COPY_DST) 38 | ) 39 | #self._indirect_count_buffers.append( 40 | # self._device.create_buffer(size=np.dtype(np.uint32).itemsize, 41 | # usage=wgpu.BufferUsage.INDIRECT | wgpu.BufferUsage.COPY_DST) 42 | #) 43 | self._indirect_buffers_npy.append(np.zeros((max_draw_calls_per_buffer, 4), dtype=np.uint32)) 44 | #self._indirect_count_buffers_npy.append(np.zeros((1,), dtype=np.uint32)) 45 | self._indirect_buffers_npy[-1][:, 0] = 6 # vertex count 46 | 47 | 48 | def specify_vertex_buffer_assignment(self, buffer_names): 49 | buffers = [] 50 | for name in buffer_names: 51 | match name: 52 | case "pos_smooth": 53 | buffers.append(self.get_pos_smooth_buffers()) 54 | case "mass_and_quantity": 55 | buffers.append(self.get_mass_and_quantity_buffers()) 56 | case "rgb": 57 | buffers.append(self.get_rgb_buffers()) 58 | case _: 59 | raise ValueError(f"Unknown buffer name: {name}") 60 | self._current_vertex_buffers = buffers 61 | self._last_bufnum = -1 62 | 63 | def set_vertex_buffers(self, bufnum: int, render_pass: wgpu.GPURenderPassEncoder): 64 | if bufnum == self._last_bufnum: 65 | return 66 | for i, buffers in enumerate(self._current_vertex_buffers): 67 | render_pass.set_vertex_buffer(i, buffers[bufnum]) 68 | self._last_bufnum = bufnum 69 | 70 | def issue_draw_indirect(self, sph_render_pass: wgpu.GPURenderPassEncoder): 71 | 72 | for bufnum in range(self._split_buffers.num_buffers): 73 | self.set_vertex_buffers(bufnum, sph_render_pass) 74 | sph_render_pass._multi_draw_indirect(self._indirect_buffers[bufnum], 0, self._max_draw_calls_per_buffer) 75 | 76 | def update_particle_ranges(self, particle_mins: list[int], particle_lens: list[int]): 77 | per_buf_start_lens = self._split_buffers.global_to_split_monotonic(particle_mins, particle_lens) 78 | for bufnum, (particle_min, particle_len) in enumerate(per_buf_start_lens): 79 | self._indirect_buffers_npy[bufnum][len(particle_min):,1] = 0 80 | self._indirect_buffers_npy[bufnum][:len(particle_min), 1] = particle_len # instance count 81 | self._indirect_buffers_npy[bufnum][:len(particle_min), 3] = particle_min # first instance 82 | self._device.queue.write_buffer(self._indirect_buffers[bufnum], 0, self._indirect_buffers_npy[bufnum]) 83 | 84 | def get_pos_smooth_buffers(self): 85 | if not hasattr(self, "_pos_smooth_buffers"): 86 | logger.info("Creating position+smoothing buffer") 87 | data = self._loader.get_pos_smooth().astype(np.float32) 88 | self._pos_smooth_buffers = self._split_buffers.create_buffers(self._device, 4 * 4, 89 | wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST) 90 | self._split_buffers.write_buffers(self._device, self._pos_smooth_buffers, data) 91 | return self._pos_smooth_buffers 92 | 93 | def get_mass_and_quantity_buffers(self): 94 | if self._quantity_buffer_is_for_name != self.quantity_name: 95 | self._create_mass_and_quantity_buffers_if_needed() 96 | data = np.zeros((len(self._loader), 3), dtype=np.float32) 97 | data[:, 0] = self._loader.get_mass() 98 | if self.quantity_name is not None: 99 | data[:, 1] = self._loader.get_named_quantity(self.quantity_name) 100 | self._split_buffers.write_buffers(self._device, self._mass_and_quantity_buffers, data) 101 | self._quantity_buffer_is_for_name = self.quantity_name 102 | return self._mass_and_quantity_buffers 103 | 104 | def get_rgb_buffers(self): 105 | if not hasattr(self, "_rgb_masses_buffers"): 106 | logger.info("Creating rgb buffer") 107 | data = self._loader.get_rgb_masses().view(np.float32) 108 | self._rgb_masses_buffers = self._split_buffers.create_buffers(self._device, 4 * 3, 109 | wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST) 110 | self._split_buffers.write_buffers(self._device, self._rgb_masses_buffers, data) 111 | return self._rgb_masses_buffers 112 | 113 | def _create_mass_and_quantity_buffers_if_needed(self): 114 | if self._mass_and_quantity_buffers is not None: 115 | return 116 | logger.info("Creating quantity buffer") 117 | self._mass_and_quantity_buffers = self._split_buffers.create_buffers(self._device, 4 * 3, 118 | wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST) -------------------------------------------------------------------------------- /tests/test_smooth.py: -------------------------------------------------------------------------------- 1 | """Test the bilateral filtering smoothing operation in ColorAsSurfaceMap.""" 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | from pathlib import Path 6 | import topsy 7 | from topsy.canvas import offscreen 8 | from topsy.colormap.surface import ColorAsSurfaceMap 9 | 10 | 11 | def create_test_image(width=256, height=256): 12 | """Create a two-channel test image with noise, gradient, and discontinuity.""" 13 | 14 | np.random.seed(1337) 15 | 16 | # Create coordinate grids 17 | x = np.linspace(0, 1, width) 18 | y = np.linspace(0, 1, height) 19 | X, Y = np.meshgrid(x, y) 20 | 21 | # Initialize two-channel image 22 | test_image = np.zeros((height, width, 2), dtype=np.float32) 23 | 24 | # Channel 0: depth/height values 25 | # Add gradient 26 | gradient = X * 0.5 + Y * 0.3 27 | 28 | # Add discontinuity in the middle 29 | discontinuity = np.zeros_like(gradient) 30 | discontinuity[height//4:3*height//4, width//4:3*width//4] = 0.5 31 | 32 | # Add gaussian noise 33 | noise = np.random.normal(0, 0.05, (height, width)) 34 | 35 | test_image[:, :, 0] = gradient + discontinuity + noise 36 | 37 | # Channel 1: density/mass values 38 | # Similar structure but different values 39 | gradient2 = Y * 0.4 + X * 0.2 40 | discontinuity2 = np.zeros_like(gradient2) 41 | discontinuity2[height//3:2*height//3, width//3:2*width//3] = 0.3 42 | noise2 = np.random.normal(0, 0.03, (height, width)) 43 | 44 | test_image[:, :, 1] = gradient2 + discontinuity2 + noise2 45 | 46 | # Ensure all values are positive (typical for SPH data) 47 | test_image = np.abs(test_image) + 0.01 48 | 49 | return test_image 50 | 51 | 52 | def test_smoothing_operation(): 53 | """Test the smoothing operation using ColorAsSurfaceMap._smooth_numpy.""" 54 | # Create output folder 55 | folder = Path(__file__).parent / "output" 56 | folder.mkdir(exist_ok=True) 57 | 58 | test_image = create_test_image() 59 | 60 | vis = topsy.test(100, render_resolution=test_image.shape[0], canvas_class=offscreen.VisualizerCanvas) 61 | 62 | vis.colormap.update_parameters({ 63 | 'type': 'surface', 64 | 'smoothing_scale': 0.02, 65 | }) 66 | 67 | # Get the surface colormap instance 68 | surface_map = vis.colormap._impl 69 | 70 | smoothed_output = surface_map._smooth_numpy(test_image) 71 | 72 | # Save outputs 73 | np.save(folder / 'test_smooth_input.npy', test_image) 74 | np.save(folder / 'test_smooth_output.npy', smoothed_output) 75 | 76 | # no smoothing on channel 0 77 | npt.assert_allclose(test_image[..., 0], smoothed_output[..., 0], atol=1e-7) 78 | 79 | # channel 1 is smoothed but hard edges are still there 80 | expected_global_samples = [0.04350269, 0.03492163, 0.03985117, 0.06765869, 0.09567888, 81 | 0.08533357, 0.10505654, 0.10955958, 0.12479778, 0.14756411, 82 | 0.15113021, 0.193778 , 0.17523961, 0.03672419, 0.04862463, 83 | 0.06353201, 0.07946779, 0.10003348, 0.11344124, 0.1363642 , 84 | 0.14499053, 0.14822552, 0.18607135, 0.19857994, 0.21028055, 85 | 0.2280225 , 0.04706344, 0.08421917, 0.11555166, 0.10795519, 86 | 0.13088886, 0.14922458, 0.16380104, 0.18379168, 0.21592017, 87 | 0.21860206, 0.22268587, 0.26771224, 0.25449008, 0.11944734, 88 | 0.10907468, 0.14002527, 0.14063616, 0.16814029, 0.199773 , 89 | 0.1959068 , 0.21331303, 0.230805 , 0.2333972 , 0.25765777, 90 | 0.2804561 , 0.279846 , 0.12494649, 0.1561663 , 0.17562656, 91 | 0.18248414, 0.19247337, 0.21708317, 0.23262993, 0.24562259, 92 | 0.25853062, 0.28892142, 0.28311318, 0.29793793, 0.33640784, 93 | 0.17010446, 0.18930109, 0.21303204, 0.23668505, 0.2234434 , 94 | 0.53724605, 0.5470026 , 0.59243613, 0.58226967, 0.30523527, 95 | 0.32288015, 0.34913924, 0.3622239 , 0.19013162, 0.22471175, 96 | 0.23146637, 0.24505465, 0.24869259, 0.56404704, 0.577596 , 97 | 0.60139513, 0.63722885, 0.33896458, 0.3310112 , 0.36426947, 98 | 0.37940887, 0.24438018, 0.23389104, 0.27845004, 0.27066812, 99 | 0.30080408, 0.61517453, 0.6194309 , 0.6452768 , 0.65323645, 100 | 0.36640742, 0.4021577 , 0.39369363, 0.40901196, 0.2726592 , 101 | 0.27804396, 0.27932608, 0.327759 , 0.32077065, 0.6381103 , 102 | 0.65829104, 0.6525516 , 0.67309 , 0.40183058, 0.4239184 , 103 | 0.42731968, 0.4529117 , 0.29238752, 0.31701306, 0.33301896, 104 | 0.3448711 , 0.34038064, 0.37747476, 0.38808158, 0.39679116, 105 | 0.4076333 , 0.43214443, 0.46319935, 0.4567135 , 0.47022748, 106 | 0.3328091 , 0.34471515, 0.35241255, 0.38199195, 0.40559343, 107 | 0.3941699 , 0.41470632, 0.4342157 , 0.44876787, 0.45069033, 108 | 0.4712601 , 0.4940555 , 0.49560383, 0.3539791 , 0.35014838, 109 | 0.374198 , 0.41407102, 0.40864894, 0.4354274 , 0.45889753, 110 | 0.4607982 , 0.48360315, 0.49325395, 0.5230214 , 0.5123342 , 111 | 0.54918534, 0.3842421 , 0.40161213, 0.41294298, 0.4332912 , 112 | 0.4545948 , 0.47875527, 0.47641772, 0.4919833 , 0.52749985, 113 | 0.53773326, 0.53975433, 0.55892164, 0.56311625] 114 | 115 | global_check = smoothed_output[::20, ::20, 1].ravel() 116 | 117 | expected_edge_check = [0.19247337, 0.19723694, 0.1849934 , 0.19642891, 0.19979529, 118 | 0.1993649 , 0.20303836, 0.2173118 , 0.1889234 , 0.20993778, 119 | 0.19258004, 0.22300835, 0.20747848, 0.20263639, 0.2036718 , 120 | 0.20567518, 0.21924324, 0.20507486, 0.19026951, 0.20912749, 121 | 0.19608739, 0.20039833, 0.19389133, 0.19785273, 0.19580497, 122 | 0.20818928, 0.20516331, 0.20875177, 0.21691433, 0.18723641, 123 | 0.21353379, 0.19767466, 0.20573969, 0.1855796 , 0.19924074, 124 | 0.21442725, 0.1919996 , 0.17996098, 0.20208283, 0.21387406, 125 | 0.24663831, 0.20913196, 0.19462162, 0.21180561, 0.16858205, 126 | 0.21117128, 0.20315671, 0.20511323, 0.21663508, 0.20262529, 127 | 0.20380434, 0.19074719, 0.1645996 , 0.2216465 , 0.2202986 , 128 | 0.51710486, 0.5235424 , 0.51853245, 0.51737845, 0.5172093 , 129 | 0.20266144, 0.19300404, 0.19269636, 0.20673202, 0.20537308, 130 | 0.5252804 , 0.5193447 , 0.5337641 , 0.51419616, 0.5214026 , 131 | 0.20528564, 0.1887608 , 0.22220144, 0.20611644, 0.2162794 , 132 | 0.5282587 , 0.5235715 , 0.5250429 , 0.532619 , 0.53551286, 133 | 0.20328672, 0.20438206, 0.20458573, 0.2203121 , 0.22026433, 134 | 0.5077367 , 0.5264619 , 0.52011055, 0.5161084 , 0.5056762 , 135 | 0.2132591 , 0.21822827, 0.19445635, 0.21045099, 0.22532488, 136 | 0.5191735 , 0.530404 , 0.52163655, 0.5298376 , 0.5205669 ] 137 | 138 | edge_check = smoothed_output[80:90, 80:90, 1].ravel() 139 | 140 | npt.assert_allclose(global_check, expected_global_samples, atol=1e-6) 141 | npt.assert_allclose(edge_check, expected_edge_check, atol=1e-6) 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | test_smoothing_operation() -------------------------------------------------------------------------------- /src/topsy/colormap/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import wgpu 3 | 4 | from . import implementation, surface 5 | 6 | from .implementation import ColormapBase, NoColormap, Colormap, RGBColormap, RGBHDRColormap, BivariateColormap 7 | from .ui import ColorMapController, BivariateColorMapController, RGBMapController, GenericController, SurfaceMapController 8 | from .. import config 9 | 10 | from typing import Iterator, Optional 11 | 12 | class ColormapHolder: 13 | """ 14 | A class to hold and update the color map for a visualizer. 15 | 16 | This is necessary because the color map may change during the lifetime of a visualizer, and 17 | the logic for updating the color map is encapsulated here rather than in the visualizer itself. 18 | """ 19 | 20 | def __init__(self, device: wgpu.GPUDevice, input_texture: wgpu.GPUTexture, output_format: wgpu.TextureFormat): 21 | self._device = device 22 | self._input_texture = input_texture 23 | self._output_format = output_format 24 | self._impl: ColormapBase = self.instance_from_parameters( 25 | { 26 | 'colormap_name': config.DEFAULT_COLORMAP, 27 | 'vmin': None, 28 | 'vmax': None, 29 | 'log': False, 30 | 'type': 'none', 31 | }, device, input_texture, output_format, 32 | ) 33 | 34 | def _check_valid(self): 35 | if self._impl is None or isinstance(self._impl, NoColormap): 36 | raise ValueError("ColormapHolder is not fully initialized") 37 | 38 | @classmethod 39 | def _iter_classes(cls, base_class=ColormapBase) -> Iterator[ColormapBase] : 40 | """ 41 | Iterate over all subclasses of ColormapBase that match the given parameters. 42 | """ 43 | for subclass in base_class.__subclasses__(): 44 | yield subclass 45 | yield from cls._iter_classes(subclass) 46 | 47 | 48 | @classmethod 49 | def _class_from_parameters(cls, parameters) -> Optional[type[ColormapBase]]: 50 | for cl in cls._iter_classes(): 51 | if cl.accepts_parameters(parameters): 52 | return cl 53 | 54 | return None 55 | 56 | @classmethod 57 | def instance_from_parameters(cls, parameters, device: wgpu.GPUDevice, input_texture: wgpu.GPUTexture, 58 | output_format: wgpu.TextureFormat) -> ColormapBase: 59 | colormap_class = cls._class_from_parameters(parameters) 60 | if colormap_class is None: 61 | raise ValueError(f"No colormap class found for parameters: {parameters}") 62 | return colormap_class(device, input_texture, output_format, parameters) 63 | 64 | def update_parameters(self, parameters: dict): 65 | """ 66 | Update the colormap parameters and recreate the colormap if necessary. 67 | 68 | Returns True if the colormap was recreated, False if it was updated in place. 69 | """ 70 | all_parameters = self.get_parameters() | parameters # merge with existing parameters 71 | if self._impl is None and self._class_from_parameters(all_parameters) is None: 72 | return # we are in an initialization phase and it's fine to have no colormap yet 73 | if self._impl is None or not self._impl.accepts_parameters(all_parameters): 74 | self._impl = self.instance_from_parameters(all_parameters, self._device, self._input_texture, 75 | self._output_format) 76 | return True 77 | else: 78 | # Update the existing colormap parameters, without passing back in already known parameters 79 | self._impl.update_parameters(parameters) 80 | return False 81 | 82 | def get_parameter(self, name: str): 83 | """ 84 | Get a parameter from the colormap. 85 | """ 86 | return self._impl.get_parameter(name) 87 | 88 | def get_parameters(self) -> dict: 89 | """ 90 | Get all parameters from the colormap. 91 | """ 92 | return self._impl.get_parameters() 93 | 94 | def autorange(self, sph_render_output: np.ndarray): 95 | """Update the colormap ranges based on the provided SPH render output.""" 96 | self._check_valid() 97 | self._impl.autorange_vmin_vmax(sph_render_output) 98 | 99 | def encode_render_pass(self, command_encoder, target_texture_view): 100 | """ 101 | Encode the render pass for the colormap. 102 | This will set up the necessary buffers and shaders for rendering the colormap. 103 | """ 104 | self._check_valid() 105 | 106 | self._impl.encode_render_pass(command_encoder, target_texture_view) 107 | 108 | def set_scaling(self, width, height, mass_scaling): 109 | """ 110 | Set the scaling for the colormap. 111 | """ 112 | self._check_valid() 113 | self._impl.set_scaling(width, height, mass_scaling) 114 | 115 | def sph_raw_output_to_image(self, sph_raw_output: np.ndarray) -> np.ndarray: 116 | """ 117 | Convert SPH raw output to an image using the colormap. 118 | """ 119 | self._check_valid() 120 | return self._impl.sph_raw_output_to_image(sph_raw_output) 121 | 122 | def sph_raw_output_to_content(self, sph_raw_output: np.ndarray) -> np.ndarray: 123 | """ 124 | Convert SPH raw output to the logical content represented by the colormap. 125 | 126 | This is typically used for debugging or analysis purposes. 127 | """ 128 | self._check_valid() 129 | return self._impl.sph_raw_output_to_content(sph_raw_output) 130 | 131 | def make_ui_controller(self, visualizer, refresh_ui_callback: Optional[callable] = None) -> GenericController: 132 | """ 133 | Make a UI controller for the currently instantiated colormap. 134 | 135 | This is used to interact with the colormap in a user interface. The controller is an abstract 136 | description of the UI elements and their behavior, allowing for different implementations 137 | (specifically Qt or Jupyter) to render the UI accordingly. 138 | """ 139 | self._check_valid() 140 | if isinstance(self._impl, BivariateColormap): 141 | return BivariateColorMapController(visualizer, refresh_ui_callback) 142 | elif isinstance(self._impl, RGBColormap): 143 | return RGBMapController(visualizer, refresh_ui_callback) 144 | elif isinstance(self._impl, surface.ColorAsSurfaceMap): 145 | return SurfaceMapController(visualizer, refresh_ui_callback) 146 | else: 147 | return ColorMapController(visualizer, refresh_ui_callback) 148 | 149 | def __getitem__(self, key: str): 150 | """ 151 | Allow dictionary-like access to colormap parameters. 152 | """ 153 | return self.get_parameter(key) 154 | 155 | def __setitem__(self, key: str, value): 156 | """ 157 | Allow dictionary-like setting of colormap parameters. 158 | """ 159 | self.update_parameters({key: value}) -------------------------------------------------------------------------------- /src/topsy/canvas/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import rendercanvas.jupyter, rendercanvas.auto 5 | import time 6 | import copy 7 | 8 | from .. import config 9 | 10 | from typing import TYPE_CHECKING 11 | if TYPE_CHECKING: 12 | from ..visualizer import Visualizer 13 | 14 | 15 | 16 | class VisualizerCanvasBase: 17 | def __init__(self, *args, **kwargs): 18 | self._visualizer : Visualizer = kwargs.pop("visualizer") 19 | 20 | self._last_x = 0 21 | self._last_y = 0 22 | # The below are dummy values that will be updated by the initial resize event 23 | self.width_physical, self.height_physical = 640,480 24 | self.pixel_ratio = 1 25 | 26 | super().__init__(*args, **kwargs) 27 | 28 | self.add_event_handler(self.event_handler, "*") 29 | 30 | def event_handler(self, event): 31 | if event['event_type']=='pointer_move': 32 | if len(event['buttons'])>0: 33 | if len(event['modifiers'])==0: 34 | self.drag(event['x']-self._last_x, event['y']-self._last_y) 35 | else: 36 | self.shift_drag(event['x']-self._last_x, event['y']-self._last_y) 37 | self._last_x = event['x'] 38 | self._last_y = event['y'] 39 | elif event['event_type']=='wheel': 40 | self.mouse_wheel(event['dx'], event['dy']) 41 | elif event['event_type']=='key_up': 42 | self.key_up(event['key']) 43 | elif event['event_type']=='resize': 44 | self.resize_complete(event['width'], event['height'], event['pixel_ratio']) 45 | elif event['event_type']=='double_click': 46 | self.double_click(event['x'], event['y']) 47 | elif event['event_type']=='pointer_up': 48 | self.release_drag() 49 | else: 50 | pass 51 | 52 | 53 | def drag(self, dx, dy): 54 | self._visualizer.rotate(dx*0.01, dy*0.01) 55 | 56 | def shift_drag(self, dx, dy): 57 | biggest_dimension = max(self.width_physical, self.height_physical) 58 | 59 | displacement = 2.*self.pixel_ratio*np.array([dx, -dy, 0], dtype=np.float32) / biggest_dimension * self._visualizer.scale 60 | self._visualizer.position_offset += self._visualizer.rotation_matrix.T @ displacement 61 | 62 | self._visualizer.display_status("centre = [{:.2f}, {:.2f}, {:.2f}]".format(*self._visualizer._sph.position_offset)) 63 | 64 | self._visualizer.crosshairs_visible = True 65 | 66 | 67 | def key_up(self, key): 68 | if key=='s': 69 | self._visualizer.save() 70 | elif key=='r': 71 | self._visualizer.colormap_autorange() 72 | elif key=='h': 73 | self._visualizer.reset_view() 74 | elif key=='w': 75 | offset = self._visualizer.position_offset 76 | rotation_matrix = self._visualizer.rotation_matrix 77 | offset_string = np.array2string(offset, separator=",") 78 | rotation_matrix_string = np.array2string(rotation_matrix, separator=",") 79 | print(f".translate({offset_string}).transform(np.array({rotation_matrix_string}))") 80 | 81 | def mouse_wheel(self, delta_x, delta_y): 82 | if isinstance(self, rendercanvas.jupyter.JupyterRenderCanvas): 83 | # scroll events are much smaller from the web browser, for 84 | # some reason, compared with native windowing 85 | delta_y *= 10 86 | delta_x *= 10 87 | 88 | self._visualizer.scale*=np.exp(delta_y/1000) 89 | 90 | def release_drag(self): 91 | if self._visualizer.crosshairs_visible: 92 | self._visualizer.crosshairs_visible = False 93 | self._visualizer.invalidate() 94 | 95 | 96 | def resize(self, *args): 97 | # putting this here as a reminder that the resize method must be passed to the base class 98 | super().resize(*args) 99 | 100 | def resize_complete(self, width, height, pixel_ratio=1): 101 | self.width_physical = int(width*pixel_ratio) 102 | self.height_physical = int(height*pixel_ratio) 103 | self.pixel_ratio = pixel_ratio 104 | 105 | def double_click(self, x, y): 106 | original_position = copy.copy(self._visualizer.position_offset) 107 | 108 | biggest_dimension = max(self.width_physical, self.height_physical) 109 | 110 | 111 | centre_physical_x = self.width_physical / (2*self.pixel_ratio) 112 | centre_physical_y = self.height_physical / (2*self.pixel_ratio) 113 | 114 | xy_displacement = 2. * self.pixel_ratio * np.array([centre_physical_x-x, 115 | y-centre_physical_y, 116 | 0], dtype=np.float32) / biggest_dimension * self._visualizer.scale 117 | 118 | 119 | self._visualizer.position_offset += self._visualizer.rotation_matrix.T @ xy_displacement 120 | 121 | 122 | depth_im = self._visualizer.get_depth_image() 123 | central_depth = depth_im[depth_im.shape[0]//2, depth_im.shape[1]//2] 124 | 125 | if ~np.isnan(central_depth): 126 | z_displacement = np.array([0, 0, -central_depth], dtype=np.float32) 127 | self._visualizer.position_offset += self._visualizer.rotation_matrix.T @ z_displacement 128 | 129 | final_position = self._visualizer.position_offset 130 | 131 | # the actual work is done - now animate it so it looks understandable 132 | self._visualizer.position_offset = original_position 133 | 134 | #def interpolate_position(t): 135 | # return original_position + (final_position - original_position) * t 136 | 137 | def interpolate_position(t): 138 | w1 = np.arctan(5*(t*2-1))/np.pi+0.5 139 | w2 = 1-w1 140 | return w2 * original_position + w1 * final_position 141 | 142 | start = time.time() 143 | 144 | def glide(): 145 | t = (time.time()-start)/config.GLIDE_TIME 146 | if t>1: 147 | self._visualizer.position_offset = final_position 148 | else: 149 | self.call_later(0.0, glide) 150 | self._visualizer.position_offset = interpolate_position(t) 151 | 152 | 153 | 154 | self.call_later(1. / config.TARGET_FPS, glide) 155 | 156 | 157 | @classmethod 158 | def call_later(cls, delay, fn, *args): 159 | raise NotImplementedError() 160 | 161 | 162 | 163 | 164 | 165 | # Now we are going to select a specific backend 166 | # 167 | # we don't use rendercanvas.auto directly because it prefers the glfw backend over qt 168 | # whereas we want to use qt 169 | # 170 | # Note also that is_jupyter as implemented fails to distinguish correctly if we are 171 | # running inside a kernel that isn't attached to a notebook. There doesn't seem to 172 | # be any way to distinguish this, so we live with it for now. 173 | 174 | 175 | from .. import is_jupyter 176 | 177 | if is_jupyter(): 178 | from .jupyter import VisualizerCanvas 179 | else: 180 | from .qt import VisualizerCanvas 181 | 182 | -------------------------------------------------------------------------------- /src/topsy/split_buffers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import wgpu 4 | 5 | from . import config, performance 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class SplitBuffers: 10 | """Manages splitting buffers into smaller buffers for GPU hardware. 11 | 12 | This is needed because some/most GPUs have a limit on the size of buffers which is less than the actual total 13 | RAM available. 14 | 15 | We adopt terminology of a 'global' address (in terms of particles) which can then be mapped onto a 'split' 16 | address which is a tuple of (buffer number, buffer offset).""" 17 | 18 | def __init__(self, num_particles: int, max_particles_per_buffer: int | None = None): 19 | if max_particles_per_buffer is None: 20 | max_particles_per_buffer = config.MAX_PARTICLES_PER_BUFFER 21 | 22 | self._num_particles = num_particles 23 | self._max_particles_per_buffer = max_particles_per_buffer 24 | self._calculate_splits() 25 | 26 | def _calculate_splits(self): 27 | if self._num_particles > self._max_particles_per_buffer: 28 | num_buffers = int(np.ceil(self._num_particles / self._max_particles_per_buffer)) 29 | else: 30 | num_buffers = 1 31 | 32 | self._num_buffers = num_buffers 33 | self._buffer_particle_sizes = np.empty(num_buffers, dtype=np.intp) 34 | self._buffer_particle_sizes.fill(self._max_particles_per_buffer) 35 | self._buffer_particle_sizes[-1] = self._num_particles - (len(self._buffer_particle_sizes)-1) * self._max_particles_per_buffer 36 | 37 | self._buffer_particle_starts = np.cumsum(self._buffer_particle_sizes) - self._buffer_particle_sizes 38 | logger.info(f"Splitting {self._num_particles} particles into {self._num_buffers} buffer(s)") 39 | 40 | 41 | def _global_to_split_address(self, address: int) -> (int, int): 42 | """Given a logical buffer particle offset, returns the physical buffer number and address""" 43 | bufnum = np.searchsorted(self._buffer_particle_starts, address, side='right')-1 44 | return bufnum, address - self._buffer_particle_starts[bufnum] 45 | 46 | @property 47 | def num_buffers(self) -> int: 48 | """Number of buffers per global buffer""" 49 | return self._num_buffers 50 | 51 | def global_to_split(self, start: int, length: int) -> tuple[list[int], list[int], list[int]]: 52 | """Map global start and length to split buffer numbers, starts and lengths.""" 53 | bufs = [] 54 | starts = [] 55 | lengths = [] 56 | 57 | global_start = start 58 | global_length_remaining = length 59 | bufnum, local_start = self._global_to_split_address(global_start) 60 | 61 | while global_length_remaining>0 and bufnum 0: 74 | raise ValueError(f"Requested length {length} starting at {start} exceeds available buffers") 75 | 76 | return bufs, starts, lengths 77 | 78 | def global_to_split_monotonic(self, start: list[int], length: list[int]) -> list[tuple[list[int], list[int]]]: 79 | """Map global start and length to starts and lengths for each buffer. Addressing must be monotonically increasing.""" 80 | performance.signposter.emit_event("global_to_split_monotonic") 81 | starts = [] 82 | lengths = [] 83 | 84 | cur_buf = 0 85 | cur_buf_start = 0 86 | cur_buf_end = self._buffer_particle_sizes[cur_buf] 87 | 88 | all_buf_results = [(starts, lengths)] 89 | 90 | for global_start, global_length in zip(start, length): 91 | while global_length > 0: 92 | while global_start >= cur_buf_end: 93 | # move to next buffer 94 | cur_buf += 1 95 | if cur_buf>=self._num_buffers: 96 | raise ValueError(f"Requested length {global_length} starting at {global_start} exceeds available buffers") 97 | cur_buf_start = self._buffer_particle_starts[cur_buf] 98 | cur_buf_end = cur_buf_start + self._buffer_particle_sizes[cur_buf] 99 | starts = [] 100 | lengths = [] 101 | all_buf_results.append((starts, lengths)) 102 | 103 | this_buf_start = global_start - cur_buf_start 104 | this_buf_length = min(global_length, cur_buf_end - global_start) 105 | starts.append(this_buf_start) 106 | lengths.append(this_buf_length) 107 | global_length -= this_buf_length 108 | global_start += this_buf_length 109 | 110 | if cur_buf < self._num_buffers-1: 111 | for bufnum in range(cur_buf+1, self._num_buffers): 112 | all_buf_results.append(([], [])) 113 | 114 | performance.signposter.emit_event("end global_to_split_monotonic") 115 | 116 | return all_buf_results 117 | 118 | 119 | def create_buffers(self, wgpu_device: wgpu.GPUDevice, item_size: int, usage: wgpu.BufferUsage) -> list[wgpu.GPUBuffer]: 120 | """Create a set of split buffers 121 | 122 | Parameters 123 | ---------- 124 | wgpu_device : wgpu.GPUDevice 125 | The GPU device to create the buffer on. 126 | item_size : int 127 | The size of each item in the buffer. 128 | usage : int 129 | The usage flags for the buffer. 130 | 131 | """ 132 | buffers = [] 133 | for this_size in self._buffer_particle_sizes: 134 | size = this_size * item_size 135 | buffer = wgpu_device.create_buffer( 136 | size=size, 137 | usage=usage, 138 | ) 139 | buffers.append(buffer) 140 | return buffers 141 | 142 | def write_buffers(self, wgpu_device: wgpu.GPUDevice, buffers: list[wgpu.GPUBuffer], data: np.ndarray) -> None: 143 | """Write data to the split buffers. 144 | 145 | Parameters 146 | ---------- 147 | wgpu_device : wgpu.GPUDevice 148 | The GPU device to write the buffer on. 149 | buffers : list[wgpu.GPUBuffer] 150 | The buffers to write to. 151 | data : np.ndarray 152 | The data to write. 153 | """ 154 | if len(buffers) != self._num_buffers: 155 | raise ValueError(f"Number of buffers {len(buffers)} does not match number of split buffers {self._num_buffers}") 156 | if len(data) != self._num_particles: 157 | raise ValueError(f"Data size {len(data)} does not match number of particles {self._num_particles}") 158 | 159 | for bufnum, buf in enumerate(buffers): 160 | start = self._buffer_particle_starts[bufnum] 161 | length = self._buffer_particle_sizes[bufnum] 162 | wgpu_device.queue.write_buffer(buf, 0, data[start:start+length]) -------------------------------------------------------------------------------- /src/topsy/recorder/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import copy 4 | import time 5 | import tqdm 6 | import wgpu 7 | import numpy as np 8 | import logging 9 | import pickle 10 | 11 | from . interpolator import (Interpolator, StepInterpolator, LinearInterpolator, RotationInterpolator, 12 | SmoothedRotationInterpolator, SmoothedLinearInterpolator, SmoothedStepInterpolator) 13 | from ..drawreason import DrawReason 14 | from ..view_synchronizer import ViewSynchronizer 15 | 16 | from typing import TYPE_CHECKING 17 | 18 | if TYPE_CHECKING: 19 | from ..visualizer import Visualizer 20 | 21 | logger = logging.getLogger(__name__) 22 | logger.setLevel(logging.INFO) 23 | 24 | 25 | 26 | class VisualizationRecorder: 27 | _record_properties = ['colormap[type]', 'quantity_name', 'colormap[log]', 'colormap[vmin]', 'colormap[vmax]', 'colormap[gamma]', # NB ordering is important to prevent triggering auto-scaling 28 | 'colormap[density_vmin]', 'colormap[density_vmax]', 'rotation_matrix', 'scale', 'position_offset'] 29 | _record_interpolation_class_smoothed = [StepInterpolator, StepInterpolator, StepInterpolator, SmoothedStepInterpolator, SmoothedStepInterpolator, 30 | SmoothedStepInterpolator, SmoothedStepInterpolator, SmoothedStepInterpolator, SmoothedRotationInterpolator, SmoothedLinearInterpolator, SmoothedLinearInterpolator] 31 | _record_interpolation_class_unsmoothed = [StepInterpolator, StepInterpolator, StepInterpolator, StepInterpolator, StepInterpolator, 32 | StepInterpolator, StepInterpolator, StepInterpolator, RotationInterpolator, LinearInterpolator, LinearInterpolator] 33 | 34 | 35 | def __init__(self, visualizer: Visualizer): 36 | vs = ViewSynchronizer(synchronize=self._record_properties) 37 | vs.add_view(visualizer) 38 | vs.add_view(self, setter = VisualizationRecorder._add_event) 39 | self._recording = False 40 | self._playback = False 41 | self._recording_ends_at = None 42 | self._visualizer = visualizer 43 | self._reset_timestream() 44 | 45 | def _add_event(self, key, value): 46 | if key in self._record_properties: 47 | self._view_synchronizer.update_completed(self) # this marks the update as done 48 | if self._recording: 49 | self._timestream[key].append((self._time_elapsed(), copy.copy(value))) 50 | 51 | def _time_elapsed(self): 52 | return time.time() - self._t0 53 | 54 | def _reset_timestream(self): 55 | self._timestream = {r: [(0.0, copy.copy( 56 | self._view_synchronizer._default_getter(self._visualizer, r)))] for r in self._record_properties} 57 | 58 | def record(self): 59 | self._t0 = time.time() 60 | self._reset_timestream() 61 | self._recording = True 62 | self._playback = False 63 | 64 | def stop(self): 65 | if self._recording: 66 | self._recording_ends_at = self._time_elapsed() 67 | self._recording = False 68 | self._playback = False 69 | 70 | 71 | def _get_value_at_time(self, property, time): 72 | return self._interpolators[property](time) 73 | 74 | def _progress_iterator(self, ntot): 75 | """Return an iterator that displays progress in an appropriate way 76 | 77 | Overriden for the qt gui""" 78 | return tqdm.tqdm(range(ntot), unit="frame") 79 | 80 | def _replay(self, fps=30.0, resolution=(1920, 1080), show_colorbar=True, 81 | show_scalebar=True, smooth=True, set_vmin_vmax=True, 82 | set_quantity=True): 83 | if self._recording: 84 | self.stop() 85 | if self._recording_ends_at is None: 86 | raise RuntimeError("Can't playback before recording") 87 | 88 | self._recording = False 89 | self._playback = True 90 | 91 | exclude = [] 92 | 93 | if not set_vmin_vmax: 94 | exclude.extend(['vmin', 'vmax']) 95 | if not set_quantity: 96 | exclude.append('quantity_name') 97 | 98 | 99 | try: 100 | self._visualizer.show_colorbar = show_colorbar 101 | self._visualizer.show_scalebar = show_scalebar 102 | if smooth: 103 | self._interpolators = {r: c(self._timestream[r]) 104 | for c, r in zip(self._record_interpolation_class_smoothed, 105 | self._record_properties) 106 | if r not in exclude} 107 | else: 108 | self._interpolators = {r: c(self._timestream[r]) 109 | for c, r in zip(self._record_interpolation_class_unsmoothed, 110 | self._record_properties) 111 | if r not in exclude} 112 | 113 | device = self._visualizer.device 114 | 115 | render_texture: wgpu.GPUTexture = device.create_texture( 116 | size=(resolution[0], resolution[1], 1), 117 | usage=wgpu.TextureUsage.RENDER_ATTACHMENT | 118 | wgpu.TextureUsage.COPY_SRC, 119 | format=self._visualizer.canvas_format, 120 | label="output_texture", 121 | ) 122 | 123 | num_frames = int(self._recording_ends_at * fps) 124 | for i in self._progress_iterator(num_frames): 125 | t = i / fps 126 | for p in self._record_properties: 127 | if p not in exclude: 128 | val = self._get_value_at_time(p, t) 129 | if val is not Interpolator.no_value: 130 | self._view_synchronizer._default_setter(self._visualizer, p, val) 131 | 132 | self._visualizer.display_status("github.com/pynbody/topsy/", timeout=1e6) 133 | self._visualizer.draw(DrawReason.EXPORT, render_texture.create_view()) 134 | im = device.queue.read_texture({'texture': render_texture, 'origin': (0, 0, 0)}, 135 | {'bytes_per_row': 4 * resolution[0]}, 136 | (resolution[0], resolution[1], 1)) 137 | im_npy = np.frombuffer(im, dtype=np.uint8).reshape((resolution[1], resolution[0], 4)) 138 | im_npy = im_npy[:, :,:3] 139 | yield im_npy 140 | 141 | self.playback = False 142 | finally: 143 | self._visualizer.show_colorbar = True 144 | self._visualizer.show_scalebar = True 145 | self._visualizer.display_status("Complete", timeout=1.0) 146 | 147 | def save_mp4(self, filename, fps, resolution, *args, **kwargs): 148 | import cv2 149 | writer = cv2.VideoWriter(filename, cv2.VideoWriter.fourcc(*'mp4v'), fps, 150 | resolution) 151 | 152 | for image in self._replay(fps, resolution, *args, **kwargs): 153 | writer.write(image) 154 | 155 | writer.release() 156 | 157 | def save_timestream(self, fname): 158 | pickle.dump((self._timestream, self._recording_ends_at), open(fname, 'wb')) 159 | 160 | def load_timestream(self, fname): 161 | self._timestream, self._recording_ends_at = pickle.load(open(fname, 'rb')) 162 | 163 | 164 | @property 165 | def recording(self): 166 | return self._recording 167 | 168 | -------------------------------------------------------------------------------- /tests/test_scalebar.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from topsy.scalebar import BarLengthRecommender 4 | 5 | 6 | def test_very_small_scales_parsecs(): 7 | """Test scalebar recommendations for very small scales (sub-parsec).""" 8 | # 10^-3 pc = 1e-6 kpc 9 | window_width_kpc = 1e-6 10 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 11 | length = recommender.physical_scalebar_length_base_units 12 | label = recommender.label 13 | 14 | # Should recommend something reasonable and small 15 | assert length <= window_width_kpc / 2 16 | assert "au" in label # At this scale, should use AU units 17 | 18 | 19 | def test_small_parsec_scales(): 20 | """Test scalebar recommendations for parsec-scale windows.""" 21 | # 1 pc = 1e-3 kpc 22 | window_width_kpc = 1e-3 23 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 24 | length = recommender.physical_scalebar_length_base_units 25 | label = recommender.label 26 | 27 | assert length <= window_width_kpc / 2 28 | assert "pc" in label 29 | 30 | # 10 pc = 0.01 kpc 31 | window_width_kpc = 0.01 32 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 33 | length = recommender.physical_scalebar_length_base_units 34 | label = recommender.label 35 | 36 | assert length <= window_width_kpc / 2 37 | assert "pc" in label 38 | 39 | 40 | def test_kiloparsec_scales(): 41 | """Test scalebar recommendations for kiloparsec-scale windows.""" 42 | # 1 kpc 43 | window_width_kpc = 1.0 44 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 45 | length = recommender.physical_scalebar_length_base_units 46 | label = recommender.label 47 | 48 | assert length <= window_width_kpc / 2 49 | assert "pc" in label or "kpc" in label 50 | 51 | # 100 kpc 52 | window_width_kpc = 100.0 53 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 54 | length = recommender.physical_scalebar_length_base_units 55 | label = recommender.label 56 | 57 | assert length <= window_width_kpc / 2 58 | assert "kpc" in label 59 | 60 | 61 | def test_megaparsec_scales(): 62 | """Test scalebar recommendations for megaparsec-scale windows.""" 63 | # 1 Mpc = 1000 kpc 64 | window_width_kpc = 1100.0 65 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 66 | length = recommender.physical_scalebar_length_base_units 67 | label = recommender.label 68 | 69 | assert length <= window_width_kpc / 2 70 | assert "kpc" in label or "Mpc" in label 71 | 72 | # 100 Mpc 73 | window_width_kpc = 110000.0 74 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 75 | length = recommender.physical_scalebar_length_base_units 76 | label = recommender.label 77 | 78 | assert length <= window_width_kpc / 2 79 | assert "Mpc" in label 80 | 81 | # 2000 Mpc 82 | window_width_kpc = 2100000.0 83 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 84 | length = recommender.physical_scalebar_length_base_units 85 | label = recommender.label 86 | 87 | assert length <= window_width_kpc / 2 88 | assert "Mpc" in label 89 | 90 | def test_alternative_base_units(): 91 | recommender = BarLengthRecommender(base_units="au") 92 | recommender.update_window_width(1e6) # 1 million AU ~ 4.85 pc 93 | length = recommender.physical_scalebar_length_base_units 94 | label = recommender.label 95 | assert length <= 5e5 and length > 4e5 96 | assert label == '2 pc' 97 | 98 | def test_quantization_logic(): 99 | """Test that recommended lengths follow 1, 2, 5 × 10^n pattern in their appropriate units.""" 100 | test_windows = [1e-6, 1e-3, 0.01, 1.0, 10.0, 100.0, 1000.0, 10000.0] 101 | 102 | for window_width_kpc in test_windows: 103 | recommender = BarLengthRecommender(initial_window_width_in_base_units=window_width_kpc) 104 | length = recommender._physical_scalebar_length_in_chosen_unit 105 | 106 | power_of_ten = np.floor(np.log10(length)) 107 | mantissa = length / (10 ** power_of_ten) 108 | assert any(abs(mantissa - target) < 1e-10 for target in [1.0, 2.0, 5.0]) 109 | 110 | def test_update_window_width(): 111 | """Test that updating the window width recalculates the recommendation correctly.""" 112 | recommender = BarLengthRecommender(initial_window_width_in_base_units=1.0) 113 | initial_length = recommender.physical_scalebar_length_base_units 114 | initial_label = recommender.label 115 | 116 | # Update to a larger window 117 | recommender.update_window_width(100.0) 118 | new_length = recommender.physical_scalebar_length_base_units 119 | new_label = recommender.label 120 | 121 | assert new_length != initial_length 122 | assert new_label != initial_label 123 | assert new_length <= 100.0 / 2 124 | 125 | 126 | def test_label_formatting(): 127 | """Test that labels are formatted correctly for different ranges.""" 128 | 129 | # Test very small values that should use scientific notation 130 | recommender = BarLengthRecommender(initial_window_width_in_base_units=1e-6) # 0.001 pc window 131 | label = recommender.label 132 | if "pc" in label and recommender.physical_scalebar_length_base_units * 1000 < 0.01: # If in parsecs and very small 133 | assert "$" in label and "\\times 10^{" in label 134 | 135 | # Test normal parsec values 136 | recommender = BarLengthRecommender(initial_window_width_in_base_units=0.01) # 10 pc window 137 | label = recommender.label 138 | if "pc" in label: 139 | # Should be normal formatting, not scientific 140 | value_in_pc = recommender.physical_scalebar_length_base_units * 1000 141 | if 0.01 <= value_in_pc <= 1000: 142 | assert "$" not in label 143 | 144 | # Test kpc values 145 | recommender = BarLengthRecommender(initial_window_width_in_base_units=10.0) # 10 kpc window 146 | label = recommender.label 147 | if "kpc" in label: 148 | assert "$" not in label # Normal formatting 149 | 150 | # Test Mpc values 151 | recommender = BarLengthRecommender(initial_window_width_in_base_units=10000.0) # 10 Mpc window 152 | label = recommender.label 153 | if "Mpc" in label: 154 | assert "$" not in label # Normal formatting 155 | 156 | 157 | def test_format_scientific_latex(): 158 | """Test the LaTeX scientific notation formatter.""" 159 | 160 | # Test normal range values (no scientific notation) 161 | result = BarLengthRecommender._format_scientific_latex(0.1, "pc") 162 | assert result == "0.1 pc" 163 | 164 | result = BarLengthRecommender._format_scientific_latex(1.0, "pc") 165 | assert result == "1 pc" 166 | 167 | result = BarLengthRecommender._format_scientific_latex(10.5, "kpc") 168 | assert result == "10.5 kpc" 169 | 170 | # Test very small values (scientific notation) 171 | result = BarLengthRecommender._format_scientific_latex(0.005, "pc") 172 | assert result == "$5 \\times 10^{-3}$ pc" 173 | 174 | result = BarLengthRecommender._format_scientific_latex(0.002, "pc") 175 | assert result == "$2 \\times 10^{-3}$ pc" 176 | 177 | # Test very large values (scientific notation) 178 | result = BarLengthRecommender._format_scientific_latex(2000, "Mpc") 179 | assert result == "$2 \\times 10^{3}$ Mpc" 180 | 181 | # Test zero 182 | result = BarLengthRecommender._format_scientific_latex(0, "pc") 183 | assert result == "0 pc" 184 | 185 | 186 | -------------------------------------------------------------------------------- /src/topsy/line.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import wgpu 5 | 6 | from .util import load_shader 7 | 8 | from typing import TYPE_CHECKING 9 | if TYPE_CHECKING: 10 | from .visualizer import Visualizer 11 | class Line: 12 | def __init__(self, visualizer: Visualizer, path, color, width): 13 | self._visualizer = visualizer 14 | 15 | if path is not None: 16 | path = np.asarray(path, dtype=np.float32) 17 | assert path.ndim == 2, "Path must be an array of points, each with 4 (xyzw) coordinates" 18 | assert path.shape[1] == 4, "Path must be an array of points, each with 4 (xyzw) coordinates" 19 | self._line_starts = path[:-1] 20 | self._line_ends = path[1:] 21 | else: 22 | assert hasattr(self, "_line_starts") and hasattr(self, "_line_ends"), \ 23 | "Either path must be provided, or _line_starts and _line_ends must be defined by a subclass" 24 | assert len(self._line_starts) == len(self._line_ends), \ 25 | "Number of line starts must equal number of line ends" 26 | 27 | self._color = color 28 | self._width = width 29 | self._device = visualizer.device 30 | self._target_canvas_format = visualizer.canvas_format 31 | 32 | self._setup_shader_module() 33 | self._setup_buffers() 34 | self._setup_render_pipeline() 35 | 36 | def _setup_shader_module(self): 37 | self._shader_module = self._visualizer.device.create_shader_module( 38 | code=load_shader("line.wgsl"), 39 | label="line_shader_module" 40 | ) 41 | 42 | 43 | def _setup_buffers(self): 44 | self._vertex_buffer_starts = self._visualizer.device.create_buffer_with_data( 45 | usage=wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.COPY_DST, 46 | label="line_vertex_buffer_start", 47 | data=self._line_starts 48 | ) 49 | 50 | self._vertex_buffer_ends = self._visualizer.device.create_buffer_with_data( 51 | usage=wgpu.BufferUsage.VERTEX | wgpu.BufferUsage.COPY_DST, 52 | label="line_vertex_buffer_start", 53 | data=self._line_ends 54 | ) 55 | 56 | _param_dtype = np.dtype([ 57 | ("transform", np.float32, (4, 4)), 58 | ("color", np.float32, 4), 59 | ("vp_size_pix", np.float32, 2), 60 | ("width_pix", np.float32), 61 | ("padding", np.float32, 3) 62 | ]) 63 | 64 | self._param_buffer = self._device.create_buffer( 65 | label="line_param_buffer", 66 | size=_param_dtype.itemsize, 67 | usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST 68 | ) 69 | 70 | self._params = np.zeros(1, dtype=_param_dtype) 71 | self._params["transform"] = np.eye(4) 72 | self._params["color"] = self._color 73 | self._params["width_pix"] = self._width 74 | 75 | def _setup_render_pipeline(self): 76 | self._bind_group_layout = self._device.create_bind_group_layout( 77 | label="line_bind_group_layout", 78 | entries=[ 79 | { 80 | "binding": 0, 81 | "visibility": wgpu.ShaderStage.VERTEX, 82 | "buffer": { 83 | "type": wgpu.BufferBindingType.uniform, 84 | } 85 | } 86 | ] 87 | ) 88 | 89 | self._bind_group = self._device.create_bind_group( 90 | label="line_bind_group", 91 | layout=self._bind_group_layout, 92 | entries=[ 93 | { 94 | "binding": 0, 95 | "resource": { 96 | "buffer": self._param_buffer, 97 | "offset": 0, 98 | "size": self._param_buffer.size 99 | } 100 | } 101 | ] 102 | ) 103 | 104 | self._pipeline_layout = self._device.create_pipeline_layout( 105 | label="line_pipeline_layout", 106 | bind_group_layouts=[self._bind_group_layout] 107 | ) 108 | 109 | self._render_pipeline = self._device.create_render_pipeline( 110 | label="line_render_pipeline", 111 | layout=self._pipeline_layout, 112 | vertex={ 113 | "module": self._shader_module, 114 | "entry_point": "vertex_main", 115 | "buffers": [ 116 | { # start of line segment 117 | "array_stride": 4*4, 118 | "step_mode": wgpu.VertexStepMode.instance, 119 | "attributes": [ 120 | { 121 | "format": wgpu.VertexFormat.float32x4, 122 | "offset": 0, 123 | "shader_location": 0 124 | } 125 | ] 126 | }, 127 | { # end of line segment 128 | "array_stride": 4 * 4, 129 | "step_mode": wgpu.VertexStepMode.instance, 130 | "attributes": [ 131 | { 132 | "format": wgpu.VertexFormat.float32x4, 133 | "offset": 0, 134 | "shader_location": 1 135 | } 136 | ] 137 | } 138 | ] 139 | }, 140 | primitive={ 141 | "topology": wgpu.PrimitiveTopology.triangle_strip, 142 | }, 143 | depth_stencil=None, 144 | multisample=None, 145 | fragment={ 146 | "module": self._shader_module, 147 | "entry_point": "fragment_main", 148 | "targets": [ 149 | { 150 | "format": self._target_canvas_format, 151 | "blend": { 152 | "color": { 153 | "src_target": wgpu.BlendFactor.src_alpha, 154 | "dst_target": wgpu.BlendFactor.one_minus_src_alpha, 155 | "operation": wgpu.BlendOperation.add, 156 | }, 157 | "alpha": { 158 | "src_target": wgpu.BlendFactor.src_alpha, 159 | "dst_target": wgpu.BlendFactor.one_minus_src_alpha, 160 | "operation": wgpu.BlendOperation.add, 161 | }, 162 | "write_mask": wgpu.ColorWrite.ALL, 163 | } 164 | } 165 | ] 166 | } 167 | ) 168 | 169 | 170 | def encode_render_pass(self, command_encoder: wgpu.GPUCommandEncoder, 171 | target_texture_view: wgpu.GPUTextureView): 172 | 173 | self._params["vp_size_pix"] = target_texture_view.size[:2] 174 | 175 | self._device.queue.write_buffer(self._param_buffer, 0, self._params) 176 | 177 | render_pass = command_encoder.begin_render_pass( 178 | color_attachments=[{ 179 | "view": target_texture_view, 180 | "resolve_target": None, 181 | "clear_value": (0.0, 0.0, 0.0, 1.0), 182 | "load_op": wgpu.LoadOp.load, 183 | "store_op": wgpu.StoreOp.store, 184 | }], 185 | ) 186 | render_pass.set_pipeline(self._render_pipeline) 187 | render_pass.set_bind_group(0, self._bind_group, [], 0, 999999) 188 | render_pass.set_vertex_buffer(0, self._vertex_buffer_starts) 189 | render_pass.set_vertex_buffer(1, self._vertex_buffer_ends) 190 | render_pass.draw(4, len(self._line_starts), 0, 0) 191 | render_pass.end() 192 | -------------------------------------------------------------------------------- /src/topsy/scalebar.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pynbody 5 | 6 | from . import text 7 | from . import overlay 8 | 9 | 10 | from typing import TYPE_CHECKING 11 | if TYPE_CHECKING: 12 | from .visualizer import Visualizer 13 | 14 | class BarLengthRecommender: 15 | """Class to recommend a physical length for a scalebar, given the window width in kpc. 16 | 17 | The recommended length will be a "nice" number (1, 2, or 5 times a power of ten), along with a unit. 18 | 19 | The unit will be chosen from km, AU, pc, kpc, or Mpc depending on the length itself and the window width. 20 | 21 | """ 22 | 23 | acceptable_units = "km", "au", "pc", "kpc", "Mpc" 24 | 25 | 26 | def __init__(self, initial_window_width_in_base_units=1.0, base_units="kpc"): 27 | self.unit_conversion_to_base = np.array([ 28 | pynbody.units.Unit(u).in_units(base_units) for u in self.acceptable_units 29 | ]) 30 | self._window_width_in_base_units = initial_window_width_in_base_units 31 | self._update_recommendation() 32 | self._update_label() 33 | 34 | def _update_recommendation(self): 35 | magnitude_in_each_unit = abs(np.log10(self._window_width_in_base_units / self.unit_conversion_to_base) - 0.5) 36 | chosen_unit_index = np.argmin(magnitude_in_each_unit) 37 | chosen_unit = self.acceptable_units[chosen_unit_index] 38 | chosen_unit_conversion = self.unit_conversion_to_base[chosen_unit_index] 39 | target_scalebar_length_in_chosen_unit = (self._window_width_in_base_units / 2.0) / chosen_unit_conversion 40 | quantized_length_in_chosen_unit = self._quantize_length(target_scalebar_length_in_chosen_unit) 41 | self._physical_scalebar_length_in_chosen_unit = quantized_length_in_chosen_unit 42 | self._physical_scalebar_length_unit_name = chosen_unit 43 | self._physical_scalebar_length_base_units = quantized_length_in_chosen_unit * chosen_unit_conversion 44 | 45 | @classmethod 46 | def _quantize_length(cls, physical_scalebar_length): 47 | """Find a length less than or equal to physical_scalebar_length, that is 1, 2, or 5 times a power of ten.""" 48 | power_of_ten = np.floor(np.log10(physical_scalebar_length)) 49 | mantissa = physical_scalebar_length / 10 ** power_of_ten 50 | if mantissa < 2.0: 51 | physical_scalebar_length = 10.0 ** power_of_ten 52 | elif mantissa < 5.0: 53 | physical_scalebar_length = 2.0 * 10.0 ** power_of_ten 54 | else: 55 | physical_scalebar_length = 5.0 * 10.0 ** power_of_ten 56 | return physical_scalebar_length 57 | 58 | @classmethod 59 | def _format_scientific_latex(cls, value, unit): 60 | """Format a number in scientific notation with LaTeX rendering.""" 61 | if value == 0: 62 | return f"0 {unit}" 63 | 64 | # Only use scientific notation for very small or very large numbers 65 | if 0.01 <= abs(value) <= 1000: 66 | if value == int(value): 67 | return f"{int(value)} {unit}" 68 | else: 69 | return f"{value:.2f}".rstrip('0').rstrip('.') + f" {unit}" 70 | 71 | exponent = int(np.floor(np.log10(abs(value)))) 72 | mantissa = value / (10 ** exponent) 73 | 74 | return f"${mantissa:.0f} \\times 10^{{{exponent}}}$ {unit}" 75 | 76 | def _update_label(self): 77 | self._label = self._format_scientific_latex(self._physical_scalebar_length_in_chosen_unit, 78 | self._physical_scalebar_length_unit_name) 79 | self._label_is_for = (self._physical_scalebar_length_in_chosen_unit, self._physical_scalebar_length_unit_name) 80 | 81 | def update_window_width(self, window_width_in_base_units): 82 | """Update the window width in base units, and recalculate the recommended scalebar length if it has changed.""" 83 | if window_width_in_base_units != self._window_width_in_base_units: 84 | self._window_width_in_base_units = window_width_in_base_units 85 | self._update_recommendation() 86 | 87 | @property 88 | def label(self): 89 | """Get the label for the current recommended scalebar length.""" 90 | if self._label_is_for != (self._physical_scalebar_length_in_chosen_unit, 91 | self._physical_scalebar_length_unit_name): 92 | self._update_label() 93 | return self._label 94 | 95 | @property 96 | def physical_scalebar_length_base_units(self): 97 | """Get the recommended physical scalebar length in base units.""" 98 | return self._physical_scalebar_length_base_units 99 | 100 | 101 | class BarOverlay(overlay.Overlay): 102 | """Overlay that implements a bar.""" 103 | 104 | def __init__(self, *args, x0=0.1, y0=0.1, height_pixels=20, color=(1, 1, 1, 1), initial_length=0.2, **kwargs): 105 | self.x0 = x0 106 | self.y0 = y0 107 | self.height_pixels = height_pixels 108 | self.color = color 109 | self.length = initial_length # update just by directly setting self.length in clipspace coords - simple :-) 110 | 111 | super().__init__(*args, **kwargs) 112 | 113 | def render_contents(self) -> np.ndarray: 114 | # just a single pixel of the right color 115 | pixel = np.ones((1, 1, 4), dtype=np.float32) 116 | pixel[0,0,:] = self.color 117 | return pixel 118 | 119 | def get_clipspace_coordinates(self, window_pixel_width, window_pixel_height): 120 | height_clipspace = 2.0 * self.height_pixels / window_pixel_height 121 | return self.x0, self.y0, self.length, height_clipspace 122 | 123 | 124 | class ScalebarOverlay: 125 | def __init__(self, visualizer: Visualizer): 126 | self._label = text.TextOverlay(visualizer, "Scalebar", (-0.9, -0.85), 40, color=(1, 1, 1, 1)) 127 | self._bar = BarOverlay(visualizer, x0=-0.9, y0=-0.9, height_pixels=10, color=(1, 1, 1, 1)) 128 | self._recommender = BarLengthRecommender(1.0, visualizer.data_loader.get_position_units()) # will be updated immediately 129 | self._visualizer = visualizer 130 | 131 | def encode_render_pass(self, command_encoder: wgpu.GPUCommandEncoder, target_texture_view: wgpu.GPUTextureView): 132 | self._update_length() 133 | self._bar.length = self._physical_scalebar_length / self._visualizer.scale 134 | # note that the visualizer scale refers to a square rendering target 135 | # however only part of this is shown in the final window if the window 136 | # aspect ratio isn't 1:1. So we now need to correct for this effect. 137 | # The full x extent is shown if the width is greater than the height, so 138 | # no correction is needed then. If the height is greater than the width, 139 | # then the x extent is scaled by the ratio of the height to the width. 140 | 141 | if self._visualizer.canvas.width_physical < self._visualizer.canvas.height_physical: 142 | self._bar.length *= self._visualizer.canvas.height_physical / self._visualizer.canvas.width_physical 143 | 144 | self._label.encode_render_pass(command_encoder, target_texture_view) 145 | self._bar.encode_render_pass(command_encoder, target_texture_view) 146 | 147 | def _update_scalebar_label(self, physical_scalebar_length): 148 | if getattr(self, "_scalebar_label_is_for_length", None) != physical_scalebar_length: 149 | self._label.text = self._recommender.label 150 | self._scalebar_label_is_for_length = physical_scalebar_length 151 | self._label.update() 152 | 153 | def _update_length(self): 154 | # target is for the scalebar to be no more than 1/2 the viewport 155 | # (but not too much less either); however the length is to be 10^n or 5*10^n 156 | # in world coordinates the viewport is 2 * self.scale kpc wide 157 | # so the maximum scalebar length should be self.scale kpc 158 | window_width_kpc = 2.0 * self._visualizer.scale 159 | self._recommender.update_window_width(window_width_kpc) 160 | self._physical_scalebar_length = self._recommender.physical_scalebar_length_base_units 161 | self._update_scalebar_label(self._physical_scalebar_length) 162 | 163 | 164 | -------------------------------------------------------------------------------- /src/topsy/canvas/jupyter.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display 2 | import ipywidgets as widgets 3 | from typing import Callable, Any 4 | 5 | from rendercanvas.jupyter import RenderCanvas, loop 6 | from . import VisualizerCanvasBase 7 | from ..config import JUPYTER_UI_LAG 8 | from ..colormap.ui import ControlSpec, UnifiedColorMapController 9 | 10 | class VisualizerCanvas(VisualizerCanvasBase, RenderCanvas): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self._allow_events = True 14 | 15 | def request_draw(self, function=None): 16 | 17 | # As a side effect, wgpu gui layer stores our function call, to enable it to be 18 | # repainted later. But we want to distinguish such repaints and handle them 19 | # differently, so we need to replace the function with our own 20 | 21 | #def function_wrapper(): 22 | # function() 23 | # self._subwidget.draw_frame = lambda: self._visualizer.draw(DrawReason.PRESENTATION_CHANGE) 24 | # TODO: above needs to be hacked for jupyter 25 | 26 | super().request_draw(function) 27 | 28 | 29 | 30 | @classmethod 31 | def call_later(cls, delay, fn, *args): 32 | loop.call_later(delay, fn, *args) 33 | 34 | def ipython_display_with_widgets(self): 35 | """Display the canvas in a Jupyter notebook with widgets.""" 36 | color_controls = self.build_color_controls() 37 | 38 | # stack canvas, dropdown and range slider 39 | display(widgets.VBox([self, color_controls])) 40 | 41 | 42 | 43 | 44 | def build_color_controls(self) -> widgets.Widget: 45 | """ 46 | Return a nested ipywidget (HBox/VBox) tree driven by the generic ColorMapController.get_layout() spec. 47 | """ 48 | self._controller_box = widgets.Box() 49 | self._controller = UnifiedColorMapController(self._visualizer, self._refresh_ui) 50 | 51 | if self._controller: 52 | self._rebuild_ui(self._controller.get_layout()) 53 | else: 54 | self._rebuild_ui(None) 55 | 56 | return self._controller_box 57 | 58 | def _callback(self, callback: Callable[[Any], None], value: Any): 59 | if not self._allow_events: 60 | return 61 | callback(value) 62 | 63 | 64 | def _refresh_ui(self, root_spec, new_widgets): 65 | """Walk the layout and update all values, including slider ranges.""" 66 | if not hasattr(self, "_controller"): 67 | return 68 | 69 | if new_widgets: 70 | self._rebuild_ui(root_spec) 71 | else: 72 | self._update_ui(root_spec) 73 | 74 | def _update_ui(self, root_spec): 75 | self._allow_events = False 76 | try: 77 | self.update_widget(root_spec, self._controls) 78 | finally: 79 | # re-enable events after a delay, to allow the UI to settle (eugh! surely must be a better way?) 80 | self.call_later(JUPYTER_UI_LAG, lambda: setattr(self, "_allow_events", True)) 81 | 82 | def _rebuild_ui(self, root_spec): 83 | if root_spec is not None: 84 | self._controls = self.convert_layout_to_widget(root_spec) 85 | else: 86 | self._controls = widgets.HTML("No colormap controls available") 87 | self._controller_box.children = [self._controls] 88 | 89 | def convert_layout_to_widget(self, spec) -> widgets.Widget: 90 | children = [] 91 | for child in spec.children: 92 | if isinstance(child, ControlSpec): 93 | children.append(self.make_widget(child)) 94 | else: 95 | children.append(self.convert_layout_to_widget(child)) 96 | if spec.type == "hbox": 97 | return widgets.HBox(children) 98 | else: 99 | return widgets.VBox(children) 100 | 101 | def make_widget(self, spec): 102 | if spec.type == "combo" or spec.type == 'combo-edit': # can't get a good implementation of combo editing in ipython currently 103 | w = widgets.Dropdown( 104 | options=spec.options or [], 105 | value=spec.value, 106 | description=spec.label or "", 107 | layout=widgets.Layout(width="200px") 108 | ) 109 | w.observe(lambda change, cb=spec.callback: self._callback(cb, change["new"]), names="value") 110 | 111 | elif spec.type == "checkbox": 112 | w = widgets.Checkbox( 113 | value=bool(spec.value), 114 | description=spec.label or "" 115 | ) 116 | w.observe(lambda change, cb=spec.callback: self._callback(cb, change["new"]), names="value") 117 | 118 | elif spec.type == "range_slider": 119 | lo, hi = spec.range or (0.0, 1.0) 120 | w = widgets.FloatRangeSlider( 121 | value=tuple(spec.value), 122 | min=lo, max=hi, 123 | step=None, 124 | description=spec.label or "", 125 | layout=widgets.Layout(width="400px") 126 | ) 127 | w.observe(lambda change, cb=spec.callback: self._callback(cb,change["new"]), names="value") 128 | 129 | elif spec.type == "slider": 130 | lo, hi = spec.range or (0.0, 1.0) 131 | w = widgets.FloatSlider( 132 | value=spec.value, 133 | min=lo, max=hi, 134 | step=None, 135 | description=spec.label or "", 136 | layout=widgets.Layout(width="400px") 137 | ) 138 | w.observe(lambda change, cb=spec.callback: self._callback(cb, change["new"]), names="value") 139 | 140 | elif spec.type == "button": 141 | w = widgets.Button(description=spec.label or "") 142 | w.on_click(lambda btn, cb=spec.callback: self._callback(cb, None)) 143 | elif spec.type == "color_picker": 144 | w = widgets.ColorPicker(concise=True, description=spec.label or "", value=spec.value) 145 | w.observe(lambda change, cb=spec.callback: self._callback(cb, change["new"]), names="value") 146 | else: 147 | w = widgets.HTML(f"Unknown control {spec.name}") 148 | 149 | return w 150 | 151 | 152 | def update_widget(self, spec, widget): 153 | if isinstance(spec, ControlSpec): 154 | if spec.type in {"combo", "combo-edit"}: 155 | widget.value = spec.value 156 | elif spec.type == "checkbox": 157 | widget.value = bool(spec.value) 158 | elif spec.type == "range_slider": 159 | lo, hi = spec.range or (0.0, 1.0) 160 | self.safe_update_slider_range(widget, lo, hi) 161 | wlo, whi = spec.value 162 | 163 | # seemingly need to set this after the range update has gone through, otherwise get 164 | # nonsense results in some cases 165 | self.call_later(JUPYTER_UI_LAG/2, lambda: setattr(widget, "value", (wlo, whi))) 166 | 167 | elif spec.type == "slider": 168 | lo, hi = spec.range or (0.0, 1.0) 169 | self.safe_update_slider_range(widget, lo, hi) 170 | 171 | # seemingly need to set this after the range update has gone through, otherwise get 172 | # nonsense results in some cases 173 | self.call_later(JUPYTER_UI_LAG/2, lambda: setattr(widget, "value", spec.value)) 174 | else: 175 | for child_spec, child_widget in zip(spec.children, widget.children): 176 | self.update_widget(child_spec, child_widget) 177 | 178 | @classmethod 179 | def safe_update_slider_range(cls, slider, min_, max_): 180 | # sliders in ipywidgets seem to offer no option to update the range atomically. If one naively sets 181 | # min and max, the intermediate state can be invalid and raise an exception. Therefore, one needs 182 | # to first set a bounding range, then narrow back down to min, max. 183 | if slider.min == min_ and slider.max == max_: 184 | return 185 | bounding_min = min(min_, slider.min) 186 | bounding_max = max(max_, slider.max) 187 | slider.min = bounding_min 188 | slider.max = bounding_max 189 | slider.min = min_ 190 | slider.max = max_ 191 | 192 | -------------------------------------------------------------------------------- /tests/test_colormap.py: -------------------------------------------------------------------------------- 1 | import topsy 2 | import pytest 3 | import pylab as p 4 | import numpy as np 5 | import numpy.testing as npt 6 | 7 | from pathlib import Path 8 | from matplotlib import colors, cm 9 | 10 | from topsy import colormap 11 | from topsy.canvas import offscreen 12 | 13 | @pytest.fixture 14 | def folder(): 15 | folder = Path(__file__).parent / "output" 16 | folder.mkdir(exist_ok=True) 17 | return folder 18 | 19 | 20 | @pytest.fixture 21 | def vis(request): 22 | vis = topsy.test(100, render_resolution=200, canvas_class = offscreen.VisualizerCanvas) 23 | vis.scale = 200.0 24 | return vis 25 | 26 | @pytest.fixture 27 | def input_image(): 28 | """A dummy output from the SPH renderer, where the density varies in the x direction logarithmically between 10^-3 29 | and 1 and the weighted average varies in the y direction linearly between 0 and 1.""" 30 | input_image = np.empty((200, 200, 2), dtype=np.float32) 31 | input_image[:, :, 0] = np.logspace(-3, 0, 200) 32 | input_image[:, :, 1] = np.linspace(0, 1, 200)[:, np.newaxis] * input_image[:, :, 0] 33 | return input_image 34 | 35 | @pytest.mark.parametrize("mode", ['density', 'weighted-average', 'bivariate']) 36 | @pytest.mark.parametrize("log_scale", [True, False], ids=["log", "linear"]) 37 | def test_colormap(vis, input_image, mode, log_scale, folder): 38 | cmap = vis.colormap 39 | 40 | if mode == 'density': 41 | weighted_average = False 42 | type = 'density' 43 | if log_scale: 44 | vmin, vmax = -3.0, 0.0 45 | else: 46 | vmin, vmax = 0.0, 1.0 47 | elif mode == 'weighted-average': 48 | weighted_average = True 49 | type = 'density' 50 | if log_scale: 51 | vmin, vmax = -2.0, 0.0 52 | else: 53 | vmin, vmax = 0.0, 1.0 54 | elif mode == 'bivariate': 55 | weighted_average = True 56 | type = 'bivariate' 57 | if log_scale: 58 | vmin, vmax = -2.0, 0.0 59 | else: 60 | vmin, vmax = 0.0, 1.0 61 | else: 62 | raise ValueError("Invalid mode: {}".format(mode)) 63 | 64 | cmap.update_parameters({ 65 | 'type': type, 66 | 'weighted_average': weighted_average, 67 | 'vmin': vmin, 68 | 'vmax': vmax, 69 | 'density_vmin': -3.0, 70 | 'density_vmax': 0.0, 71 | 'log': log_scale, 72 | }) 73 | 74 | image = cmap.sph_raw_output_to_image(input_image) 75 | 76 | assert image.shape == (200, 200, 4) 77 | 78 | p.imsave(folder / f"test_colormap_{mode}_{log_scale}.png", image) 79 | 80 | image_via_mpl = _colormap_in_software(input_image, cmap, log_scale, vmax, vmin) 81 | p.imsave(folder / f"test_colormap_software_{mode}_{log_scale}.png", image_via_mpl) 82 | 83 | npt.assert_allclose(image, image_via_mpl, atol=5) 84 | 85 | 86 | def _colormap_in_software(input_image, cmap, log_scale, vmax, vmin): 87 | if cmap.get_parameter("type") == "bivariate": 88 | return _bivariate_colormap_in_software(input_image, cmap) 89 | else: 90 | return _univariate_colormap_in_software(input_image, cmap) 91 | 92 | def _univariate_colormap_in_software(input_image, cmap): 93 | mpl_cmap_name = cmap.get_parameter("colormap_name") 94 | vmin = cmap.get_parameter("vmin") 95 | vmax = cmap.get_parameter("vmax") 96 | log_scale = cmap.get_parameter("log") 97 | 98 | norm = colors.Normalize(vmin=vmin, vmax=vmax) 99 | mpl_cmap = cm.ScalarMappable(norm=norm, cmap=cm.get_cmap(mpl_cmap_name)).to_rgba 100 | 101 | content = cmap.sph_raw_output_to_content(input_image) 102 | if log_scale: 103 | content = np.log10(content) 104 | image_via_mpl = (mpl_cmap(content) * 255).astype(np.uint8) 105 | return image_via_mpl 106 | 107 | def _bivariate_colormap_in_software(input_image, cmap): 108 | den_vmin = cmap.get_parameter("density_vmin") 109 | den_vmax = cmap.get_parameter("density_vmax") 110 | vmin = cmap.get_parameter("vmin") 111 | vmax = cmap.get_parameter("vmax") 112 | vlog = cmap.get_parameter("log") 113 | 114 | # generate a 1000 x 1000 grid of points. The mapping is a 2D grid of points in the unit square 115 | mapping = cmap._impl._generate_mapping_rgba_f32(1000) 116 | 117 | # for each point in the input image, figure out the coordinate in the mapping 118 | density = input_image[:, :, 0] 119 | value_times_density = input_image[:, :, 1] 120 | weighted_value = value_times_density / density 121 | 122 | scaled_density = (np.log10(density) - den_vmin) / (den_vmax - den_vmin) 123 | if vlog: 124 | scaled_weighted_value = (np.log10(weighted_value) - vmin) / (vmax - vmin) 125 | else: 126 | scaled_weighted_value = (weighted_value - vmin) / (vmax - vmin) 127 | 128 | # set up an interpolator for the mapping on the unit square. Out-of-bounds values should map to nearest 129 | from scipy.interpolate import RegularGridInterpolator 130 | points = np.linspace(0, 1, 1000) 131 | interpolator = RegularGridInterpolator((points, points), mapping, bounds_error=True, method='linear') 132 | 133 | # now interpolate the mapping for each point in the input image 134 | coords = np.stack((scaled_weighted_value, scaled_density), axis=-1) 135 | coords = np.clip(coords, 0, 1) # ensure coordinates are within [0, 1] 136 | image = np.clip(interpolator(coords), 0, 1) 137 | 138 | # convert to 8-bit RGBA 139 | image = (image * 255).astype(np.uint8) 140 | 141 | return image 142 | 143 | 144 | 145 | 146 | def test_colormap_holder_instantiation(vis): 147 | from topsy.colormap import ColormapHolder 148 | from topsy.colormap.implementation import RGBColormap, RGBHDRColormap, BivariateColormap, Colormap 149 | specs = [ 150 | {"params": {"type": "rgb", "hdr": True}, 151 | "expected": RGBHDRColormap}, 152 | {"params": {"type": "rgb", "hdr": False}, 153 | "expected": RGBColormap}, 154 | {"params": {"type": "bivariate", "hdr": False}, 155 | "expected": BivariateColormap}, 156 | {"params": {"type": "density", "hdr": False}, 157 | "expected": Colormap}, 158 | ] 159 | 160 | for spec in specs: 161 | colormap_class = ColormapHolder.instance_from_parameters(spec["params"], vis.device, 162 | vis._sph.get_output_texture(), 163 | vis.canvas_format) 164 | assert type(colormap_class) == spec["expected"] 165 | 166 | 167 | def test_colormap_updating(vis): 168 | """Test that updating the colormap correctly decides whether to create a new implementation or not""" 169 | cmap = vis.colormap 170 | cmap.update_parameters({'type': 'density'}) 171 | assert isinstance(cmap._impl, colormap.implementation.Colormap) 172 | impl_id = id(cmap._impl) 173 | 174 | cmap.update_parameters({'vmin': 0.0, 'vmax': 20.0}) 175 | assert impl_id == id(cmap._impl) # should not create a new implementation 176 | 177 | 178 | cmap.update_parameters({'type': 'bivariate'}) 179 | assert isinstance(cmap._impl, colormap.implementation.BivariateColormap) 180 | assert impl_id != id(cmap._impl) # should create a new implementation 181 | 182 | def test_rgb_colormap_vmin_vmax(): 183 | """Test that RGB colormap can be updated either with vmin/vmax or with min_mag/max_mag""" 184 | vis = topsy.test(100, render_resolution=200, canvas_class=offscreen.VisualizerCanvas, 185 | render_mode='rgb') 186 | 187 | vis.colormap.update_parameters({'vmin': 1.0, 'vmax': 2.0}) 188 | assert vis.colormap.get_parameter('vmin') == 1.0 189 | assert vis.colormap.get_parameter('vmax') == 2.0 190 | assert np.allclose(vis.colormap.get_parameter('min_mag'), 31.57212566586528) 191 | assert np.allclose(vis.colormap.get_parameter('max_mag'), 34.07212566586528) 192 | 193 | vis.colormap.update_parameters({'min_mag': 1.0, 'max_mag': 2.0}) 194 | assert np.allclose(vis.colormap.get_parameter('min_mag'), 1.0) 195 | assert np.allclose(vis.colormap.get_parameter('max_mag'), 2.0) 196 | assert np.allclose(vis.colormap.get_parameter('vmin'), 13.828850266346112) 197 | assert np.allclose(vis.colormap.get_parameter('vmax'), 14.228850266346113) 198 | 199 | 200 | def test_colormap_dict_access(vis): 201 | """Test that the colormap can be accessed as a dictionary""" 202 | cmap = vis.colormap 203 | cmap.update_parameters({'type': 'density', 'vmin': 0.0, 'vmax': 20.0}) 204 | 205 | assert cmap['type'] == 'density' 206 | assert cmap['vmin'] == 0.0 207 | assert cmap['vmax'] == 20.0 208 | 209 | cmap['vmin'] = 5.0 210 | assert cmap['vmin'] == 5.0 211 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | https://github.com/user-attachments/assets/b185c3b8-8658-4f7d-96de-e976959e7ad6 2 | 3 | 4 | topsy 5 | ===== 6 | 7 | [![Build Status](https://github.com/pynbody/topsy/actions/workflows/build-test.yaml/badge.svg)](https://github.com/pynbody/topsy/actions) 8 | 9 | This package visualises simulations, and is an add-on to the [pynbody](https://github.com/pynbody/pynbody) analysis package. 10 | Its name nods to the [TIPSY](https://github.com/N-BodyShop/tipsy) project. 11 | It is built using [wgpu](https://wgpu.rs), which is a future-facing GPU standard (with thanks to the [python wgpu bindings](https://wgpu-py.readthedocs.io/en/stable/guide.html)). 12 | 13 | At the moment, `topsy` is experimental, but has proven to work well in a variety of environments. It is mainly 14 | developed and optimized on Apple M-series chips, but has also been shown to work on NVidia GPUs. 15 | 16 | The future development path will depend on the level of interest from the community. 17 | 18 | Installing 19 | ---------- 20 | 21 | You will need python 3.11 or later, running in a UNIX variant (basically MacOS, Linux or if you're on Windows you need [WSL](https://learn.microsoft.com/en-us/windows/wsl/install)). You can then install `topsy` using `pip` 22 | as usual: 23 | 24 | ``` 25 | pip install topsy 26 | ``` 27 | 28 | This will install topsy and its dependencies (including `pynbody` itself) into 29 | your current python environment. (If it fails, check that you have python 3.11 30 | or later, and `pip` is itself up-to-date using `pip install -U pip`.) 31 | 32 | ### Alternative 1: install into isolated environment using pipx 33 | 34 | You can also install `topsy` into its own isolated environment using [pipx](https://pypi.org/project/pipx/): 35 | 36 | ``` 37 | pipx install topsy 38 | ``` 39 | 40 | The command line tool will now be available, but you won't have access to the `topsy` package from your existing python environment. This can be useful if you don't want to risk disturbing anything. 41 | 42 | ### Alternative 2: install into new environment using venv and pip 43 | 44 | If you want to play with `topsy` without disturbing your existing installation, but also want to be able to use `topsy` from python scripts or jupyter etc, I recommend using `venv`: 45 | 46 | ``` 47 | # create a toy environment 48 | python -m venv visualiser-env 49 | 50 | # activate the new environment 51 | source visualiser-env/bin/activate 52 | 53 | # install 54 | pip install topsy 55 | 56 | ... other commands ... 57 | 58 | # get your old environment back: 59 | deactivate 60 | ``` 61 | 62 | For more information about venv, see its 63 | [tutorial page](https://docs.python.org/3/library/venv.html). 64 | 65 | ### Alternative 3: install unreleased versions or contribute to development 66 | 67 | As usual, you can also install direct from github, e.g. 68 | 69 | ``` 70 | pip install git+https://github.com/pynbody/topsy 71 | ``` 72 | 73 | Or clone the repository and install for development using 74 | 75 | ``` 76 | pip install -e . 77 | ``` 78 | 79 | from inside the cloned repository. 80 | 81 | 82 | 83 | 84 | Trying it out 85 | ------------- 86 | 87 | ### Very quick start 88 | 89 | Once `topsy` is installed, if you just want to try it out and you don't have a 90 | suitable simulation snapshot to hand, you can download some 91 | from the [tangos tutorial datasets (5.1GB)](https://zenodo.org/records/5959983/files/tutorial_changa.tar.gz?download=1). 92 | You need to untar them (`tar -xzf tutorial_changa.tar.gz` from your command line), then 93 | you can type `topsy pioneer50h128.1536gst1.bwK1.000832` to visualise that file's 94 | dark matter content. 95 | 96 | ### More detailed description 97 | 98 | If using from the command line, pass `topsy` the path to the simulation that you wish to visualise. 99 | 100 | You can (and probably should) also 101 | tell it what to center on using the `-c` flag, to which valid arguments are: 102 | 103 | * `-c none` (just loads the file without changing the centering) 104 | * `-c halo-1` (uses the shrink sphere center of halo 1; or you can change 1 to any other number) 105 | * `-c zoom` (uses the shrink sphere center on the highest resolution particles, without loading a halo catalogue) 106 | * `-c all` (uses the shrink sphere center on all particles in the file) 107 | 108 | By default, it will show you dark matter particles. To change this pass `-p gas` to show gas particles or `-p star` for 109 | stars. Note that the particle type _cannot_ be changed once the window is open (although you can open a separate window for each particle type; see below). 110 | 111 | If your particles have other quantities defined on them (such as `temp` for gas particles), you can view the 112 | density-weighted average quantity by passing `-q temp`. The quantity to visualise can also be changed by selecting it via the main window controls 113 | (see below). 114 | 115 | To open more than one visualisation window on different files or with different parameters, you can 116 | pass multiple groups of parameters separated by `+`, for example to see separate views of the gas and 117 | dark matter you could launch `topsy` with: 118 | 119 | ``` 120 | topsy -c halo-1 -p gas my_simulation + -c halo-1 -p dm my_simulation 121 | ``` 122 | 123 | You can choose to link the rotation/zoom of multiple views using the toolbar (see below). 124 | 125 | Using SSPs 126 | ---------- 127 | 128 | If you have stars in your simulation, you can try rendering using pynbody's SSP tables, using the command-line 129 | flag `--render-mode rgb`, or by selecting the RGB option in the colormap. Make sure you are visualising stars rather than 130 | any other particles e.g. 131 | 132 | ``` 133 | topsy -c halo-1 -p s --render-mode rgb my_simulation 134 | ``` 135 | 136 | Even better, if you have an HDR display (e.g. recent Macbook Pros), you can use the `--render-mode rgb-hdr` flag to render in HDR mode. 137 | Note in HDR mode that the magnitude range specified applies to the SDR range, i.e. HDR brightnesses extend beyond the specified maximum surface brightness limit. The exact brightest magntiude that can be displayed will depend on your display hardware. 138 | 139 | 140 | Controls in the main window 141 | --------------------------- 142 | 143 | The view in the `topsy` window can be manipulated as follows: 144 | 145 | * To spin around the centre, **drag** the mouse. 146 | * To zoom in and out, use the mouse **scroll** wheel. 147 | * To move the centre, **double click** on a target (topsy will determine its depth), or **shift-drag** to move in x-y plane. 148 | * To rescale the colours to an appropriate range for the current view, press `r`(ange) 149 | * To return the view to the original orientation and zoom, press `h`(ome) 150 | 151 | There is also a toolbar at the bottom of the window with some buttons: 152 | 153 | * 154 | - start recording actions (rotations, scalings, movements and more). Press again to stop. 155 | * 156 | - render the recorded actions into an mp4 file. You will be prompted about various options and a filename. 157 | * 158 | 159 | - load and save the recorded actions to a file for later use. 160 | * 161 | - save a snapshot of the current view to an image file. 162 | * 163 | - link this window to other topsy windows, so that rotating, scaling or moving one does the same to the other 164 | * - open colormap control; this lets you select the rendering mode, min/max values, the quantity to visualise, the matplotlib colormap and more. 165 | 166 | Using from jupyter 167 | ------------------ 168 | 169 | It is possible to use `topsy` within a jupyter notebook. Graphics are rendered on the jupyter server, and displayed within the notebook. 170 | 171 | To open a topsy view within your jupyter notebook, try 172 | 173 | ```python 174 | import topsy 175 | topsy.load("/path/to/simulation", particle="gas") 176 | ``` 177 | Note that you can interact with this widget in exactly the same way as the native window produced by `topsy`. Most of 178 | the same options you can pass on the command line are also available via this `load` function (type 179 | `help(topsy.load)` for details). 180 | -------------------------------------------------------------------------------- /src/topsy/__init__.py: -------------------------------------------------------------------------------- 1 | """topsy - An astrophysics simulation visualization package based on webgpu, using pynbody for reading data""" 2 | 3 | from __future__ import annotations 4 | 5 | __version__ = "0.8.1" 6 | 7 | import argparse 8 | import logging 9 | import sys 10 | 11 | from typing import TYPE_CHECKING 12 | if TYPE_CHECKING: 13 | import pynbody 14 | from .visualizer import Visualizer 15 | 16 | 17 | from . import config 18 | 19 | logger = None 20 | 21 | def parse_args(args=None): 22 | """Create arguments and kwargs to pass to the visualizer, from sys.argv""" 23 | 24 | # create the argument parser, add arguments for filename, resolution, and colormap, and parse the arguments 25 | argparser = argparse.ArgumentParser(description="Visualize an astrophysics simulation. Multiple windows can be opened by separating groups of arguments with *.") 26 | 27 | argparser.add_argument("filename", help="Specify path to a simulation file to be visualized") 28 | argparser.add_argument("--resolution", "-r", help="Specify the resolution of the visualization", 29 | default=config.DEFAULT_RESOLUTION, type=int) 30 | argparser.add_argument("--colormap", "-m", help="Specify the matplotlib colormap to be used", 31 | default=config.DEFAULT_COLORMAP, type=str) 32 | argparser.add_argument("--particle", "-p", help="Specify the particle type to visualise", 33 | default="dm", type=str) 34 | argparser.add_argument("--center", "-c", help="Specify the centering method: 'halo-', 'all', 'zoom' or 'none'", 35 | default="none", type=str) 36 | argparser.add_argument("--quantity", "-q", help="Specify a quantity to render instead of density", 37 | default=None, type=str) 38 | argparser.add_argument("--tile", "-t", help="Wrap and tile the simulation box using its periodicity", 39 | default=False, action="store_true") 40 | argparser.add_argument("--render-mode", help="Rendering mode: univariate (default), bivariate, rgb, rgb-hdr, surface", 41 | default="univariate", choices=['univariate', 'bivariate', 'rgb', 'rgb-hdr', 'surface'], dest='render_mode') 42 | argparser.add_argument("--load-sphere", nargs='+', help="Load a sphere of particles with the given " 43 | "radius and, optionally, centre in simulation units. " 44 | "e.g. --load-sphere 5.0 to load a sphere of radius 5.0 about" 45 | "the centre of the simulation, or 5.0 3.0 1.0 2.0 to load a " 46 | "sphere of radius 5.0 about the point (3.0, 1.0, 2.0)." 47 | "Supported only for swift simulations. Units are simulation units.", 48 | metavar=("_"), 49 | default=None, type=float) 50 | 51 | if args is None: 52 | args = sys.argv[1:] 53 | arg_batches = [] 54 | # split args into batches separated by '+' 55 | while len(args) > 0: 56 | try: 57 | split_index = args.index("+") 58 | except ValueError: 59 | split_index = len(args) 60 | 61 | this_args = argparser.parse_args(args[:split_index]) 62 | 63 | if this_args.load_sphere is not None and len(this_args.load_sphere) != 1 and len(this_args.load_sphere) != 4: 64 | argparser.error("Invalid number of arguments for --load-sphere. Must be 1 or 4.") 65 | arg_batches.append(this_args) 66 | args = args[split_index+1:] 67 | 68 | 69 | return arg_batches 70 | 71 | def setup_logging(): 72 | global logger 73 | if logger is not None: 74 | return 75 | logger = logging.getLogger(__name__) 76 | logger.setLevel(logging.DEBUG) 77 | ch = logging.StreamHandler() 78 | ch.setLevel(logging.DEBUG) 79 | ch.setFormatter( 80 | logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 81 | ) 82 | logger.addHandler(ch) 83 | 84 | def main(): 85 | all_args = parse_args() 86 | 87 | for args in all_args: 88 | # Convert CLI args to parameters dict for cleaner interface 89 | vis = load(args.filename, center=args.center, resolution=args.resolution, 90 | particle=args.particle, tile=args.tile, 91 | sphere_radius=args.load_sphere[0] if args.load_sphere is not None else None, 92 | sphere_center=tuple(args.load_sphere[1:]) if args.load_sphere is not None and len(args.load_sphere) == 4 else None, 93 | render_mode=args.render_mode) 94 | vis.quantity_name = args.quantity 95 | vis.canvas.show() 96 | 97 | from rendercanvas import qt # has to be imported here so that underlying qt toolkit has been autoselected 98 | qt.loop.run() 99 | 100 | def topsy(snapshot: pynbody.snapshot.SimSnap, quantity: str | None = None, parameters: dict = None, **kwargs): 101 | from . import visualizer, loader 102 | vis = visualizer.Visualizer(data_loader_class=loader.PynbodyDataInMemory, 103 | data_loader_args=(snapshot,), 104 | parameters=parameters, 105 | **kwargs) 106 | vis.quantity_name = quantity 107 | return vis 108 | 109 | def load(filename: str, center: str = "none", particle: str = "gas", 110 | resolution: int = config.DEFAULT_RESOLUTION, tile: bool = False, 111 | sphere_radius: float | None = None, sphere_center: tuple[float, float, float] | None = None, 112 | render_mode: str = None) -> Visualizer: 113 | """ 114 | Load a simulation file (currently using pynbody) and return a visualizer object. 115 | 116 | Parameters 117 | ---------- 118 | 119 | filename : str 120 | Path to the simulation file. You can also specify test:// to generate a test dataset with N particles. 121 | 122 | center : str 123 | Centering method. Can be 'halo-', 'all', 'zoom' or 'none'. 124 | 125 | particle : str 126 | Particle type to visualize. Default is 'gas'; other options include 'dm' and 'star'. 127 | 128 | resolution : int 129 | Resolution of the visualization in pixels. 130 | 131 | sphere_radius : float | None 132 | If specified, load a sphere of particles with the given radius. Units are simulation units. 133 | 134 | sphere_center : tuple[float, float, float] | None 135 | If specified, load a sphere of particles with the given center. Units are simulation units. 136 | Must be a tuple of three floats (x, y, z). 137 | 138 | tile : bool 139 | If True, wrap and tile the simulation box using its periodicity. Default is False. 140 | 141 | render_mode : str 142 | Visualization mode. Should be one of 'univariate', 'bivariate', 'rgb', 'rgb-hdr', 'surface', etc. 143 | 144 | Returns 145 | ------- 146 | visualizer.Visualizer 147 | A visualizer object that can be used to render the simulation data. 148 | 149 | """ 150 | from . import visualizer, loader 151 | setup_logging() 152 | 153 | if "test://" in filename: 154 | loader_class = loader.TestDataLoader 155 | try: 156 | n_part = int(float(filename[7:])) # going through float allows scientific notation 157 | except ValueError: 158 | n_part = config.TEST_DATA_NUM_PARTICLES_DEFAULT 159 | logger.info(f"Using test data with {n_part} particles") 160 | loader_args = (n_part,) 161 | else: 162 | import pynbody 163 | loader_class = loader.PynbodyDataLoader 164 | if sphere_radius is not None: 165 | if sphere_center is not None: 166 | loader_args = (filename, center, particle, pynbody.filt.Sphere(sphere_radius, sphere_center)) 167 | else: 168 | loader_args = (filename, center, particle, pynbody.filt.Sphere(sphere_radius)) 169 | else: 170 | loader_args = (filename, center, particle) 171 | 172 | vis = visualizer.Visualizer(data_loader_class=loader_class, 173 | data_loader_args=loader_args, 174 | periodic_tiling=tile, 175 | render_resolution=resolution, 176 | render_mode=render_mode) 177 | 178 | return vis 179 | 180 | def test(nparticle=config.TEST_DATA_NUM_PARTICLES_DEFAULT, **kwargs) -> Visualizer: 181 | from . import visualizer, loader 182 | vis = visualizer.Visualizer(data_loader_class=loader.TestDataLoader, 183 | data_loader_args=(nparticle,), 184 | data_loader_kwargs={'with_cells': kwargs.pop('with_cells', False), 185 | 'periodic': kwargs.get('periodic_tiling', False)}, 186 | **kwargs) 187 | return vis 188 | 189 | 190 | 191 | 192 | _force_is_jupyter = False 193 | 194 | def is_jupyter(): 195 | """Determine whether the user is executing in a Jupyter Notebook / Lab. 196 | 197 | This has been pasted from an old version of wgpu.gui.auto.is_jupyter; the function was removed""" 198 | global _force_is_jupyter 199 | if _force_is_jupyter: 200 | return True 201 | from IPython import get_ipython 202 | try: 203 | ip = get_ipython() 204 | if ip is None: 205 | return False 206 | if ip.has_trait("kernel"): 207 | return True 208 | else: 209 | return False 210 | except NameError: 211 | return False 212 | 213 | def force_jupyter(): 214 | """Force the return from is_jupyter() to be True; used in testing""" 215 | global _force_is_jupyter 216 | _force_is_jupyter = True 217 | -------------------------------------------------------------------------------- /src/topsy/canvas/qt/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | import PySide6 # noqa: F401 (need to import to select the qt backend) 5 | from PySide6 import QtWidgets, QtGui, QtCore 6 | 7 | from rendercanvas.qt import RenderCanvas, loop 8 | 9 | from .colormap import ColorMapControls 10 | from .lineedit import MyLineEdit 11 | from .recording import RecordingSettingsDialog, VisualizationRecorderWithQtProgressbar 12 | from .. import VisualizerCanvasBase 13 | from ...drawreason import DrawReason 14 | 15 | import os 16 | import logging 17 | 18 | from typing import TYPE_CHECKING 19 | 20 | if TYPE_CHECKING: 21 | pass 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | logger.setLevel(logging.INFO) 26 | 27 | def _get_icon(name): 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | return QtGui.QIcon(os.path.join(this_dir, "icons", name)) 30 | 31 | 32 | class VisualizerCanvas(VisualizerCanvasBase, RenderCanvas): 33 | 34 | _all_instances = [] 35 | def __init__(self, **kwargs): 36 | super().__init__(**kwargs) 37 | self._all_instances.append(self) 38 | 39 | self._toolbar = QtWidgets.QToolBar() 40 | self._toolbar.setIconSize(QtCore.QSize(16, 16)) 41 | 42 | # setup toolbar to show text and icons 43 | self._toolbar.setToolButtonStyle(QtCore.Qt.ToolButtonStyle.ToolButtonTextBesideIcon) 44 | 45 | self._load_icons() 46 | 47 | self._record_action = QtGui.QAction(self._record_icon, "Record", self) 48 | self._record_action.triggered.connect(self.on_click_record) 49 | 50 | self._save_action = QtGui.QAction(self._save_icon, "Snapshot", self) 51 | self._save_action.triggered.connect(self.on_click_save) 52 | 53 | self._save_movie_action = QtGui.QAction(self._save_movie_icon, "Save mp4", self) 54 | self._save_movie_action.triggered.connect(self.on_click_save_movie) 55 | self._save_movie_action.setDisabled(True) 56 | 57 | self._save_script_action = QtGui.QAction(self._export_icon, "Save timestream", self) 58 | self._save_script_action.triggered.connect(self.on_click_save_script) 59 | self._save_script_action.setDisabled(True) 60 | 61 | self._load_script_action = QtGui.QAction(self._import_icon, "Load timestream", self) 62 | self._load_script_action.triggered.connect(self.on_click_load_script) 63 | 64 | self._link_action = QtGui.QAction(self._unlinked_icon, "Link to other windows", self) 65 | self._link_action.setIconText("Link") 66 | self._link_action.triggered.connect(self.on_click_link) 67 | 68 | 69 | 70 | 71 | self._cmap_icon = _get_icon("rgb.png") 72 | self._open_cmap = QtGui.QAction(self._cmap_icon, "Color", self) 73 | 74 | 75 | 76 | self._toolbar.addAction(self._load_script_action) 77 | self._toolbar.addAction(self._save_script_action) 78 | self._toolbar.addAction(self._record_action) 79 | self._toolbar.addAction(self._save_movie_action) 80 | 81 | self._toolbar.addSeparator() 82 | self._toolbar.addAction(self._save_action) 83 | self._toolbar.addSeparator() 84 | 85 | self._toolbar.addAction(self._open_cmap) 86 | 87 | self._toolbar.addSeparator() 88 | 89 | 90 | self._toolbar.addAction(self._link_action) 91 | self._recorder = None 92 | 93 | 94 | 95 | # now replace the wgpu layout with our own 96 | layout = self.layout() 97 | layout.removeWidget(self._subwidget) 98 | 99 | our_layout = PySide6.QtWidgets.QVBoxLayout() 100 | our_layout.addWidget(self._subwidget) 101 | our_layout.addWidget(self._toolbar) 102 | our_layout.setContentsMargins(0, 0, 0, 0) 103 | our_layout.setSpacing(0) 104 | 105 | self._toolbar.adjustSize() 106 | 107 | self._toolbar_update_timer = QtCore.QTimer(self) 108 | self._toolbar_update_timer.timeout.connect(self._update_toolbar) 109 | self._toolbar_update_timer.start(100) 110 | 111 | layout.addLayout(our_layout) 112 | self.call_later(0, self._prepare_colormap_pane) 113 | 114 | def _prepare_colormap_pane(self): 115 | if hasattr(self, '_cmap_connection'): 116 | self._open_cmap.disconnect(self._cmap_connection) 117 | 118 | self._colormap_controls = ColorMapControls(self) 119 | 120 | self._cmap_connection = self._open_cmap.triggered.connect(self._colormap_controls.open) 121 | 122 | 123 | def __del__(self): 124 | try: 125 | self._all_instances.remove(self) 126 | except ValueError: 127 | pass 128 | super().__del__() 129 | 130 | def _load_icons(self): 131 | self._record_icon = _get_icon("record.png") 132 | self._stop_icon = _get_icon("stop.png") 133 | self._save_icon = _get_icon("camera.png") 134 | self._linked_icon = _get_icon("linked.png") 135 | self._unlinked_icon = _get_icon("unlinked.png") 136 | self._save_movie_icon = _get_icon("movie.png") 137 | self._import_icon = _get_icon("load_script.png") 138 | self._export_icon = _get_icon("save_script.png") 139 | 140 | def on_click_record(self): 141 | 142 | if self._recorder is None or not self._recorder.recording: 143 | logger.info("Starting recorder") 144 | self._recorder = VisualizationRecorderWithQtProgressbar(self._visualizer, self) 145 | self._recorder.record() 146 | self._record_action.setIconText("Stop") 147 | self._record_action.setIcon(self._stop_icon) 148 | else: 149 | logger.info("Stopping recorder") 150 | self._recorder.stop() 151 | self._record_action.setIconText("Record") 152 | self._record_action.setIcon(self._record_icon) 153 | 154 | def on_click_save_movie(self): 155 | # show the options dialog first: 156 | dialog = RecordingSettingsDialog(self) 157 | dialog.exec() 158 | if dialog.result() == QtWidgets.QDialog.DialogCode.Accepted: 159 | fd = QtWidgets.QFileDialog(self) 160 | fname, _ = fd.getSaveFileName(self, "Save video", "", "MP4 (*.mp4)") 161 | if fname: 162 | logger.info("Saving video to %s", fname) 163 | self._recorder.save_mp4(fname, show_colorbar=dialog.show_colorbar, 164 | show_scalebar=dialog.show_scalebar, 165 | fps=dialog.fps, 166 | resolution=dialog.resolution, 167 | smooth=dialog.smooth, 168 | set_vmin_vmax=dialog.set_vmin_vmax, 169 | set_quantity=dialog.set_quantity) 170 | QtGui.QDesktopServices.openUrl(QtCore.QUrl.fromLocalFile(fname)) 171 | 172 | def on_click_save(self): 173 | fd = QtWidgets.QFileDialog(self) 174 | fname, _ = fd.getSaveFileName(self, "Save snapshot", "", "PNG (*.png);; PDF (*.pdf);; numpy (*.npy)") 175 | if fname: 176 | logger.info("Saving snapshot to %s", fname) 177 | self._visualizer.save(fname) 178 | QtGui.QDesktopServices.openUrl(QtCore.QUrl.fromLocalFile(fname)) 179 | 180 | def on_click_save_script(self): 181 | fd = QtWidgets.QFileDialog(self) 182 | fname, _ = fd.getSaveFileName(self, "Save camera movements", "", "Python Pickle (*.pickle)") 183 | if fname: 184 | logger.info("Saving timestream to %s", fname) 185 | self._recorder.save_timestream(fname) 186 | 187 | def on_click_load_script(self): 188 | fd = QtWidgets.QFileDialog(self) 189 | fname, _ = fd.getOpenFileName(self, "Load camera movements", "", "Python Pickle (*.pickle)") 190 | if fname: 191 | logger.info("Loading timestream from %s", fname) 192 | self._recorder = VisualizationRecorderWithQtProgressbar(self._visualizer, self) 193 | self._recorder.load_timestream(fname) 194 | 195 | 196 | def on_click_link(self): 197 | if self._visualizer.is_synchronizing(): 198 | logger.info("Stop synchronizing") 199 | self._visualizer.stop_synchronizing() 200 | else: 201 | logger.info("Start synchronizing") 202 | from ... import view_synchronizer 203 | synchronizer = view_synchronizer.ViewSynchronizer() 204 | for instance in self._all_instances: 205 | synchronizer.add_view(instance._visualizer) 206 | 207 | def _update_toolbar(self): 208 | if self._recorder is not None or len(self._all_instances)<2: 209 | self._link_action.setDisabled(True) 210 | else: 211 | self._link_action.setDisabled(False) 212 | if self._visualizer.is_synchronizing(): 213 | self._link_action.setIcon(self._linked_icon) 214 | self._link_action.setIconText("Unlink") 215 | else: 216 | self._link_action.setIcon(self._unlinked_icon) 217 | self._link_action.setIconText("Link") 218 | if self._recorder is not None and not self._recorder.recording: 219 | self._save_movie_action.setDisabled(False) 220 | self._save_script_action.setDisabled(False) 221 | else: 222 | self._save_movie_action.setDisabled(True) 223 | self._save_script_action.setDisabled(True) 224 | 225 | 226 | 227 | 228 | def request_draw(self, function=None): 229 | # As a side effect, wgpu gui layer stores our function call, to enable it to be 230 | # repainted later. But we want to distinguish such repaints and handle them 231 | # differently, so we need to replace the function with our own 232 | call_count = 0 233 | def function_wrapper(): 234 | nonlocal call_count 235 | if call_count == 0: 236 | function() 237 | else: 238 | # we have been cached! 239 | self._visualizer.draw(DrawReason.PRESENTATION_CHANGE) 240 | call_count += 1 241 | 242 | super().request_draw(function_wrapper) 243 | 244 | @classmethod 245 | def call_later(cls, delay, fn, *args): 246 | loop.call_later(delay, fn, *args) 247 | -------------------------------------------------------------------------------- /tests/test_progression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from topsy import progressive_render, config 5 | from topsy.drawreason import DrawReason 6 | 7 | def _get_single_block(blocks): 8 | """Helper function to convert from list of blocks to single block""" 9 | assert len(blocks)==2 10 | assert len(blocks[0])==1 11 | assert len(blocks[1])==1 12 | return blocks[0][0], blocks[1][0] 13 | 14 | def test_initial_recommendations(): 15 | # Test the initial recommendation for a small number of particles 16 | render_progression = progressive_render.RenderProgression(config.INITIAL_PARTICLES_TO_RENDER//2) 17 | render_progression.start_frame(DrawReason.INITIAL_UPDATE) 18 | assert _get_single_block(render_progression.get_block(0.0)) == (0, config.INITIAL_PARTICLES_TO_RENDER//2) 19 | 20 | # Test the initial recommendation for a large number of particles 21 | render_progression = progressive_render.RenderProgression(config.INITIAL_PARTICLES_TO_RENDER*2) 22 | render_progression.start_frame(DrawReason.INITIAL_UPDATE) 23 | assert _get_single_block(render_progression.get_block(0.0)) == (0, config.INITIAL_PARTICLES_TO_RENDER) 24 | 25 | def test_export_recommendations(): 26 | render_progression = progressive_render.RenderProgression(config.INITIAL_PARTICLES_TO_RENDER * 2) 27 | render_progression.start_frame(DrawReason.EXPORT) 28 | assert _get_single_block(render_progression.get_block(0.0)) == (0, config.INITIAL_PARTICLES_TO_RENDER*2) 29 | render_progression.end_block(0.1) 30 | assert render_progression.get_block(1.0) is None 31 | 32 | def test_progression(): 33 | # Test the progression of recommendations 34 | render_progression = progressive_render.RenderProgression(1000, 100) 35 | render_progression.start_frame(DrawReason.CHANGE) 36 | 37 | # Simulate rendering a block of particles 38 | start_index, num_to_render = _get_single_block(render_progression.get_block(0.0)) 39 | assert start_index == 0 40 | assert num_to_render == 100 41 | 42 | # Simulate ending the block and reporting time taken 43 | render_progression.end_block(0.5/config.TARGET_FPS) 44 | 45 | # Check the next recommendation. We're half way through. 46 | start_index, num_to_render = _get_single_block(render_progression.get_block(0.5/config.TARGET_FPS)) 47 | assert start_index == 100 48 | assert num_to_render == 50 # hasn't updated the expected rendering number, just looks at time remaining 49 | 50 | render_progression.end_block(1./config.TARGET_FPS) 51 | 52 | # Check the end of the frame 53 | assert render_progression.get_block(1./config.TARGET_FPS) is None 54 | 55 | # We've rendered 150 particles. 56 | assert render_progression.end_frame_get_scalefactor() == 1000./150 57 | 58 | 59 | def test_timeout_and_progression(): 60 | # Test the timeout and progression of recommendations 61 | render_progression = progressive_render.RenderProgression(1000, 100) 62 | render_progression.start_frame(DrawReason.CHANGE) 63 | 64 | # Simulate a long time elapsed 65 | block = _get_single_block(render_progression.get_block(0.0)) 66 | assert block is not None 67 | 68 | render_progression.end_block(1.0) # far too long! 69 | # Check that the next recommendation is None 70 | block = render_progression.get_block(1.0) 71 | assert block is None 72 | 73 | sf = render_progression.end_frame_get_scalefactor() 74 | assert sf == 10.0 75 | 76 | assert render_progression.needs_refine() 77 | 78 | render_progression.start_frame(DrawReason.REFINE) 79 | start, num = _get_single_block(render_progression.get_block(0.0)) 80 | assert start == 100 81 | assert num == int(100/config.TARGET_FPS) # took one second to render 100 particles, so recommendation should be this 82 | 83 | 84 | def test_always_one_block(): 85 | render_progression = progressive_render.RenderProgression(1000, 100) 86 | render_progression.start_frame(DrawReason.CHANGE) 87 | 88 | block = _get_single_block(render_progression.get_block(1.0)) # simulate a long time elapsed, but should still render at least one block 89 | assert block is not None 90 | 91 | def test_no_render_on_presentation_change(): 92 | render_progression = progressive_render.RenderProgression(1000, 100) 93 | render_progression.start_frame(DrawReason.CHANGE) 94 | 95 | # Simulate rendering the first frame 96 | t = 0.0 97 | while (block := render_progression.get_block(t)) is not None: 98 | t+=1e-5 99 | render_progression.end_block(t) 100 | 101 | render_progression.end_frame_get_scalefactor() 102 | assert not render_progression.needs_refine() 103 | 104 | render_progression.start_frame(DrawReason.PRESENTATION_CHANGE) 105 | 106 | block = render_progression.get_block(0.0) 107 | assert block is None 108 | 109 | render_progression.end_frame_get_scalefactor() 110 | assert not render_progression.needs_refine() 111 | 112 | def test_no_frame_exception(): 113 | render_progression = progressive_render.RenderProgression(1000, 100) 114 | with pytest.raises(RuntimeError): 115 | _get_single_block(render_progression.get_block(0.0)) 116 | 117 | def test_export(): 118 | """Test that export always recommends the full resolution""" 119 | 120 | render_progression = progressive_render.RenderProgression(1000, 100) 121 | render_progression.start_frame(DrawReason.EXPORT) 122 | block = _get_single_block(render_progression.get_block(0.0)) 123 | assert block == (0, 1000) 124 | 125 | def test_always_one_particle(): 126 | """Test that always at least one particle is recommended (even if it takes a long time)""" 127 | 128 | render_progression = progressive_render.RenderProgression(1000, 3) 129 | render_progression.start_frame(DrawReason.CHANGE) 130 | 131 | # Simulate a long time elapsed 132 | block = _get_single_block(render_progression.get_block(0.0)) 133 | assert block is not None 134 | 135 | render_progression.end_block(1.0) # far too long! 136 | 137 | assert render_progression.get_block(1.0) is None # end of this frame 138 | render_progression.end_frame_get_scalefactor() 139 | assert render_progression.needs_refine() 140 | 141 | # Check that the refinement consists of at least one particle, even though technically 142 | # we would have rounded this down to zero particles 143 | render_progression.start_frame(DrawReason.REFINE) 144 | block = _get_single_block(render_progression.get_block(1.0)) 145 | assert block == (3,1) 146 | 147 | @pytest.fixture 148 | def cell_progressive_render(num_particles=100000, num_side=10, num_part_to_take = 100): 149 | np.random.seed(1337) 150 | pos = np.random.uniform(0.0, 1.0, (num_particles, 3)) 151 | 152 | cell_layout, order = progressive_render.CellLayout.from_positions(pos, 0.0, 1.0, num_side) 153 | pos = pos[order] 154 | 155 | render_progression = progressive_render.RenderProgressionWithCells(cell_layout, len(pos), num_part_to_take) 156 | 157 | return render_progression, pos 158 | 159 | def test_blocks_with_layout(cell_progressive_render): 160 | 161 | render_progression, pos = cell_progressive_render 162 | cell_layout = render_progression._cell_layout 163 | 164 | total_particles = 0 165 | 166 | rendered = np.zeros(len(pos), dtype=np.int32) 167 | 168 | render_progression.start_frame(DrawReason.CHANGE) 169 | first_render = True 170 | 171 | while True: 172 | 173 | block = render_progression.get_block(0.0) 174 | 175 | for start, length in zip(*block): 176 | # find which cell this block belongs to 177 | cell_index = cell_layout.cell_index_from_offset(start) 178 | assert length!=0 # should not return zero length blocks 179 | cell_index_at_end = cell_layout.cell_index_from_offset(start+length-1) 180 | assert cell_index == cell_index_at_end 181 | total_particles+=length 182 | rendered[start:start+length]+=1 183 | 184 | if first_render: 185 | assert total_particles > 95 and total_particles < 105 186 | 187 | render_progression.end_block(0.0001) 188 | render_progression.end_frame_get_scalefactor() 189 | 190 | if render_progression.needs_refine(): 191 | first_render = False 192 | render_progression.start_frame(DrawReason.REFINE) 193 | else: 194 | break 195 | 196 | assert (rendered==1).all() 197 | 198 | # check that render_progression returns None at end of frame 199 | 200 | render_progression.start_frame(DrawReason.CHANGE) 201 | npart = 0 202 | while (block := render_progression.get_block(0.0)): 203 | starts, lens = block 204 | npart+=lens.sum() 205 | render_progression.end_block(0.0) 206 | assert npart == len(pos) 207 | 208 | def test_spatial_limits(cell_progressive_render): 209 | render_progression, pos = cell_progressive_render 210 | 211 | render_progression.select_sphere((0.5, 0.5, 0.5), 0.1) 212 | 213 | render_progression.start_frame(DrawReason.CHANGE) 214 | rendered = np.zeros(len(pos), dtype=np.int32) 215 | while (block:=render_progression.get_block(0.0)): 216 | for start, length in zip(*block): 217 | rendered[start:start+length]+=1 218 | 219 | render_progression.end_block(0.0) 220 | 221 | assert rendered.max() == 1 222 | 223 | r = np.linalg.norm(pos-0.5, axis=1) 224 | rendered_r = r[rendered==1] 225 | unrendered_r = r[rendered==0] 226 | 227 | assert (rendered_r<0.4).all() 228 | assert (unrendered_r>0.1).all() 229 | 230 | def test_export_very_large(): 231 | num_renders = 5 232 | render_progression = progressive_render.RenderProgression(config.MAX_PARTICLES_PER_EXPORT_RENDERCALL * num_renders) 233 | render_progression.start_frame(DrawReason.EXPORT) 234 | 235 | for blocknum in range(num_renders): 236 | # pretend we have a crazy long-running render. We should render the whole thing, but in a series 237 | # of blocks 238 | block = render_progression.get_block(100.0*blocknum) 239 | assert block is not None 240 | assert block[0][0] == config.MAX_PARTICLES_PER_EXPORT_RENDERCALL * blocknum 241 | assert block[1][0] == config.MAX_PARTICLES_PER_EXPORT_RENDERCALL 242 | render_progression.end_block(100.0 * (blocknum + 1)) 243 | 244 | assert render_progression.get_block(100.0*num_renders) is None # finished now! 245 | 246 | clear = render_progression.start_frame(DrawReason.EXPORT) 247 | 248 | assert clear 249 | 250 | -------------------------------------------------------------------------------- /src/topsy/canvas/qt/colormap.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Dict, Union, Callable 4 | 5 | from PySide6 import QtWidgets, QtCore 6 | from superqt import QLabeledDoubleRangeSlider, QLabeledDoubleSlider 7 | from rendercanvas import BaseRenderCanvas 8 | 9 | from .lineedit import MyLineEdit 10 | 11 | from ...colormap.ui import LayoutSpec, GenericController, ControlSpec, UnifiedColorMapController 12 | 13 | import math 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | class MapControlsBase(QtWidgets.QDialog): 20 | def __init__(self, parent: BaseRenderCanvas): 21 | self._parent = parent 22 | super().__init__() 23 | 24 | self.setWindowTitle("Color controls") 25 | self.setWindowFlags(QtCore.Qt.WindowType.Popup | QtCore.Qt.WindowType.FramelessWindowHint) 26 | self.resize(400, 0) 27 | 28 | def open(self): 29 | action_rect = self._parent._toolbar.actionGeometry(self._parent._open_cmap) # EEK! 30 | popoverPosition = self._parent._toolbar.mapToGlobal(action_rect.topLeft()) 31 | super().show() 32 | self.move(popoverPosition - QtCore.QPoint(self.width()//2, self.height())) 33 | 34 | 35 | class QLabeledDoubleRangeSliderWithAutoscale(QLabeledDoubleRangeSlider): 36 | def __init__(self, *args: Any, **kwargs: Any) -> None: 37 | self._scale_exponent = 0 38 | super().__init__(*args, **kwargs) 39 | 40 | def _scale_float(self, value: float) -> float: 41 | return value / 10**self._scale_exponent 42 | 43 | def _unscale_float(self, value: float) -> float: 44 | return value * 10**self._scale_exponent 45 | 46 | def _repr_value_to_scale_exponent(self, value: float) -> int: 47 | if value == 0.0: 48 | return 0 49 | try: 50 | exponent = math.floor(math.log10(abs(value))) 51 | except ValueError: 52 | return 0 53 | if exponent < -2 or exponent > 2: 54 | return exponent 55 | else: 56 | return 0 57 | 58 | def setRange(self, vmin: float, vmax: float) -> None: 59 | if vmin == 0.0 and vmax == 0.0: 60 | repr_val = 1.0 61 | elif vmin==0.0: 62 | repr_val = vmax 63 | elif vmax==0.0: 64 | repr_val = vmin 65 | else: 66 | repr_val = max(abs(vmin), abs(vmax)) 67 | 68 | self._scale_exponent = self._repr_value_to_scale_exponent(repr_val) 69 | scaled_min = self._scale_float(vmin) 70 | scaled_max = self._scale_float(vmax) 71 | 72 | super().setRange(scaled_min, scaled_max) 73 | 74 | def setValue(self, value: tuple[float, float]) -> None: 75 | scaled_value = (self._scale_float(value[0]), self._scale_float(value[1])) 76 | super().setValue(scaled_value) 77 | 78 | def value(self) -> tuple[float, float]: 79 | scaled_value = super().value() 80 | return (self._unscale_float(scaled_value[0]), self._unscale_float(scaled_value[1])) 81 | 82 | 83 | 84 | class ColorMapControls(QtWidgets.QDialog): 85 | def __init__(self, canvas: BaseRenderCanvas): 86 | super().__init__(canvas) 87 | self.setWindowTitle("Color controls") 88 | self.setWindowFlags(QtCore.Qt.WindowType.Popup 89 | | QtCore.Qt.WindowType.FramelessWindowHint) 90 | 91 | self.controller: GenericController = UnifiedColorMapController(canvas._visualizer, self._refresh_ui) 92 | 93 | # build UI 94 | self._widgets: Dict[str, QtWidgets.QWidget] = {} 95 | root_spec = self.controller.get_layout() 96 | self._layout = self._build_layout(root_spec) 97 | self.setLayout(self._layout) 98 | 99 | def open(self): 100 | # position next to toolbar 101 | self.controller.refresh_ui() 102 | super().show() 103 | self._update_screen_size_and_position() 104 | 105 | def _update_screen_size_and_position(self): 106 | self.resize(400, 0) 107 | self.updateGeometry() 108 | action_rect = self.parent()._toolbar.actionGeometry( 109 | self.parent()._open_cmap 110 | ) 111 | pos = self.parent()._toolbar.mapToGlobal(action_rect.topLeft()) 112 | self.move(pos - QtCore.QPoint(self.width()//2, self.height())) 113 | 114 | def _build_layout(self, spec: LayoutSpec) -> QtWidgets.QLayout: 115 | if spec.type == "vbox": 116 | layout = QtWidgets.QVBoxLayout() 117 | else: 118 | layout = QtWidgets.QHBoxLayout() 119 | 120 | for child in spec.children: 121 | if isinstance(child, ControlSpec): 122 | label_layout = self._make_widget(child) 123 | self._widgets[child.name] = label_layout 124 | 125 | if child.label is not None and child.type != "button" and child.type != "checkbox": 126 | w_inner = label_layout 127 | label_layout = QtWidgets.QHBoxLayout() 128 | label_layout.addWidget(QtWidgets.QLabel(child.label)) 129 | label_layout.addWidget(w_inner) 130 | label_layout.setAlignment(QtCore.Qt.AlignmentFlag.AlignRight) 131 | layout.addLayout(label_layout) 132 | else: 133 | layout.addWidget(label_layout) 134 | else: 135 | layout.addLayout(self._build_layout(child)) 136 | return layout 137 | 138 | def _make_widget(self, spec: ControlSpec) -> QtWidgets.QWidget: 139 | if spec.type == "combo" or spec.type == "combo-edit": 140 | w = QtWidgets.QComboBox() 141 | w.setEditable(spec.name == "quantity") 142 | 143 | w.addItems(spec.options or []) 144 | w.setCurrentText(spec.value) 145 | edited_callback = lambda: self._on_changed(spec.callback, w.currentText()) 146 | w.currentIndexChanged.connect(edited_callback) 147 | if spec.type == "combo-edit": 148 | w.setLineEdit(MyLineEdit()) 149 | w.lineEdit().editingFinished.connect(edited_callback) 150 | 151 | elif spec.type == "checkbox": 152 | w = QtWidgets.QCheckBox(spec.label or "") 153 | w.setChecked(bool(spec.value)) 154 | w.stateChanged.connect( 155 | lambda st, cb=spec.callback: self._on_changed(cb, bool(st)) 156 | ) 157 | elif spec.type == "range_slider": 158 | w = QLabeledDoubleRangeSliderWithAutoscale() 159 | w.setRange(*(spec.range or (0.0, 1.0))) 160 | w.setValue(tuple(spec.value)) 161 | w.valueChanged.connect( 162 | lambda _, cb=spec.callback, widget=w: 163 | self._on_changed(cb, widget.value()) 164 | ) 165 | elif spec.type == "slider": 166 | w = QLabeledDoubleSlider() 167 | w.setRange(*(spec.range or (0.0, 1.0))) 168 | w.setValue(spec.value) 169 | w.valueChanged.connect( 170 | lambda _, cb=spec.callback, widget=w: 171 | self._on_changed(cb, widget.value()) 172 | ) 173 | elif spec.type == "button": 174 | w = QtWidgets.QPushButton(spec.label or "") 175 | w.setStyleSheet("color: black;") # unclear why this is necessary 176 | w.pressed.connect(lambda cb=spec.callback: self._on_changed(cb, None)) 177 | elif spec.type == "color_picker": 178 | w = QtWidgets.QPushButton() 179 | w.setText("") 180 | w.setStyleSheet(f"background-color: {spec.value};") 181 | original_color = spec.value 182 | def pick_color(): 183 | dialog = QtWidgets.QColorDialog(self) 184 | dialog.setCurrentColor(spec.value) 185 | dialog.setWindowTitle(spec.name) 186 | dialog.setOption(QtWidgets.QColorDialog.ColorDialogOption.ShowAlphaChannel, False) 187 | def on_color_changed(color): 188 | if color.isValid(): 189 | w.setStyleSheet(f"background-color: {color.name()};") 190 | self._on_changed(spec.callback, color.name()) 191 | dialog.currentColorChanged.connect(on_color_changed) 192 | if not dialog.exec(): 193 | w.setStyleSheet(f"background-color: {original_color};") 194 | self._on_changed(spec.callback, original_color) 195 | 196 | w.clicked.connect(pick_color) 197 | else: 198 | w = QtWidgets.QLabel(f"Unknown control {spec.name}") 199 | 200 | return w 201 | 202 | def _on_changed(self, callback: Callable[[Any], None], value: Any): 203 | callback(value) 204 | 205 | @classmethod 206 | def _clear_layout(cls, layout): 207 | while layout.count(): 208 | item = layout.takeAt(0) 209 | widget = item.widget() 210 | child_layout = item.layout() 211 | if widget is not None: 212 | widget.setParent(None) 213 | widget.deleteLater() 214 | elif child_layout is not None: 215 | cls._clear_layout(child_layout) 216 | child_layout.setParent(None) 217 | 218 | def _rebuild_ui(self, root: LayoutSpec): 219 | self._clear_layout(self._layout) 220 | self._widgets = {} 221 | QtWidgets.QWidget().setLayout(self._layout) 222 | self._layout = self._build_layout(root) 223 | self.setLayout(self._layout) 224 | self._update_screen_size_and_position() 225 | 226 | def _update_ui(self, root: LayoutSpec): 227 | if isinstance(root, ControlSpec): 228 | w = self._widgets.get(root.name) 229 | if not w: 230 | return 231 | if root.type == "combo" or root.type == 'combo-edit': 232 | w.blockSignals(True) 233 | w.setCurrentText(root.value) 234 | w.blockSignals(False) 235 | elif root.type == "checkbox": 236 | w.blockSignals(True) 237 | w.setChecked(root.value) 238 | w.blockSignals(False) 239 | elif root.type == "range_slider": 240 | w.blockSignals(True) 241 | w.setRange(*(root.range or (0, 1))) 242 | w.setValue(tuple(root.value)) 243 | w.blockSignals(False) 244 | elif root.type == "slider": 245 | w.blockSignals(True) 246 | w.setRange(*(root.range or (0, 1))) 247 | w.setValue(root.value) 248 | w.blockSignals(False) 249 | else: 250 | for c in root.children: 251 | self._update_ui(c) 252 | 253 | def _refresh_ui(self, root: LayoutSpec, new_widgets: bool): 254 | if new_widgets: 255 | self._rebuild_ui(root) 256 | else: 257 | self._update_ui(root) 258 | 259 | --------------------------------------------------------------------------------