├── tests ├── __init__.py ├── test_version.py ├── utils.py ├── example.py ├── test_deltalake.py ├── test_polars_parquet.py ├── conftest.py ├── test_upath_io_managers_lazy.py ├── test_polars_delta.py └── test_upath_io_managers.py ├── dagster_polars ├── py.typed ├── version.py ├── constants.py ├── io_managers │ ├── __init__.py │ ├── utils.py │ ├── bigquery.py │ ├── parquet.py │ ├── delta.py │ └── base.py ├── types.py └── __init__.py ├── README.md ├── tox.ini ├── .pre-commit-config.yaml ├── pyproject.toml ├── docs └── examples.md ├── .gitignore └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dagster_polars/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dagster_polars/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1!0+dev" 2 | -------------------------------------------------------------------------------- /dagster_polars/constants.py: -------------------------------------------------------------------------------- 1 | DAGSTER_POLARS_STORAGE_METADATA_KEY = "dagster_polars_metadata" 2 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | from dagster_polars.version import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project moved 2 | 3 | `dagster-polars` integration has been moved to [dagster-io/community-integrations](https://github.com/dagster-io/community-integrations/tree/main/libraries/dagster-polars). Please create `dagster-polars` related issues there and tag `@danielgafni`. 4 | 5 | # Documentation 6 | 7 | API documentation can be found [here](https://docs.dagster.io/_apidocs/libraries/dagster-polars). 8 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from dagster import ExecuteInProcessResult 2 | 3 | 4 | def get_saved_path(result: ExecuteInProcessResult, asset_name: str) -> str: 5 | path = ( 6 | list(filter(lambda evt: evt.is_handled_output, result.events_for_node(asset_name)))[0] # noqa: RUF015 7 | .event_specific_data.metadata["path"] # type: ignore 8 | .value 9 | ) 10 | assert isinstance(path, str) 11 | return path 12 | -------------------------------------------------------------------------------- /tests/example.py: -------------------------------------------------------------------------------- 1 | # I'm just using it to view the description in Dagit 2 | 3 | import polars as pl 4 | from dagster import Definitions, asset 5 | 6 | from dagster_polars import PolarsDeltaIOManager, PolarsParquetIOManager 7 | 8 | 9 | @asset(io_manager_def=PolarsParquetIOManager(base_dir="/tmp/dagster")) 10 | def my_parquet_asset() -> pl.DataFrame: 11 | return pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) 12 | 13 | 14 | @asset(io_manager_def=PolarsDeltaIOManager(base_dir="/tmp/dagster")) 15 | def my_delta_asset() -> pl.DataFrame: 16 | return pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) 17 | 18 | 19 | definitions = Definitions(assets=[my_parquet_asset, my_delta_asset]) 20 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/__init__.py: -------------------------------------------------------------------------------- 1 | from dagster_polars.io_managers.base import BasePolarsUPathIOManager 2 | from dagster_polars.io_managers.parquet import PolarsParquetIOManager 3 | 4 | __all__ = [ 5 | "PolarsParquetIOManager", 6 | "BasePolarsUPathIOManager", 7 | ] 8 | 9 | 10 | try: 11 | # provided by dagster-polars[delta] 12 | from dagster_polars.io_managers.delta import DeltaWriteMode, PolarsDeltaIOManager # noqa 13 | 14 | __all__.extend(["DeltaWriteMode", "PolarsDeltaIOManager"]) 15 | except ImportError: 16 | pass 17 | 18 | 19 | try: 20 | # provided by dagster-polars[bigquery] 21 | from dagster_polars.io_managers.bigquery import PolarsBigQueryIOManager, PolarsBigQueryTypeHandler # noqa 22 | 23 | __all__.extend(["PolarsBigQueryIOManager", "PolarsBigQueryTypeHandler"]) 24 | except ImportError: 25 | pass 26 | -------------------------------------------------------------------------------- /dagster_polars/types.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Any, Dict, Tuple 3 | 4 | if sys.version_info < (3, 10): 5 | from typing_extensions import TypeAlias 6 | else: 7 | from typing import TypeAlias 8 | 9 | import polars as pl 10 | 11 | StorageMetadata: TypeAlias = Dict[str, Any] 12 | DataFrameWithMetadata: TypeAlias = Tuple[pl.DataFrame, StorageMetadata] 13 | LazyFrameWithMetadata: TypeAlias = Tuple[pl.LazyFrame, StorageMetadata] 14 | DataFramePartitions: TypeAlias = Dict[str, pl.DataFrame] 15 | DataFramePartitionsWithMetadata: TypeAlias = Dict[str, DataFrameWithMetadata] 16 | LazyFramePartitions: TypeAlias = Dict[str, pl.LazyFrame] 17 | LazyFramePartitionsWithMetadata: TypeAlias = Dict[str, LazyFrameWithMetadata] 18 | 19 | __all__ = [ 20 | "StorageMetadata", 21 | "DataFrameWithMetadata", 22 | "LazyFrameWithMetadata", 23 | "DataFramePartitions", 24 | "DataFramePartitionsWithMetadata", 25 | "LazyFramePartitions", 26 | "LazyFramePartitionsWithMetadata", 27 | ] 28 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | ;multiple polars versions testing doesn't work with poetry 2 | ;this is done in CI 3 | 4 | ;[tox] 5 | ;min_version = 4.0 6 | ;isolated_build = True 7 | ;deps = 8 | ; polars0.15: polars >=0.15, <0.16 9 | ; polars0.16: polars >=0.16, <0.17 10 | ; polars0.17: polars >=0.17, <0.18 11 | ; polars0.18: polars >=0.18, <0.19 12 | ;env_list = 13 | ; py38-polars{0.15,0.16,0.17,0.18} 14 | ; py39-polars{0.15,0.16,0.17,0.18} 15 | ; py310-polars{0.15,0.16,0.17,0.18} 16 | ; py311-polars{0.15,0.16,0.17,0.18} 17 | ; 18 | ;[testenv] 19 | ;allowlist_externals = poetry 20 | ;;skip_install = true 21 | ;commands_pre = 22 | ; poetry install --all-extras --sync 23 | ;commands = 24 | ; poetry run pytest tests/ --import-mode importlib 25 | ; 26 | ;[testenv:type] 27 | ;allowlist_externals = poetry 28 | ;;skip_install = true 29 | ;commands_pre = 30 | ; poetry install --all-extras --sync 31 | ;commands = pyright . 32 | ; 33 | ;[gh-actions] 34 | ;python = 35 | ; 3.8: py38-polars{0.15,0.16,0.17,0.18} 36 | ; 3.9: py39-polars{0.15,0.16,0.17,0.18} 37 | ; 3.10: py310-polars{0.15,0.16,0.17,0.18} 38 | ; 3.11: py311-polars{0.15,0.16,0.17,0.18} 39 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-added-large-files 6 | args: [ '--maxkb=10000' ] 7 | - id: check-json 8 | - id: check-toml 9 | - id: check-yaml 10 | - id: forbid-new-submodules 11 | - id: mixed-line-ending 12 | args: [ '--fix=lf' ] 13 | - id: trailing-whitespace 14 | - id: check-docstring-first 15 | - id: check-merge-conflict 16 | - id: detect-private-key 17 | # - repo: https://github.com/asottile/pyupgrade 18 | # rev: v3.4.0 19 | # hooks: 20 | # - id: pyupgrade 21 | # entry: pyupgrade --py38-plus 22 | - repo: local 23 | hooks: 24 | - id: ruff 25 | name: ruff 26 | entry: ruff --fix . 27 | language: system 28 | pass_filenames: false 29 | - id: format 30 | name: format 31 | entry: ruff format . 32 | language: system 33 | pass_filenames: false 34 | - id: black-docs 35 | name: black-docs 36 | entry: blacken-docs 37 | language: system 38 | pass_filenames: true 39 | files: '\.md$' 40 | 41 | - id: pyright 42 | name: pyright 43 | entry: pyright . 44 | language: system 45 | pass_filenames: false 46 | -------------------------------------------------------------------------------- /dagster_polars/__init__.py: -------------------------------------------------------------------------------- 1 | from dagster_polars.io_managers.base import BasePolarsUPathIOManager 2 | from dagster_polars.io_managers.parquet import PolarsParquetIOManager 3 | from dagster_polars.types import ( 4 | DataFramePartitions, 5 | DataFramePartitionsWithMetadata, 6 | DataFrameWithMetadata, 7 | LazyFramePartitions, 8 | LazyFramePartitionsWithMetadata, 9 | LazyFrameWithMetadata, 10 | StorageMetadata, 11 | ) 12 | from dagster_polars.version import __version__ 13 | 14 | __all__ = [ 15 | "PolarsParquetIOManager", 16 | "BasePolarsUPathIOManager", 17 | "StorageMetadata", 18 | "DataFrameWithMetadata", 19 | "LazyFrameWithMetadata", 20 | "DataFramePartitions", 21 | "LazyFramePartitions", 22 | "DataFramePartitionsWithMetadata", 23 | "LazyFramePartitionsWithMetadata", 24 | "__version__", 25 | ] 26 | 27 | 28 | try: 29 | # provided by dagster-polars[delta] 30 | from dagster_polars.io_managers.delta import DeltaWriteMode, PolarsDeltaIOManager # noqa 31 | 32 | __all__.extend(["DeltaWriteMode", "PolarsDeltaIOManager"]) # noqa 33 | except ImportError: 34 | pass 35 | 36 | 37 | try: 38 | # provided by dagster-polars[bigquery] 39 | from dagster_polars.io_managers.bigquery import ( 40 | PolarsBigQueryIOManager, # noqa 41 | PolarsBigQueryTypeHandler, # noqa 42 | ) 43 | 44 | __all__.extend(["PolarsBigQueryIOManager", "PolarsBigQueryTypeHandler"]) 45 | except ImportError: 46 | pass 47 | -------------------------------------------------------------------------------- /tests/test_deltalake.py: -------------------------------------------------------------------------------- 1 | # seems like the problems with reading/writing delta tables are only happening when 2 | # doing this very fast, i.e. in a test. 3 | # commenting this for now 4 | 5 | import shutil 6 | 7 | import polars as pl 8 | import polars.testing as pl_testing 9 | from _pytest.tmpdir import TempPathFactory 10 | from hypothesis import given, settings 11 | from polars.testing.parametric import dataframes 12 | 13 | 14 | @given( 15 | df=dataframes( 16 | excluded_dtypes=[ 17 | pl.Categorical, # Unsupported type in delta protocol 18 | pl.Duration, # Unsupported type in delta protocol 19 | pl.Time, # Unsupported type in delta protocol 20 | pl.UInt8, # These get casted to int in deltalake whenever it fits 21 | pl.UInt16, # These get casted to int in deltalake whenever it fits 22 | pl.UInt32, # These get casted to int in deltalake whenever it fits 23 | pl.UInt64, # These get casted to int in deltalake whenever it fits 24 | pl.Datetime("ns", None), # These get casted to datetime('ms') 25 | ], 26 | min_size=5, 27 | allow_infinities=False, 28 | ) 29 | ) 30 | @settings(max_examples=20, deadline=None) 31 | def test_polars_delta_io(df: pl.DataFrame, tmp_path_factory: TempPathFactory): 32 | tmp_path = tmp_path_factory.mktemp("data") 33 | df.write_delta(str(tmp_path), delta_write_options={"engine": "rust"}) 34 | pl_testing.assert_frame_equal(df.with_columns(), pl.read_delta(str(tmp_path))) 35 | shutil.rmtree(str(tmp_path)) # cleanup manually because of hypothesis 36 | -------------------------------------------------------------------------------- /tests/test_polars_parquet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import polars as pl 4 | import polars.testing as pl_testing 5 | from dagster import asset, materialize 6 | from hypothesis import given, settings 7 | from polars.testing.parametric import dataframes 8 | 9 | from dagster_polars import PolarsParquetIOManager 10 | from tests.utils import get_saved_path 11 | 12 | 13 | # allowed_dtypes=[pl.List(inner) for inner in 14 | # list(pl.TEMPORAL_DTYPES | pl.FLOAT_DTYPES | pl.INTEGER_DTYPES) + [pl.Boolean, pl.Utf8]] 15 | @given(df=dataframes(excluded_dtypes=[pl.Categorical], min_size=5)) 16 | @settings(max_examples=100, deadline=None) 17 | def test_polars_parquet_io_manager_read_write( 18 | session_polars_parquet_io_manager: PolarsParquetIOManager, df: pl.DataFrame 19 | ): 20 | @asset(io_manager_def=session_polars_parquet_io_manager) 21 | def upstream() -> pl.DataFrame: 22 | return df 23 | 24 | @asset(io_manager_def=session_polars_parquet_io_manager) 25 | def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: 26 | return upstream.collect(streaming=True) 27 | 28 | result = materialize( 29 | [upstream, downstream], 30 | ) 31 | 32 | saved_path = get_saved_path(result, "upstream") 33 | pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) 34 | os.remove(saved_path) # cleanup manually because of hypothesis 35 | 36 | 37 | # allowed_dtypes=[pl.List(inner) for inner in 38 | # list(pl.TEMPORAL_DTYPES | pl.FLOAT_DTYPES | pl.INTEGER_DTYPES) + [pl.Boolean, pl.Utf8]] 39 | @given(df=dataframes(excluded_dtypes=[pl.Categorical], min_size=5)) 40 | @settings(max_examples=100, deadline=None) 41 | def test_polars_parquet_io_manager_read_write_full_lazy( 42 | session_polars_parquet_io_manager: PolarsParquetIOManager, df: pl.DataFrame 43 | ): 44 | @asset(io_manager_def=session_polars_parquet_io_manager) 45 | def upstream() -> pl.DataFrame: 46 | return df 47 | 48 | @asset(io_manager_def=session_polars_parquet_io_manager) 49 | def downstream(upstream: pl.LazyFrame) -> pl.LazyFrame: 50 | return upstream 51 | 52 | @asset(io_manager_def=session_polars_parquet_io_manager) 53 | def downstream2(downstream: pl.LazyFrame) -> pl.LazyFrame: 54 | return downstream 55 | 56 | result = materialize( 57 | [upstream, downstream, downstream2], 58 | ) 59 | 60 | saved_path = get_saved_path(result, "upstream") 61 | pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) 62 | os.remove(saved_path) # cleanup manually because of hypothesis 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "dagster-polars" 3 | version = "0.0.0" 4 | description = "Dagster integration library for Polars" 5 | authors = [ 6 | "Daniel Gafni " 7 | ] 8 | readme = "README.md" 9 | packages = [{include = "dagster_polars"}] 10 | repository = "https://github.com/danielgafni/dagster-polars" 11 | keywords = [ 12 | "dagster", 13 | "polars", 14 | "ETL", 15 | "dataframe", 16 | ] 17 | classifiers = [ 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Operating System :: OS Independent", 24 | "Topic :: Software Development :: Libraries :: Python Modules", 25 | ] 26 | license = "Apache-2.0" 27 | 28 | [tool.poetry.dependencies] 29 | python = "^3.8" 30 | dagster = "^1.5.1" 31 | polars = ">=0.20.0" 32 | pyarrow = ">=8.0.0" 33 | typing-extensions = "^4.7.1" 34 | 35 | deltalake = { version = ">=0.15.0", optional = true } 36 | dagster-gcp = { version = ">=0.19.5", optional = true } 37 | universal-pathlib = "^0.1.4" 38 | 39 | # pinned due to Dagster not working with Pendulum < 3.0.0 40 | # TODO: remove once Dagster supports Pendulum 3.0.0 41 | pendulum = "<3.0.0" 42 | 43 | [tool.poetry.extras] 44 | gcp = ["dagster-gcp"] 45 | deltalake = ["deltalake"] 46 | 47 | 48 | [tool.poetry.group.dev.dependencies] 49 | hypothesis = "^6.89.0" 50 | pytest = "^7.3.1" 51 | deepdiff = "^6.3.0" 52 | ruff = "^0.1.3" 53 | pyright = "^1.1.313" 54 | tox = "^4.6.0" 55 | tox-gh = "^1.0.0" 56 | pre-commit = "^3.3.2" 57 | dagit = "^1.3.9" 58 | black = "^23.3.0" 59 | pytest-cases = "^3.6.14" 60 | blacken-docs = "^1.16.0" 61 | pytest-rerunfailures = "^12.0" 62 | 63 | [build-system] 64 | requires = ["poetry-core"] 65 | build-backend = "poetry.core.masonry.api" 66 | 67 | [tool.poetry-dynamic-versioning] 68 | enable = true 69 | strict = false 70 | vcs = "git" 71 | style = "pep440" 72 | dirty = true 73 | bump = true 74 | metadata = false 75 | 76 | [tool.poetry-dynamic-versioning.substitution] 77 | files = [ 78 | "pyproject.toml", 79 | "dagster_polars/version.py" 80 | ] 81 | 82 | [tool.pytest.ini_options] 83 | log_cli = true 84 | log_level = "INFO" 85 | 86 | [tool.black] 87 | line-length = 120 88 | target-version = ['py39'] 89 | include = '\.pyi?$' 90 | exclude = ''' 91 | /( 92 | \.eggs 93 | | \.git 94 | | \.hg 95 | | \.mypy_cache 96 | | \.pytest_cache 97 | | \.ruff_cache 98 | | \.venv 99 | | build 100 | | dist 101 | )/ 102 | ''' 103 | 104 | [tool.ruff] 105 | extend-select = ["I"] 106 | line-length = 120 107 | src = [ 108 | "dagster_polars", 109 | "tests" 110 | ] 111 | exclude = [ 112 | ".bzr", 113 | ".direnv", 114 | ".eggs", 115 | ".git", 116 | ".hg", 117 | ".mypy_cache", 118 | ".nox", 119 | ".pants.d", 120 | ".ruff_cache", 121 | ".svn", 122 | ".tox", 123 | ".venv", 124 | "__pypackages__", 125 | "_build", 126 | "buck-out", 127 | "build", 128 | "dist", 129 | "node_modules", 130 | "venv", 131 | ] 132 | [tool.ruff.isort] 133 | known-first-party = ["dagster_polars", "tests"] 134 | 135 | [tool.pyright] 136 | reportPropertyTypeMismatch = true 137 | reportImportCycles = true 138 | reportWildcardImportFromLibrary = true 139 | reportUntypedFunctionDecorator = true 140 | reportUntypedClassDecorator = true 141 | reportUnnecessaryTypeIgnoreComment = "warning" 142 | 143 | include = [ 144 | "dagster_polars", 145 | "tests" 146 | ] 147 | exclude = [ 148 | ".bzr", 149 | ".direnv", 150 | ".eggs", 151 | ".git", 152 | ".hg", 153 | ".mypy_cache", 154 | ".nox", 155 | ".pants.d", 156 | ".ruff_cache", 157 | ".svn", 158 | ".tox", 159 | ".venv", 160 | "__pypackages__", 161 | "_build", 162 | "buck-out", 163 | "build", 164 | "dist", 165 | "node_modules", 166 | "venv", 167 | ] 168 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from datetime import date, datetime, timedelta 4 | from typing import Tuple, Type 5 | 6 | import dagster 7 | import polars as pl 8 | import pytest 9 | import pytest_cases 10 | from _pytest.tmpdir import TempPathFactory 11 | from dagster import DagsterInstance 12 | 13 | from dagster_polars import BasePolarsUPathIOManager, PolarsDeltaIOManager, PolarsParquetIOManager 14 | 15 | logging.getLogger("alembic.runtime.migration").setLevel(logging.WARNING) 16 | warnings.filterwarnings("ignore", category=dagster.ExperimentalWarning) 17 | 18 | 19 | @pytest.fixture 20 | def dagster_instance(tmp_path_factory: TempPathFactory) -> DagsterInstance: 21 | return DagsterInstance.ephemeral(tempdir=str(tmp_path_factory.mktemp("dagster_home"))) 22 | 23 | 24 | @pytest.fixture 25 | def polars_parquet_io_manager(dagster_instance: DagsterInstance) -> PolarsParquetIOManager: 26 | return PolarsParquetIOManager(base_dir=dagster_instance.storage_directory()) 27 | 28 | 29 | @pytest.fixture 30 | def polars_delta_io_manager(dagster_instance: DagsterInstance) -> PolarsDeltaIOManager: 31 | return PolarsDeltaIOManager(base_dir=dagster_instance.storage_directory()) 32 | 33 | 34 | @pytest.fixture(scope="session") 35 | def session_scoped_dagster_instance(tmp_path_factory: TempPathFactory) -> DagsterInstance: 36 | return DagsterInstance.ephemeral(tempdir=str(tmp_path_factory.mktemp("dagster_home_session"))) 37 | 38 | 39 | @pytest.fixture(scope="session") 40 | def session_polars_parquet_io_manager( 41 | session_scoped_dagster_instance: DagsterInstance, 42 | ) -> PolarsParquetIOManager: 43 | return PolarsParquetIOManager( 44 | base_dir=session_scoped_dagster_instance.storage_directory() 45 | ) # to use with hypothesis 46 | 47 | 48 | @pytest.fixture(scope="session") 49 | def session_polars_delta_io_manager( 50 | session_scoped_dagster_instance: DagsterInstance, 51 | ) -> PolarsDeltaIOManager: 52 | return PolarsDeltaIOManager(base_dir=session_scoped_dagster_instance.storage_directory()) # to use with hypothesis 53 | 54 | 55 | main_data = { 56 | "1": [0, 1, None], 57 | "2": [0.0, 1.0, None], 58 | "3": ["a", "b", None], 59 | "4": [[0, 1], [2, 3], None], 60 | "6": [{"a": 0}, {"a": 1}, None], 61 | "7": [datetime(2022, 1, 1), datetime(2022, 1, 2), None], 62 | "8": [date(2022, 1, 1), date(2022, 1, 2), None], 63 | } 64 | 65 | _df_for_delta = pl.DataFrame(main_data) 66 | 67 | _lazy_df_for_delta = pl.LazyFrame(main_data) 68 | 69 | parquet_data = main_data 70 | parquet_data["9"] = [timedelta(hours=1), timedelta(hours=2), None] 71 | 72 | _df_for_parquet = pl.DataFrame(parquet_data) 73 | _lazy_df_for_parquet = pl.LazyFrame(parquet_data) 74 | 75 | 76 | @pytest_cases.fixture(scope="session") 77 | def df_for_parquet() -> pl.DataFrame: 78 | return _df_for_parquet 79 | 80 | 81 | @pytest_cases.fixture(scope="session") 82 | def df_for_delta() -> pl.DataFrame: 83 | return _df_for_delta 84 | 85 | 86 | @pytest_cases.fixture(scope="session") 87 | def lazy_df_for_parquet() -> pl.LazyFrame: 88 | return _lazy_df_for_parquet 89 | 90 | 91 | @pytest_cases.fixture(scope="session") 92 | def lazy_df_for_delta() -> pl.LazyFrame: 93 | return _lazy_df_for_delta 94 | 95 | 96 | @pytest_cases.fixture 97 | @pytest_cases.parametrize( 98 | "io_manager,frame", 99 | [(PolarsParquetIOManager, _df_for_parquet), (PolarsDeltaIOManager, _df_for_delta)], 100 | ) 101 | def io_manager_and_df( # to use without hypothesis 102 | io_manager: Type[BasePolarsUPathIOManager], 103 | frame: pl.DataFrame, 104 | dagster_instance: DagsterInstance, 105 | ) -> Tuple[BasePolarsUPathIOManager, pl.DataFrame]: 106 | return io_manager(base_dir=dagster_instance.storage_directory()), frame 107 | 108 | 109 | @pytest_cases.fixture 110 | @pytest_cases.parametrize( 111 | "io_manager,frame", 112 | [(PolarsParquetIOManager, _lazy_df_for_parquet), (PolarsDeltaIOManager, _lazy_df_for_delta)], 113 | ) 114 | def io_manager_and_lazy_df( # to use without hypothesis 115 | io_manager: Type[BasePolarsUPathIOManager], 116 | frame: pl.LazyFrame, 117 | dagster_instance: DagsterInstance, 118 | ) -> Tuple[BasePolarsUPathIOManager, pl.LazyFrame]: 119 | return io_manager(base_dir=dagster_instance.storage_directory()), frame 120 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | 2 | # Examples 3 | 4 | ## Providing IOManagers to `Definitions` 5 | 6 | ```python 7 | from dagster import Definitions 8 | from dagster_polars import PolarsDeltaIOManager, PolarsParquetIOManager 9 | 10 | base_dir = ( 11 | "/remote/or/local/path" # s3://my-bucket/... or gs://my-bucket/... also works! 12 | ) 13 | 14 | definitions = Definitions( 15 | resources={ 16 | "polars_parquet_io_manager": PolarsParquetIOManager(base_dir=base_dir), 17 | "polars_delta_io_manager": PolarsDeltaIOManager(base_dir=base_dir), 18 | } 19 | ) 20 | ``` 21 | 22 | ## Reading specific columns 23 | ```python 24 | import polars as pl 25 | from dagster import AssetIn, asset 26 | 27 | 28 | @asset( 29 | io_manager_key="polars_parquet_io_manager", 30 | ins={ 31 | "upstream": AssetIn(metadata={"columns": ["a"]}) 32 | }, # explicitly specify which columns to load 33 | ) 34 | def downstream(upstream: pl.DataFrame): 35 | assert upstream.columns == ["a"] 36 | ``` 37 | 38 | ## Reading `LazyFrame` 39 | 40 | ```python 41 | import polars as pl 42 | from dagster import asset 43 | 44 | 45 | @asset( 46 | io_manager_key="polars_parquet_io_manager", 47 | ) 48 | def downstream( 49 | upstream: pl.LazyFrame, # the type annotation controls whether we load an eager or lazy DataFrame 50 | ) -> pl.DataFrame: 51 | df = ... # some lazy operations with `upstream` 52 | return df.collect() 53 | ``` 54 | 55 | ## Reading multiple partitions 56 | ```python 57 | import polars as pl 58 | from dagster import asset, StaticPartitionsDefinition 59 | from dagster_polars import DataFramePartitions, LazyFramePartitions 60 | 61 | 62 | @asset( 63 | partitions_def=StaticPartitionsDefinition(["a", "b"]), 64 | io_manager_key="polars_parquet_io_manager", 65 | ) 66 | def upstream() -> pl.DataFrame: 67 | return pl.DataFrame(...) 68 | 69 | 70 | @asset( 71 | io_manager_key="polars_parquet_io_manager", 72 | ) 73 | def downstream_eager(upstream: DataFramePartitions): 74 | assert isinstance(upstream, dict) 75 | assert isinstance(upstream["a"], pl.DataFrame) 76 | assert isinstance(upstream["b"], pl.DataFrame) 77 | 78 | 79 | @asset( 80 | io_manager_key="polars_parquet_io_manager", 81 | ) 82 | def downstream_lazy(upstream: LazyFramePartitions): 83 | assert isinstance(upstream, dict) 84 | assert isinstance(upstream["a"], pl.LazyFrame) 85 | assert isinstance(upstream["b"], pl.LazyFrame) 86 | ``` 87 | 88 | ## Skipping missing input/output 89 | 90 | ```python 91 | from typing import Optional 92 | 93 | import polars as pl 94 | from dagster import asset 95 | 96 | 97 | @asset( 98 | io_manager_key="polars_parquet_io_manager", 99 | ) 100 | def downstream(upstream: Optional[pl.DataFrame]) -> Optional[pl.DataFrame]: 101 | maybe_df: Optional[pl.DataFrame] = ... 102 | return maybe_df 103 | ``` 104 | 105 | ## Reading/writing custom metadata from/to storage 106 | 107 | It's possible to write any custom metadata dict into storage for some IOManagers. For example, `PolarsParquetIOManager` supports this feature. 108 | 109 | 110 | ```python 111 | import polars as pl 112 | from dagster import asset 113 | from dagster_polars import DataFrameWithMetadata 114 | 115 | 116 | @asset( 117 | io_manager_key="polars_parquet_io_manager", 118 | ) 119 | def upstream() -> DataFrameWithMetadata: 120 | return pl.DataFrame(...), {"my_custom_metadata": "my_custom_value"} 121 | 122 | 123 | @asset( 124 | io_manager_key="polars_parquet_io_manager", 125 | ) 126 | def downsteam(upstream: DataFrameWithMetadata): 127 | df, metadata = upstream 128 | assert metadata["my_custom_metadata"] == "my_custom_value" 129 | ``` 130 | 131 | The metadata can be retrieved from the materialized asset outside of Dagster runtime. 132 | 133 | This can be done either by importing the `Definitions` object and referring to the asset by it's key: 134 | 135 | ```python 136 | from dagster import DagsterInstance 137 | from dagster_polars import DataFrameWithMetadata 138 | 139 | from your_definitions import definitions # noqa 140 | 141 | with DagsterInstance.ephemeral() as instance: 142 | df, metadata = definitions.load_asset_value( 143 | ["asset", "key"], python_type=DataFrameWithMetadata, instance=instance 144 | ) 145 | ``` 146 | 147 | or directly from the serialized asset (depending on the IOManager metadata saving implementation). For example, with `PolarsParquetIOManager`: 148 | 149 | ```python 150 | from dagster_polars import PolarsParquetIOManager 151 | from upath import UPath 152 | 153 | metadata = PolarsParquetIOManager.read_parquet_metadata(UPath("/asset/key.parquet")) 154 | ``` 155 | 156 | 157 | ## Append to DeltaLake table 158 | ```python 159 | import polars as pl 160 | from dagster import asset 161 | 162 | 163 | @asset(io_manager_key="polars_parquet_io_manager") 164 | def upstream() -> pl.DataFrame: 165 | return pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) 166 | 167 | 168 | @asset( 169 | io_manager_key="polars_delta_io_manager", 170 | metadata={ 171 | "mode": "append" # append to the existing table instead of overwriting it 172 | }, 173 | ) 174 | def downstream_append(upstream: pl.DataFrame) -> pl.DataFrame: 175 | return upstream 176 | ``` 177 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .idea 3 | junit.xml 4 | .ruff_cache 5 | .dagster_home 6 | 7 | ### JetBrains template 8 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 9 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 10 | 11 | # User-specific stuff 12 | .idea/**/workspace.xml 13 | .idea/**/tasks.xml 14 | .idea/**/usage.statistics.xml 15 | .idea/**/dictionaries 16 | .idea/**/shelf 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/artifacts 39 | # .idea/compiler.xml 40 | # .idea/jarRepositories.xml 41 | # .idea/modules.xml 42 | # .idea/*.iml 43 | # .idea/modules 44 | # *.iml 45 | # *.ipr 46 | 47 | # CMake 48 | cmake-build-*/ 49 | 50 | # Mongo Explorer plugin 51 | .idea/**/mongoSettings.xml 52 | 53 | # File-based project format 54 | *.iws 55 | 56 | # IntelliJ 57 | out/ 58 | 59 | # mpeltonen/sbt-idea plugin 60 | .idea_modules/ 61 | 62 | # JIRA plugin 63 | atlassian-ide-plugin.xml 64 | 65 | # Cursive Clojure plugin 66 | .idea/replstate.xml 67 | 68 | # Crashlytics plugin (for Android Studio and IntelliJ) 69 | com_crashlytics_export_strings.xml 70 | crashlytics.properties 71 | crashlytics-build.properties 72 | fabric.properties 73 | 74 | # Editor-based Rest Client 75 | .idea/httpRequests 76 | 77 | # Android studio 3.1+ serialized cache file 78 | .idea/caches/build_file_checksums.ser 79 | 80 | ### macOS template 81 | # General 82 | .DS_Store 83 | .AppleDouble 84 | .LSOverride 85 | 86 | # Icon must end with two \r 87 | Icon 88 | 89 | # Thumbnails 90 | ._* 91 | 92 | # Files that might appear in the root of a volume 93 | .DocumentRevisions-V100 94 | .fseventsd 95 | .Spotlight-V100 96 | .TemporaryItems 97 | .Trashes 98 | .VolumeIcon.icns 99 | .com.apple.timemachine.donotpresent 100 | 101 | # Directories potentially created on remote AFP share 102 | .AppleDB 103 | .AppleDesktop 104 | Network Trash Folder 105 | Temporary Items 106 | .apdisk 107 | 108 | ### Python template 109 | # Byte-compiled / optimized / DLL files 110 | __pycache__/ 111 | *.py[cod] 112 | *$py.class 113 | 114 | # C extensions 115 | *.so 116 | 117 | # Distribution / packaging 118 | .Python 119 | build/ 120 | develop-eggs/ 121 | dist/ 122 | downloads/ 123 | eggs/ 124 | .eggs/ 125 | lib/ 126 | lib64/ 127 | parts/ 128 | sdist/ 129 | var/ 130 | wheels/ 131 | share/python-wheels/ 132 | *.egg-info/ 133 | .installed.cfg 134 | *.egg 135 | MANIFEST 136 | 137 | # PyInstaller 138 | # Usually these files are written by a python script from a template 139 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 140 | *.manifest 141 | *.spec 142 | 143 | # Installer logs 144 | pip-log.txt 145 | pip-delete-this-directory.txt 146 | 147 | # Unit test / coverage reports 148 | htmlcov/ 149 | .tox/ 150 | .nox/ 151 | .coverage 152 | .coverage.* 153 | .cache 154 | nosetests.xml 155 | coverage.xml 156 | *.cover 157 | *.py,cover 158 | .hypothesis/ 159 | .pytest_cache/ 160 | cover/ 161 | 162 | # Translations 163 | *.mo 164 | *.pot 165 | 166 | # Django stuff: 167 | *.log 168 | local_settings.py 169 | db.sqlite3 170 | db.sqlite3-journal 171 | 172 | # Flask stuff: 173 | instance/ 174 | .webassets-cache 175 | 176 | # Scrapy stuff: 177 | .scrapy 178 | 179 | # Sphinx documentation 180 | docs/_build/ 181 | 182 | # PyBuilder 183 | .pybuilder/ 184 | target/ 185 | 186 | # Jupyter Notebook 187 | .ipynb_checkpoints 188 | 189 | # IPython 190 | profile_default/ 191 | ipython_config.py 192 | 193 | # pyenv 194 | # For a library or package, you might want to ignore these files since the code is 195 | # intended to run in multiple environments; otherwise, check them in: 196 | # .python-version 197 | 198 | # pipenv 199 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 200 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 201 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 202 | # install all needed dependencies. 203 | #Pipfile.lock 204 | 205 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 206 | __pypackages__/ 207 | 208 | # Celery stuff 209 | celerybeat-schedule 210 | celerybeat.pid 211 | 212 | # SageMath parsed files 213 | *.sage.py 214 | 215 | # Environments 216 | .env 217 | .venv 218 | env/ 219 | venv/ 220 | ENV/ 221 | env.bak/ 222 | venv.bak/ 223 | 224 | # Spyder project settings 225 | .spyderproject 226 | .spyproject 227 | 228 | # Rope project settings 229 | .ropeproject 230 | 231 | # mkdocs documentation 232 | /site 233 | 234 | # mypy 235 | .mypy_cache/ 236 | .dmypy.json 237 | dmypy.json 238 | 239 | # Pyre type checker 240 | .pyre/ 241 | 242 | # pytype static type analyzer 243 | .pytype/ 244 | 245 | # Cython debug symbols 246 | cython_debug/ 247 | 248 | ### Linux template 249 | *~ 250 | 251 | # temporary files which can be created if a process still has a handle open of a deleted file 252 | .fuse_hidden* 253 | 254 | # KDE directory preferences 255 | .directory 256 | 257 | # Linux trash folder which might appear on any partition or disk 258 | .Trash-* 259 | 260 | # .nfs files are created when an open file is removed but is still being accessed 261 | .nfs* 262 | 263 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from datetime import date, datetime, time, timedelta 4 | from pprint import pformat 5 | from typing import Any, Dict, Mapping, Optional, Union 6 | 7 | import polars as pl 8 | from dagster import ( 9 | MetadataValue, 10 | OutputContext, 11 | TableColumn, 12 | TableMetadataValue, 13 | TableRecord, 14 | TableSchema, 15 | ) 16 | from packaging.version import Version 17 | 18 | POLARS_DATA_FRAME_ANNOTATIONS = [ 19 | Any, 20 | pl.DataFrame, 21 | Dict[str, pl.DataFrame], 22 | Mapping[str, pl.DataFrame], 23 | type(None), 24 | None, 25 | ] 26 | 27 | POLARS_LAZY_FRAME_ANNOTATIONS = [ 28 | pl.LazyFrame, 29 | Dict[str, pl.LazyFrame], 30 | Mapping[str, pl.LazyFrame], 31 | ] 32 | 33 | 34 | if sys.version >= "3.9": 35 | POLARS_DATA_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore # ignore needed with Python < 3.9 36 | POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) # type: ignore # ignore needed with Python < 3.9 37 | 38 | 39 | def cast_polars_single_value_to_dagster_table_types(val: Any): 40 | if val is None: 41 | return "" 42 | elif isinstance(val, (date, datetime, time, timedelta, bytes)): 43 | return str(val) 44 | elif isinstance(val, (list, dict)): 45 | # default=str because sometimes the object can be a list of datetimes or something like this 46 | return json.dumps(val, default=str) 47 | else: 48 | return val 49 | 50 | 51 | def get_metadata_schema( 52 | df: Union[pl.DataFrame, pl.LazyFrame], 53 | descriptions: Optional[Dict[str, str]] = None, 54 | ) -> TableSchema: 55 | """Takes the schema from a dataframe or lazyframe and converts it a Dagster TableSchema. 56 | 57 | Args: 58 | df (Union[pl.DataFrame, pl.LazyFrame]): dataframe 59 | descriptions (Optional[Dict[str, str]], optional): column descriptions. Defaults to None. 60 | 61 | Returns: 62 | TableSchema: dagster TableSchema 63 | """ 64 | descriptions = descriptions or {} 65 | return TableSchema( 66 | columns=[ 67 | TableColumn(name=col, type=str(pl_type), description=descriptions.get(col)) 68 | for col, pl_type in df.schema.items() 69 | ] 70 | ) 71 | 72 | 73 | def get_table_metadata( 74 | context: OutputContext, 75 | df: pl.DataFrame, 76 | schema: TableSchema, 77 | n_rows: Optional[int] = 5, 78 | fraction: Optional[float] = None, 79 | ) -> Optional[TableMetadataValue]: 80 | """Takes the polars DataFrame and takes a sample of the data and returns it as TableMetaDataValue. 81 | A lazyframe this is not possible without doing possible a very costly operation. 82 | 83 | Args: 84 | context (OutputContext): output context 85 | df (pl.DataFrame): polars frame 86 | schema (TableSchema): dataframe schema, 87 | n_rows (Optional[int], optional): number of rows to sample from. Defaults to 5. 88 | fraction (Optional[float], optional): fraction of rows to sample from. Defaults to None. 89 | 90 | Returns: 91 | Tuple[TableSchema, Optional[TableMetadataValue]]: schema metadata, and optional sample metadata 92 | """ 93 | assert not fraction and n_rows, "only one of n_rows and frac should be set" 94 | n_rows = min(n_rows, len(df)) 95 | df_sample = df.sample(n=n_rows, fraction=fraction, shuffle=True) 96 | 97 | try: 98 | # this can fail sometimes 99 | # because TableRecord doesn't support all python types 100 | df_sample_dict = df_sample.to_dicts() 101 | table = MetadataValue.table( 102 | records=[ 103 | TableRecord( 104 | {col: cast_polars_single_value_to_dagster_table_types(df_sample_dict[i][col]) for col in df.columns} 105 | ) 106 | for i in range(len(df_sample)) 107 | ], 108 | schema=schema, 109 | ) 110 | except TypeError as e: 111 | context.log.error( 112 | f"Failed to create table sample metadata." 113 | f"Reason:\n{e}\n" 114 | f"Schema:\n{df.schema}\n" 115 | f"Polars sample:\n{df_sample}\n" 116 | f"dict sample:\n{pformat(df_sample.to_dicts())}" 117 | ) 118 | return None 119 | return table 120 | 121 | 122 | def get_polars_df_stats( 123 | df: pl.DataFrame, 124 | ) -> Dict[str, Dict[str, Union[str, int, float]]]: 125 | describe = df.describe().fill_null(pl.lit("null")) 126 | # TODO(ion): replace once there is a index column selector 127 | if Version(pl.__version__) >= Version("0.20.6"): 128 | col_name = "statistic" 129 | else: 130 | col_name = "describe" 131 | return { 132 | col: {stat: describe[col][i] for i, stat in enumerate(describe[col_name].to_list())} 133 | for col in describe.columns[1:] 134 | } 135 | 136 | 137 | def get_polars_metadata(context: OutputContext, df: Union[pl.DataFrame, pl.LazyFrame]) -> Dict[str, MetadataValue]: 138 | """Retrives some metadata on polars frames: 139 | - DataFrame: stats, row_count, table or schema 140 | - LazyFrame: schema 141 | 142 | Args: 143 | context (OutputContext): context 144 | df (Union[pl.DataFrame, pl.LazyFrame]): output dataframe 145 | 146 | Returns: 147 | Dict[str, MetadataValue]: metadata about df 148 | """ 149 | assert context.metadata is not None 150 | 151 | schema = get_metadata_schema(df, descriptions=context.metadata.get("descriptions")) 152 | 153 | metadata = {} 154 | 155 | if isinstance(df, pl.DataFrame): 156 | table = get_table_metadata( 157 | context=context, 158 | df=df, 159 | schema=schema, 160 | n_rows=context.metadata.get("n_rows", 5), 161 | fraction=context.metadata.get("fraction"), 162 | ) 163 | metadata["stats"] = MetadataValue.json(get_polars_df_stats(df)) 164 | metadata["row_count"] = MetadataValue.int(df.shape[0]) 165 | else: 166 | table = None 167 | 168 | if table is not None: 169 | metadata["table"] = table 170 | else: 171 | metadata["schema"] = schema 172 | 173 | return metadata 174 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/bigquery.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Type 2 | 3 | import polars as pl 4 | from dagster import InputContext, MetadataValue, OutputContext 5 | from dagster._annotations import experimental 6 | from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice 7 | 8 | from dagster_polars.io_managers.utils import get_polars_metadata 9 | 10 | try: 11 | from dagster_gcp.bigquery.io_manager import BigQueryClient, BigQueryIOManager 12 | from google.cloud import bigquery as bigquery 13 | except ImportError as e: 14 | raise ImportError("Install 'dagster-polars[gcp]' to use BigQuery functionality") from e 15 | 16 | 17 | @experimental 18 | class PolarsBigQueryTypeHandler(DbTypeHandler[pl.DataFrame]): 19 | """Plugin for the BigQuery I/O Manager that can store and load Polars DataFrames as BigQuery tables. 20 | 21 | Examples: 22 | .. code-block:: python 23 | 24 | from dagster_gcp import BigQueryIOManager 25 | from dagster_bigquery_polars import BigQueryPolarsTypeHandler 26 | from dagster import Definitions, EnvVar 27 | 28 | class MyBigQueryIOManager(BigQueryIOManager): 29 | @staticmethod 30 | def type_handlers() -> Sequence[DbTypeHandler]: 31 | return [PolarsBigQueryTypeHandler()] 32 | 33 | @asset( 34 | key_prefix=["my_dataset"] # my_dataset will be used as the dataset in BigQuery 35 | ) 36 | def my_table() -> pd.DataFrame: # the name of the asset will be the table name 37 | ... 38 | 39 | defs = Definitions( 40 | assets=[my_table], 41 | resources={ 42 | "io_manager": MyBigQueryIOManager(project=EnvVar("GCP_PROJECT")) 43 | } 44 | ) 45 | 46 | """ 47 | 48 | def handle_output( 49 | self, 50 | context: OutputContext, 51 | table_slice: TableSlice, 52 | obj: Optional[pl.DataFrame], 53 | connection, 54 | ): 55 | """Stores the polars DataFrame in BigQuery.""" 56 | skip_upload = False 57 | if obj is None: 58 | context.log.warning("Skipping BigQuery output as the output is None") 59 | skip_upload = True 60 | elif len(obj) == 0: 61 | context.log.warning("Skipping BigQuery output as the output DataFrame is empty") 62 | skip_upload = True 63 | 64 | if skip_upload: 65 | context.add_output_metadata({"missing": MetadataValue.bool(True)}) 66 | return 67 | 68 | assert obj is not None 69 | assert isinstance(connection, bigquery.Client) 70 | assert context.metadata is not None 71 | job_config = bigquery.LoadJobConfig(write_disposition=context.metadata.get("write_disposition")) 72 | 73 | # FIXME: load_table_from_dataframe writes the dataframe to a temporary parquet file 74 | # and then calls load_table_from_file. This can cause problems in cloud environments 75 | # therefore, it's better to use load_table_from_uri with GCS, 76 | # but this requires the remote filesystem to be available in this code 77 | job = connection.load_table_from_dataframe( 78 | dataframe=obj.to_pandas(), 79 | destination=f"{table_slice.schema}.{table_slice.table}", 80 | project=table_slice.database, 81 | location=context.resource_config.get("location") if context.resource_config else None, # type: ignore 82 | timeout=context.resource_config.get("timeout") if context.resource_config else None, # type: ignore 83 | job_config=job_config, 84 | ) 85 | job.result() 86 | 87 | context.add_output_metadata(get_polars_metadata(context=context, df=obj)) 88 | 89 | def load_input(self, context: InputContext, table_slice: TableSlice, connection) -> pl.DataFrame: 90 | """Loads the input as a Polars DataFrame.""" 91 | assert isinstance(connection, bigquery.Client) 92 | 93 | if table_slice.partition_dimensions and len(context.asset_partition_keys) == 0: 94 | return pl.DataFrame() 95 | result = connection.query( 96 | query=BigQueryClient.get_select_statement(table_slice), 97 | project=table_slice.database, 98 | location=context.resource_config.get("location") if context.resource_config else None, 99 | timeout=context.resource_config.get("timeout") if context.resource_config else None, 100 | ).to_arrow() 101 | 102 | return pl.DataFrame(result) 103 | 104 | @property 105 | def supported_types(self): 106 | return [pl.DataFrame] 107 | 108 | 109 | class PolarsBigQueryIOManager(BigQueryIOManager): 110 | """Implements reading and writing Polars DataFrames from/to `BigQuery `_). 111 | 112 | Features: 113 | - All :py:class:`~dagster.DBIOManager` features 114 | - Supports writing partitioned tables (`"partition_expr"` input metadata key must be specified). 115 | 116 | Returns: 117 | IOManagerDefinition 118 | 119 | Examples: 120 | .. code-block:: python 121 | 122 | from dagster import Definitions, EnvVar 123 | from dagster_polars import PolarsBigQueryIOManager 124 | 125 | @asset( 126 | key_prefix=["my_dataset"] # will be used as the dataset in BigQuery 127 | ) 128 | def my_table() -> pl.DataFrame: # the name of the asset will be the table name 129 | ... 130 | 131 | defs = Definitions( 132 | assets=[my_table], 133 | resources={ 134 | "io_manager": PolarsBigQueryIOManager(project=EnvVar("GCP_PROJECT")) 135 | } 136 | ) 137 | 138 | You can tell Dagster in which dataset to create tables by setting the "dataset" configuration value. 139 | If you do not provide a dataset as configuration to the I/O manager, Dagster will determine a dataset based 140 | on the assets and ops using the I/O Manager. For assets, the dataset will be determined from the asset key, 141 | as shown in the above example. The final prefix before the asset name will be used as the dataset. For example, 142 | if the asset "my_table" had the key prefix ["gcp", "bigquery", "my_dataset"], the dataset "my_dataset" will be 143 | used. For ops, the dataset can be specified by including a "schema" entry in output metadata. If "schema" is 144 | not provided via config or on the asset/op, "public" will be used for the dataset. 145 | 146 | .. code-block:: python 147 | 148 | @op( 149 | out={"my_table": Out(metadata={"schema": "my_dataset"})} 150 | ) 151 | def make_my_table() -> pl.DataFrame: 152 | # the returned value will be stored at my_dataset.my_table 153 | ... 154 | 155 | To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the 156 | In or AssetIn. 157 | 158 | .. code-block:: python 159 | 160 | @asset( 161 | ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})} 162 | ) 163 | def my_table_a(my_table: pl.DataFrame) -> pd.DataFrame: 164 | # my_table will just contain the data from column "a" 165 | ... 166 | 167 | If you cannot upload a file to your Dagster deployment, or otherwise cannot 168 | `authenticate with GCP `_ 169 | via a standard method, you can provide a service account key as the "gcp_credentials" configuration. 170 | Dagster will store this key in a temporary file and set GOOGLE_APPLICATION_CREDENTIALS to point to the file. 171 | After the run completes, the file will be deleted, and GOOGLE_APPLICATION_CREDENTIALS will be 172 | unset. The key must be base64 encoded to avoid issues with newlines in the keys. You can retrieve 173 | the base64 encoded key with this shell command: cat $GOOGLE_APPLICATION_CREDENTIALS | base64 174 | 175 | The "write_disposition" metadata key can be used to set the `write_disposition` parameter 176 | of `bigquery.JobConfig`. For example, set it to `"WRITE_APPEND"` to append to an existing table intead of 177 | overwriting it. 178 | 179 | Install `dagster-polars[gcp]` to use this IOManager. 180 | 181 | """ 182 | 183 | @staticmethod 184 | def type_handlers() -> Sequence[DbTypeHandler]: 185 | return [PolarsBigQueryTypeHandler()] 186 | 187 | @staticmethod 188 | def default_load_type() -> Optional[Type]: 189 | return pl.DataFrame 190 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Elementl, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/parquet.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import TYPE_CHECKING, Any, Literal, Optional, Union, overload 3 | 4 | import polars as pl 5 | import pyarrow.dataset as ds 6 | import pyarrow.parquet as pq 7 | from dagster import InputContext, OutputContext 8 | from dagster._annotations import experimental 9 | from fsspec.implementations.local import LocalFileSystem 10 | from packaging.version import Version 11 | from pyarrow import Table 12 | 13 | from dagster_polars.io_managers.base import BasePolarsUPathIOManager 14 | from dagster_polars.types import LazyFrameWithMetadata, StorageMetadata 15 | 16 | if TYPE_CHECKING: 17 | from upath import UPath 18 | 19 | 20 | DAGSTER_POLARS_STORAGE_METADATA_KEY = "dagster_polars_metadata" 21 | 22 | 23 | def get_pyarrow_dataset(path: "UPath", context: InputContext) -> ds.Dataset: 24 | context_metadata = context.metadata or {} 25 | 26 | fs = path.fs if hasattr(path, "fs") else None 27 | 28 | if context_metadata.get("partitioning") is not None: 29 | context.log.warning( 30 | f'"partitioning" metadata value for PolarsParquetIOManager is deprecated ' 31 | f'in favor of "partition_by" (loading from {path})' 32 | ) 33 | 34 | dataset = ds.dataset( 35 | str(path), 36 | filesystem=fs, 37 | format=context_metadata.get("format", "parquet"), 38 | partitioning=context_metadata.get("partitioning") or context_metadata.get("partition_by"), 39 | partition_base_dir=context_metadata.get("partition_base_dir"), 40 | exclude_invalid_files=context_metadata.get("exclude_invalid_files", True), 41 | ignore_prefixes=context_metadata.get("ignore_prefixes", [".", "_"]), 42 | ) 43 | 44 | return dataset 45 | 46 | 47 | def scan_parquet(path: "UPath", context: InputContext) -> pl.LazyFrame: 48 | """Scan a parquet file and return a lazy frame (uses polars native reader). 49 | 50 | :param path: 51 | :param context: 52 | :return: 53 | """ 54 | context_metadata = context.metadata or {} 55 | 56 | storage_options: Optional[dict[str, Any]] = path.storage_options if hasattr(path, "storage_options") else None 57 | 58 | kwargs = dict( 59 | n_rows=context_metadata.get("n_rows", None), 60 | cache=context_metadata.get("cache", True), 61 | parallel=context_metadata.get("parallel", "auto"), 62 | rechunk=context_metadata.get("rechunk", True), 63 | low_memory=context_metadata.get("low_memory", False), 64 | use_statistics=context_metadata.get("use_statistics", True), 65 | hive_partitioning=context_metadata.get("hive_partitioning", True), 66 | retries=context_metadata.get("retries", 0), 67 | ) 68 | if Version(pl.__version__) >= Version("0.20.4"): 69 | kwargs["row_index_name"] = context_metadata.get("row_index_name", None) 70 | kwargs["row_index_offset"] = context_metadata.get("row_index_offset", 0) 71 | else: 72 | kwargs["row_count_name"] = context_metadata.get("row_count_name", None) 73 | kwargs["row_count_offset"] = context_metadata.get("row_count_offset", 0) 74 | 75 | return pl.scan_parquet(str(path), storage_options=storage_options, **kwargs) # type: ignore 76 | 77 | 78 | @experimental 79 | class PolarsParquetIOManager(BasePolarsUPathIOManager): 80 | """Implements reading and writing Polars DataFrames in Apache Parquet format. 81 | 82 | Features: 83 | - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. 84 | - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). 85 | - Supports reading partitioned Parquet datasets (for example, often produced by Spark). 86 | - Supports reading/writing custom metadata in the Parquet file's schema as json-serialized bytes at `"dagster_polars_metadata"` key. 87 | 88 | Examples: 89 | 90 | .. code-block:: python 91 | 92 | from dagster import asset 93 | from dagster_polars import PolarsParquetIOManager 94 | import polars as pl 95 | 96 | @asset( 97 | io_manager_key="polars_parquet_io_manager", 98 | key_prefix=["my_dataset"] 99 | ) 100 | def my_asset() -> pl.DataFrame: # data will be stored at /my_dataset/my_asset.parquet 101 | ... 102 | 103 | defs = Definitions( 104 | assets=[my_table], 105 | resources={ 106 | "polars_parquet_io_manager": PolarsParquetIOManager(base_dir="s3://my-bucket/my-dir") 107 | } 108 | ) 109 | 110 | Reading partitioned Parquet datasets: 111 | 112 | .. code-block:: python 113 | 114 | from dagster import SourceAsset 115 | 116 | my_asset = SourceAsset( 117 | key=["path", "to", "dataset"], 118 | io_manager_key="polars_parquet_io_manager", 119 | metadata={ 120 | "partition_by": ["year", "month", "day"] 121 | } 122 | ) 123 | 124 | Storing custom metadata in the Parquet file schema (this metadata can be read outside of Dagster with a helper function :py:meth:`dagster_polars.PolarsParquetIOManager.read_parquet_metadata`): 125 | 126 | .. code-block:: python 127 | 128 | from dagster_polars import DataFrameWithMetadata 129 | 130 | 131 | @asset( 132 | io_manager_key="polars_parquet_io_manager", 133 | ) 134 | def upstream() -> DataFrameWithMetadata: 135 | return pl.DataFrame(...), {"my_custom_metadata": "my_custom_value"} 136 | 137 | 138 | @asset( 139 | io_manager_key="polars_parquet_io_manager", 140 | ) 141 | def downsteam(upstream: DataFrameWithMetadata): 142 | df, metadata = upstream 143 | assert metadata["my_custom_metadata"] == "my_custom_value" 144 | """ 145 | 146 | extension: str = ".parquet" # type: ignore 147 | 148 | def sink_df_to_path( 149 | self, 150 | context: OutputContext, 151 | df: pl.LazyFrame, 152 | path: "UPath", 153 | metadata: Optional[StorageMetadata] = None, 154 | ): 155 | context_metadata = context.metadata or {} 156 | 157 | if metadata is not None: 158 | context.log.warning("Sink not possible with StorageMetadata, instead it's dispatched to pyarrow writer.") 159 | return self.write_df_to_path(context, df.collect(), path, metadata) 160 | else: 161 | fs = path.fs if hasattr(path, "fs") else None 162 | if isinstance(fs, LocalFileSystem): 163 | compression = context_metadata.get("compression", "zstd") 164 | compression_level = context_metadata.get("compression_level") 165 | statistics = context_metadata.get("statistics", False) 166 | row_group_size = context_metadata.get("row_group_size") 167 | 168 | df.sink_parquet( 169 | str(path), 170 | compression=compression, 171 | compression_level=compression_level, 172 | statistics=statistics, 173 | row_group_size=row_group_size, 174 | ) 175 | else: 176 | # TODO(ion): add sink_parquet once this PR gets merged: https://github.com/pola-rs/polars/pull/11519 177 | context.log.warning( 178 | "Cloud sink is not possible yet, instead it's dispatched to pyarrow writer which collects it into memory first.", 179 | ) 180 | return self.write_df_to_path(context, df.collect(), path, metadata) 181 | 182 | def write_df_to_path( 183 | self, 184 | context: OutputContext, 185 | df: pl.DataFrame, 186 | path: "UPath", 187 | metadata: Optional[StorageMetadata] = None, 188 | ): 189 | context_metadata = context.metadata or {} 190 | compression = context_metadata.get("compression", "zstd") 191 | compression_level = context_metadata.get("compression_level") 192 | statistics = context_metadata.get("statistics", False) 193 | row_group_size = context_metadata.get("row_group_size") 194 | pyarrow_options = context_metadata.get("pyarrow_options", None) 195 | 196 | fs = path.fs if hasattr(path, "fs") else None 197 | 198 | if metadata is not None: 199 | table: Table = df.to_arrow() 200 | context.log.warning("StorageMetadata is passed, so the PyArrow writer is used.") 201 | existing_metadata = table.schema.metadata.to_dict() if table.schema.metadata is not None else {} 202 | existing_metadata.update({DAGSTER_POLARS_STORAGE_METADATA_KEY: json.dumps(metadata)}) 203 | table = table.replace_schema_metadata(existing_metadata) 204 | 205 | if pyarrow_options is not None and pyarrow_options.get("partition_cols"): 206 | pyarrow_options["compression"] = None if compression == "uncompressed" else compression 207 | pyarrow_options["compression_level"] = compression_level 208 | pyarrow_options["write_statistics"] = statistics 209 | pyarrow_options["row_group_size"] = row_group_size 210 | pq.write_to_dataset( 211 | table=table, 212 | root_path=str(path), 213 | fs=fs, 214 | **(pyarrow_options or {}), 215 | ) 216 | else: 217 | pq.write_table( 218 | table=table, 219 | where=str(path), 220 | row_group_size=row_group_size, 221 | compression=None if compression == "uncompressed" else compression, # type: ignore 222 | compression_level=compression_level, 223 | write_statistics=statistics, 224 | filesystem=fs, 225 | **(pyarrow_options or {}), 226 | ) 227 | else: 228 | if pyarrow_options is not None: 229 | pyarrow_options["filesystem"] = fs 230 | df.write_parquet( 231 | str(path), 232 | compression=compression, # type: ignore 233 | compression_level=compression_level, 234 | statistics=statistics, 235 | row_group_size=row_group_size, 236 | use_pyarrow=True, 237 | pyarrow_options=pyarrow_options, 238 | ) 239 | elif fs is not None: 240 | with fs.open(str(path), mode="wb") as f: 241 | df.write_parquet( 242 | f, # type: ignore 243 | compression=compression, # type: ignore 244 | compression_level=compression_level, 245 | statistics=statistics, 246 | row_group_size=row_group_size, 247 | ) 248 | else: 249 | df.write_parquet( 250 | str(path), 251 | compression=compression, # type: ignore 252 | compression_level=compression_level, 253 | statistics=statistics, 254 | row_group_size=row_group_size, 255 | ) 256 | 257 | @overload 258 | def scan_df_from_path( 259 | self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] 260 | ) -> pl.LazyFrame: 261 | ... 262 | 263 | @overload 264 | def scan_df_from_path( 265 | self, path: "UPath", context: InputContext, with_metadata: Literal[True] 266 | ) -> LazyFrameWithMetadata: 267 | ... 268 | 269 | def scan_df_from_path( 270 | self, 271 | path: "UPath", 272 | context: InputContext, 273 | with_metadata: Optional[bool] = False, 274 | partition_key: Optional[str] = None, 275 | ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: 276 | ldf = scan_parquet(path, context) 277 | 278 | if not with_metadata: 279 | return ldf 280 | else: 281 | ds = get_pyarrow_dataset(path, context) 282 | dagster_polars_metadata = ( 283 | ds.schema.metadata.get(DAGSTER_POLARS_STORAGE_METADATA_KEY.encode("utf-8")) 284 | if ds.schema.metadata is not None 285 | else None 286 | ) 287 | 288 | metadata = json.loads(dagster_polars_metadata) if dagster_polars_metadata is not None else {} 289 | 290 | return ldf, metadata 291 | 292 | @classmethod 293 | def read_parquet_metadata(cls, path: "UPath") -> StorageMetadata: 294 | """Just a helper method to read metadata from a parquet file. 295 | 296 | Is not used internally, but is helpful for reading Parquet metadata from outside of Dagster. 297 | :param path: 298 | :return: 299 | """ 300 | metadata = pq.read_metadata(str(path), filesystem=path.fs if hasattr(path, "fs") else None).metadata 301 | 302 | dagster_polars_metadata = ( 303 | metadata.get(DAGSTER_POLARS_STORAGE_METADATA_KEY.encode("utf-8")) if metadata is not None else None 304 | ) 305 | 306 | return json.loads(dagster_polars_metadata) if dagster_polars_metadata is not None else {} 307 | -------------------------------------------------------------------------------- /tests/test_upath_io_managers_lazy.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Dict, Optional, Tuple 3 | 4 | import polars as pl 5 | import polars.testing as pl_testing 6 | import pytest 7 | from dagster import ( 8 | AssetExecutionContext, 9 | AssetIn, 10 | DailyPartitionsDefinition, 11 | DimensionPartitionMapping, 12 | IdentityPartitionMapping, 13 | MultiPartitionKey, 14 | MultiPartitionMapping, 15 | MultiPartitionsDefinition, 16 | OpExecutionContext, 17 | StaticPartitionsDefinition, 18 | TimeWindowPartitionMapping, 19 | asset, 20 | materialize, 21 | ) 22 | 23 | from dagster_polars import ( 24 | BasePolarsUPathIOManager, 25 | LazyFramePartitions, 26 | PolarsDeltaIOManager, 27 | PolarsParquetIOManager, 28 | StorageMetadata, 29 | ) 30 | from tests.utils import get_saved_path 31 | 32 | 33 | def test_polars_upath_io_manager_stats_metadata( 34 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 35 | ): 36 | manager, _ = io_manager_and_lazy_df 37 | 38 | df = pl.LazyFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) 39 | 40 | @asset(io_manager_def=manager) 41 | def upstream() -> pl.LazyFrame: 42 | return df 43 | 44 | result = materialize( 45 | [upstream], 46 | ) 47 | 48 | handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) 49 | 50 | stats = handled_output_events[0].event_specific_data.metadata.get("stats") # type: ignore 51 | 52 | # TODO(ion): think about how we can store lazyframe stats without doing costly computations (likely not ever possible) 53 | assert stats is None 54 | 55 | 56 | def test_polars_upath_io_manager_type_annotations( 57 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 58 | ): 59 | manager, df = io_manager_and_lazy_df 60 | 61 | @asset(io_manager_def=manager) 62 | def upstream() -> pl.LazyFrame: 63 | return df 64 | 65 | @asset(io_manager_def=manager) 66 | def downstream_lazy(upstream: pl.LazyFrame) -> None: 67 | assert isinstance(upstream, pl.LazyFrame), type(upstream) 68 | 69 | partitions_def = StaticPartitionsDefinition(["a", "b"]) 70 | 71 | @asset(io_manager_def=manager, partitions_def=partitions_def) 72 | def upstream_partitioned(context: OpExecutionContext) -> pl.LazyFrame: 73 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 74 | 75 | @asset(io_manager_def=manager) 76 | def downstream_multi_partitioned_lazy(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: 77 | for _df in upstream_partitioned.values(): 78 | assert isinstance(_df, pl.LazyFrame), type(_df) 79 | assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() 80 | 81 | for partition_key in ["a", "b"]: 82 | materialize( 83 | [upstream_partitioned], 84 | partition_key=partition_key, 85 | ) 86 | 87 | materialize( 88 | [ 89 | upstream_partitioned.to_source_asset(), 90 | upstream, 91 | downstream_lazy, 92 | downstream_multi_partitioned_lazy, 93 | ], 94 | ) 95 | 96 | 97 | def test_polars_upath_io_manager_nested_dtypes( 98 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 99 | ): 100 | manager, df = io_manager_and_lazy_df 101 | 102 | @asset(io_manager_def=manager) 103 | def upstream() -> pl.LazyFrame: 104 | return df 105 | 106 | @asset(io_manager_def=manager) 107 | def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: 108 | return upstream.collect(streaming=True) 109 | 110 | result = materialize( 111 | [upstream, downstream], 112 | ) 113 | 114 | saved_path = get_saved_path(result, "upstream") 115 | 116 | if isinstance(manager, PolarsParquetIOManager): 117 | pl_testing.assert_frame_equal(df.collect(), pl.read_parquet(saved_path)) 118 | elif isinstance(manager, PolarsDeltaIOManager): 119 | pl_testing.assert_frame_equal(df.collect(), pl.read_delta(saved_path)) 120 | else: 121 | raise ValueError(f"Test not implemented for {type(manager)}") 122 | 123 | 124 | def test_polars_upath_io_manager_input_optional_lazy( 125 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 126 | ): 127 | manager, df = io_manager_and_lazy_df 128 | 129 | @asset(io_manager_def=manager) 130 | def upstream() -> pl.LazyFrame: 131 | return df 132 | 133 | @asset(io_manager_def=manager) 134 | def downstream(upstream: Optional[pl.LazyFrame]) -> pl.DataFrame: 135 | assert upstream is not None 136 | return upstream.collect() 137 | 138 | materialize( 139 | [upstream, downstream], 140 | ) 141 | 142 | 143 | def test_polars_upath_io_manager_input_optional_lazy_e2e( 144 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 145 | ): 146 | manager, df = io_manager_and_lazy_df 147 | 148 | @asset(io_manager_def=manager) 149 | def upstream() -> pl.LazyFrame: 150 | return df 151 | 152 | @asset(io_manager_def=manager) 153 | def downstream(upstream: Optional[pl.LazyFrame]) -> pl.LazyFrame: 154 | assert upstream is not None 155 | return upstream 156 | 157 | materialize( 158 | [upstream, downstream], 159 | ) 160 | 161 | 162 | def test_polars_upath_io_manager_input_dict_lazy( 163 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 164 | ): 165 | manager, df = io_manager_and_lazy_df 166 | 167 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 168 | def upstream(context: AssetExecutionContext) -> pl.LazyFrame: 169 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 170 | 171 | @asset(io_manager_def=manager) 172 | def downstream(upstream: Dict[str, pl.LazyFrame]) -> pl.LazyFrame: 173 | dfs = [] 174 | for df in upstream.values(): 175 | assert isinstance(df, pl.LazyFrame) 176 | dfs.append(df) 177 | return pl.concat(dfs) 178 | 179 | for partition_key in ["a", "b"]: 180 | materialize( 181 | [upstream], 182 | partition_key=partition_key, 183 | ) 184 | 185 | materialize( 186 | [upstream.to_source_asset(), downstream], 187 | ) 188 | 189 | 190 | def test_polars_upath_io_manager_input_lazy_frame_partitions_lazy( 191 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 192 | ): 193 | manager, df = io_manager_and_lazy_df 194 | 195 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 196 | def upstream(context: AssetExecutionContext) -> pl.LazyFrame: 197 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 198 | 199 | @asset(io_manager_def=manager) 200 | def downstream(upstream: LazyFramePartitions) -> pl.LazyFrame: 201 | dfs = [] 202 | for df in upstream.values(): 203 | assert isinstance(df, pl.LazyFrame) 204 | dfs.append(df) 205 | return pl.concat(dfs) 206 | 207 | for partition_key in ["a", "b"]: 208 | materialize( 209 | [upstream], 210 | partition_key=partition_key, 211 | ) 212 | 213 | materialize( 214 | [upstream.to_source_asset(), downstream], 215 | ) 216 | 217 | 218 | def test_polars_upath_io_manager_input_optional_lazy_return_none( 219 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 220 | ): 221 | manager, df = io_manager_and_lazy_df 222 | 223 | @asset(io_manager_def=manager) 224 | def upstream() -> pl.LazyFrame: 225 | return df 226 | 227 | @asset 228 | def downstream(upstream: Optional[pl.LazyFrame]): 229 | assert upstream is None 230 | 231 | materialize( 232 | [upstream.to_source_asset(), downstream], 233 | ) 234 | 235 | 236 | def test_polars_upath_io_manager_output_optional_lazy( 237 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 238 | ): 239 | manager, df = io_manager_and_lazy_df 240 | 241 | @asset(io_manager_def=manager) 242 | def upstream() -> Optional[pl.LazyFrame]: 243 | return None 244 | 245 | @asset(io_manager_def=manager) 246 | def downstream(upstream: Optional[pl.LazyFrame]) -> Optional[pl.LazyFrame]: 247 | assert upstream is None 248 | return upstream 249 | 250 | materialize( 251 | [upstream, downstream], 252 | ) 253 | 254 | 255 | IO_MANAGERS_SUPPORTING_STORAGE_METADATA = ( 256 | PolarsParquetIOManager, 257 | PolarsDeltaIOManager, 258 | ) 259 | 260 | 261 | def check_skip_storage_metadata_test(io_manager_def: BasePolarsUPathIOManager): 262 | if not isinstance(io_manager_def, IO_MANAGERS_SUPPORTING_STORAGE_METADATA): 263 | pytest.skip(f"Only {IO_MANAGERS_SUPPORTING_STORAGE_METADATA} support storage metadata") 264 | 265 | 266 | @pytest.fixture 267 | def metadata() -> StorageMetadata: 268 | return {"a": 1, "b": "2", "c": [1, 2, 3], "d": {"e": 1}, "f": [1, 2, 3, {"g": 1}]} 269 | 270 | 271 | def test_upath_io_manager_storage_metadata_lazy( 272 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], metadata: StorageMetadata 273 | ): 274 | io_manager_def, df = io_manager_and_lazy_df 275 | check_skip_storage_metadata_test(io_manager_def) 276 | 277 | @asset(io_manager_def=io_manager_def) 278 | def upstream() -> Tuple[pl.LazyFrame, StorageMetadata]: 279 | return df, metadata 280 | 281 | @asset(io_manager_def=io_manager_def) 282 | def downstream(upstream: Tuple[pl.LazyFrame, StorageMetadata]) -> None: 283 | loaded_df, upstream_metadata = upstream 284 | assert upstream_metadata == metadata 285 | pl_testing.assert_frame_equal(loaded_df.collect(), df.collect()) 286 | 287 | materialize( 288 | [upstream, downstream], 289 | ) 290 | 291 | 292 | def test_upath_io_manager_storage_metadata_optional_lazy_exists( 293 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], metadata: StorageMetadata 294 | ): 295 | io_manager_def, df = io_manager_and_lazy_df 296 | check_skip_storage_metadata_test(io_manager_def) 297 | 298 | @asset(io_manager_def=io_manager_def) 299 | def upstream() -> Optional[Tuple[pl.LazyFrame, StorageMetadata]]: 300 | return df, metadata 301 | 302 | @asset(io_manager_def=io_manager_def) 303 | def downstream(upstream: Optional[Tuple[pl.LazyFrame, StorageMetadata]]) -> None: 304 | assert upstream is not None 305 | df, upstream_metadata = upstream 306 | assert upstream_metadata == metadata 307 | 308 | materialize( 309 | [upstream, downstream], 310 | ) 311 | 312 | 313 | def test_upath_io_manager_storage_metadata_optional_lazy_missing( 314 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], metadata: StorageMetadata 315 | ): 316 | io_manager_def, df = io_manager_and_lazy_df 317 | check_skip_storage_metadata_test(io_manager_def) 318 | 319 | @asset(io_manager_def=io_manager_def) 320 | def upstream() -> Optional[Tuple[pl.LazyFrame, StorageMetadata]]: 321 | return None 322 | 323 | @asset(io_manager_def=io_manager_def) 324 | def downstream(upstream: Optional[Tuple[pl.LazyFrame, StorageMetadata]]) -> None: 325 | assert upstream is None 326 | 327 | materialize( 328 | [upstream, downstream], 329 | ) 330 | 331 | 332 | def test_upath_io_manager_multi_partitions_definition_load_multiple_partitions( 333 | io_manager_and_lazy_df: Tuple[BasePolarsUPathIOManager, pl.LazyFrame], 334 | ): 335 | io_manager_def, df = io_manager_and_lazy_df 336 | 337 | today = datetime.now().date() 338 | 339 | partitions_def = MultiPartitionsDefinition( 340 | { 341 | "time": DailyPartitionsDefinition(start_date=str(today - timedelta(days=3))), 342 | "static": StaticPartitionsDefinition(["a"]), 343 | } 344 | ) 345 | 346 | @asset(partitions_def=partitions_def, io_manager_def=io_manager_def) 347 | def upstream(context: AssetExecutionContext) -> pl.LazyFrame: 348 | return pl.LazyFrame({"partition": [str(context.partition_key)]}) 349 | 350 | # this asset will request 2 upstream partitions 351 | @asset( 352 | io_manager_def=io_manager_def, 353 | partitions_def=partitions_def, 354 | ins={ 355 | "upstream": AssetIn( 356 | partition_mapping=MultiPartitionMapping( 357 | { 358 | "time": DimensionPartitionMapping("time", TimeWindowPartitionMapping(start_offset=-1)), 359 | "static": DimensionPartitionMapping("static", IdentityPartitionMapping()), 360 | } 361 | ) 362 | ) 363 | }, 364 | ) 365 | def downstream(context: AssetExecutionContext, upstream: LazyFramePartitions) -> None: 366 | assert len(upstream.values()) == 2 367 | 368 | materialize( 369 | [upstream], 370 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=3)), "static": "a"}), 371 | ) 372 | materialize( 373 | [upstream], 374 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=2)), "static": "a"}), 375 | ) 376 | # materialize([upstream], partition_key=MultiPartitionKey({"time": str(today - timedelta(days=1)), "static": "a"})) 377 | 378 | materialize( 379 | [upstream.to_source_asset(), downstream], 380 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=2)), "static": "a"}), 381 | ) 382 | -------------------------------------------------------------------------------- /tests/test_polars_delta.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | from typing import Dict 4 | 5 | import polars as pl 6 | import polars.testing as pl_testing 7 | import pytest 8 | from dagster import ( 9 | AssetExecutionContext, 10 | AssetIn, 11 | Config, 12 | DagsterInstance, 13 | OpExecutionContext, 14 | RunConfig, 15 | StaticPartitionsDefinition, 16 | asset, 17 | materialize, 18 | ) 19 | from deltalake import DeltaTable 20 | from hypothesis import given, settings 21 | from polars.testing.parametric import dataframes 22 | 23 | from dagster_polars import PolarsDeltaIOManager 24 | from dagster_polars.io_managers.delta import DeltaWriteMode 25 | from tests.utils import get_saved_path 26 | 27 | # TODO: remove pl.Time once it's supported 28 | # TODO: remove pl.Duration pl.Duration once it's supported 29 | # https://github.com/pola-rs/polars/issues/9631 30 | # TODO: remove UInt types once they are supported 31 | # https://github.com/pola-rs/polars/issues/9627 32 | 33 | 34 | @pytest.mark.flaky(reruns=5) 35 | @given( 36 | df=dataframes( 37 | excluded_dtypes=[ 38 | pl.Categorical, 39 | pl.Duration, 40 | pl.Time, 41 | pl.UInt8, 42 | pl.UInt16, 43 | pl.UInt32, 44 | pl.UInt64, 45 | pl.Datetime("ns", None), 46 | ], 47 | min_size=5, 48 | allow_infinities=False, 49 | ) 50 | ) 51 | @settings(max_examples=50, deadline=None) 52 | def test_polars_delta_io_manager(session_polars_delta_io_manager: PolarsDeltaIOManager, df: pl.DataFrame): 53 | time.sleep(0.2) # too frequent writes mess up DeltaLake concurrent 54 | 55 | @asset(io_manager_def=session_polars_delta_io_manager, metadata={"overwrite_schema": True}) 56 | def upstream() -> pl.DataFrame: 57 | return df 58 | 59 | @asset(io_manager_def=session_polars_delta_io_manager, metadata={"overwrite_schema": True}) 60 | def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: 61 | return upstream.collect(streaming=True) 62 | 63 | result = materialize( 64 | [upstream, downstream], 65 | ) 66 | 67 | handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) 68 | 69 | saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore[index,union-attr] 70 | assert isinstance(saved_path, str) 71 | pl_testing.assert_frame_equal(df, pl.read_delta(saved_path)) 72 | shutil.rmtree(saved_path) # cleanup manually because of hypothesis 73 | 74 | 75 | def test_polars_delta_io_manager_append(polars_delta_io_manager: PolarsDeltaIOManager): 76 | df = pl.DataFrame( 77 | { 78 | "a": [1, 2, 3], 79 | } 80 | ) 81 | 82 | @asset(io_manager_def=polars_delta_io_manager, metadata={"mode": "append"}) 83 | def append_asset() -> pl.DataFrame: 84 | return df 85 | 86 | result = materialize( 87 | [append_asset], 88 | ) 89 | 90 | handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("append_asset"))) 91 | saved_path = handled_output_events[0].event_specific_data.metadata["path"].value # type: ignore 92 | assert handled_output_events[0].event_specific_data.metadata["row_count"].value == 3 # type: ignore 93 | assert handled_output_events[0].event_specific_data.metadata["append_row_count"].value == 3 # type: ignore 94 | assert isinstance(saved_path, str) 95 | 96 | result = materialize( 97 | [append_asset], 98 | ) 99 | handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("append_asset"))) 100 | assert handled_output_events[0].event_specific_data.metadata["row_count"].value == 6 # type: ignore 101 | assert handled_output_events[0].event_specific_data.metadata["append_row_count"].value == 3 # type: ignore 102 | 103 | pl_testing.assert_frame_equal(pl.concat([df, df]), pl.read_delta(saved_path)) 104 | 105 | 106 | def test_polars_delta_io_manager_overwrite_schema( 107 | polars_delta_io_manager: PolarsDeltaIOManager, dagster_instance: DagsterInstance 108 | ): 109 | @asset(io_manager_def=polars_delta_io_manager) 110 | def overwrite_schema_asset_1() -> pl.DataFrame: 111 | return pl.DataFrame( 112 | { 113 | "a": [1, 2, 3], 114 | } 115 | ) 116 | 117 | result = materialize( 118 | [overwrite_schema_asset_1], 119 | ) 120 | 121 | saved_path = get_saved_path(result, "overwrite_schema_asset_1") 122 | 123 | pl_testing.assert_frame_equal( 124 | pl.DataFrame( 125 | { 126 | "a": [1, 2, 3], 127 | } 128 | ), 129 | pl.read_delta(saved_path), 130 | ) 131 | 132 | @asset( 133 | io_manager_def=polars_delta_io_manager, 134 | metadata={"overwrite_schema": True, "mode": "overwrite"}, 135 | ) 136 | def overwrite_schema_asset_2() -> pl.DataFrame: 137 | return pl.DataFrame( 138 | { 139 | "b": ["1", "2", "3"], 140 | } 141 | ) 142 | 143 | result = materialize( 144 | [overwrite_schema_asset_2], 145 | ) 146 | 147 | saved_path = get_saved_path(result, "overwrite_schema_asset_2") 148 | 149 | pl_testing.assert_frame_equal( 150 | pl.DataFrame( 151 | { 152 | "b": ["1", "2", "3"], 153 | } 154 | ), 155 | pl.read_delta(saved_path), 156 | ) 157 | 158 | # test IOManager configuration works too 159 | @asset( 160 | io_manager_def=PolarsDeltaIOManager( 161 | base_dir=dagster_instance.storage_directory(), 162 | mode=DeltaWriteMode.overwrite, 163 | overwrite_schema=True, 164 | ) 165 | ) 166 | def overwrite_schema_asset_3() -> pl.DataFrame: 167 | return pl.DataFrame( 168 | { 169 | "a": [1, 2, 3], 170 | } 171 | ) 172 | 173 | result = materialize( 174 | [overwrite_schema_asset_3], 175 | ) 176 | 177 | saved_path = get_saved_path(result, "overwrite_schema_asset_3") 178 | 179 | pl_testing.assert_frame_equal( 180 | pl.DataFrame( 181 | { 182 | "a": [1, 2, 3], 183 | } 184 | ), 185 | pl.read_delta(saved_path), 186 | ) 187 | 188 | 189 | def test_polars_delta_io_manager_overwrite_schema_lazy( 190 | polars_delta_io_manager: PolarsDeltaIOManager, dagster_instance: DagsterInstance 191 | ): 192 | @asset(io_manager_def=polars_delta_io_manager) 193 | def overwrite_schema_asset_1() -> pl.LazyFrame: 194 | return pl.LazyFrame( 195 | { 196 | "a": [1, 2, 3], 197 | } 198 | ) 199 | 200 | result = materialize( 201 | [overwrite_schema_asset_1], 202 | ) 203 | 204 | saved_path = get_saved_path(result, "overwrite_schema_asset_1") 205 | 206 | pl_testing.assert_frame_equal( 207 | pl.DataFrame( 208 | { 209 | "a": [1, 2, 3], 210 | } 211 | ), 212 | pl.read_delta(saved_path), 213 | ) 214 | 215 | @asset( 216 | io_manager_def=polars_delta_io_manager, 217 | metadata={"overwrite_schema": True, "mode": "overwrite"}, 218 | ) 219 | def overwrite_schema_asset_2() -> pl.LazyFrame: 220 | return pl.LazyFrame( 221 | { 222 | "b": ["1", "2", "3"], 223 | } 224 | ) 225 | 226 | result = materialize( 227 | [overwrite_schema_asset_2], 228 | ) 229 | 230 | saved_path = get_saved_path(result, "overwrite_schema_asset_2") 231 | 232 | pl_testing.assert_frame_equal( 233 | pl.DataFrame( 234 | { 235 | "b": ["1", "2", "3"], 236 | } 237 | ), 238 | pl.read_delta(saved_path), 239 | ) 240 | 241 | # test IOManager configuration works too 242 | @asset( 243 | io_manager_def=PolarsDeltaIOManager( 244 | base_dir=dagster_instance.storage_directory(), 245 | mode=DeltaWriteMode.overwrite, 246 | overwrite_schema=True, 247 | ) 248 | ) 249 | def overwrite_schema_asset_3() -> pl.LazyFrame: 250 | return pl.LazyFrame( 251 | { 252 | "a": [1, 2, 3], 253 | } 254 | ) 255 | 256 | result = materialize( 257 | [overwrite_schema_asset_3], 258 | ) 259 | 260 | saved_path = get_saved_path(result, "overwrite_schema_asset_3") 261 | 262 | pl_testing.assert_frame_equal( 263 | pl.DataFrame( 264 | { 265 | "a": [1, 2, 3], 266 | } 267 | ), 268 | pl.read_delta(saved_path), 269 | ) 270 | 271 | 272 | def test_polars_delta_native_partitioning(polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame): 273 | manager = polars_delta_io_manager 274 | df = df_for_delta 275 | 276 | partitions_def = StaticPartitionsDefinition(["a", "b"]) 277 | 278 | @asset( 279 | io_manager_def=manager, 280 | partitions_def=partitions_def, 281 | metadata={"partition_by": "partition"}, 282 | ) 283 | def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: 284 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 285 | 286 | lenghts = {} 287 | 288 | @asset(io_manager_def=manager) 289 | def downstream_load_multiple_partitions(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: 290 | for partition, _ldf in upstream_partitioned.items(): 291 | assert isinstance(_ldf, pl.LazyFrame), type(_ldf) 292 | _df = _ldf.collect() 293 | assert (_df.select(pl.col("partition").eq(partition).alias("eq")))["eq"].all() 294 | lenghts[partition] = len(_df) 295 | 296 | assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() 297 | 298 | saved_path = None # noqa 299 | 300 | for partition_key in ["a", "b"]: 301 | result = materialize( 302 | [upstream_partitioned], 303 | partition_key=partition_key, 304 | ) 305 | saved_path = get_saved_path(result, "upstream_partitioned") 306 | assert saved_path.endswith("upstream_partitioned.delta"), saved_path # DeltaLake should handle partitioning! 307 | assert DeltaTable(saved_path).metadata().partition_columns == ["partition"] 308 | 309 | assert saved_path is not None 310 | written_df = pl.read_delta(saved_path) 311 | 312 | assert len(written_df) == len(df) * 2 313 | assert set(written_df["partition"].unique()) == {"a", "b"} 314 | 315 | materialize( 316 | [ 317 | upstream_partitioned.to_source_asset(), 318 | downstream_load_multiple_partitions, 319 | ], 320 | ) 321 | 322 | @asset(io_manager_def=manager) 323 | def downstream_load_multiple_partitions_as_single_df(upstream_partitioned: pl.DataFrame) -> None: 324 | assert set(upstream_partitioned["partition"].unique()) == {"a", "b"} 325 | 326 | materialize( 327 | [ 328 | upstream_partitioned.to_source_asset(), 329 | downstream_load_multiple_partitions_as_single_df, 330 | ], 331 | ) 332 | 333 | 334 | def test_polars_delta_native_partitioning_loading_single_partition( 335 | polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame 336 | ): 337 | manager = polars_delta_io_manager 338 | df = df_for_delta 339 | 340 | partitions_def = StaticPartitionsDefinition(["a", "b"]) 341 | 342 | @asset( 343 | io_manager_def=manager, 344 | partitions_def=partitions_def, 345 | metadata={"partition_by": "partition"}, 346 | ) 347 | def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: 348 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 349 | 350 | @asset(io_manager_def=manager, partitions_def=partitions_def) 351 | def downstream_partitioned(context: AssetExecutionContext, upstream_partitioned: pl.DataFrame) -> None: 352 | partitions = upstream_partitioned["partition"].unique().to_list() 353 | assert len(partitions) == 1 354 | assert partitions[0] == context.partition_key 355 | 356 | for partition_key in ["a", "b"]: 357 | materialize( 358 | [upstream_partitioned, downstream_partitioned], 359 | partition_key=partition_key, 360 | ) 361 | 362 | 363 | def test_polars_delta_time_travel(polars_delta_io_manager: PolarsDeltaIOManager, df_for_delta: pl.DataFrame): 364 | manager = polars_delta_io_manager 365 | df = df_for_delta 366 | 367 | class UpstreamConfig(Config): 368 | foo: str 369 | 370 | @asset(io_manager_def=manager) 371 | def upstream(context: OpExecutionContext, config: UpstreamConfig) -> pl.DataFrame: 372 | return df.with_columns(pl.lit(config.foo).alias("foo")) 373 | 374 | for foo in ["a", "b"]: 375 | materialize([upstream], run_config=RunConfig(ops={"upstream": UpstreamConfig(foo=foo)})) 376 | 377 | # get_saved_path(result, "upstream") 378 | 379 | @asset(ins={"upstream": AssetIn(metadata={"version": 0})}) 380 | def downstream_0(upstream: pl.DataFrame) -> None: 381 | assert upstream["foo"].head(1).item() == "a" 382 | 383 | materialize( 384 | [ 385 | upstream.to_source_asset(), 386 | downstream_0, 387 | ] 388 | ) 389 | 390 | @asset(ins={"upstream": AssetIn(metadata={"version": "1"})}) 391 | def downstream_1(upstream: pl.DataFrame) -> None: 392 | assert upstream["foo"].head(1).item() == "b" 393 | 394 | materialize( 395 | [ 396 | upstream.to_source_asset(), 397 | downstream_1, 398 | ] 399 | ) 400 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/delta.py: -------------------------------------------------------------------------------- 1 | import json 2 | from enum import Enum 3 | from pprint import pformat 4 | from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, overload 5 | 6 | import dagster._check as check 7 | import polars as pl 8 | from dagster import InputContext, MetadataValue, OutputContext 9 | from dagster._annotations import experimental 10 | from dagster._core.storage.upath_io_manager import is_dict_type 11 | 12 | from dagster_polars.io_managers.base import BasePolarsUPathIOManager 13 | from dagster_polars.types import DataFrameWithMetadata, LazyFrameWithMetadata, StorageMetadata 14 | 15 | try: 16 | from deltalake import DeltaTable 17 | from deltalake.exceptions import TableNotFoundError 18 | except ImportError as e: 19 | raise ImportError("Install 'dagster-polars[deltalake]' to use DeltaLake functionality") from e 20 | 21 | if TYPE_CHECKING: 22 | from upath import UPath 23 | 24 | 25 | DAGSTER_POLARS_STORAGE_METADATA_SUBDIR = ".dagster_polars_metadata" 26 | 27 | SINGLE_LOADING_TYPES = (pl.DataFrame, pl.LazyFrame, LazyFrameWithMetadata, DataFrameWithMetadata) 28 | 29 | 30 | class DeltaWriteMode(str, Enum): 31 | error = "error" 32 | append = "append" 33 | overwrite = "overwrite" 34 | ignore = "ignore" 35 | 36 | 37 | @experimental 38 | class PolarsDeltaIOManager(BasePolarsUPathIOManager): 39 | """Implements writing and reading DeltaLake tables. 40 | 41 | Features: 42 | - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. 43 | - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). 44 | - Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table. 45 | To enable this behavior, set the `partition_by` metadata value or config parameter (it's passed to `delta_write_options` of `pl.DataFrame.write_delta`). 46 | Automatically filters loaded partitions, unless `MultiPartitionsDefinition` are used. 47 | In this case you are responsible for filtering the partitions in the downstream asset, as it's non-trivial to do so in the IOManager. 48 | When loading all available asset partitions, the whole table can be loaded in one go by using type annotations like `pl.DataFrame` and `pl.LazyFrame`. 49 | - Supports writing/reading custom metadata to/from `.dagster_polars_metadata/.json` file in the DeltaLake table directory. 50 | 51 | Install `dagster-polars[delta]` to use this IOManager. 52 | 53 | Examples: 54 | 55 | .. code-block:: python 56 | 57 | from dagster import asset 58 | from dagster_polars import PolarsDeltaIOManager 59 | import polars as pl 60 | 61 | @asset( 62 | io_manager_key="polars_delta_io_manager", 63 | key_prefix=["my_dataset"] 64 | ) 65 | def my_asset() -> pl.DataFrame: # data will be stored at /my_dataset/my_asset.delta 66 | ... 67 | 68 | defs = Definitions( 69 | assets=[my_table], 70 | resources={ 71 | "polars_parquet_io_manager": PolarsDeltaIOManager(base_dir="s3://my-bucket/my-dir") 72 | } 73 | ) 74 | 75 | 76 | Appending to a DeltaLake table: 77 | 78 | .. code-block:: python 79 | 80 | @asset( 81 | io_manager_key="polars_delta_io_manager", 82 | metadata={ 83 | "mode": "append" 84 | }, 85 | ) 86 | def my_table() -> pl.DataFrame: 87 | ... 88 | 89 | Using native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table: 90 | 91 | .. code-block:: python 92 | 93 | from dagster import AssetExecutionContext, DailyPartitionedDefinition 94 | from dagster_polars import LazyFramePartitions 95 | 96 | @asset( 97 | io_manager_key="polars_delta_io_manager", 98 | metadata={ 99 | "partition_by": "partition_col" 100 | }, 101 | partitions_def=... 102 | ) 103 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 104 | df = ... 105 | 106 | # add partition to the DataFrame 107 | df = df.with_columns(pl.lit(context.partition_key).alias("partition_col")) 108 | return df 109 | 110 | @asset 111 | def downstream(upstream: LazyFramePartitions) -> pl.DataFrame: 112 | # concat LazyFrames, filter required partitions and call .collect() 113 | ... 114 | """ 115 | 116 | extension: str = ".delta" # type: ignore 117 | mode: DeltaWriteMode = DeltaWriteMode.overwrite.value # type: ignore 118 | overwrite_schema: bool = False 119 | version: Optional[int] = None 120 | 121 | # tmp fix until UPathIOManager supports this: added special handling for loading all partitions of an asset 122 | 123 | def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: 124 | # If no asset key, we are dealing with an op output which is always non-partitioned 125 | if not context.has_asset_key or not context.has_asset_partitions: 126 | path = self._get_path(context) 127 | return self._load_single_input(path, context) 128 | else: 129 | asset_partition_keys = context.asset_partition_keys 130 | if len(asset_partition_keys) == 0: 131 | return None 132 | elif len(asset_partition_keys) == 1: 133 | paths = self._get_paths_for_partitions(context) 134 | check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") 135 | path = next(iter(paths.values())) 136 | backcompat_paths = self._get_multipartition_backcompat_paths(context) 137 | backcompat_path = None if not backcompat_paths else next(iter(backcompat_paths.values())) 138 | 139 | return self._load_partition_from_path( 140 | context=context, 141 | partition_key=asset_partition_keys[0], 142 | path=path, 143 | backcompat_path=backcompat_path, 144 | ) 145 | else: # we are dealing with multiple partitions of an asset 146 | type_annotation = context.dagster_type.typing_type 147 | if type_annotation == Any or is_dict_type(type_annotation): 148 | return self._load_multiple_inputs(context) 149 | 150 | # special case of loading the whole DeltaLake table at once 151 | # when using AllPartitionMappings and native DeltaLake partitioning 152 | elif ( 153 | context.upstream_output is not None 154 | and context.upstream_output.metadata is not None 155 | and context.upstream_output.metadata.get("partition_by") is not None 156 | and type_annotation in SINGLE_LOADING_TYPES 157 | and context.upstream_output.asset_info is not None 158 | and context.upstream_output.asset_info.partitions_def is not None 159 | and set(asset_partition_keys) 160 | == set( 161 | context.upstream_output.asset_info.partitions_def.get_partition_keys( 162 | dynamic_partitions_store=context.instance 163 | ) 164 | ) 165 | ): 166 | # load all partitions at once 167 | return self.load_from_path( 168 | context=context, 169 | path=self.get_path_for_partition( 170 | context=context, 171 | partition=asset_partition_keys[0], # 0 would work, 172 | path=self._get_paths_for_partitions(context)[asset_partition_keys[0]], # 0 would work, 173 | ), 174 | partition_key=None, 175 | ) 176 | else: 177 | check.failed( 178 | "Loading an input that corresponds to multiple partitions, but the" 179 | f" type annotation on the op input is not a dict, Dict, Mapping, one of {SINGLE_LOADING_TYPES}," 180 | " or Any: is '{type_annotation}'." 181 | ) 182 | 183 | def sink_df_to_path( 184 | self, 185 | context: OutputContext, 186 | df: pl.LazyFrame, 187 | path: "UPath", 188 | metadata: Optional[StorageMetadata] = None, 189 | ): 190 | context_metadata = context.metadata or {} 191 | streaming = context_metadata.get("streaming", False) 192 | return self.write_df_to_path(context, df.collect(streaming=streaming), path, metadata) 193 | 194 | def write_df_to_path( 195 | self, 196 | context: OutputContext, 197 | df: pl.DataFrame, 198 | path: "UPath", 199 | metadata: Optional[StorageMetadata] = None, # why is metadata passed 200 | ): 201 | context_metadata = context.metadata or {} 202 | delta_write_options = context_metadata.get( 203 | "delta_write_options" 204 | ) # This needs to be gone and just only key value on the metadata 205 | 206 | if context.has_asset_partitions: 207 | delta_write_options = delta_write_options or {} 208 | partition_by = context_metadata.get( 209 | "partition_by" 210 | ) # this could be wrong, you could have partition_by in delta_write_options and in the metadata 211 | 212 | if partition_by is not None: 213 | assert context.partition_key is not None, 'can\'t set "partition_by" for an asset without partitions' 214 | 215 | delta_write_options["partition_by"] = partition_by 216 | delta_write_options["partition_filters"] = [(partition_by, "=", context.partition_key)] 217 | 218 | if delta_write_options is not None: 219 | context.log.debug(f"Writing with delta_write_options: {pformat(delta_write_options)}") 220 | 221 | storage_options = self.storage_options 222 | try: 223 | dt = DeltaTable(str(path), storage_options=storage_options) 224 | except TableNotFoundError: 225 | dt = str(path) 226 | 227 | df.write_delta( 228 | dt, 229 | mode=context_metadata.get("mode") or self.mode.value, 230 | overwrite_schema=context_metadata.get("overwrite_schema") or self.overwrite_schema, 231 | storage_options=storage_options, 232 | delta_write_options=delta_write_options, 233 | ) 234 | if isinstance(dt, DeltaTable): 235 | current_version = dt.version() 236 | else: 237 | current_version = DeltaTable(str(path), storage_options=storage_options, without_files=True).version() 238 | context.add_output_metadata({"version": current_version}) 239 | 240 | if metadata is not None: 241 | metadata_path = self.get_storage_metadata_path(path, current_version) 242 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 243 | metadata_path.write_text(json.dumps(metadata)) 244 | 245 | @overload 246 | def scan_df_from_path( 247 | self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] 248 | ) -> pl.LazyFrame: 249 | ... 250 | 251 | @overload 252 | def scan_df_from_path( 253 | self, path: "UPath", context: InputContext, with_metadata: Literal[True] 254 | ) -> LazyFrameWithMetadata: 255 | ... 256 | 257 | def scan_df_from_path( 258 | self, 259 | path: "UPath", 260 | context: InputContext, 261 | with_metadata: Optional[bool] = False, 262 | ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: 263 | context_metadata = context.metadata or {} 264 | 265 | version = self.get_delta_version_to_load(path, context) 266 | 267 | context.log.debug(f"Reading Delta table with version: {version}") 268 | 269 | ldf = pl.scan_delta( 270 | str(path), 271 | version=version, 272 | delta_table_options=context_metadata.get("delta_table_options"), 273 | pyarrow_options=context_metadata.get("pyarrow_options"), 274 | storage_options=self.storage_options, 275 | ) 276 | 277 | if with_metadata: 278 | version = self.get_delta_version_to_load(path, context) 279 | metadata_path = self.get_storage_metadata_path(path, version) 280 | if metadata_path.exists(): 281 | metadata = json.loads(metadata_path.read_text()) 282 | else: 283 | metadata = {} 284 | return ldf, metadata 285 | 286 | else: 287 | return ldf 288 | 289 | def get_path_for_partition( 290 | self, context: Union[InputContext, OutputContext], path: "UPath", partition: str 291 | ) -> "UPath": 292 | if isinstance(context, InputContext): 293 | if ( 294 | context.upstream_output is not None 295 | and context.upstream_output.metadata is not None 296 | and context.upstream_output.metadata.get("partition_by") is not None 297 | ): 298 | # upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself 299 | return path 300 | 301 | if isinstance(context, OutputContext): 302 | if context.metadata is not None and context.metadata.get("partition_by") is not None: 303 | # this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself 304 | return path 305 | 306 | return path / partition # partitioning is handled by the IOManager 307 | 308 | def get_metadata( 309 | self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] 310 | ) -> Dict[str, MetadataValue]: 311 | context_metadata = context.metadata or {} 312 | 313 | metadata = super().get_metadata(context, obj) 314 | 315 | if context.has_asset_partitions: 316 | partition_by = context_metadata.get("partition_by") 317 | if partition_by is not None: 318 | metadata["partition_by"] = partition_by 319 | 320 | if context_metadata.get("mode") == "append": 321 | # modify the medatata to reflect the fact that we are appending to the table 322 | 323 | if context.has_asset_partitions: 324 | # paths = self._get_paths_for_partitions(context) 325 | # assert len(paths) == 1 326 | # path = list(paths.values())[0] 327 | 328 | # FIXME: what to about row_count metadata do if we are appending to a partitioned table? 329 | # we should not be using the full table length, 330 | # but it's unclear how to get the length of the partition we are appending to 331 | pass 332 | else: 333 | metadata["append_row_count"] = metadata["row_count"] 334 | 335 | path = self._get_path(context) 336 | # we need to get row_count from the full table 337 | metadata["row_count"] = MetadataValue.int( 338 | DeltaTable(str(path), storage_options=self.storage_options).to_pyarrow_dataset().count_rows() 339 | ) 340 | 341 | return metadata 342 | 343 | def get_delta_version_to_load(self, path: "UPath", context: InputContext) -> int: 344 | context_metadata = context.metadata or {} 345 | version_from_metadata = context_metadata.get("version") 346 | 347 | version_from_config = self.version 348 | 349 | version: Optional[int] = None 350 | 351 | if version_from_metadata is not None and version_from_config is not None: 352 | context.log.warning( 353 | f"Both version from metadata ({version_from_metadata}) " 354 | f"and config ({version_from_config}) are set. Using version from metadata." 355 | ) 356 | version = int(version_from_metadata) 357 | elif version_from_metadata is None and version_from_config is not None: 358 | version = int(version_from_config) 359 | elif version_from_metadata is not None and version_from_config is None: 360 | version = int(version_from_metadata) 361 | 362 | if version is None: 363 | return DeltaTable(str(path), storage_options=self.storage_options, without_files=True).version() 364 | else: 365 | return version 366 | 367 | def get_storage_metadata_path(self, path: "UPath", version: int) -> "UPath": 368 | return path / DAGSTER_POLARS_STORAGE_METADATA_SUBDIR / f"{version}.json" 369 | -------------------------------------------------------------------------------- /dagster_polars/io_managers/base.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from abc import abstractmethod 3 | from typing import ( 4 | TYPE_CHECKING, 5 | Any, 6 | Dict, 7 | Literal, 8 | Mapping, 9 | Optional, 10 | Tuple, 11 | Union, 12 | cast, 13 | get_args, 14 | get_origin, 15 | overload, 16 | ) 17 | 18 | import polars as pl 19 | from dagster import ( 20 | ConfigurableIOManager, 21 | EnvVar, 22 | InitResourceContext, 23 | InputContext, 24 | MetadataValue, 25 | OutputContext, 26 | UPathIOManager, 27 | ) 28 | from dagster import ( 29 | _check as check, 30 | ) 31 | from dagster._core.storage.upath_io_manager import is_dict_type 32 | from pydantic import PrivateAttr 33 | from pydantic.fields import Field 34 | 35 | from dagster_polars.io_managers.utils import get_polars_metadata 36 | from dagster_polars.types import ( 37 | DataFramePartitions, 38 | DataFramePartitionsWithMetadata, 39 | LazyFramePartitions, 40 | LazyFramePartitionsWithMetadata, 41 | LazyFrameWithMetadata, 42 | StorageMetadata, 43 | ) 44 | 45 | if TYPE_CHECKING: 46 | from upath import UPath 47 | 48 | POLARS_EAGER_FRAME_ANNOTATIONS = [ 49 | pl.DataFrame, 50 | Optional[pl.DataFrame], 51 | # common default types 52 | Any, 53 | type(None), 54 | None, 55 | # multiple partitions 56 | Dict[str, pl.DataFrame], 57 | Mapping[str, pl.DataFrame], 58 | DataFramePartitions, 59 | # DataFrame + metadata 60 | Tuple[pl.DataFrame, StorageMetadata], 61 | Optional[Tuple[pl.DataFrame, StorageMetadata]], 62 | # multiple partitions + metadata 63 | DataFramePartitionsWithMetadata, 64 | ] 65 | 66 | POLARS_LAZY_FRAME_ANNOTATIONS = [ 67 | pl.LazyFrame, 68 | Optional[pl.LazyFrame], 69 | # multiple partitions 70 | Dict[str, pl.LazyFrame], 71 | Mapping[str, pl.LazyFrame], 72 | LazyFramePartitions, 73 | # LazyFrame + metadata 74 | Tuple[pl.LazyFrame, StorageMetadata], 75 | Optional[Tuple[pl.LazyFrame, StorageMetadata]], 76 | # multiple partitions + metadata 77 | LazyFramePartitionsWithMetadata, 78 | ] 79 | 80 | 81 | if sys.version_info >= (3, 9): 82 | POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame]) 83 | POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.DataFrame]]) 84 | 85 | POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.LazyFrame]) 86 | POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.LazyFrame]]) 87 | 88 | 89 | def annotation_is_typing_optional(annotation) -> bool: 90 | return get_origin(annotation) == Union and type(None) in get_args(annotation) 91 | 92 | 93 | def annotation_is_tuple(annotation) -> bool: 94 | return get_origin(annotation) in (Tuple, tuple) 95 | 96 | 97 | def annotation_for_multiple_partitions(annotation) -> bool: 98 | if not annotation_is_typing_optional(annotation): 99 | return annotation_is_tuple(annotation) and get_origin(annotation) in (dict, Dict, Mapping) 100 | else: 101 | inner_annotation = get_args(annotation)[0] 102 | return annotation_is_tuple(inner_annotation) and get_origin(inner_annotation) in ( 103 | dict, 104 | Dict, 105 | Mapping, 106 | ) 107 | 108 | 109 | def annotation_is_tuple_with_metadata(annotation) -> bool: 110 | if annotation_is_typing_optional(annotation): 111 | annotation = get_args(annotation)[0] 112 | 113 | return annotation_is_tuple(annotation) and get_origin(get_args(annotation)[1]) in [ 114 | dict, 115 | Dict, 116 | Mapping, 117 | ] 118 | 119 | 120 | def annotation_for_storage_metadata(annotation) -> bool: 121 | # first unwrap the Optional type 122 | if annotation_is_typing_optional(annotation): 123 | annotation = get_args(annotation)[0] 124 | 125 | if not annotation_for_multiple_partitions(annotation): 126 | return annotation_is_tuple_with_metadata(annotation) 127 | else: 128 | # unwrap the partitions 129 | annotation = get_args(annotation)[1] 130 | return annotation_is_tuple_with_metadata(annotation) 131 | 132 | 133 | def _process_env_vars(config: Mapping[str, Any]) -> Dict[str, Any]: 134 | out = {} 135 | for key, value in config.items(): 136 | if isinstance(value, dict) and len(value) == 1 and next(iter(value.keys())) == "env": 137 | out[key] = EnvVar(next(iter(value.values()))).get_value() 138 | else: 139 | out[key] = value 140 | return out 141 | 142 | 143 | class BasePolarsUPathIOManager(ConfigurableIOManager, UPathIOManager): 144 | """Base class for `dagster-polars` IOManagers. 145 | 146 | Doesn't define a specific storage format. 147 | 148 | To implement a specific storage format (parquet, csv, etc), inherit from this class and implement the `write_df_to_path` and `scan_df_from_path` methods. 149 | 150 | Features: 151 | - All the features of :py:class:`~dagster.UPathIOManager` - works with local and remote filesystems (like S3), supports loading multiple partitions with respect to :py:class:`~dagster.PartitionMapping`, and more 152 | - returns the correct type - `polars.DataFrame`, `polars.LazyFrame`, or other types defined in :py:mod:`dagster_polars.types` - based on the input type annotation (or `dagster.DagsterType`'s `typing_type`) 153 | - handles `Nones` with `Optional` types by skipping loading missing inputs or saving `None` outputs 154 | - logs various metadata about the DataFrame - size, schema, sample, stats, ... 155 | - the `"columns"` input metadata value can be used to select a subset of columns to load 156 | """ 157 | 158 | base_dir: Optional[str] = Field(default=None, description="Base directory for storing files.") 159 | cloud_storage_options: Optional[Mapping[str, Any]] = Field( 160 | default=None, description="Storage authentication for cloud object store", alias="storage_options" 161 | ) 162 | _base_path = PrivateAttr() 163 | 164 | def setup_for_execution(self, context: InitResourceContext) -> None: 165 | from upath import UPath 166 | 167 | sp = _process_env_vars(self.cloud_storage_options) if self.cloud_storage_options is not None else {} 168 | self._base_path = ( 169 | UPath(self.base_dir, **sp) 170 | if self.base_dir is not None 171 | else UPath(check.not_none(context.instance).storage_directory()) 172 | ) 173 | 174 | @abstractmethod 175 | def write_df_to_path( 176 | self, 177 | context: OutputContext, 178 | df: pl.DataFrame, 179 | path: "UPath", 180 | metadata: Optional[StorageMetadata] = None, 181 | ): 182 | ... 183 | 184 | @abstractmethod 185 | def sink_df_to_path( 186 | self, 187 | context: OutputContext, 188 | df: pl.LazyFrame, 189 | path: "UPath", 190 | metadata: Optional[StorageMetadata] = None, 191 | ): 192 | ... 193 | 194 | @overload 195 | @abstractmethod 196 | def scan_df_from_path( 197 | self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] 198 | ) -> pl.LazyFrame: 199 | ... 200 | 201 | @overload 202 | @abstractmethod 203 | def scan_df_from_path( 204 | self, path: "UPath", context: InputContext, with_metadata: Literal[True] 205 | ) -> LazyFrameWithMetadata: 206 | ... 207 | 208 | @abstractmethod 209 | def scan_df_from_path( 210 | self, path: "UPath", context: InputContext, with_metadata: Optional[bool] = False 211 | ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: 212 | ... 213 | 214 | # tmp fix until https://github.com/dagster-io/dagster/pull/19294 is merged 215 | def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: 216 | # If no asset key, we are dealing with an op output which is always non-partitioned 217 | if not context.has_asset_key or not context.has_asset_partitions: 218 | path = self._get_path(context) 219 | return self._load_single_input(path, context) 220 | else: 221 | asset_partition_keys = context.asset_partition_keys 222 | if len(asset_partition_keys) == 0: 223 | return None 224 | elif len(asset_partition_keys) == 1: 225 | paths = self._get_paths_for_partitions(context) 226 | check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") 227 | path = next(iter(paths.values())) 228 | backcompat_paths = self._get_multipartition_backcompat_paths(context) 229 | backcompat_path = None if not backcompat_paths else next(iter(backcompat_paths.values())) 230 | 231 | return self._load_partition_from_path( 232 | context=context, 233 | partition_key=asset_partition_keys[0], 234 | path=path, 235 | backcompat_path=backcompat_path, 236 | ) 237 | else: # we are dealing with multiple partitions of an asset 238 | type_annotation = context.dagster_type.typing_type 239 | if type_annotation != Any and not is_dict_type(type_annotation): 240 | check.failed( 241 | "Loading an input that corresponds to multiple partitions, but the" 242 | " type annotation on the op input is not a dict, Dict, Mapping, or" 243 | f" Any: is '{type_annotation}'." 244 | ) 245 | 246 | return self._load_multiple_inputs(context) 247 | 248 | def dump_to_path( 249 | self, 250 | context: OutputContext, 251 | obj: Union[ 252 | pl.DataFrame, 253 | Optional[pl.DataFrame], 254 | Tuple[pl.DataFrame, Dict[str, Any]], 255 | pl.LazyFrame, 256 | Optional[pl.LazyFrame], 257 | Tuple[pl.LazyFrame, Dict[str, Any]], 258 | ], 259 | path: "UPath", 260 | partition_key: Optional[str] = None, 261 | ): 262 | typing_type = context.dagster_type.typing_type 263 | 264 | if annotation_is_typing_optional(typing_type) and ( 265 | obj is None or annotation_for_storage_metadata(typing_type) and obj[0] is None 266 | ): 267 | context.log.warning(self.get_optional_output_none_log_message(context, path)) 268 | return 269 | else: 270 | assert obj is not None, "output should not be None if it's type is not Optional" 271 | if not annotation_for_storage_metadata(typing_type): 272 | if typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: 273 | obj = cast(pl.DataFrame, obj) 274 | df = obj 275 | self.write_df_to_path(context=context, df=df, path=path) 276 | elif typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: 277 | obj = cast(pl.LazyFrame, obj) 278 | df = obj 279 | self.sink_df_to_path(context=context, df=df, path=path) 280 | else: 281 | raise NotImplementedError(f"dump_df_to_path for {typing_type} is not implemented") 282 | else: 283 | if not annotation_is_typing_optional(typing_type): 284 | frame_type = get_args(typing_type)[0] 285 | else: 286 | frame_type = get_args(get_args(typing_type)[0])[0] 287 | 288 | if frame_type in POLARS_EAGER_FRAME_ANNOTATIONS: 289 | obj = cast(Tuple[pl.DataFrame, Dict[str, Any]], obj) 290 | df, metadata = obj 291 | self.write_df_to_path(context=context, df=df, path=path, metadata=metadata) 292 | elif frame_type in POLARS_LAZY_FRAME_ANNOTATIONS: 293 | obj = cast(Tuple[pl.LazyFrame, Dict[str, Any]], obj) 294 | df, metadata = obj 295 | self.sink_df_to_path(context=context, df=df, path=path, metadata=metadata) 296 | else: 297 | raise NotImplementedError(f"dump_df_to_path for {typing_type} is not implemented") 298 | 299 | def load_from_path( 300 | self, context: InputContext, path: "UPath", partition_key: Optional[str] = None 301 | ) -> Union[ 302 | pl.DataFrame, 303 | pl.LazyFrame, 304 | Tuple[pl.DataFrame, Dict[str, Any]], 305 | Tuple[pl.LazyFrame, Dict[str, Any]], 306 | None, 307 | ]: 308 | if annotation_is_typing_optional(context.dagster_type.typing_type) and not path.exists(): 309 | context.log.warning(self.get_missing_optional_input_log_message(context, path)) 310 | return None 311 | 312 | assert context.metadata is not None 313 | 314 | metadata: Optional[StorageMetadata] = None 315 | 316 | return_storage_metadata = annotation_for_storage_metadata(context.dagster_type.typing_type) 317 | 318 | if not return_storage_metadata: 319 | ldf = self.scan_df_from_path(path=path, context=context) # type: ignore 320 | else: 321 | ldf, metadata = self.scan_df_from_path(path=path, context=context, with_metadata=True) 322 | 323 | columns = context.metadata.get("columns") 324 | if columns is not None: 325 | context.log.debug(f"Loading {columns=}") 326 | ldf = ldf.select(columns) 327 | 328 | if ( 329 | context.upstream_output is not None 330 | and context.upstream_output.asset_info is not None 331 | and context.upstream_output.asset_info.partitions_def is not None 332 | and context.upstream_output.metadata is not None 333 | and partition_key is not None 334 | ): 335 | partition_by = context.upstream_output.metadata.get("partition_by") 336 | 337 | # we can only support automatically filtering by 1 column 338 | # otherwise we would have been dealing with a multi-partition key 339 | # which is not straightforward to filter by 340 | if partition_by is not None and isinstance(partition_by, str): 341 | context.log.debug(f"Filtering by {partition_by}=={partition_key}") 342 | ldf = ldf.filter(pl.col(partition_by) == partition_key) 343 | 344 | if context.dagster_type.typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: 345 | if not return_storage_metadata: 346 | return ldf.collect() 347 | else: 348 | assert metadata is not None 349 | return ldf.collect(), metadata 350 | 351 | elif context.dagster_type.typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: 352 | if not return_storage_metadata: 353 | return ldf 354 | else: 355 | assert metadata is not None 356 | return ldf, metadata 357 | else: 358 | raise NotImplementedError(f"Can't load object for type annotation {context.dagster_type.typing_type}") 359 | 360 | def get_metadata( 361 | self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] 362 | ) -> Dict[str, MetadataValue]: 363 | if obj is None: 364 | return {"missing": MetadataValue.bool(True)} 365 | else: 366 | if annotation_for_storage_metadata(context.dagster_type.typing_type): 367 | df = obj[0] 368 | else: 369 | df = obj 370 | return get_polars_metadata(context, df) if df is not None else {"missing": MetadataValue.bool(True)} 371 | 372 | def get_path_for_partition( 373 | self, context: Union[InputContext, OutputContext], path: "UPath", partition: str 374 | ) -> "UPath": 375 | """Method for accessing the path for a given partition. 376 | 377 | Override this method if you want to use a different partitioning scheme 378 | (for example, if the saving function handles partitioning instead). 379 | The extension will be added later. 380 | :param context: 381 | :param path: asset path before partitioning 382 | :param partition: formatted partition key 383 | :return: 384 | """ 385 | return path / partition 386 | 387 | def get_missing_optional_input_log_message(self, context: InputContext, path: "UPath") -> str: 388 | return f"Optional input {context.name} at {path} doesn't exist in the filesystem and won't be loaded!" 389 | 390 | def get_optional_output_none_log_message(self, context: OutputContext, path: "UPath") -> str: 391 | return f"The object for the optional output {context.name} is None, so it won't be saved to {path}!" 392 | 393 | # this method is overridden because the default one does not pass partition_key to load_from_path 394 | def _load_partition_from_path( 395 | self, 396 | context: InputContext, 397 | partition_key: str, 398 | path: "UPath", 399 | backcompat_path: Optional["UPath"] = None, 400 | ) -> Any: 401 | """1. Try to load the partition from the normal path. 402 | 2. If it was not found, try to load it from the backcompat path. 403 | 3. If allow_missing_partitions metadata is True, skip the partition if it was not found in any of the paths. 404 | Otherwise, raise an error. 405 | 406 | Args: 407 | context (InputContext): IOManager Input context 408 | partition_key (str): the partition key corresponding to the partition being loaded 409 | path (UPath): The path to the partition. 410 | backcompat_path (Optional[UPath]): The path to the partition in the backcompat location. 411 | 412 | Returns: 413 | Any: The object loaded from the partition. 414 | """ 415 | allow_missing_partitions = ( 416 | context.metadata.get("allow_missing_partitions", False) if context.metadata is not None else False 417 | ) 418 | 419 | try: 420 | context.log.debug(self.get_loading_input_partition_log_message(path, partition_key)) 421 | obj = self.load_from_path(context=context, path=path, partition_key=partition_key) 422 | return obj 423 | except FileNotFoundError as e: 424 | if backcompat_path is not None: 425 | try: 426 | obj = self.load_from_path(context=context, path=path, partition_key=partition_key) 427 | context.log.debug( 428 | f"File not found at {path}. Loaded instead from backcompat path:" f" {backcompat_path}" 429 | ) 430 | return obj 431 | except FileNotFoundError as e: 432 | if allow_missing_partitions: 433 | context.log.warning(self.get_missing_partition_log_message(partition_key)) 434 | return None 435 | else: 436 | raise e 437 | if allow_missing_partitions: 438 | context.log.warning(self.get_missing_partition_log_message(partition_key)) 439 | return None 440 | else: 441 | raise e 442 | -------------------------------------------------------------------------------- /tests/test_upath_io_managers.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Dict, Optional, Tuple 3 | 4 | import polars as pl 5 | import polars.testing as pl_testing 6 | import pytest 7 | from dagster import ( 8 | AssetExecutionContext, 9 | AssetIn, 10 | DailyPartitionsDefinition, 11 | DimensionPartitionMapping, 12 | IdentityPartitionMapping, 13 | MultiPartitionKey, 14 | MultiPartitionMapping, 15 | MultiPartitionsDefinition, 16 | OpExecutionContext, 17 | StaticPartitionsDefinition, 18 | TimeWindowPartitionMapping, 19 | asset, 20 | materialize, 21 | ) 22 | from deepdiff import DeepDiff 23 | from packaging.version import Version 24 | 25 | from dagster_polars import ( 26 | BasePolarsUPathIOManager, 27 | DataFramePartitions, 28 | LazyFramePartitions, 29 | PolarsDeltaIOManager, 30 | PolarsParquetIOManager, 31 | StorageMetadata, 32 | ) 33 | from tests.utils import get_saved_path 34 | 35 | 36 | def test_polars_upath_io_manager_stats_metadata( 37 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 38 | ): 39 | manager, _ = io_manager_and_df 40 | 41 | df = pl.DataFrame({"a": [0, 1, None], "b": ["a", "b", "c"]}) 42 | 43 | @asset(io_manager_def=manager) 44 | def upstream() -> pl.DataFrame: 45 | return df 46 | 47 | result = materialize( 48 | [upstream], 49 | ) 50 | 51 | handled_output_events = list(filter(lambda evt: evt.is_handled_output, result.events_for_node("upstream"))) 52 | 53 | stats = handled_output_events[0].event_specific_data.metadata["stats"].value # type: ignore 54 | 55 | expected_stats = { 56 | "a": { 57 | # count started ignoring null values in polars 0.20.0 58 | "count": 3.0 if Version(pl.__version__) < Version("0.20.0") else 2.0, 59 | "null_count": 1.0, 60 | "mean": 0.5, 61 | "std": 0.7071067811865476, 62 | "min": 0.0, 63 | "max": 1.0, 64 | "median": 0.5, 65 | "25%": 0.0, 66 | "75%": 1.0, 67 | }, 68 | "b": { 69 | "count": "3", 70 | "null_count": "0", 71 | "mean": "null", 72 | "std": "null", 73 | "min": "a", 74 | "max": "c", 75 | "median": "null", 76 | "25%": "null", 77 | "75%": "null", 78 | }, 79 | } 80 | 81 | # "50%" and "median" are problematic to test because they were changed in polars 0.18.0 82 | # so we remove them from the test 83 | for col in ("50%", "median"): 84 | for s in (stats, expected_stats): 85 | if col in s["a"]: # type: ignore 86 | s["a"].pop(col) # type: ignore 87 | if col in s["b"]: # type: ignore 88 | s["b"].pop(col) # type: ignore 89 | 90 | assert DeepDiff(stats, expected_stats) == {} 91 | 92 | 93 | def test_polars_upath_io_manager_type_annotations( 94 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 95 | ): 96 | manager, df = io_manager_and_df 97 | 98 | @asset(io_manager_def=manager) 99 | def upstream() -> pl.DataFrame: 100 | return df 101 | 102 | @asset(io_manager_def=manager) 103 | def downstream_default_eager(upstream) -> None: 104 | assert isinstance(upstream, pl.DataFrame), type(upstream) 105 | 106 | @asset(io_manager_def=manager) 107 | def downstream_eager(upstream: pl.DataFrame) -> None: 108 | assert isinstance(upstream, pl.DataFrame), type(upstream) 109 | 110 | @asset(io_manager_def=manager) 111 | def downstream_lazy(upstream: pl.LazyFrame) -> None: 112 | assert isinstance(upstream, pl.LazyFrame), type(upstream) 113 | 114 | partitions_def = StaticPartitionsDefinition(["a", "b"]) 115 | 116 | @asset(io_manager_def=manager, partitions_def=partitions_def) 117 | def upstream_partitioned(context: OpExecutionContext) -> pl.DataFrame: 118 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 119 | 120 | @asset(io_manager_def=manager) 121 | def downstream_multi_partitioned_eager(upstream_partitioned: Dict[str, pl.DataFrame]) -> None: 122 | for _df in upstream_partitioned.values(): 123 | assert isinstance(_df, pl.DataFrame), type(_df) 124 | assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() 125 | 126 | @asset(io_manager_def=manager) 127 | def downstream_multi_partitioned_lazy(upstream_partitioned: Dict[str, pl.LazyFrame]) -> None: 128 | for _df in upstream_partitioned.values(): 129 | assert isinstance(_df, pl.LazyFrame), type(_df) 130 | assert set(upstream_partitioned.keys()) == {"a", "b"}, upstream_partitioned.keys() 131 | 132 | for partition_key in ["a", "b"]: 133 | materialize( 134 | [upstream_partitioned], 135 | partition_key=partition_key, 136 | ) 137 | 138 | materialize( 139 | [ 140 | upstream_partitioned.to_source_asset(), 141 | upstream, 142 | downstream_default_eager, 143 | downstream_eager, 144 | downstream_lazy, 145 | downstream_multi_partitioned_eager, 146 | downstream_multi_partitioned_lazy, 147 | ], 148 | ) 149 | 150 | 151 | def test_polars_upath_io_manager_nested_dtypes( 152 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 153 | ): 154 | manager, df = io_manager_and_df 155 | 156 | @asset(io_manager_def=manager) 157 | def upstream() -> pl.DataFrame: 158 | return df 159 | 160 | @asset(io_manager_def=manager) 161 | def downstream(upstream: pl.LazyFrame) -> pl.DataFrame: 162 | return upstream.collect(streaming=True) 163 | 164 | result = materialize( 165 | [upstream, downstream], 166 | ) 167 | 168 | saved_path = get_saved_path(result, "upstream") 169 | 170 | if isinstance(manager, PolarsParquetIOManager): 171 | pl_testing.assert_frame_equal(df, pl.read_parquet(saved_path)) 172 | elif isinstance(manager, PolarsDeltaIOManager): 173 | pl_testing.assert_frame_equal(df, pl.read_delta(saved_path)) 174 | else: 175 | raise ValueError(f"Test not implemented for {type(manager)}") 176 | 177 | 178 | def test_polars_upath_io_manager_input_optional_eager( 179 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 180 | ): 181 | manager, df = io_manager_and_df 182 | 183 | @asset(io_manager_def=manager) 184 | def upstream() -> pl.DataFrame: 185 | return df 186 | 187 | @asset(io_manager_def=manager) 188 | def downstream(upstream: Optional[pl.DataFrame]) -> pl.DataFrame: 189 | assert upstream is not None 190 | return upstream 191 | 192 | materialize( 193 | [upstream, downstream], 194 | ) 195 | 196 | 197 | def test_polars_upath_io_manager_input_optional_lazy( 198 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 199 | ): 200 | manager, df = io_manager_and_df 201 | 202 | @asset(io_manager_def=manager) 203 | def upstream() -> pl.DataFrame: 204 | return df 205 | 206 | @asset(io_manager_def=manager) 207 | def downstream(upstream: Optional[pl.LazyFrame]) -> pl.DataFrame: 208 | assert upstream is not None 209 | return upstream.collect() 210 | 211 | materialize( 212 | [upstream, downstream], 213 | ) 214 | 215 | 216 | def test_polars_upath_io_manager_input_dict_eager( 217 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 218 | ): 219 | manager, df = io_manager_and_df 220 | 221 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 222 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 223 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 224 | 225 | @asset(io_manager_def=manager) 226 | def downstream(upstream: Dict[str, pl.DataFrame]) -> pl.DataFrame: 227 | dfs = [] 228 | for df in upstream.values(): 229 | assert isinstance(df, pl.DataFrame) 230 | dfs.append(df) 231 | return pl.concat(dfs) 232 | 233 | for partition_key in ["a", "b"]: 234 | materialize( 235 | [upstream], 236 | partition_key=partition_key, 237 | ) 238 | 239 | materialize( 240 | [upstream.to_source_asset(), downstream], 241 | ) 242 | 243 | 244 | def test_polars_upath_io_manager_input_dict_lazy( 245 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 246 | ): 247 | manager, df = io_manager_and_df 248 | 249 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 250 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 251 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 252 | 253 | @asset(io_manager_def=manager) 254 | def downstream(upstream: Dict[str, pl.LazyFrame]) -> pl.DataFrame: 255 | dfs = [] 256 | for df in upstream.values(): 257 | assert isinstance(df, pl.LazyFrame) 258 | dfs.append(df) 259 | return pl.concat(dfs).collect() 260 | 261 | for partition_key in ["a", "b"]: 262 | materialize( 263 | [upstream], 264 | partition_key=partition_key, 265 | ) 266 | 267 | materialize( 268 | [upstream.to_source_asset(), downstream], 269 | ) 270 | 271 | 272 | def test_polars_upath_io_manager_input_data_frame_partitions( 273 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 274 | ): 275 | manager, df = io_manager_and_df 276 | 277 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 278 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 279 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 280 | 281 | @asset(io_manager_def=manager) 282 | def downstream(upstream: DataFramePartitions) -> pl.DataFrame: 283 | dfs = [] 284 | for df in upstream.values(): 285 | assert isinstance(df, pl.DataFrame) 286 | dfs.append(df) 287 | return pl.concat(dfs) 288 | 289 | for partition_key in ["a", "b"]: 290 | materialize( 291 | [upstream], 292 | partition_key=partition_key, 293 | ) 294 | 295 | materialize( 296 | [upstream.to_source_asset(), downstream], 297 | ) 298 | 299 | 300 | def test_polars_upath_io_manager_input_lazy_frame_partitions_lazy( 301 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 302 | ): 303 | manager, df = io_manager_and_df 304 | 305 | @asset(io_manager_def=manager, partitions_def=StaticPartitionsDefinition(["a", "b"])) 306 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 307 | return df.with_columns(pl.lit(context.partition_key).alias("partition")) 308 | 309 | @asset(io_manager_def=manager) 310 | def downstream(upstream: LazyFramePartitions) -> pl.DataFrame: 311 | dfs = [] 312 | for df in upstream.values(): 313 | assert isinstance(df, pl.LazyFrame) 314 | dfs.append(df) 315 | return pl.concat(dfs).collect() 316 | 317 | for partition_key in ["a", "b"]: 318 | materialize( 319 | [upstream], 320 | partition_key=partition_key, 321 | ) 322 | 323 | materialize( 324 | [upstream.to_source_asset(), downstream], 325 | ) 326 | 327 | 328 | def test_polars_upath_io_manager_input_optional_eager_return_none( 329 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 330 | ): 331 | manager, df = io_manager_and_df 332 | 333 | @asset(io_manager_def=manager) 334 | def upstream() -> pl.DataFrame: 335 | return df 336 | 337 | @asset 338 | def downstream(upstream: Optional[pl.DataFrame]): 339 | assert upstream is None 340 | 341 | materialize( 342 | [upstream.to_source_asset(), downstream], 343 | ) 344 | 345 | 346 | def test_polars_upath_io_manager_output_optional_eager( 347 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 348 | ): 349 | manager, df = io_manager_and_df 350 | 351 | @asset(io_manager_def=manager) 352 | def upstream() -> Optional[pl.DataFrame]: 353 | return None 354 | 355 | @asset(io_manager_def=manager) 356 | def downstream(upstream: Optional[pl.DataFrame]) -> Optional[pl.DataFrame]: 357 | assert upstream is None 358 | return upstream 359 | 360 | materialize( 361 | [upstream, downstream], 362 | ) 363 | 364 | 365 | def test_polars_upath_io_manager_output_optional_lazy( 366 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 367 | ): 368 | manager, df = io_manager_and_df 369 | 370 | @asset(io_manager_def=manager) 371 | def upstream() -> Optional[pl.DataFrame]: 372 | return None 373 | 374 | @asset(io_manager_def=manager) 375 | def downstream(upstream: Optional[pl.LazyFrame]) -> Optional[pl.DataFrame]: 376 | assert upstream is None 377 | return upstream 378 | 379 | materialize( 380 | [upstream, downstream], 381 | ) 382 | 383 | 384 | IO_MANAGERS_SUPPORTING_STORAGE_METADATA = ( 385 | PolarsParquetIOManager, 386 | PolarsDeltaIOManager, 387 | ) 388 | 389 | 390 | def check_skip_storage_metadata_test(io_manager_def: BasePolarsUPathIOManager): 391 | if not isinstance(io_manager_def, IO_MANAGERS_SUPPORTING_STORAGE_METADATA): 392 | pytest.skip(f"Only {IO_MANAGERS_SUPPORTING_STORAGE_METADATA} support storage metadata") 393 | 394 | 395 | @pytest.fixture 396 | def metadata() -> StorageMetadata: 397 | return {"a": 1, "b": "2", "c": [1, 2, 3], "d": {"e": 1}, "f": [1, 2, 3, {"g": 1}]} 398 | 399 | 400 | def test_upath_io_manager_storage_metadata_eager( 401 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 402 | ): 403 | io_manager_def, df = io_manager_and_df 404 | check_skip_storage_metadata_test(io_manager_def) 405 | 406 | @asset(io_manager_def=io_manager_def) 407 | def upstream() -> Tuple[pl.DataFrame, StorageMetadata]: 408 | return df, metadata 409 | 410 | @asset(io_manager_def=io_manager_def) 411 | def downstream(upstream: Tuple[pl.DataFrame, StorageMetadata]) -> None: 412 | loaded_df, upstream_metadata = upstream 413 | assert upstream_metadata == metadata 414 | pl_testing.assert_frame_equal(loaded_df, df) 415 | 416 | materialize( 417 | [upstream, downstream], 418 | ) 419 | 420 | 421 | def test_upath_io_manager_storage_metadata_lazy( 422 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 423 | ): 424 | io_manager_def, df = io_manager_and_df 425 | check_skip_storage_metadata_test(io_manager_def) 426 | 427 | @asset(io_manager_def=io_manager_def) 428 | def upstream() -> Tuple[pl.DataFrame, StorageMetadata]: 429 | return df, metadata 430 | 431 | @asset(io_manager_def=io_manager_def) 432 | def downstream(upstream: Tuple[pl.LazyFrame, StorageMetadata]) -> None: 433 | df, upstream_metadata = upstream 434 | assert isinstance(df, pl.LazyFrame) 435 | assert upstream_metadata == metadata 436 | 437 | materialize( 438 | [upstream, downstream], 439 | ) 440 | 441 | 442 | def test_upath_io_manager_storage_metadata_optional_eager_exists( 443 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 444 | ): 445 | io_manager_def, df = io_manager_and_df 446 | check_skip_storage_metadata_test(io_manager_def) 447 | 448 | @asset(io_manager_def=io_manager_def) 449 | def upstream() -> Optional[Tuple[pl.DataFrame, StorageMetadata]]: 450 | return df, metadata 451 | 452 | @asset(io_manager_def=io_manager_def) 453 | def downstream(upstream: Optional[Tuple[pl.DataFrame, StorageMetadata]]) -> None: 454 | assert upstream is not None 455 | df, upstream_metadata = upstream 456 | assert upstream_metadata == metadata 457 | 458 | materialize( 459 | [upstream, downstream], 460 | ) 461 | 462 | 463 | def test_upath_io_manager_storage_metadata_optional_eager_missing( 464 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 465 | ): 466 | io_manager_def, df = io_manager_and_df 467 | check_skip_storage_metadata_test(io_manager_def) 468 | 469 | @asset(io_manager_def=io_manager_def) 470 | def upstream() -> Optional[Tuple[pl.DataFrame, StorageMetadata]]: 471 | return None 472 | 473 | @asset(io_manager_def=io_manager_def) 474 | def downstream(upstream: Optional[Tuple[pl.DataFrame, StorageMetadata]]) -> None: 475 | assert upstream is None 476 | 477 | materialize( 478 | [upstream, downstream], 479 | ) 480 | 481 | 482 | def test_upath_io_manager_storage_metadata_optional_lazy_exists( 483 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 484 | ): 485 | io_manager_def, df = io_manager_and_df 486 | check_skip_storage_metadata_test(io_manager_def) 487 | 488 | @asset(io_manager_def=io_manager_def) 489 | def upstream() -> Optional[Tuple[pl.DataFrame, StorageMetadata]]: 490 | return df, metadata 491 | 492 | @asset(io_manager_def=io_manager_def) 493 | def downstream(upstream: Optional[Tuple[pl.LazyFrame, StorageMetadata]]) -> None: 494 | assert upstream is not None 495 | df, upstream_metadata = upstream 496 | assert isinstance(df, pl.LazyFrame) 497 | assert upstream_metadata == metadata 498 | 499 | materialize( 500 | [upstream, downstream], 501 | ) 502 | 503 | 504 | def test_upath_io_manager_storage_metadata_optional_lazy_missing( 505 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], metadata: StorageMetadata 506 | ): 507 | io_manager_def, df = io_manager_and_df 508 | check_skip_storage_metadata_test(io_manager_def) 509 | 510 | @asset(io_manager_def=io_manager_def) 511 | def upstream() -> Optional[Tuple[pl.DataFrame, StorageMetadata]]: 512 | return None 513 | 514 | @asset(io_manager_def=io_manager_def) 515 | def downstream(upstream: Optional[Tuple[pl.LazyFrame, StorageMetadata]]) -> None: 516 | assert upstream is None 517 | 518 | materialize( 519 | [upstream, downstream], 520 | ) 521 | 522 | 523 | def test_upath_io_manager_multi_partitions_definition_load_multiple_partitions( 524 | io_manager_and_df: Tuple[BasePolarsUPathIOManager, pl.DataFrame], 525 | ): 526 | io_manager_def, df = io_manager_and_df 527 | 528 | today = datetime.now().date() 529 | 530 | partitions_def = MultiPartitionsDefinition( 531 | { 532 | "time": DailyPartitionsDefinition(start_date=str(today - timedelta(days=3))), 533 | "static": StaticPartitionsDefinition(["a"]), 534 | } 535 | ) 536 | 537 | @asset(partitions_def=partitions_def, io_manager_def=io_manager_def) 538 | def upstream(context: AssetExecutionContext) -> pl.DataFrame: 539 | return pl.DataFrame({"partition": [str(context.partition_key)]}) 540 | 541 | # this asset will request 2 upstream partitions 542 | @asset( 543 | io_manager_def=io_manager_def, 544 | partitions_def=partitions_def, 545 | ins={ 546 | "upstream": AssetIn( 547 | partition_mapping=MultiPartitionMapping( 548 | { 549 | "time": DimensionPartitionMapping("time", TimeWindowPartitionMapping(start_offset=-1)), 550 | "static": DimensionPartitionMapping("static", IdentityPartitionMapping()), 551 | } 552 | ) 553 | ) 554 | }, 555 | ) 556 | def downstream(context: AssetExecutionContext, upstream: DataFramePartitions) -> None: 557 | assert len(upstream.values()) == 2 558 | 559 | materialize( 560 | [upstream], 561 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=3)), "static": "a"}), 562 | ) 563 | materialize( 564 | [upstream], 565 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=2)), "static": "a"}), 566 | ) 567 | # materialize([upstream], partition_key=MultiPartitionKey({"time": str(today - timedelta(days=1)), "static": "a"})) 568 | 569 | materialize( 570 | [upstream.to_source_asset(), downstream], 571 | partition_key=MultiPartitionKey({"time": str(today - timedelta(days=2)), "static": "a"}), 572 | ) 573 | --------------------------------------------------------------------------------