├── tests ├── __init__.py ├── data │ ├── air.zarr │ │ ├── .zgroup │ │ ├── lat │ │ │ ├── 0 │ │ │ ├── .zattrs │ │ │ └── .zarray │ │ ├── lon │ │ │ ├── 0 │ │ │ ├── .zattrs │ │ │ └── .zarray │ │ ├── time │ │ │ ├── 0 │ │ │ ├── .zattrs │ │ │ └── .zarray │ │ ├── air │ │ │ ├── 0.0.0 │ │ │ ├── .zarray │ │ │ └── .zattrs │ │ ├── .zattrs │ │ └── .zmetadata │ ├── air.nc │ ├── gapminder.parquet │ ├── N_19781028_conc_v3.0.png │ ├── 1978_10_Oct_N_19781026_conc_v3.0.png │ ├── 1978_10_Oct_N_19781027_conc_v3.0.png │ └── gapminder.csv ├── test_polars.py ├── test_settings.py ├── test_pandas.py ├── test_xarray.py ├── test_renderers.py ├── conftest.py ├── test_models.py ├── test_wrappers.py ├── test_core.py ├── test_serializers.py └── test_streams.py ├── docs ├── index.md ├── assets │ ├── logo.png │ └── logo_bw.png ├── api_reference │ ├── core.md │ ├── models.md │ ├── streams.md │ ├── wrappers.md │ ├── renderers.md │ └── settings.md ├── stylesheets │ └── extra.css ├── example_recipes │ ├── air_temperature.md │ ├── sine_wave.md │ ├── sea_ice.md │ ├── nmme_forecast.md │ ├── oisst_globe.md │ ├── gender_gapminder.md │ ├── nice_orbit.md │ ├── stream_code.md │ ├── co2_timeseries.md │ └── temperature_anomaly.md ├── package_design.md ├── supported_formats.md └── how_do_i.md ├── streamjoy ├── .gitignore ├── cli.py ├── pandas.py ├── xarray.py ├── polars.py ├── __init__.py ├── settings.py ├── core.py ├── models.py ├── renderers.py ├── wrappers.py ├── ui.py ├── _utils.py └── serializers.py ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── build.yml │ ├── documentation.yml │ └── release.yml ├── HOWTORELEASE.md ├── LICENSE ├── CONTRIBUTING.md ├── HOWTOCONTRIBUTE.md ├── .gitignore ├── pyproject.toml ├── mkdocs.yml └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --8<-- "README.md" 2 | -------------------------------------------------------------------------------- /tests/data/air.zarr/.zgroup: -------------------------------------------------------------------------------- 1 | { 2 | "zarr_format": 2 3 | } -------------------------------------------------------------------------------- /tests/data/air.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.nc -------------------------------------------------------------------------------- /docs/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/docs/assets/logo.png -------------------------------------------------------------------------------- /docs/assets/logo_bw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/docs/assets/logo_bw.png -------------------------------------------------------------------------------- /tests/data/air.zarr/lat/0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/lat/0 -------------------------------------------------------------------------------- /tests/data/air.zarr/lon/0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/lon/0 -------------------------------------------------------------------------------- /tests/data/air.zarr/time/0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/time/0 -------------------------------------------------------------------------------- /tests/data/air.zarr/air/0.0.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/air/0.0.0 -------------------------------------------------------------------------------- /tests/data/gapminder.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/gapminder.parquet -------------------------------------------------------------------------------- /tests/data/N_19781028_conc_v3.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/N_19781028_conc_v3.0.png -------------------------------------------------------------------------------- /docs/api_reference/core.md: -------------------------------------------------------------------------------- 1 | # Core 2 | 3 | ::: streamjoy.core 4 | options: 5 | show_root_heading: false 6 | show_source: true 7 | -------------------------------------------------------------------------------- /tests/data/1978_10_Oct_N_19781026_conc_v3.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/1978_10_Oct_N_19781026_conc_v3.0.png -------------------------------------------------------------------------------- /tests/data/1978_10_Oct_N_19781027_conc_v3.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/1978_10_Oct_N_19781027_conc_v3.0.png -------------------------------------------------------------------------------- /docs/api_reference/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | ::: streamjoy.models 4 | options: 5 | show_root_heading: false 6 | show_source: true 7 | -------------------------------------------------------------------------------- /docs/api_reference/streams.md: -------------------------------------------------------------------------------- 1 | # Streams 2 | 3 | ::: streamjoy.streams 4 | options: 5 | show_root_heading: false 6 | show_source: true 7 | -------------------------------------------------------------------------------- /docs/api_reference/wrappers.md: -------------------------------------------------------------------------------- 1 | # Wrappers 2 | 3 | ::: streamjoy.wrappers 4 | options: 5 | show_root_heading: false 6 | show_source: true 7 | -------------------------------------------------------------------------------- /docs/api_reference/renderers.md: -------------------------------------------------------------------------------- 1 | # Renderers 2 | 3 | ::: streamjoy.renderers 4 | options: 5 | show_root_heading: false 6 | show_source: true 7 | -------------------------------------------------------------------------------- /streamjoy/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .vscode 3 | build 4 | dist 5 | docs/build/* 6 | __pycache__ 7 | *.egg* 8 | *.ipynb_checkpoints/ 9 | *.jpg 10 | *.png 11 | *.gif 12 | *.mp4 13 | *.ipynb 14 | -------------------------------------------------------------------------------- /tests/data/air.zarr/lat/.zattrs: -------------------------------------------------------------------------------- 1 | { 2 | "_ARRAY_DIMENSIONS": [ 3 | "lat" 4 | ], 5 | "axis": "Y", 6 | "long_name": "Latitude", 7 | "standard_name": "latitude", 8 | "units": "degrees_north" 9 | } -------------------------------------------------------------------------------- /tests/data/air.zarr/lon/.zattrs: -------------------------------------------------------------------------------- 1 | { 2 | "_ARRAY_DIMENSIONS": [ 3 | "lon" 4 | ], 5 | "axis": "X", 6 | "long_name": "Longitude", 7 | "standard_name": "longitude", 8 | "units": "degrees_east" 9 | } -------------------------------------------------------------------------------- /tests/data/air.zarr/time/.zattrs: -------------------------------------------------------------------------------- 1 | { 2 | "_ARRAY_DIMENSIONS": [ 3 | "time" 4 | ], 5 | "calendar": "standard", 6 | "long_name": "Time", 7 | "standard_name": "time", 8 | "units": "hours since 1800-01-01" 9 | } -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | -------------------------------------------------------------------------------- /tests/data/air.zarr/.zattrs: -------------------------------------------------------------------------------- 1 | { 2 | "Conventions": "COARDS", 3 | "description": "Data is from NMC initialized reanalysis\n(4x/day). These are the 0.9950 sigma level values.", 4 | "platform": "Model", 5 | "references": "http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html", 6 | "title": "4x daily NMC reanalysis (1948)" 7 | } -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Expectation / Proposal 4 | 5 | # Traceback / Example 6 | 7 | - [ ] I would like to [help contribute](https://github.com/ahuang11/streamjoy/blob/main/HOWTOCONTRIBUTE.md) a pull request to resolve this! 8 | -------------------------------------------------------------------------------- /HOWTORELEASE.md: -------------------------------------------------------------------------------- 1 | 1. Increment version in `package_name/__init__.py` 2 | 2. Push changes 3 | 3. Create a new tag 4 | 5 | image 6 | 7 | image 8 | -------------------------------------------------------------------------------- /tests/data/air.zarr/lat/.zarray: -------------------------------------------------------------------------------- 1 | { 2 | "chunks": [ 3 | 25 4 | ], 5 | "compressor": { 6 | "blocksize": 0, 7 | "clevel": 5, 8 | "cname": "lz4", 9 | "id": "blosc", 10 | "shuffle": 1 11 | }, 12 | "dtype": " 4 | 5 | 6 | 7 | Super barebones example of rendering air temperature data from xarray. 8 | 9 | Highlights: 10 | 11 | - Imports `streamjoy.xarray` to use the `stream` accessor. 12 | - Passes the `uri` to `stream` as the first argument to save the animation to disk. 13 | 14 | ```python hl_lines="2 5" 15 | import xarray as xr 16 | import streamjoy.xarray 17 | 18 | if __name__ == "__main__": 19 | ds = xr.tutorial.open_dataset("air_temperature") 20 | ds.streamjoy("air_temperature.mp4") 21 | ``` 22 | -------------------------------------------------------------------------------- /streamjoy/xarray.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | import xarray as xr 5 | except Exception as exc: 6 | raise ImportError( 7 | "Could not patch plotting API onto xarray. xarray could not be imported." 8 | ) from exc 9 | 10 | from .core import stream 11 | 12 | 13 | def patch(name="streamjoy"): 14 | class StreamAccessor: 15 | def __init__(self, resources: xr.Dataset | xr.DataArray): 16 | self._resources = resources 17 | 18 | def __call__(self, *args, **kwargs): 19 | return stream(self._resources, *args, **kwargs) 20 | 21 | StreamAccessor.__doc__ = stream.__doc__ 22 | 23 | xr.register_dataset_accessor(name)(StreamAccessor) 24 | xr.register_dataarray_accessor(name)(StreamAccessor) 25 | 26 | 27 | patch() 28 | -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | from imageio.v3 import improps 2 | 3 | from streamjoy import config, stream 4 | 5 | 6 | class TestConfig: 7 | def test_update_defaults(self, df): 8 | config["max_frames"] = 3 9 | config["ending_pause"] = 0 10 | 11 | sj = stream(df) 12 | assert sj.max_frames == 3 13 | assert sj.ending_pause == 0 14 | buf = sj.write() 15 | props = improps(buf) 16 | assert props.n_images == 3 17 | 18 | def test_override_defaults(self, df): 19 | config["max_frames"] = 3 20 | config["ending_pause"] = 0 21 | 22 | sj = stream(df, max_frames=5) 23 | assert sj.max_frames == 5 24 | assert sj.ending_pause == 0 25 | buf = sj.write() 26 | props = improps(buf) 27 | assert props.n_images == 5 28 | -------------------------------------------------------------------------------- /streamjoy/polars.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | import polars as pl 5 | except Exception as exc: 6 | raise ImportError( 7 | "Could not patch streamjoy API onto polars. Polars could not be imported." 8 | ) from exc 9 | 10 | from .core import stream 11 | 12 | 13 | def patch(name="streamjoy"): 14 | class StreamAccessor: 15 | def __init__(self, resources: pl.DataFrame | pl.Series | pl.LazyFrame): 16 | self._resources = resources 17 | 18 | def __call__(self, *args, **kwargs): 19 | return stream(self._resources, *args, **kwargs) 20 | 21 | pl.api.register_dataframe_namespace(name)(StreamAccessor) 22 | pl.api.register_series_namespace(name)(StreamAccessor) 23 | pl.api.register_lazyframe_namespace(name)(StreamAccessor) 24 | 25 | 26 | patch() 27 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | test: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python_version: ['3.10'] 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: ${{ matrix.python_version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install hatch 23 | hatch env create 24 | # - name: Lint and typecheck 25 | # run: | 26 | # hatch run lint-check 27 | - name: Test 28 | run: | 29 | hatch run test-cov-xml 30 | - uses: codecov/codecov-action@v4 31 | with: 32 | token: ${{ secrets.CODECOV_TOKEN }} 33 | fail_ci_if_error: true 34 | verbose: true 35 | -------------------------------------------------------------------------------- /tests/test_pandas.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import streamjoy.pandas # noqa: F401 4 | 5 | 6 | class TestPandas: 7 | def test_dataframe(self, df): 8 | stream = df.streamjoy(groupby="Country") 9 | assert stream.renderer_kwargs == { 10 | "groupby": "Country", 11 | "x": "Year", 12 | "y": "fertility", 13 | "xlabel": "Year", 14 | "ylabel": "Fertility", 15 | } 16 | assert isinstance(stream.write(), BytesIO) 17 | 18 | def test_series(self, df): 19 | stream = df.set_index("Year")[["Country", "life"]].streamjoy(groupby="Country") 20 | assert stream.renderer_kwargs == { 21 | "groupby": "Country", 22 | "x": "Year", 23 | "y": "life", 24 | "xlabel": "Year", 25 | "ylabel": "Life", 26 | } 27 | assert isinstance(stream.write(), BytesIO) 28 | -------------------------------------------------------------------------------- /docs/example_recipes/sine_wave.md: -------------------------------------------------------------------------------- 1 | # Sine wave 2 | 3 | 4 | 5 | Example of how to use `stream` alongside a custom `renderer` to create a sine wave animation. 6 | 7 | Highlights: 8 | 9 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure. 10 | - Uses a custom `renderer` function to create each frame of the animation. 11 | 12 | ```python hl_lines="5-6 14" 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | from streamjoy import stream, wrap_matplotlib 16 | 17 | @wrap_matplotlib() 18 | def plot_frame(i): 19 | x = np.linspace(0, 2, 1000) 20 | y = np.sin(2 * np.pi * (x - 0.01 * i)) 21 | fig, ax = plt.subplots() 22 | ax.plot(x, y) 23 | return fig 24 | 25 | if __name__ == "__main__": 26 | stream(list(range(10)), uri="sine_wave.gif", renderer=plot_frame) 27 | ``` 28 | -------------------------------------------------------------------------------- /tests/test_xarray.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import streamjoy.xarray # noqa: F401 4 | 5 | 6 | class TestXArray: 7 | def test_dataset_3d(self, ds): 8 | sj = ds.streamjoy() 9 | assert "vmin" in sj.renderer_kwargs 10 | assert "vmax" in sj.renderer_kwargs 11 | assert isinstance(sj.write(), BytesIO) 12 | 13 | def test_dataarray_3d(self, ds): 14 | sj = ds["air"].streamjoy() 15 | assert "vmin" in sj.renderer_kwargs 16 | assert "vmax" in sj.renderer_kwargs 17 | assert isinstance(sj.write(), BytesIO) 18 | 19 | def test_dataset_2d(self, ds): 20 | sj = ds.mean("lat").streamjoy() 21 | assert "ylim" in sj.renderer_kwargs 22 | assert isinstance(sj.write(), BytesIO) 23 | 24 | def test_dataarray_2d(self, ds): 25 | sj = ds["air"].mean("lat").streamjoy() 26 | assert "ylim" in sj.renderer_kwargs 27 | assert isinstance(sj.write(), BytesIO) 28 | -------------------------------------------------------------------------------- /streamjoy/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from ._utils import update_logger 4 | from .core import connect, stream 5 | from .models import ImageText, Paused 6 | from .renderers import ( 7 | default_holoviews_renderer, 8 | default_pandas_renderer, 9 | default_xarray_renderer, 10 | ) 11 | from .settings import config, file_handlers, obj_handlers 12 | from .streams import GifStream, HtmlStream, Mp4Stream 13 | from .wrappers import wrap_holoviews, wrap_matplotlib 14 | 15 | __version__ = "0.0.10" 16 | 17 | __all__ = [ 18 | "GifStream", 19 | "HtmlStream", 20 | "ImageText", 21 | "Mp4Stream", 22 | "Paused", 23 | "config", 24 | "connect", 25 | "default_holoviews_renderer", 26 | "default_pandas_renderer", 27 | "default_xarray_renderer", 28 | "file_handlers", 29 | "obj_handlers", 30 | "stream", 31 | "wrap_holoviews", 32 | "wrap_matplotlib", 33 | ] 34 | 35 | logging.basicConfig( 36 | level=config["logging_level"], 37 | format=config["logging_format"], 38 | datefmt=config["logging_datefmt"], 39 | ) 40 | 41 | update_logger() 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024, Andrew Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Development 2 | 3 | ### Setup environment 4 | 5 | We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system. 6 | 7 | ### Run unit tests 8 | 9 | You can run all the tests with: 10 | 11 | ```bash 12 | hatch run test 13 | ``` 14 | 15 | ### Format the code 16 | 17 | Execute the following command to apply linting and check typing: 18 | 19 | ```bash 20 | hatch run lint 21 | ``` 22 | 23 | ### Publish a new version 24 | 25 | You can bump the version, create a commit and associated tag with one command: 26 | 27 | ```bash 28 | hatch version patch 29 | ``` 30 | 31 | ```bash 32 | hatch version minor 33 | ``` 34 | 35 | ```bash 36 | hatch version major 37 | ``` 38 | 39 | Your default Git text editor will open so you can add information about the release. 40 | 41 | When you push the tag on GitHub, the workflow will automatically publish it on PyPi and a GitHub release will be created as draft. 42 | 43 | ## Serve the documentation 44 | 45 | You can serve the Mkdocs documentation with: 46 | 47 | ```bash 48 | hatch run docs-serve 49 | ``` 50 | 51 | It'll automatically watch for changes in your code. 52 | -------------------------------------------------------------------------------- /HOWTOCONTRIBUTE.md: -------------------------------------------------------------------------------- 1 | ## Development 2 | 3 | ### Setup environment 4 | 5 | We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system. 6 | 7 | ### Run unit tests 8 | 9 | You can run all the tests with: 10 | 11 | ```bash 12 | hatch run test 13 | ``` 14 | 15 | ### Format the code 16 | 17 | Execute the following command to apply linting and check typing: 18 | 19 | ```bash 20 | hatch run lint 21 | ``` 22 | 23 | ### Publish a new version 24 | 25 | You can bump the version, create a commit and associated tag with one command: 26 | 27 | ```bash 28 | hatch version patch 29 | ``` 30 | 31 | ```bash 32 | hatch version minor 33 | ``` 34 | 35 | ```bash 36 | hatch version major 37 | ``` 38 | 39 | Your default Git text editor will open so you can add information about the release. 40 | 41 | When you push the tag on GitHub, the workflow will automatically publish it on PyPi and a GitHub release will be created as draft. 42 | 43 | ## Serve the documentation 44 | 45 | You can serve the Mkdocs documentation with: 46 | 47 | ```bash 48 | hatch run docs-serve 49 | ``` 50 | 51 | It'll automatically watch for changes in your code. 52 | -------------------------------------------------------------------------------- /docs/example_recipes/sea_ice.md: -------------------------------------------------------------------------------- 1 | # Sea ice 2 | 3 | 6 | 7 | Compares sea ice concentration data from the NOAA G02135 dataset for August 15th in 1989 and 2023. 8 | 9 | Highlights: 10 | 11 | - Downloads images directly from the NSIDC FTP server. 12 | - Uses `connect` to concatenate the two homogeneous streams together (same keyword arguments, different resources). 13 | - Uses `pattern` to filter for only the sea ice concentration images. 14 | - Uses `intro_title` and `intro_subtitle` to provide context at the beginning of the animation. 15 | 16 | ```python hl_lines="3 6-9" 17 | from streamjoy import stream, connect 18 | 19 | connect( 20 | [ 21 | stream( 22 | f"https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/{year}/08_Aug/", 23 | pattern=f"N_*_conc_v3.0.png", 24 | intro_title=f"August 15, {year}", 25 | intro_subtitle="Sea Ice Concentration", 26 | max_files=31, 27 | ) 28 | for year in [1989, 2023] 29 | ] 30 | ).write("sea_ice.mp4") 31 | ``` -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | # Allow one concurrent deployment 15 | concurrency: 16 | group: "pages" 17 | cancel-in-progress: true 18 | 19 | # Default to bash 20 | defaults: 21 | run: 22 | shell: bash 23 | 24 | jobs: 25 | build: 26 | 27 | runs-on: ubuntu-latest 28 | 29 | steps: 30 | - uses: actions/checkout@v4 31 | - name: Set up Python 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: '3.11' 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install hatch 39 | hatch env create 40 | - name: Build 41 | run: hatch run docs-build 42 | - name: Upload artifact 43 | uses: actions/upload-pages-artifact@v3 44 | with: 45 | path: ./site 46 | 47 | deploy: 48 | environment: 49 | name: github-pages 50 | url: ${{ steps.deployment.outputs.page_url }} 51 | runs-on: ubuntu-latest 52 | needs: build 53 | steps: 54 | - name: Deploy to GitHub Pages 55 | id: deployment 56 | uses: actions/deploy-pages@v4 57 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Build & Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" 7 | 8 | jobs: 9 | build-release: 10 | name: Build Release 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.9' 19 | 20 | - name: Install packages 21 | run: | 22 | python -m pip install --upgrade pip build 23 | python -m pip install --upgrade --upgrade-strategy eager -e . 24 | 25 | - name: Build a binary wheel and a source tarball 26 | run: | 27 | python -m build --sdist --wheel --outdir dist/ 28 | 29 | - name: Publish build artifacts 30 | uses: actions/upload-artifact@v3 31 | with: 32 | name: built-package 33 | path: "./dist" 34 | 35 | publish-release: 36 | name: Publish release to PyPI 37 | needs: [build-release] 38 | environment: "prod" 39 | runs-on: ubuntu-latest 40 | 41 | steps: 42 | - name: Download build artifacts 43 | uses: actions/download-artifact@v3 44 | with: 45 | name: built-package 46 | path: './dist' 47 | 48 | - name: Publish distribution to PyPI 49 | uses: pypa/gh-action-pypi-publish@release/v1 50 | with: 51 | password: ${{ secrets.PYPI_API_TOKEN }} 52 | verbose: true 53 | -------------------------------------------------------------------------------- /tests/test_renderers.py: -------------------------------------------------------------------------------- 1 | import holoviews as hv 2 | import matplotlib.pyplot as plt 3 | import pytest 4 | 5 | from streamjoy.renderers import ( 6 | default_holoviews_renderer, 7 | default_pandas_renderer, 8 | default_polars_renderer, 9 | default_xarray_renderer, 10 | ) 11 | 12 | 13 | class TestDefaultRenderer: 14 | @pytest.mark.parametrize("title", ["Constant", "{Year}", None]) 15 | def test_pandas(self, df, title): 16 | fig = default_pandas_renderer( 17 | df, x="Year", y="life", groupby="Country", title=title 18 | ) 19 | assert isinstance(fig, plt.Figure) 20 | 21 | @pytest.mark.parametrize("title", ["Constant", "{Year}", None]) 22 | def test_polars(self, pl_df, title): 23 | rendered_obj = default_polars_renderer( 24 | pl_df, x="Year", y="life", groupby="Country", title=title 25 | ) 26 | assert isinstance(rendered_obj, hv.NdOverlay) 27 | 28 | @pytest.mark.parametrize("title", ["Constant", "{time}", None]) 29 | def test_xarray(self, ds, title): 30 | da = ds.air.isel(time=0) 31 | fig = default_xarray_renderer(da, title=title) 32 | assert isinstance(fig, plt.Figure) 33 | 34 | @pytest.mark.parametrize("title", ["Constant", "{x}", None]) 35 | def test_holoviews(self, title): 36 | hv_obj = hv.Curve([1, 2, 3]) 37 | rendered_obj = default_holoviews_renderer(hv_obj, title=title) 38 | assert isinstance(rendered_obj, hv.Element) 39 | -------------------------------------------------------------------------------- /docs/example_recipes/nmme_forecast.md: -------------------------------------------------------------------------------- 1 | # NMME forecast 2 | 3 | 10 | 11 | Highlights: 12 | 13 | - Writes to memory to later use in another Panel component 14 | - Sets `extension` to hint at the desired output format 15 | - Appends two streams in a `Tabs` layout 16 | - Links the Players' value from the first tab to the second tab 17 | 18 | ```python hl_lines="30 36 37 38" 19 | import panel as pn 20 | import pandas as pd 21 | from streamjoy import stream 22 | 23 | pn.extension() 24 | 25 | URL_FMT = ( 26 | "https://www.cpc.ncep.noaa.gov/products/NMME/archive/{dt:%Y%m}0800/" 27 | "current/images/NMME_ensemble_{var}_us_lead{i}.png" 28 | ) 29 | 30 | VARS = { 31 | "prate": "Precipitation Rate", 32 | "tmp2m": "2m Temperature", 33 | } 34 | LEADS = 7 35 | 36 | var_tabs = pn.Tabs() 37 | for var in VARS.keys(): 38 | dt_range = [ 39 | pd.to_datetime("2024-03-08") - pd.DateOffset(months=lead) 40 | for lead in range(LEADS) 41 | ] 42 | urls = [ 43 | URL_FMT.format(i=i, dt=dt, var=var) 44 | for i, dt in enumerate(dt_range, 1) 45 | ] 46 | col = stream( 47 | urls, 48 | extension=".html", 49 | fps=1, 50 | ending_pause=0, 51 | display=False, 52 | sizing_mode="stretch_width", 53 | height=400, 54 | ).write() 55 | var_tabs.append((VARS[var], col)) 56 | var_tabs[0][1].jslink(var_tabs[1][1], value="value") 57 | var_tabs.save("nmme_forecast.html") 58 | ``` -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hvplot.xarray # noqa: F401 4 | import imageio.v3 as iio 5 | import pandas as pd 6 | import polars as pl 7 | import pytest 8 | import xarray as xr 9 | 10 | from streamjoy._utils import get_distributed_client 11 | from streamjoy.settings import config 12 | 13 | DATA_DIR = Path(__file__).parent / "data" 14 | NC_PATH = DATA_DIR / "air.nc" 15 | ZARR_PATH = DATA_DIR / "air.zarr" 16 | CSV_PATH = DATA_DIR / "gapminder.csv" 17 | PARQUET_PATH = DATA_DIR / "gapminder.parquet" 18 | 19 | 20 | @pytest.fixture 21 | def array(): 22 | return iio.imread("imageio:newtonscradle.gif") 23 | 24 | 25 | @pytest.fixture 26 | def ds(): 27 | return xr.open_zarr(ZARR_PATH) 28 | 29 | 30 | @pytest.fixture 31 | def df(): 32 | return pd.read_parquet(PARQUET_PATH) 33 | 34 | 35 | @pytest.fixture 36 | def pl_df(): 37 | return pl.read_parquet(PARQUET_PATH) 38 | 39 | 40 | @pytest.fixture 41 | def dmap(ds): 42 | return ds.hvplot("lon", "lat", dynamic=True) 43 | 44 | 45 | @pytest.fixture 46 | def hmap(ds): 47 | return ds.hvplot("lon", "lat", dynamic=False) 48 | 49 | 50 | @pytest.fixture(autouse=True, scope="session") 51 | def client(): 52 | return get_distributed_client() 53 | 54 | 55 | @pytest.fixture(autouse=True, scope="session") 56 | def default_config(): 57 | config["fps"] = 1 58 | config["max_frames"] = 3 59 | config["max_files"] = 3 60 | config["ending_pause"] = 0 61 | 62 | 63 | @pytest.fixture(scope="session") 64 | def data_dir(): 65 | return DATA_DIR 66 | 67 | 68 | @pytest.fixture(scope="session") 69 | def fsspec_fs(): 70 | try: 71 | import fsspec 72 | 73 | return fsspec.filesystem("file") 74 | except ImportError: 75 | pytest.skip("fsspec not installed") 76 | -------------------------------------------------------------------------------- /docs/example_recipes/oisst_globe.md: -------------------------------------------------------------------------------- 1 | # OISST globe 2 | 3 | 4 | 7 | 8 | Render sea surface temperature anomaly data from the NOAA OISST v2.1 dataset on a globe. 9 | 10 | Highlights: 11 | 12 | - Concatenates multiple homogeneous streams together (same keyword arguments, different resources) by summing them. 13 | - Uses the built-in `default_xarray_renderer` under the hood 14 | - Uses `renderer_kwargs` to pass keyword arguments to the underlying `ds.plot` method. 15 | 16 | ```python hl_lines="19 32" 17 | import cartopy.crs as ccrs 18 | from streamjoy import stream 19 | 20 | YEAR = 2023 21 | URL_FMT = "https://www.ncei.noaa.gov/data/sea-surface-temperature-optimum-interpolation/v2.1/access/avhrr/{year}{month:02}/" 22 | 23 | if __name__ == "__main__": 24 | streams = [] 25 | for month in range(1, 13): 26 | url = URL_FMT.format(year=YEAR, month=month) 27 | streams.append( 28 | stream( 29 | url, 30 | pattern="oisst-avhrr-v02r01*.nc", 31 | var="anom", 32 | dim="time", 33 | max_files=29, 34 | max_frames=-1, 35 | renderer_kwargs=dict( 36 | cmap="RdBu_r", 37 | vmin=-5, 38 | vmax=5, 39 | subplot_kws=dict( 40 | projection=ccrs.Orthographic(central_longitude=-150), 41 | facecolor="gray", 42 | ), 43 | transform=ccrs.PlateCarree(), 44 | ), 45 | ) 46 | ) 47 | 48 | joined_stream = sum(streams) 49 | joined_stream.write("oisst_globe.mp4") 50 | ``` -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _NOTEBOOKS/ 2 | streamjoy_scratch 3 | *.ipynb 4 | 5 | *.gif 6 | *.mp4 7 | *.html 8 | *.png 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | env/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | junit/ 59 | junit.xml 60 | test.db 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # dotenv 96 | .env 97 | 98 | # virtualenv 99 | .venv 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | 116 | # OS files 117 | .DS_Store 118 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | 3 | from streamjoy.models import ImageText, Paused 4 | 5 | 6 | class TestPaused: 7 | def test_paused_initialization(self): 8 | output = "some_output" 9 | seconds = 5 10 | paused_instance = Paused(output=output, seconds=seconds) 11 | assert paused_instance.output == output 12 | assert paused_instance.seconds == seconds 13 | 14 | 15 | class TestImageText: 16 | def test_image_text_initialization(self): 17 | text = "Hello, World!" 18 | font = "Arial" 19 | size = 24 20 | color = "black" 21 | anchor = "mm" 22 | x = 100 23 | y = 100 24 | kwargs = {"width": 500} 25 | 26 | image_text_instance = ImageText( 27 | text=text, 28 | font=font, 29 | size=size, 30 | color=color, 31 | anchor=anchor, 32 | x=x, 33 | y=y, 34 | kwargs=kwargs, 35 | ) 36 | 37 | assert image_text_instance.text == text 38 | assert image_text_instance.font == font 39 | assert image_text_instance.size == size 40 | assert image_text_instance.color == color 41 | assert image_text_instance.anchor == anchor 42 | assert image_text_instance.x == x 43 | assert image_text_instance.y == y 44 | assert image_text_instance.kwargs == kwargs 45 | 46 | def test_image_text_render(self): 47 | text = "Test Render" 48 | image_text_instance = ImageText( 49 | text=text, font="Arial", size=24, color="black", anchor="mm", x=50, y=50 50 | ) 51 | img = Image.new("RGB", (100, 100), color="white") 52 | draw = ImageDraw.Draw(img) 53 | image_text_instance.render(draw) 54 | 55 | # Since it's difficult to assert the actual drawing, we check if the method runs without errors 56 | assert True # This is a placeholder to indicate the test passed by reaching this point without errors 57 | -------------------------------------------------------------------------------- /docs/example_recipes/gender_gapminder.md: -------------------------------------------------------------------------------- 1 | # Gender gapminder 2 | 3 | 6 | 7 | This example demonstrates how to use `stream` and `connect` to create a video 8 | comparing gender population data from the Gapminder dataset. 9 | 10 | Highlights: 11 | 12 | - Uses `intro_title` and `intro_subtitle` to set the title and subtitle of the video. 13 | - Uses `renderer_kwargs` to pass keyword arguments to the custom `renderer` function. 14 | - Updates `fps` to 30 to create a smoother animation. 15 | - Uses `connect` to concatenate the two heterogeneous streams (different keyword arguments with different titles) together. 16 | 17 | ```python hl_lines="21-23 31 34" 18 | import pandas as pd 19 | from streamjoy import stream, connect 20 | 21 | if __name__ == "__main__": 22 | url_fmt = ( 23 | "https://raw.githubusercontent.com/open-numbers/ddf--gapminder--systema_globalis/master/" 24 | "countries-etc-datapoints/ddf--datapoints--{gender}_population_with_projections--by--geo--time.csv" 25 | ) 26 | df = pd.concat(( 27 | pd.read_csv(url_fmt.format(gender=gender)).set_index(["geo", "time"]) 28 | for gender in ["male", "female"]), 29 | axis=1, 30 | ) 31 | 32 | streams = [] 33 | for country in ["usa", "chn"]: 34 | df_sub = df.loc[country].reset_index().melt("time") 35 | streamed = stream( 36 | df_sub, 37 | groupby="variable", 38 | intro_title="Gapminder", 39 | intro_subtitle=f"{country.upper()} Male vs Female Population", 40 | renderer_kwargs={ 41 | "x": "time", 42 | "y": "value", 43 | "xlabel": "Year", 44 | "ylabel": "Population", 45 | "title": f"{country.upper()} {{time}}", 46 | }, 47 | max_frames=-1, 48 | fps=30, 49 | ) 50 | streams.append(streamed) 51 | connect(streams).write("gender_gapminder.mp4") 52 | ``` -------------------------------------------------------------------------------- /docs/package_design.md: -------------------------------------------------------------------------------- 1 | # Package design 2 | 3 | ## 🪪 Naming 4 | 5 | StreamJoy stems from the idea of streaming parallelized, output images to GIF or MP4, *as they get serialized*. 6 | 7 | This was a mini breakthrough for me, as I had written other packages to try efficiently animating data (e.g. [`enjoyn`](https://enjoyn.readthedocs.io/en/latest/) and [`ahlive`](https://ahlive.readthedocs.io/en/latest/)). However, both of these packages suffered from the bottleneck of having to wait for all the images to get written out to disk before starting generating the animation. 8 | 9 | After, discovering this breakthrough, it brought me joy, and I wanted to share that joy with others by writing a package that reduces the boilerplate and time to work on animations, bringing joy to the user. 10 | 11 | Coincidentally, SJ is also my wife's initials, so it was a perfect fit! :D 12 | 13 | I also was thinking of naming this `streamio` and `streamit`, but the former was already taken and the latter too close to `streamlit`. 14 | 15 | ## 📶 Diagram 16 | 17 | Below is a diagram of the package design. The animation part is actually quite simple--most of the complexity comes with handling various input types, e.g. URLs, files, and datasets. 18 | 19 | ```mermaid 20 | graph TD 21 | A[Start] --> Z{Input Type} 22 | Z -->|URL| U[Download and Assess Content] 23 | Z -->|Direct Input| V{Determine Content Type} 24 | U --> V 25 | V -->|DataFrame| B[Split pandas DataFrame] 26 | V -->|XArray Dataset/DataArray| X[Split XArray by dim] 27 | V -->|HoloViews HoloMap/DynamicMap| H[Split HoloViews by kdim] 28 | V -->|Directory of Images| Y[Glob all files] 29 | V -->|List of Images| D[Open MP4/GIF buffer and start Dask client] 30 | B --> D 31 | X --> D 32 | H --> D 33 | Y --> D 34 | D --> E{Process each frame in parallel with Dask} 35 | E -->|Renderer available| F[Call custom/default renderer] 36 | E -->|Renderer not available| G[Convert to np.array with imread] 37 | F --> G 38 | G --> I[Stream to MP4/GIF buffer with imwrite] 39 | I --> J{All units processed?} 40 | J -->|Yes| K[Save MP4/GIF] 41 | J -->|No| E 42 | K -->|Optimize GIF?| N[Optimize GIF] 43 | N --> L[End] 44 | K --> L 45 | ``` -------------------------------------------------------------------------------- /tests/test_wrappers.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from pathlib import Path 3 | 4 | import holoviews as hv 5 | from matplotlib import pyplot as plt 6 | 7 | from streamjoy.models import Paused 8 | from streamjoy.wrappers import wrap_holoviews, wrap_matplotlib 9 | 10 | 11 | class TestWrapMatplotlib: 12 | def test_wrap_matplotlib_figure_to_file(self, tmp_path): 13 | @wrap_matplotlib(in_memory=False, scratch_dir=tmp_path) 14 | def render_figure(): 15 | fig, ax = plt.subplots() 16 | ax.plot([1, 2, 3], [1, 2, 3]) 17 | return fig 18 | 19 | path = render_figure() 20 | assert Path(path).exists() 21 | 22 | def test_wrap_matplotlib_figure_to_memory(self): 23 | @wrap_matplotlib(in_memory=True) 24 | def render_figure(): 25 | fig, ax = plt.subplots() 26 | ax.plot([1, 2, 3], [1, 2, 3]) 27 | return fig 28 | 29 | output = render_figure() 30 | assert isinstance(output, BytesIO) 31 | 32 | def test_wrap_matplotlib_with_paused(self): 33 | @wrap_matplotlib(in_memory=True) 34 | def render_figure(): 35 | fig, ax = plt.subplots() 36 | ax.plot([1, 2, 3], [1, 2, 3]) 37 | return Paused(output=fig, seconds=5) 38 | 39 | output = render_figure() 40 | assert isinstance(output, Paused) 41 | assert isinstance(output.output, BytesIO) 42 | assert output.seconds == 5 43 | 44 | 45 | class TestWrapHoloViews: 46 | def test_wrap_holoviews_to_file(self, tmp_path): 47 | @wrap_holoviews(in_memory=False, scratch_dir=tmp_path) 48 | def render_hv(): 49 | curve = hv.Curve((range(10), range(10))) 50 | return curve 51 | 52 | path = render_hv() 53 | assert Path(path).exists() 54 | 55 | def test_wrap_holoviews_with_paused(self, tmp_path): 56 | @wrap_holoviews(in_memory=False, scratch_dir=tmp_path) 57 | def render_hv(): 58 | curve = hv.Curve((range(10), range(10))) 59 | return Paused(output=curve, seconds=5) 60 | 61 | output = render_hv() 62 | assert isinstance(output, Paused) 63 | assert output.output.exists() 64 | assert output.seconds == 5 65 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from streamjoy.core import connect, stream 7 | from streamjoy.streams import AnyStream, ConnectedStreams, GifStream, Mp4Stream 8 | 9 | 10 | class TestStream: 11 | def test_no_uri(self): 12 | result = stream(resources=[0, 1, 2]) 13 | assert isinstance( 14 | result, AnyStream 15 | ), "Expected an instance of AnyStream or its subclass" 16 | 17 | def test_uri(self, df, tmp_path): 18 | uri = tmp_path / "test.mp4" 19 | stream(resources=df, uri=uri) 20 | assert uri.exists(), "Stream file should exist" 21 | 22 | def test_uri_with_bytesio(self, df): 23 | uri = BytesIO() 24 | result = stream(resources=df, uri=uri, extension=".mp4") 25 | assert isinstance( 26 | result, BytesIO 27 | ), "Expected a BytesIO object when URI is a BytesIO instance" 28 | 29 | def test_uri_with_str_path(self, df, tmp_path): 30 | uri = str(tmp_path / "test.gif") 31 | stream(resources=df, uri=uri) 32 | assert Path(uri).exists(), "Stream file should exist when URI is a string path" 33 | 34 | @pytest.mark.parametrize("extension", [".mp4", ".gif"]) 35 | def test_stream_with_different_extensions(self, df, extension): 36 | result = stream(resources=df, extension=extension) 37 | expected_class = Mp4Stream if extension == ".mp4" else GifStream 38 | assert isinstance( 39 | result, expected_class 40 | ), f"Expected an instance of {expected_class.__name__}" 41 | 42 | 43 | class TestConnect: 44 | def test_connect_streams(self): 45 | stream1 = stream(resources=[0, 1, 2]) 46 | stream2 = stream(resources=[3, 4, 5]) 47 | result = connect(streams=[stream1, stream2]) 48 | assert isinstance( 49 | result, ConnectedStreams 50 | ), "Expected an instance of ConnectedStreams" 51 | 52 | def test_connect_streams_with_uri(self, df, tmp_path): 53 | stream1 = stream(resources=df) 54 | stream2 = stream(resources=df) 55 | uri = tmp_path / "connected.mp4" 56 | connect(streams=[stream1, stream2], uri=uri) 57 | assert uri.exists(), "Connected stream file should exist" 58 | -------------------------------------------------------------------------------- /docs/api_reference/settings.md: -------------------------------------------------------------------------------- 1 | # Settings 2 | 3 | ```python 4 | config = { 5 | # animation 6 | "fps": 8, 7 | "max_frames": 50, 8 | # dask 9 | "batch_size": 10, 10 | "processes": True, 11 | "threads_per_worker": None, 12 | # intro 13 | "intro_pause": 2, 14 | "intro_watermark": "made with streamjoy", 15 | "intro_background": "black", 16 | # from_url 17 | "max_files": 2, 18 | # matplotlib 19 | "max_open_warning": 100, 20 | # output 21 | "in_memory": False, 22 | "scratch_dir": "streamjoy_scratch", 23 | "uri": None, 24 | # imageio 25 | "codec": "libx264", 26 | "loop": 0, 27 | "ending_pause": 2, 28 | # gif 29 | "optimize": False, 30 | # image text 31 | "image_text_font": "Avenir.ttc", 32 | "image_text_size": 20, 33 | "image_text_color": "white", 34 | "image_text_background": "black", 35 | # notebook 36 | "display": True, 37 | # logging 38 | "logging_success_level": 25, 39 | "logging_level": 25, 40 | "logging_format": "[%(levelname)s] %(asctime)s: %(message)s", 41 | "logging_datefmt": "%I:%M%p", 42 | "logging_warning_color": "\x1b[31;1m", 43 | "logging_success_color": "\x1b[32;1m", 44 | "logging_reset_color": "\x1b[0m", 45 | } 46 | ``` 47 | 48 | ```python 49 | obj_handlers = { 50 | "xarray.Dataset": "_expand_from_xarray", 51 | "xarray.DataArray": "_expand_from_xarray", 52 | "pandas.DataFrame": "_expand_from_pandas", 53 | "pandas.Series": "_expand_from_pandas", 54 | "holoviews": "_expand_from_holoviews", 55 | } 56 | ``` 57 | 58 | ```python 59 | file_handlers = { 60 | ".nc": { 61 | "import_path": "xarray.open_mfdataset", 62 | }, 63 | ".nc4": { 64 | "import_path": "xarray.open_mfdataset", 65 | }, 66 | ".zarr": { 67 | "import_path": "xarray.open_zarr", 68 | }, 69 | ".grib": { 70 | "import_path": "xarray.open_mfdataset", 71 | "kwargs": {"engine": "cfgrib"}, 72 | }, 73 | ".csv": { 74 | "import_path": "pandas.read_csv", 75 | "concat_path": "pandas.concat", 76 | }, 77 | ".parquet": { 78 | "import_path": "pandas.read_parquet", 79 | "concat_path": "pandas.concat", 80 | }, 81 | ".html": { 82 | "import_path": "pandas.read_html", 83 | "concat_path": "pandas.concat", 84 | }, 85 | } 86 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | target-version = "py310" 3 | 4 | [tool.ruff.lint] 5 | extend-select = ["I", "UP"] 6 | extend-ignore = ["TRY003"] 7 | 8 | [tool.ruff.lint.flake8-type-checking] 9 | quote-annotations = true 10 | 11 | [tool.pytest.ini_options] 12 | addopts = "--cov=streamjoy/ --cov-report=term-missing" 13 | 14 | 15 | [tool.hatch] 16 | 17 | [tool.hatch.metadata] 18 | allow-direct-references = true 19 | 20 | [tool.hatch.version] 21 | source = "regex_commit" 22 | commit_extra_args = ["-e"] 23 | path = "streamjoy/__init__.py" 24 | 25 | [tool.hatch.envs.default] 26 | python = "3.9" 27 | dependencies = [ 28 | "ruff", 29 | "pytest", 30 | "pytest-cov", 31 | "mkdocs-material", 32 | "mkdocstrings[python]", 33 | "xarray", 34 | "pandas", 35 | "polars", 36 | "zarr", 37 | "netcdf4", 38 | "matplotlib", 39 | "pyarrow", 40 | "hvplot", 41 | "bs4", 42 | "selenium", 43 | "webdriver_manager", 44 | "panel", 45 | ] 46 | 47 | [tool.hatch.envs.default.scripts] 48 | test = "pytest" 49 | test-cov-xml = "pytest --cov-report=xml" 50 | lint = [ 51 | "ruff format .", 52 | "ruff check --fix .", 53 | ] 54 | lint-check = [ 55 | "ruff format --check .", 56 | "ruff check .", 57 | ] 58 | docs-serve = "mkdocs serve" 59 | docs-build = "mkdocs build" 60 | 61 | [build-system] 62 | requires = ["hatchling", "hatch-regex-commit"] 63 | build-backend = "hatchling.build" 64 | 65 | [project] 66 | name = "streamjoy" 67 | authors = [ 68 | { name = "streamjoy", email = "hey.at.py@gmail.com" } 69 | ] 70 | description = "Enjoy animating images into GIFs and MP4s!" 71 | readme = "README.md" 72 | dynamic = ["version"] 73 | classifiers = [ 74 | "Programming Language :: Python :: 3 :: Only", 75 | ] 76 | requires-python = ">=3.9" 77 | dependencies = [ 78 | "param>2", 79 | "dask[distributed]", 80 | "imageio[pyav]>=2.34.0", 81 | "pygifsicle==1.0.5", 82 | "requests", 83 | "bs4", 84 | ] 85 | 86 | [project.urls] 87 | Documentation = "https://ahuang11.github.io/streamjoy/" 88 | Source = "https://github.com/ahuang11/streamjoy" 89 | 90 | [project.scripts] 91 | streamjoy = "streamjoy.cli:main" 92 | 93 | [project.optional-dependencies] 94 | ui = [ 95 | "panel", 96 | "param", 97 | "requests", 98 | "xarray", 99 | "netcdf4", 100 | ] -------------------------------------------------------------------------------- /streamjoy/settings.py: -------------------------------------------------------------------------------- 1 | config = { 2 | # animation 3 | "fps": 8, 4 | "max_frames": 50, 5 | # dask 6 | "batch_size": 10, 7 | "processes": True, 8 | "threads_per_worker": None, 9 | # intro 10 | "intro_pause": 2, 11 | "intro_watermark": "made with streamjoy", 12 | "intro_background": "black", 13 | # from_url 14 | "max_files": 2, 15 | # download 16 | "parent_depth": 4, 17 | # matplotlib 18 | "max_open_warning": 100, 19 | # holoviews 20 | "webdriver": "firefox", 21 | "num_retries": 5, 22 | # output 23 | "in_memory": False, 24 | "scratch_dir": "streamjoy_scratch", 25 | "uri": None, 26 | # imageio 27 | "codec": "libx264", 28 | "loop": 0, 29 | "ending_pause": 2, 30 | # gif 31 | "optimize": False, 32 | # image text 33 | "image_text_font": "Avenir.ttc", 34 | "image_text_size": 20, 35 | "image_text_color": "white", 36 | "image_text_background": "black", 37 | # notebook 38 | "display": True, 39 | # logging 40 | "logging_success_level": 25, 41 | "logging_level": 20, 42 | "logging_format": "[%(levelname)s] %(asctime)s: %(message)s", 43 | "logging_datefmt": "%I:%M%p", 44 | "logging_warning_color": "\x1b[31;1m", 45 | "logging_success_color": "\x1b[32;1m", 46 | "logging_reset_color": "\x1b[0m", 47 | } 48 | 49 | extension_handlers = { 50 | None: "AnyStream", 51 | ".mp4": "Mp4Stream", 52 | ".gif": "GifStream", 53 | ".html": "HtmlStream", 54 | } 55 | 56 | obj_handlers = { 57 | "xarray.Dataset": "serialize_xarray", 58 | "xarray.DataArray": "serialize_xarray", 59 | "pandas.DataFrame": "serialize_pandas", 60 | "pandas.Series": "serialize_pandas", 61 | "holoviews": "serialize_holoviews", 62 | "polars.DataFrame": "serialize_polars", 63 | "numpy.ndarray": "serialize_numpy", 64 | } 65 | 66 | file_handlers = { 67 | ".nc": { 68 | "import_path": "xarray.open_mfdataset", 69 | }, 70 | ".nc4": { 71 | "import_path": "xarray.open_mfdataset", 72 | }, 73 | ".zarr": { 74 | "import_path": "xarray.open_zarr", 75 | }, 76 | ".grib": { 77 | "import_path": "xarray.open_mfdataset", 78 | "kwargs": {"engine": "cfgrib"}, 79 | }, 80 | ".csv": { 81 | "import_path": "pandas.read_csv", 82 | "concat_path": "pandas.concat", 83 | }, 84 | ".parquet": { 85 | "import_path": "pandas.read_parquet", 86 | "concat_path": "pandas.concat", 87 | }, 88 | ".html": { 89 | "import_path": "pandas.read_html", 90 | "concat_path": "pandas.concat", 91 | }, 92 | } 93 | -------------------------------------------------------------------------------- /tests/test_serializers.py: -------------------------------------------------------------------------------- 1 | from streamjoy.models import Serialized 2 | from streamjoy.serializers import ( 3 | serialize_holoviews, 4 | serialize_numpy, 5 | serialize_pandas, 6 | serialize_polars, 7 | serialize_xarray, 8 | ) 9 | 10 | 11 | class TestSerializeNumpy: 12 | def test_serialize_numpy(self, array): 13 | serialized = serialize_numpy(None, array) 14 | assert isinstance(serialized, Serialized) 15 | assert len(serialized.resources) == 36 16 | assert isinstance(serialized.resources, list) 17 | assert not serialized.renderer 18 | assert serialized.renderer_iterables is None 19 | assert isinstance(serialized.renderer_kwargs, dict) 20 | assert isinstance(serialized.kwargs, dict) 21 | 22 | 23 | class TestSerializeXarray: 24 | def test_serialize_xarray(self, ds): 25 | serialized = serialize_xarray(None, ds) 26 | assert isinstance(serialized, Serialized) 27 | assert len(serialized.resources) == 3 28 | assert isinstance(serialized.resources, list) 29 | assert callable(serialized.renderer) 30 | assert serialized.renderer_iterables is None 31 | assert isinstance(serialized.renderer_kwargs, dict) 32 | assert isinstance(serialized.kwargs, dict) 33 | 34 | 35 | class TestSerializePandas: 36 | def test_serialize_pandas(self, df): 37 | serialized = serialize_pandas(None, df) 38 | assert isinstance(serialized, Serialized) 39 | assert len(serialized.resources) == 3 40 | assert isinstance(serialized.resources, list) 41 | assert callable(serialized.renderer) 42 | assert serialized.renderer_iterables is None 43 | assert isinstance(serialized.renderer_kwargs, dict) 44 | assert isinstance(serialized.kwargs, dict) 45 | 46 | 47 | class TestSerializePolars: 48 | def test_serialize_polars(self, pl_df): 49 | serialized = serialize_polars(None, pl_df) 50 | assert isinstance(serialized, Serialized) 51 | assert len(serialized.resources) == 3 52 | assert isinstance(serialized.resources, list) 53 | assert callable(serialized.renderer) 54 | assert serialized.renderer_iterables is None 55 | assert isinstance(serialized.renderer_kwargs, dict) 56 | assert isinstance(serialized.kwargs, dict) 57 | 58 | 59 | class TestSerializeHoloviews: 60 | def test_serialize_holoviews(self, hmap): 61 | serialized = serialize_holoviews(None, hmap) 62 | assert isinstance(serialized, Serialized) 63 | assert len(serialized.resources) == 20 64 | assert isinstance(serialized.resources, list) 65 | assert callable(serialized.renderer) 66 | assert serialized.renderer_iterables is None 67 | assert isinstance(serialized.renderer_kwargs, dict) 68 | assert isinstance(serialized.kwargs, dict) 69 | -------------------------------------------------------------------------------- /streamjoy/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from io import BytesIO 4 | from pathlib import Path 5 | from typing import Any, Callable, Literal 6 | 7 | from . import streams 8 | from .serializers import serialize_appropriately 9 | from .settings import extension_handlers 10 | from .streams import AnyStream, ConnectedStreams, GifStream, HtmlStream, Mp4Stream 11 | 12 | 13 | def stream( 14 | resources: Any, 15 | uri: str | Path | BytesIO | None = None, 16 | renderer: Callable | None = None, 17 | renderer_iterables: list | None = None, 18 | renderer_kwargs: dict | None = None, 19 | extension: Literal[".mp4", ".gif"] | None = None, 20 | **kwargs: dict[str, Any], 21 | ) -> AnyStream | GifStream | Mp4Stream | HtmlStream | Path: 22 | """ 23 | Create a stream from the given resources. 24 | 25 | Args: 26 | resources: The resources to create a stream from. 27 | uri: The destination to write the stream to. If None, the stream is returned. 28 | renderer: The renderer to use. If None, the default renderer is used. 29 | renderer_iterables: Additional positional arguments to map over the renderer. 30 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 31 | extension: The extension to use; useful if uri is a file-like object. 32 | **kwargs: Additional keyword arguments to pass. 33 | 34 | Returns: 35 | The stream if uri is None, otherwise the uri. 36 | """ 37 | if isinstance(uri, str): 38 | uri = Path(uri) 39 | 40 | extension = extension or (uri and uri.suffix) 41 | if isinstance(extension, str) and extension not in extension_handlers: 42 | raise ValueError(f"Unsupported extension: {extension}") 43 | 44 | stream_cls = getattr(streams, extension_handlers.get(extension), AnyStream) 45 | serialized = serialize_appropriately( 46 | stream_cls, 47 | resources, 48 | renderer, 49 | renderer_iterables, 50 | renderer_kwargs or {}, 51 | **kwargs, 52 | ) 53 | stream = stream_cls(**serialized.param.values(), **serialized.kwargs) 54 | 55 | if uri: 56 | return stream.write(uri=uri, extension=extension) 57 | return stream 58 | 59 | 60 | def connect( 61 | streams: list[AnyStream | GifStream | Mp4Stream | HtmlStream], 62 | uri: str | Path | BytesIO | None = None, 63 | ) -> ConnectedStreams | Path: 64 | """ 65 | Connect hetegeneous streams into a single stream. 66 | 67 | Unlike `stream.join`, this function can connect streams 68 | with unique params, such as different renderers. 69 | 70 | Args: 71 | streams: The streams to connect. 72 | uri: The destination to write the connected streams to. 73 | If None, the connected streams are returned. 74 | 75 | Returns: 76 | The connected streams if uri is None, otherwise the uri. 77 | """ 78 | stream = ConnectedStreams(streams=streams) 79 | if uri: 80 | return stream.write(uri=uri) 81 | return stream 82 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: streamjoy 2 | site_description: Enjoy animating images into GIFs and MP4s! 3 | 4 | repo_url: https://github.com/ahuang11/streamjoy 5 | repo_name: ahuang11/streamjoy 6 | 7 | theme: 8 | name: material 9 | features: 10 | - content.code.copy 11 | logo: assets/logo.png 12 | palette: 13 | # Palette toggle for automatic mode 14 | - media: "(prefers-color-scheme)" 15 | toggle: 16 | icon: material/brightness-auto 17 | name: Switch to light mode 18 | 19 | # Palette toggle for light mode 20 | - media: "(prefers-color-scheme: light)" 21 | scheme: default 22 | primary: custom 23 | accent: custom 24 | toggle: 25 | icon: material/brightness-7 26 | name: Switch to dark mode 27 | 28 | # Palette toggle for dark mode 29 | - media: "(prefers-color-scheme: dark)" 30 | scheme: slate 31 | primary: custom 32 | accent: custom 33 | toggle: 34 | icon: material/brightness-4 35 | name: Switch to light mode 36 | 37 | markdown_extensions: 38 | - toc: 39 | permalink: true 40 | - pymdownx.highlight: 41 | anchor_linenums: true 42 | - pymdownx.tasklist: 43 | custom_checkbox: true 44 | - admonition 45 | - pymdownx.details 46 | - pymdownx.tabbed 47 | - pymdownx.magiclink 48 | - pymdownx.inlinehilite 49 | - pymdownx.snippets 50 | - pymdownx.superfences: 51 | custom_fences: 52 | - name: mermaid 53 | class: mermaid 54 | format: !!python/name:pymdownx.superfences.fence_code_format 55 | 56 | plugins: 57 | - search 58 | - mkdocstrings: 59 | handlers: 60 | python: 61 | options: 62 | docstring_style: google 63 | find_stubs_package: true 64 | 65 | watch: 66 | - docs 67 | - streamjoy 68 | 69 | nav: 70 | - Read me!: index.md 71 | - Supported formats: supported_formats.md 72 | - How do I...: how_do_i.md 73 | - Package design: package_design.md 74 | - Example recipes: 75 | - Air temperature: example_recipes/air_temperature.md 76 | - Sine wave: example_recipes/sine_wave.md 77 | - CO2 timeseries: example_recipes/co2_timeseries.md 78 | - Temperature anomaly: example_recipes/temperature_anomaly.md 79 | - Sea ice: example_recipes/sea_ice.md 80 | - OISST globe: example_recipes/oisst_globe.md 81 | - Gender gapminder: example_recipes/gender_gapminder.md 82 | - Stream code: example_recipes/stream_code.md 83 | - Nice orbit: example_recipes/nice_orbit.md 84 | - NMME forecast: example_recipes/nmme_forecast.md 85 | - API reference: 86 | - Core: api_reference/core.md 87 | - Models: api_reference/models.md 88 | - Streams: api_reference/streams.md 89 | - Renderers: api_reference/renderers.md 90 | - Wrappers: api_reference/wrappers.md 91 | - Settings: api_reference/settings.md 92 | 93 | extra_css: 94 | - stylesheets/extra.css 95 | -------------------------------------------------------------------------------- /docs/example_recipes/nice_orbit.md: -------------------------------------------------------------------------------- 1 | # Nice orbit 2 | 3 | 6 | 7 | Creates a visually appealing, nice orbits of a 2d dynamical system. 8 | 9 | Code adapted from [Nice_orbits.ipynb](https://github.com/profConradi/Python_Simulations/blob/main/Nice_orbits.ipynb). 10 | All credits go to [Simone Conradi](https://github.com/profConradi); the only addition here was wrapping the code into a function and using `streamjoy` to create the animation. Please consider giving the [Python_Simulations](https://github.com/profConradi/Python_Simulations/tree/main) repo a star! 11 | 12 | Highlights: 13 | 14 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure. 15 | - Uses a custom `renderer` function to create each frame of the animation. 16 | 17 | ```python hl_lines="45 46" 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | from numba import njit 21 | from streamjoy import stream, wrap_matplotlib 22 | 23 | @njit 24 | def meshgrid(x, y): 25 | """ 26 | This function replace np.meshgrid that is not supported by numba 27 | """ 28 | xx = np.empty(shape=(x.size, y.size), dtype=x.dtype) 29 | yy = np.empty(shape=(x.size, y.size), dtype=y.dtype) 30 | for j in range(y.size): 31 | for k in range(x.size): 32 | xx[j, k] = k # change to x[k] if indexing xy 33 | yy[j, k] = j # change to y[j] if indexing xy 34 | return xx, yy 35 | 36 | @njit 37 | def calc_orbit(n_points, a, b, n_iter): 38 | """ 39 | This function calculate orbits in a vectorized fashion. 40 | 41 | -n_points: lattice of initial conditions, n_points x n_points in [-1,1]x[-1,1] 42 | -a: first parameter of the dynamical system 43 | -b: second parameter of the dynamical system 44 | -n_iter: number of iterations 45 | 46 | Return: two ndarrays: x and y coordinates of every point of every orbit. 47 | """ 48 | area = [[-1, 1], [-1, 1]] 49 | x = np.linspace(area[0][0], area[0][1], n_points) 50 | y = np.linspace(area[1][0], area[1][1], n_points) 51 | xx, yy = meshgrid(x, y) 52 | l_cx, l_cy = np.zeros(n_iter * n_points**2), np.zeros(n_iter * n_points**2) 53 | for i in range(n_iter): 54 | xx_new = np.sin(xx**2 - yy**2 + a) 55 | yy_new = np.cos(2 * xx * yy + b) 56 | xx = xx_new 57 | yy = yy_new 58 | l_cx[i * n_points**2 : (i + 1) * n_points**2] = xx.flatten() 59 | l_cy[i * n_points**2 : (i + 1) * n_points**2] = yy.flatten() 60 | return l_cx, l_cy 61 | 62 | @wrap_matplotlib() 63 | def plot_frame(n): 64 | l_cx, l_cy = calc_orbit(n_points, a + 0.002 * n, b - 0.001 * n, n) 65 | area = [[-1, 1], [-1, 1]] 66 | h, _, _ = np.histogram2d(l_cx, l_cy, bins=3000, range=area) 67 | fig, ax = plt.subplots(figsize=(10, 10)) 68 | fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None) 69 | ax.imshow(np.log(h + 1), vmin=0, vmax=5, cmap="magma") 70 | plt.xticks([]), plt.yticks([]) 71 | return fig 72 | 73 | if __name__ == "__main__": 74 | n_points = 500 75 | a, b = 5.48, 4.28 76 | stream(np.arange(1, 100).tolist(), renderer=plot_frame, uri="nice_orbit.mp4") 77 | ``` 78 | -------------------------------------------------------------------------------- /streamjoy/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import param 6 | from PIL import ImageDraw, ImageFont 7 | 8 | from . import _utils 9 | 10 | 11 | class Paused(param.Parameterized): 12 | """ 13 | A data model for pausing a stream. 14 | 15 | Expand Source code to see all the parameters and descriptions. 16 | """ 17 | 18 | output = param.Parameter(doc="The output to pause for.") 19 | 20 | seconds = param.Number(doc="The number of seconds to pause for.") 21 | 22 | def __init__(self, output: Any, seconds: int, **params): 23 | self.output = output 24 | self.seconds = seconds 25 | super().__init__(**params) 26 | 27 | 28 | class ImageText(param.Parameterized): 29 | """ 30 | A data model for rendering text on an image. 31 | 32 | Expand Source code to see all the parameters and descriptions. 33 | """ 34 | 35 | text = param.String( 36 | doc="The text to render.", 37 | ) 38 | 39 | font = param.String( 40 | doc="The font to use for the text.", 41 | ) 42 | 43 | size = param.Integer( 44 | doc="The font size to use for the text.", 45 | ) 46 | 47 | color = param.String( 48 | doc="The color to use for the text.", 49 | ) 50 | 51 | anchor = param.String( 52 | doc="The anchor to use for the text.", 53 | ) 54 | 55 | x = param.Integer( 56 | doc="The x-coordinate to use for the text.", 57 | ) 58 | 59 | y = param.Integer( 60 | doc="The y-coordinate to use for the text.", 61 | ) 62 | 63 | kwargs = param.Dict( 64 | default={}, 65 | doc="Additional keyword arguments to pass to the text renderer.", 66 | ) 67 | 68 | def __init__(self, text: str, **params) -> None: 69 | params["text"] = text 70 | params = _utils.populate_config_defaults( 71 | params, self.param, config_prefix="image_text" 72 | ) 73 | super().__init__(**params) 74 | 75 | def render( 76 | self, 77 | draw: ImageDraw, 78 | width: int | None = None, 79 | height: int | None = None, 80 | ) -> None: 81 | x = self.x or width // 2 82 | y = self.y or height // 2 83 | try: 84 | font = ImageFont.truetype(self.font, self.size) 85 | except Exception: 86 | font = ImageFont.load_default() 87 | draw.text( 88 | (x, y), 89 | self.text, 90 | font=font, 91 | fill=self.color, 92 | anchor=self.anchor, 93 | **self.kwargs, 94 | ) 95 | 96 | 97 | class Serialized(param.Parameterized): 98 | resources = param.List(doc="The list of resources.") 99 | renderer = param.Callable(doc="The rendering function to use on the resources.") 100 | renderer_iterables = param.List( 101 | doc="Additional iterable arguments to pass to the renderer.", 102 | allow_None=True, 103 | ) 104 | renderer_kwargs = param.Dict( 105 | doc="Additional keyword arguments to pass to the renderer.", 106 | allow_None=True, 107 | ) 108 | 109 | def __init__( 110 | self, resources, renderer, renderer_iterables, renderer_kwargs, kwargs, **params 111 | ): 112 | super().__init__( 113 | resources=resources, 114 | renderer=renderer, 115 | renderer_iterables=renderer_iterables, 116 | renderer_kwargs=renderer_kwargs, 117 | **params, 118 | ) 119 | self.kwargs = kwargs 120 | -------------------------------------------------------------------------------- /tests/data/gapminder.csv: -------------------------------------------------------------------------------- 1 | ,Year,Country,fertility,life,population,child_mortality,gdp,region 2 | 0,1964,Bangladesh,6.853,49.297,56071080.0,240.4,1138.0,South Asia 3 | 1,1965,Bangladesh,6.877999999999999,49.652,57791778.0,236.0,1177.0,South Asia 4 | 2,1966,Bangladesh,6.901,49.734,59681186.0,232.3,1169.0,South Asia 5 | 3,1967,Bangladesh,6.92,49.50899999999999,61702785.0,229.4,1122.0,South Asia 6 | 4,1968,Bangladesh,6.935,49.008,63703327.0,227.1,1202.0,South Asia 7 | 5,1969,Bangladesh,6.945,48.307,65474474.0,225.4,1195.0,South Asia 8 | 6,1970,Bangladesh,6.947,47.58,66881158.0,224.1,1226.0,South Asia 9 | 7,1971,Bangladesh,6.942,47.04600000000001,67849312.0,223.0,1142.0,South Asia 10 | 8,1972,Bangladesh,6.928,46.874,68461849.0,222.0,986.0,South Asia 11 | 9,1973,Bangladesh,6.904,47.162,68933357.0,220.7,971.0,South Asia 12 | 50,1964,China,6.12,53.32072,696171650.0,130.77,713.0,East Asia & Pacific 13 | 51,1965,China,6.022,55.6468,710290299.0,115.43,772.0,East Asia & Pacific 14 | 52,1966,China,6.211,56.8032,727601056.0,120.8,826.0,East Asia & Pacific 15 | 53,1967,China,5.252000000000001,58.38112,747678772.0,126.42,719.0,East Asia & Pacific 16 | 54,1968,China,6.37,59.4052,769666505.0,132.3,669.0,East Asia & Pacific 17 | 55,1969,China,5.67,60.9652,792308749.0,119.1,732.0,East Asia & Pacific 18 | 56,1970,China,5.746,62.6508,814622841.0,113.3,848.0,East Asia & Pacific 19 | 57,1971,China,5.396,63.73736,836431505.0,107.7,876.0,East Asia & Pacific 20 | 58,1972,China,4.92,63.11888,857804035.0,102.1,843.0,East Asia & Pacific 21 | 59,1973,China,4.506,62.7808,878305091.0,96.5,894.0,East Asia & Pacific 22 | 100,1964,India,5.807,44.375,486038945.0,232.9,1125.0,South Asia 23 | 101,1965,India,5.781000000000001,45.141000000000005,496400381.0,229.6,1053.0,South Asia 24 | 102,1966,India,5.747000000000001,45.903,507115411.0,226.3,1037.0,South Asia 25 | 103,1967,India,5.7010000000000005,46.655,518192403.0,223.1,1096.0,South Asia 26 | 104,1968,India,5.642,47.394,529658233.0,219.8,1095.0,South Asia 27 | 105,1969,India,5.573,48.119,541544619.0,216.6,1141.0,South Asia 28 | 106,1970,India,5.494,48.836000000000006,553873890.0,213.3,1170.0,South Asia 29 | 107,1971,India,5.41,49.554,566651479.0,209.9,1154.0,South Asia 30 | 108,1972,India,5.323,50.281000000000006,579871075.0,206.1,1125.0,South Asia 31 | 109,1973,India,5.238,51.015,593526633.0,202.2,1151.0,South Asia 32 | 150,1964,South Africa,5.984,50.574,19308185.0,174.55,9004.0,Sub-Saharan Africa 33 | 151,1965,South Africa,5.9110000000000005,50.96,19813932.0,169.51,9255.0,Sub-Saharan Africa 34 | 152,1966,South Africa,5.836,51.348,20325179.0,164.61,9371.0,Sub-Saharan Africa 35 | 153,1967,South Africa,5.765,51.73,20843695.0,159.85,9720.0,Sub-Saharan Africa 36 | 154,1968,South Africa,5.7,52.104,21374801.0,155.24,9849.0,Sub-Saharan Africa 37 | 155,1969,South Africa,5.643,52.47,21926000.0,150.75,10156.0,Sub-Saharan Africa 38 | 156,1970,South Africa,5.591,52.825,22502306.0,146.4,10394.0,Sub-Saharan Africa 39 | 157,1971,South Africa,5.539,53.168,23106584.0,142.17,10654.0,Sub-Saharan Africa 40 | 158,1972,South Africa,5.482,53.504,23736249.0,138.06,10615.0,Sub-Saharan Africa 41 | 159,1973,South Africa,5.415,53.838,24384286.0,134.07,10813.0,Sub-Saharan Africa 42 | 200,1964,United States,3.222,70.33,197094531.0,27.7,20338.0,America 43 | 201,1965,United States,2.926,70.41,199452508.0,27.1,21361.0,America 44 | 202,1966,United States,2.714,70.43,201657141.0,26.4,22495.0,America 45 | 203,1967,United States,2.564,70.76,203717833.0,25.7,22803.0,America 46 | 204,1968,United States,2.467,70.42,205672498.0,24.9,23647.0,America 47 | 205,1969,United States,2.457,70.66,207573866.0,24.1,24147.0,America 48 | 206,1970,United States,2.461,70.92,209463865.0,23.3,23908.0,America 49 | 207,1971,United States,2.268,71.24,211355529.0,22.4,24350.0,America 50 | 208,1972,United States,2.008,71.34,213250350.0,21.5,25374.0,America 51 | 209,1973,United States,1.871,71.54,215164616.0,20.6,26567.0,America 52 | -------------------------------------------------------------------------------- /docs/example_recipes/stream_code.md: -------------------------------------------------------------------------------- 1 | # Stream code 2 | 3 | 4 | 5 | Generates an animation of a code snippet being written character by character. 6 | 7 | Highlights: 8 | 9 | - Uses a custom `renderer` function to create each frame of the animation. 10 | - Propagates `formatter`, `max_line_length`, and `max_line_number` to the custom `renderer` function. 11 | 12 | ```python hl_lines="51 102-104" 13 | from textwrap import dedent 14 | 15 | import numpy as np 16 | from PIL import Image, ImageDraw 17 | from pygments import lex 18 | from pygments.formatters import ImageFormatter 19 | from pygments.lexers import get_lexer_by_name 20 | from streamjoy import stream 21 | 22 | def _custom_format( 23 | formatter: ImageFormatter, 24 | tokensource: list[tuple], 25 | max_line_length: int = None, 26 | max_line_number: int = None, 27 | ) -> Image: 28 | formatter._create_drawables(tokensource) 29 | formatter._draw_line_numbers() 30 | max_line_length = max_line_length or formatter.maxlinelength 31 | max_line_number = max_line_number or formatter.maxlineno 32 | 33 | image = Image.new( 34 | "RGB", 35 | formatter._get_image_size(max_line_length, max_line_number), 36 | formatter.background_color, 37 | ) 38 | formatter._paint_line_number_bg(image) 39 | draw = ImageDraw.Draw(image) 40 | # Highlight 41 | if formatter.hl_lines: 42 | x = ( 43 | formatter.image_pad 44 | + formatter.line_number_width 45 | - formatter.line_number_pad 46 | + 1 47 | ) 48 | recth = formatter._get_line_height() 49 | rectw = image.size[0] - x 50 | for linenumber in formatter.hl_lines: 51 | y = formatter._get_line_y(linenumber - 1) 52 | draw.rectangle([(x, y), (x + rectw, y + recth)], fill=formatter.hl_color) 53 | for pos, value, font, text_fg, text_bg in formatter.drawables: 54 | if text_bg: 55 | text_size = draw.textsize(text=value, font=font) 56 | draw.rectangle( 57 | [pos[0], pos[1], pos[0] + text_size[0], pos[1] + text_size[1]], 58 | fill=text_bg, 59 | ) 60 | draw.text(pos, value, font=font, fill=text_fg) 61 | return np.asarray(image) 62 | 63 | def render_frame( 64 | code: str, 65 | formatter: ImageFormatter, 66 | max_line_length: int = None, 67 | max_line_number: int = None, 68 | ) -> Image: 69 | lexer = get_lexer_by_name("python") 70 | return _custom_format( 71 | formatter, 72 | lex(code, lexer), 73 | max_line_length=max_line_length, 74 | max_line_number=max_line_number, 75 | ) 76 | 77 | if __name__ == "__main__": 78 | code = dedent( 79 | """ 80 | import matplotlib.pyplot as plt 81 | import numpy as np 82 | 83 | from streamjoy import stream, wrap_matplotlib 84 | 85 | @wrap_matplotlib() 86 | def plot_frame(i): 87 | x = np.linspace(0, 2, 1000) 88 | y = np.sin(2 * np.pi * (x - 0.01 * i)) 89 | fig, ax = plt.subplots() 90 | ax.plot(x, y) 91 | return fig 92 | 93 | stream(list(range(10)), uri="sine_wave.mp4", renderer=plot_frame) 94 | """ 95 | ) 96 | 97 | formatter = ImageFormatter( 98 | image_format="gif", 99 | line_pad=8, 100 | line_number_bg=None, 101 | line_number_fg=None, 102 | encoding="utf-8", 103 | ) 104 | longest_line = max(code.splitlines(), key=len) + " " * 12 105 | max_line_length, _ = formatter.fonts.get_text_size(longest_line) 106 | max_line_number = code.count("\n") + 1 107 | items = [code[:i] for i in range(0, len(code) + 3, 3)] 108 | 109 | stream( 110 | items, 111 | ending_pause=20, 112 | uri="stream_code.gif", 113 | renderer=render_frame, 114 | formatter=formatter, 115 | max_line_length=max_line_length, 116 | max_line_number=max_line_number, 117 | ) 118 | ``` 119 | -------------------------------------------------------------------------------- /tests/data/air.zarr/.zmetadata: -------------------------------------------------------------------------------- 1 | { 2 | "metadata": { 3 | ".zattrs": { 4 | "Conventions": "COARDS", 5 | "description": "Data is from NMC initialized reanalysis\n(4x/day). These are the 0.9950 sigma level values.", 6 | "platform": "Model", 7 | "references": "http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html", 8 | "title": "4x daily NMC reanalysis (1948)" 9 | }, 10 | ".zgroup": { 11 | "zarr_format": 2 12 | }, 13 | "air/.zarray": { 14 | "chunks": [ 15 | 20, 16 | 25, 17 | 53 18 | ], 19 | "compressor": { 20 | "blocksize": 0, 21 | "clevel": 5, 22 | "cname": "lz4", 23 | "id": "blosc", 24 | "shuffle": 1 25 | }, 26 | "dtype": " 4 | 5 | 6 | 7 | Shows the yearly CO2 measurements from the Mauna Loa Observatory in Hawaii. 8 | 9 | The data is sourced from the [datasets/co2-ppm-daily](https://github.com/datasets/co2-ppm-daily/blob/master/co2-ppm-daily-flow.py). 10 | 11 | Highlights: 12 | 13 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure. 14 | - Uses a custom `renderer` function to create each frame of the animation. 15 | - Uses `Paused` to pause the animation at notable dates. 16 | 17 | ```python hl_lines="4 19 114" 18 | import pandas as pd 19 | import matplotlib.pyplot as plt 20 | from matplotlib.ticker import AutoMinorLocator 21 | from streamjoy import stream, wrap_matplotlib, Paused 22 | 23 | URL = "https://raw.githubusercontent.com/datasets/co2-ppm-daily/master/data/co2-ppm-daily.csv" 24 | NOTABLE_YEARS = { 25 | 1958: "Mauna Loa measurements begin", 26 | 1979: "1st World Climate Conference", 27 | 1997: "Kyoto Protocol", 28 | 2005: "exceeded 380 ppm", 29 | 2010: "exceeded 390 ppm", 30 | 2013: "exceeded 400 ppm", 31 | 2015: "Paris Agreement", 32 | } 33 | 34 | 35 | @wrap_matplotlib() 36 | def renderer(df): 37 | plt.style.use("dark_background") 38 | 39 | fig, ax = plt.subplots(figsize=(7, 5)) 40 | fig.patch.set_facecolor("#1b1e23") 41 | ax.set_facecolor("#1b1e23") 42 | ax.set_frame_on(False) 43 | ax.axis("off") 44 | ax.set_title( 45 | "CO2 Yearly Max", 46 | fontsize=20, 47 | loc="left", 48 | fontname="Courier New", 49 | color="lightgrey", 50 | ) 51 | 52 | # draw line 53 | df.plot( 54 | y="value", 55 | color="lightgrey", # Line color 56 | legend=False, 57 | ax=ax, 58 | ) 59 | 60 | # max date 61 | max_date = df["value"].idxmax() 62 | max_co2 = df["value"].max() 63 | ax.text( 64 | 0.0, 65 | 0.92, 66 | f"{max_co2:.0f} ppm", 67 | va="bottom", 68 | ha="left", 69 | transform=ax.transAxes, 70 | fontsize=25, 71 | color="lightgrey", 72 | ) 73 | ax.text( 74 | 0.0, 75 | 0.91, 76 | f"Peaked in {max_date.year}", 77 | va="top", 78 | ha="left", 79 | transform=ax.transAxes, 80 | fontsize=12, 81 | color="lightgrey", 82 | fontname="Courier New", 83 | ) 84 | 85 | # draw end point 86 | date = df.index[-1] 87 | co2 = df["value"].values[-1] 88 | diff = df["diff"].fillna(0).values[-1] 89 | diff = f"+{diff:.0f}" if diff >= 0 else f"{diff:.0f}" 90 | ax.scatter(date, co2, color="red", zorder=999) 91 | ax.annotate( 92 | f"{diff} ppm", 93 | (date, co2), 94 | textcoords="offset points", 95 | xytext=(-10, 5), 96 | fontsize=12, 97 | ha="right", 98 | va="bottom", 99 | color="lightgrey", 100 | ) 101 | 102 | # draw source label 103 | ax.text( 104 | 0.0, 105 | 0.03, 106 | f"Source: {URL}", 107 | va="bottom", 108 | ha="left", 109 | transform=ax.transAxes, 110 | fontsize=8, 111 | color="lightgrey", 112 | ) 113 | 114 | # properly tighten layout 115 | plt.subplots_adjust(bottom=0, top=0.9, right=0.9, left=0.05) 116 | 117 | # pause at notable years 118 | year = date.year 119 | if year in NOTABLE_YEARS: 120 | ax.annotate( 121 | f"{NOTABLE_YEARS[year]} - {year}", 122 | (date, co2), 123 | textcoords="offset points", 124 | xytext=(-10, 3), 125 | fontsize=10.5, 126 | ha="right", 127 | va="top", 128 | color="lightgrey", 129 | fontname="Courier New", 130 | ) 131 | return Paused(ax, 2.8) 132 | else: 133 | ax.annotate( 134 | year, 135 | (date, co2), 136 | textcoords="offset points", 137 | xytext=(-10, 3), 138 | fontsize=10.5, 139 | ha="right", 140 | va="top", 141 | color="lightgrey", 142 | fontname="Courier New", 143 | ) 144 | return ax 145 | 146 | 147 | if __name__ == "__main__": 148 | df = ( 149 | pd.read_csv(URL, parse_dates=True, index_col="date") 150 | .resample("1YE") 151 | .max() 152 | .interpolate() 153 | .assign( 154 | diff=lambda df: df["value"].diff(), 155 | ) 156 | ) 157 | stream(df, renderer=renderer, max_frames=-1, threads_per_worker=1).write("co2_emissions.mp4") 158 | ``` -------------------------------------------------------------------------------- /docs/example_recipes/temperature_anomaly.md: -------------------------------------------------------------------------------- 1 | # Temperature anomaly 2 | 3 | 6 | 7 | Shows the global temperature anomaly from 1995 to 2024 using the HadCRUT5 dataset. The video pauses at notable dates. 8 | 9 | Highlights: 10 | 11 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure. 12 | - Uses a custom `renderer` function to create each frame of the animation. 13 | - Uses `Paused` to pause the animation at notable dates. 14 | 15 | ```python hl_lines="16 17 112" 16 | import pandas as pd 17 | import matplotlib.pyplot as plt 18 | from streamjoy import stream, wrap_matplotlib, Paused 19 | 20 | URL = "https://climexp.knmi.nl/data/ihadcrut5_global.dat" 21 | NOTABLE_DATES = { 22 | "1997-12": "Kyoto Protocol adopted", 23 | "2005-01": "Exceeded 380 ppm", 24 | "2010-01": "Exceeded 390 ppm", 25 | "2013-05": "Exceeded 400 ppm", 26 | "2015-12": "Paris Agreement signed", 27 | "2016-01": "CO2 permanently over 400 ppm", 28 | } 29 | 30 | 31 | @wrap_matplotlib() 32 | def renderer(df): 33 | plt.style.use("dark_background") # Setting the style for dark mode 34 | 35 | fig, ax = plt.subplots() 36 | fig.patch.set_facecolor("#1b1e23") 37 | ax.set_facecolor("#1b1e23") 38 | ax.set_frame_on(False) 39 | ax.axis("off") 40 | 41 | # Set title 42 | year = df["year"].iloc[-1] 43 | ax.set_title( 44 | f"Global Temperature Anomaly {year} [HadCRUT5]", 45 | fontsize=15, 46 | loc="left", 47 | fontname="Courier New", 48 | color="lightgrey", 49 | ) 50 | 51 | # draw line 52 | df.groupby("year")["anom"].plot( 53 | y="anom", color="lightgrey", legend=False, ax=ax, lw=0.5 54 | ) 55 | 56 | # add source text at bottom right 57 | ax.text( 58 | 0.01, 59 | 0.05, 60 | f"Source: {URL}", 61 | va="bottom", 62 | ha="left", 63 | transform=ax.transAxes, 64 | fontsize=8, 65 | color="lightgrey", 66 | fontname="Courier New", 67 | ) 68 | 69 | # draw end point 70 | jday = df.index.values[-1] 71 | anom = df["anom"].values[-1] 72 | ax.scatter(jday, anom, color="red", zorder=999) 73 | anom_label = f"+{anom:.1f} K" if anom > 0 else f"{anom:.1f} K" 74 | ax.annotate( 75 | anom_label, 76 | (jday, anom), 77 | textcoords="offset points", 78 | xytext=(-10, 5), 79 | fontsize=12, 80 | ha="right", 81 | va="bottom", 82 | color="lightgrey", 83 | ) 84 | 85 | # draw yearly labels 86 | for year, df_year in df.reset_index().groupby("year").last().iloc[-5:].iterrows(): 87 | if df_year["month"] != 12: 88 | continue 89 | ax.annotate( 90 | year, 91 | (df_year["jday"], df_year["anom"]), 92 | fontsize=12, 93 | ha="left", 94 | va="center", 95 | color="lightgrey", 96 | fontname="Courier New", 97 | ) 98 | 99 | plt.subplots_adjust(bottom=0, top=0.9, left=0.05) 100 | 101 | month = df["date"].iloc[-1].strftime("%b") 102 | ax.annotate( 103 | month, 104 | (jday, anom), 105 | textcoords="offset points", 106 | xytext=(-10, 3), 107 | fontsize=12, 108 | ha="right", 109 | va="top", 110 | color="lightgrey", 111 | fontname="Courier New", 112 | ) 113 | date_string = df["date"].iloc[-1].strftime("%Y-%m") 114 | if date_string in NOTABLE_DATES: 115 | ax.annotate( 116 | f"{NOTABLE_DATES[date_string]}", 117 | xy=(0, 1), 118 | xycoords="axes fraction", 119 | xytext=(0, -5), 120 | textcoords="offset points", 121 | fontsize=12, 122 | ha="left", 123 | va="top", 124 | color="lightgrey", 125 | fontname="Courier New", 126 | ) 127 | return Paused(fig, 3) 128 | return fig 129 | 130 | 131 | df = ( 132 | pd.read_csv( 133 | URL, 134 | comment="#", 135 | header=None, 136 | sep="\s+", 137 | na_values=[-999.9], 138 | ) 139 | .rename(columns={0: "year"}) 140 | .melt(id_vars="year", var_name="month", value_name="anom") 141 | ) 142 | df.index = pd.to_datetime( 143 | df["year"].astype(str) + df["month"].astype(str), format="%Y%m" 144 | ) 145 | df = df.sort_index()["1995":"2024"] 146 | df["jday"] = df.index.dayofyear 147 | df = df.rename_axis("date").reset_index().set_index("jday") 148 | df_list = [df[:i] for i in range(1, len(df) + 1)] 149 | 150 | stream(df_list, renderer=renderer, threads_per_worker=1).write( 151 | "temperature_anomaly.mp4" 152 | ) 153 | ``` -------------------------------------------------------------------------------- /streamjoy/renderers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING, Any 4 | 5 | from . import _utils 6 | 7 | if TYPE_CHECKING: 8 | try: 9 | import matplotlib.pyplot as plt 10 | except ImportError: 11 | plt = None 12 | 13 | try: 14 | import pandas as pd 15 | except ImportError: 16 | pd = None 17 | 18 | try: 19 | import polars as pl 20 | except ImportError: 21 | pl = None 22 | 23 | try: 24 | import xarray as xr 25 | except ImportError: 26 | xr = None 27 | 28 | try: 29 | import holoviews as hv 30 | except ImportError: 31 | hv = None 32 | 33 | 34 | def default_pandas_renderer( 35 | df_sub: pd.DataFrame, *args: tuple[Any], **kwargs: dict[str, Any] 36 | ) -> plt.Figure: 37 | """ 38 | Render a pandas DataFrame using matplotlib. 39 | 40 | Args: 41 | df_sub: The DataFrame to render. 42 | *args: Additional positional arguments to pass to the renderer. 43 | **kwargs: Additional keyword arguments to pass to the renderer. 44 | 45 | Returns 46 | A matplotlib figure. 47 | """ 48 | import matplotlib.pyplot as plt 49 | 50 | df_sub = df_sub.reset_index() 51 | 52 | fig, ax = plt.subplots() 53 | 54 | title = kwargs.get("title") 55 | if title: 56 | title = title.format(**df_sub.iloc[-1]) 57 | elif title is None: 58 | title = df_sub[kwargs["x"]].iloc[-1] 59 | kwargs["title"] = title 60 | 61 | groupby = kwargs.pop("groupby", None) 62 | if groupby: 63 | for group, df_group in df_sub.groupby(groupby): 64 | df_group.plot(*args, ax=ax, label=group, **kwargs) 65 | else: 66 | df_sub.plot(*args, ax=ax, **kwargs) 67 | 68 | return fig 69 | 70 | 71 | def default_polars_renderer( 72 | df_sub: pl.DataFrame, *args: tuple[Any], **kwargs: dict[str, Any] 73 | ) -> hv.Element: 74 | """ 75 | Render a polars DataFrame using HoloViews. 76 | 77 | Args: 78 | df_sub: The DataFrame to render. 79 | *args: Additional positional arguments to pass to the renderer. 80 | **kwargs: Additional keyword arguments to pass to the renderer. 81 | 82 | Returns: 83 | The rendered HoloViews Element. 84 | """ 85 | backend = kwargs.pop("backend", None) 86 | by = kwargs.pop("groupby", None) 87 | 88 | title = kwargs.get("title") 89 | if title: 90 | title = title.format(**df_sub.tail(1).to_pandas().to_dict("records")[0]) 91 | elif title is None: 92 | title = df_sub[kwargs["x"]].tail(1)[0] 93 | kwargs["title"] = str(title) 94 | 95 | if by: 96 | kwargs["by"] = by 97 | hv_obj = df_sub.plot(*args, **kwargs) 98 | return default_holoviews_renderer(hv_obj, backend=backend) 99 | 100 | 101 | def default_xarray_renderer( 102 | da_sel: xr.DataArray, *args: tuple[Any], **kwargs: dict[str, Any] 103 | ) -> plt.Figure: 104 | """ 105 | Render an xarray DataArray using matplotlib. 106 | 107 | Args: 108 | da_sel: The DataArray to render. 109 | *args: Additional positional arguments to pass to the renderer. 110 | **kwargs: Additional keyword arguments to pass to the renderer. 111 | 112 | Returns: 113 | A matplotlib figure. 114 | """ 115 | import matplotlib.pyplot as plt 116 | 117 | da_sel = _utils.validate_xarray(da_sel, warn=False) 118 | 119 | fig = plt.figure() 120 | ax = plt.axes(**kwargs.pop("subplot_kws", {})) 121 | title = kwargs.pop("title", None) 122 | 123 | try: 124 | da_sel.plot(ax=ax, extend="both", *args, **kwargs) 125 | except Exception: 126 | da_sel.plot(ax=ax, *args, **kwargs) 127 | 128 | if title: 129 | title_format = {coord: da_sel[coord].values for coord in da_sel.coords} 130 | ax.set_title(title.format(**title_format)) 131 | 132 | return fig 133 | 134 | 135 | def default_holoviews_renderer( 136 | hv_obj: hv.Element, *args: tuple[Any], **kwargs: dict[str, Any] 137 | ) -> hv.Element: 138 | """ 139 | Render a HoloViews Element using the default backend. 140 | 141 | Args: 142 | hv_obj: The HoloViews Element to render. 143 | *args: Additional positional arguments to pass to the renderer. 144 | **kwargs: Additional keyword arguments to pass to the renderer. 145 | 146 | Returns: 147 | The rendered HoloViews Element. 148 | """ 149 | import holoviews as hv 150 | 151 | backend = kwargs.get("backend", hv.Store.current_backend) 152 | 153 | clims = kwargs.pop("clims", {}) 154 | for hv_el in hv_obj.traverse(full_breadth=False): 155 | try: 156 | vdim = hv_el.vdims[0].name 157 | except IndexError: 158 | continue 159 | if vdim in clims: 160 | hv_el.opts(clim=clims[vdim], backend=backend) 161 | 162 | if backend == "bokeh": 163 | kwargs["toolbar"] = None 164 | elif backend == "matplotlib": 165 | kwargs["cbar_extend"] = kwargs.get("cbar_extend", "both") 166 | 167 | if isinstance(hv_obj, hv.Overlay): 168 | for hv_el in hv_obj: 169 | try: 170 | hv_el.opts(**kwargs) 171 | except Exception: 172 | pass 173 | else: 174 | hv_obj.opts(**kwargs) 175 | 176 | return hv_obj 177 | -------------------------------------------------------------------------------- /docs/supported_formats.md: -------------------------------------------------------------------------------- 1 | # Supported formats 2 | 3 | StreamJoy supports a variety of input types! 4 | 5 | ## 📋 List of Images, GIFs, Videos, or URLs 6 | 7 | ```python 8 | from streamjoy import stream 9 | 10 | URL_FMT = "https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/2024/01_Jan/N_202401{day:02d}_conc_v3.0.png" 11 | 12 | stream([URL_FMT.format(day=day) for day in range(1, 31)], uri="2024_jan_sea_ice.mp4") 13 | ``` 14 | 17 | 18 | ## 📁 Directory of Images, GIFs, Videos, or URLs 19 | 20 | ```python 21 | from streamjoy import stream 22 | 23 | URL_DIR = "https://downloads.psl.noaa.gov/Datasets/ncep.reanalysis/Dailies/surface/" 24 | 25 | stream(URL_DIR, uri="air_temperature.mp4", pattern="air.sig995.194*.nc") 26 | ``` 27 | 28 | 31 | 32 | ## 🧮 Numpy NdArray 33 | 34 | ```python 35 | from streamjoy import stream 36 | import imageio.v3 as iio 37 | 38 | array = iio.imread("imageio:newtonscradle.gif") # is a 4D numpy array 39 | stream(array, max_frames=-1).write("newtonscradle.mp4") 40 | ``` 41 | 42 | 45 | 46 | ## 🐼 Pandas DataFrame or Series 47 | 48 | !!! note "Additional Requirements" 49 | 50 | You will need to additionally install `pandas` and `matplotlib` to support this format: 51 | 52 | ```bash 53 | pip install pandas matplotlib 54 | ``` 55 | 56 | ```python 57 | from streamjoy import stream 58 | import pandas as pd 59 | 60 | df = pd.read_csv( 61 | "https://raw.githubusercontent.com/franlopezguzman/gapminder-with-bokeh/master/gapminder_tidy.csv" 62 | ).set_index("Year") 63 | df = df.query("Country in ['United States', 'China', 'South Africa']") 64 | stream(df, uri="gapminder.mp4", groupby="Country", title="{Year}") 65 | ``` 66 | 67 | 70 | 71 | ## 🐻‍❄️ Polars DataFrame 72 | 73 | !!! note "Additional Requirements" 74 | 75 | You will need to additionally install `polars`, `pyarrow`, `hvplot`, `selenium`, and `webdriver-manager` to support this format: 76 | 77 | ```bash 78 | pip install polars pyarrow hvplot selenium webdriver-manager 79 | ``` 80 | 81 | You must also have `firefox` or `chromedriver` installed on your system. 82 | 83 | ```bash 84 | conda install -c conda-forge firefox 85 | ``` 86 | 87 | ```python 88 | from streamjoy import stream 89 | import polars as pl 90 | 91 | df = pl.read_csv( 92 | "https://raw.githubusercontent.com/franlopezguzman/gapminder-with-bokeh/master/gapminder_tidy.csv" 93 | ) 94 | df = df.filter(pl.col("Country").is_in(['United States', 'China', 'South Africa'])) 95 | stream(df, uri="gapminder.mp4", groupby="Country", title="{Year}") 96 | ``` 97 | 98 | 101 | 102 | ## 🗄️ XArray Dataset or DataArray 103 | 104 | !!! note "Additional Requirements" 105 | 106 | You will need to additionally install `xarray` and `matplotlib` to support this format: 107 | 108 | ```bash 109 | pip install xarray matplotlib 110 | ``` 111 | 112 | For this example, you will also need to install `pooch` and `netcdf4`: 113 | 114 | ```bash 115 | pip install pooch netcdf4 116 | ``` 117 | 118 | ```python 119 | from streamjoy import stream 120 | import xarray as xr 121 | 122 | ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 100)) 123 | stream(ds, uri="air.mp4", cmap="RdBu_r") 124 | ``` 125 | 126 | 129 | 130 | ## 📊 HoloViews HoloMap or DynamicMap 131 | 132 | !!! note "Additional Requirements" 133 | 134 | You will need to additionally install `holoviews` to support this format: 135 | 136 | ```bash 137 | pip install holoviews 138 | ``` 139 | 140 | For the bokeh backend, you will need to install `bokeh`, `selenium`, and `webdriver-manager`: 141 | 142 | ```bash 143 | pip install bokeh selenium webdriver-manager 144 | ``` 145 | 146 | For the matplotlib backend, you will need to install `matplotlib`: 147 | 148 | ```bash 149 | pip install matplotlib 150 | ``` 151 | 152 | For this example, you will also need to install `pooch`, `netcdf4, `hvplot`, and `xarray`: 153 | 154 | ```bash 155 | pip install pooch netcdf4 hvplot xarray 156 | ``` 157 | 158 | You must also have `firefox` or `chromedriver` installed on your system. 159 | 160 | ```bash 161 | conda install -c conda-forge firefox 162 | ``` 163 | 164 | ```python 165 | import xarray as xr 166 | import hvplot.xarray 167 | from streamjoy import stream 168 | 169 | ds = xr.tutorial.open_dataset("rasm").isel(time=slice(10)) 170 | stream(ds.hvplot.image("x", "y"), uri="rasm.mp4") 171 | ``` 172 | 173 | 176 | -------------------------------------------------------------------------------- /tests/test_streams.py: -------------------------------------------------------------------------------- 1 | import panel as pn 2 | import pytest 3 | from imageio.v3 import improps 4 | 5 | from streamjoy.models import Paused 6 | from streamjoy.streams import GifStream, HtmlStream, Mp4Stream 7 | from streamjoy.wrappers import wrap_matplotlib 8 | 9 | 10 | class AbstractTestMediaStream: 11 | def _assert_stream_and_props(self, sj, stream_cls, max_frames=3): 12 | assert isinstance(sj, stream_cls) 13 | buf = sj.write() 14 | props = improps(buf) 15 | props.n_images == max_frames 16 | return props 17 | 18 | def test_from_numpy(self, stream_cls, array): 19 | sj = stream_cls.from_numpy(array) 20 | self._assert_stream_and_props(sj, stream_cls) 21 | 22 | def test_from_pandas(self, stream_cls, df): 23 | sj = stream_cls.from_pandas(df) 24 | self._assert_stream_and_props(sj, stream_cls) 25 | 26 | def test_from_polars(self, stream_cls, pl_df): 27 | sj = stream_cls.from_polars(pl_df) 28 | self._assert_stream_and_props(sj, stream_cls) 29 | 30 | def test_from_xarray(self, stream_cls, ds): 31 | sj = stream_cls.from_xarray(ds) 32 | self._assert_stream_and_props(sj, stream_cls) 33 | 34 | def test_from_holoviews_hmap(self, stream_cls, hmap): 35 | sj = stream_cls.from_holoviews(hmap) 36 | self._assert_stream_and_props(sj, stream_cls) 37 | 38 | def test_from_holoviews_dmap(self, stream_cls, dmap): 39 | sj = stream_cls.from_holoviews(dmap) 40 | self._assert_stream_and_props(sj, stream_cls) 41 | 42 | def test_from_url_dir(self, stream_cls): 43 | sj = stream_cls.from_url( 44 | "https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/1978/10_Oct/", 45 | pattern="N_197810*_conc_v3.0.png", 46 | ) 47 | self._assert_stream_and_props(sj, stream_cls) 48 | 49 | def test_from_url_path(self, stream_cls): 50 | sj = stream_cls.from_url( 51 | "https://github.com/ahuang11/streamjoy/raw/main/tests/data/gapminder.parquet", 52 | ) 53 | self._assert_stream_and_props(sj, stream_cls) 54 | 55 | def test_from_directory(self, stream_cls, data_dir): 56 | sj = stream_cls.from_directory(data_dir, pattern="*.png") 57 | self._assert_stream_and_props(sj, stream_cls) 58 | 59 | def test_fsspec_fs(self, stream_cls, df, fsspec_fs): 60 | sj = stream_cls.from_pandas(df, fsspec_fs=fsspec_fs) 61 | self._assert_stream_and_props(sj, stream_cls) 62 | 63 | def test_holoviews_matplotlib_backend(self, stream_cls, ds): 64 | sj = stream_cls.from_holoviews( 65 | ds.hvplot("lon", "lat", fig_size=200, backend="matplotlib") 66 | ) 67 | props = self._assert_stream_and_props(sj, stream_cls) 68 | assert props.shape[1] == 300 69 | 70 | def test_holoviews_bokeh_backend(self, stream_cls, ds): 71 | sj = stream_cls.from_holoviews( 72 | ds.hvplot("lon", "lat", width=300, backend="bokeh") 73 | ) 74 | props = self._assert_stream_and_props(sj, stream_cls) 75 | assert props.shape[1] == 300 76 | 77 | def test_write_max_frames(self, stream_cls, df): 78 | sj = stream_cls.from_pandas(df, max_frames=3) 79 | self._assert_stream_and_props(sj, stream_cls, max_frames=3) 80 | 81 | 82 | class TestGifStream(AbstractTestMediaStream): 83 | @pytest.fixture(scope="class") 84 | def stream_cls(self): 85 | return GifStream 86 | 87 | def test_paused(self, stream_cls, df): 88 | @wrap_matplotlib() 89 | def renderer(df, groupby=None): # TODO: fix bug groupby not needed 90 | return Paused(df.plot(), seconds=2) 91 | 92 | buf = stream_cls.from_pandas(df, renderer=renderer).write() 93 | props = improps(buf) 94 | assert props.n_images == 3 95 | 96 | 97 | class TestMp4Stream(AbstractTestMediaStream): 98 | @pytest.fixture(scope="class") 99 | def stream_cls(self): 100 | return Mp4Stream 101 | 102 | def test_paused(self, stream_cls, df): 103 | @wrap_matplotlib() 104 | def renderer(df, groupby=None): # TODO: fix bug groupby not needed 105 | return Paused(df.plot(), seconds=2) 106 | 107 | buf = stream_cls.from_pandas(df, renderer=renderer).write() 108 | props = improps(buf) 109 | assert props.n_images == 9 110 | 111 | 112 | class TestHtmlStream(AbstractTestMediaStream): 113 | def _assert_stream_and_props(self, sj, stream_cls, max_frames=3): 114 | assert isinstance(sj, stream_cls) 115 | buf = sj.write() 116 | assert isinstance(buf, pn.Column) 117 | tabs = buf[0] 118 | assert isinstance(tabs, pn.Tabs) 119 | image = tabs[0] 120 | assert isinstance(image, pn.pane.Image) 121 | assert len(tabs) == max_frames 122 | player = buf[1] 123 | assert isinstance(player, pn.widgets.Player) 124 | return image 125 | 126 | @pytest.fixture(scope="class") 127 | def stream_cls(self): 128 | return HtmlStream 129 | 130 | def test_holoviews_matplotlib_backend(self, stream_cls, ds): 131 | sj = stream_cls.from_holoviews( 132 | ds.hvplot("lon", "lat", fig_size=200, backend="matplotlib") 133 | ) 134 | image = self._assert_stream_and_props(sj, stream_cls) 135 | assert image.width is None 136 | 137 | def test_holoviews_bokeh_backend(self, stream_cls, ds): 138 | sj = stream_cls.from_holoviews( 139 | ds.hvplot("lon", "lat", width=300, backend="bokeh") 140 | ) 141 | image = self._assert_stream_and_props(sj, stream_cls) 142 | assert image.width is None 143 | 144 | def test_fixed_width_height(self, stream_cls, df): 145 | sj = stream_cls.from_pandas(df, width=300, height=300, sizing_mode="fixed") 146 | image = self._assert_stream_and_props(sj, stream_cls) 147 | assert image.width == 300 148 | assert image.height == 300 149 | 150 | def test_stretch_width(self, stream_cls, df): 151 | sj = stream_cls.from_pandas(df, height=300, sizing_mode="stretch_width") 152 | image = self._assert_stream_and_props(sj, stream_cls) 153 | assert image.width is None 154 | assert image.height is None 155 | -------------------------------------------------------------------------------- /streamjoy/wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import time 5 | from functools import wraps 6 | from io import BytesIO 7 | from pathlib import Path 8 | from typing import Any, Callable 9 | 10 | from . import _utils 11 | from .models import Paused 12 | from .settings import config 13 | 14 | 15 | def wrap_matplotlib( 16 | in_memory: bool = False, 17 | scratch_dir: str | Path | None = None, 18 | fsspec_fs: Any | None = None, 19 | ) -> Callable: 20 | """ 21 | Wraps a function used to render a matplotlib figure so that 22 | it automatically saves the figure and closes it. 23 | 24 | Args: 25 | in_memory: Whether to render the figure in-memory. 26 | scratch_dir: The scratch directory to use. 27 | fsspec_fs: The fsspec filesystem to use. 28 | 29 | Returns: 30 | The wrapped function. 31 | """ 32 | 33 | def wrapper(renderer): 34 | @wraps(renderer) 35 | def wrapped(*args, **kwargs) -> Path | BytesIO: 36 | import matplotlib 37 | 38 | matplotlib.use("Agg") 39 | import matplotlib.pyplot as plt 40 | 41 | plt.rcParams.update({"figure.max_open_warning": config["max_open_warning"]}) 42 | 43 | output = renderer(*args, **kwargs) 44 | 45 | fig = output 46 | return_paused = False 47 | if isinstance(output, Paused): 48 | return_paused = True 49 | fig = output.output 50 | 51 | if isinstance(fig, plt.Axes): 52 | fig = fig.figure 53 | elif not isinstance(fig, plt.Figure): 54 | raise ValueError("Renderer must return a Figure or Axes object.") 55 | 56 | uri = _utils.resolve_uri( 57 | file_name=f"{hash(fig)}.jpg", 58 | scratch_dir=scratch_dir, 59 | in_memory=in_memory, 60 | fsspec_fs=fsspec_fs, 61 | ) 62 | if fsspec_fs: 63 | with fsspec_fs.open(uri, "wb") as f: 64 | buf = BytesIO() 65 | fig.savefig(buf, format="jpg") 66 | buf.seek(0) 67 | f.write(buf.read()) 68 | else: 69 | fig.savefig(uri, format="jpg") 70 | plt.close(fig) 71 | return ( 72 | uri if not return_paused else Paused(output=uri, seconds=output.seconds) 73 | ) 74 | 75 | return wrapped 76 | 77 | return wrapper 78 | 79 | 80 | def wrap_holoviews( 81 | in_memory: bool = False, 82 | scratch_dir: str | Path | None = None, 83 | fsspec_fs: Any | None = None, 84 | webdriver: str | Callable | None = None, 85 | num_retries: int | None = None, 86 | ) -> Callable: 87 | """ 88 | Wraps a function used to render a holoviews object so that 89 | it automatically saves the object. 90 | 91 | Args: 92 | in_memory: Whether to render the object in-memory. 93 | scratch_dir: The scratch directory to use. 94 | fsspec_fs: The fsspec filesystem to use. 95 | webdriver: The webdriver to use. 96 | num_retries: The number of retries to use. 97 | 98 | Returns: 99 | The wrapped function. 100 | """ 101 | 102 | webdriver = _utils.get_config_default("webdriver", webdriver, warn=False) 103 | if isinstance(webdriver, str): 104 | webdriver = (webdriver, _utils.get_webdriver_path(webdriver)) 105 | 106 | if in_memory: 107 | raise ValueError("Holoviews renderer does not support in-memory rendering.") 108 | 109 | def wrapper(renderer): 110 | @wraps(renderer) 111 | def wrapped(*args, **kwargs) -> Path | BytesIO: 112 | import holoviews as hv 113 | 114 | backend = kwargs.get("backend", hv.Store.current_backend) 115 | output = renderer(*args, **kwargs) 116 | 117 | hv_obj = output 118 | return_paused = False 119 | if isinstance(output, Paused): 120 | return_paused = True 121 | hv_obj = output.output 122 | 123 | uri = _utils.resolve_uri( 124 | file_name=f"{hash(hv_obj)}.png", 125 | scratch_dir=scratch_dir, 126 | in_memory=in_memory, 127 | fsspec_fs=fsspec_fs, 128 | ) 129 | if backend == "bokeh": 130 | from bokeh.io.export import get_screenshot_as_png 131 | 132 | retries = _utils.get_config_default( 133 | "num_retries", num_retries, warn=False 134 | ) 135 | for r in range(retries): 136 | try: 137 | driver = _utils.get_webdriver(webdriver) 138 | with driver: 139 | image = get_screenshot_as_png( 140 | hv.render(hv_obj, backend=backend), driver=driver 141 | ) 142 | if fsspec_fs: 143 | with fsspec_fs.open(uri, "wb") as f: 144 | image.save(f, format="png") 145 | else: 146 | image.save(uri, format="png") 147 | break 148 | except Exception as e: 149 | logging.warning( 150 | f"Failed to save image: {e}, retrying in {r * 2}s" 151 | ) 152 | time.sleep(r * 2) 153 | if r == retries - 1: 154 | raise e 155 | else: 156 | if fsspec_fs: 157 | with fsspec_fs.open(uri, "wb") as f: 158 | buf = BytesIO() 159 | hv.save(hv_obj, buf, fmt="png") 160 | buf.seek(0) 161 | f.write(buf.read()) 162 | else: 163 | hv.save(hv_obj, uri, fmt="png", backend=backend) 164 | 165 | return ( 166 | uri if not return_paused else Paused(output=uri, seconds=output.seconds) 167 | ) 168 | 169 | return wrapped 170 | 171 | return wrapper 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🌈 StreamJoy 😊 2 | 3 | --- 4 | 5 | [![build](https://github.com/ahuang11/streamjoy/workflows/Build/badge.svg)](https://github.com/ahuang11/streamjoy/actions) [![codecov](https://codecov.io/gh/ahuang11/streamjoy/branch/master/graph/badge.svg)](https://codecov.io/gh/ahuang11/streamjoy) [![PyPI version](https://badge.fury.io/py/streamjoy.svg)](https://badge.fury.io/py/streamjoy) 6 | 7 | [![Downloads](https://pepy.tech/badge/streamjoy)](https://pepy.tech/project/streamjoy) [![GitHub stars](https://img.shields.io/github/stars/ahuang11/streamjoy?style=flat-square)](https://img.shields.io/github/stars/ahuang11/streamjoy?style=flat-square) 8 | 9 | --- 10 | 11 | ## 🔥 Enjoy animating! 12 | 13 | Streamjoy turns your images into animations using sensible defaults for fun, hassle-free creation. 14 | 15 | It cuts down the boilerplate and time to work on animations, and it's simple to start with just a few lines of code. 16 | 17 | Install it with just pip to start, blazingly fast! 18 | 19 | ```python 20 | pip install streamjoy 21 | ``` 22 | 23 | 24 | 25 | Or, try out a basic web app version here: 26 | 27 | https://huggingface.co/spaces/ahuang11/streamjoy 28 | 29 | 30 | 31 | ## 🛠️ Built-in features 32 | 33 | - 🌐 Animate from URLs, files, and datasets 34 | - 🎨 Render images with default or custom renderers 35 | - 🎬 Provide context with a short intro splash 36 | - ⏸ Add pauses at the beginning, end, or between frames 37 | - ⚡ Execute read, render, and write in parallel 38 | - 🔗 Connect multiple animations together 39 | 40 | ## 🚀 Quick start 41 | 42 | ### 🐤 Absolute basics 43 | 44 | Stream from a list of images--local files work too! 45 | 46 | ```python 47 | from streamjoy import stream 48 | 49 | if __name__ == "__main__": 50 | URL_FMT = "https://www.cpc.ncep.noaa.gov/products/NMME/archive/2025090800/current/images/NMME_ensemble_tmpsfc_lead{i}.png" 51 | resources = [URL_FMT.format(i=i) for i in range(1, 8)] 52 | stream(resources, uri="nmme.gif") # .gif, .mp4, and .html supported 53 | ``` 54 | 55 | ![nmme](https://github.com/user-attachments/assets/ab9a3b5e-3b5c-4deb-b093-891adb936f0c) 56 | 57 | ### 💅 Polish up 58 | 59 | Specify a few more keywords to: 60 | 61 | 1. add an intro title and subtitle 62 | 2. adjust the pauses 63 | 3. optimize the GIF thru pygifsicle 64 | 65 | Note: This example no longer works because URL changed! 66 | 67 | ```python 68 | from streamjoy import stream 69 | 70 | if __name__ == "__main__": 71 | URL_FMT = "https://www.goes.noaa.gov/dimg/jma/fd/vis/{i}.gif" 72 | resources = [URL_FMT.format(i=i) for i in range(1, 11)] 73 | himawari_stream = stream( 74 | resources, 75 | uri="goes_custom.gif", 76 | intro_title="Himawari Visible", 77 | intro_subtitle="10 Hours Loop", 78 | intro_pause=1, 79 | ending_pause=1, 80 | optimize=True, 81 | ) 82 | ``` 83 | 84 | 85 | 86 | ### 👀 Preview inputs 87 | 88 | If you'd like to preview the `repr` before writing, drop `uri`. 89 | 90 | Note: This example no longer works because URL changed! 91 | 92 | Output: 93 | ```yaml 94 | 95 | --- 96 | Output: 97 | max_frames: 50 98 | fps: 10 99 | display: True 100 | scratch_dir: streamjoy_scratch 101 | in_memory: False 102 | --- 103 | Intro: 104 | intro_title: Himawari Visible 105 | intro_subtitle: 10 Hours Loop 106 | intro_watermark: made with streamjoy 107 | intro_pause: 1 108 | intro_background: black 109 | --- 110 | Client: 111 | batch_size: 10 112 | processes: True 113 | threads_per_worker: None 114 | --- 115 | Resources: (10 frames to stream) 116 | https://www.goes.noaa.gov/dimg/jma/fd/vis/1.gif 117 | ... 118 | https://www.goes.noaa.gov/dimg/jma/fd/vis/10.gif 119 | --- 120 | ``` 121 | 122 | Then, when ready, call the `write` method to save the animation! 123 | 124 | ```python 125 | himawari_stream.write() 126 | ``` 127 | 128 | ### 🖇️ Connect streams 129 | 130 | Connect multiple streams together to provide further context. 131 | 132 | Note: This example no longer works because URL changed! 133 | 134 | ```python 135 | from streamjoy import stream, connect 136 | 137 | URL_FMTS = { 138 | "visible": "https://www.goes.noaa.gov/dimg/jma/fd/vis/{i}.gif", 139 | "infrared": "https://www.goes.noaa.gov/dimg/jma/fd/rbtop/{i}.gif", 140 | } 141 | 142 | if __name__ == "__main__": 143 | visible_stream = stream( 144 | [URL_FMTS["visible"].format(i=i) for i in range(1, 11)], 145 | intro_title="Himawari Visible", 146 | intro_subtitle="10 Hours Loop", 147 | ) 148 | infrared_stream = stream( 149 | [URL_FMTS["infrared"].format(i=i) for i in range(1, 11)], 150 | intro_title="Himawari Infrared", 151 | intro_subtitle="10 Hours Loop", 152 | ) 153 | connect([visible_stream, infrared_stream], uri="goes_connected.gif") 154 | ``` 155 | 156 | 157 | 158 | ### 📷 Render datasets 159 | 160 | You can also render images directly from datasets, either through a custom renderer or a built-in one, and they'll also run in parallel! 161 | 162 | The following example requires xarray, cartopy, matplotlib, and netcdf4. 163 | 164 | ```bash 165 | pip install xarray cartopy matplotlib netcdf4 166 | ``` 167 | 168 | ```python 169 | import numpy as np 170 | import cartopy.crs as ccrs 171 | import matplotlib.pyplot as plt 172 | from streamjoy import stream, wrap_matplotlib 173 | 174 | @wrap_matplotlib() 175 | def plot(da, central_longitude, **plot_kwargs): 176 | time = da["time"].dt.strftime("%b %d %Y").values.item() 177 | projection = ccrs.Orthographic(central_longitude=central_longitude) 178 | subplot_kw = dict(projection=projection, facecolor="gray") 179 | fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=subplot_kw) 180 | im = da.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, **plot_kwargs) 181 | ax.set_title(f"Sea Surface Temperature Anomaly\n{time}", loc="left", transform=ax.transAxes) 182 | ax.set_title("Source: NOAA OISST v2.1", loc="right", size=5, y=-0.01) 183 | ax.set_title("", loc="center") # suppress default title 184 | plt.colorbar(im, ax=ax, label="°C", shrink=0.8) 185 | return fig 186 | 187 | if __name__ == "__main__": 188 | url = ( 189 | "https://www.ncei.noaa.gov/data/sea-surface-temperature-" 190 | "optimum-interpolation/v2.1/access/avhrr/201008/" 191 | ) 192 | pattern = "oisst-avhrr-v02r01.*.nc" 193 | stream( 194 | url, 195 | uri="oisst.gif", 196 | pattern=pattern, # GifStream.from_url kwargs 197 | max_files=30, 198 | renderer=plot, # renderer related kwargs 199 | renderer_iterables=[np.linspace(-140, -150, 30)], # iterables; central longitude per frame (30 frames) 200 | renderer_kwargs=dict(cmap="RdBu_r", vmin=-5, vmax=5), # renderer kwargs 201 | # cmap="RdBu_r", # renderer_kwargs can also be propagated for convenience 202 | # vmin=-5, 203 | # vmax=5, 204 | ) 205 | ``` 206 | 207 | 208 | 209 | Check out all the supported formats [here](https://ahuang11.github.io/streamjoy/supported_formats/) or best practices [here](https://ahuang11.github.io/streamjoy/best_practices/). (Or maybe you're interested in the design--[here](https://ahuang11.github.io/streamjoy/package_design/)) 210 | 211 | --- 212 | 213 | ❤️ Made with considerable passion. 214 | 215 | 🌟 Appreciate the project? Consider giving a star! 216 | 217 | -------------------------------------------------------------------------------- /docs/how_do_i.md: -------------------------------------------------------------------------------- 1 | # How do I... 2 | 3 | ## 🖼️ Use all resources with `max_frames=-1` 4 | 5 | By default, StreamJoy only renders the first 50 frames to prevent accidentally rendering a large dataset. 6 | 7 | To render all frames, set `max_frames=-1`. 8 | 9 | ```python 10 | from streamjoy import stream 11 | 12 | stream(..., max_frames=-1) 13 | ``` 14 | 15 | ## ⏸️ How to pause animations with `Paused`, `intro_pause`, `ending_pause` 16 | 17 | Animations can be good, but sometimes you want to pause at various points of the animation to provide context or to emphasize a point. 18 | 19 | To pause at a given frame using a custom `renderer`, wrap `Paused` around the output: 20 | 21 | ```python 22 | from streamjoy import stream 23 | 24 | def plot_frame(time) 25 | important_time = ... 26 | if time == some_time: 27 | return Paused(fig, seconds=3) 28 | else: 29 | return fig 30 | 31 | stream(..., renderer=plot_frame) 32 | ``` 33 | 34 | Don't forget there's also `intro_pause` and `ending_pause` to pause at the beginning and end of the animation! 35 | 36 | ## 📊 Reduce boilerplate code with `wrap_*` decorators 37 | 38 | If you're using a custom `renderer`, you can use `wrap_matplotlib` and `wrap_holoviews` to automatically handle saving and closing the figure. 39 | 40 | ```python 41 | from streamjoy import stream, wrap_matplotlib 42 | 43 | @wrap_matplotlib() 44 | def plot_frame(time): 45 | ... 46 | 47 | stream(..., renderer=plot_frame) 48 | ``` 49 | 50 | ## 🗣️ Provide context with `intro_title` and `intro_subtitle` 51 | 52 | Use `intro_title` and `intro_subtitle` to provide context at the beginning of the animation. 53 | 54 | ```python 55 | from streamjoy import stream 56 | 57 | stream(..., intro_title="Himawari Visible", intro_subtitle="10 Hours Loop") 58 | ``` 59 | 60 | ## 💾 Write animation to memory instead of file 61 | 62 | If you're just testing out the animation, you can save it to memory instead of to disk by calling write without specifying a uri. 63 | 64 | ```python 65 | from streamjoy import stream 66 | 67 | stream(...).write() 68 | ``` 69 | 70 | ## 🚪 Use as a method of `pandas` and `xarray` objects 71 | 72 | StreamJoy can be used directly from `pandas` and `xarray` objects as an accessor. 73 | 74 | ```python 75 | import pandas as pd 76 | import streamjoy.pandas 77 | 78 | df = pd.DataFrame(...) 79 | 80 | # equivalent to streamjoy.stream(df) 81 | df.streamjoy(...) 82 | 83 | # series also works! 84 | df["col"].streamjoy(...) 85 | ``` 86 | 87 | ```python 88 | import xarray as xr 89 | import streamjoy.xarray 90 | 91 | ds = xr.Dataset(...) 92 | 93 | # equivalent to streamjoy.stream(ds) 94 | ds.streamjoy(...) 95 | 96 | # dataarray also works! 97 | ds["var"].streamjoy(...) 98 | ``` 99 | 100 | ## ⛓️ Join streams with `sum` and `connect` 101 | 102 | Use `sum` to join homogeneous streams, i.e. streams that have the same keyword arguments. 103 | 104 | ```python 105 | from streamjoy import stream, sum 106 | 107 | sum([stream(..., **same_kwargs) for i in range(10)]) 108 | ``` 109 | 110 | Use `connect` to join heterogeneous streams, i.e. streams that have different keyword arguments, like different `intro_title` and `intro_subtitle`. 111 | 112 | ```python 113 | from streamjoy import stream, connect 114 | 115 | connect([stream(..., **kwargs1), stream(..., **kwargs2)]) 116 | ``` 117 | 118 | ## 🎥 Decide between writing as `.mp4` vs `.gif` 119 | 120 | If you need a comprehensive color palette, use `.mp4` as it supports more colors. 121 | 122 | For automatic playing and looping, use `.gif`. To reduce the file size of the `.gif`, set `optimize=True`, which uses `pygifsicle` to reduce the file size. 123 | 124 | ## 📦 Prevent `RuntimeError` by using `__name__ == "__main__"` 125 | 126 | If you run a `.py` script without it, you might encounter the following `RuntimeError`: 127 | 128 | ```python 129 | RuntimeError: 130 | An attempt has been made to start a new process before the 131 | current process has finished its bootstrapping phase. 132 | 133 | This probably means that you are not using fork to start your 134 | child processes and you have forgotten to use the proper idiom 135 | in the main module: 136 | 137 | if __name__ == '__main__': 138 | freeze_support() 139 | ... 140 | 141 | The "freeze_support()" line can be omitted if the program 142 | is not going to be frozen to produce an executable. 143 | ``` 144 | 145 | To patch, simply wrap your `stream` call in `if __name__ == "__main__":`. 146 | 147 | ```python 148 | if __name__ == "__main__": 149 | stream(...) 150 | ``` 151 | 152 | It's fine without it in notebooks though. 153 | 154 | ## ⚙️ Set your own default settings with `config` 155 | 156 | StreamJoy uses a simple `config` dict to store settings. You can change the default settings by modifying the `streamjoy.config` object. 157 | 158 | To see the available options: 159 | ```python 160 | import streamjoy 161 | 162 | print(streamjoy.config) 163 | ``` 164 | 165 | To change the settings, it's simply updating the key-value pair. 166 | ```python 167 | import streamjoy 168 | 169 | streamjoy.config["max_frames"] = -1 170 | ``` 171 | 172 | Be wary of completely overwriting the `config` object, as it might break the functionality; do not do this! 173 | ```python 174 | import streamjoy 175 | 176 | streamjoy.config = {"max_frames": -1} 177 | ``` 178 | 179 | ## 🔧 Use custom values instead of the defaults 180 | 181 | Much of StreamJoy is based on sensible defaults to get you started quickly, but you should override them. 182 | 183 | For example, `max_frames` is set to 50 by default so you can quickly preview the animation. If you want to render the entire animation, set `max_frames=-1`. 184 | 185 | StreamJoy will warn you on some settings if you don't override them: 186 | 187 | ```python 188 | No 'max_frames' specified; using the default 50 / 100 frames. Pass `-1` to use all frames. Suppress this by passing 'max_frames'. 189 | ``` 190 | 191 | ## 🧩 Render HoloViews objects with `processes=False` 192 | 193 | This is done automatically! However, in case there's an edge case, note that the kdims/vdims don't seem to carry over properly to the subprocesses when rendering HoloViews objects. It might complain that it can't find the desired dimensions. 194 | 195 | ## 📚 Prevent flickering by setting `threads_per_worker` 196 | 197 | Matplotlib is not always thread-safe, so if you're seeing flickering, set `threads_per_worker=1`. 198 | 199 | ```python 200 | from streamjoy import stream 201 | 202 | stream(..., threads_per_worker=1) 203 | ``` 204 | 205 | ## 🖥️ Provide `client` if using a remote cluster 206 | 207 | If you're using a remote cluster, specify the `client` argument to use the Dask client. 208 | 209 | ```python 210 | from dask.distributed import Client 211 | 212 | client = Client() 213 | stream(..., client=client) 214 | ``` 215 | 216 | ## 🪣 Read & write files on a remote filesystem with `fsspec_fs` 217 | 218 | To read and write files on a remote filesystem, use `fsspec_fs` to specify the filesystem. 219 | 220 | A scratch directory must be provided; be sure to prefix the bucket name. 221 | 222 | ```python 223 | fs = fsspec.filesystem('s3', anon=False) 224 | stream(..., fsspec_fs=fs, scratch_dir="bucket-name/streamjoy_scratch") 225 | ``` 226 | 227 | ## 🚗 Use a custom webdriver to render HoloViews 228 | 229 | By default, StreamJoy uses Firefox as the default headless webdriver to render HoloViews objects into images. 230 | 231 | If you want to use Chrome instead, you can pass `webdriver="chrome"`. 232 | 233 | If you want to use a different webdriver, you can pass a custom function to `webdriver`. 234 | 235 | ```python 236 | def get_webdriver(): 237 | from selenium.webdriver.firefox.options import Options 238 | from selenium.webdriver.firefox.webdriver import Service, WebDriver 239 | from webdriver_manager.firefox import GeckoDriverManager 240 | 241 | options = Options() 242 | options.add_argument("--headless") 243 | options.add_argument("--disable-extensions") 244 | executable_path = GeckoDriverManager().install() 245 | driver = WebDriver( 246 | service=Service(executable_path), options=options 247 | ) 248 | return driver 249 | 250 | stream(..., webdriver=get_webdriver) 251 | ``` 252 | -------------------------------------------------------------------------------- /streamjoy/ui.py: -------------------------------------------------------------------------------- 1 | import re 2 | from io import BytesIO 3 | 4 | try: 5 | import param 6 | import panel as pn 7 | 8 | pn.extension(notifications=True) 9 | except ImportError: 10 | raise ImportError( 11 | "StreamJoy UI additionally requires panel" 12 | "run `pip install 'streamjoy[ui]'` to install." 13 | ) 14 | 15 | from .core import stream 16 | 17 | 18 | class App(pn.viewable.Viewer): 19 | 20 | url = param.String( 21 | label="URL", 22 | default="https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/2024/01_Jan/", 23 | ) 24 | 25 | max_files = param.Integer(bounds=(0, 1000), default=10) 26 | 27 | pattern = param.String(default="N_202401{DAY:02d}_conc_v3.0.png") 28 | 29 | pattern_inputs_start = param.Integer(bounds=(0, 1000), default=1) 30 | 31 | pattern_inputs_end = param.Integer(bounds=(0, 1000), default=10) 32 | 33 | pattern_inputs = param.Dict() 34 | 35 | extension = param.Selector(objects=[".gif", ".html"], default=".html") 36 | 37 | def __init__(self, **params): 38 | super().__init__(**params) 39 | url_input = pn.widgets.TextInput.from_param( 40 | self.param.url, placeholder="Enter URL" 41 | ) 42 | max_files_input = pn.widgets.Spinner.from_param( 43 | self.param.max_files, name="Max Files" 44 | ) 45 | pattern_input = pn.widgets.TextInput.from_param( 46 | self.param.pattern, placeholder="Enter pattern (e.g. *.png or {0}.png)" 47 | ) 48 | pattern_inputs_simple = pn.WidgetBox( 49 | pn.widgets.Spinner.from_param( 50 | self.param.pattern_inputs_start, name="Start of {}" 51 | ), 52 | pn.widgets.Spinner.from_param( 53 | self.param.pattern_inputs_end, name="End of {}" 54 | ), 55 | ) 56 | pattern_inputs_editor = pn.widgets.JSONEditor.from_param( 57 | self.param.pattern_inputs, 58 | mode="form", 59 | value={"i": [0]}, 60 | sizing_mode="stretch_width", 61 | search=False, 62 | menu=False, 63 | ) 64 | self._pattern_inputs_tabs = pn.Tabs( 65 | ("Simple", pattern_inputs_simple), 66 | # ("Editor", pattern_inputs_editor), 67 | ) 68 | self._pattern_preview = pn.pane.HTML("Preview: ") 69 | self._pattern_view = pn.Column( 70 | pn.pane.HTML("Pattern Inputs"), 71 | self._pattern_inputs_tabs, 72 | self._pattern_preview, 73 | ) 74 | input_widgets = pn.Card( 75 | url_input, 76 | pattern_input, 77 | max_files_input, 78 | self._pattern_view, 79 | title="Inputs", 80 | sizing_mode="stretch_width" 81 | ) 82 | submit_button = pn.widgets.Button( 83 | on_click=self._on_submit, 84 | name="Submit", 85 | sizing_mode="stretch_width", 86 | button_type="success", 87 | ) 88 | self._download_button = pn.widgets.FileDownload( 89 | filename="streamjoy.html", 90 | callback=self._download, 91 | sizing_mode="stretch_width", 92 | button_type="primary", 93 | disabled=True, 94 | ) 95 | extension_input = pn.widgets.Select.from_param( 96 | self.param.extension, sizing_mode="stretch_width" 97 | ) 98 | self._sidebar = pn.Column( 99 | pn.Row(submit_button, self._download_button), 100 | extension_input, 101 | input_widgets, 102 | ) 103 | self._main = pn.Column() 104 | self._dashboard = pn.template.FastListTemplate( 105 | title="StreamJoy", 106 | sidebar=[self._sidebar], 107 | main=[self._main], 108 | ) 109 | self._update_pattern_preview() 110 | 111 | def _extract_templates(self, pattern): 112 | pattern_formats = re.search(r"{(\w+)", pattern) 113 | return pattern_formats 114 | 115 | @param.depends("pattern", watch=True) 116 | def _update_pattern_inputs(self): 117 | pattern = self.pattern 118 | pattern_formats = self._extract_templates(pattern) 119 | 120 | if pattern_formats is not None: 121 | self._pattern_view.visible = True 122 | pattern_inputs_simple = self._pattern_inputs_tabs[0] 123 | # pattern_inputs_editor = self._pattern_inputs_tabs[1] 124 | try: 125 | pattern_formats.group(2) 126 | pn.state.notifications.error(f"Only one pattern format is allowed.") 127 | pattern_inputs_simple.disabled = True 128 | # pattern_inputs_editor.disabled = True 129 | except IndexError: 130 | pass 131 | pattern_format_key = pattern_formats.group(1) 132 | pattern_inputs_simple.disabled = False 133 | # pattern_inputs_editor.disabled = False 134 | pattern_inputs_simple[0].name = f"Start of {pattern_format_key}" 135 | # pattern_inputs_simple[1].name = f"End of {pattern_format_key}" 136 | else: 137 | self._pattern_view.visible = False 138 | 139 | @param.depends("pattern", "pattern_inputs_start", "pattern_inputs_end", watch=True) 140 | def _update_pattern_preview(self): 141 | pattern = self.pattern 142 | pattern_formats = self._extract_templates(pattern) 143 | if pattern_formats is not None: 144 | pattern_format_key = pattern_formats.group(1) 145 | pattern_inputs_start = self.pattern_inputs_start 146 | pattern_inputs_end = self.pattern_inputs_end 147 | pattern_start = pattern.format(**{pattern_format_key: pattern_inputs_start}) 148 | pattern_end = pattern.format(**{pattern_format_key: pattern_inputs_end}) 149 | self._pattern_preview.object = ( 150 | f"Preview:
" 151 | f"{pattern_start}" 152 | f"
...to...
" 153 | f"{pattern_end}" 154 | ) 155 | 156 | def _on_submit(self, event): 157 | with self._sidebar.param.update(loading=True): 158 | if self.url: 159 | stream_kwargs = {} 160 | if self._pattern_view.visible: 161 | pattern = self.pattern 162 | pattern_formats = self._extract_templates(pattern) 163 | if pattern_formats is not None: 164 | url = self.url 165 | if not url.endswith("/"): 166 | url += "/" 167 | pattern_format_key = pattern_formats.group(1) 168 | pattern_inputs_start = self.pattern_inputs_start 169 | pattern_inputs_end = self.pattern_inputs_end 170 | resources = [] 171 | for pattern_input in range( 172 | pattern_inputs_start, pattern_inputs_end + 1 173 | ): 174 | resource = self.url + pattern.format( 175 | **{pattern_format_key: pattern_input} 176 | ) 177 | resources.append(resource) 178 | else: 179 | resources = self.url 180 | stream_kwargs["pattern"] = self.pattern 181 | stream_kwargs["max_files"] = self.max_files 182 | 183 | if self.extension == ".html": 184 | stream_kwargs["ending_pause"] = 0 185 | 186 | output = stream( 187 | resources, extension=self.extension, **stream_kwargs 188 | ).write() 189 | 190 | if self.extension == ".html": 191 | self._main.objects = [output] 192 | buf = BytesIO() 193 | output.save(buf) 194 | self._buf = buf 195 | else: 196 | self._main.objects = [pn.pane.GIF(output)] 197 | self._buf = output 198 | self._download_button.disabled = False 199 | 200 | def _download(self): 201 | self._download_button.filename = f"streamjoy{self.extension}" 202 | self._buf.seek(0) 203 | return self._buf 204 | 205 | def serve(self, port: int = 8888, show: bool = True, **kwargs): 206 | pn.serve(self.__panel__(), port=port, show=show, **kwargs) 207 | 208 | def __panel__(self): 209 | return self._dashboard 210 | -------------------------------------------------------------------------------- /streamjoy/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import logging 5 | import os 6 | from collections.abc import Iterable 7 | from io import BytesIO 8 | from itertools import islice 9 | from pathlib import Path 10 | from typing import TYPE_CHECKING, Any, Callable 11 | 12 | import imageio.v3 as iio 13 | import numpy as np 14 | import param 15 | from dask.distributed import Client, Future, get_client 16 | 17 | from .models import Paused 18 | from .settings import config 19 | 20 | if TYPE_CHECKING: 21 | try: 22 | import xarray as xr 23 | except ImportError: 24 | xr = None 25 | 26 | try: 27 | from selenium.webdriver.remote.webdriver import BaseWebDriver 28 | except ImportError: 29 | BaseWebDriver = None 30 | 31 | 32 | def update_logger( 33 | level: str | None = None, 34 | format: str | None = None, 35 | datefmt: str | None = None, 36 | ) -> logging.Logger: 37 | success_level = config["logging_success_level"] 38 | 39 | class CustomLogger(logging.Logger): 40 | def success(self, message, *args, **kws): 41 | if self.isEnabledFor(success_level): 42 | self._log(success_level, message, args, **kws) 43 | 44 | color = config["logging_success_color"] 45 | reset = config["logging_reset_color"] 46 | logging.setLoggerClass(CustomLogger) 47 | logging.addLevelName(success_level, f"{color}SUCCESS{reset}") 48 | logger = logging.getLogger(__name__) 49 | 50 | level = level or config["logging_level"] 51 | format = format or config["logging_format"] 52 | datefmt = datefmt or config["logging_datefmt"] 53 | for handler in logging.getLogger().handlers: 54 | handler.setLevel(level) 55 | handler.setFormatter(logging.Formatter(format, datefmt)) 56 | 57 | return logger 58 | 59 | 60 | def warn_default_used( 61 | key: str, default_value: Any, total_value: Any | None = None, suffix: str = "" 62 | ) -> None: 63 | color = config["logging_warning_color"] 64 | reset = config["logging_reset_color"] 65 | message = ( 66 | f"No {color}{key!r}{reset} specified; using the default " 67 | f"{color}{default_value!r}{reset}" 68 | ) 69 | if total_value is not None: 70 | message += f" / {total_value!r}" 71 | if suffix: 72 | message += f" {suffix}" 73 | message += f". Suppress this by passing {key!r}." 74 | 75 | if total_value is not None and default_value < total_value: 76 | logging.warning(message) 77 | elif total_value is None: 78 | logging.warning(message) 79 | 80 | 81 | def get_config_default( 82 | key: str, 83 | value: Any, 84 | warn: bool = True, 85 | require: bool = True, 86 | config_prefix: str = "", 87 | **warn_kwargs, 88 | ) -> Any: 89 | config_alias = f"{config_prefix}_{key}" if config_prefix else key 90 | if require and config_alias not in config: 91 | raise ValueError(f"Missing required config key: {config_alias}") 92 | 93 | if value is None: 94 | default = config[config_alias] 95 | if warn: 96 | value = warn_default_used(key, default, **warn_kwargs) 97 | value = default 98 | return value 99 | 100 | 101 | def populate_config_defaults( 102 | params: dict[str, Any], 103 | keys: param.Parameterized, 104 | warn_on: list[str] | None = None, 105 | config_prefix: str = "", 106 | ): 107 | for key in keys: 108 | config_alias = f"{config_prefix}_{key}" if config_prefix else key 109 | if config_alias not in config.keys(): 110 | continue 111 | warn = key in (warn_on or []) 112 | value = get_config_default( 113 | key, params.get(key), warn=warn, config_prefix=config_prefix 114 | ) 115 | params[key] = value 116 | params = {key: value for key, value in params.items() if value is not None} 117 | return params 118 | 119 | 120 | def get_distributed_client(client: Client | None = None, **kwargs) -> Client: 121 | if client is not None: 122 | return client 123 | 124 | try: 125 | client = get_client() 126 | except ValueError: 127 | client = Client(**kwargs) 128 | return client 129 | 130 | 131 | def download_file( 132 | url: str, 133 | scratch_dir: Path | None = None, 134 | in_memory: bool = False, 135 | parent_depth: int | None = None, 136 | ) -> str: 137 | try: 138 | import requests 139 | except ImportError: 140 | raise ImportError("To directly read from a URL, `pip install requests`") 141 | 142 | url_path = Path(url) 143 | file_name = url_path.name 144 | 145 | parent_depth = get_config_default("parent_depth", parent_depth, warn=False) 146 | for _ in range(parent_depth): 147 | url_path = url_path.parent 148 | file_name = f"{url_path.name}_{file_name}" 149 | uri = resolve_uri(file_name=file_name, scratch_dir=scratch_dir, in_memory=in_memory) 150 | if not isinstance(uri, BytesIO) and os.path.exists(uri): 151 | return uri 152 | 153 | response = requests.get(url, stream=True) 154 | response.raise_for_status() 155 | with open(uri, "wb") as f: 156 | for chunk in response.iter_content(chunk_size=8192): 157 | if chunk: 158 | f.write(chunk) 159 | return uri 160 | 161 | 162 | def get_max_frames(total_frames: int, max_frames: int) -> int: 163 | default_max_frames = config["max_frames"] 164 | if max_frames is None and total_frames > default_max_frames: 165 | warn_default_used( 166 | "max_frames", 167 | default_max_frames, 168 | total_value=total_frames, 169 | suffix="frames. Pass `-1` to use all frames", 170 | ) 171 | max_frames = default_max_frames 172 | elif max_frames is None: 173 | max_frames = total_frames 174 | elif max_frames > total_frames: 175 | max_frames = total_frames 176 | elif max_frames == -1: 177 | max_frames = total_frames 178 | return max_frames 179 | 180 | 181 | def get_first(iterable): 182 | if isinstance(iterable, (list, tuple)): 183 | return iterable[0] 184 | return next(islice(iterable, 0, 1), None) 185 | 186 | 187 | def get_result(future: Future) -> Any: 188 | if isinstance(future, Future): 189 | return future.result() 190 | elif hasattr(future, "compute"): 191 | return future.compute() 192 | else: 193 | return future 194 | 195 | 196 | def using_notebook(): 197 | try: 198 | from IPython import get_ipython 199 | 200 | if "IPKernelApp" not in get_ipython().config: # Check if under IPython kernel 201 | return False 202 | except Exception: 203 | return False 204 | return True 205 | 206 | 207 | def resolve_uri( 208 | file_name: str | None = None, 209 | scratch_dir: str | Path | None = None, 210 | in_memory: bool = False, 211 | fsspec_fs: Any | None = None, 212 | ) -> str | Path | BytesIO: 213 | if in_memory: 214 | return BytesIO() 215 | 216 | output_dir = get_config_default("scratch_dir", scratch_dir, warn=False) 217 | if fsspec_fs: 218 | try: 219 | fsspec_fs.mkdir(output_dir, exist_ok=True, parents=True) 220 | except FileExistsError: 221 | pass 222 | uri = os.path.join(output_dir, file_name) 223 | else: 224 | output_dir = Path(output_dir) 225 | output_dir.mkdir(exist_ok=True, parents=True) 226 | uri = output_dir / file_name 227 | return uri 228 | 229 | 230 | def pop_kwargs(callable: Callable, kwargs: dict) -> None: 231 | args_spec = inspect.getfullargspec(callable) 232 | return {arg: kwargs.pop(arg) for arg in args_spec.args if arg in kwargs} 233 | 234 | 235 | def pop_from_cls(cls: type, kwargs: dict) -> dict: 236 | return { 237 | key: kwargs.pop(key) for key in set(kwargs) if key not in cls.param.values() 238 | } 239 | 240 | 241 | def import_function(import_path: str) -> Callable: 242 | module, function = import_path.rsplit(".", 1) 243 | module = __import__(module, fromlist=[function]) 244 | return getattr(module, function) 245 | 246 | 247 | def validate_xarray( 248 | ds: xr.Dataset | xr.DataArray, 249 | dim: str | None = None, 250 | var: str | None = None, 251 | warn: bool = True, 252 | raise_ndim: bool = True, 253 | ): 254 | import xarray as xr 255 | 256 | if var: 257 | ds = ds[var] 258 | elif isinstance(ds, xr.Dataset): 259 | var = list(ds.data_vars)[0] 260 | if warn: 261 | warn_default_used("var", var, suffix="from the dataset") 262 | ds = ds[var] 263 | 264 | squeeze_dims = [d for d in ds.dims if d != dim and ds.sizes[d] == 1] 265 | ds = ds.squeeze(squeeze_dims) 266 | if ds.ndim > 3 and raise_ndim: 267 | raise ValueError(f"Can only handle 3D arrays; {ds.ndim}D array found") 268 | return ds 269 | 270 | 271 | def validate_renderer_iterables( 272 | resources: list[Any], 273 | iterables: list[list[Any]], 274 | ): 275 | if iterables is None: 276 | return 277 | 278 | num_iterables = len(iterables) 279 | num_resources = len(resources) 280 | 281 | if num_iterables == num_resources: 282 | logging.warning( 283 | "The length of the iterables matches the length of the resources. " 284 | "This is likely not what you want; the iterables should be a list of lists, " 285 | "where each inner list corresponds to the arguments for each frame." 286 | ) 287 | 288 | if not isinstance(iterables[0], Iterable) or isinstance(iterables[0], str): 289 | raise TypeError( 290 | "Iterables should be like a list of lists, where each inner list corresponds " 291 | "to the arguments for each frame; e.g. `[[arg1_for_frame1, arg1_for_frame2], " 292 | "[arg2_for_frame_1, arg2_for_frame2]]`" 293 | ) 294 | 295 | 296 | def map_over(client, func, resources, batch_size, *args, **kwargs): 297 | try: 298 | return client.map(func, resources, *args, batch_size=batch_size, **kwargs) 299 | except TypeError: 300 | return [ 301 | client.submit(func, resource, *args, **kwargs) for resource in resources 302 | ] 303 | 304 | 305 | def repeat_frame( 306 | write: Callable, image: np.ndarray, seconds: int, fps: int, **write_kwargs 307 | ) -> np.ndarray: 308 | if seconds == 0: 309 | return image 310 | 311 | repeat = int(seconds * fps) 312 | for _ in range(repeat): 313 | write(image, **write_kwargs) 314 | return image 315 | 316 | 317 | def imread_with_pause( 318 | uri: Any | Paused, 319 | extension: str | None = None, 320 | plugin: str | None = None, 321 | fsspec_fs: Any | None = None, 322 | ) -> np.ndarray | Paused: 323 | imread_kwargs = dict(extension=extension, plugin=plugin) 324 | seconds = None 325 | if isinstance(uri, Paused): 326 | seconds = uri.seconds 327 | uri = uri.output 328 | if fsspec_fs: 329 | with fsspec_fs.open(uri, "rb") as f: 330 | image = iio.imread(f, **imread_kwargs) 331 | else: 332 | image = iio.imread(uri, **imread_kwargs).squeeze() 333 | image = image.squeeze() 334 | if seconds is None: 335 | return image 336 | else: 337 | return Paused(output=image, seconds=seconds) 338 | 339 | 340 | def subset_resources_renderer_iterables( 341 | resources: Any, renderer_iterables: list[Any], max_frames: int 342 | ): 343 | resources = resources[: max_frames or max_frames] 344 | renderer_iterables = [ 345 | iterable[: len(resources)] for iterable in renderer_iterables or [] 346 | ] 347 | return resources, renderer_iterables 348 | 349 | 350 | def get_webdriver_path(webdriver: str): 351 | if webdriver.lower() == "chrome": 352 | from webdriver_manager.chrome import ChromeDriverManager 353 | 354 | webdriver_path = ChromeDriverManager().install() 355 | elif webdriver.lower() == "firefox": 356 | from webdriver_manager.firefox import GeckoDriverManager 357 | 358 | webdriver_path = GeckoDriverManager().install() 359 | return webdriver_path 360 | 361 | 362 | def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver: 363 | if isinstance(webdriver, Callable): 364 | return webdriver() 365 | 366 | webdriver_key, webdriver_path = webdriver 367 | if webdriver_key.lower() == "chrome": 368 | from selenium.webdriver.chrome.options import Options 369 | from selenium.webdriver.chrome.webdriver import Service, WebDriver 370 | 371 | options = Options() 372 | options.add_argument("--headless") 373 | options.add_argument("--disable-extensions") 374 | webdriver_path = webdriver_path or get_webdriver_path("chrome") 375 | driver = WebDriver(service=Service(webdriver_path), options=options) 376 | 377 | elif webdriver_key.lower() == "firefox": 378 | from selenium.webdriver.firefox.options import Options 379 | from selenium.webdriver.firefox.webdriver import Service, WebDriver 380 | 381 | options = Options() 382 | options.add_argument("--headless") 383 | options.add_argument("--disable-extensions") 384 | webdriver_path = webdriver_path or get_webdriver_path("firefox") 385 | driver = WebDriver(service=Service(webdriver_path), options=options) 386 | 387 | else: 388 | raise NotImplementedError( 389 | f"Webdriver {webdriver_key} not supported; " 390 | f"use 'chrome' or 'firefox', or pass a custom callable." 391 | ) 392 | 393 | return driver 394 | -------------------------------------------------------------------------------- /streamjoy/serializers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from inspect import isgenerator 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Any, Callable 7 | 8 | import numpy as np 9 | 10 | from . import _utils 11 | from .models import Serialized 12 | from .renderers import ( 13 | default_holoviews_renderer, 14 | default_pandas_renderer, 15 | default_polars_renderer, 16 | default_xarray_renderer, 17 | ) 18 | from .settings import file_handlers, obj_handlers 19 | from .wrappers import wrap_holoviews, wrap_matplotlib 20 | 21 | if TYPE_CHECKING: 22 | try: 23 | import pandas as pd 24 | except ImportError: 25 | pd = None 26 | 27 | try: 28 | import polars as pl 29 | except ImportError: 30 | pl = None 31 | 32 | try: 33 | import xarray as xr 34 | except ImportError: 35 | xr = None 36 | 37 | try: 38 | import holoviews as hv 39 | except ImportError: 40 | hv = None 41 | 42 | from .streams import MediaStream 43 | 44 | 45 | def _select_obj_handler(resources: Any) -> MediaStream: 46 | if isinstance(resources, str) and "://" in resources: 47 | return serialize_url 48 | if isinstance(resources, (Path, str)): 49 | return serialize_paths 50 | 51 | resources_type = type(resources) 52 | module = getattr(resources_type, "__module__").split(".", maxsplit=1)[0] 53 | type_ = resources_type.__name__ 54 | for class_or_package_name, function_name in obj_handlers.items(): 55 | if ( 56 | f"{module}.{type_}" == class_or_package_name 57 | or module == class_or_package_name 58 | ): 59 | return globals()[function_name] 60 | 61 | raise ValueError( 62 | f"Could not find a method to handle {resources_type}; " 63 | f"supported classes/packages are {list(obj_handlers.keys())}." 64 | ) 65 | 66 | 67 | def serialize_numpy( 68 | stream_cls, 69 | resources: np.ndarray, 70 | renderer: Callable | None = None, 71 | renderer_iterables: list[Any] | None = None, 72 | renderer_kwargs: dict | None = None, 73 | **kwargs, 74 | ) -> Serialized: 75 | """ 76 | Serialize numpy arrays for streaming or rendering. 77 | 78 | Args: 79 | stream_cls: The class reference used for logging and utility functions. 80 | resources: The numpy array to be serialized. 81 | renderer: The rendering function to use on the array. 82 | renderer_iterables: Additional iterable arguments to pass to the renderer. 83 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 84 | **kwargs: Additional keyword arguments, including 'dim' and 'var' for xarray selection. 85 | 86 | Returns: 87 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments. 88 | """ 89 | resources = [resource for resource in resources] 90 | renderer_kwargs = renderer_kwargs or {} 91 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) 92 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) 93 | 94 | 95 | def serialize_xarray( 96 | stream_cls, 97 | resources: xr.Dataset | xr.DataArray, 98 | renderer: Callable | None = None, 99 | renderer_iterables: list[Any] | None = None, 100 | renderer_kwargs: dict | None = None, 101 | **kwargs, 102 | ) -> Serialized: 103 | """ 104 | Serialize xarray datasets or data arrays for streaming or rendering. 105 | 106 | Args: 107 | stream_cls: The class reference used for logging and utility functions. 108 | resources: The xarray dataset or data array to be serialized. 109 | renderer: The rendering function to use on the dataset. 110 | renderer_iterables: Additional iterable arguments to pass to the renderer. 111 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 112 | **kwargs: Additional keyword arguments, including 'dim' and 'var' for xarray selection. 113 | 114 | Returns: 115 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments. 116 | """ 117 | 118 | ds = resources 119 | dim = kwargs.pop("dim", None) 120 | var = kwargs.pop("var", None) 121 | 122 | ds = _utils.validate_xarray(ds, dim=dim, var=var, raise_ndim=renderer is None) 123 | if not dim: 124 | dim = list(ds.dims)[0] 125 | _utils.warn_default_used("dim", dim, suffix="from the dataset") 126 | elif dim not in ds.dims: 127 | raise ValueError(f"{dim!r} not in {ds.dims!r}") 128 | 129 | total_frames = len(ds[dim]) 130 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames")) 131 | resources = [ds.isel({dim: i}) for i in range(max_frames)] 132 | 133 | renderer_kwargs = renderer_kwargs or {} 134 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) 135 | 136 | if renderer is None: 137 | renderer = wrap_matplotlib( 138 | in_memory=kwargs.get("in_memory"), 139 | scratch_dir=kwargs.get("scratch_dir"), 140 | fsspec_fs=kwargs.get("fsspec_fs"), 141 | )(default_xarray_renderer) 142 | ds_0 = resources[0] 143 | if ds_0.ndim >= 2: 144 | renderer_kwargs["vmin"] = renderer_kwargs.get( 145 | "vmin", _utils.get_result(ds_0.min()).item() 146 | ) 147 | renderer_kwargs["vmax"] = renderer_kwargs.get( 148 | "vmax", _utils.get_result(ds_0.max()).item() 149 | ) 150 | else: 151 | renderer_kwargs["ylim"] = renderer_kwargs.get( 152 | "ylim", 153 | ( 154 | _utils.get_result(ds_0.min()).item(), 155 | _utils.get_result(ds_0.max()).item(), 156 | ), 157 | ) 158 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) 159 | 160 | 161 | def serialize_pandas( 162 | stream_cls, 163 | resources: pd.DataFrame, 164 | renderer: Callable | None = None, 165 | renderer_iterables: list[Any] | None = None, 166 | renderer_kwargs: dict | None = None, 167 | **kwargs, 168 | ) -> Serialized: 169 | """ 170 | Serialize pandas DataFrame for streaming or rendering. 171 | 172 | Args: 173 | stream_cls: The class reference used for logging and utility functions. 174 | resources: The pandas DataFrame to be serialized. 175 | renderer: The rendering function to use on the DataFrame. 176 | renderer_iterables: Additional iterable arguments to pass to the renderer. 177 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 178 | **kwargs: Additional keyword arguments, including 'groupby' for DataFrame grouping. 179 | 180 | Returns: 181 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments. 182 | """ 183 | import pandas as pd 184 | 185 | df = resources 186 | groupby = kwargs.get("groupby") 187 | 188 | total_frames = df.groupby(groupby).size().max() if groupby else len(df) 189 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames")) 190 | resources = [ 191 | df.groupby(groupby, as_index=False).head(i) if groupby else df.head(i) 192 | for i in range(1, max_frames + 1) 193 | ] 194 | 195 | renderer_kwargs = renderer_kwargs or {} 196 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) 197 | 198 | if renderer is None: 199 | renderer = wrap_matplotlib( 200 | in_memory=kwargs.get("in_memory"), 201 | scratch_dir=kwargs.get("scratch_dir"), 202 | fsspec_fs=kwargs.get("fsspec_fs"), 203 | )(default_pandas_renderer) 204 | if "x" not in renderer_kwargs: 205 | if df.index.name or isinstance(df, pd.Series): 206 | renderer_kwargs["x"] = df.index.name 207 | else: 208 | for col in df.columns: 209 | if col != groupby: 210 | break 211 | renderer_kwargs["x"] = col 212 | _utils.warn_default_used( 213 | "x", renderer_kwargs["x"], suffix="from the dataframe" 214 | ) 215 | if "y" not in renderer_kwargs: 216 | if isinstance(df, pd.Series): 217 | col = df.name 218 | else: 219 | numeric_cols = df.select_dtypes(include="number").columns 220 | for col in numeric_cols: 221 | if col not in (renderer_kwargs["x"], groupby): 222 | break 223 | renderer_kwargs["y"] = col 224 | _utils.warn_default_used( 225 | "y", renderer_kwargs["y"], suffix="from the dataframe" 226 | ) 227 | if "xlabel" not in renderer_kwargs: 228 | renderer_kwargs["xlabel"] = renderer_kwargs["x"].title().replace("_", " ") 229 | if "ylabel" not in renderer_kwargs: 230 | renderer_kwargs["ylabel"] = renderer_kwargs["y"].title().replace("_", " ") 231 | 232 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) 233 | 234 | 235 | def serialize_polars( 236 | stream_cls, 237 | resources: pl.DataFrame, 238 | renderer: Callable | None = None, 239 | renderer_iterables: list[Any] | None = None, 240 | renderer_kwargs: dict | None = None, 241 | **kwargs, 242 | ) -> Serialized: 243 | """ 244 | Serialize Polars DataFrame for streaming or rendering. 245 | 246 | Args: 247 | stream_cls: The class reference used for logging and utility functions. 248 | resources: The Polars DataFrame to be serialized. 249 | renderer: The rendering function to use on the DataFrame. 250 | renderer_iterables: Additional iterable arguments to pass to the renderer. 251 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 252 | **kwargs: Additional keyword arguments, including 'groupby' for DataFrame grouping. 253 | 254 | Returns: 255 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments. 256 | """ 257 | import polars as pl 258 | 259 | groupby = kwargs.get("groupby") 260 | 261 | if groupby: 262 | group_sizes = resources.groupby(groupby).agg(pl.len()) 263 | total_frames = group_sizes.select(pl.col("len").max()).to_numpy()[0, 0] 264 | else: 265 | total_frames = len(resources) 266 | 267 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames")) 268 | resources_expanded = [ 269 | resources.groupby(groupby).head(i) if groupby else resources.head(i) 270 | for i in range(1, max_frames + 1) 271 | ] 272 | 273 | renderer_kwargs = renderer_kwargs or {} 274 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) 275 | 276 | if renderer is None: 277 | renderer = wrap_holoviews( 278 | in_memory=kwargs.get("in_memory"), 279 | scratch_dir=kwargs.get("scratch_dir"), 280 | fsspec_fs=kwargs.get("fsspec_fs"), 281 | webdriver=renderer_kwargs.pop("webdriver", None), 282 | )(default_polars_renderer) 283 | numeric_cols = [ 284 | col 285 | for col in resources.columns 286 | if resources[col].dtype in [pl.Float64, pl.Int64, pl.Float32, pl.Int32] 287 | ] 288 | if "x" not in renderer_kwargs: 289 | for col in numeric_cols: 290 | if col != groupby: 291 | renderer_kwargs["x"] = col 292 | break 293 | _utils.warn_default_used( 294 | "x", renderer_kwargs["x"], suffix="from the dataframe" 295 | ) 296 | if "y" not in renderer_kwargs: 297 | for col in numeric_cols: 298 | if col not in (renderer_kwargs["x"], groupby): 299 | renderer_kwargs["y"] = col 300 | break 301 | _utils.warn_default_used( 302 | "y", renderer_kwargs["y"], suffix="from the dataframe" 303 | ) 304 | if "xlabel" not in renderer_kwargs: 305 | renderer_kwargs["xlabel"] = renderer_kwargs["x"].title().replace("_", " ") 306 | if "ylabel" not in renderer_kwargs: 307 | renderer_kwargs["ylabel"] = renderer_kwargs["y"].title().replace("_", " ") 308 | 309 | if kwargs.get("processes"): 310 | logging.warning( 311 | "Polars (HoloViews) rendering does not support processes; " 312 | "setting processes=False." 313 | ) 314 | kwargs["processes"] = False 315 | return Serialized( 316 | resources_expanded, renderer, renderer_iterables, renderer_kwargs, kwargs 317 | ) 318 | 319 | 320 | def serialize_holoviews( 321 | stream_cls, 322 | resources: hv.HoloMap | hv.DynamicMap, 323 | renderer: Callable | None = None, 324 | renderer_iterables: list[Any] | None = None, 325 | renderer_kwargs: dict | None = None, 326 | **kwargs, 327 | ) -> Serialized: 328 | """ 329 | Serialize HoloViews objects for streaming or rendering. 330 | 331 | Args: 332 | stream_cls: The class reference used for logging and utility functions. 333 | resources: The HoloViews object to be serialized. 334 | renderer: The rendering function to use on the HoloViews object. 335 | renderer_iterables: Additional iterable arguments to pass to the renderer. 336 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 337 | **kwargs: Additional keyword arguments for HoloViews object customization. 338 | 339 | Returns: 340 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments. 341 | """ 342 | import holoviews as hv 343 | 344 | backend = kwargs.get("backend", hv.Store.current_backend) 345 | 346 | def _select_element(hv_obj, key): 347 | try: 348 | resource = hv_obj[key] 349 | except Exception: 350 | resource = hv_obj.select(**{kdims[0].name: key}) 351 | return resource 352 | 353 | hv_obj = resources 354 | if isinstance(hv_obj, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)): 355 | hv_map = hv_obj 356 | elif issubclass( 357 | type(hv_obj), (hv.core.layout.Layoutable, hv.core.overlay.Overlayable) 358 | ): 359 | hv_map = hv_obj[0] 360 | 361 | if not isinstance(hv_map, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)): 362 | raise ValueError("Can only handle HoloMap and DynamicMap objects.") 363 | elif isinstance(hv_map, hv.core.spaces.DynamicMap): 364 | kdims = hv_map.kdims 365 | keys = hv_map.kdims[0].values 366 | else: 367 | kdims = hv_map.kdims 368 | keys = hv_map.keys() 369 | 370 | if len(kdims) > 1: 371 | raise ValueError("Can only handle 1D HoloViews objects.") 372 | 373 | resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys] 374 | 375 | renderer_kwargs = renderer_kwargs or {} 376 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs)) 377 | 378 | if renderer is None: 379 | renderer = wrap_holoviews( 380 | in_memory=kwargs.get("in_memory"), 381 | scratch_dir=kwargs.get("scratch_dir"), 382 | fsspec_fs=kwargs.get("fsspec_fs"), 383 | webdriver=renderer_kwargs.pop("webdriver", None), 384 | )(default_holoviews_renderer) 385 | clims = {} 386 | for hv_el in hv_obj.traverse(full_breadth=False): 387 | if isinstance(hv_el, hv.DynamicMap): 388 | hv.render(hv_el, backend=backend) 389 | 390 | if isinstance(hv_el, hv.Element): 391 | if hv_el.ndims > 1: 392 | vdim = hv_el.vdims[0].name 393 | array = hv_el.dimension_values(vdim) 394 | clim = (np.nanmin(array), np.nanmax(array)) 395 | clims[vdim] = clim 396 | 397 | renderer_kwargs.update( 398 | backend=backend, 399 | clims=clims, 400 | ) 401 | 402 | if kwargs.get("processes"): 403 | logging.warning( 404 | "HoloViews rendering does not support processes; " 405 | "setting processes=False." 406 | ) 407 | kwargs["processes"] = False 408 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) 409 | 410 | 411 | def serialize_url( 412 | stream_cls, 413 | resources: str, 414 | renderer: Callable | None = None, 415 | renderer_iterables: list[Any] | None = None, 416 | renderer_kwargs: dict | None = None, 417 | **kwargs, 418 | ) -> Serialized: 419 | """ 420 | Serialize resources from a URL for streaming or rendering. 421 | 422 | Args: 423 | stream_cls: The class reference used for logging and utility functions. 424 | resources: The URL of the resources to be serialized. 425 | renderer: The rendering function to use on the resources. 426 | renderer_iterables: Additional iterable arguments to pass to the renderer. 427 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 428 | **kwargs: Additional keyword arguments, including 'pattern', 'sort_key', 'max_files', 'file_handler', and 'file_handler_kwargs'. 429 | 430 | Returns: 431 | A MediaStream object containing the serialized resources. 432 | """ 433 | import re 434 | 435 | import requests 436 | from bs4 import BeautifulSoup 437 | 438 | base_url = resources 439 | pattern = kwargs.pop("pattern", None) 440 | sort_key = kwargs.pop("sort_key", None) 441 | max_files = kwargs.pop("max_files", None) 442 | file_handler = kwargs.pop("file_handler", None) 443 | file_handler_kwargs = kwargs.pop("file_handler_kwargs", None) 444 | 445 | max_files = _utils.get_config_default("max_files", max_files, suffix="links") 446 | 447 | logging.info(f"Retrieving resources from {base_url!r}.") 448 | 449 | with requests.get(resources, stream=True) as response: 450 | response.raise_for_status() 451 | 452 | partial_html = "" 453 | content_type = response.headers.get("Content-Type") 454 | for chunk in response.iter_content(chunk_size=1024, decode_unicode=True): 455 | if chunk: 456 | if pattern is not None: 457 | partial_html += chunk 458 | soup = BeautifulSoup(partial_html, "html.parser") 459 | href = re.compile(pattern.replace("*", ".*")) 460 | links = soup.find_all("a", href=href) 461 | if max_files > 0 and len(links) >= max_files: 462 | break 463 | else: 464 | if content_type.startswith("text"): 465 | raise ValueError( 466 | f"A pattern must be provided if the URL is a directory of files; " 467 | f"got {resources!r}." 468 | ) 469 | links = [{"href": ""}] 470 | 471 | if max_files > 0: 472 | links = links[:max_files] 473 | 474 | if len(links) == 0: 475 | raise ValueError(f"No links found with pattern {pattern!r} at {base_url!r}.") 476 | 477 | # download files 478 | urls = [base_url + link.get("href") for link in links] 479 | client = _utils.get_distributed_client( 480 | client=kwargs.get("client"), 481 | processes=kwargs.get("processes"), 482 | threads_per_worker=kwargs.get("threads_per_worker"), 483 | ) 484 | 485 | logging.info(f"Downloading {len(urls)} files from {base_url!r}.") 486 | futures = _utils.map_over( 487 | client, 488 | _utils.download_file, 489 | urls, 490 | kwargs.get("batch_size"), 491 | scratch_dir=kwargs.get("scratch_dir"), 492 | in_memory=kwargs.get("in_memory"), 493 | ) 494 | paths = client.gather(futures) 495 | return serialize_paths( 496 | stream_cls, 497 | paths, 498 | sort_key=sort_key, 499 | max_files=max_files, 500 | file_handler=file_handler, 501 | file_handler_kwargs=file_handler_kwargs, 502 | renderer=renderer, 503 | renderer_iterables=renderer_iterables, 504 | renderer_kwargs=renderer_kwargs, 505 | **kwargs, 506 | ) 507 | 508 | 509 | def serialize_paths( 510 | stream_cls, 511 | resources: list[str | Path] | str, 512 | renderer: Callable | None = None, 513 | renderer_iterables: list[Any] | None = None, 514 | renderer_kwargs: dict | None = None, 515 | **kwargs, 516 | ) -> Serialized: 517 | """ 518 | Serialize resources from file paths for streaming or rendering. 519 | 520 | Args: 521 | resources: A list of file paths or a single file path string to be serialized. 522 | renderer: The rendering function to use on the resources. 523 | renderer_iterables: Additional iterable arguments to pass to the renderer. 524 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 525 | **kwargs: Additional keyword arguments, including 'sort_key', 'max_files', 'file_handler', and 'file_handler_kwargs'. 526 | 527 | Returns: 528 | A MediaStream object containing the serialized resources. 529 | """ 530 | if isinstance(resources, str): 531 | resources = [resources] 532 | 533 | sort_key = kwargs.pop("sort_key", None) 534 | max_files = kwargs.pop("max_files", None) 535 | file_handler = kwargs.pop("file_handler", None) 536 | file_handler_kwargs = kwargs.pop("file_handler_kwargs", None) 537 | 538 | paths = sorted(resources, key=sort_key) 539 | 540 | max_files = _utils.get_config_default( 541 | "max_files", max_files, total_value=len(paths), suffix="paths" 542 | ) 543 | if max_files > 0: 544 | paths = paths[:max_files] 545 | 546 | # find a file handler 547 | extension = Path(paths[0]).suffix 548 | file_handler_meta = file_handlers.get(extension, {}) 549 | file_handler_import_path = file_handler_meta.get("import_path") 550 | file_handler_concat_path = file_handler_meta.get("concat_path") 551 | if file_handler is None and file_handler_import_path is not None: 552 | file_handler = _utils.import_function(file_handler_import_path) 553 | if file_handler_concat_path is not None: 554 | file_handler_concat = _utils.import_function(file_handler_concat_path) 555 | 556 | # read as objects 557 | if file_handler is not None: 558 | if file_handler_concat_path is not None: 559 | resources = file_handler_concat( 560 | file_handler(path, **(file_handler_kwargs or {})) for path in paths 561 | ) 562 | else: 563 | resources = file_handler(paths, **(file_handler_kwargs or {})) 564 | return serialize_appropriately( 565 | stream_cls, 566 | resources, 567 | renderer, 568 | renderer_iterables, 569 | renderer_kwargs, 570 | **kwargs, 571 | ) 572 | 573 | # or simply return image paths 574 | return Serialized(paths, renderer, renderer_iterables, renderer_kwargs, kwargs) 575 | 576 | 577 | def serialize_appropriately( 578 | stream_cls, 579 | resources: Any, 580 | renderer: Callable, 581 | renderer_iterables: list[Any], 582 | renderer_kwargs: dict[str, Any], 583 | **kwargs, 584 | ) -> Serialized: 585 | """ 586 | Automatically select the appropriate serialization method based on the type of resources. 587 | 588 | Args: 589 | stream_cls: The class reference used for logging and utility functions. 590 | resources: The resources to be serialized, which can be of any type. 591 | renderer: The rendering function to use on the resources. 592 | renderer_iterables: Additional iterable arguments to pass to the renderer. 593 | renderer_kwargs: Additional keyword arguments to pass to the renderer. 594 | **kwargs: Additional keyword arguments for further customization. 595 | 596 | Returns: 597 | A Serialized object containing the serialized resources. 598 | """ 599 | if not (isinstance(resources, (list, tuple)) or isgenerator(resources)): 600 | obj_handler = _select_obj_handler(resources) 601 | _utils.validate_renderer_iterables(resources, renderer_iterables) 602 | 603 | serialized = obj_handler( 604 | stream_cls, 605 | resources, 606 | renderer=renderer, 607 | renderer_iterables=renderer_iterables, 608 | renderer_kwargs=renderer_kwargs, 609 | **kwargs, 610 | ) 611 | resources = serialized.resources 612 | renderer = serialized.renderer 613 | renderer_iterables = serialized.renderer_iterables 614 | renderer_kwargs = serialized.renderer_kwargs 615 | kwargs = serialized.kwargs 616 | max_frames = _utils.get_max_frames(len(resources), kwargs.get("max_frames")) 617 | kwargs["max_frames"] = max_frames 618 | resources, renderer_iterables = _utils.subset_resources_renderer_iterables( 619 | resources, renderer_iterables, max_frames 620 | ) 621 | _utils.pop_kwargs(obj_handler, renderer_kwargs) 622 | _utils.pop_kwargs(obj_handler, kwargs) 623 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs) 624 | --------------------------------------------------------------------------------