├── tests
├── __init__.py
├── data
│ ├── air.zarr
│ │ ├── .zgroup
│ │ ├── lat
│ │ │ ├── 0
│ │ │ ├── .zattrs
│ │ │ └── .zarray
│ │ ├── lon
│ │ │ ├── 0
│ │ │ ├── .zattrs
│ │ │ └── .zarray
│ │ ├── time
│ │ │ ├── 0
│ │ │ ├── .zattrs
│ │ │ └── .zarray
│ │ ├── air
│ │ │ ├── 0.0.0
│ │ │ ├── .zarray
│ │ │ └── .zattrs
│ │ ├── .zattrs
│ │ └── .zmetadata
│ ├── air.nc
│ ├── gapminder.parquet
│ ├── N_19781028_conc_v3.0.png
│ ├── 1978_10_Oct_N_19781026_conc_v3.0.png
│ ├── 1978_10_Oct_N_19781027_conc_v3.0.png
│ └── gapminder.csv
├── test_polars.py
├── test_settings.py
├── test_pandas.py
├── test_xarray.py
├── test_renderers.py
├── conftest.py
├── test_models.py
├── test_wrappers.py
├── test_core.py
├── test_serializers.py
└── test_streams.py
├── docs
├── index.md
├── assets
│ ├── logo.png
│ └── logo_bw.png
├── api_reference
│ ├── core.md
│ ├── models.md
│ ├── streams.md
│ ├── wrappers.md
│ ├── renderers.md
│ └── settings.md
├── stylesheets
│ └── extra.css
├── example_recipes
│ ├── air_temperature.md
│ ├── sine_wave.md
│ ├── sea_ice.md
│ ├── nmme_forecast.md
│ ├── oisst_globe.md
│ ├── gender_gapminder.md
│ ├── nice_orbit.md
│ ├── stream_code.md
│ ├── co2_timeseries.md
│ └── temperature_anomaly.md
├── package_design.md
├── supported_formats.md
└── how_do_i.md
├── streamjoy
├── .gitignore
├── cli.py
├── pandas.py
├── xarray.py
├── polars.py
├── __init__.py
├── settings.py
├── core.py
├── models.py
├── renderers.py
├── wrappers.py
├── ui.py
├── _utils.py
└── serializers.py
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE.md
└── workflows
│ ├── build.yml
│ ├── documentation.yml
│ └── release.yml
├── HOWTORELEASE.md
├── LICENSE
├── CONTRIBUTING.md
├── HOWTOCONTRIBUTE.md
├── .gitignore
├── pyproject.toml
├── mkdocs.yml
└── README.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | --8<-- "README.md"
2 |
--------------------------------------------------------------------------------
/tests/data/air.zarr/.zgroup:
--------------------------------------------------------------------------------
1 | {
2 | "zarr_format": 2
3 | }
--------------------------------------------------------------------------------
/tests/data/air.nc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.nc
--------------------------------------------------------------------------------
/docs/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/docs/assets/logo.png
--------------------------------------------------------------------------------
/docs/assets/logo_bw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/docs/assets/logo_bw.png
--------------------------------------------------------------------------------
/tests/data/air.zarr/lat/0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/lat/0
--------------------------------------------------------------------------------
/tests/data/air.zarr/lon/0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/lon/0
--------------------------------------------------------------------------------
/tests/data/air.zarr/time/0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/time/0
--------------------------------------------------------------------------------
/tests/data/air.zarr/air/0.0.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/air.zarr/air/0.0.0
--------------------------------------------------------------------------------
/tests/data/gapminder.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/gapminder.parquet
--------------------------------------------------------------------------------
/tests/data/N_19781028_conc_v3.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/N_19781028_conc_v3.0.png
--------------------------------------------------------------------------------
/docs/api_reference/core.md:
--------------------------------------------------------------------------------
1 | # Core
2 |
3 | ::: streamjoy.core
4 | options:
5 | show_root_heading: false
6 | show_source: true
7 |
--------------------------------------------------------------------------------
/tests/data/1978_10_Oct_N_19781026_conc_v3.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/1978_10_Oct_N_19781026_conc_v3.0.png
--------------------------------------------------------------------------------
/tests/data/1978_10_Oct_N_19781027_conc_v3.0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ahuang11/streamjoy/HEAD/tests/data/1978_10_Oct_N_19781027_conc_v3.0.png
--------------------------------------------------------------------------------
/docs/api_reference/models.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | ::: streamjoy.models
4 | options:
5 | show_root_heading: false
6 | show_source: true
7 |
--------------------------------------------------------------------------------
/docs/api_reference/streams.md:
--------------------------------------------------------------------------------
1 | # Streams
2 |
3 | ::: streamjoy.streams
4 | options:
5 | show_root_heading: false
6 | show_source: true
7 |
--------------------------------------------------------------------------------
/docs/api_reference/wrappers.md:
--------------------------------------------------------------------------------
1 | # Wrappers
2 |
3 | ::: streamjoy.wrappers
4 | options:
5 | show_root_heading: false
6 | show_source: true
7 |
--------------------------------------------------------------------------------
/docs/api_reference/renderers.md:
--------------------------------------------------------------------------------
1 | # Renderers
2 |
3 | ::: streamjoy.renderers
4 | options:
5 | show_root_heading: false
6 | show_source: true
7 |
--------------------------------------------------------------------------------
/streamjoy/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .vscode
3 | build
4 | dist
5 | docs/build/*
6 | __pycache__
7 | *.egg*
8 | *.ipynb_checkpoints/
9 | *.jpg
10 | *.png
11 | *.gif
12 | *.mp4
13 | *.ipynb
14 |
--------------------------------------------------------------------------------
/tests/data/air.zarr/lat/.zattrs:
--------------------------------------------------------------------------------
1 | {
2 | "_ARRAY_DIMENSIONS": [
3 | "lat"
4 | ],
5 | "axis": "Y",
6 | "long_name": "Latitude",
7 | "standard_name": "latitude",
8 | "units": "degrees_north"
9 | }
--------------------------------------------------------------------------------
/tests/data/air.zarr/lon/.zattrs:
--------------------------------------------------------------------------------
1 | {
2 | "_ARRAY_DIMENSIONS": [
3 | "lon"
4 | ],
5 | "axis": "X",
6 | "long_name": "Longitude",
7 | "standard_name": "longitude",
8 | "units": "degrees_east"
9 | }
--------------------------------------------------------------------------------
/tests/data/air.zarr/time/.zattrs:
--------------------------------------------------------------------------------
1 | {
2 | "_ARRAY_DIMENSIONS": [
3 | "time"
4 | ],
5 | "calendar": "standard",
6 | "long_name": "Time",
7 | "standard_name": "time",
8 | "units": "hours since 1800-01-01"
9 | }
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
--------------------------------------------------------------------------------
/tests/data/air.zarr/.zattrs:
--------------------------------------------------------------------------------
1 | {
2 | "Conventions": "COARDS",
3 | "description": "Data is from NMC initialized reanalysis\n(4x/day). These are the 0.9950 sigma level values.",
4 | "platform": "Model",
5 | "references": "http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html",
6 | "title": "4x daily NMC reanalysis (1948)"
7 | }
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Expectation / Proposal
4 |
5 | # Traceback / Example
6 |
7 | - [ ] I would like to [help contribute](https://github.com/ahuang11/streamjoy/blob/main/HOWTOCONTRIBUTE.md) a pull request to resolve this!
8 |
--------------------------------------------------------------------------------
/HOWTORELEASE.md:
--------------------------------------------------------------------------------
1 | 1. Increment version in `package_name/__init__.py`
2 | 2. Push changes
3 | 3. Create a new tag
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/tests/data/air.zarr/lat/.zarray:
--------------------------------------------------------------------------------
1 | {
2 | "chunks": [
3 | 25
4 | ],
5 | "compressor": {
6 | "blocksize": 0,
7 | "clevel": 5,
8 | "cname": "lz4",
9 | "id": "blosc",
10 | "shuffle": 1
11 | },
12 | "dtype": "
4 |
5 |
6 |
7 | Super barebones example of rendering air temperature data from xarray.
8 |
9 | Highlights:
10 |
11 | - Imports `streamjoy.xarray` to use the `stream` accessor.
12 | - Passes the `uri` to `stream` as the first argument to save the animation to disk.
13 |
14 | ```python hl_lines="2 5"
15 | import xarray as xr
16 | import streamjoy.xarray
17 |
18 | if __name__ == "__main__":
19 | ds = xr.tutorial.open_dataset("air_temperature")
20 | ds.streamjoy("air_temperature.mp4")
21 | ```
22 |
--------------------------------------------------------------------------------
/streamjoy/xarray.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | try:
4 | import xarray as xr
5 | except Exception as exc:
6 | raise ImportError(
7 | "Could not patch plotting API onto xarray. xarray could not be imported."
8 | ) from exc
9 |
10 | from .core import stream
11 |
12 |
13 | def patch(name="streamjoy"):
14 | class StreamAccessor:
15 | def __init__(self, resources: xr.Dataset | xr.DataArray):
16 | self._resources = resources
17 |
18 | def __call__(self, *args, **kwargs):
19 | return stream(self._resources, *args, **kwargs)
20 |
21 | StreamAccessor.__doc__ = stream.__doc__
22 |
23 | xr.register_dataset_accessor(name)(StreamAccessor)
24 | xr.register_dataarray_accessor(name)(StreamAccessor)
25 |
26 |
27 | patch()
28 |
--------------------------------------------------------------------------------
/tests/test_settings.py:
--------------------------------------------------------------------------------
1 | from imageio.v3 import improps
2 |
3 | from streamjoy import config, stream
4 |
5 |
6 | class TestConfig:
7 | def test_update_defaults(self, df):
8 | config["max_frames"] = 3
9 | config["ending_pause"] = 0
10 |
11 | sj = stream(df)
12 | assert sj.max_frames == 3
13 | assert sj.ending_pause == 0
14 | buf = sj.write()
15 | props = improps(buf)
16 | assert props.n_images == 3
17 |
18 | def test_override_defaults(self, df):
19 | config["max_frames"] = 3
20 | config["ending_pause"] = 0
21 |
22 | sj = stream(df, max_frames=5)
23 | assert sj.max_frames == 5
24 | assert sj.ending_pause == 0
25 | buf = sj.write()
26 | props = improps(buf)
27 | assert props.n_images == 5
28 |
--------------------------------------------------------------------------------
/streamjoy/polars.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | try:
4 | import polars as pl
5 | except Exception as exc:
6 | raise ImportError(
7 | "Could not patch streamjoy API onto polars. Polars could not be imported."
8 | ) from exc
9 |
10 | from .core import stream
11 |
12 |
13 | def patch(name="streamjoy"):
14 | class StreamAccessor:
15 | def __init__(self, resources: pl.DataFrame | pl.Series | pl.LazyFrame):
16 | self._resources = resources
17 |
18 | def __call__(self, *args, **kwargs):
19 | return stream(self._resources, *args, **kwargs)
20 |
21 | pl.api.register_dataframe_namespace(name)(StreamAccessor)
22 | pl.api.register_series_namespace(name)(StreamAccessor)
23 | pl.api.register_lazyframe_namespace(name)(StreamAccessor)
24 |
25 |
26 | patch()
27 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | test:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python_version: ['3.10']
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: ${{ matrix.python_version }}
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install --upgrade pip
22 | pip install hatch
23 | hatch env create
24 | # - name: Lint and typecheck
25 | # run: |
26 | # hatch run lint-check
27 | - name: Test
28 | run: |
29 | hatch run test-cov-xml
30 | - uses: codecov/codecov-action@v4
31 | with:
32 | token: ${{ secrets.CODECOV_TOKEN }}
33 | fail_ci_if_error: true
34 | verbose: true
35 |
--------------------------------------------------------------------------------
/tests/test_pandas.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import streamjoy.pandas # noqa: F401
4 |
5 |
6 | class TestPandas:
7 | def test_dataframe(self, df):
8 | stream = df.streamjoy(groupby="Country")
9 | assert stream.renderer_kwargs == {
10 | "groupby": "Country",
11 | "x": "Year",
12 | "y": "fertility",
13 | "xlabel": "Year",
14 | "ylabel": "Fertility",
15 | }
16 | assert isinstance(stream.write(), BytesIO)
17 |
18 | def test_series(self, df):
19 | stream = df.set_index("Year")[["Country", "life"]].streamjoy(groupby="Country")
20 | assert stream.renderer_kwargs == {
21 | "groupby": "Country",
22 | "x": "Year",
23 | "y": "life",
24 | "xlabel": "Year",
25 | "ylabel": "Life",
26 | }
27 | assert isinstance(stream.write(), BytesIO)
28 |
--------------------------------------------------------------------------------
/docs/example_recipes/sine_wave.md:
--------------------------------------------------------------------------------
1 | # Sine wave
2 |
3 |
4 |
5 | Example of how to use `stream` alongside a custom `renderer` to create a sine wave animation.
6 |
7 | Highlights:
8 |
9 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure.
10 | - Uses a custom `renderer` function to create each frame of the animation.
11 |
12 | ```python hl_lines="5-6 14"
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | from streamjoy import stream, wrap_matplotlib
16 |
17 | @wrap_matplotlib()
18 | def plot_frame(i):
19 | x = np.linspace(0, 2, 1000)
20 | y = np.sin(2 * np.pi * (x - 0.01 * i))
21 | fig, ax = plt.subplots()
22 | ax.plot(x, y)
23 | return fig
24 |
25 | if __name__ == "__main__":
26 | stream(list(range(10)), uri="sine_wave.gif", renderer=plot_frame)
27 | ```
28 |
--------------------------------------------------------------------------------
/tests/test_xarray.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import streamjoy.xarray # noqa: F401
4 |
5 |
6 | class TestXArray:
7 | def test_dataset_3d(self, ds):
8 | sj = ds.streamjoy()
9 | assert "vmin" in sj.renderer_kwargs
10 | assert "vmax" in sj.renderer_kwargs
11 | assert isinstance(sj.write(), BytesIO)
12 |
13 | def test_dataarray_3d(self, ds):
14 | sj = ds["air"].streamjoy()
15 | assert "vmin" in sj.renderer_kwargs
16 | assert "vmax" in sj.renderer_kwargs
17 | assert isinstance(sj.write(), BytesIO)
18 |
19 | def test_dataset_2d(self, ds):
20 | sj = ds.mean("lat").streamjoy()
21 | assert "ylim" in sj.renderer_kwargs
22 | assert isinstance(sj.write(), BytesIO)
23 |
24 | def test_dataarray_2d(self, ds):
25 | sj = ds["air"].mean("lat").streamjoy()
26 | assert "ylim" in sj.renderer_kwargs
27 | assert isinstance(sj.write(), BytesIO)
28 |
--------------------------------------------------------------------------------
/streamjoy/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from ._utils import update_logger
4 | from .core import connect, stream
5 | from .models import ImageText, Paused
6 | from .renderers import (
7 | default_holoviews_renderer,
8 | default_pandas_renderer,
9 | default_xarray_renderer,
10 | )
11 | from .settings import config, file_handlers, obj_handlers
12 | from .streams import GifStream, HtmlStream, Mp4Stream
13 | from .wrappers import wrap_holoviews, wrap_matplotlib
14 |
15 | __version__ = "0.0.10"
16 |
17 | __all__ = [
18 | "GifStream",
19 | "HtmlStream",
20 | "ImageText",
21 | "Mp4Stream",
22 | "Paused",
23 | "config",
24 | "connect",
25 | "default_holoviews_renderer",
26 | "default_pandas_renderer",
27 | "default_xarray_renderer",
28 | "file_handlers",
29 | "obj_handlers",
30 | "stream",
31 | "wrap_holoviews",
32 | "wrap_matplotlib",
33 | ]
34 |
35 | logging.basicConfig(
36 | level=config["logging_level"],
37 | format=config["logging_format"],
38 | datefmt=config["logging_datefmt"],
39 | )
40 |
41 | update_logger()
42 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024, Andrew Huang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Development
2 |
3 | ### Setup environment
4 |
5 | We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system.
6 |
7 | ### Run unit tests
8 |
9 | You can run all the tests with:
10 |
11 | ```bash
12 | hatch run test
13 | ```
14 |
15 | ### Format the code
16 |
17 | Execute the following command to apply linting and check typing:
18 |
19 | ```bash
20 | hatch run lint
21 | ```
22 |
23 | ### Publish a new version
24 |
25 | You can bump the version, create a commit and associated tag with one command:
26 |
27 | ```bash
28 | hatch version patch
29 | ```
30 |
31 | ```bash
32 | hatch version minor
33 | ```
34 |
35 | ```bash
36 | hatch version major
37 | ```
38 |
39 | Your default Git text editor will open so you can add information about the release.
40 |
41 | When you push the tag on GitHub, the workflow will automatically publish it on PyPi and a GitHub release will be created as draft.
42 |
43 | ## Serve the documentation
44 |
45 | You can serve the Mkdocs documentation with:
46 |
47 | ```bash
48 | hatch run docs-serve
49 | ```
50 |
51 | It'll automatically watch for changes in your code.
52 |
--------------------------------------------------------------------------------
/HOWTOCONTRIBUTE.md:
--------------------------------------------------------------------------------
1 | ## Development
2 |
3 | ### Setup environment
4 |
5 | We use [Hatch](https://hatch.pypa.io/latest/install/) to manage the development environment and production build. Ensure it's installed on your system.
6 |
7 | ### Run unit tests
8 |
9 | You can run all the tests with:
10 |
11 | ```bash
12 | hatch run test
13 | ```
14 |
15 | ### Format the code
16 |
17 | Execute the following command to apply linting and check typing:
18 |
19 | ```bash
20 | hatch run lint
21 | ```
22 |
23 | ### Publish a new version
24 |
25 | You can bump the version, create a commit and associated tag with one command:
26 |
27 | ```bash
28 | hatch version patch
29 | ```
30 |
31 | ```bash
32 | hatch version minor
33 | ```
34 |
35 | ```bash
36 | hatch version major
37 | ```
38 |
39 | Your default Git text editor will open so you can add information about the release.
40 |
41 | When you push the tag on GitHub, the workflow will automatically publish it on PyPi and a GitHub release will be created as draft.
42 |
43 | ## Serve the documentation
44 |
45 | You can serve the Mkdocs documentation with:
46 |
47 | ```bash
48 | hatch run docs-serve
49 | ```
50 |
51 | It'll automatically watch for changes in your code.
52 |
--------------------------------------------------------------------------------
/docs/example_recipes/sea_ice.md:
--------------------------------------------------------------------------------
1 | # Sea ice
2 |
3 |
6 |
7 | Compares sea ice concentration data from the NOAA G02135 dataset for August 15th in 1989 and 2023.
8 |
9 | Highlights:
10 |
11 | - Downloads images directly from the NSIDC FTP server.
12 | - Uses `connect` to concatenate the two homogeneous streams together (same keyword arguments, different resources).
13 | - Uses `pattern` to filter for only the sea ice concentration images.
14 | - Uses `intro_title` and `intro_subtitle` to provide context at the beginning of the animation.
15 |
16 | ```python hl_lines="3 6-9"
17 | from streamjoy import stream, connect
18 |
19 | connect(
20 | [
21 | stream(
22 | f"https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/{year}/08_Aug/",
23 | pattern=f"N_*_conc_v3.0.png",
24 | intro_title=f"August 15, {year}",
25 | intro_subtitle="Sea Ice Concentration",
26 | max_files=31,
27 | )
28 | for year in [1989, 2023]
29 | ]
30 | ).write("sea_ice.mp4")
31 | ```
--------------------------------------------------------------------------------
/.github/workflows/documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages
9 | permissions:
10 | contents: read
11 | pages: write
12 | id-token: write
13 |
14 | # Allow one concurrent deployment
15 | concurrency:
16 | group: "pages"
17 | cancel-in-progress: true
18 |
19 | # Default to bash
20 | defaults:
21 | run:
22 | shell: bash
23 |
24 | jobs:
25 | build:
26 |
27 | runs-on: ubuntu-latest
28 |
29 | steps:
30 | - uses: actions/checkout@v4
31 | - name: Set up Python
32 | uses: actions/setup-python@v5
33 | with:
34 | python-version: '3.11'
35 | - name: Install dependencies
36 | run: |
37 | python -m pip install --upgrade pip
38 | pip install hatch
39 | hatch env create
40 | - name: Build
41 | run: hatch run docs-build
42 | - name: Upload artifact
43 | uses: actions/upload-pages-artifact@v3
44 | with:
45 | path: ./site
46 |
47 | deploy:
48 | environment:
49 | name: github-pages
50 | url: ${{ steps.deployment.outputs.page_url }}
51 | runs-on: ubuntu-latest
52 | needs: build
53 | steps:
54 | - name: Deploy to GitHub Pages
55 | id: deployment
56 | uses: actions/deploy-pages@v4
57 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Build & Release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*"
7 |
8 | jobs:
9 | build-release:
10 | name: Build Release
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: '3.9'
19 |
20 | - name: Install packages
21 | run: |
22 | python -m pip install --upgrade pip build
23 | python -m pip install --upgrade --upgrade-strategy eager -e .
24 |
25 | - name: Build a binary wheel and a source tarball
26 | run: |
27 | python -m build --sdist --wheel --outdir dist/
28 |
29 | - name: Publish build artifacts
30 | uses: actions/upload-artifact@v3
31 | with:
32 | name: built-package
33 | path: "./dist"
34 |
35 | publish-release:
36 | name: Publish release to PyPI
37 | needs: [build-release]
38 | environment: "prod"
39 | runs-on: ubuntu-latest
40 |
41 | steps:
42 | - name: Download build artifacts
43 | uses: actions/download-artifact@v3
44 | with:
45 | name: built-package
46 | path: './dist'
47 |
48 | - name: Publish distribution to PyPI
49 | uses: pypa/gh-action-pypi-publish@release/v1
50 | with:
51 | password: ${{ secrets.PYPI_API_TOKEN }}
52 | verbose: true
53 |
--------------------------------------------------------------------------------
/tests/test_renderers.py:
--------------------------------------------------------------------------------
1 | import holoviews as hv
2 | import matplotlib.pyplot as plt
3 | import pytest
4 |
5 | from streamjoy.renderers import (
6 | default_holoviews_renderer,
7 | default_pandas_renderer,
8 | default_polars_renderer,
9 | default_xarray_renderer,
10 | )
11 |
12 |
13 | class TestDefaultRenderer:
14 | @pytest.mark.parametrize("title", ["Constant", "{Year}", None])
15 | def test_pandas(self, df, title):
16 | fig = default_pandas_renderer(
17 | df, x="Year", y="life", groupby="Country", title=title
18 | )
19 | assert isinstance(fig, plt.Figure)
20 |
21 | @pytest.mark.parametrize("title", ["Constant", "{Year}", None])
22 | def test_polars(self, pl_df, title):
23 | rendered_obj = default_polars_renderer(
24 | pl_df, x="Year", y="life", groupby="Country", title=title
25 | )
26 | assert isinstance(rendered_obj, hv.NdOverlay)
27 |
28 | @pytest.mark.parametrize("title", ["Constant", "{time}", None])
29 | def test_xarray(self, ds, title):
30 | da = ds.air.isel(time=0)
31 | fig = default_xarray_renderer(da, title=title)
32 | assert isinstance(fig, plt.Figure)
33 |
34 | @pytest.mark.parametrize("title", ["Constant", "{x}", None])
35 | def test_holoviews(self, title):
36 | hv_obj = hv.Curve([1, 2, 3])
37 | rendered_obj = default_holoviews_renderer(hv_obj, title=title)
38 | assert isinstance(rendered_obj, hv.Element)
39 |
--------------------------------------------------------------------------------
/docs/example_recipes/nmme_forecast.md:
--------------------------------------------------------------------------------
1 | # NMME forecast
2 |
3 |
10 |
11 | Highlights:
12 |
13 | - Writes to memory to later use in another Panel component
14 | - Sets `extension` to hint at the desired output format
15 | - Appends two streams in a `Tabs` layout
16 | - Links the Players' value from the first tab to the second tab
17 |
18 | ```python hl_lines="30 36 37 38"
19 | import panel as pn
20 | import pandas as pd
21 | from streamjoy import stream
22 |
23 | pn.extension()
24 |
25 | URL_FMT = (
26 | "https://www.cpc.ncep.noaa.gov/products/NMME/archive/{dt:%Y%m}0800/"
27 | "current/images/NMME_ensemble_{var}_us_lead{i}.png"
28 | )
29 |
30 | VARS = {
31 | "prate": "Precipitation Rate",
32 | "tmp2m": "2m Temperature",
33 | }
34 | LEADS = 7
35 |
36 | var_tabs = pn.Tabs()
37 | for var in VARS.keys():
38 | dt_range = [
39 | pd.to_datetime("2024-03-08") - pd.DateOffset(months=lead)
40 | for lead in range(LEADS)
41 | ]
42 | urls = [
43 | URL_FMT.format(i=i, dt=dt, var=var)
44 | for i, dt in enumerate(dt_range, 1)
45 | ]
46 | col = stream(
47 | urls,
48 | extension=".html",
49 | fps=1,
50 | ending_pause=0,
51 | display=False,
52 | sizing_mode="stretch_width",
53 | height=400,
54 | ).write()
55 | var_tabs.append((VARS[var], col))
56 | var_tabs[0][1].jslink(var_tabs[1][1], value="value")
57 | var_tabs.save("nmme_forecast.html")
58 | ```
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import hvplot.xarray # noqa: F401
4 | import imageio.v3 as iio
5 | import pandas as pd
6 | import polars as pl
7 | import pytest
8 | import xarray as xr
9 |
10 | from streamjoy._utils import get_distributed_client
11 | from streamjoy.settings import config
12 |
13 | DATA_DIR = Path(__file__).parent / "data"
14 | NC_PATH = DATA_DIR / "air.nc"
15 | ZARR_PATH = DATA_DIR / "air.zarr"
16 | CSV_PATH = DATA_DIR / "gapminder.csv"
17 | PARQUET_PATH = DATA_DIR / "gapminder.parquet"
18 |
19 |
20 | @pytest.fixture
21 | def array():
22 | return iio.imread("imageio:newtonscradle.gif")
23 |
24 |
25 | @pytest.fixture
26 | def ds():
27 | return xr.open_zarr(ZARR_PATH)
28 |
29 |
30 | @pytest.fixture
31 | def df():
32 | return pd.read_parquet(PARQUET_PATH)
33 |
34 |
35 | @pytest.fixture
36 | def pl_df():
37 | return pl.read_parquet(PARQUET_PATH)
38 |
39 |
40 | @pytest.fixture
41 | def dmap(ds):
42 | return ds.hvplot("lon", "lat", dynamic=True)
43 |
44 |
45 | @pytest.fixture
46 | def hmap(ds):
47 | return ds.hvplot("lon", "lat", dynamic=False)
48 |
49 |
50 | @pytest.fixture(autouse=True, scope="session")
51 | def client():
52 | return get_distributed_client()
53 |
54 |
55 | @pytest.fixture(autouse=True, scope="session")
56 | def default_config():
57 | config["fps"] = 1
58 | config["max_frames"] = 3
59 | config["max_files"] = 3
60 | config["ending_pause"] = 0
61 |
62 |
63 | @pytest.fixture(scope="session")
64 | def data_dir():
65 | return DATA_DIR
66 |
67 |
68 | @pytest.fixture(scope="session")
69 | def fsspec_fs():
70 | try:
71 | import fsspec
72 |
73 | return fsspec.filesystem("file")
74 | except ImportError:
75 | pytest.skip("fsspec not installed")
76 |
--------------------------------------------------------------------------------
/docs/example_recipes/oisst_globe.md:
--------------------------------------------------------------------------------
1 | # OISST globe
2 |
3 |
4 |
7 |
8 | Render sea surface temperature anomaly data from the NOAA OISST v2.1 dataset on a globe.
9 |
10 | Highlights:
11 |
12 | - Concatenates multiple homogeneous streams together (same keyword arguments, different resources) by summing them.
13 | - Uses the built-in `default_xarray_renderer` under the hood
14 | - Uses `renderer_kwargs` to pass keyword arguments to the underlying `ds.plot` method.
15 |
16 | ```python hl_lines="19 32"
17 | import cartopy.crs as ccrs
18 | from streamjoy import stream
19 |
20 | YEAR = 2023
21 | URL_FMT = "https://www.ncei.noaa.gov/data/sea-surface-temperature-optimum-interpolation/v2.1/access/avhrr/{year}{month:02}/"
22 |
23 | if __name__ == "__main__":
24 | streams = []
25 | for month in range(1, 13):
26 | url = URL_FMT.format(year=YEAR, month=month)
27 | streams.append(
28 | stream(
29 | url,
30 | pattern="oisst-avhrr-v02r01*.nc",
31 | var="anom",
32 | dim="time",
33 | max_files=29,
34 | max_frames=-1,
35 | renderer_kwargs=dict(
36 | cmap="RdBu_r",
37 | vmin=-5,
38 | vmax=5,
39 | subplot_kws=dict(
40 | projection=ccrs.Orthographic(central_longitude=-150),
41 | facecolor="gray",
42 | ),
43 | transform=ccrs.PlateCarree(),
44 | ),
45 | )
46 | )
47 |
48 | joined_stream = sum(streams)
49 | joined_stream.write("oisst_globe.mp4")
50 | ```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | _NOTEBOOKS/
2 | streamjoy_scratch
3 | *.ipynb
4 |
5 | *.gif
6 | *.mp4
7 | *.html
8 | *.png
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | env/
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 | junit/
59 | junit.xml
60 | test.db
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # dotenv
96 | .env
97 |
98 | # virtualenv
99 | .venv
100 | venv/
101 | ENV/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 |
116 | # OS files
117 | .DS_Store
118 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw
2 |
3 | from streamjoy.models import ImageText, Paused
4 |
5 |
6 | class TestPaused:
7 | def test_paused_initialization(self):
8 | output = "some_output"
9 | seconds = 5
10 | paused_instance = Paused(output=output, seconds=seconds)
11 | assert paused_instance.output == output
12 | assert paused_instance.seconds == seconds
13 |
14 |
15 | class TestImageText:
16 | def test_image_text_initialization(self):
17 | text = "Hello, World!"
18 | font = "Arial"
19 | size = 24
20 | color = "black"
21 | anchor = "mm"
22 | x = 100
23 | y = 100
24 | kwargs = {"width": 500}
25 |
26 | image_text_instance = ImageText(
27 | text=text,
28 | font=font,
29 | size=size,
30 | color=color,
31 | anchor=anchor,
32 | x=x,
33 | y=y,
34 | kwargs=kwargs,
35 | )
36 |
37 | assert image_text_instance.text == text
38 | assert image_text_instance.font == font
39 | assert image_text_instance.size == size
40 | assert image_text_instance.color == color
41 | assert image_text_instance.anchor == anchor
42 | assert image_text_instance.x == x
43 | assert image_text_instance.y == y
44 | assert image_text_instance.kwargs == kwargs
45 |
46 | def test_image_text_render(self):
47 | text = "Test Render"
48 | image_text_instance = ImageText(
49 | text=text, font="Arial", size=24, color="black", anchor="mm", x=50, y=50
50 | )
51 | img = Image.new("RGB", (100, 100), color="white")
52 | draw = ImageDraw.Draw(img)
53 | image_text_instance.render(draw)
54 |
55 | # Since it's difficult to assert the actual drawing, we check if the method runs without errors
56 | assert True # This is a placeholder to indicate the test passed by reaching this point without errors
57 |
--------------------------------------------------------------------------------
/docs/example_recipes/gender_gapminder.md:
--------------------------------------------------------------------------------
1 | # Gender gapminder
2 |
3 |
6 |
7 | This example demonstrates how to use `stream` and `connect` to create a video
8 | comparing gender population data from the Gapminder dataset.
9 |
10 | Highlights:
11 |
12 | - Uses `intro_title` and `intro_subtitle` to set the title and subtitle of the video.
13 | - Uses `renderer_kwargs` to pass keyword arguments to the custom `renderer` function.
14 | - Updates `fps` to 30 to create a smoother animation.
15 | - Uses `connect` to concatenate the two heterogeneous streams (different keyword arguments with different titles) together.
16 |
17 | ```python hl_lines="21-23 31 34"
18 | import pandas as pd
19 | from streamjoy import stream, connect
20 |
21 | if __name__ == "__main__":
22 | url_fmt = (
23 | "https://raw.githubusercontent.com/open-numbers/ddf--gapminder--systema_globalis/master/"
24 | "countries-etc-datapoints/ddf--datapoints--{gender}_population_with_projections--by--geo--time.csv"
25 | )
26 | df = pd.concat((
27 | pd.read_csv(url_fmt.format(gender=gender)).set_index(["geo", "time"])
28 | for gender in ["male", "female"]),
29 | axis=1,
30 | )
31 |
32 | streams = []
33 | for country in ["usa", "chn"]:
34 | df_sub = df.loc[country].reset_index().melt("time")
35 | streamed = stream(
36 | df_sub,
37 | groupby="variable",
38 | intro_title="Gapminder",
39 | intro_subtitle=f"{country.upper()} Male vs Female Population",
40 | renderer_kwargs={
41 | "x": "time",
42 | "y": "value",
43 | "xlabel": "Year",
44 | "ylabel": "Population",
45 | "title": f"{country.upper()} {{time}}",
46 | },
47 | max_frames=-1,
48 | fps=30,
49 | )
50 | streams.append(streamed)
51 | connect(streams).write("gender_gapminder.mp4")
52 | ```
--------------------------------------------------------------------------------
/docs/package_design.md:
--------------------------------------------------------------------------------
1 | # Package design
2 |
3 | ## 🪪 Naming
4 |
5 | StreamJoy stems from the idea of streaming parallelized, output images to GIF or MP4, *as they get serialized*.
6 |
7 | This was a mini breakthrough for me, as I had written other packages to try efficiently animating data (e.g. [`enjoyn`](https://enjoyn.readthedocs.io/en/latest/) and [`ahlive`](https://ahlive.readthedocs.io/en/latest/)). However, both of these packages suffered from the bottleneck of having to wait for all the images to get written out to disk before starting generating the animation.
8 |
9 | After, discovering this breakthrough, it brought me joy, and I wanted to share that joy with others by writing a package that reduces the boilerplate and time to work on animations, bringing joy to the user.
10 |
11 | Coincidentally, SJ is also my wife's initials, so it was a perfect fit! :D
12 |
13 | I also was thinking of naming this `streamio` and `streamit`, but the former was already taken and the latter too close to `streamlit`.
14 |
15 | ## 📶 Diagram
16 |
17 | Below is a diagram of the package design. The animation part is actually quite simple--most of the complexity comes with handling various input types, e.g. URLs, files, and datasets.
18 |
19 | ```mermaid
20 | graph TD
21 | A[Start] --> Z{Input Type}
22 | Z -->|URL| U[Download and Assess Content]
23 | Z -->|Direct Input| V{Determine Content Type}
24 | U --> V
25 | V -->|DataFrame| B[Split pandas DataFrame]
26 | V -->|XArray Dataset/DataArray| X[Split XArray by dim]
27 | V -->|HoloViews HoloMap/DynamicMap| H[Split HoloViews by kdim]
28 | V -->|Directory of Images| Y[Glob all files]
29 | V -->|List of Images| D[Open MP4/GIF buffer and start Dask client]
30 | B --> D
31 | X --> D
32 | H --> D
33 | Y --> D
34 | D --> E{Process each frame in parallel with Dask}
35 | E -->|Renderer available| F[Call custom/default renderer]
36 | E -->|Renderer not available| G[Convert to np.array with imread]
37 | F --> G
38 | G --> I[Stream to MP4/GIF buffer with imwrite]
39 | I --> J{All units processed?}
40 | J -->|Yes| K[Save MP4/GIF]
41 | J -->|No| E
42 | K -->|Optimize GIF?| N[Optimize GIF]
43 | N --> L[End]
44 | K --> L
45 | ```
--------------------------------------------------------------------------------
/tests/test_wrappers.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from pathlib import Path
3 |
4 | import holoviews as hv
5 | from matplotlib import pyplot as plt
6 |
7 | from streamjoy.models import Paused
8 | from streamjoy.wrappers import wrap_holoviews, wrap_matplotlib
9 |
10 |
11 | class TestWrapMatplotlib:
12 | def test_wrap_matplotlib_figure_to_file(self, tmp_path):
13 | @wrap_matplotlib(in_memory=False, scratch_dir=tmp_path)
14 | def render_figure():
15 | fig, ax = plt.subplots()
16 | ax.plot([1, 2, 3], [1, 2, 3])
17 | return fig
18 |
19 | path = render_figure()
20 | assert Path(path).exists()
21 |
22 | def test_wrap_matplotlib_figure_to_memory(self):
23 | @wrap_matplotlib(in_memory=True)
24 | def render_figure():
25 | fig, ax = plt.subplots()
26 | ax.plot([1, 2, 3], [1, 2, 3])
27 | return fig
28 |
29 | output = render_figure()
30 | assert isinstance(output, BytesIO)
31 |
32 | def test_wrap_matplotlib_with_paused(self):
33 | @wrap_matplotlib(in_memory=True)
34 | def render_figure():
35 | fig, ax = plt.subplots()
36 | ax.plot([1, 2, 3], [1, 2, 3])
37 | return Paused(output=fig, seconds=5)
38 |
39 | output = render_figure()
40 | assert isinstance(output, Paused)
41 | assert isinstance(output.output, BytesIO)
42 | assert output.seconds == 5
43 |
44 |
45 | class TestWrapHoloViews:
46 | def test_wrap_holoviews_to_file(self, tmp_path):
47 | @wrap_holoviews(in_memory=False, scratch_dir=tmp_path)
48 | def render_hv():
49 | curve = hv.Curve((range(10), range(10)))
50 | return curve
51 |
52 | path = render_hv()
53 | assert Path(path).exists()
54 |
55 | def test_wrap_holoviews_with_paused(self, tmp_path):
56 | @wrap_holoviews(in_memory=False, scratch_dir=tmp_path)
57 | def render_hv():
58 | curve = hv.Curve((range(10), range(10)))
59 | return Paused(output=curve, seconds=5)
60 |
61 | output = render_hv()
62 | assert isinstance(output, Paused)
63 | assert output.output.exists()
64 | assert output.seconds == 5
65 |
--------------------------------------------------------------------------------
/tests/test_core.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from streamjoy.core import connect, stream
7 | from streamjoy.streams import AnyStream, ConnectedStreams, GifStream, Mp4Stream
8 |
9 |
10 | class TestStream:
11 | def test_no_uri(self):
12 | result = stream(resources=[0, 1, 2])
13 | assert isinstance(
14 | result, AnyStream
15 | ), "Expected an instance of AnyStream or its subclass"
16 |
17 | def test_uri(self, df, tmp_path):
18 | uri = tmp_path / "test.mp4"
19 | stream(resources=df, uri=uri)
20 | assert uri.exists(), "Stream file should exist"
21 |
22 | def test_uri_with_bytesio(self, df):
23 | uri = BytesIO()
24 | result = stream(resources=df, uri=uri, extension=".mp4")
25 | assert isinstance(
26 | result, BytesIO
27 | ), "Expected a BytesIO object when URI is a BytesIO instance"
28 |
29 | def test_uri_with_str_path(self, df, tmp_path):
30 | uri = str(tmp_path / "test.gif")
31 | stream(resources=df, uri=uri)
32 | assert Path(uri).exists(), "Stream file should exist when URI is a string path"
33 |
34 | @pytest.mark.parametrize("extension", [".mp4", ".gif"])
35 | def test_stream_with_different_extensions(self, df, extension):
36 | result = stream(resources=df, extension=extension)
37 | expected_class = Mp4Stream if extension == ".mp4" else GifStream
38 | assert isinstance(
39 | result, expected_class
40 | ), f"Expected an instance of {expected_class.__name__}"
41 |
42 |
43 | class TestConnect:
44 | def test_connect_streams(self):
45 | stream1 = stream(resources=[0, 1, 2])
46 | stream2 = stream(resources=[3, 4, 5])
47 | result = connect(streams=[stream1, stream2])
48 | assert isinstance(
49 | result, ConnectedStreams
50 | ), "Expected an instance of ConnectedStreams"
51 |
52 | def test_connect_streams_with_uri(self, df, tmp_path):
53 | stream1 = stream(resources=df)
54 | stream2 = stream(resources=df)
55 | uri = tmp_path / "connected.mp4"
56 | connect(streams=[stream1, stream2], uri=uri)
57 | assert uri.exists(), "Connected stream file should exist"
58 |
--------------------------------------------------------------------------------
/docs/api_reference/settings.md:
--------------------------------------------------------------------------------
1 | # Settings
2 |
3 | ```python
4 | config = {
5 | # animation
6 | "fps": 8,
7 | "max_frames": 50,
8 | # dask
9 | "batch_size": 10,
10 | "processes": True,
11 | "threads_per_worker": None,
12 | # intro
13 | "intro_pause": 2,
14 | "intro_watermark": "made with streamjoy",
15 | "intro_background": "black",
16 | # from_url
17 | "max_files": 2,
18 | # matplotlib
19 | "max_open_warning": 100,
20 | # output
21 | "in_memory": False,
22 | "scratch_dir": "streamjoy_scratch",
23 | "uri": None,
24 | # imageio
25 | "codec": "libx264",
26 | "loop": 0,
27 | "ending_pause": 2,
28 | # gif
29 | "optimize": False,
30 | # image text
31 | "image_text_font": "Avenir.ttc",
32 | "image_text_size": 20,
33 | "image_text_color": "white",
34 | "image_text_background": "black",
35 | # notebook
36 | "display": True,
37 | # logging
38 | "logging_success_level": 25,
39 | "logging_level": 25,
40 | "logging_format": "[%(levelname)s] %(asctime)s: %(message)s",
41 | "logging_datefmt": "%I:%M%p",
42 | "logging_warning_color": "\x1b[31;1m",
43 | "logging_success_color": "\x1b[32;1m",
44 | "logging_reset_color": "\x1b[0m",
45 | }
46 | ```
47 |
48 | ```python
49 | obj_handlers = {
50 | "xarray.Dataset": "_expand_from_xarray",
51 | "xarray.DataArray": "_expand_from_xarray",
52 | "pandas.DataFrame": "_expand_from_pandas",
53 | "pandas.Series": "_expand_from_pandas",
54 | "holoviews": "_expand_from_holoviews",
55 | }
56 | ```
57 |
58 | ```python
59 | file_handlers = {
60 | ".nc": {
61 | "import_path": "xarray.open_mfdataset",
62 | },
63 | ".nc4": {
64 | "import_path": "xarray.open_mfdataset",
65 | },
66 | ".zarr": {
67 | "import_path": "xarray.open_zarr",
68 | },
69 | ".grib": {
70 | "import_path": "xarray.open_mfdataset",
71 | "kwargs": {"engine": "cfgrib"},
72 | },
73 | ".csv": {
74 | "import_path": "pandas.read_csv",
75 | "concat_path": "pandas.concat",
76 | },
77 | ".parquet": {
78 | "import_path": "pandas.read_parquet",
79 | "concat_path": "pandas.concat",
80 | },
81 | ".html": {
82 | "import_path": "pandas.read_html",
83 | "concat_path": "pandas.concat",
84 | },
85 | }
86 | ```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | target-version = "py310"
3 |
4 | [tool.ruff.lint]
5 | extend-select = ["I", "UP"]
6 | extend-ignore = ["TRY003"]
7 |
8 | [tool.ruff.lint.flake8-type-checking]
9 | quote-annotations = true
10 |
11 | [tool.pytest.ini_options]
12 | addopts = "--cov=streamjoy/ --cov-report=term-missing"
13 |
14 |
15 | [tool.hatch]
16 |
17 | [tool.hatch.metadata]
18 | allow-direct-references = true
19 |
20 | [tool.hatch.version]
21 | source = "regex_commit"
22 | commit_extra_args = ["-e"]
23 | path = "streamjoy/__init__.py"
24 |
25 | [tool.hatch.envs.default]
26 | python = "3.9"
27 | dependencies = [
28 | "ruff",
29 | "pytest",
30 | "pytest-cov",
31 | "mkdocs-material",
32 | "mkdocstrings[python]",
33 | "xarray",
34 | "pandas",
35 | "polars",
36 | "zarr",
37 | "netcdf4",
38 | "matplotlib",
39 | "pyarrow",
40 | "hvplot",
41 | "bs4",
42 | "selenium",
43 | "webdriver_manager",
44 | "panel",
45 | ]
46 |
47 | [tool.hatch.envs.default.scripts]
48 | test = "pytest"
49 | test-cov-xml = "pytest --cov-report=xml"
50 | lint = [
51 | "ruff format .",
52 | "ruff check --fix .",
53 | ]
54 | lint-check = [
55 | "ruff format --check .",
56 | "ruff check .",
57 | ]
58 | docs-serve = "mkdocs serve"
59 | docs-build = "mkdocs build"
60 |
61 | [build-system]
62 | requires = ["hatchling", "hatch-regex-commit"]
63 | build-backend = "hatchling.build"
64 |
65 | [project]
66 | name = "streamjoy"
67 | authors = [
68 | { name = "streamjoy", email = "hey.at.py@gmail.com" }
69 | ]
70 | description = "Enjoy animating images into GIFs and MP4s!"
71 | readme = "README.md"
72 | dynamic = ["version"]
73 | classifiers = [
74 | "Programming Language :: Python :: 3 :: Only",
75 | ]
76 | requires-python = ">=3.9"
77 | dependencies = [
78 | "param>2",
79 | "dask[distributed]",
80 | "imageio[pyav]>=2.34.0",
81 | "pygifsicle==1.0.5",
82 | "requests",
83 | "bs4",
84 | ]
85 |
86 | [project.urls]
87 | Documentation = "https://ahuang11.github.io/streamjoy/"
88 | Source = "https://github.com/ahuang11/streamjoy"
89 |
90 | [project.scripts]
91 | streamjoy = "streamjoy.cli:main"
92 |
93 | [project.optional-dependencies]
94 | ui = [
95 | "panel",
96 | "param",
97 | "requests",
98 | "xarray",
99 | "netcdf4",
100 | ]
--------------------------------------------------------------------------------
/streamjoy/settings.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # animation
3 | "fps": 8,
4 | "max_frames": 50,
5 | # dask
6 | "batch_size": 10,
7 | "processes": True,
8 | "threads_per_worker": None,
9 | # intro
10 | "intro_pause": 2,
11 | "intro_watermark": "made with streamjoy",
12 | "intro_background": "black",
13 | # from_url
14 | "max_files": 2,
15 | # download
16 | "parent_depth": 4,
17 | # matplotlib
18 | "max_open_warning": 100,
19 | # holoviews
20 | "webdriver": "firefox",
21 | "num_retries": 5,
22 | # output
23 | "in_memory": False,
24 | "scratch_dir": "streamjoy_scratch",
25 | "uri": None,
26 | # imageio
27 | "codec": "libx264",
28 | "loop": 0,
29 | "ending_pause": 2,
30 | # gif
31 | "optimize": False,
32 | # image text
33 | "image_text_font": "Avenir.ttc",
34 | "image_text_size": 20,
35 | "image_text_color": "white",
36 | "image_text_background": "black",
37 | # notebook
38 | "display": True,
39 | # logging
40 | "logging_success_level": 25,
41 | "logging_level": 20,
42 | "logging_format": "[%(levelname)s] %(asctime)s: %(message)s",
43 | "logging_datefmt": "%I:%M%p",
44 | "logging_warning_color": "\x1b[31;1m",
45 | "logging_success_color": "\x1b[32;1m",
46 | "logging_reset_color": "\x1b[0m",
47 | }
48 |
49 | extension_handlers = {
50 | None: "AnyStream",
51 | ".mp4": "Mp4Stream",
52 | ".gif": "GifStream",
53 | ".html": "HtmlStream",
54 | }
55 |
56 | obj_handlers = {
57 | "xarray.Dataset": "serialize_xarray",
58 | "xarray.DataArray": "serialize_xarray",
59 | "pandas.DataFrame": "serialize_pandas",
60 | "pandas.Series": "serialize_pandas",
61 | "holoviews": "serialize_holoviews",
62 | "polars.DataFrame": "serialize_polars",
63 | "numpy.ndarray": "serialize_numpy",
64 | }
65 |
66 | file_handlers = {
67 | ".nc": {
68 | "import_path": "xarray.open_mfdataset",
69 | },
70 | ".nc4": {
71 | "import_path": "xarray.open_mfdataset",
72 | },
73 | ".zarr": {
74 | "import_path": "xarray.open_zarr",
75 | },
76 | ".grib": {
77 | "import_path": "xarray.open_mfdataset",
78 | "kwargs": {"engine": "cfgrib"},
79 | },
80 | ".csv": {
81 | "import_path": "pandas.read_csv",
82 | "concat_path": "pandas.concat",
83 | },
84 | ".parquet": {
85 | "import_path": "pandas.read_parquet",
86 | "concat_path": "pandas.concat",
87 | },
88 | ".html": {
89 | "import_path": "pandas.read_html",
90 | "concat_path": "pandas.concat",
91 | },
92 | }
93 |
--------------------------------------------------------------------------------
/tests/test_serializers.py:
--------------------------------------------------------------------------------
1 | from streamjoy.models import Serialized
2 | from streamjoy.serializers import (
3 | serialize_holoviews,
4 | serialize_numpy,
5 | serialize_pandas,
6 | serialize_polars,
7 | serialize_xarray,
8 | )
9 |
10 |
11 | class TestSerializeNumpy:
12 | def test_serialize_numpy(self, array):
13 | serialized = serialize_numpy(None, array)
14 | assert isinstance(serialized, Serialized)
15 | assert len(serialized.resources) == 36
16 | assert isinstance(serialized.resources, list)
17 | assert not serialized.renderer
18 | assert serialized.renderer_iterables is None
19 | assert isinstance(serialized.renderer_kwargs, dict)
20 | assert isinstance(serialized.kwargs, dict)
21 |
22 |
23 | class TestSerializeXarray:
24 | def test_serialize_xarray(self, ds):
25 | serialized = serialize_xarray(None, ds)
26 | assert isinstance(serialized, Serialized)
27 | assert len(serialized.resources) == 3
28 | assert isinstance(serialized.resources, list)
29 | assert callable(serialized.renderer)
30 | assert serialized.renderer_iterables is None
31 | assert isinstance(serialized.renderer_kwargs, dict)
32 | assert isinstance(serialized.kwargs, dict)
33 |
34 |
35 | class TestSerializePandas:
36 | def test_serialize_pandas(self, df):
37 | serialized = serialize_pandas(None, df)
38 | assert isinstance(serialized, Serialized)
39 | assert len(serialized.resources) == 3
40 | assert isinstance(serialized.resources, list)
41 | assert callable(serialized.renderer)
42 | assert serialized.renderer_iterables is None
43 | assert isinstance(serialized.renderer_kwargs, dict)
44 | assert isinstance(serialized.kwargs, dict)
45 |
46 |
47 | class TestSerializePolars:
48 | def test_serialize_polars(self, pl_df):
49 | serialized = serialize_polars(None, pl_df)
50 | assert isinstance(serialized, Serialized)
51 | assert len(serialized.resources) == 3
52 | assert isinstance(serialized.resources, list)
53 | assert callable(serialized.renderer)
54 | assert serialized.renderer_iterables is None
55 | assert isinstance(serialized.renderer_kwargs, dict)
56 | assert isinstance(serialized.kwargs, dict)
57 |
58 |
59 | class TestSerializeHoloviews:
60 | def test_serialize_holoviews(self, hmap):
61 | serialized = serialize_holoviews(None, hmap)
62 | assert isinstance(serialized, Serialized)
63 | assert len(serialized.resources) == 20
64 | assert isinstance(serialized.resources, list)
65 | assert callable(serialized.renderer)
66 | assert serialized.renderer_iterables is None
67 | assert isinstance(serialized.renderer_kwargs, dict)
68 | assert isinstance(serialized.kwargs, dict)
69 |
--------------------------------------------------------------------------------
/streamjoy/core.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from io import BytesIO
4 | from pathlib import Path
5 | from typing import Any, Callable, Literal
6 |
7 | from . import streams
8 | from .serializers import serialize_appropriately
9 | from .settings import extension_handlers
10 | from .streams import AnyStream, ConnectedStreams, GifStream, HtmlStream, Mp4Stream
11 |
12 |
13 | def stream(
14 | resources: Any,
15 | uri: str | Path | BytesIO | None = None,
16 | renderer: Callable | None = None,
17 | renderer_iterables: list | None = None,
18 | renderer_kwargs: dict | None = None,
19 | extension: Literal[".mp4", ".gif"] | None = None,
20 | **kwargs: dict[str, Any],
21 | ) -> AnyStream | GifStream | Mp4Stream | HtmlStream | Path:
22 | """
23 | Create a stream from the given resources.
24 |
25 | Args:
26 | resources: The resources to create a stream from.
27 | uri: The destination to write the stream to. If None, the stream is returned.
28 | renderer: The renderer to use. If None, the default renderer is used.
29 | renderer_iterables: Additional positional arguments to map over the renderer.
30 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
31 | extension: The extension to use; useful if uri is a file-like object.
32 | **kwargs: Additional keyword arguments to pass.
33 |
34 | Returns:
35 | The stream if uri is None, otherwise the uri.
36 | """
37 | if isinstance(uri, str):
38 | uri = Path(uri)
39 |
40 | extension = extension or (uri and uri.suffix)
41 | if isinstance(extension, str) and extension not in extension_handlers:
42 | raise ValueError(f"Unsupported extension: {extension}")
43 |
44 | stream_cls = getattr(streams, extension_handlers.get(extension), AnyStream)
45 | serialized = serialize_appropriately(
46 | stream_cls,
47 | resources,
48 | renderer,
49 | renderer_iterables,
50 | renderer_kwargs or {},
51 | **kwargs,
52 | )
53 | stream = stream_cls(**serialized.param.values(), **serialized.kwargs)
54 |
55 | if uri:
56 | return stream.write(uri=uri, extension=extension)
57 | return stream
58 |
59 |
60 | def connect(
61 | streams: list[AnyStream | GifStream | Mp4Stream | HtmlStream],
62 | uri: str | Path | BytesIO | None = None,
63 | ) -> ConnectedStreams | Path:
64 | """
65 | Connect hetegeneous streams into a single stream.
66 |
67 | Unlike `stream.join`, this function can connect streams
68 | with unique params, such as different renderers.
69 |
70 | Args:
71 | streams: The streams to connect.
72 | uri: The destination to write the connected streams to.
73 | If None, the connected streams are returned.
74 |
75 | Returns:
76 | The connected streams if uri is None, otherwise the uri.
77 | """
78 | stream = ConnectedStreams(streams=streams)
79 | if uri:
80 | return stream.write(uri=uri)
81 | return stream
82 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: streamjoy
2 | site_description: Enjoy animating images into GIFs and MP4s!
3 |
4 | repo_url: https://github.com/ahuang11/streamjoy
5 | repo_name: ahuang11/streamjoy
6 |
7 | theme:
8 | name: material
9 | features:
10 | - content.code.copy
11 | logo: assets/logo.png
12 | palette:
13 | # Palette toggle for automatic mode
14 | - media: "(prefers-color-scheme)"
15 | toggle:
16 | icon: material/brightness-auto
17 | name: Switch to light mode
18 |
19 | # Palette toggle for light mode
20 | - media: "(prefers-color-scheme: light)"
21 | scheme: default
22 | primary: custom
23 | accent: custom
24 | toggle:
25 | icon: material/brightness-7
26 | name: Switch to dark mode
27 |
28 | # Palette toggle for dark mode
29 | - media: "(prefers-color-scheme: dark)"
30 | scheme: slate
31 | primary: custom
32 | accent: custom
33 | toggle:
34 | icon: material/brightness-4
35 | name: Switch to light mode
36 |
37 | markdown_extensions:
38 | - toc:
39 | permalink: true
40 | - pymdownx.highlight:
41 | anchor_linenums: true
42 | - pymdownx.tasklist:
43 | custom_checkbox: true
44 | - admonition
45 | - pymdownx.details
46 | - pymdownx.tabbed
47 | - pymdownx.magiclink
48 | - pymdownx.inlinehilite
49 | - pymdownx.snippets
50 | - pymdownx.superfences:
51 | custom_fences:
52 | - name: mermaid
53 | class: mermaid
54 | format: !!python/name:pymdownx.superfences.fence_code_format
55 |
56 | plugins:
57 | - search
58 | - mkdocstrings:
59 | handlers:
60 | python:
61 | options:
62 | docstring_style: google
63 | find_stubs_package: true
64 |
65 | watch:
66 | - docs
67 | - streamjoy
68 |
69 | nav:
70 | - Read me!: index.md
71 | - Supported formats: supported_formats.md
72 | - How do I...: how_do_i.md
73 | - Package design: package_design.md
74 | - Example recipes:
75 | - Air temperature: example_recipes/air_temperature.md
76 | - Sine wave: example_recipes/sine_wave.md
77 | - CO2 timeseries: example_recipes/co2_timeseries.md
78 | - Temperature anomaly: example_recipes/temperature_anomaly.md
79 | - Sea ice: example_recipes/sea_ice.md
80 | - OISST globe: example_recipes/oisst_globe.md
81 | - Gender gapminder: example_recipes/gender_gapminder.md
82 | - Stream code: example_recipes/stream_code.md
83 | - Nice orbit: example_recipes/nice_orbit.md
84 | - NMME forecast: example_recipes/nmme_forecast.md
85 | - API reference:
86 | - Core: api_reference/core.md
87 | - Models: api_reference/models.md
88 | - Streams: api_reference/streams.md
89 | - Renderers: api_reference/renderers.md
90 | - Wrappers: api_reference/wrappers.md
91 | - Settings: api_reference/settings.md
92 |
93 | extra_css:
94 | - stylesheets/extra.css
95 |
--------------------------------------------------------------------------------
/docs/example_recipes/nice_orbit.md:
--------------------------------------------------------------------------------
1 | # Nice orbit
2 |
3 |
6 |
7 | Creates a visually appealing, nice orbits of a 2d dynamical system.
8 |
9 | Code adapted from [Nice_orbits.ipynb](https://github.com/profConradi/Python_Simulations/blob/main/Nice_orbits.ipynb).
10 | All credits go to [Simone Conradi](https://github.com/profConradi); the only addition here was wrapping the code into a function and using `streamjoy` to create the animation. Please consider giving the [Python_Simulations](https://github.com/profConradi/Python_Simulations/tree/main) repo a star!
11 |
12 | Highlights:
13 |
14 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure.
15 | - Uses a custom `renderer` function to create each frame of the animation.
16 |
17 | ```python hl_lines="45 46"
18 | import numpy as np
19 | import matplotlib.pyplot as plt
20 | from numba import njit
21 | from streamjoy import stream, wrap_matplotlib
22 |
23 | @njit
24 | def meshgrid(x, y):
25 | """
26 | This function replace np.meshgrid that is not supported by numba
27 | """
28 | xx = np.empty(shape=(x.size, y.size), dtype=x.dtype)
29 | yy = np.empty(shape=(x.size, y.size), dtype=y.dtype)
30 | for j in range(y.size):
31 | for k in range(x.size):
32 | xx[j, k] = k # change to x[k] if indexing xy
33 | yy[j, k] = j # change to y[j] if indexing xy
34 | return xx, yy
35 |
36 | @njit
37 | def calc_orbit(n_points, a, b, n_iter):
38 | """
39 | This function calculate orbits in a vectorized fashion.
40 |
41 | -n_points: lattice of initial conditions, n_points x n_points in [-1,1]x[-1,1]
42 | -a: first parameter of the dynamical system
43 | -b: second parameter of the dynamical system
44 | -n_iter: number of iterations
45 |
46 | Return: two ndarrays: x and y coordinates of every point of every orbit.
47 | """
48 | area = [[-1, 1], [-1, 1]]
49 | x = np.linspace(area[0][0], area[0][1], n_points)
50 | y = np.linspace(area[1][0], area[1][1], n_points)
51 | xx, yy = meshgrid(x, y)
52 | l_cx, l_cy = np.zeros(n_iter * n_points**2), np.zeros(n_iter * n_points**2)
53 | for i in range(n_iter):
54 | xx_new = np.sin(xx**2 - yy**2 + a)
55 | yy_new = np.cos(2 * xx * yy + b)
56 | xx = xx_new
57 | yy = yy_new
58 | l_cx[i * n_points**2 : (i + 1) * n_points**2] = xx.flatten()
59 | l_cy[i * n_points**2 : (i + 1) * n_points**2] = yy.flatten()
60 | return l_cx, l_cy
61 |
62 | @wrap_matplotlib()
63 | def plot_frame(n):
64 | l_cx, l_cy = calc_orbit(n_points, a + 0.002 * n, b - 0.001 * n, n)
65 | area = [[-1, 1], [-1, 1]]
66 | h, _, _ = np.histogram2d(l_cx, l_cy, bins=3000, range=area)
67 | fig, ax = plt.subplots(figsize=(10, 10))
68 | fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
69 | ax.imshow(np.log(h + 1), vmin=0, vmax=5, cmap="magma")
70 | plt.xticks([]), plt.yticks([])
71 | return fig
72 |
73 | if __name__ == "__main__":
74 | n_points = 500
75 | a, b = 5.48, 4.28
76 | stream(np.arange(1, 100).tolist(), renderer=plot_frame, uri="nice_orbit.mp4")
77 | ```
78 |
--------------------------------------------------------------------------------
/streamjoy/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any
4 |
5 | import param
6 | from PIL import ImageDraw, ImageFont
7 |
8 | from . import _utils
9 |
10 |
11 | class Paused(param.Parameterized):
12 | """
13 | A data model for pausing a stream.
14 |
15 | Expand Source code to see all the parameters and descriptions.
16 | """
17 |
18 | output = param.Parameter(doc="The output to pause for.")
19 |
20 | seconds = param.Number(doc="The number of seconds to pause for.")
21 |
22 | def __init__(self, output: Any, seconds: int, **params):
23 | self.output = output
24 | self.seconds = seconds
25 | super().__init__(**params)
26 |
27 |
28 | class ImageText(param.Parameterized):
29 | """
30 | A data model for rendering text on an image.
31 |
32 | Expand Source code to see all the parameters and descriptions.
33 | """
34 |
35 | text = param.String(
36 | doc="The text to render.",
37 | )
38 |
39 | font = param.String(
40 | doc="The font to use for the text.",
41 | )
42 |
43 | size = param.Integer(
44 | doc="The font size to use for the text.",
45 | )
46 |
47 | color = param.String(
48 | doc="The color to use for the text.",
49 | )
50 |
51 | anchor = param.String(
52 | doc="The anchor to use for the text.",
53 | )
54 |
55 | x = param.Integer(
56 | doc="The x-coordinate to use for the text.",
57 | )
58 |
59 | y = param.Integer(
60 | doc="The y-coordinate to use for the text.",
61 | )
62 |
63 | kwargs = param.Dict(
64 | default={},
65 | doc="Additional keyword arguments to pass to the text renderer.",
66 | )
67 |
68 | def __init__(self, text: str, **params) -> None:
69 | params["text"] = text
70 | params = _utils.populate_config_defaults(
71 | params, self.param, config_prefix="image_text"
72 | )
73 | super().__init__(**params)
74 |
75 | def render(
76 | self,
77 | draw: ImageDraw,
78 | width: int | None = None,
79 | height: int | None = None,
80 | ) -> None:
81 | x = self.x or width // 2
82 | y = self.y or height // 2
83 | try:
84 | font = ImageFont.truetype(self.font, self.size)
85 | except Exception:
86 | font = ImageFont.load_default()
87 | draw.text(
88 | (x, y),
89 | self.text,
90 | font=font,
91 | fill=self.color,
92 | anchor=self.anchor,
93 | **self.kwargs,
94 | )
95 |
96 |
97 | class Serialized(param.Parameterized):
98 | resources = param.List(doc="The list of resources.")
99 | renderer = param.Callable(doc="The rendering function to use on the resources.")
100 | renderer_iterables = param.List(
101 | doc="Additional iterable arguments to pass to the renderer.",
102 | allow_None=True,
103 | )
104 | renderer_kwargs = param.Dict(
105 | doc="Additional keyword arguments to pass to the renderer.",
106 | allow_None=True,
107 | )
108 |
109 | def __init__(
110 | self, resources, renderer, renderer_iterables, renderer_kwargs, kwargs, **params
111 | ):
112 | super().__init__(
113 | resources=resources,
114 | renderer=renderer,
115 | renderer_iterables=renderer_iterables,
116 | renderer_kwargs=renderer_kwargs,
117 | **params,
118 | )
119 | self.kwargs = kwargs
120 |
--------------------------------------------------------------------------------
/tests/data/gapminder.csv:
--------------------------------------------------------------------------------
1 | ,Year,Country,fertility,life,population,child_mortality,gdp,region
2 | 0,1964,Bangladesh,6.853,49.297,56071080.0,240.4,1138.0,South Asia
3 | 1,1965,Bangladesh,6.877999999999999,49.652,57791778.0,236.0,1177.0,South Asia
4 | 2,1966,Bangladesh,6.901,49.734,59681186.0,232.3,1169.0,South Asia
5 | 3,1967,Bangladesh,6.92,49.50899999999999,61702785.0,229.4,1122.0,South Asia
6 | 4,1968,Bangladesh,6.935,49.008,63703327.0,227.1,1202.0,South Asia
7 | 5,1969,Bangladesh,6.945,48.307,65474474.0,225.4,1195.0,South Asia
8 | 6,1970,Bangladesh,6.947,47.58,66881158.0,224.1,1226.0,South Asia
9 | 7,1971,Bangladesh,6.942,47.04600000000001,67849312.0,223.0,1142.0,South Asia
10 | 8,1972,Bangladesh,6.928,46.874,68461849.0,222.0,986.0,South Asia
11 | 9,1973,Bangladesh,6.904,47.162,68933357.0,220.7,971.0,South Asia
12 | 50,1964,China,6.12,53.32072,696171650.0,130.77,713.0,East Asia & Pacific
13 | 51,1965,China,6.022,55.6468,710290299.0,115.43,772.0,East Asia & Pacific
14 | 52,1966,China,6.211,56.8032,727601056.0,120.8,826.0,East Asia & Pacific
15 | 53,1967,China,5.252000000000001,58.38112,747678772.0,126.42,719.0,East Asia & Pacific
16 | 54,1968,China,6.37,59.4052,769666505.0,132.3,669.0,East Asia & Pacific
17 | 55,1969,China,5.67,60.9652,792308749.0,119.1,732.0,East Asia & Pacific
18 | 56,1970,China,5.746,62.6508,814622841.0,113.3,848.0,East Asia & Pacific
19 | 57,1971,China,5.396,63.73736,836431505.0,107.7,876.0,East Asia & Pacific
20 | 58,1972,China,4.92,63.11888,857804035.0,102.1,843.0,East Asia & Pacific
21 | 59,1973,China,4.506,62.7808,878305091.0,96.5,894.0,East Asia & Pacific
22 | 100,1964,India,5.807,44.375,486038945.0,232.9,1125.0,South Asia
23 | 101,1965,India,5.781000000000001,45.141000000000005,496400381.0,229.6,1053.0,South Asia
24 | 102,1966,India,5.747000000000001,45.903,507115411.0,226.3,1037.0,South Asia
25 | 103,1967,India,5.7010000000000005,46.655,518192403.0,223.1,1096.0,South Asia
26 | 104,1968,India,5.642,47.394,529658233.0,219.8,1095.0,South Asia
27 | 105,1969,India,5.573,48.119,541544619.0,216.6,1141.0,South Asia
28 | 106,1970,India,5.494,48.836000000000006,553873890.0,213.3,1170.0,South Asia
29 | 107,1971,India,5.41,49.554,566651479.0,209.9,1154.0,South Asia
30 | 108,1972,India,5.323,50.281000000000006,579871075.0,206.1,1125.0,South Asia
31 | 109,1973,India,5.238,51.015,593526633.0,202.2,1151.0,South Asia
32 | 150,1964,South Africa,5.984,50.574,19308185.0,174.55,9004.0,Sub-Saharan Africa
33 | 151,1965,South Africa,5.9110000000000005,50.96,19813932.0,169.51,9255.0,Sub-Saharan Africa
34 | 152,1966,South Africa,5.836,51.348,20325179.0,164.61,9371.0,Sub-Saharan Africa
35 | 153,1967,South Africa,5.765,51.73,20843695.0,159.85,9720.0,Sub-Saharan Africa
36 | 154,1968,South Africa,5.7,52.104,21374801.0,155.24,9849.0,Sub-Saharan Africa
37 | 155,1969,South Africa,5.643,52.47,21926000.0,150.75,10156.0,Sub-Saharan Africa
38 | 156,1970,South Africa,5.591,52.825,22502306.0,146.4,10394.0,Sub-Saharan Africa
39 | 157,1971,South Africa,5.539,53.168,23106584.0,142.17,10654.0,Sub-Saharan Africa
40 | 158,1972,South Africa,5.482,53.504,23736249.0,138.06,10615.0,Sub-Saharan Africa
41 | 159,1973,South Africa,5.415,53.838,24384286.0,134.07,10813.0,Sub-Saharan Africa
42 | 200,1964,United States,3.222,70.33,197094531.0,27.7,20338.0,America
43 | 201,1965,United States,2.926,70.41,199452508.0,27.1,21361.0,America
44 | 202,1966,United States,2.714,70.43,201657141.0,26.4,22495.0,America
45 | 203,1967,United States,2.564,70.76,203717833.0,25.7,22803.0,America
46 | 204,1968,United States,2.467,70.42,205672498.0,24.9,23647.0,America
47 | 205,1969,United States,2.457,70.66,207573866.0,24.1,24147.0,America
48 | 206,1970,United States,2.461,70.92,209463865.0,23.3,23908.0,America
49 | 207,1971,United States,2.268,71.24,211355529.0,22.4,24350.0,America
50 | 208,1972,United States,2.008,71.34,213250350.0,21.5,25374.0,America
51 | 209,1973,United States,1.871,71.54,215164616.0,20.6,26567.0,America
52 |
--------------------------------------------------------------------------------
/docs/example_recipes/stream_code.md:
--------------------------------------------------------------------------------
1 | # Stream code
2 |
3 |
4 |
5 | Generates an animation of a code snippet being written character by character.
6 |
7 | Highlights:
8 |
9 | - Uses a custom `renderer` function to create each frame of the animation.
10 | - Propagates `formatter`, `max_line_length`, and `max_line_number` to the custom `renderer` function.
11 |
12 | ```python hl_lines="51 102-104"
13 | from textwrap import dedent
14 |
15 | import numpy as np
16 | from PIL import Image, ImageDraw
17 | from pygments import lex
18 | from pygments.formatters import ImageFormatter
19 | from pygments.lexers import get_lexer_by_name
20 | from streamjoy import stream
21 |
22 | def _custom_format(
23 | formatter: ImageFormatter,
24 | tokensource: list[tuple],
25 | max_line_length: int = None,
26 | max_line_number: int = None,
27 | ) -> Image:
28 | formatter._create_drawables(tokensource)
29 | formatter._draw_line_numbers()
30 | max_line_length = max_line_length or formatter.maxlinelength
31 | max_line_number = max_line_number or formatter.maxlineno
32 |
33 | image = Image.new(
34 | "RGB",
35 | formatter._get_image_size(max_line_length, max_line_number),
36 | formatter.background_color,
37 | )
38 | formatter._paint_line_number_bg(image)
39 | draw = ImageDraw.Draw(image)
40 | # Highlight
41 | if formatter.hl_lines:
42 | x = (
43 | formatter.image_pad
44 | + formatter.line_number_width
45 | - formatter.line_number_pad
46 | + 1
47 | )
48 | recth = formatter._get_line_height()
49 | rectw = image.size[0] - x
50 | for linenumber in formatter.hl_lines:
51 | y = formatter._get_line_y(linenumber - 1)
52 | draw.rectangle([(x, y), (x + rectw, y + recth)], fill=formatter.hl_color)
53 | for pos, value, font, text_fg, text_bg in formatter.drawables:
54 | if text_bg:
55 | text_size = draw.textsize(text=value, font=font)
56 | draw.rectangle(
57 | [pos[0], pos[1], pos[0] + text_size[0], pos[1] + text_size[1]],
58 | fill=text_bg,
59 | )
60 | draw.text(pos, value, font=font, fill=text_fg)
61 | return np.asarray(image)
62 |
63 | def render_frame(
64 | code: str,
65 | formatter: ImageFormatter,
66 | max_line_length: int = None,
67 | max_line_number: int = None,
68 | ) -> Image:
69 | lexer = get_lexer_by_name("python")
70 | return _custom_format(
71 | formatter,
72 | lex(code, lexer),
73 | max_line_length=max_line_length,
74 | max_line_number=max_line_number,
75 | )
76 |
77 | if __name__ == "__main__":
78 | code = dedent(
79 | """
80 | import matplotlib.pyplot as plt
81 | import numpy as np
82 |
83 | from streamjoy import stream, wrap_matplotlib
84 |
85 | @wrap_matplotlib()
86 | def plot_frame(i):
87 | x = np.linspace(0, 2, 1000)
88 | y = np.sin(2 * np.pi * (x - 0.01 * i))
89 | fig, ax = plt.subplots()
90 | ax.plot(x, y)
91 | return fig
92 |
93 | stream(list(range(10)), uri="sine_wave.mp4", renderer=plot_frame)
94 | """
95 | )
96 |
97 | formatter = ImageFormatter(
98 | image_format="gif",
99 | line_pad=8,
100 | line_number_bg=None,
101 | line_number_fg=None,
102 | encoding="utf-8",
103 | )
104 | longest_line = max(code.splitlines(), key=len) + " " * 12
105 | max_line_length, _ = formatter.fonts.get_text_size(longest_line)
106 | max_line_number = code.count("\n") + 1
107 | items = [code[:i] for i in range(0, len(code) + 3, 3)]
108 |
109 | stream(
110 | items,
111 | ending_pause=20,
112 | uri="stream_code.gif",
113 | renderer=render_frame,
114 | formatter=formatter,
115 | max_line_length=max_line_length,
116 | max_line_number=max_line_number,
117 | )
118 | ```
119 |
--------------------------------------------------------------------------------
/tests/data/air.zarr/.zmetadata:
--------------------------------------------------------------------------------
1 | {
2 | "metadata": {
3 | ".zattrs": {
4 | "Conventions": "COARDS",
5 | "description": "Data is from NMC initialized reanalysis\n(4x/day). These are the 0.9950 sigma level values.",
6 | "platform": "Model",
7 | "references": "http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanalysis.html",
8 | "title": "4x daily NMC reanalysis (1948)"
9 | },
10 | ".zgroup": {
11 | "zarr_format": 2
12 | },
13 | "air/.zarray": {
14 | "chunks": [
15 | 20,
16 | 25,
17 | 53
18 | ],
19 | "compressor": {
20 | "blocksize": 0,
21 | "clevel": 5,
22 | "cname": "lz4",
23 | "id": "blosc",
24 | "shuffle": 1
25 | },
26 | "dtype": "
4 |
5 |
6 |
7 | Shows the yearly CO2 measurements from the Mauna Loa Observatory in Hawaii.
8 |
9 | The data is sourced from the [datasets/co2-ppm-daily](https://github.com/datasets/co2-ppm-daily/blob/master/co2-ppm-daily-flow.py).
10 |
11 | Highlights:
12 |
13 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure.
14 | - Uses a custom `renderer` function to create each frame of the animation.
15 | - Uses `Paused` to pause the animation at notable dates.
16 |
17 | ```python hl_lines="4 19 114"
18 | import pandas as pd
19 | import matplotlib.pyplot as plt
20 | from matplotlib.ticker import AutoMinorLocator
21 | from streamjoy import stream, wrap_matplotlib, Paused
22 |
23 | URL = "https://raw.githubusercontent.com/datasets/co2-ppm-daily/master/data/co2-ppm-daily.csv"
24 | NOTABLE_YEARS = {
25 | 1958: "Mauna Loa measurements begin",
26 | 1979: "1st World Climate Conference",
27 | 1997: "Kyoto Protocol",
28 | 2005: "exceeded 380 ppm",
29 | 2010: "exceeded 390 ppm",
30 | 2013: "exceeded 400 ppm",
31 | 2015: "Paris Agreement",
32 | }
33 |
34 |
35 | @wrap_matplotlib()
36 | def renderer(df):
37 | plt.style.use("dark_background")
38 |
39 | fig, ax = plt.subplots(figsize=(7, 5))
40 | fig.patch.set_facecolor("#1b1e23")
41 | ax.set_facecolor("#1b1e23")
42 | ax.set_frame_on(False)
43 | ax.axis("off")
44 | ax.set_title(
45 | "CO2 Yearly Max",
46 | fontsize=20,
47 | loc="left",
48 | fontname="Courier New",
49 | color="lightgrey",
50 | )
51 |
52 | # draw line
53 | df.plot(
54 | y="value",
55 | color="lightgrey", # Line color
56 | legend=False,
57 | ax=ax,
58 | )
59 |
60 | # max date
61 | max_date = df["value"].idxmax()
62 | max_co2 = df["value"].max()
63 | ax.text(
64 | 0.0,
65 | 0.92,
66 | f"{max_co2:.0f} ppm",
67 | va="bottom",
68 | ha="left",
69 | transform=ax.transAxes,
70 | fontsize=25,
71 | color="lightgrey",
72 | )
73 | ax.text(
74 | 0.0,
75 | 0.91,
76 | f"Peaked in {max_date.year}",
77 | va="top",
78 | ha="left",
79 | transform=ax.transAxes,
80 | fontsize=12,
81 | color="lightgrey",
82 | fontname="Courier New",
83 | )
84 |
85 | # draw end point
86 | date = df.index[-1]
87 | co2 = df["value"].values[-1]
88 | diff = df["diff"].fillna(0).values[-1]
89 | diff = f"+{diff:.0f}" if diff >= 0 else f"{diff:.0f}"
90 | ax.scatter(date, co2, color="red", zorder=999)
91 | ax.annotate(
92 | f"{diff} ppm",
93 | (date, co2),
94 | textcoords="offset points",
95 | xytext=(-10, 5),
96 | fontsize=12,
97 | ha="right",
98 | va="bottom",
99 | color="lightgrey",
100 | )
101 |
102 | # draw source label
103 | ax.text(
104 | 0.0,
105 | 0.03,
106 | f"Source: {URL}",
107 | va="bottom",
108 | ha="left",
109 | transform=ax.transAxes,
110 | fontsize=8,
111 | color="lightgrey",
112 | )
113 |
114 | # properly tighten layout
115 | plt.subplots_adjust(bottom=0, top=0.9, right=0.9, left=0.05)
116 |
117 | # pause at notable years
118 | year = date.year
119 | if year in NOTABLE_YEARS:
120 | ax.annotate(
121 | f"{NOTABLE_YEARS[year]} - {year}",
122 | (date, co2),
123 | textcoords="offset points",
124 | xytext=(-10, 3),
125 | fontsize=10.5,
126 | ha="right",
127 | va="top",
128 | color="lightgrey",
129 | fontname="Courier New",
130 | )
131 | return Paused(ax, 2.8)
132 | else:
133 | ax.annotate(
134 | year,
135 | (date, co2),
136 | textcoords="offset points",
137 | xytext=(-10, 3),
138 | fontsize=10.5,
139 | ha="right",
140 | va="top",
141 | color="lightgrey",
142 | fontname="Courier New",
143 | )
144 | return ax
145 |
146 |
147 | if __name__ == "__main__":
148 | df = (
149 | pd.read_csv(URL, parse_dates=True, index_col="date")
150 | .resample("1YE")
151 | .max()
152 | .interpolate()
153 | .assign(
154 | diff=lambda df: df["value"].diff(),
155 | )
156 | )
157 | stream(df, renderer=renderer, max_frames=-1, threads_per_worker=1).write("co2_emissions.mp4")
158 | ```
--------------------------------------------------------------------------------
/docs/example_recipes/temperature_anomaly.md:
--------------------------------------------------------------------------------
1 | # Temperature anomaly
2 |
3 |
6 |
7 | Shows the global temperature anomaly from 1995 to 2024 using the HadCRUT5 dataset. The video pauses at notable dates.
8 |
9 | Highlights:
10 |
11 | - Uses `wrap_matplotlib` to automatically handle saving and closing the figure.
12 | - Uses a custom `renderer` function to create each frame of the animation.
13 | - Uses `Paused` to pause the animation at notable dates.
14 |
15 | ```python hl_lines="16 17 112"
16 | import pandas as pd
17 | import matplotlib.pyplot as plt
18 | from streamjoy import stream, wrap_matplotlib, Paused
19 |
20 | URL = "https://climexp.knmi.nl/data/ihadcrut5_global.dat"
21 | NOTABLE_DATES = {
22 | "1997-12": "Kyoto Protocol adopted",
23 | "2005-01": "Exceeded 380 ppm",
24 | "2010-01": "Exceeded 390 ppm",
25 | "2013-05": "Exceeded 400 ppm",
26 | "2015-12": "Paris Agreement signed",
27 | "2016-01": "CO2 permanently over 400 ppm",
28 | }
29 |
30 |
31 | @wrap_matplotlib()
32 | def renderer(df):
33 | plt.style.use("dark_background") # Setting the style for dark mode
34 |
35 | fig, ax = plt.subplots()
36 | fig.patch.set_facecolor("#1b1e23")
37 | ax.set_facecolor("#1b1e23")
38 | ax.set_frame_on(False)
39 | ax.axis("off")
40 |
41 | # Set title
42 | year = df["year"].iloc[-1]
43 | ax.set_title(
44 | f"Global Temperature Anomaly {year} [HadCRUT5]",
45 | fontsize=15,
46 | loc="left",
47 | fontname="Courier New",
48 | color="lightgrey",
49 | )
50 |
51 | # draw line
52 | df.groupby("year")["anom"].plot(
53 | y="anom", color="lightgrey", legend=False, ax=ax, lw=0.5
54 | )
55 |
56 | # add source text at bottom right
57 | ax.text(
58 | 0.01,
59 | 0.05,
60 | f"Source: {URL}",
61 | va="bottom",
62 | ha="left",
63 | transform=ax.transAxes,
64 | fontsize=8,
65 | color="lightgrey",
66 | fontname="Courier New",
67 | )
68 |
69 | # draw end point
70 | jday = df.index.values[-1]
71 | anom = df["anom"].values[-1]
72 | ax.scatter(jday, anom, color="red", zorder=999)
73 | anom_label = f"+{anom:.1f} K" if anom > 0 else f"{anom:.1f} K"
74 | ax.annotate(
75 | anom_label,
76 | (jday, anom),
77 | textcoords="offset points",
78 | xytext=(-10, 5),
79 | fontsize=12,
80 | ha="right",
81 | va="bottom",
82 | color="lightgrey",
83 | )
84 |
85 | # draw yearly labels
86 | for year, df_year in df.reset_index().groupby("year").last().iloc[-5:].iterrows():
87 | if df_year["month"] != 12:
88 | continue
89 | ax.annotate(
90 | year,
91 | (df_year["jday"], df_year["anom"]),
92 | fontsize=12,
93 | ha="left",
94 | va="center",
95 | color="lightgrey",
96 | fontname="Courier New",
97 | )
98 |
99 | plt.subplots_adjust(bottom=0, top=0.9, left=0.05)
100 |
101 | month = df["date"].iloc[-1].strftime("%b")
102 | ax.annotate(
103 | month,
104 | (jday, anom),
105 | textcoords="offset points",
106 | xytext=(-10, 3),
107 | fontsize=12,
108 | ha="right",
109 | va="top",
110 | color="lightgrey",
111 | fontname="Courier New",
112 | )
113 | date_string = df["date"].iloc[-1].strftime("%Y-%m")
114 | if date_string in NOTABLE_DATES:
115 | ax.annotate(
116 | f"{NOTABLE_DATES[date_string]}",
117 | xy=(0, 1),
118 | xycoords="axes fraction",
119 | xytext=(0, -5),
120 | textcoords="offset points",
121 | fontsize=12,
122 | ha="left",
123 | va="top",
124 | color="lightgrey",
125 | fontname="Courier New",
126 | )
127 | return Paused(fig, 3)
128 | return fig
129 |
130 |
131 | df = (
132 | pd.read_csv(
133 | URL,
134 | comment="#",
135 | header=None,
136 | sep="\s+",
137 | na_values=[-999.9],
138 | )
139 | .rename(columns={0: "year"})
140 | .melt(id_vars="year", var_name="month", value_name="anom")
141 | )
142 | df.index = pd.to_datetime(
143 | df["year"].astype(str) + df["month"].astype(str), format="%Y%m"
144 | )
145 | df = df.sort_index()["1995":"2024"]
146 | df["jday"] = df.index.dayofyear
147 | df = df.rename_axis("date").reset_index().set_index("jday")
148 | df_list = [df[:i] for i in range(1, len(df) + 1)]
149 |
150 | stream(df_list, renderer=renderer, threads_per_worker=1).write(
151 | "temperature_anomaly.mp4"
152 | )
153 | ```
--------------------------------------------------------------------------------
/streamjoy/renderers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING, Any
4 |
5 | from . import _utils
6 |
7 | if TYPE_CHECKING:
8 | try:
9 | import matplotlib.pyplot as plt
10 | except ImportError:
11 | plt = None
12 |
13 | try:
14 | import pandas as pd
15 | except ImportError:
16 | pd = None
17 |
18 | try:
19 | import polars as pl
20 | except ImportError:
21 | pl = None
22 |
23 | try:
24 | import xarray as xr
25 | except ImportError:
26 | xr = None
27 |
28 | try:
29 | import holoviews as hv
30 | except ImportError:
31 | hv = None
32 |
33 |
34 | def default_pandas_renderer(
35 | df_sub: pd.DataFrame, *args: tuple[Any], **kwargs: dict[str, Any]
36 | ) -> plt.Figure:
37 | """
38 | Render a pandas DataFrame using matplotlib.
39 |
40 | Args:
41 | df_sub: The DataFrame to render.
42 | *args: Additional positional arguments to pass to the renderer.
43 | **kwargs: Additional keyword arguments to pass to the renderer.
44 |
45 | Returns
46 | A matplotlib figure.
47 | """
48 | import matplotlib.pyplot as plt
49 |
50 | df_sub = df_sub.reset_index()
51 |
52 | fig, ax = plt.subplots()
53 |
54 | title = kwargs.get("title")
55 | if title:
56 | title = title.format(**df_sub.iloc[-1])
57 | elif title is None:
58 | title = df_sub[kwargs["x"]].iloc[-1]
59 | kwargs["title"] = title
60 |
61 | groupby = kwargs.pop("groupby", None)
62 | if groupby:
63 | for group, df_group in df_sub.groupby(groupby):
64 | df_group.plot(*args, ax=ax, label=group, **kwargs)
65 | else:
66 | df_sub.plot(*args, ax=ax, **kwargs)
67 |
68 | return fig
69 |
70 |
71 | def default_polars_renderer(
72 | df_sub: pl.DataFrame, *args: tuple[Any], **kwargs: dict[str, Any]
73 | ) -> hv.Element:
74 | """
75 | Render a polars DataFrame using HoloViews.
76 |
77 | Args:
78 | df_sub: The DataFrame to render.
79 | *args: Additional positional arguments to pass to the renderer.
80 | **kwargs: Additional keyword arguments to pass to the renderer.
81 |
82 | Returns:
83 | The rendered HoloViews Element.
84 | """
85 | backend = kwargs.pop("backend", None)
86 | by = kwargs.pop("groupby", None)
87 |
88 | title = kwargs.get("title")
89 | if title:
90 | title = title.format(**df_sub.tail(1).to_pandas().to_dict("records")[0])
91 | elif title is None:
92 | title = df_sub[kwargs["x"]].tail(1)[0]
93 | kwargs["title"] = str(title)
94 |
95 | if by:
96 | kwargs["by"] = by
97 | hv_obj = df_sub.plot(*args, **kwargs)
98 | return default_holoviews_renderer(hv_obj, backend=backend)
99 |
100 |
101 | def default_xarray_renderer(
102 | da_sel: xr.DataArray, *args: tuple[Any], **kwargs: dict[str, Any]
103 | ) -> plt.Figure:
104 | """
105 | Render an xarray DataArray using matplotlib.
106 |
107 | Args:
108 | da_sel: The DataArray to render.
109 | *args: Additional positional arguments to pass to the renderer.
110 | **kwargs: Additional keyword arguments to pass to the renderer.
111 |
112 | Returns:
113 | A matplotlib figure.
114 | """
115 | import matplotlib.pyplot as plt
116 |
117 | da_sel = _utils.validate_xarray(da_sel, warn=False)
118 |
119 | fig = plt.figure()
120 | ax = plt.axes(**kwargs.pop("subplot_kws", {}))
121 | title = kwargs.pop("title", None)
122 |
123 | try:
124 | da_sel.plot(ax=ax, extend="both", *args, **kwargs)
125 | except Exception:
126 | da_sel.plot(ax=ax, *args, **kwargs)
127 |
128 | if title:
129 | title_format = {coord: da_sel[coord].values for coord in da_sel.coords}
130 | ax.set_title(title.format(**title_format))
131 |
132 | return fig
133 |
134 |
135 | def default_holoviews_renderer(
136 | hv_obj: hv.Element, *args: tuple[Any], **kwargs: dict[str, Any]
137 | ) -> hv.Element:
138 | """
139 | Render a HoloViews Element using the default backend.
140 |
141 | Args:
142 | hv_obj: The HoloViews Element to render.
143 | *args: Additional positional arguments to pass to the renderer.
144 | **kwargs: Additional keyword arguments to pass to the renderer.
145 |
146 | Returns:
147 | The rendered HoloViews Element.
148 | """
149 | import holoviews as hv
150 |
151 | backend = kwargs.get("backend", hv.Store.current_backend)
152 |
153 | clims = kwargs.pop("clims", {})
154 | for hv_el in hv_obj.traverse(full_breadth=False):
155 | try:
156 | vdim = hv_el.vdims[0].name
157 | except IndexError:
158 | continue
159 | if vdim in clims:
160 | hv_el.opts(clim=clims[vdim], backend=backend)
161 |
162 | if backend == "bokeh":
163 | kwargs["toolbar"] = None
164 | elif backend == "matplotlib":
165 | kwargs["cbar_extend"] = kwargs.get("cbar_extend", "both")
166 |
167 | if isinstance(hv_obj, hv.Overlay):
168 | for hv_el in hv_obj:
169 | try:
170 | hv_el.opts(**kwargs)
171 | except Exception:
172 | pass
173 | else:
174 | hv_obj.opts(**kwargs)
175 |
176 | return hv_obj
177 |
--------------------------------------------------------------------------------
/docs/supported_formats.md:
--------------------------------------------------------------------------------
1 | # Supported formats
2 |
3 | StreamJoy supports a variety of input types!
4 |
5 | ## 📋 List of Images, GIFs, Videos, or URLs
6 |
7 | ```python
8 | from streamjoy import stream
9 |
10 | URL_FMT = "https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/2024/01_Jan/N_202401{day:02d}_conc_v3.0.png"
11 |
12 | stream([URL_FMT.format(day=day) for day in range(1, 31)], uri="2024_jan_sea_ice.mp4")
13 | ```
14 |
17 |
18 | ## 📁 Directory of Images, GIFs, Videos, or URLs
19 |
20 | ```python
21 | from streamjoy import stream
22 |
23 | URL_DIR = "https://downloads.psl.noaa.gov/Datasets/ncep.reanalysis/Dailies/surface/"
24 |
25 | stream(URL_DIR, uri="air_temperature.mp4", pattern="air.sig995.194*.nc")
26 | ```
27 |
28 |
31 |
32 | ## 🧮 Numpy NdArray
33 |
34 | ```python
35 | from streamjoy import stream
36 | import imageio.v3 as iio
37 |
38 | array = iio.imread("imageio:newtonscradle.gif") # is a 4D numpy array
39 | stream(array, max_frames=-1).write("newtonscradle.mp4")
40 | ```
41 |
42 |
45 |
46 | ## 🐼 Pandas DataFrame or Series
47 |
48 | !!! note "Additional Requirements"
49 |
50 | You will need to additionally install `pandas` and `matplotlib` to support this format:
51 |
52 | ```bash
53 | pip install pandas matplotlib
54 | ```
55 |
56 | ```python
57 | from streamjoy import stream
58 | import pandas as pd
59 |
60 | df = pd.read_csv(
61 | "https://raw.githubusercontent.com/franlopezguzman/gapminder-with-bokeh/master/gapminder_tidy.csv"
62 | ).set_index("Year")
63 | df = df.query("Country in ['United States', 'China', 'South Africa']")
64 | stream(df, uri="gapminder.mp4", groupby="Country", title="{Year}")
65 | ```
66 |
67 |
70 |
71 | ## 🐻❄️ Polars DataFrame
72 |
73 | !!! note "Additional Requirements"
74 |
75 | You will need to additionally install `polars`, `pyarrow`, `hvplot`, `selenium`, and `webdriver-manager` to support this format:
76 |
77 | ```bash
78 | pip install polars pyarrow hvplot selenium webdriver-manager
79 | ```
80 |
81 | You must also have `firefox` or `chromedriver` installed on your system.
82 |
83 | ```bash
84 | conda install -c conda-forge firefox
85 | ```
86 |
87 | ```python
88 | from streamjoy import stream
89 | import polars as pl
90 |
91 | df = pl.read_csv(
92 | "https://raw.githubusercontent.com/franlopezguzman/gapminder-with-bokeh/master/gapminder_tidy.csv"
93 | )
94 | df = df.filter(pl.col("Country").is_in(['United States', 'China', 'South Africa']))
95 | stream(df, uri="gapminder.mp4", groupby="Country", title="{Year}")
96 | ```
97 |
98 |
101 |
102 | ## 🗄️ XArray Dataset or DataArray
103 |
104 | !!! note "Additional Requirements"
105 |
106 | You will need to additionally install `xarray` and `matplotlib` to support this format:
107 |
108 | ```bash
109 | pip install xarray matplotlib
110 | ```
111 |
112 | For this example, you will also need to install `pooch` and `netcdf4`:
113 |
114 | ```bash
115 | pip install pooch netcdf4
116 | ```
117 |
118 | ```python
119 | from streamjoy import stream
120 | import xarray as xr
121 |
122 | ds = xr.tutorial.open_dataset("air_temperature").isel(time=slice(0, 100))
123 | stream(ds, uri="air.mp4", cmap="RdBu_r")
124 | ```
125 |
126 |
129 |
130 | ## 📊 HoloViews HoloMap or DynamicMap
131 |
132 | !!! note "Additional Requirements"
133 |
134 | You will need to additionally install `holoviews` to support this format:
135 |
136 | ```bash
137 | pip install holoviews
138 | ```
139 |
140 | For the bokeh backend, you will need to install `bokeh`, `selenium`, and `webdriver-manager`:
141 |
142 | ```bash
143 | pip install bokeh selenium webdriver-manager
144 | ```
145 |
146 | For the matplotlib backend, you will need to install `matplotlib`:
147 |
148 | ```bash
149 | pip install matplotlib
150 | ```
151 |
152 | For this example, you will also need to install `pooch`, `netcdf4, `hvplot`, and `xarray`:
153 |
154 | ```bash
155 | pip install pooch netcdf4 hvplot xarray
156 | ```
157 |
158 | You must also have `firefox` or `chromedriver` installed on your system.
159 |
160 | ```bash
161 | conda install -c conda-forge firefox
162 | ```
163 |
164 | ```python
165 | import xarray as xr
166 | import hvplot.xarray
167 | from streamjoy import stream
168 |
169 | ds = xr.tutorial.open_dataset("rasm").isel(time=slice(10))
170 | stream(ds.hvplot.image("x", "y"), uri="rasm.mp4")
171 | ```
172 |
173 |
176 |
--------------------------------------------------------------------------------
/tests/test_streams.py:
--------------------------------------------------------------------------------
1 | import panel as pn
2 | import pytest
3 | from imageio.v3 import improps
4 |
5 | from streamjoy.models import Paused
6 | from streamjoy.streams import GifStream, HtmlStream, Mp4Stream
7 | from streamjoy.wrappers import wrap_matplotlib
8 |
9 |
10 | class AbstractTestMediaStream:
11 | def _assert_stream_and_props(self, sj, stream_cls, max_frames=3):
12 | assert isinstance(sj, stream_cls)
13 | buf = sj.write()
14 | props = improps(buf)
15 | props.n_images == max_frames
16 | return props
17 |
18 | def test_from_numpy(self, stream_cls, array):
19 | sj = stream_cls.from_numpy(array)
20 | self._assert_stream_and_props(sj, stream_cls)
21 |
22 | def test_from_pandas(self, stream_cls, df):
23 | sj = stream_cls.from_pandas(df)
24 | self._assert_stream_and_props(sj, stream_cls)
25 |
26 | def test_from_polars(self, stream_cls, pl_df):
27 | sj = stream_cls.from_polars(pl_df)
28 | self._assert_stream_and_props(sj, stream_cls)
29 |
30 | def test_from_xarray(self, stream_cls, ds):
31 | sj = stream_cls.from_xarray(ds)
32 | self._assert_stream_and_props(sj, stream_cls)
33 |
34 | def test_from_holoviews_hmap(self, stream_cls, hmap):
35 | sj = stream_cls.from_holoviews(hmap)
36 | self._assert_stream_and_props(sj, stream_cls)
37 |
38 | def test_from_holoviews_dmap(self, stream_cls, dmap):
39 | sj = stream_cls.from_holoviews(dmap)
40 | self._assert_stream_and_props(sj, stream_cls)
41 |
42 | def test_from_url_dir(self, stream_cls):
43 | sj = stream_cls.from_url(
44 | "https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/1978/10_Oct/",
45 | pattern="N_197810*_conc_v3.0.png",
46 | )
47 | self._assert_stream_and_props(sj, stream_cls)
48 |
49 | def test_from_url_path(self, stream_cls):
50 | sj = stream_cls.from_url(
51 | "https://github.com/ahuang11/streamjoy/raw/main/tests/data/gapminder.parquet",
52 | )
53 | self._assert_stream_and_props(sj, stream_cls)
54 |
55 | def test_from_directory(self, stream_cls, data_dir):
56 | sj = stream_cls.from_directory(data_dir, pattern="*.png")
57 | self._assert_stream_and_props(sj, stream_cls)
58 |
59 | def test_fsspec_fs(self, stream_cls, df, fsspec_fs):
60 | sj = stream_cls.from_pandas(df, fsspec_fs=fsspec_fs)
61 | self._assert_stream_and_props(sj, stream_cls)
62 |
63 | def test_holoviews_matplotlib_backend(self, stream_cls, ds):
64 | sj = stream_cls.from_holoviews(
65 | ds.hvplot("lon", "lat", fig_size=200, backend="matplotlib")
66 | )
67 | props = self._assert_stream_and_props(sj, stream_cls)
68 | assert props.shape[1] == 300
69 |
70 | def test_holoviews_bokeh_backend(self, stream_cls, ds):
71 | sj = stream_cls.from_holoviews(
72 | ds.hvplot("lon", "lat", width=300, backend="bokeh")
73 | )
74 | props = self._assert_stream_and_props(sj, stream_cls)
75 | assert props.shape[1] == 300
76 |
77 | def test_write_max_frames(self, stream_cls, df):
78 | sj = stream_cls.from_pandas(df, max_frames=3)
79 | self._assert_stream_and_props(sj, stream_cls, max_frames=3)
80 |
81 |
82 | class TestGifStream(AbstractTestMediaStream):
83 | @pytest.fixture(scope="class")
84 | def stream_cls(self):
85 | return GifStream
86 |
87 | def test_paused(self, stream_cls, df):
88 | @wrap_matplotlib()
89 | def renderer(df, groupby=None): # TODO: fix bug groupby not needed
90 | return Paused(df.plot(), seconds=2)
91 |
92 | buf = stream_cls.from_pandas(df, renderer=renderer).write()
93 | props = improps(buf)
94 | assert props.n_images == 3
95 |
96 |
97 | class TestMp4Stream(AbstractTestMediaStream):
98 | @pytest.fixture(scope="class")
99 | def stream_cls(self):
100 | return Mp4Stream
101 |
102 | def test_paused(self, stream_cls, df):
103 | @wrap_matplotlib()
104 | def renderer(df, groupby=None): # TODO: fix bug groupby not needed
105 | return Paused(df.plot(), seconds=2)
106 |
107 | buf = stream_cls.from_pandas(df, renderer=renderer).write()
108 | props = improps(buf)
109 | assert props.n_images == 9
110 |
111 |
112 | class TestHtmlStream(AbstractTestMediaStream):
113 | def _assert_stream_and_props(self, sj, stream_cls, max_frames=3):
114 | assert isinstance(sj, stream_cls)
115 | buf = sj.write()
116 | assert isinstance(buf, pn.Column)
117 | tabs = buf[0]
118 | assert isinstance(tabs, pn.Tabs)
119 | image = tabs[0]
120 | assert isinstance(image, pn.pane.Image)
121 | assert len(tabs) == max_frames
122 | player = buf[1]
123 | assert isinstance(player, pn.widgets.Player)
124 | return image
125 |
126 | @pytest.fixture(scope="class")
127 | def stream_cls(self):
128 | return HtmlStream
129 |
130 | def test_holoviews_matplotlib_backend(self, stream_cls, ds):
131 | sj = stream_cls.from_holoviews(
132 | ds.hvplot("lon", "lat", fig_size=200, backend="matplotlib")
133 | )
134 | image = self._assert_stream_and_props(sj, stream_cls)
135 | assert image.width is None
136 |
137 | def test_holoviews_bokeh_backend(self, stream_cls, ds):
138 | sj = stream_cls.from_holoviews(
139 | ds.hvplot("lon", "lat", width=300, backend="bokeh")
140 | )
141 | image = self._assert_stream_and_props(sj, stream_cls)
142 | assert image.width is None
143 |
144 | def test_fixed_width_height(self, stream_cls, df):
145 | sj = stream_cls.from_pandas(df, width=300, height=300, sizing_mode="fixed")
146 | image = self._assert_stream_and_props(sj, stream_cls)
147 | assert image.width == 300
148 | assert image.height == 300
149 |
150 | def test_stretch_width(self, stream_cls, df):
151 | sj = stream_cls.from_pandas(df, height=300, sizing_mode="stretch_width")
152 | image = self._assert_stream_and_props(sj, stream_cls)
153 | assert image.width is None
154 | assert image.height is None
155 |
--------------------------------------------------------------------------------
/streamjoy/wrappers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import time
5 | from functools import wraps
6 | from io import BytesIO
7 | from pathlib import Path
8 | from typing import Any, Callable
9 |
10 | from . import _utils
11 | from .models import Paused
12 | from .settings import config
13 |
14 |
15 | def wrap_matplotlib(
16 | in_memory: bool = False,
17 | scratch_dir: str | Path | None = None,
18 | fsspec_fs: Any | None = None,
19 | ) -> Callable:
20 | """
21 | Wraps a function used to render a matplotlib figure so that
22 | it automatically saves the figure and closes it.
23 |
24 | Args:
25 | in_memory: Whether to render the figure in-memory.
26 | scratch_dir: The scratch directory to use.
27 | fsspec_fs: The fsspec filesystem to use.
28 |
29 | Returns:
30 | The wrapped function.
31 | """
32 |
33 | def wrapper(renderer):
34 | @wraps(renderer)
35 | def wrapped(*args, **kwargs) -> Path | BytesIO:
36 | import matplotlib
37 |
38 | matplotlib.use("Agg")
39 | import matplotlib.pyplot as plt
40 |
41 | plt.rcParams.update({"figure.max_open_warning": config["max_open_warning"]})
42 |
43 | output = renderer(*args, **kwargs)
44 |
45 | fig = output
46 | return_paused = False
47 | if isinstance(output, Paused):
48 | return_paused = True
49 | fig = output.output
50 |
51 | if isinstance(fig, plt.Axes):
52 | fig = fig.figure
53 | elif not isinstance(fig, plt.Figure):
54 | raise ValueError("Renderer must return a Figure or Axes object.")
55 |
56 | uri = _utils.resolve_uri(
57 | file_name=f"{hash(fig)}.jpg",
58 | scratch_dir=scratch_dir,
59 | in_memory=in_memory,
60 | fsspec_fs=fsspec_fs,
61 | )
62 | if fsspec_fs:
63 | with fsspec_fs.open(uri, "wb") as f:
64 | buf = BytesIO()
65 | fig.savefig(buf, format="jpg")
66 | buf.seek(0)
67 | f.write(buf.read())
68 | else:
69 | fig.savefig(uri, format="jpg")
70 | plt.close(fig)
71 | return (
72 | uri if not return_paused else Paused(output=uri, seconds=output.seconds)
73 | )
74 |
75 | return wrapped
76 |
77 | return wrapper
78 |
79 |
80 | def wrap_holoviews(
81 | in_memory: bool = False,
82 | scratch_dir: str | Path | None = None,
83 | fsspec_fs: Any | None = None,
84 | webdriver: str | Callable | None = None,
85 | num_retries: int | None = None,
86 | ) -> Callable:
87 | """
88 | Wraps a function used to render a holoviews object so that
89 | it automatically saves the object.
90 |
91 | Args:
92 | in_memory: Whether to render the object in-memory.
93 | scratch_dir: The scratch directory to use.
94 | fsspec_fs: The fsspec filesystem to use.
95 | webdriver: The webdriver to use.
96 | num_retries: The number of retries to use.
97 |
98 | Returns:
99 | The wrapped function.
100 | """
101 |
102 | webdriver = _utils.get_config_default("webdriver", webdriver, warn=False)
103 | if isinstance(webdriver, str):
104 | webdriver = (webdriver, _utils.get_webdriver_path(webdriver))
105 |
106 | if in_memory:
107 | raise ValueError("Holoviews renderer does not support in-memory rendering.")
108 |
109 | def wrapper(renderer):
110 | @wraps(renderer)
111 | def wrapped(*args, **kwargs) -> Path | BytesIO:
112 | import holoviews as hv
113 |
114 | backend = kwargs.get("backend", hv.Store.current_backend)
115 | output = renderer(*args, **kwargs)
116 |
117 | hv_obj = output
118 | return_paused = False
119 | if isinstance(output, Paused):
120 | return_paused = True
121 | hv_obj = output.output
122 |
123 | uri = _utils.resolve_uri(
124 | file_name=f"{hash(hv_obj)}.png",
125 | scratch_dir=scratch_dir,
126 | in_memory=in_memory,
127 | fsspec_fs=fsspec_fs,
128 | )
129 | if backend == "bokeh":
130 | from bokeh.io.export import get_screenshot_as_png
131 |
132 | retries = _utils.get_config_default(
133 | "num_retries", num_retries, warn=False
134 | )
135 | for r in range(retries):
136 | try:
137 | driver = _utils.get_webdriver(webdriver)
138 | with driver:
139 | image = get_screenshot_as_png(
140 | hv.render(hv_obj, backend=backend), driver=driver
141 | )
142 | if fsspec_fs:
143 | with fsspec_fs.open(uri, "wb") as f:
144 | image.save(f, format="png")
145 | else:
146 | image.save(uri, format="png")
147 | break
148 | except Exception as e:
149 | logging.warning(
150 | f"Failed to save image: {e}, retrying in {r * 2}s"
151 | )
152 | time.sleep(r * 2)
153 | if r == retries - 1:
154 | raise e
155 | else:
156 | if fsspec_fs:
157 | with fsspec_fs.open(uri, "wb") as f:
158 | buf = BytesIO()
159 | hv.save(hv_obj, buf, fmt="png")
160 | buf.seek(0)
161 | f.write(buf.read())
162 | else:
163 | hv.save(hv_obj, uri, fmt="png", backend=backend)
164 |
165 | return (
166 | uri if not return_paused else Paused(output=uri, seconds=output.seconds)
167 | )
168 |
169 | return wrapped
170 |
171 | return wrapper
172 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🌈 StreamJoy 😊
2 |
3 | ---
4 |
5 | [](https://github.com/ahuang11/streamjoy/actions) [](https://codecov.io/gh/ahuang11/streamjoy) [](https://badge.fury.io/py/streamjoy)
6 |
7 | [](https://pepy.tech/project/streamjoy) [](https://img.shields.io/github/stars/ahuang11/streamjoy?style=flat-square)
8 |
9 | ---
10 |
11 | ## 🔥 Enjoy animating!
12 |
13 | Streamjoy turns your images into animations using sensible defaults for fun, hassle-free creation.
14 |
15 | It cuts down the boilerplate and time to work on animations, and it's simple to start with just a few lines of code.
16 |
17 | Install it with just pip to start, blazingly fast!
18 |
19 | ```python
20 | pip install streamjoy
21 | ```
22 |
23 |
24 |
25 | Or, try out a basic web app version here:
26 |
27 | https://huggingface.co/spaces/ahuang11/streamjoy
28 |
29 |
30 |
31 | ## 🛠️ Built-in features
32 |
33 | - 🌐 Animate from URLs, files, and datasets
34 | - 🎨 Render images with default or custom renderers
35 | - 🎬 Provide context with a short intro splash
36 | - ⏸ Add pauses at the beginning, end, or between frames
37 | - ⚡ Execute read, render, and write in parallel
38 | - 🔗 Connect multiple animations together
39 |
40 | ## 🚀 Quick start
41 |
42 | ### 🐤 Absolute basics
43 |
44 | Stream from a list of images--local files work too!
45 |
46 | ```python
47 | from streamjoy import stream
48 |
49 | if __name__ == "__main__":
50 | URL_FMT = "https://www.cpc.ncep.noaa.gov/products/NMME/archive/2025090800/current/images/NMME_ensemble_tmpsfc_lead{i}.png"
51 | resources = [URL_FMT.format(i=i) for i in range(1, 8)]
52 | stream(resources, uri="nmme.gif") # .gif, .mp4, and .html supported
53 | ```
54 |
55 | 
56 |
57 | ### 💅 Polish up
58 |
59 | Specify a few more keywords to:
60 |
61 | 1. add an intro title and subtitle
62 | 2. adjust the pauses
63 | 3. optimize the GIF thru pygifsicle
64 |
65 | Note: This example no longer works because URL changed!
66 |
67 | ```python
68 | from streamjoy import stream
69 |
70 | if __name__ == "__main__":
71 | URL_FMT = "https://www.goes.noaa.gov/dimg/jma/fd/vis/{i}.gif"
72 | resources = [URL_FMT.format(i=i) for i in range(1, 11)]
73 | himawari_stream = stream(
74 | resources,
75 | uri="goes_custom.gif",
76 | intro_title="Himawari Visible",
77 | intro_subtitle="10 Hours Loop",
78 | intro_pause=1,
79 | ending_pause=1,
80 | optimize=True,
81 | )
82 | ```
83 |
84 |
85 |
86 | ### 👀 Preview inputs
87 |
88 | If you'd like to preview the `repr` before writing, drop `uri`.
89 |
90 | Note: This example no longer works because URL changed!
91 |
92 | Output:
93 | ```yaml
94 |
95 | ---
96 | Output:
97 | max_frames: 50
98 | fps: 10
99 | display: True
100 | scratch_dir: streamjoy_scratch
101 | in_memory: False
102 | ---
103 | Intro:
104 | intro_title: Himawari Visible
105 | intro_subtitle: 10 Hours Loop
106 | intro_watermark: made with streamjoy
107 | intro_pause: 1
108 | intro_background: black
109 | ---
110 | Client:
111 | batch_size: 10
112 | processes: True
113 | threads_per_worker: None
114 | ---
115 | Resources: (10 frames to stream)
116 | https://www.goes.noaa.gov/dimg/jma/fd/vis/1.gif
117 | ...
118 | https://www.goes.noaa.gov/dimg/jma/fd/vis/10.gif
119 | ---
120 | ```
121 |
122 | Then, when ready, call the `write` method to save the animation!
123 |
124 | ```python
125 | himawari_stream.write()
126 | ```
127 |
128 | ### 🖇️ Connect streams
129 |
130 | Connect multiple streams together to provide further context.
131 |
132 | Note: This example no longer works because URL changed!
133 |
134 | ```python
135 | from streamjoy import stream, connect
136 |
137 | URL_FMTS = {
138 | "visible": "https://www.goes.noaa.gov/dimg/jma/fd/vis/{i}.gif",
139 | "infrared": "https://www.goes.noaa.gov/dimg/jma/fd/rbtop/{i}.gif",
140 | }
141 |
142 | if __name__ == "__main__":
143 | visible_stream = stream(
144 | [URL_FMTS["visible"].format(i=i) for i in range(1, 11)],
145 | intro_title="Himawari Visible",
146 | intro_subtitle="10 Hours Loop",
147 | )
148 | infrared_stream = stream(
149 | [URL_FMTS["infrared"].format(i=i) for i in range(1, 11)],
150 | intro_title="Himawari Infrared",
151 | intro_subtitle="10 Hours Loop",
152 | )
153 | connect([visible_stream, infrared_stream], uri="goes_connected.gif")
154 | ```
155 |
156 |
157 |
158 | ### 📷 Render datasets
159 |
160 | You can also render images directly from datasets, either through a custom renderer or a built-in one, and they'll also run in parallel!
161 |
162 | The following example requires xarray, cartopy, matplotlib, and netcdf4.
163 |
164 | ```bash
165 | pip install xarray cartopy matplotlib netcdf4
166 | ```
167 |
168 | ```python
169 | import numpy as np
170 | import cartopy.crs as ccrs
171 | import matplotlib.pyplot as plt
172 | from streamjoy import stream, wrap_matplotlib
173 |
174 | @wrap_matplotlib()
175 | def plot(da, central_longitude, **plot_kwargs):
176 | time = da["time"].dt.strftime("%b %d %Y").values.item()
177 | projection = ccrs.Orthographic(central_longitude=central_longitude)
178 | subplot_kw = dict(projection=projection, facecolor="gray")
179 | fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=subplot_kw)
180 | im = da.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, **plot_kwargs)
181 | ax.set_title(f"Sea Surface Temperature Anomaly\n{time}", loc="left", transform=ax.transAxes)
182 | ax.set_title("Source: NOAA OISST v2.1", loc="right", size=5, y=-0.01)
183 | ax.set_title("", loc="center") # suppress default title
184 | plt.colorbar(im, ax=ax, label="°C", shrink=0.8)
185 | return fig
186 |
187 | if __name__ == "__main__":
188 | url = (
189 | "https://www.ncei.noaa.gov/data/sea-surface-temperature-"
190 | "optimum-interpolation/v2.1/access/avhrr/201008/"
191 | )
192 | pattern = "oisst-avhrr-v02r01.*.nc"
193 | stream(
194 | url,
195 | uri="oisst.gif",
196 | pattern=pattern, # GifStream.from_url kwargs
197 | max_files=30,
198 | renderer=plot, # renderer related kwargs
199 | renderer_iterables=[np.linspace(-140, -150, 30)], # iterables; central longitude per frame (30 frames)
200 | renderer_kwargs=dict(cmap="RdBu_r", vmin=-5, vmax=5), # renderer kwargs
201 | # cmap="RdBu_r", # renderer_kwargs can also be propagated for convenience
202 | # vmin=-5,
203 | # vmax=5,
204 | )
205 | ```
206 |
207 |
208 |
209 | Check out all the supported formats [here](https://ahuang11.github.io/streamjoy/supported_formats/) or best practices [here](https://ahuang11.github.io/streamjoy/best_practices/). (Or maybe you're interested in the design--[here](https://ahuang11.github.io/streamjoy/package_design/))
210 |
211 | ---
212 |
213 | ❤️ Made with considerable passion.
214 |
215 | 🌟 Appreciate the project? Consider giving a star!
216 |
217 |
--------------------------------------------------------------------------------
/docs/how_do_i.md:
--------------------------------------------------------------------------------
1 | # How do I...
2 |
3 | ## 🖼️ Use all resources with `max_frames=-1`
4 |
5 | By default, StreamJoy only renders the first 50 frames to prevent accidentally rendering a large dataset.
6 |
7 | To render all frames, set `max_frames=-1`.
8 |
9 | ```python
10 | from streamjoy import stream
11 |
12 | stream(..., max_frames=-1)
13 | ```
14 |
15 | ## ⏸️ How to pause animations with `Paused`, `intro_pause`, `ending_pause`
16 |
17 | Animations can be good, but sometimes you want to pause at various points of the animation to provide context or to emphasize a point.
18 |
19 | To pause at a given frame using a custom `renderer`, wrap `Paused` around the output:
20 |
21 | ```python
22 | from streamjoy import stream
23 |
24 | def plot_frame(time)
25 | important_time = ...
26 | if time == some_time:
27 | return Paused(fig, seconds=3)
28 | else:
29 | return fig
30 |
31 | stream(..., renderer=plot_frame)
32 | ```
33 |
34 | Don't forget there's also `intro_pause` and `ending_pause` to pause at the beginning and end of the animation!
35 |
36 | ## 📊 Reduce boilerplate code with `wrap_*` decorators
37 |
38 | If you're using a custom `renderer`, you can use `wrap_matplotlib` and `wrap_holoviews` to automatically handle saving and closing the figure.
39 |
40 | ```python
41 | from streamjoy import stream, wrap_matplotlib
42 |
43 | @wrap_matplotlib()
44 | def plot_frame(time):
45 | ...
46 |
47 | stream(..., renderer=plot_frame)
48 | ```
49 |
50 | ## 🗣️ Provide context with `intro_title` and `intro_subtitle`
51 |
52 | Use `intro_title` and `intro_subtitle` to provide context at the beginning of the animation.
53 |
54 | ```python
55 | from streamjoy import stream
56 |
57 | stream(..., intro_title="Himawari Visible", intro_subtitle="10 Hours Loop")
58 | ```
59 |
60 | ## 💾 Write animation to memory instead of file
61 |
62 | If you're just testing out the animation, you can save it to memory instead of to disk by calling write without specifying a uri.
63 |
64 | ```python
65 | from streamjoy import stream
66 |
67 | stream(...).write()
68 | ```
69 |
70 | ## 🚪 Use as a method of `pandas` and `xarray` objects
71 |
72 | StreamJoy can be used directly from `pandas` and `xarray` objects as an accessor.
73 |
74 | ```python
75 | import pandas as pd
76 | import streamjoy.pandas
77 |
78 | df = pd.DataFrame(...)
79 |
80 | # equivalent to streamjoy.stream(df)
81 | df.streamjoy(...)
82 |
83 | # series also works!
84 | df["col"].streamjoy(...)
85 | ```
86 |
87 | ```python
88 | import xarray as xr
89 | import streamjoy.xarray
90 |
91 | ds = xr.Dataset(...)
92 |
93 | # equivalent to streamjoy.stream(ds)
94 | ds.streamjoy(...)
95 |
96 | # dataarray also works!
97 | ds["var"].streamjoy(...)
98 | ```
99 |
100 | ## ⛓️ Join streams with `sum` and `connect`
101 |
102 | Use `sum` to join homogeneous streams, i.e. streams that have the same keyword arguments.
103 |
104 | ```python
105 | from streamjoy import stream, sum
106 |
107 | sum([stream(..., **same_kwargs) for i in range(10)])
108 | ```
109 |
110 | Use `connect` to join heterogeneous streams, i.e. streams that have different keyword arguments, like different `intro_title` and `intro_subtitle`.
111 |
112 | ```python
113 | from streamjoy import stream, connect
114 |
115 | connect([stream(..., **kwargs1), stream(..., **kwargs2)])
116 | ```
117 |
118 | ## 🎥 Decide between writing as `.mp4` vs `.gif`
119 |
120 | If you need a comprehensive color palette, use `.mp4` as it supports more colors.
121 |
122 | For automatic playing and looping, use `.gif`. To reduce the file size of the `.gif`, set `optimize=True`, which uses `pygifsicle` to reduce the file size.
123 |
124 | ## 📦 Prevent `RuntimeError` by using `__name__ == "__main__"`
125 |
126 | If you run a `.py` script without it, you might encounter the following `RuntimeError`:
127 |
128 | ```python
129 | RuntimeError:
130 | An attempt has been made to start a new process before the
131 | current process has finished its bootstrapping phase.
132 |
133 | This probably means that you are not using fork to start your
134 | child processes and you have forgotten to use the proper idiom
135 | in the main module:
136 |
137 | if __name__ == '__main__':
138 | freeze_support()
139 | ...
140 |
141 | The "freeze_support()" line can be omitted if the program
142 | is not going to be frozen to produce an executable.
143 | ```
144 |
145 | To patch, simply wrap your `stream` call in `if __name__ == "__main__":`.
146 |
147 | ```python
148 | if __name__ == "__main__":
149 | stream(...)
150 | ```
151 |
152 | It's fine without it in notebooks though.
153 |
154 | ## ⚙️ Set your own default settings with `config`
155 |
156 | StreamJoy uses a simple `config` dict to store settings. You can change the default settings by modifying the `streamjoy.config` object.
157 |
158 | To see the available options:
159 | ```python
160 | import streamjoy
161 |
162 | print(streamjoy.config)
163 | ```
164 |
165 | To change the settings, it's simply updating the key-value pair.
166 | ```python
167 | import streamjoy
168 |
169 | streamjoy.config["max_frames"] = -1
170 | ```
171 |
172 | Be wary of completely overwriting the `config` object, as it might break the functionality; do not do this!
173 | ```python
174 | import streamjoy
175 |
176 | streamjoy.config = {"max_frames": -1}
177 | ```
178 |
179 | ## 🔧 Use custom values instead of the defaults
180 |
181 | Much of StreamJoy is based on sensible defaults to get you started quickly, but you should override them.
182 |
183 | For example, `max_frames` is set to 50 by default so you can quickly preview the animation. If you want to render the entire animation, set `max_frames=-1`.
184 |
185 | StreamJoy will warn you on some settings if you don't override them:
186 |
187 | ```python
188 | No 'max_frames' specified; using the default 50 / 100 frames. Pass `-1` to use all frames. Suppress this by passing 'max_frames'.
189 | ```
190 |
191 | ## 🧩 Render HoloViews objects with `processes=False`
192 |
193 | This is done automatically! However, in case there's an edge case, note that the kdims/vdims don't seem to carry over properly to the subprocesses when rendering HoloViews objects. It might complain that it can't find the desired dimensions.
194 |
195 | ## 📚 Prevent flickering by setting `threads_per_worker`
196 |
197 | Matplotlib is not always thread-safe, so if you're seeing flickering, set `threads_per_worker=1`.
198 |
199 | ```python
200 | from streamjoy import stream
201 |
202 | stream(..., threads_per_worker=1)
203 | ```
204 |
205 | ## 🖥️ Provide `client` if using a remote cluster
206 |
207 | If you're using a remote cluster, specify the `client` argument to use the Dask client.
208 |
209 | ```python
210 | from dask.distributed import Client
211 |
212 | client = Client()
213 | stream(..., client=client)
214 | ```
215 |
216 | ## 🪣 Read & write files on a remote filesystem with `fsspec_fs`
217 |
218 | To read and write files on a remote filesystem, use `fsspec_fs` to specify the filesystem.
219 |
220 | A scratch directory must be provided; be sure to prefix the bucket name.
221 |
222 | ```python
223 | fs = fsspec.filesystem('s3', anon=False)
224 | stream(..., fsspec_fs=fs, scratch_dir="bucket-name/streamjoy_scratch")
225 | ```
226 |
227 | ## 🚗 Use a custom webdriver to render HoloViews
228 |
229 | By default, StreamJoy uses Firefox as the default headless webdriver to render HoloViews objects into images.
230 |
231 | If you want to use Chrome instead, you can pass `webdriver="chrome"`.
232 |
233 | If you want to use a different webdriver, you can pass a custom function to `webdriver`.
234 |
235 | ```python
236 | def get_webdriver():
237 | from selenium.webdriver.firefox.options import Options
238 | from selenium.webdriver.firefox.webdriver import Service, WebDriver
239 | from webdriver_manager.firefox import GeckoDriverManager
240 |
241 | options = Options()
242 | options.add_argument("--headless")
243 | options.add_argument("--disable-extensions")
244 | executable_path = GeckoDriverManager().install()
245 | driver = WebDriver(
246 | service=Service(executable_path), options=options
247 | )
248 | return driver
249 |
250 | stream(..., webdriver=get_webdriver)
251 | ```
252 |
--------------------------------------------------------------------------------
/streamjoy/ui.py:
--------------------------------------------------------------------------------
1 | import re
2 | from io import BytesIO
3 |
4 | try:
5 | import param
6 | import panel as pn
7 |
8 | pn.extension(notifications=True)
9 | except ImportError:
10 | raise ImportError(
11 | "StreamJoy UI additionally requires panel"
12 | "run `pip install 'streamjoy[ui]'` to install."
13 | )
14 |
15 | from .core import stream
16 |
17 |
18 | class App(pn.viewable.Viewer):
19 |
20 | url = param.String(
21 | label="URL",
22 | default="https://noaadata.apps.nsidc.org/NOAA/G02135/north/daily/images/2024/01_Jan/",
23 | )
24 |
25 | max_files = param.Integer(bounds=(0, 1000), default=10)
26 |
27 | pattern = param.String(default="N_202401{DAY:02d}_conc_v3.0.png")
28 |
29 | pattern_inputs_start = param.Integer(bounds=(0, 1000), default=1)
30 |
31 | pattern_inputs_end = param.Integer(bounds=(0, 1000), default=10)
32 |
33 | pattern_inputs = param.Dict()
34 |
35 | extension = param.Selector(objects=[".gif", ".html"], default=".html")
36 |
37 | def __init__(self, **params):
38 | super().__init__(**params)
39 | url_input = pn.widgets.TextInput.from_param(
40 | self.param.url, placeholder="Enter URL"
41 | )
42 | max_files_input = pn.widgets.Spinner.from_param(
43 | self.param.max_files, name="Max Files"
44 | )
45 | pattern_input = pn.widgets.TextInput.from_param(
46 | self.param.pattern, placeholder="Enter pattern (e.g. *.png or {0}.png)"
47 | )
48 | pattern_inputs_simple = pn.WidgetBox(
49 | pn.widgets.Spinner.from_param(
50 | self.param.pattern_inputs_start, name="Start of {}"
51 | ),
52 | pn.widgets.Spinner.from_param(
53 | self.param.pattern_inputs_end, name="End of {}"
54 | ),
55 | )
56 | pattern_inputs_editor = pn.widgets.JSONEditor.from_param(
57 | self.param.pattern_inputs,
58 | mode="form",
59 | value={"i": [0]},
60 | sizing_mode="stretch_width",
61 | search=False,
62 | menu=False,
63 | )
64 | self._pattern_inputs_tabs = pn.Tabs(
65 | ("Simple", pattern_inputs_simple),
66 | # ("Editor", pattern_inputs_editor),
67 | )
68 | self._pattern_preview = pn.pane.HTML("Preview: ")
69 | self._pattern_view = pn.Column(
70 | pn.pane.HTML("Pattern Inputs"),
71 | self._pattern_inputs_tabs,
72 | self._pattern_preview,
73 | )
74 | input_widgets = pn.Card(
75 | url_input,
76 | pattern_input,
77 | max_files_input,
78 | self._pattern_view,
79 | title="Inputs",
80 | sizing_mode="stretch_width"
81 | )
82 | submit_button = pn.widgets.Button(
83 | on_click=self._on_submit,
84 | name="Submit",
85 | sizing_mode="stretch_width",
86 | button_type="success",
87 | )
88 | self._download_button = pn.widgets.FileDownload(
89 | filename="streamjoy.html",
90 | callback=self._download,
91 | sizing_mode="stretch_width",
92 | button_type="primary",
93 | disabled=True,
94 | )
95 | extension_input = pn.widgets.Select.from_param(
96 | self.param.extension, sizing_mode="stretch_width"
97 | )
98 | self._sidebar = pn.Column(
99 | pn.Row(submit_button, self._download_button),
100 | extension_input,
101 | input_widgets,
102 | )
103 | self._main = pn.Column()
104 | self._dashboard = pn.template.FastListTemplate(
105 | title="StreamJoy",
106 | sidebar=[self._sidebar],
107 | main=[self._main],
108 | )
109 | self._update_pattern_preview()
110 |
111 | def _extract_templates(self, pattern):
112 | pattern_formats = re.search(r"{(\w+)", pattern)
113 | return pattern_formats
114 |
115 | @param.depends("pattern", watch=True)
116 | def _update_pattern_inputs(self):
117 | pattern = self.pattern
118 | pattern_formats = self._extract_templates(pattern)
119 |
120 | if pattern_formats is not None:
121 | self._pattern_view.visible = True
122 | pattern_inputs_simple = self._pattern_inputs_tabs[0]
123 | # pattern_inputs_editor = self._pattern_inputs_tabs[1]
124 | try:
125 | pattern_formats.group(2)
126 | pn.state.notifications.error(f"Only one pattern format is allowed.")
127 | pattern_inputs_simple.disabled = True
128 | # pattern_inputs_editor.disabled = True
129 | except IndexError:
130 | pass
131 | pattern_format_key = pattern_formats.group(1)
132 | pattern_inputs_simple.disabled = False
133 | # pattern_inputs_editor.disabled = False
134 | pattern_inputs_simple[0].name = f"Start of {pattern_format_key}"
135 | # pattern_inputs_simple[1].name = f"End of {pattern_format_key}"
136 | else:
137 | self._pattern_view.visible = False
138 |
139 | @param.depends("pattern", "pattern_inputs_start", "pattern_inputs_end", watch=True)
140 | def _update_pattern_preview(self):
141 | pattern = self.pattern
142 | pattern_formats = self._extract_templates(pattern)
143 | if pattern_formats is not None:
144 | pattern_format_key = pattern_formats.group(1)
145 | pattern_inputs_start = self.pattern_inputs_start
146 | pattern_inputs_end = self.pattern_inputs_end
147 | pattern_start = pattern.format(**{pattern_format_key: pattern_inputs_start})
148 | pattern_end = pattern.format(**{pattern_format_key: pattern_inputs_end})
149 | self._pattern_preview.object = (
150 | f"Preview:
"
151 | f"{pattern_start}"
152 | f"
...to...
"
153 | f"{pattern_end}"
154 | )
155 |
156 | def _on_submit(self, event):
157 | with self._sidebar.param.update(loading=True):
158 | if self.url:
159 | stream_kwargs = {}
160 | if self._pattern_view.visible:
161 | pattern = self.pattern
162 | pattern_formats = self._extract_templates(pattern)
163 | if pattern_formats is not None:
164 | url = self.url
165 | if not url.endswith("/"):
166 | url += "/"
167 | pattern_format_key = pattern_formats.group(1)
168 | pattern_inputs_start = self.pattern_inputs_start
169 | pattern_inputs_end = self.pattern_inputs_end
170 | resources = []
171 | for pattern_input in range(
172 | pattern_inputs_start, pattern_inputs_end + 1
173 | ):
174 | resource = self.url + pattern.format(
175 | **{pattern_format_key: pattern_input}
176 | )
177 | resources.append(resource)
178 | else:
179 | resources = self.url
180 | stream_kwargs["pattern"] = self.pattern
181 | stream_kwargs["max_files"] = self.max_files
182 |
183 | if self.extension == ".html":
184 | stream_kwargs["ending_pause"] = 0
185 |
186 | output = stream(
187 | resources, extension=self.extension, **stream_kwargs
188 | ).write()
189 |
190 | if self.extension == ".html":
191 | self._main.objects = [output]
192 | buf = BytesIO()
193 | output.save(buf)
194 | self._buf = buf
195 | else:
196 | self._main.objects = [pn.pane.GIF(output)]
197 | self._buf = output
198 | self._download_button.disabled = False
199 |
200 | def _download(self):
201 | self._download_button.filename = f"streamjoy{self.extension}"
202 | self._buf.seek(0)
203 | return self._buf
204 |
205 | def serve(self, port: int = 8888, show: bool = True, **kwargs):
206 | pn.serve(self.__panel__(), port=port, show=show, **kwargs)
207 |
208 | def __panel__(self):
209 | return self._dashboard
210 |
--------------------------------------------------------------------------------
/streamjoy/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import inspect
4 | import logging
5 | import os
6 | from collections.abc import Iterable
7 | from io import BytesIO
8 | from itertools import islice
9 | from pathlib import Path
10 | from typing import TYPE_CHECKING, Any, Callable
11 |
12 | import imageio.v3 as iio
13 | import numpy as np
14 | import param
15 | from dask.distributed import Client, Future, get_client
16 |
17 | from .models import Paused
18 | from .settings import config
19 |
20 | if TYPE_CHECKING:
21 | try:
22 | import xarray as xr
23 | except ImportError:
24 | xr = None
25 |
26 | try:
27 | from selenium.webdriver.remote.webdriver import BaseWebDriver
28 | except ImportError:
29 | BaseWebDriver = None
30 |
31 |
32 | def update_logger(
33 | level: str | None = None,
34 | format: str | None = None,
35 | datefmt: str | None = None,
36 | ) -> logging.Logger:
37 | success_level = config["logging_success_level"]
38 |
39 | class CustomLogger(logging.Logger):
40 | def success(self, message, *args, **kws):
41 | if self.isEnabledFor(success_level):
42 | self._log(success_level, message, args, **kws)
43 |
44 | color = config["logging_success_color"]
45 | reset = config["logging_reset_color"]
46 | logging.setLoggerClass(CustomLogger)
47 | logging.addLevelName(success_level, f"{color}SUCCESS{reset}")
48 | logger = logging.getLogger(__name__)
49 |
50 | level = level or config["logging_level"]
51 | format = format or config["logging_format"]
52 | datefmt = datefmt or config["logging_datefmt"]
53 | for handler in logging.getLogger().handlers:
54 | handler.setLevel(level)
55 | handler.setFormatter(logging.Formatter(format, datefmt))
56 |
57 | return logger
58 |
59 |
60 | def warn_default_used(
61 | key: str, default_value: Any, total_value: Any | None = None, suffix: str = ""
62 | ) -> None:
63 | color = config["logging_warning_color"]
64 | reset = config["logging_reset_color"]
65 | message = (
66 | f"No {color}{key!r}{reset} specified; using the default "
67 | f"{color}{default_value!r}{reset}"
68 | )
69 | if total_value is not None:
70 | message += f" / {total_value!r}"
71 | if suffix:
72 | message += f" {suffix}"
73 | message += f". Suppress this by passing {key!r}."
74 |
75 | if total_value is not None and default_value < total_value:
76 | logging.warning(message)
77 | elif total_value is None:
78 | logging.warning(message)
79 |
80 |
81 | def get_config_default(
82 | key: str,
83 | value: Any,
84 | warn: bool = True,
85 | require: bool = True,
86 | config_prefix: str = "",
87 | **warn_kwargs,
88 | ) -> Any:
89 | config_alias = f"{config_prefix}_{key}" if config_prefix else key
90 | if require and config_alias not in config:
91 | raise ValueError(f"Missing required config key: {config_alias}")
92 |
93 | if value is None:
94 | default = config[config_alias]
95 | if warn:
96 | value = warn_default_used(key, default, **warn_kwargs)
97 | value = default
98 | return value
99 |
100 |
101 | def populate_config_defaults(
102 | params: dict[str, Any],
103 | keys: param.Parameterized,
104 | warn_on: list[str] | None = None,
105 | config_prefix: str = "",
106 | ):
107 | for key in keys:
108 | config_alias = f"{config_prefix}_{key}" if config_prefix else key
109 | if config_alias not in config.keys():
110 | continue
111 | warn = key in (warn_on or [])
112 | value = get_config_default(
113 | key, params.get(key), warn=warn, config_prefix=config_prefix
114 | )
115 | params[key] = value
116 | params = {key: value for key, value in params.items() if value is not None}
117 | return params
118 |
119 |
120 | def get_distributed_client(client: Client | None = None, **kwargs) -> Client:
121 | if client is not None:
122 | return client
123 |
124 | try:
125 | client = get_client()
126 | except ValueError:
127 | client = Client(**kwargs)
128 | return client
129 |
130 |
131 | def download_file(
132 | url: str,
133 | scratch_dir: Path | None = None,
134 | in_memory: bool = False,
135 | parent_depth: int | None = None,
136 | ) -> str:
137 | try:
138 | import requests
139 | except ImportError:
140 | raise ImportError("To directly read from a URL, `pip install requests`")
141 |
142 | url_path = Path(url)
143 | file_name = url_path.name
144 |
145 | parent_depth = get_config_default("parent_depth", parent_depth, warn=False)
146 | for _ in range(parent_depth):
147 | url_path = url_path.parent
148 | file_name = f"{url_path.name}_{file_name}"
149 | uri = resolve_uri(file_name=file_name, scratch_dir=scratch_dir, in_memory=in_memory)
150 | if not isinstance(uri, BytesIO) and os.path.exists(uri):
151 | return uri
152 |
153 | response = requests.get(url, stream=True)
154 | response.raise_for_status()
155 | with open(uri, "wb") as f:
156 | for chunk in response.iter_content(chunk_size=8192):
157 | if chunk:
158 | f.write(chunk)
159 | return uri
160 |
161 |
162 | def get_max_frames(total_frames: int, max_frames: int) -> int:
163 | default_max_frames = config["max_frames"]
164 | if max_frames is None and total_frames > default_max_frames:
165 | warn_default_used(
166 | "max_frames",
167 | default_max_frames,
168 | total_value=total_frames,
169 | suffix="frames. Pass `-1` to use all frames",
170 | )
171 | max_frames = default_max_frames
172 | elif max_frames is None:
173 | max_frames = total_frames
174 | elif max_frames > total_frames:
175 | max_frames = total_frames
176 | elif max_frames == -1:
177 | max_frames = total_frames
178 | return max_frames
179 |
180 |
181 | def get_first(iterable):
182 | if isinstance(iterable, (list, tuple)):
183 | return iterable[0]
184 | return next(islice(iterable, 0, 1), None)
185 |
186 |
187 | def get_result(future: Future) -> Any:
188 | if isinstance(future, Future):
189 | return future.result()
190 | elif hasattr(future, "compute"):
191 | return future.compute()
192 | else:
193 | return future
194 |
195 |
196 | def using_notebook():
197 | try:
198 | from IPython import get_ipython
199 |
200 | if "IPKernelApp" not in get_ipython().config: # Check if under IPython kernel
201 | return False
202 | except Exception:
203 | return False
204 | return True
205 |
206 |
207 | def resolve_uri(
208 | file_name: str | None = None,
209 | scratch_dir: str | Path | None = None,
210 | in_memory: bool = False,
211 | fsspec_fs: Any | None = None,
212 | ) -> str | Path | BytesIO:
213 | if in_memory:
214 | return BytesIO()
215 |
216 | output_dir = get_config_default("scratch_dir", scratch_dir, warn=False)
217 | if fsspec_fs:
218 | try:
219 | fsspec_fs.mkdir(output_dir, exist_ok=True, parents=True)
220 | except FileExistsError:
221 | pass
222 | uri = os.path.join(output_dir, file_name)
223 | else:
224 | output_dir = Path(output_dir)
225 | output_dir.mkdir(exist_ok=True, parents=True)
226 | uri = output_dir / file_name
227 | return uri
228 |
229 |
230 | def pop_kwargs(callable: Callable, kwargs: dict) -> None:
231 | args_spec = inspect.getfullargspec(callable)
232 | return {arg: kwargs.pop(arg) for arg in args_spec.args if arg in kwargs}
233 |
234 |
235 | def pop_from_cls(cls: type, kwargs: dict) -> dict:
236 | return {
237 | key: kwargs.pop(key) for key in set(kwargs) if key not in cls.param.values()
238 | }
239 |
240 |
241 | def import_function(import_path: str) -> Callable:
242 | module, function = import_path.rsplit(".", 1)
243 | module = __import__(module, fromlist=[function])
244 | return getattr(module, function)
245 |
246 |
247 | def validate_xarray(
248 | ds: xr.Dataset | xr.DataArray,
249 | dim: str | None = None,
250 | var: str | None = None,
251 | warn: bool = True,
252 | raise_ndim: bool = True,
253 | ):
254 | import xarray as xr
255 |
256 | if var:
257 | ds = ds[var]
258 | elif isinstance(ds, xr.Dataset):
259 | var = list(ds.data_vars)[0]
260 | if warn:
261 | warn_default_used("var", var, suffix="from the dataset")
262 | ds = ds[var]
263 |
264 | squeeze_dims = [d for d in ds.dims if d != dim and ds.sizes[d] == 1]
265 | ds = ds.squeeze(squeeze_dims)
266 | if ds.ndim > 3 and raise_ndim:
267 | raise ValueError(f"Can only handle 3D arrays; {ds.ndim}D array found")
268 | return ds
269 |
270 |
271 | def validate_renderer_iterables(
272 | resources: list[Any],
273 | iterables: list[list[Any]],
274 | ):
275 | if iterables is None:
276 | return
277 |
278 | num_iterables = len(iterables)
279 | num_resources = len(resources)
280 |
281 | if num_iterables == num_resources:
282 | logging.warning(
283 | "The length of the iterables matches the length of the resources. "
284 | "This is likely not what you want; the iterables should be a list of lists, "
285 | "where each inner list corresponds to the arguments for each frame."
286 | )
287 |
288 | if not isinstance(iterables[0], Iterable) or isinstance(iterables[0], str):
289 | raise TypeError(
290 | "Iterables should be like a list of lists, where each inner list corresponds "
291 | "to the arguments for each frame; e.g. `[[arg1_for_frame1, arg1_for_frame2], "
292 | "[arg2_for_frame_1, arg2_for_frame2]]`"
293 | )
294 |
295 |
296 | def map_over(client, func, resources, batch_size, *args, **kwargs):
297 | try:
298 | return client.map(func, resources, *args, batch_size=batch_size, **kwargs)
299 | except TypeError:
300 | return [
301 | client.submit(func, resource, *args, **kwargs) for resource in resources
302 | ]
303 |
304 |
305 | def repeat_frame(
306 | write: Callable, image: np.ndarray, seconds: int, fps: int, **write_kwargs
307 | ) -> np.ndarray:
308 | if seconds == 0:
309 | return image
310 |
311 | repeat = int(seconds * fps)
312 | for _ in range(repeat):
313 | write(image, **write_kwargs)
314 | return image
315 |
316 |
317 | def imread_with_pause(
318 | uri: Any | Paused,
319 | extension: str | None = None,
320 | plugin: str | None = None,
321 | fsspec_fs: Any | None = None,
322 | ) -> np.ndarray | Paused:
323 | imread_kwargs = dict(extension=extension, plugin=plugin)
324 | seconds = None
325 | if isinstance(uri, Paused):
326 | seconds = uri.seconds
327 | uri = uri.output
328 | if fsspec_fs:
329 | with fsspec_fs.open(uri, "rb") as f:
330 | image = iio.imread(f, **imread_kwargs)
331 | else:
332 | image = iio.imread(uri, **imread_kwargs).squeeze()
333 | image = image.squeeze()
334 | if seconds is None:
335 | return image
336 | else:
337 | return Paused(output=image, seconds=seconds)
338 |
339 |
340 | def subset_resources_renderer_iterables(
341 | resources: Any, renderer_iterables: list[Any], max_frames: int
342 | ):
343 | resources = resources[: max_frames or max_frames]
344 | renderer_iterables = [
345 | iterable[: len(resources)] for iterable in renderer_iterables or []
346 | ]
347 | return resources, renderer_iterables
348 |
349 |
350 | def get_webdriver_path(webdriver: str):
351 | if webdriver.lower() == "chrome":
352 | from webdriver_manager.chrome import ChromeDriverManager
353 |
354 | webdriver_path = ChromeDriverManager().install()
355 | elif webdriver.lower() == "firefox":
356 | from webdriver_manager.firefox import GeckoDriverManager
357 |
358 | webdriver_path = GeckoDriverManager().install()
359 | return webdriver_path
360 |
361 |
362 | def get_webdriver(webdriver: tuple[str, str] | Callable) -> BaseWebDriver:
363 | if isinstance(webdriver, Callable):
364 | return webdriver()
365 |
366 | webdriver_key, webdriver_path = webdriver
367 | if webdriver_key.lower() == "chrome":
368 | from selenium.webdriver.chrome.options import Options
369 | from selenium.webdriver.chrome.webdriver import Service, WebDriver
370 |
371 | options = Options()
372 | options.add_argument("--headless")
373 | options.add_argument("--disable-extensions")
374 | webdriver_path = webdriver_path or get_webdriver_path("chrome")
375 | driver = WebDriver(service=Service(webdriver_path), options=options)
376 |
377 | elif webdriver_key.lower() == "firefox":
378 | from selenium.webdriver.firefox.options import Options
379 | from selenium.webdriver.firefox.webdriver import Service, WebDriver
380 |
381 | options = Options()
382 | options.add_argument("--headless")
383 | options.add_argument("--disable-extensions")
384 | webdriver_path = webdriver_path or get_webdriver_path("firefox")
385 | driver = WebDriver(service=Service(webdriver_path), options=options)
386 |
387 | else:
388 | raise NotImplementedError(
389 | f"Webdriver {webdriver_key} not supported; "
390 | f"use 'chrome' or 'firefox', or pass a custom callable."
391 | )
392 |
393 | return driver
394 |
--------------------------------------------------------------------------------
/streamjoy/serializers.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | from inspect import isgenerator
5 | from pathlib import Path
6 | from typing import TYPE_CHECKING, Any, Callable
7 |
8 | import numpy as np
9 |
10 | from . import _utils
11 | from .models import Serialized
12 | from .renderers import (
13 | default_holoviews_renderer,
14 | default_pandas_renderer,
15 | default_polars_renderer,
16 | default_xarray_renderer,
17 | )
18 | from .settings import file_handlers, obj_handlers
19 | from .wrappers import wrap_holoviews, wrap_matplotlib
20 |
21 | if TYPE_CHECKING:
22 | try:
23 | import pandas as pd
24 | except ImportError:
25 | pd = None
26 |
27 | try:
28 | import polars as pl
29 | except ImportError:
30 | pl = None
31 |
32 | try:
33 | import xarray as xr
34 | except ImportError:
35 | xr = None
36 |
37 | try:
38 | import holoviews as hv
39 | except ImportError:
40 | hv = None
41 |
42 | from .streams import MediaStream
43 |
44 |
45 | def _select_obj_handler(resources: Any) -> MediaStream:
46 | if isinstance(resources, str) and "://" in resources:
47 | return serialize_url
48 | if isinstance(resources, (Path, str)):
49 | return serialize_paths
50 |
51 | resources_type = type(resources)
52 | module = getattr(resources_type, "__module__").split(".", maxsplit=1)[0]
53 | type_ = resources_type.__name__
54 | for class_or_package_name, function_name in obj_handlers.items():
55 | if (
56 | f"{module}.{type_}" == class_or_package_name
57 | or module == class_or_package_name
58 | ):
59 | return globals()[function_name]
60 |
61 | raise ValueError(
62 | f"Could not find a method to handle {resources_type}; "
63 | f"supported classes/packages are {list(obj_handlers.keys())}."
64 | )
65 |
66 |
67 | def serialize_numpy(
68 | stream_cls,
69 | resources: np.ndarray,
70 | renderer: Callable | None = None,
71 | renderer_iterables: list[Any] | None = None,
72 | renderer_kwargs: dict | None = None,
73 | **kwargs,
74 | ) -> Serialized:
75 | """
76 | Serialize numpy arrays for streaming or rendering.
77 |
78 | Args:
79 | stream_cls: The class reference used for logging and utility functions.
80 | resources: The numpy array to be serialized.
81 | renderer: The rendering function to use on the array.
82 | renderer_iterables: Additional iterable arguments to pass to the renderer.
83 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
84 | **kwargs: Additional keyword arguments, including 'dim' and 'var' for xarray selection.
85 |
86 | Returns:
87 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
88 | """
89 | resources = [resource for resource in resources]
90 | renderer_kwargs = renderer_kwargs or {}
91 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
92 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
93 |
94 |
95 | def serialize_xarray(
96 | stream_cls,
97 | resources: xr.Dataset | xr.DataArray,
98 | renderer: Callable | None = None,
99 | renderer_iterables: list[Any] | None = None,
100 | renderer_kwargs: dict | None = None,
101 | **kwargs,
102 | ) -> Serialized:
103 | """
104 | Serialize xarray datasets or data arrays for streaming or rendering.
105 |
106 | Args:
107 | stream_cls: The class reference used for logging and utility functions.
108 | resources: The xarray dataset or data array to be serialized.
109 | renderer: The rendering function to use on the dataset.
110 | renderer_iterables: Additional iterable arguments to pass to the renderer.
111 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
112 | **kwargs: Additional keyword arguments, including 'dim' and 'var' for xarray selection.
113 |
114 | Returns:
115 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
116 | """
117 |
118 | ds = resources
119 | dim = kwargs.pop("dim", None)
120 | var = kwargs.pop("var", None)
121 |
122 | ds = _utils.validate_xarray(ds, dim=dim, var=var, raise_ndim=renderer is None)
123 | if not dim:
124 | dim = list(ds.dims)[0]
125 | _utils.warn_default_used("dim", dim, suffix="from the dataset")
126 | elif dim not in ds.dims:
127 | raise ValueError(f"{dim!r} not in {ds.dims!r}")
128 |
129 | total_frames = len(ds[dim])
130 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames"))
131 | resources = [ds.isel({dim: i}) for i in range(max_frames)]
132 |
133 | renderer_kwargs = renderer_kwargs or {}
134 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
135 |
136 | if renderer is None:
137 | renderer = wrap_matplotlib(
138 | in_memory=kwargs.get("in_memory"),
139 | scratch_dir=kwargs.get("scratch_dir"),
140 | fsspec_fs=kwargs.get("fsspec_fs"),
141 | )(default_xarray_renderer)
142 | ds_0 = resources[0]
143 | if ds_0.ndim >= 2:
144 | renderer_kwargs["vmin"] = renderer_kwargs.get(
145 | "vmin", _utils.get_result(ds_0.min()).item()
146 | )
147 | renderer_kwargs["vmax"] = renderer_kwargs.get(
148 | "vmax", _utils.get_result(ds_0.max()).item()
149 | )
150 | else:
151 | renderer_kwargs["ylim"] = renderer_kwargs.get(
152 | "ylim",
153 | (
154 | _utils.get_result(ds_0.min()).item(),
155 | _utils.get_result(ds_0.max()).item(),
156 | ),
157 | )
158 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
159 |
160 |
161 | def serialize_pandas(
162 | stream_cls,
163 | resources: pd.DataFrame,
164 | renderer: Callable | None = None,
165 | renderer_iterables: list[Any] | None = None,
166 | renderer_kwargs: dict | None = None,
167 | **kwargs,
168 | ) -> Serialized:
169 | """
170 | Serialize pandas DataFrame for streaming or rendering.
171 |
172 | Args:
173 | stream_cls: The class reference used for logging and utility functions.
174 | resources: The pandas DataFrame to be serialized.
175 | renderer: The rendering function to use on the DataFrame.
176 | renderer_iterables: Additional iterable arguments to pass to the renderer.
177 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
178 | **kwargs: Additional keyword arguments, including 'groupby' for DataFrame grouping.
179 |
180 | Returns:
181 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
182 | """
183 | import pandas as pd
184 |
185 | df = resources
186 | groupby = kwargs.get("groupby")
187 |
188 | total_frames = df.groupby(groupby).size().max() if groupby else len(df)
189 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames"))
190 | resources = [
191 | df.groupby(groupby, as_index=False).head(i) if groupby else df.head(i)
192 | for i in range(1, max_frames + 1)
193 | ]
194 |
195 | renderer_kwargs = renderer_kwargs or {}
196 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
197 |
198 | if renderer is None:
199 | renderer = wrap_matplotlib(
200 | in_memory=kwargs.get("in_memory"),
201 | scratch_dir=kwargs.get("scratch_dir"),
202 | fsspec_fs=kwargs.get("fsspec_fs"),
203 | )(default_pandas_renderer)
204 | if "x" not in renderer_kwargs:
205 | if df.index.name or isinstance(df, pd.Series):
206 | renderer_kwargs["x"] = df.index.name
207 | else:
208 | for col in df.columns:
209 | if col != groupby:
210 | break
211 | renderer_kwargs["x"] = col
212 | _utils.warn_default_used(
213 | "x", renderer_kwargs["x"], suffix="from the dataframe"
214 | )
215 | if "y" not in renderer_kwargs:
216 | if isinstance(df, pd.Series):
217 | col = df.name
218 | else:
219 | numeric_cols = df.select_dtypes(include="number").columns
220 | for col in numeric_cols:
221 | if col not in (renderer_kwargs["x"], groupby):
222 | break
223 | renderer_kwargs["y"] = col
224 | _utils.warn_default_used(
225 | "y", renderer_kwargs["y"], suffix="from the dataframe"
226 | )
227 | if "xlabel" not in renderer_kwargs:
228 | renderer_kwargs["xlabel"] = renderer_kwargs["x"].title().replace("_", " ")
229 | if "ylabel" not in renderer_kwargs:
230 | renderer_kwargs["ylabel"] = renderer_kwargs["y"].title().replace("_", " ")
231 |
232 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
233 |
234 |
235 | def serialize_polars(
236 | stream_cls,
237 | resources: pl.DataFrame,
238 | renderer: Callable | None = None,
239 | renderer_iterables: list[Any] | None = None,
240 | renderer_kwargs: dict | None = None,
241 | **kwargs,
242 | ) -> Serialized:
243 | """
244 | Serialize Polars DataFrame for streaming or rendering.
245 |
246 | Args:
247 | stream_cls: The class reference used for logging and utility functions.
248 | resources: The Polars DataFrame to be serialized.
249 | renderer: The rendering function to use on the DataFrame.
250 | renderer_iterables: Additional iterable arguments to pass to the renderer.
251 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
252 | **kwargs: Additional keyword arguments, including 'groupby' for DataFrame grouping.
253 |
254 | Returns:
255 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
256 | """
257 | import polars as pl
258 |
259 | groupby = kwargs.get("groupby")
260 |
261 | if groupby:
262 | group_sizes = resources.groupby(groupby).agg(pl.len())
263 | total_frames = group_sizes.select(pl.col("len").max()).to_numpy()[0, 0]
264 | else:
265 | total_frames = len(resources)
266 |
267 | max_frames = _utils.get_max_frames(total_frames, kwargs.get("max_frames"))
268 | resources_expanded = [
269 | resources.groupby(groupby).head(i) if groupby else resources.head(i)
270 | for i in range(1, max_frames + 1)
271 | ]
272 |
273 | renderer_kwargs = renderer_kwargs or {}
274 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
275 |
276 | if renderer is None:
277 | renderer = wrap_holoviews(
278 | in_memory=kwargs.get("in_memory"),
279 | scratch_dir=kwargs.get("scratch_dir"),
280 | fsspec_fs=kwargs.get("fsspec_fs"),
281 | webdriver=renderer_kwargs.pop("webdriver", None),
282 | )(default_polars_renderer)
283 | numeric_cols = [
284 | col
285 | for col in resources.columns
286 | if resources[col].dtype in [pl.Float64, pl.Int64, pl.Float32, pl.Int32]
287 | ]
288 | if "x" not in renderer_kwargs:
289 | for col in numeric_cols:
290 | if col != groupby:
291 | renderer_kwargs["x"] = col
292 | break
293 | _utils.warn_default_used(
294 | "x", renderer_kwargs["x"], suffix="from the dataframe"
295 | )
296 | if "y" not in renderer_kwargs:
297 | for col in numeric_cols:
298 | if col not in (renderer_kwargs["x"], groupby):
299 | renderer_kwargs["y"] = col
300 | break
301 | _utils.warn_default_used(
302 | "y", renderer_kwargs["y"], suffix="from the dataframe"
303 | )
304 | if "xlabel" not in renderer_kwargs:
305 | renderer_kwargs["xlabel"] = renderer_kwargs["x"].title().replace("_", " ")
306 | if "ylabel" not in renderer_kwargs:
307 | renderer_kwargs["ylabel"] = renderer_kwargs["y"].title().replace("_", " ")
308 |
309 | if kwargs.get("processes"):
310 | logging.warning(
311 | "Polars (HoloViews) rendering does not support processes; "
312 | "setting processes=False."
313 | )
314 | kwargs["processes"] = False
315 | return Serialized(
316 | resources_expanded, renderer, renderer_iterables, renderer_kwargs, kwargs
317 | )
318 |
319 |
320 | def serialize_holoviews(
321 | stream_cls,
322 | resources: hv.HoloMap | hv.DynamicMap,
323 | renderer: Callable | None = None,
324 | renderer_iterables: list[Any] | None = None,
325 | renderer_kwargs: dict | None = None,
326 | **kwargs,
327 | ) -> Serialized:
328 | """
329 | Serialize HoloViews objects for streaming or rendering.
330 |
331 | Args:
332 | stream_cls: The class reference used for logging and utility functions.
333 | resources: The HoloViews object to be serialized.
334 | renderer: The rendering function to use on the HoloViews object.
335 | renderer_iterables: Additional iterable arguments to pass to the renderer.
336 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
337 | **kwargs: Additional keyword arguments for HoloViews object customization.
338 |
339 | Returns:
340 | A tuple containing the serialized resources, renderer, renderer_iterables, renderer_kwargs, and any additional keyword arguments.
341 | """
342 | import holoviews as hv
343 |
344 | backend = kwargs.get("backend", hv.Store.current_backend)
345 |
346 | def _select_element(hv_obj, key):
347 | try:
348 | resource = hv_obj[key]
349 | except Exception:
350 | resource = hv_obj.select(**{kdims[0].name: key})
351 | return resource
352 |
353 | hv_obj = resources
354 | if isinstance(hv_obj, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)):
355 | hv_map = hv_obj
356 | elif issubclass(
357 | type(hv_obj), (hv.core.layout.Layoutable, hv.core.overlay.Overlayable)
358 | ):
359 | hv_map = hv_obj[0]
360 |
361 | if not isinstance(hv_map, (hv.core.spaces.DynamicMap, hv.core.spaces.HoloMap)):
362 | raise ValueError("Can only handle HoloMap and DynamicMap objects.")
363 | elif isinstance(hv_map, hv.core.spaces.DynamicMap):
364 | kdims = hv_map.kdims
365 | keys = hv_map.kdims[0].values
366 | else:
367 | kdims = hv_map.kdims
368 | keys = hv_map.keys()
369 |
370 | if len(kdims) > 1:
371 | raise ValueError("Can only handle 1D HoloViews objects.")
372 |
373 | resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys]
374 |
375 | renderer_kwargs = renderer_kwargs or {}
376 | renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
377 |
378 | if renderer is None:
379 | renderer = wrap_holoviews(
380 | in_memory=kwargs.get("in_memory"),
381 | scratch_dir=kwargs.get("scratch_dir"),
382 | fsspec_fs=kwargs.get("fsspec_fs"),
383 | webdriver=renderer_kwargs.pop("webdriver", None),
384 | )(default_holoviews_renderer)
385 | clims = {}
386 | for hv_el in hv_obj.traverse(full_breadth=False):
387 | if isinstance(hv_el, hv.DynamicMap):
388 | hv.render(hv_el, backend=backend)
389 |
390 | if isinstance(hv_el, hv.Element):
391 | if hv_el.ndims > 1:
392 | vdim = hv_el.vdims[0].name
393 | array = hv_el.dimension_values(vdim)
394 | clim = (np.nanmin(array), np.nanmax(array))
395 | clims[vdim] = clim
396 |
397 | renderer_kwargs.update(
398 | backend=backend,
399 | clims=clims,
400 | )
401 |
402 | if kwargs.get("processes"):
403 | logging.warning(
404 | "HoloViews rendering does not support processes; "
405 | "setting processes=False."
406 | )
407 | kwargs["processes"] = False
408 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
409 |
410 |
411 | def serialize_url(
412 | stream_cls,
413 | resources: str,
414 | renderer: Callable | None = None,
415 | renderer_iterables: list[Any] | None = None,
416 | renderer_kwargs: dict | None = None,
417 | **kwargs,
418 | ) -> Serialized:
419 | """
420 | Serialize resources from a URL for streaming or rendering.
421 |
422 | Args:
423 | stream_cls: The class reference used for logging and utility functions.
424 | resources: The URL of the resources to be serialized.
425 | renderer: The rendering function to use on the resources.
426 | renderer_iterables: Additional iterable arguments to pass to the renderer.
427 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
428 | **kwargs: Additional keyword arguments, including 'pattern', 'sort_key', 'max_files', 'file_handler', and 'file_handler_kwargs'.
429 |
430 | Returns:
431 | A MediaStream object containing the serialized resources.
432 | """
433 | import re
434 |
435 | import requests
436 | from bs4 import BeautifulSoup
437 |
438 | base_url = resources
439 | pattern = kwargs.pop("pattern", None)
440 | sort_key = kwargs.pop("sort_key", None)
441 | max_files = kwargs.pop("max_files", None)
442 | file_handler = kwargs.pop("file_handler", None)
443 | file_handler_kwargs = kwargs.pop("file_handler_kwargs", None)
444 |
445 | max_files = _utils.get_config_default("max_files", max_files, suffix="links")
446 |
447 | logging.info(f"Retrieving resources from {base_url!r}.")
448 |
449 | with requests.get(resources, stream=True) as response:
450 | response.raise_for_status()
451 |
452 | partial_html = ""
453 | content_type = response.headers.get("Content-Type")
454 | for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
455 | if chunk:
456 | if pattern is not None:
457 | partial_html += chunk
458 | soup = BeautifulSoup(partial_html, "html.parser")
459 | href = re.compile(pattern.replace("*", ".*"))
460 | links = soup.find_all("a", href=href)
461 | if max_files > 0 and len(links) >= max_files:
462 | break
463 | else:
464 | if content_type.startswith("text"):
465 | raise ValueError(
466 | f"A pattern must be provided if the URL is a directory of files; "
467 | f"got {resources!r}."
468 | )
469 | links = [{"href": ""}]
470 |
471 | if max_files > 0:
472 | links = links[:max_files]
473 |
474 | if len(links) == 0:
475 | raise ValueError(f"No links found with pattern {pattern!r} at {base_url!r}.")
476 |
477 | # download files
478 | urls = [base_url + link.get("href") for link in links]
479 | client = _utils.get_distributed_client(
480 | client=kwargs.get("client"),
481 | processes=kwargs.get("processes"),
482 | threads_per_worker=kwargs.get("threads_per_worker"),
483 | )
484 |
485 | logging.info(f"Downloading {len(urls)} files from {base_url!r}.")
486 | futures = _utils.map_over(
487 | client,
488 | _utils.download_file,
489 | urls,
490 | kwargs.get("batch_size"),
491 | scratch_dir=kwargs.get("scratch_dir"),
492 | in_memory=kwargs.get("in_memory"),
493 | )
494 | paths = client.gather(futures)
495 | return serialize_paths(
496 | stream_cls,
497 | paths,
498 | sort_key=sort_key,
499 | max_files=max_files,
500 | file_handler=file_handler,
501 | file_handler_kwargs=file_handler_kwargs,
502 | renderer=renderer,
503 | renderer_iterables=renderer_iterables,
504 | renderer_kwargs=renderer_kwargs,
505 | **kwargs,
506 | )
507 |
508 |
509 | def serialize_paths(
510 | stream_cls,
511 | resources: list[str | Path] | str,
512 | renderer: Callable | None = None,
513 | renderer_iterables: list[Any] | None = None,
514 | renderer_kwargs: dict | None = None,
515 | **kwargs,
516 | ) -> Serialized:
517 | """
518 | Serialize resources from file paths for streaming or rendering.
519 |
520 | Args:
521 | resources: A list of file paths or a single file path string to be serialized.
522 | renderer: The rendering function to use on the resources.
523 | renderer_iterables: Additional iterable arguments to pass to the renderer.
524 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
525 | **kwargs: Additional keyword arguments, including 'sort_key', 'max_files', 'file_handler', and 'file_handler_kwargs'.
526 |
527 | Returns:
528 | A MediaStream object containing the serialized resources.
529 | """
530 | if isinstance(resources, str):
531 | resources = [resources]
532 |
533 | sort_key = kwargs.pop("sort_key", None)
534 | max_files = kwargs.pop("max_files", None)
535 | file_handler = kwargs.pop("file_handler", None)
536 | file_handler_kwargs = kwargs.pop("file_handler_kwargs", None)
537 |
538 | paths = sorted(resources, key=sort_key)
539 |
540 | max_files = _utils.get_config_default(
541 | "max_files", max_files, total_value=len(paths), suffix="paths"
542 | )
543 | if max_files > 0:
544 | paths = paths[:max_files]
545 |
546 | # find a file handler
547 | extension = Path(paths[0]).suffix
548 | file_handler_meta = file_handlers.get(extension, {})
549 | file_handler_import_path = file_handler_meta.get("import_path")
550 | file_handler_concat_path = file_handler_meta.get("concat_path")
551 | if file_handler is None and file_handler_import_path is not None:
552 | file_handler = _utils.import_function(file_handler_import_path)
553 | if file_handler_concat_path is not None:
554 | file_handler_concat = _utils.import_function(file_handler_concat_path)
555 |
556 | # read as objects
557 | if file_handler is not None:
558 | if file_handler_concat_path is not None:
559 | resources = file_handler_concat(
560 | file_handler(path, **(file_handler_kwargs or {})) for path in paths
561 | )
562 | else:
563 | resources = file_handler(paths, **(file_handler_kwargs or {}))
564 | return serialize_appropriately(
565 | stream_cls,
566 | resources,
567 | renderer,
568 | renderer_iterables,
569 | renderer_kwargs,
570 | **kwargs,
571 | )
572 |
573 | # or simply return image paths
574 | return Serialized(paths, renderer, renderer_iterables, renderer_kwargs, kwargs)
575 |
576 |
577 | def serialize_appropriately(
578 | stream_cls,
579 | resources: Any,
580 | renderer: Callable,
581 | renderer_iterables: list[Any],
582 | renderer_kwargs: dict[str, Any],
583 | **kwargs,
584 | ) -> Serialized:
585 | """
586 | Automatically select the appropriate serialization method based on the type of resources.
587 |
588 | Args:
589 | stream_cls: The class reference used for logging and utility functions.
590 | resources: The resources to be serialized, which can be of any type.
591 | renderer: The rendering function to use on the resources.
592 | renderer_iterables: Additional iterable arguments to pass to the renderer.
593 | renderer_kwargs: Additional keyword arguments to pass to the renderer.
594 | **kwargs: Additional keyword arguments for further customization.
595 |
596 | Returns:
597 | A Serialized object containing the serialized resources.
598 | """
599 | if not (isinstance(resources, (list, tuple)) or isgenerator(resources)):
600 | obj_handler = _select_obj_handler(resources)
601 | _utils.validate_renderer_iterables(resources, renderer_iterables)
602 |
603 | serialized = obj_handler(
604 | stream_cls,
605 | resources,
606 | renderer=renderer,
607 | renderer_iterables=renderer_iterables,
608 | renderer_kwargs=renderer_kwargs,
609 | **kwargs,
610 | )
611 | resources = serialized.resources
612 | renderer = serialized.renderer
613 | renderer_iterables = serialized.renderer_iterables
614 | renderer_kwargs = serialized.renderer_kwargs
615 | kwargs = serialized.kwargs
616 | max_frames = _utils.get_max_frames(len(resources), kwargs.get("max_frames"))
617 | kwargs["max_frames"] = max_frames
618 | resources, renderer_iterables = _utils.subset_resources_renderer_iterables(
619 | resources, renderer_iterables, max_frames
620 | )
621 | _utils.pop_kwargs(obj_handler, renderer_kwargs)
622 | _utils.pop_kwargs(obj_handler, kwargs)
623 | return Serialized(resources, renderer, renderer_iterables, renderer_kwargs, kwargs)
624 |
--------------------------------------------------------------------------------