├── .flake8 ├── .gitattributes ├── .github └── workflows │ ├── conda.yml │ ├── dask_test.yaml │ ├── pre-commit.yml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── README.md ├── changes.md ├── ci ├── environment.yml ├── environment_released.yml └── recipe │ └── meta.yaml ├── codecov.yml ├── dask_expr ├── __init__.py ├── _accessor.py ├── _backends.py ├── _categorical.py ├── _collection.py ├── _concat.py ├── _core.py ├── _cumulative.py ├── _datetime.py ├── _describe.py ├── _dispatch.py ├── _dummies.py ├── _expr.py ├── _groupby.py ├── _indexing.py ├── _interchange.py ├── _merge.py ├── _merge_asof.py ├── _quantile.py ├── _quantiles.py ├── _reductions.py ├── _repartition.py ├── _resample.py ├── _rolling.py ├── _shuffle.py ├── _str_accessor.py ├── _util.py ├── _version.py ├── array │ ├── __init__.py │ ├── _creation.py │ ├── blockwise.py │ ├── core.py │ ├── random.py │ ├── rechunk.py │ ├── reductions.py │ ├── slicing.py │ └── tests │ │ ├── __init__.py │ │ ├── test_array.py │ │ └── test_creation.py ├── datasets.py ├── diagnostics │ ├── __init__.py │ ├── _analyze.py │ ├── _analyze_plugin.py │ └── _explain.py ├── io │ ├── __init__.py │ ├── _delayed.py │ ├── bag.py │ ├── csv.py │ ├── hdf.py │ ├── io.py │ ├── json.py │ ├── orc.py │ ├── parquet.py │ ├── records.py │ ├── sql.py │ └── tests │ │ ├── __init__.py │ │ ├── test_delayed.py │ │ ├── test_distributed.py │ │ ├── test_from_pandas.py │ │ ├── test_io.py │ │ ├── test_parquet.py │ │ └── test_sql.py └── tests │ ├── __init__.py │ ├── _util.py │ ├── test_align_partitions.py │ ├── test_categorical.py │ ├── test_collection.py │ ├── test_concat.py │ ├── test_core.py │ ├── test_cumulative.py │ ├── test_datasets.py │ ├── test_datetime.py │ ├── test_describe.py │ ├── test_diagnostics.py │ ├── test_distributed.py │ ├── test_dummies.py │ ├── test_format.py │ ├── test_fusion.py │ ├── test_groupby.py │ ├── test_indexing.py │ ├── test_interchange.py │ ├── test_map_partitions_overlap.py │ ├── test_merge.py │ ├── test_merge_asof.py │ ├── test_partitioning_knowledge.py │ ├── test_predicate_pushdown.py │ ├── test_quantiles.py │ ├── test_reductions.py │ ├── test_repartition.py │ ├── test_resample.py │ ├── test_reshape.py │ ├── test_rolling.py │ ├── test_shuffle.py │ ├── test_string_accessor.py │ └── test_ufunc.py ├── demo.ipynb ├── pyproject.toml ├── setup.cfg └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | # flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234 2 | [flake8] 3 | # References: 4 | # https://flake8.readthedocs.io/en/latest/user/configuration.html 5 | # https://flake8.readthedocs.io/en/latest/user/error-codes.html 6 | # https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes 7 | exclude = __init__.py 8 | ignore = 9 | # Extra space in brackets 10 | E20 11 | # Multiple spaces around "," 12 | E231,E241 13 | # Comments 14 | E26 15 | # Import formatting 16 | E4 17 | # Comparing types instead of isinstance 18 | E721 19 | # Assigning lambda expression 20 | E731 21 | # Ambiguous variable names 22 | E741 23 | # Line break before binary operator 24 | W503 25 | # Line break after binary operator 26 | W504 27 | # Redefinition of unused 'loop' from line 10 28 | F811 29 | # No explicit stacklevel in warnings.warn. FIXME we should correct this in the code 30 | B028 31 | 32 | max-line-length = 120 33 | per-file-ignores = 34 | *_test.py: 35 | # Do not call assert False since python -O removes these calls 36 | B011, 37 | **/tests/*: 38 | # Do not call assert False since python -O removes these calls 39 | B011, 40 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | dask_expr/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/workflows/conda.yml: -------------------------------------------------------------------------------- 1 | name: Conda build 2 | on: 3 | push: 4 | branches: 5 | - main 6 | tags: 7 | - "*" 8 | pull_request: 9 | paths: 10 | - setup.py 11 | - ci/recipe/** 12 | - .github/workflows/conda.yml 13 | - pyproject.toml 14 | 15 | # When this workflow is queued, automatically cancel any previous running 16 | # or pending jobs from the same branch 17 | concurrency: 18 | group: conda-${{ github.ref }} 19 | cancel-in-progress: true 20 | 21 | # Required shell entrypoint to have properly activated conda environments 22 | defaults: 23 | run: 24 | shell: bash -l {0} 25 | 26 | jobs: 27 | conda: 28 | name: Build (and upload) 29 | runs-on: ubuntu-latest 30 | steps: 31 | - uses: actions/checkout@v4.1.1 32 | with: 33 | fetch-depth: 0 34 | - name: Set up Python 35 | uses: conda-incubator/setup-miniconda@v3.0.4 36 | with: 37 | miniforge-version: latest 38 | use-mamba: true 39 | python-version: 3.9 40 | channel-priority: strict 41 | - name: Install dependencies 42 | run: | 43 | mamba install -c conda-forge boa conda-verify 44 | 45 | which python 46 | pip list 47 | mamba list 48 | - name: Build conda packages 49 | run: | 50 | # suffix for pre-release package versions 51 | export VERSION_SUFFIX=a`date +%y%m%d` 52 | 53 | # conda search for the latest dask-core pre-release 54 | arr=($(conda search --override-channels -c dask/label/dev dask-core | tail -n 1)) 55 | 56 | # extract dask-core pre-release version / build 57 | export DASK_CORE_VERSION=${arr[1]} 58 | 59 | # distributed pre-release build 60 | conda mambabuild ci/recipe \ 61 | --channel dask/label/dev \ 62 | --no-anaconda-upload \ 63 | --output-folder . 64 | - name: Upload conda packages 65 | if: github.event_name == 'push' && github.repository == 'dask/dask-expr' 66 | env: 67 | ANACONDA_API_TOKEN: ${{ secrets.DASK_CONDA_TOKEN }} 68 | run: | 69 | # install anaconda for upload 70 | mamba install -c conda-forge anaconda-client 71 | 72 | anaconda upload --label dev noarch/*.tar.bz2 73 | -------------------------------------------------------------------------------- /.github/workflows/dask_test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests / dask/dask 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | # When this workflow is queued, automatically cancel any previous running 11 | # or pending jobs from the same branch 12 | concurrency: 13 | group: dask_test-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | # Required shell entrypoint to have properly activated conda environments 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | 21 | jobs: 22 | test: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | python-version: ["3.10", "3.12"] 28 | environment-file: [ci/environment.yml] 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | with: 33 | fetch-depth: 0 # Needed by codecov.io 34 | 35 | - name: Get current date 36 | id: date 37 | run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}" 38 | 39 | - name: Install Environment 40 | uses: mamba-org/setup-micromamba@v1 41 | with: 42 | environment-file: ${{ matrix.environment-file }} 43 | create-args: python=${{ matrix.python-version }} 44 | # Wipe cache every 24 hours or whenever environment.yml changes. This means it 45 | # may take up to a day before changes to unpinned packages are picked up. 46 | # To force a cache refresh, change the hardcoded numerical suffix below. 47 | cache-environment-key: environment-${{ steps.date.outputs.date }}-0 48 | 49 | - name: Install current main versions of dask 50 | run: python -m pip install git+https://github.com/dask/dask 51 | 52 | - name: Install current main versions of distributed 53 | run: python -m pip install git+https://github.com/dask/distributed 54 | 55 | - name: Install dask-expr 56 | run: python -m pip install -e . --no-deps 57 | 58 | - name: Print dask versions 59 | # Output of `micromamba list` is buggy for pip-installed packages 60 | run: pip list | grep -E 'dask|distributed' 61 | 62 | - name: Run Dask DataFrame tests 63 | run: python -c "import dask.dataframe as dd; dd.test_dataframe()" 64 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: main 8 | 9 | jobs: 10 | checks: 11 | name: pre-commit hooks 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-python@v5 16 | with: 17 | python-version: '3.10' 18 | - uses: pre-commit/action@v3.0.0 19 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | workflow_dispatch: 9 | 10 | # When this workflow is queued, automatically cancel any previous running 11 | # or pending jobs from the same branch 12 | concurrency: 13 | group: test-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | # Required shell entrypoint to have properly activated conda environments 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | 21 | jobs: 22 | test: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | python-version: ["3.10", "3.11", "3.12", "3.13"] 28 | environment-file: [ci/environment.yml] 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | with: 33 | fetch-depth: 0 # Needed by codecov.io 34 | 35 | - name: Get current date 36 | id: date 37 | run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}" 38 | 39 | - name: Install Environment 40 | uses: mamba-org/setup-micromamba@v1 41 | with: 42 | environment-file: ${{ matrix.environment-file }} 43 | create-args: python=${{ matrix.python-version }} 44 | # Wipe cache every 24 hours or whenever environment.yml changes. This means it 45 | # may take up to a day before changes to unpinned packages are picked up. 46 | # To force a cache refresh, change the hardcoded numerical suffix below. 47 | cache-environment-key: environment-${{ steps.date.outputs.date }}-1 48 | 49 | - name: Install current main versions of dask 50 | run: python -m pip install git+https://github.com/dask/dask 51 | if: ${{ matrix.environment-file == 'ci/environment.yml' }} 52 | 53 | - name: Install current main versions of distributed 54 | run: python -m pip install git+https://github.com/dask/distributed 55 | if: ${{ matrix.environment-file == 'ci/environment.yml' }} 56 | 57 | - name: Install dask-expr 58 | run: python -m pip install -e . --no-deps 59 | 60 | - name: Print dask versions 61 | # Output of `micromamba list` is buggy for pip-installed packages 62 | run: pip list | grep -E 'dask|distributed' 63 | 64 | - name: Run tests 65 | run: py.test -n auto --verbose --cov=dask_expr --cov-report=xml 66 | 67 | - name: Coverage 68 | uses: codecov/codecov-action@v3 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | build/ 3 | dist/ 4 | *.egg-info/ 5 | bench/shakespeare.txt 6 | .coverage 7 | *.sw? 8 | .DS_STORE 9 | \.tox/ 10 | .idea/ 11 | .ipynb_checkpoints/ 12 | coverage.xml 13 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: debug-statements 7 | - repo: https://github.com/MarcoGorelli/absolufy-imports 8 | rev: v0.3.1 9 | hooks: 10 | - id: absolufy-imports 11 | name: absolufy-imports 12 | - repo: https://github.com/pycqa/isort 13 | rev: 5.12.0 14 | hooks: 15 | - id: isort 16 | language_version: python3 17 | - repo: https://github.com/asottile/pyupgrade 18 | rev: v3.3.2 19 | hooks: 20 | - id: pyupgrade 21 | args: 22 | - --py310-plus 23 | - repo: https://github.com/psf/black 24 | rev: 23.3.0 25 | hooks: 26 | - id: black 27 | language_version: python3 28 | exclude: versioneer.py 29 | args: 30 | - --target-version=py310 31 | - repo: https://github.com/pycqa/flake8 32 | rev: 6.0.0 33 | hooks: 34 | - id: flake8 35 | language_version: python3 36 | additional_dependencies: 37 | # NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive 38 | # dependency. Make sure to update flake8-bugbear manually on a regular basis. 39 | - flake8-bugbear==22.8.23 40 | - repo: https://github.com/codespell-project/codespell 41 | rev: v2.2.4 42 | hooks: 43 | - id: codespell 44 | types_or: [rst, markdown] 45 | files: docs 46 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright Dask-expr development team 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Dask Expressions 2 | ================ 3 | 4 | The Implementation is now the default and only backend for Dask DataFrames and was 5 | moved to https://github.com/dask/dask. 6 | 7 | This repository is no longer maintained. 8 | -------------------------------------------------------------------------------- /ci/environment.yml: -------------------------------------------------------------------------------- 1 | name: dask-expr 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - pytest 6 | - pytest-cov 7 | - pytest-xdist 8 | - dask # overridden by git tip below 9 | - pyarrow>=14.0.1 10 | - pandas>=2 11 | - pre-commit 12 | - sqlalchemy 13 | - xarray 14 | - pip: 15 | - git+https://github.com/dask/distributed 16 | - git+https://github.com/dask/dask 17 | -------------------------------------------------------------------------------- /ci/environment_released.yml: -------------------------------------------------------------------------------- 1 | name: dask-expr 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - pytest 6 | - pytest-cov 7 | - pytest-xdist 8 | - dask 9 | - pyarrow>=14.0.1 10 | - pandas>=2 11 | - pre-commit 12 | -------------------------------------------------------------------------------- /ci/recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set major_minor_patch = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v').split('.') %} 2 | {% set new_patch = major_minor_patch[2] | int + 1 %} 3 | {% set version = (major_minor_patch[:2] + [new_patch]) | join('.') + environ.get('VERSION_SUFFIX', '') %} 4 | {% set dask_version = environ.get('DASK_CORE_VERSION', '0.0.0.dev') %} 5 | 6 | 7 | package: 8 | name: dask-expr 9 | version: {{ version }} 10 | 11 | source: 12 | git_url: ../.. 13 | 14 | build: 15 | number: {{ GIT_DESCRIBE_NUMBER }} 16 | noarch: python 17 | string: py_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} 18 | script: {{ PYTHON }} -m pip install . -vv 19 | 20 | requirements: 21 | host: 22 | - python >=3.10 23 | - pip 24 | - dask-core {{ dask_version }} 25 | - versioneer =0.28 26 | - setuptools >=62.6 27 | - tomli # [py<311] 28 | run: 29 | - python >=3.10 30 | - {{ pin_compatible('dask-core', max_pin='x.x.x.x') }} 31 | - pyarrow 32 | - pandas >=2 33 | 34 | test: 35 | imports: 36 | - dask_expr 37 | requires: 38 | - pip 39 | commands: 40 | - pip check 41 | 42 | about: 43 | home: https://github.com/dask/dask-expr 44 | summary: 'High Level Expressions for Dask' 45 | description: | 46 | High Level Expressions for Dask 47 | license: BSD-3-Clause 48 | license_family: BSD 49 | license_file: LICENSE.txt 50 | doc_url: https://github.com/dask/dask-expr/blob/main/README.md 51 | dev_url: https://github.com/dask/dask-expr 52 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "87...100" 8 | 9 | status: 10 | project: 11 | default: 12 | target: 87% 13 | threshold: 1% 14 | patch: no 15 | changes: no 16 | 17 | comment: off 18 | -------------------------------------------------------------------------------- /dask_expr/__init__.py: -------------------------------------------------------------------------------- 1 | import dask.dataframe 2 | 3 | from dask_expr import _version, datasets 4 | from dask_expr._collection import * 5 | from dask_expr._dispatch import get_collection_type 6 | from dask_expr._dummies import get_dummies 7 | from dask_expr._groupby import Aggregation 8 | from dask_expr.io._delayed import from_delayed 9 | from dask_expr.io.bag import to_bag 10 | from dask_expr.io.csv import to_csv 11 | from dask_expr.io.hdf import read_hdf, to_hdf 12 | from dask_expr.io.json import read_json, to_json 13 | from dask_expr.io.orc import read_orc, to_orc 14 | from dask_expr.io.parquet import to_parquet 15 | from dask_expr.io.records import to_records 16 | from dask_expr.io.sql import read_sql, read_sql_query, read_sql_table, to_sql 17 | 18 | __version__ = _version.get_versions()["version"] 19 | -------------------------------------------------------------------------------- /dask_expr/_accessor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from dask.dataframe.accessor import _bind_method, _bind_property, maybe_wrap_pandas 4 | from dask.dataframe.dispatch import make_meta, meta_nonempty 5 | 6 | from dask_expr._expr import Elemwise, Expr 7 | 8 | 9 | class Accessor: 10 | """ 11 | Base class for pandas Accessor objects cat, dt, and str. 12 | 13 | Notes 14 | ----- 15 | Subclasses should define ``_accessor_name``, ``_accessor_methods``, and 16 | ``_accessor_properties``. 17 | """ 18 | 19 | def __init__(self, series): 20 | from dask_expr import Series 21 | 22 | if not isinstance(series, Series): 23 | raise ValueError("Accessor cannot be initialized") 24 | 25 | series_meta = series._meta 26 | if hasattr(series_meta, "to_series"): # is index-like 27 | series_meta = series_meta.to_series() 28 | meta = getattr(series_meta, self._accessor_name) 29 | 30 | self._meta = meta 31 | self._series = series 32 | 33 | def __init_subclass__(cls, **kwargs): 34 | """Bind all auto-generated methods & properties""" 35 | import pandas as pd 36 | 37 | super().__init_subclass__(**kwargs) 38 | pd_cls = getattr(pd.Series, cls._accessor_name) 39 | for item in cls._accessor_methods: 40 | attr, min_version = item if isinstance(item, tuple) else (item, None) 41 | if not hasattr(cls, attr): 42 | _bind_method(cls, pd_cls, attr, min_version) 43 | for item in cls._accessor_properties: 44 | attr, min_version = item if isinstance(item, tuple) else (item, None) 45 | if not hasattr(cls, attr): 46 | _bind_property(cls, pd_cls, attr, min_version) 47 | 48 | @staticmethod 49 | def _delegate_property(obj, accessor, attr): 50 | out = getattr(getattr(obj, accessor, obj), attr) 51 | return maybe_wrap_pandas(obj, out) 52 | 53 | @staticmethod 54 | def _delegate_method(obj, accessor, attr, args, kwargs): 55 | out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs) 56 | return maybe_wrap_pandas(obj, out) 57 | 58 | def _function_map(self, attr, *args, **kwargs): 59 | from dask_expr._collection import Index, new_collection 60 | 61 | if isinstance(self._series, Index): 62 | return new_collection( 63 | FunctionMapIndex(self._series, self._accessor_name, attr, args, kwargs) 64 | ) 65 | 66 | return new_collection( 67 | FunctionMap(self._series, self._accessor_name, attr, args, kwargs) 68 | ) 69 | 70 | def _property_map(self, attr, *args, **kwargs): 71 | from dask_expr._collection import Index, new_collection 72 | 73 | if isinstance(self._series, Index): 74 | return new_collection( 75 | PropertyMapIndex(self._series, self._accessor_name, attr) 76 | ) 77 | 78 | return new_collection(PropertyMap(self._series, self._accessor_name, attr)) 79 | 80 | 81 | class PropertyMap(Elemwise): 82 | _parameters = [ 83 | "frame", 84 | "accessor", 85 | "attr", 86 | ] 87 | 88 | @staticmethod 89 | def operation(obj, accessor, attr): 90 | out = getattr(getattr(obj, accessor, obj), attr) 91 | return maybe_wrap_pandas(obj, out) 92 | 93 | 94 | class PropertyMapIndex(PropertyMap): 95 | def _divisions(self): 96 | # TODO: We can do better here 97 | return (None,) * (self.frame.npartitions + 1) 98 | 99 | 100 | class FunctionMap(Elemwise): 101 | _parameters = ["frame", "accessor", "attr", "args", "kwargs"] 102 | 103 | @functools.cached_property 104 | def _meta(self): 105 | args = [ 106 | meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args 107 | ] 108 | return make_meta(self.operation(*args, **self._kwargs)) 109 | 110 | @staticmethod 111 | def operation(obj, accessor, attr, args, kwargs): 112 | out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs) 113 | return maybe_wrap_pandas(obj, out) 114 | 115 | 116 | class FunctionMapIndex(FunctionMap): 117 | def _divisions(self): 118 | # TODO: We can do better here 119 | return (None,) * (self.frame.npartitions + 1) 120 | -------------------------------------------------------------------------------- /dask_expr/_backends.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from dask.backends import CreationDispatch 6 | from dask.dataframe.backends import DataFrameBackendEntrypoint 7 | from dask.dataframe.dispatch import to_pandas_dispatch 8 | 9 | from dask_expr._dispatch import get_collection_type 10 | from dask_expr._expr import ToBackend 11 | 12 | try: 13 | import sparse 14 | 15 | sparse_installed = True 16 | except ImportError: 17 | sparse_installed = False 18 | 19 | 20 | try: 21 | import scipy.sparse as sp 22 | 23 | scipy_installed = True 24 | except ImportError: 25 | scipy_installed = False 26 | 27 | 28 | dataframe_creation_dispatch = CreationDispatch( 29 | module_name="dataframe", 30 | default="pandas", 31 | entrypoint_root="dask_expr", 32 | entrypoint_class=DataFrameBackendEntrypoint, 33 | name="dataframe_creation_dispatch", 34 | ) 35 | 36 | 37 | class ToPandasBackend(ToBackend): 38 | @staticmethod 39 | def operation(df, options): 40 | return to_pandas_dispatch(df, **options) 41 | 42 | def _simplify_down(self): 43 | if isinstance(self.frame._meta, (pd.DataFrame, pd.Series, pd.Index)): 44 | # We already have pandas data 45 | return self.frame 46 | 47 | 48 | class PandasBackendEntrypoint(DataFrameBackendEntrypoint): 49 | """Pandas-Backend Entrypoint Class for Dask-Expressions 50 | 51 | Note that all DataFrame-creation functions are defined 52 | and registered 'in-place'. 53 | """ 54 | 55 | @classmethod 56 | def to_backend(cls, data, **kwargs): 57 | from dask_expr._collection import new_collection 58 | 59 | return new_collection(ToPandasBackend(data, kwargs)) 60 | 61 | 62 | dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint()) 63 | 64 | 65 | @get_collection_type.register(pd.Series) 66 | def get_collection_type_series(_): 67 | from dask_expr._collection import Series 68 | 69 | return Series 70 | 71 | 72 | @get_collection_type.register(pd.DataFrame) 73 | def get_collection_type_dataframe(_): 74 | from dask_expr._collection import DataFrame 75 | 76 | return DataFrame 77 | 78 | 79 | @get_collection_type.register(pd.Index) 80 | def get_collection_type_index(_): 81 | from dask_expr._collection import Index 82 | 83 | return Index 84 | 85 | 86 | def create_array_collection(expr): 87 | # This is hacky and an abstraction leak, but utilizing get_collection_type 88 | # to infer that we want to create an array is the only way that is guaranteed 89 | # to be a general solution. 90 | # We can get rid of this when we have an Array expression 91 | from dask.highlevelgraph import HighLevelGraph 92 | from dask.layers import Blockwise 93 | 94 | result = expr.optimize() 95 | dsk = result.__dask_graph__() 96 | name = result._name 97 | meta = result._meta 98 | divisions = result.divisions 99 | import dask.array as da 100 | 101 | chunks = ((np.nan,) * (len(divisions) - 1),) + tuple((d,) for d in meta.shape[1:]) 102 | if len(chunks) > 1: 103 | if isinstance(dsk, HighLevelGraph): 104 | layer = dsk.layers[name] 105 | else: 106 | # dask-expr provides a dict only 107 | layer = dsk 108 | if isinstance(layer, Blockwise): 109 | layer.new_axes["j"] = chunks[1][0] 110 | layer.output_indices = layer.output_indices + ("j",) 111 | else: 112 | from dask._task_spec import Alias, Task 113 | 114 | suffix = (0,) * (len(chunks) - 1) 115 | for i in range(len(chunks[0])): 116 | task = layer.get((name, i)) 117 | new_key = (name, i) + suffix 118 | if isinstance(task, Task): 119 | task = Alias(new_key, task.key) 120 | layer[new_key] = task 121 | return da.Array(dsk, name=name, chunks=chunks, dtype=meta.dtype) 122 | 123 | 124 | @get_collection_type.register(np.ndarray) 125 | def get_collection_type_array(_): 126 | return create_array_collection 127 | 128 | 129 | if sparse_installed: 130 | 131 | @get_collection_type.register(sparse.COO) 132 | def get_collection_type_array(_): 133 | return create_array_collection 134 | 135 | 136 | if scipy_installed: 137 | 138 | @get_collection_type.register(sp.csr_matrix) 139 | def get_collection_type_array(_): 140 | return create_array_collection 141 | 142 | 143 | @get_collection_type.register(object) 144 | def get_collection_type_object(_): 145 | from dask_expr._collection import Scalar 146 | 147 | return Scalar 148 | 149 | 150 | ###################################### 151 | # cuDF: Pandas Dataframes on the GPU # 152 | ###################################### 153 | 154 | 155 | @get_collection_type.register_lazy("cudf") 156 | def _register_cudf(): 157 | import dask_cudf # noqa: F401 158 | -------------------------------------------------------------------------------- /dask_expr/_categorical.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pandas as pd 4 | from dask.dataframe.categorical import ( 5 | _categorize_block, 6 | _get_categories, 7 | _get_categories_agg, 8 | ) 9 | from dask.dataframe.utils import ( 10 | AttributeNotImplementedError, 11 | clear_known_categories, 12 | has_known_categories, 13 | ) 14 | from dask.utils import M 15 | 16 | from dask_expr._accessor import Accessor, PropertyMap 17 | from dask_expr._expr import Blockwise, Elemwise, Projection 18 | from dask_expr._reductions import ApplyConcatApply 19 | 20 | 21 | class CategoricalAccessor(Accessor): 22 | """ 23 | Accessor object for categorical properties of the Series values. 24 | 25 | Examples 26 | -------- 27 | >>> s.cat.categories # doctest: +SKIP 28 | 29 | Notes 30 | ----- 31 | Attributes that depend only on metadata are eager 32 | 33 | * categories 34 | * ordered 35 | 36 | Attributes depending on the entire dataset are lazy 37 | 38 | * codes 39 | * ... 40 | 41 | So `df.a.cat.categories` <=> `df.a._meta.cat.categories` 42 | So `df.a.cat.codes` <=> `df.a.map_partitions(lambda x: x.cat.codes)` 43 | """ 44 | 45 | _accessor_name = "cat" 46 | _accessor_methods = ( 47 | "add_categories", 48 | "as_ordered", 49 | "as_unordered", 50 | "remove_categories", 51 | "rename_categories", 52 | "reorder_categories", 53 | "set_categories", 54 | ) 55 | _accessor_properties = () 56 | 57 | @property 58 | def known(self): 59 | """Whether the categories are fully known""" 60 | return has_known_categories(self._series) 61 | 62 | def as_known(self, **kwargs): 63 | """Ensure the categories in this series are known. 64 | 65 | If the categories are known, this is a no-op. If unknown, the 66 | categories are computed, and a new series with known categories is 67 | returned. 68 | 69 | Parameters 70 | ---------- 71 | kwargs 72 | Keywords to pass on to the call to `compute`. 73 | """ 74 | if self.known: 75 | return self._series 76 | from dask_expr._collection import new_collection 77 | 78 | categories = ( 79 | new_collection(PropertyMap(self._series, "cat", "categories")) 80 | .unique() 81 | .compute() 82 | ) 83 | return self.set_categories(categories.values) 84 | 85 | def as_unknown(self): 86 | """Ensure the categories in this series are unknown""" 87 | if not self.known: 88 | return self._series 89 | 90 | from dask_expr import new_collection 91 | 92 | return new_collection(AsUnknown(self._series)) 93 | 94 | @property 95 | def ordered(self): 96 | """Whether the categories have an ordered relationship""" 97 | return self._delegate_property(self._series._meta, "cat", "ordered") 98 | 99 | @property 100 | def categories(self): 101 | """The categories of this categorical. 102 | 103 | If categories are unknown, an error is raised""" 104 | if not self.known: 105 | msg = ( 106 | "`df.column.cat.categories` with unknown categories is not " 107 | "supported. Please use `column.cat.as_known()` or " 108 | "`df.categorize()` beforehand to ensure known categories" 109 | ) 110 | raise AttributeNotImplementedError(msg) 111 | return self._delegate_property(self._series._meta, "cat", "categories") 112 | 113 | @property 114 | def codes(self): 115 | """The codes of this categorical. 116 | 117 | If categories are unknown, an error is raised""" 118 | if not self.known: 119 | msg = ( 120 | "`df.column.cat.codes` with unknown categories is not " 121 | "supported. Please use `column.cat.as_known()` or " 122 | "`df.categorize()` beforehand to ensure known categories" 123 | ) 124 | raise AttributeNotImplementedError(msg) 125 | from dask_expr._collection import new_collection 126 | 127 | return new_collection(PropertyMap(self._series, "cat", "codes")) 128 | 129 | def remove_unused_categories(self): 130 | """ 131 | Removes categories which are not used 132 | 133 | Notes 134 | ----- 135 | This method requires a full scan of the data to compute the 136 | unique values, which can be expensive. 137 | """ 138 | # get the set of used categories 139 | present = self._series.dropna().unique() 140 | present = pd.Index(present.compute()) 141 | 142 | if isinstance(self._series._meta, pd.CategoricalIndex): 143 | meta_cat = self._series._meta 144 | else: 145 | meta_cat = self._series._meta.cat 146 | 147 | # Reorder to keep cat:code relationship, filtering unused (-1) 148 | ordered, mask = present.reindex(meta_cat.categories) 149 | if mask is None: 150 | # PANDAS-23963: old and new categories match. 151 | return self._series 152 | 153 | new_categories = ordered[mask != -1] 154 | return self.set_categories(new_categories) 155 | 156 | 157 | class AsUnknown(Elemwise): 158 | _parameters = ["frame"] 159 | operation = M.copy 160 | 161 | @functools.cached_property 162 | def _meta(self): 163 | return clear_known_categories(self.frame._meta) 164 | 165 | 166 | class Categorize(Blockwise): 167 | _parameters = ["frame", "categories", "index"] 168 | operation = staticmethod(_categorize_block) 169 | _projection_passthrough = True 170 | 171 | @functools.cached_property 172 | def _meta(self): 173 | return _categorize_block( 174 | self.frame._meta, self.operand("categories"), self.operand("index") 175 | ) 176 | 177 | def _simplify_up(self, parent, dependents): 178 | result = super()._simplify_up(parent, dependents) 179 | if result is None: 180 | return result 181 | # pop potentially dropped columns from categories 182 | cats = self.operand("categories") 183 | cats = {k: v for k, v in cats.items() if k in result.frame.columns} 184 | return Categorize(result.frame, cats, result.operand("index")) 185 | 186 | 187 | class GetCategories(ApplyConcatApply): 188 | _parameters = ["frame", "columns", "index", "split_every"] 189 | 190 | chunk = staticmethod(_get_categories) 191 | aggregate = staticmethod(_get_categories_agg) 192 | 193 | @property 194 | def chunk_kwargs(self): 195 | return {"columns": self.operand("columns"), "index": self.operand("index")} 196 | 197 | @functools.cached_property 198 | def _meta(self): 199 | return ({}, pd.Series()) 200 | 201 | def _simplify_down(self): 202 | if set(self.frame.columns) == set(self.operand("columns")): 203 | return None 204 | 205 | return GetCategories( 206 | Projection(self.frame, self.operand("columns")), 207 | columns=self.operand("columns"), 208 | index=self.operand("index"), 209 | split_every=self.operand("split_every"), 210 | ) 211 | -------------------------------------------------------------------------------- /dask_expr/_cumulative.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | 4 | import pandas as pd 5 | from dask.dataframe import methods 6 | from dask.utils import M 7 | 8 | from dask_expr._expr import Blockwise, Expr, Projection, plain_column_projection 9 | 10 | 11 | class CumulativeAggregations(Expr): 12 | _parameters = ["frame", "axis", "skipna"] 13 | _defaults = {"axis": None} 14 | 15 | chunk_operation = None 16 | aggregate_operation = None 17 | neutral_element = None 18 | 19 | def _divisions(self): 20 | return self.frame._divisions() 21 | 22 | @functools.cached_property 23 | def _meta(self): 24 | return self.frame._meta 25 | 26 | def _lower(self): 27 | chunks = CumulativeBlockwise( 28 | self.frame, self.axis, self.skipna, self.chunk_operation 29 | ) 30 | chunks_last = TakeLast(chunks, self.skipna) 31 | return CumulativeFinalize( 32 | chunks, chunks_last, self.aggregate_operation, self.neutral_element 33 | ) 34 | 35 | def _simplify_up(self, parent, dependents): 36 | if isinstance(parent, Projection): 37 | return plain_column_projection(self, parent, dependents) 38 | 39 | 40 | class CumulativeBlockwise(Blockwise): 41 | _parameters = ["frame", "axis", "skipna", "operation"] 42 | _defaults = {"skipna": True, "axis": None} 43 | _projection_passthrough = True 44 | 45 | @functools.cached_property 46 | def _meta(self): 47 | return self.frame._meta 48 | 49 | @functools.cached_property 50 | def operation(self): 51 | return self.operand("operation") 52 | 53 | @functools.cached_property 54 | def _args(self) -> list: 55 | return self.operands[:-1] 56 | 57 | 58 | class TakeLast(Blockwise): 59 | _parameters = ["frame", "skipna"] 60 | _projection_passthrough = True 61 | 62 | @staticmethod 63 | def operation(a, skipna=True): 64 | if skipna: 65 | if a.ndim == 1 and (a.empty or a.isna().all()): 66 | return None 67 | a = a.ffill() 68 | return a.tail(n=1).squeeze() 69 | 70 | 71 | class CumulativeFinalize(Expr): 72 | _parameters = ["frame", "previous_partitions", "aggregator", "neutral_element"] 73 | 74 | def _divisions(self): 75 | return self.frame._divisions() 76 | 77 | @functools.cached_property 78 | def _meta(self): 79 | return self.frame._meta 80 | 81 | def _layer(self) -> dict: 82 | dsk = {} 83 | frame, previous_partitions = self.frame, self.previous_partitions 84 | dsk[(self._name, 0)] = (frame._name, 0) 85 | 86 | intermediate_name = self._name + "-intermediate" 87 | for i in range(1, self.frame.npartitions): 88 | if i == 1: 89 | dsk[(intermediate_name, i)] = (previous_partitions._name, i - 1) 90 | else: 91 | # aggregate with previous cumulation results 92 | dsk[(intermediate_name, i)] = ( 93 | cumulative_wrapper_intermediate, 94 | self.aggregator, 95 | (intermediate_name, i - 1), 96 | (previous_partitions._name, i - 1), 97 | self.neutral_element, 98 | ) 99 | dsk[(self._name, i)] = ( 100 | cumulative_wrapper, 101 | self.aggregator, 102 | (self.frame._name, i), 103 | (intermediate_name, i), 104 | self.neutral_element, 105 | ) 106 | return dsk 107 | 108 | 109 | def cumulative_wrapper(func, x, y, neutral_element): 110 | if isinstance(y, pd.Series) and len(y) == 0: 111 | y = neutral_element 112 | return func(x, y) 113 | 114 | 115 | def cumulative_wrapper_intermediate(func, x, y, neutral_element): 116 | if isinstance(y, pd.Series) and len(y) == 0: 117 | y = neutral_element 118 | return methods._cum_aggregate_apply(func, x, y) 119 | 120 | 121 | class CumSum(CumulativeAggregations): 122 | chunk_operation = M.cumsum 123 | aggregate_operation = staticmethod(methods.cumsum_aggregate) 124 | neutral_element = 0 125 | 126 | 127 | class CumProd(CumulativeAggregations): 128 | chunk_operation = M.cumprod 129 | aggregate_operation = staticmethod(methods.cumprod_aggregate) 130 | neutral_element = 1 131 | 132 | 133 | class CumMax(CumulativeAggregations): 134 | chunk_operation = M.cummax 135 | aggregate_operation = staticmethod(methods.cummax_aggregate) 136 | neutral_element = -math.inf 137 | 138 | 139 | class CumMin(CumulativeAggregations): 140 | chunk_operation = M.cummin 141 | aggregate_operation = staticmethod(methods.cummin_aggregate) 142 | neutral_element = math.inf 143 | -------------------------------------------------------------------------------- /dask_expr/_datetime.py: -------------------------------------------------------------------------------- 1 | from dask_expr._accessor import Accessor 2 | 3 | 4 | class DatetimeAccessor(Accessor): 5 | """Accessor object for datetimelike properties of the Series values. 6 | 7 | Examples 8 | -------- 9 | 10 | >>> s.dt.microsecond # doctest: +SKIP 11 | """ 12 | 13 | _accessor_name = "dt" 14 | 15 | _accessor_methods = ( 16 | "ceil", 17 | "day_name", 18 | "floor", 19 | "isocalendar", 20 | "month_name", 21 | "normalize", 22 | "round", 23 | "strftime", 24 | "to_period", 25 | "to_pydatetime", 26 | "to_pytimedelta", 27 | "to_timestamp", 28 | "total_seconds", 29 | "tz_convert", 30 | "tz_localize", 31 | ) 32 | 33 | _accessor_properties = ( 34 | "components", 35 | "date", 36 | "day", 37 | "day_of_week", 38 | "day_of_year", 39 | "dayofweek", 40 | "dayofyear", 41 | "days", 42 | "days_in_month", 43 | "daysinmonth", 44 | "end_time", 45 | "freq", 46 | "hour", 47 | "is_leap_year", 48 | "is_month_end", 49 | "is_month_start", 50 | "is_quarter_end", 51 | "is_quarter_start", 52 | "is_year_end", 53 | "is_year_start", 54 | "microsecond", 55 | "microseconds", 56 | "minute", 57 | "month", 58 | "nanosecond", 59 | "nanoseconds", 60 | "quarter", 61 | "qyear", 62 | "second", 63 | "seconds", 64 | "start_time", 65 | "time", 66 | "timetz", 67 | "tz", 68 | "week", 69 | "weekday", 70 | "weekofyear", 71 | "year", 72 | ) 73 | -------------------------------------------------------------------------------- /dask_expr/_describe.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | from dask.dataframe.dispatch import make_meta, meta_nonempty 5 | from dask.dataframe.methods import ( 6 | describe_nonnumeric_aggregate, 7 | describe_numeric_aggregate, 8 | ) 9 | from pandas.core.dtypes.common import is_datetime64_any_dtype, is_timedelta64_dtype 10 | 11 | from dask_expr._expr import Blockwise, DropnaSeries, Filter, Head, Sqrt, ToNumeric 12 | from dask_expr._quantile import SeriesQuantile 13 | from dask_expr._reductions import Reduction, Size, ValueCounts 14 | 15 | 16 | class DescribeNumeric(Reduction): 17 | _parameters = ["frame", "split_every", "percentiles", "percentile_method"] 18 | _defaults = { 19 | "percentiles": None, 20 | "split_every": None, 21 | "percentile_method": "default", 22 | } 23 | 24 | @functools.cached_property 25 | def _meta(self): 26 | return make_meta(meta_nonempty(self.frame._meta).describe()) 27 | 28 | def _divisions(self): 29 | return (None, None) 30 | 31 | def _lower(self): 32 | frame = self.frame 33 | if self.percentiles is None: 34 | percentiles = self.percentiles or [0.25, 0.5, 0.75] 35 | else: 36 | percentiles = np.array(self.percentiles) 37 | percentiles = np.append(percentiles, 0.5) 38 | percentiles = np.unique(percentiles) 39 | percentiles = list(percentiles) 40 | 41 | is_td_col = is_timedelta64_dtype(frame._meta.dtype) 42 | is_dt_col = is_datetime64_any_dtype(frame._meta.dtype) 43 | if is_td_col or is_dt_col: 44 | frame = ToNumeric(DropnaSeries(frame)) 45 | 46 | stats = [ 47 | frame.count(split_every=self.split_every), 48 | frame.mean(split_every=self.split_every), 49 | Sqrt(frame.var(split_every=self.split_every)), 50 | frame.min(split_every=self.split_every), 51 | SeriesQuantile(frame, q=percentiles, method=self.percentile_method), 52 | frame.max(split_every=self.split_every), 53 | ] 54 | try: 55 | unit = getattr(self.frame._meta.array, "unit", None) 56 | except AttributeError: 57 | # cudf Series has no array attribute 58 | unit = None 59 | return DescribeNumericAggregate( 60 | self.frame._meta.name, 61 | is_td_col, 62 | is_dt_col, 63 | unit, 64 | *stats, 65 | ) 66 | 67 | 68 | class DescribeNumericAggregate(Blockwise): 69 | _parameters = ["name", "is_timedelta_col", "is_datetime_col", "unit"] 70 | _defaults = {"is_timedelta_col": False, "is_datetime_col": False} 71 | 72 | def _broadcast_dep(self, dep): 73 | return dep.npartitions == 1 74 | 75 | @staticmethod 76 | def operation(name, is_timedelta_col, is_datetime_col, unit, *stats): 77 | return describe_numeric_aggregate( 78 | stats, name, is_timedelta_col, is_datetime_col, unit 79 | ) 80 | 81 | 82 | class DescribeNonNumeric(DescribeNumeric): 83 | _parameters = ["frame", "split_every"] 84 | 85 | def _lower(self): 86 | frame = self.frame 87 | vcounts = ValueCounts(frame, split_every=self.split_every, sort=True) 88 | count_unique = Size(Filter(vcounts, vcounts > 0)) 89 | stats = [ 90 | count_unique, 91 | frame.count(split_every=self.split_every), 92 | Head(vcounts, n=1), 93 | ] 94 | return DescribeNonNumericAggregate(frame._meta.name, *stats) 95 | 96 | 97 | class DescribeNonNumericAggregate(DescribeNumericAggregate): 98 | _parameters = ["name"] 99 | _defaults = {} 100 | 101 | @staticmethod 102 | def operation(name, *stats): 103 | return describe_nonnumeric_aggregate(stats, name) 104 | -------------------------------------------------------------------------------- /dask_expr/_dispatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dask.utils import Dispatch 4 | 5 | get_collection_type = Dispatch("get_collection_type") 6 | -------------------------------------------------------------------------------- /dask_expr/_dummies.py: -------------------------------------------------------------------------------- 1 | import dask.dataframe.methods as methods 2 | import pandas as pd 3 | from dask.dataframe.utils import has_known_categories 4 | from dask.utils import get_meta_library 5 | 6 | from dask_expr._collection import DataFrame, Series, new_collection 7 | from dask_expr._expr import Blockwise 8 | 9 | 10 | def get_dummies( 11 | data, 12 | prefix=None, 13 | prefix_sep="_", 14 | dummy_na=False, 15 | columns=None, 16 | sparse=False, 17 | drop_first=False, 18 | dtype=bool, 19 | **kwargs, 20 | ): 21 | """ 22 | Convert categorical variable into dummy/indicator variables. 23 | 24 | Data must have category dtype to infer result's ``columns``. 25 | 26 | Parameters 27 | ---------- 28 | data : Series, or DataFrame 29 | For Series, the dtype must be categorical. 30 | For DataFrame, at least one column must be categorical. 31 | prefix : string, list of strings, or dict of strings, default None 32 | String to append DataFrame column names. 33 | Pass a list with length equal to the number of columns 34 | when calling get_dummies on a DataFrame. Alternatively, `prefix` 35 | can be a dictionary mapping column names to prefixes. 36 | prefix_sep : string, default '_' 37 | If appending prefix, separator/delimiter to use. Or pass a 38 | list or dictionary as with `prefix.` 39 | dummy_na : bool, default False 40 | Add a column to indicate NaNs, if False NaNs are ignored. 41 | columns : list-like, default None 42 | Column names in the DataFrame to be encoded. 43 | If `columns` is None then all the columns with 44 | `category` dtype will be converted. 45 | sparse : bool, default False 46 | Whether the dummy columns should be sparse or not. Returns 47 | SparseDataFrame if `data` is a Series or if all columns are included. 48 | Otherwise returns a DataFrame with some SparseBlocks. 49 | 50 | .. versionadded:: 0.18.2 51 | 52 | drop_first : bool, default False 53 | Whether to get k-1 dummies out of k categorical levels by removing the 54 | first level. 55 | 56 | dtype : dtype, default bool 57 | Data type for new columns. Only a single dtype is allowed. 58 | 59 | .. versionadded:: 0.18.2 60 | 61 | Returns 62 | ------- 63 | dummies : DataFrame 64 | 65 | Examples 66 | -------- 67 | Dask's version only works with Categorical data, as this is the only way to 68 | know the output shape without computing all the data. 69 | 70 | >>> import pandas as pd 71 | >>> import dask.dataframe as dd 72 | >>> s = dd.from_pandas(pd.Series(list('abca')), npartitions=2) 73 | >>> dd.get_dummies(s) 74 | Traceback (most recent call last): 75 | ... 76 | NotImplementedError: `get_dummies` with non-categorical dtypes is not supported... 77 | 78 | With categorical data: 79 | 80 | >>> s = dd.from_pandas(pd.Series(list('abca'), dtype='category'), npartitions=2) 81 | >>> dd.get_dummies(s) # doctest: +NORMALIZE_WHITESPACE 82 | Dask DataFrame Structure: 83 | a b c 84 | npartitions=2 85 | 0 bool bool bool 86 | 2 ... ... ... 87 | 3 ... ... ... 88 | Dask Name: get_dummies, 2 graph layers 89 | >>> dd.get_dummies(s).compute() # doctest: +ELLIPSIS 90 | a b c 91 | 0 True False False 92 | 1 False True False 93 | 2 False False True 94 | 3 True False False 95 | 96 | See Also 97 | -------- 98 | pandas.get_dummies 99 | """ 100 | if isinstance(data, (pd.Series, pd.DataFrame)): 101 | return pd.get_dummies( 102 | data, 103 | prefix=prefix, 104 | prefix_sep=prefix_sep, 105 | dummy_na=dummy_na, 106 | columns=columns, 107 | sparse=sparse, 108 | drop_first=drop_first, 109 | dtype=dtype, 110 | **kwargs, 111 | ) 112 | 113 | not_cat_msg = ( 114 | "`get_dummies` with non-categorical dtypes is not " 115 | "supported. Please use `df.categorize()` beforehand to " 116 | "convert to categorical dtype." 117 | ) 118 | 119 | unknown_cat_msg = ( 120 | "`get_dummies` with unknown categories is not " 121 | "supported. Please use `column.cat.as_known()` or " 122 | "`df.categorize()` beforehand to ensure known " 123 | "categories" 124 | ) 125 | 126 | if isinstance(data, Series): 127 | if not methods.is_categorical_dtype(data): 128 | raise NotImplementedError(not_cat_msg) 129 | if not has_known_categories(data): 130 | raise NotImplementedError(unknown_cat_msg) 131 | elif isinstance(data, DataFrame): 132 | if columns is None: 133 | if (data.dtypes == "object").any(): 134 | raise NotImplementedError(not_cat_msg) 135 | if (data.dtypes == "string").any(): 136 | raise NotImplementedError(not_cat_msg) 137 | columns = data._meta.select_dtypes(include=["category"]).columns 138 | else: 139 | if not all(methods.is_categorical_dtype(data[c]) for c in columns): 140 | raise NotImplementedError(not_cat_msg) 141 | 142 | if not all(has_known_categories(data[c]) for c in columns): 143 | raise NotImplementedError(unknown_cat_msg) 144 | 145 | return new_collection( 146 | GetDummies( 147 | data, prefix, prefix_sep, dummy_na, columns, sparse, drop_first, dtype 148 | ) 149 | ) 150 | 151 | 152 | class GetDummies(Blockwise): 153 | _parameters = [ 154 | "frame", 155 | "prefix", 156 | "prefix_sep", 157 | "dummy_na", 158 | "columns", 159 | "sparse", 160 | "drop_first", 161 | "dtype", 162 | ] 163 | _defaults = { 164 | "prefix": None, 165 | "prefix_sep": "_", 166 | "dummy_na": False, 167 | "columns": None, 168 | "sparse": False, 169 | "drop_first": False, 170 | "dtype": bool, 171 | } 172 | # cudf has extra kwargs after `columns` 173 | _keyword_only = ["sparse", "drop_first", "dtype"] 174 | 175 | @staticmethod 176 | def operation(df, *args, **kwargs): 177 | return get_meta_library(df).get_dummies(df, *args, **kwargs) 178 | -------------------------------------------------------------------------------- /dask_expr/_interchange.py: -------------------------------------------------------------------------------- 1 | from dask.dataframe._compat import is_string_dtype 2 | from dask.dataframe.dispatch import is_categorical_dtype 3 | from pandas.core.interchange.dataframe_protocol import DtypeKind 4 | 5 | from dask_expr._collection import DataFrame 6 | 7 | _NP_KINDS = { 8 | "i": DtypeKind.INT, 9 | "u": DtypeKind.UINT, 10 | "f": DtypeKind.FLOAT, 11 | "b": DtypeKind.BOOL, 12 | "U": DtypeKind.STRING, 13 | "M": DtypeKind.DATETIME, 14 | "m": DtypeKind.DATETIME, 15 | } 16 | 17 | 18 | class DaskDataFrameInterchange: 19 | def __init__( 20 | self, df: DataFrame, nan_as_null: bool = False, allow_copy: bool = True 21 | ) -> None: 22 | self._df = df 23 | self._nan_as_null = nan_as_null 24 | self._allow_copy = allow_copy 25 | 26 | def get_columns(self): 27 | return [DaskColumn(self._df[name]) for name in self._df.columns] 28 | 29 | def column_names(self): 30 | return self._df.columns 31 | 32 | def num_columns(self) -> int: 33 | return len(self._df.columns) 34 | 35 | def num_rows(self) -> int: 36 | return len(self._df) 37 | 38 | 39 | class DaskColumn: 40 | def __init__(self, column, allow_copy: bool = True) -> None: 41 | self._col = column 42 | self._allow_copy = allow_copy 43 | 44 | def dtype(self) -> tuple[DtypeKind, None, None, None]: 45 | dtype = self._col.dtype 46 | 47 | if is_categorical_dtype(dtype): 48 | return ( 49 | DtypeKind.CATEGORICAL, 50 | None, 51 | None, 52 | None, 53 | ) 54 | elif is_string_dtype(dtype): 55 | return ( 56 | DtypeKind.STRING, 57 | None, 58 | None, 59 | None, 60 | ) 61 | else: 62 | return _NP_KINDS.get(dtype.kind, None), None, None, None 63 | -------------------------------------------------------------------------------- /dask_expr/_quantile.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | from dask.dataframe.dispatch import make_meta, meta_nonempty 5 | from dask.utils import import_required, is_series_like 6 | 7 | from dask_expr._expr import DropnaSeries, Expr 8 | 9 | 10 | def _finalize_scalar_result(cons, *args, **kwargs): 11 | return cons(*args, **kwargs)[0] 12 | 13 | 14 | class SeriesQuantile(Expr): 15 | _parameters = ["frame", "q", "method"] 16 | _defaults = {"method": "default"} 17 | 18 | @functools.cached_property 19 | def q(self): 20 | q = np.array(self.operand("q")) 21 | if q.ndim > 0: 22 | assert len(q) > 0, f"must provide non-empty q={q}" 23 | q.sort(kind="mergesort") 24 | return q 25 | return np.asarray([self.operand("q")]) 26 | 27 | @functools.cached_property 28 | def method(self): 29 | if self.operand("method") == "default": 30 | return "dask" 31 | else: 32 | return self.operand("method") 33 | 34 | @functools.cached_property 35 | def _meta(self): 36 | meta = self.frame._meta 37 | if not is_series_like(self.frame._meta): 38 | meta = meta.to_series() 39 | return make_meta(meta_nonempty(meta).quantile(self.operand("q"))) 40 | 41 | def _divisions(self): 42 | if is_series_like(self._meta): 43 | return (np.min(self.q), np.max(self.q)) 44 | return (None, None) 45 | 46 | @functools.cached_property 47 | def _constructor(self): 48 | meta = self.frame._meta 49 | if not is_series_like(self.frame._meta): 50 | meta = meta.to_series() 51 | return meta._constructor 52 | 53 | @functools.cached_property 54 | def _finalizer(self): 55 | if is_series_like(self._meta): 56 | return lambda tsk: ( 57 | self._constructor, 58 | tsk, 59 | self.q, 60 | None, 61 | self.frame._meta.name, 62 | ) 63 | else: 64 | return lambda tsk: (_finalize_scalar_result, self._constructor, tsk, [0]) 65 | 66 | def _lower(self): 67 | frame = DropnaSeries(self.frame) 68 | if self.method == "tdigest": 69 | return SeriesQuantileTdigest( 70 | frame, self.operand("q"), self.operand("method") 71 | ) 72 | else: 73 | return SeriesQuantileDask(frame, self.operand("q"), self.operand("method")) 74 | 75 | 76 | class SeriesQuantileTdigest(SeriesQuantile): 77 | @functools.cached_property 78 | def _meta(self): 79 | import_required( 80 | "crick", "crick is a required dependency for using the tdigest method." 81 | ) 82 | return super()._meta 83 | 84 | def _layer(self) -> dict: 85 | from dask.array.percentile import _percentiles_from_tdigest, _tdigest_chunk 86 | 87 | dsk = {} 88 | for i in range(self.frame.npartitions): 89 | dsk[("chunk-" + self._name, i)] = ( 90 | _tdigest_chunk, 91 | (getattr, (self.frame._name, i), "values"), 92 | ) 93 | 94 | dsk[(self._name, 0)] = self._finalizer( 95 | (_percentiles_from_tdigest, self.q * 100, sorted(dsk)) 96 | ) 97 | return dsk 98 | 99 | def _lower(self): 100 | return None 101 | 102 | 103 | class SeriesQuantileDask(SeriesQuantile): 104 | def _layer(self) -> dict: 105 | from dask.array.dispatch import percentile_lookup as _percentile 106 | from dask.array.percentile import merge_percentiles 107 | 108 | dsk = {} 109 | # Add 0 and 100 during calculation for more robust behavior (hopefully) 110 | calc_qs = np.pad(self.q * 100, 1, mode="constant") 111 | calc_qs[-1] = 100 112 | 113 | for i in range(self.frame.npartitions): 114 | dsk[("chunk-" + self._name, i)] = ( 115 | _percentile, 116 | (self.frame._name, i), 117 | calc_qs, 118 | ) 119 | dsk[(self._name, 0)] = self._finalizer( 120 | ( 121 | merge_percentiles, 122 | self.q * 100, 123 | [calc_qs] * self.frame.npartitions, 124 | sorted(dsk), 125 | "lower", 126 | None, 127 | False, 128 | ) 129 | ) 130 | return dsk 131 | 132 | def _lower(self): 133 | return None 134 | -------------------------------------------------------------------------------- /dask_expr/_quantiles.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | import toolz 5 | from dask.dataframe.partitionquantiles import ( 6 | create_merge_tree, 7 | dtype_info, 8 | merge_and_compress_summaries, 9 | percentiles_summary, 10 | process_val_weights, 11 | ) 12 | from dask.tokenize import tokenize 13 | from dask.utils import random_state_data 14 | 15 | from dask_expr._expr import Expr 16 | 17 | 18 | class RepartitionQuantiles(Expr): 19 | _parameters = ["frame", "input_npartitions", "upsample", "random_state"] 20 | _defaults = {"upsample": 1.0, "random_state": None} 21 | 22 | @functools.cached_property 23 | def _meta(self): 24 | return self.frame._meta 25 | 26 | @property 27 | def npartitions(self): 28 | return 1 29 | 30 | def _divisions(self): 31 | return 0.0, 1.0 32 | 33 | def __dask_postcompute__(self): 34 | return toolz.first, () 35 | 36 | def _layer(self): 37 | import pandas as pd 38 | 39 | qs = np.linspace(0, 1, self.input_npartitions + 1) 40 | if self.random_state is None: 41 | random_state = int(tokenize(self.operands), 16) % np.iinfo(np.int32).max 42 | else: 43 | random_state = self.random_state 44 | state_data = random_state_data(self.frame.npartitions, random_state) 45 | 46 | keys = self.frame.__dask_keys__() 47 | dtype_dsk = {(self._name, 0, 0): (dtype_info, keys[0])} 48 | 49 | percentiles_dsk = { 50 | (self._name, 1, i): ( 51 | percentiles_summary, 52 | key, 53 | self.frame.npartitions, 54 | self.input_npartitions, 55 | self.upsample, 56 | state, 57 | ) 58 | for i, (state, key) in enumerate(zip(state_data, keys)) 59 | } 60 | 61 | merge_dsk = create_merge_tree( 62 | merge_and_compress_summaries, sorted(percentiles_dsk), self._name, 2 63 | ) 64 | if not merge_dsk: 65 | # Compress the data even if we only have one partition 66 | merge_dsk = { 67 | (self._name, 2, 0): ( 68 | merge_and_compress_summaries, 69 | [list(percentiles_dsk)[0]], 70 | ) 71 | } 72 | 73 | merged_key = max(merge_dsk) 74 | last_dsk = { 75 | (self._name, 0): ( 76 | pd.Series, 77 | ( 78 | process_val_weights, 79 | merged_key, 80 | self.input_npartitions, 81 | (self._name, 0, 0), 82 | ), 83 | qs, 84 | None, 85 | self.frame._meta.name, 86 | ) 87 | } 88 | return {**dtype_dsk, **percentiles_dsk, **merge_dsk, **last_dsk} 89 | -------------------------------------------------------------------------------- /dask_expr/_str_accessor.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from dask.dataframe.dispatch import make_meta, meta_nonempty 4 | 5 | from dask_expr._accessor import Accessor, FunctionMap 6 | from dask_expr._expr import Blockwise 7 | from dask_expr._reductions import Reduction 8 | 9 | 10 | class StringAccessor(Accessor): 11 | """Accessor object for string properties of the Series values. 12 | 13 | Examples 14 | -------- 15 | 16 | >>> s.str.lower() # doctest: +SKIP 17 | """ 18 | 19 | _accessor_name = "str" 20 | 21 | _accessor_methods = ( 22 | "capitalize", 23 | "casefold", 24 | "center", 25 | "contains", 26 | "count", 27 | "decode", 28 | "encode", 29 | "endswith", 30 | "extract", 31 | "extractall", 32 | "find", 33 | "findall", 34 | "fullmatch", 35 | "get", 36 | "index", 37 | "isalnum", 38 | "isalpha", 39 | "isdecimal", 40 | "isdigit", 41 | "islower", 42 | "isnumeric", 43 | "isspace", 44 | "istitle", 45 | "isupper", 46 | "join", 47 | "len", 48 | "ljust", 49 | "lower", 50 | "lstrip", 51 | "match", 52 | "normalize", 53 | "pad", 54 | "partition", 55 | "removeprefix", 56 | "removesuffix", 57 | "repeat", 58 | "replace", 59 | "rfind", 60 | "rindex", 61 | "rjust", 62 | "rpartition", 63 | "rstrip", 64 | "slice", 65 | "slice_replace", 66 | "startswith", 67 | "strip", 68 | "swapcase", 69 | "title", 70 | "translate", 71 | "upper", 72 | "wrap", 73 | "zfill", 74 | ) 75 | _accessor_properties = () 76 | 77 | def _split(self, method, pat=None, n=-1, expand=False): 78 | from dask_expr import new_collection 79 | 80 | if expand: 81 | if n == -1: 82 | raise NotImplementedError( 83 | "To use the expand parameter you must specify the number of " 84 | "expected splits with the n= parameter. Usually n splits " 85 | "result in n+1 output columns." 86 | ) 87 | return new_collection( 88 | SplitMap( 89 | self._series, 90 | self._accessor_name, 91 | method, 92 | (), 93 | {"pat": pat, "n": n, "expand": expand}, 94 | ) 95 | ) 96 | return self._function_map(method, pat=pat, n=n, expand=expand) 97 | 98 | def split(self, pat=None, n=-1, expand=False): 99 | """Known inconsistencies: ``expand=True`` with unknown ``n`` will raise a ``NotImplementedError``.""" 100 | return self._split("split", pat=pat, n=n, expand=expand) 101 | 102 | def rsplit(self, pat=None, n=-1, expand=False): 103 | return self._split("rsplit", pat=pat, n=n, expand=expand) 104 | 105 | def cat(self, others=None, sep=None, na_rep=None): 106 | import pandas as pd 107 | 108 | from dask_expr._collection import Index, Series, new_collection 109 | 110 | if others is None: 111 | return new_collection(Cat(self._series, sep, na_rep)) 112 | 113 | valid_types = (Series, Index, pd.Series, pd.Index) 114 | if isinstance(others, valid_types): 115 | others = [others] 116 | elif not all(isinstance(a, valid_types) for a in others): 117 | raise TypeError("others must be Series/Index") 118 | 119 | return new_collection(CatBlockwise(self._series, sep, na_rep, *others)) 120 | 121 | def __getitem__(self, index): 122 | return self._function_map("__getitem__", index) 123 | 124 | 125 | class CatBlockwise(Blockwise): 126 | _parameters = ["frame", "sep", "na_rep"] 127 | _keyword_only = ["sep", "na_rep"] 128 | 129 | @property 130 | def _args(self) -> list: 131 | return [self.frame] + self.operands[len(self._parameters) :] 132 | 133 | @staticmethod 134 | def operation(ser, *args, **kwargs): 135 | return ser.str.cat(list(args), **kwargs) 136 | 137 | 138 | class Cat(Reduction): 139 | _parameters = ["frame", "sep", "na_rep"] 140 | 141 | @property 142 | def chunk_kwargs(self): 143 | return {"sep": self.sep, "na_rep": self.na_rep} 144 | 145 | @property 146 | def combine_kwargs(self): 147 | return self.chunk_kwargs 148 | 149 | @property 150 | def aggregate_kwargs(self): 151 | return self.chunk_kwargs 152 | 153 | @staticmethod 154 | def reduction_chunk(ser, *args, **kwargs): 155 | return ser.str.cat(*args, **kwargs) 156 | 157 | @staticmethod 158 | def reduction_combine(ser, *args, **kwargs): 159 | return Cat.reduction_chunk(ser, *args, **kwargs) 160 | 161 | @staticmethod 162 | def reduction_aggregate(ser, *args, **kwargs): 163 | return Cat.reduction_chunk(ser, *args, **kwargs) 164 | 165 | 166 | class SplitMap(FunctionMap): 167 | _parameters = ["frame", "accessor", "attr", "args", "kwargs"] 168 | 169 | @property 170 | def n(self): 171 | return self.kwargs["n"] 172 | 173 | @property 174 | def pat(self): 175 | return self.kwargs["pat"] 176 | 177 | @property 178 | def expand(self): 179 | return self.kwargs["expand"] 180 | 181 | @functools.cached_property 182 | def _meta(self): 183 | delimiter = " " if self.pat is None else self.pat 184 | meta = meta_nonempty(self.frame._meta) 185 | meta = self.frame._meta._constructor( 186 | [delimiter.join(["a"] * (self.n + 1))], 187 | index=meta.iloc[:1].index, 188 | ) 189 | return make_meta( 190 | getattr(meta.str, self.attr)(n=self.n, expand=self.expand, pat=self.pat) 191 | ) 192 | -------------------------------------------------------------------------------- /dask_expr/_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from collections import OrderedDict, UserDict 5 | from collections.abc import Hashable, Iterable, Sequence 6 | from typing import Any, Literal, TypeVar, cast 7 | 8 | import dask 9 | import numpy as np 10 | import pandas as pd 11 | from dask import config 12 | from dask.dataframe._compat import is_string_dtype 13 | from dask.dataframe.core import is_dask_collection, is_dataframe_like, is_series_like 14 | from dask.tokenize import normalize_token, tokenize 15 | from dask.utils import get_default_shuffle_method 16 | from packaging.version import Version 17 | 18 | K = TypeVar("K", bound=Hashable) 19 | V = TypeVar("V") 20 | 21 | DASK_VERSION = Version(dask.__version__) 22 | DASK_GT_20231201 = DASK_VERSION > Version("2023.12.1") 23 | PANDAS_VERSION = Version(pd.__version__) 24 | PANDAS_GE_300 = PANDAS_VERSION.major >= 3 25 | 26 | 27 | def _calc_maybe_new_divisions(df, periods, freq): 28 | """Maybe calculate new divisions by periods of size freq 29 | 30 | Used to shift the divisions for the `shift` method. If freq isn't a fixed 31 | size (not anchored or relative), then the divisions are shifted 32 | appropriately. 33 | 34 | Returning None, indicates divisions ought to be cleared. 35 | 36 | Parameters 37 | ---------- 38 | df : dd.DataFrame, dd.Series, or dd.Index 39 | periods : int 40 | The number of periods to shift. 41 | freq : DateOffset, timedelta, or time rule string 42 | The frequency to shift by. 43 | """ 44 | if isinstance(freq, str): 45 | freq = pd.tseries.frequencies.to_offset(freq) 46 | 47 | is_offset = isinstance(freq, pd.DateOffset) 48 | if is_offset: 49 | if not isinstance(freq, pd.offsets.Tick): 50 | # Can't infer divisions on relative or anchored offsets, as 51 | # divisions may now split identical index value. 52 | # (e.g. index_partitions = [[1, 2, 3], [3, 4, 5]]) 53 | return None # Would need to clear divisions 54 | if df.known_divisions: 55 | divs = pd.Series(range(len(df.divisions)), index=df.divisions) 56 | divisions = divs.shift(periods, freq=freq).index 57 | return tuple(divisions) 58 | return df.divisions 59 | 60 | 61 | def _validate_axis(axis=0, none_is_zero: bool = True) -> None | Literal[0, 1]: 62 | if axis not in (0, 1, "index", "columns", None): 63 | raise ValueError(f"No axis named {axis}") 64 | # convert to numeric axis 65 | numeric_axis: dict[str | None, Literal[0, 1]] = {"index": 0, "columns": 1} 66 | if none_is_zero: 67 | numeric_axis[None] = 0 68 | 69 | return numeric_axis.get(axis, axis) 70 | 71 | 72 | def _convert_to_list(column) -> list | None: 73 | if column is None or isinstance(column, list): 74 | pass 75 | elif isinstance(column, tuple): 76 | column = list(column) 77 | elif hasattr(column, "dtype"): 78 | column = column.tolist() 79 | else: 80 | column = [column] 81 | return column 82 | 83 | 84 | def is_scalar(x): 85 | # np.isscalar does not work for some pandas scalars, for example pd.NA 86 | if isinstance(x, (Sequence, Iterable)) and not isinstance(x, str): 87 | return False 88 | elif hasattr(x, "dtype"): 89 | return isinstance(x, np.ScalarType) 90 | if isinstance(x, dict): 91 | return False 92 | if isinstance(x, (str, int)) or x is None: 93 | return True 94 | 95 | from dask_expr._expr import Expr 96 | 97 | return not isinstance(x, Expr) 98 | 99 | 100 | def _tokenize_deterministic(*args, **kwargs) -> str: 101 | # Utility to be strict about deterministic tokens 102 | return tokenize(*args, ensure_deterministic=True, **kwargs) 103 | 104 | 105 | def _tokenize_partial(expr, ignore: list | None = None) -> str: 106 | # Helper function to "tokenize" the operands 107 | # that are not in the `ignore` list 108 | ignore = ignore or [] 109 | return _tokenize_deterministic( 110 | *[ 111 | op 112 | for i, op in enumerate(expr.operands) 113 | if i >= len(expr._parameters) or expr._parameters[i] not in ignore 114 | ] 115 | ) 116 | 117 | 118 | class LRU(UserDict[K, V]): 119 | """Limited size mapping, evicting the least recently looked-up key when full""" 120 | 121 | def __init__(self, maxsize: float) -> None: 122 | super().__init__() 123 | self.data = OrderedDict() 124 | self.maxsize = maxsize 125 | 126 | def __getitem__(self, key: K) -> V: 127 | value = super().__getitem__(key) 128 | cast(OrderedDict, self.data).move_to_end(key) 129 | return value 130 | 131 | def __setitem__(self, key: K, value: V) -> None: 132 | if len(self) >= self.maxsize: 133 | cast(OrderedDict, self.data).popitem(last=False) 134 | super().__setitem__(key, value) 135 | 136 | 137 | class _BackendData: 138 | """Helper class to wrap backend data 139 | 140 | The primary purpose of this class is to provide 141 | caching outside the ``FromPandas`` class. 142 | """ 143 | 144 | def __init__(self, data): 145 | self._data = data 146 | self._division_info = LRU(10) 147 | 148 | @functools.cached_property 149 | def _token(self): 150 | from dask_expr._util import _tokenize_deterministic 151 | 152 | return _tokenize_deterministic(self._data) 153 | 154 | def __len__(self): 155 | return len(self._data) 156 | 157 | def __getattr__(self, key: str) -> Any: 158 | try: 159 | return object.__getattribute__(self, key) 160 | except AttributeError: 161 | # Return the underlying backend attribute 162 | return getattr(self._data, key) 163 | 164 | def __reduce__(self): 165 | return type(self), (self._data,) 166 | 167 | def __deepcopy__(self, memodict=None): 168 | return type(self)(self._data.copy()) 169 | 170 | 171 | @normalize_token.register(_BackendData) 172 | def normalize_data_wrapper(data): 173 | return data._token 174 | 175 | 176 | def _maybe_from_pandas(dfs): 177 | from dask_expr import from_pandas 178 | 179 | def _pd_series_or_dataframe(x): 180 | # `x` can be a cudf Series/DataFrame 181 | return not is_dask_collection(x) and (is_series_like(x) or is_dataframe_like(x)) 182 | 183 | dfs = [from_pandas(df, 1) if _pd_series_or_dataframe(df) else df for df in dfs] 184 | return dfs 185 | 186 | 187 | def _get_shuffle_preferring_order(shuffle): 188 | if shuffle is not None: 189 | return shuffle 190 | 191 | # Choose tasks over disk since it keeps the order 192 | shuffle = get_default_shuffle_method() 193 | if shuffle == "disk": 194 | return "tasks" 195 | 196 | return shuffle 197 | 198 | 199 | def _raise_if_object_series(x, funcname): 200 | """ 201 | Utility function to raise an error if an object column does not support 202 | a certain operation like `mean`. 203 | """ 204 | if x.ndim == 1 and hasattr(x, "dtype"): 205 | if x.dtype == object: 206 | raise ValueError("`%s` not supported with object series" % funcname) 207 | elif is_string_dtype(x): 208 | raise ValueError("`%s` not supported with string series" % funcname) 209 | 210 | 211 | def _is_any_real_numeric_dtype(arr_or_dtype): 212 | try: 213 | from pandas.api.types import is_any_real_numeric_dtype 214 | 215 | return is_any_real_numeric_dtype(arr_or_dtype) 216 | except ImportError: 217 | # Temporary/soft pandas<2 support to enable cudf dev 218 | # TODO: Remove `try` block after 4/2024 219 | from pandas.api.types import is_bool_dtype, is_complex_dtype, is_numeric_dtype 220 | 221 | return ( 222 | is_numeric_dtype(arr_or_dtype) 223 | and not is_complex_dtype(arr_or_dtype) 224 | and not is_bool_dtype(arr_or_dtype) 225 | ) 226 | 227 | 228 | def get_specified_shuffle(shuffle_method): 229 | # Take the config shuffle if given, otherwise defer evaluation until optimize 230 | return shuffle_method or config.get("dataframe.shuffle.method", None) 231 | -------------------------------------------------------------------------------- /dask_expr/array/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: skip_file 2 | 3 | from dask_expr.array import random 4 | from dask_expr.array.core import Array, asarray, from_array 5 | from dask_expr.array.reductions import ( 6 | mean, 7 | moment, 8 | nanmean, 9 | nanstd, 10 | nansum, 11 | nanvar, 12 | prod, 13 | std, 14 | sum, 15 | var, 16 | ) 17 | from dask_expr.array._creation import arange, linspace, ones, empty, zeros 18 | -------------------------------------------------------------------------------- /dask_expr/array/slicing.py: -------------------------------------------------------------------------------- 1 | import toolz 2 | from dask.array.optimization import fuse_slice 3 | from dask.array.slicing import normalize_slice, slice_array 4 | from dask.array.utils import meta_from_array 5 | from dask.utils import cached_property 6 | 7 | from dask_expr.array.core import Array 8 | 9 | 10 | class Slice(Array): 11 | _parameters = ["array", "index"] 12 | 13 | @property 14 | def _meta(self): 15 | return meta_from_array(self.array._meta, ndim=len(self.chunks)) 16 | 17 | @cached_property 18 | def _info(self): 19 | return slice_array( 20 | self._name, 21 | self.array._name, 22 | self.array.chunks, 23 | self.index, 24 | self.array.dtype.itemsize, 25 | ) 26 | 27 | def _layer(self): 28 | return self._info[0] 29 | 30 | @property 31 | def chunks(self): 32 | return self._info[1] 33 | 34 | def _simplify_down(self): 35 | if all( 36 | isinstance(idx, slice) and idx == slice(None, None, None) 37 | for idx in self.index 38 | ): 39 | return self.array 40 | if isinstance(self.array, Slice): 41 | return Slice( 42 | self.array.array, 43 | normalize_slice( 44 | fuse_slice(self.array.index, self.index), self.array.array.ndim 45 | ), 46 | ) 47 | 48 | if isinstance(self.array, Elemwise): 49 | index = self.index + (slice(None),) * (self.ndim - len(self.index)) 50 | args = [] 51 | for arg, ind in toolz.partition(2, self.array.args): 52 | if ind is None: 53 | args.append(arg) 54 | else: 55 | idx = tuple(index[self.array.out_ind.index(i)] for i in ind) 56 | args.append(arg[idx]) 57 | return Elemwise(*self.array.operands[: -len(args)], *args) 58 | 59 | if isinstance(self.array, Transpose): 60 | if any(isinstance(idx, (int)) or idx is None for idx in self.index): 61 | return None # can't handle changes in dimension 62 | else: 63 | index = self.index + (slice(None),) * (self.ndim - len(self.index)) 64 | new = tuple(index[i] for i in self.array.axes) 65 | return self.array.substitute(self.array.array, self.array.array[new]) 66 | 67 | 68 | from dask_expr.array.blockwise import Elemwise, Transpose 69 | -------------------------------------------------------------------------------- /dask_expr/array/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/array/tests/__init__.py -------------------------------------------------------------------------------- /dask_expr/array/tests/test_array.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import numpy as np 4 | import pytest 5 | from dask.array.utils import assert_eq 6 | 7 | import dask_expr.array as da 8 | 9 | 10 | def test_basic(): 11 | x = np.random.random((10, 10)) 12 | xx = da.from_array(x, chunks=(4, 4)) 13 | xx._meta 14 | xx.chunks 15 | repr(xx) 16 | 17 | assert_eq(x, xx) 18 | 19 | 20 | def test_rechunk(): 21 | a = np.random.random((10, 10)) 22 | b = da.from_array(a, chunks=(4, 4)) 23 | c = b.rechunk() 24 | assert c.npartitions == 1 25 | assert_eq(b, c) 26 | 27 | d = b.rechunk((3, 3)) 28 | assert d.npartitions == 16 29 | assert_eq(d, a) 30 | 31 | 32 | def test_rechunk_optimize(): 33 | a = np.random.random((10, 10)) 34 | b = da.from_array(a, chunks=(4, 4)) 35 | 36 | c = b.rechunk((2, 5)).rechunk((5, 2)) 37 | d = b.rechunk((5, 2)) 38 | 39 | assert c.optimize()._name == d.optimize()._name 40 | 41 | assert ( 42 | b.T.rechunk((5, 2)).optimize()._name == da.from_array(a, chunks=(2, 5)).T._name 43 | ) 44 | 45 | 46 | def test_rechunk_blockwise_optimize(): 47 | a = np.random.random((10, 10)) 48 | b = da.from_array(a, chunks=(4, 4)) 49 | 50 | result = (da.from_array(a, chunks=(4, 4)) + 1).rechunk((5, 5)) 51 | expected = da.from_array(a, chunks=(5, 5)) + 1 52 | assert result.optimize()._name == expected.optimize()._name 53 | 54 | a = np.random.random((10,)) 55 | aa = da.from_array(a) 56 | b = np.random.random((10, 10)) 57 | bb = da.from_array(b) 58 | 59 | c = (aa + bb).rechunk((5, 2)) 60 | result = c.optimize() 61 | expected = da.from_array(a, chunks=(2,)) + da.from_array(b, chunks=(5, 2)) 62 | assert result._name == expected._name 63 | 64 | a = np.random.random((10, 1)) 65 | aa = da.from_array(a) 66 | b = np.random.random((10, 10)) 67 | bb = da.from_array(b) 68 | 69 | c = (aa + bb).rechunk((5, 2)) 70 | result = c.optimize() 71 | 72 | expected = da.from_array(a, chunks=(5, 1)) + da.from_array(b, chunks=(5, 2)) 73 | assert result._name == expected._name 74 | 75 | 76 | def test_elemwise(): 77 | a = np.random.random((10, 10)) 78 | b = da.from_array(a, chunks=(4, 4)) 79 | 80 | (b + 1).compute() 81 | assert_eq(a + 1, b + 1) 82 | assert_eq(a + 2 * a, b + 2 * b) 83 | 84 | x = np.random.random(10) 85 | y = da.from_array(x, chunks=(4,)) 86 | 87 | assert_eq(a + x, b + y) 88 | 89 | 90 | def test_transpose(): 91 | a = np.random.random((10, 20)) 92 | b = da.from_array(a, chunks=(2, 5)) 93 | 94 | assert_eq(a.T, b.T) 95 | 96 | a = np.random.random((10, 1)) 97 | b = da.from_array(a, chunks=(5, 1)) 98 | assert_eq(a.T + a, b.T + b) 99 | assert_eq(a + a.T, b + b.T) 100 | 101 | assert b.T.T.optimize()._name == b.optimize()._name 102 | 103 | 104 | def test_slicing(): 105 | a = np.random.random((10, 20)) 106 | b = da.from_array(a, chunks=(2, 5)) 107 | 108 | assert_eq(a[:], b[:]) 109 | assert_eq(a[::2], b[::2]) 110 | assert_eq(a[1, :5], b[1, :5]) 111 | assert_eq(a[None, ..., ::5], b[None, ..., ::5]) 112 | assert_eq(a[3], b[3]) 113 | 114 | 115 | def test_slicing_optimization(): 116 | a = np.random.random((10, 20)) 117 | b = da.from_array(a, chunks=(2, 5)) 118 | 119 | assert b[:].optimize()._name == b._name 120 | assert b[5:, 4][::2].optimize()._name == b[5::2, 4].optimize()._name 121 | 122 | assert (b + 1)[:5].optimize()._name == (b[:5] + 1)._name 123 | assert (b + 1)[5].optimize()._name == (b[5] + 1)._name 124 | assert b.T[5:].optimize()._name == b[:, 5:].T._name 125 | 126 | 127 | def test_slicing_optimization_change_dimensionality(): 128 | a = np.random.random((10, 20)) 129 | b = da.from_array(a, chunks=(2, 5)) 130 | assert (b + 1)[5].optimize()._name == (b[5] + 1)._name 131 | 132 | 133 | def test_xarray(): 134 | pytest.importorskip("xarray") 135 | 136 | import xarray as xr 137 | 138 | a = np.random.random((10, 20)) 139 | b = da.from_array(a) 140 | 141 | x = (xr.DataArray(b, dims=["x", "y"]) + 1).chunk(x=2) 142 | 143 | assert x.data.optimize()._name == (da.from_array(a, chunks={0: 2}) + 1)._name 144 | 145 | 146 | def test_random(): 147 | x = da.random.random((100, 100), chunks=(50, 50)) 148 | assert_eq(x, x) 149 | 150 | 151 | @pytest.mark.parametrize( 152 | "reduction", 153 | ["sum", "mean", "var", "std", "any", "all", "prod", "min", "max"], 154 | ) 155 | def test_reductions(reduction): 156 | a = np.random.random((10, 20)) 157 | b = da.from_array(a, chunks=(2, 5)) 158 | 159 | def func(x, **kwargs): 160 | return getattr(x, reduction)(**kwargs) 161 | 162 | assert_eq(func(a), func(b)) 163 | assert_eq(func(a, axis=1), func(b, axis=1)) 164 | 165 | 166 | @pytest.mark.parametrize( 167 | "reduction", 168 | ["nanmean", "nansum"], 169 | ) 170 | def test_reduction_functions(reduction): 171 | a = np.random.random((10, 20)) 172 | b = da.from_array(a, chunks=(2, 5)) 173 | 174 | def func(x, **kwargs): 175 | if isinstance(x, np.ndarray): 176 | return getattr(np, reduction)(x, **kwargs) 177 | else: 178 | return getattr(da, reduction)(x, **kwargs) 179 | 180 | func(b).chunks 181 | 182 | assert_eq(func(a), func(b)) 183 | assert_eq(func(a, axis=1), func(b, axis=1)) 184 | 185 | 186 | @pytest.mark.parametrize( 187 | "ufunc", 188 | [np.sqrt, np.sin, np.exp], 189 | ) 190 | def test_ufunc(ufunc): 191 | a = np.random.random((10, 20)) 192 | b = da.from_array(a, chunks=(2, 5)) 193 | assert_eq(ufunc(a), ufunc(b)) 194 | 195 | 196 | @pytest.mark.parametrize( 197 | "op", 198 | [operator.add, operator.sub, operator.pow, operator.floordiv, operator.truediv], 199 | ) 200 | def test_binop(op): 201 | a = np.random.random((10, 20)) 202 | b = np.random.random(20) 203 | aa = da.from_array(a, chunks=(2, 5)) 204 | bb = da.from_array(b, chunks=5) 205 | 206 | assert_eq(op(a, b), op(aa, bb)) 207 | 208 | 209 | def test_asarray(): 210 | a = np.random.random((10, 20)) 211 | b = da.asarray(a) 212 | assert_eq(a, b) 213 | assert isinstance(b, da.Array) and type(b) == type(da.from_array(a)) 214 | 215 | 216 | def test_unify_chunks(): 217 | a = np.random.random((10, 20)) 218 | aa = da.asarray(a, chunks=(4, 5)) 219 | b = np.random.random((20,)) 220 | bb = da.asarray(b, chunks=(10,)) 221 | 222 | assert_eq(a + b, aa + bb) 223 | 224 | 225 | def test_array_function(): 226 | a = np.random.random((10, 20)) 227 | aa = da.asarray(a, chunks=(4, 5)) 228 | 229 | assert isinstance(np.nanmean(aa), da.Array) 230 | 231 | assert_eq(np.nanmean(aa), np.nanmean(a)) 232 | -------------------------------------------------------------------------------- /dask_expr/array/tests/test_creation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from dask.array import assert_eq 4 | 5 | import dask_expr.array as da 6 | 7 | 8 | def test_arange(): 9 | assert_eq(da.arange(1, 100, 7), np.arange(1, 100, 7)) 10 | assert_eq(da.arange(100), np.arange(100)) 11 | assert_eq(da.arange(100, like=np.arange(100)), np.arange(100)) 12 | 13 | 14 | def test_linspace(): 15 | assert_eq(da.linspace(1, 100, 30), np.linspace(1, 100, 30)) 16 | 17 | 18 | @pytest.mark.parametrize("func", ["ones", "zeros"]) 19 | def test_ones(func): 20 | assert_eq(getattr(da, func)((10, 20, 15)), getattr(np, func)((10, 20, 15))) 21 | assert_eq( 22 | getattr(da, func)((10, 20, 15), dtype="i8"), 23 | getattr(np, func)((10, 20, 15), dtype="i8"), 24 | ) 25 | -------------------------------------------------------------------------------- /dask_expr/datasets.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import operator 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from dask._task_spec import Task 7 | from dask.dataframe.utils import pyarrow_strings_enabled 8 | from dask.typing import Key 9 | 10 | from dask_expr._collection import new_collection 11 | from dask_expr._expr import ArrowStringConversion 12 | from dask_expr.io import BlockwiseIO, PartitionsFiltered 13 | 14 | __all__ = ["timeseries"] 15 | 16 | 17 | class Timeseries(PartitionsFiltered, BlockwiseIO): 18 | _absorb_projections = True 19 | 20 | _parameters = [ 21 | "start", 22 | "end", 23 | "dtypes", 24 | "freq", 25 | "partition_freq", 26 | "seed", 27 | "kwargs", 28 | "columns", 29 | "_partitions", 30 | "_series", 31 | ] 32 | _defaults = { 33 | "start": "2000-01-01", 34 | "end": "2000-01-31", 35 | "dtypes": {"name": "string", "id": int, "x": float, "y": float}, 36 | "freq": "1s", 37 | "partition_freq": "1d", 38 | "seed": None, 39 | "kwargs": {}, 40 | "_partitions": None, 41 | "_series": False, 42 | } 43 | 44 | @functools.cached_property 45 | def _meta(self): 46 | result = self._make_timeseries_part("2000", "2000", 0).iloc[:0] 47 | if self._series: 48 | return result[result.columns[0]] 49 | return result 50 | 51 | def _divisions(self): 52 | return pd.date_range(start=self.start, end=self.end, freq=self.partition_freq) 53 | 54 | @property 55 | def _dtypes(self): 56 | dtypes = self.operand("dtypes") 57 | return {col: dtypes[col] for col in self.operand("columns")} 58 | 59 | @functools.cached_property 60 | def random_state(self): 61 | npartitions = len(self._divisions()) - 1 62 | ndtypes = max(len(self.operand("dtypes")), 1) 63 | random_state = np.random.RandomState(self.seed) 64 | n = npartitions * ndtypes 65 | random_data = random_state.bytes(n * 4) # `n` 32-bit integers 66 | l = list(np.frombuffer(random_data, dtype=np.uint32).reshape((n,))) 67 | assert len(l) == n 68 | return l 69 | 70 | @functools.cached_property 71 | def _make_timeseries_part(self): 72 | return MakeTimeseriesPart( 73 | self.operand("dtypes"), 74 | list(self._dtypes.keys()), 75 | self.freq, 76 | self.kwargs, 77 | ) 78 | 79 | def _filtered_task(self, name: Key, index: int) -> Task: 80 | full_divisions = self._divisions() 81 | ndtypes = max(len(self.operand("dtypes")), 1) 82 | task = Task( 83 | name, 84 | self._make_timeseries_part, 85 | full_divisions[index], 86 | full_divisions[index + 1], 87 | self.random_state[index * ndtypes], 88 | ) 89 | if self._series: 90 | return Task(name, operator.getitem, task, self.operand("columns")[0]) 91 | return task 92 | 93 | 94 | names = [ 95 | "Alice", 96 | "Bob", 97 | "Charlie", 98 | "Dan", 99 | "Edith", 100 | "Frank", 101 | "George", 102 | "Hannah", 103 | "Ingrid", 104 | "Jerry", 105 | "Kevin", 106 | "Laura", 107 | "Michael", 108 | "Norbert", 109 | "Oliver", 110 | "Patricia", 111 | "Quinn", 112 | "Ray", 113 | "Sarah", 114 | "Tim", 115 | "Ursula", 116 | "Victor", 117 | "Wendy", 118 | "Xavier", 119 | "Yvonne", 120 | "Zelda", 121 | ] 122 | 123 | 124 | def make_string(n, rstate): 125 | return rstate.choice(names, size=n) 126 | 127 | 128 | def make_categorical(n, rstate): 129 | return pd.Categorical.from_codes(rstate.randint(0, len(names), size=n), names) 130 | 131 | 132 | def make_float(n, rstate): 133 | return rstate.rand(n) * 2 - 1 134 | 135 | 136 | def make_int(n, rstate, lam=1000): 137 | return rstate.poisson(lam, size=n) 138 | 139 | 140 | make = { 141 | float: make_float, 142 | int: make_int, 143 | str: make_string, 144 | object: make_string, 145 | "string": make_string, 146 | "category": make_categorical, 147 | } 148 | 149 | 150 | class MakeTimeseriesPart: 151 | def __init__(self, dtypes, columns, freq, kwargs): 152 | self.dtypes = dtypes 153 | self.columns = columns 154 | self.freq = freq 155 | self.kwargs = kwargs 156 | 157 | def __call__(self, start, end, state_data): 158 | dtypes = self.dtypes 159 | columns = self.columns 160 | freq = self.freq 161 | kwargs = self.kwargs 162 | state = np.random.RandomState(state_data) 163 | index = pd.date_range(start=start, end=end, freq=freq, name="timestamp") 164 | data = {} 165 | for k, dt in dtypes.items(): 166 | kws = { 167 | kk.rsplit("_", 1)[1]: v 168 | for kk, v in kwargs.items() 169 | if kk.rsplit("_", 1)[0] == k 170 | } 171 | # Note: we compute data for all dtypes in order, not just those in the output 172 | # columns. This ensures the same output given the same state_data, regardless 173 | # of whether there is any column projection. 174 | # cf. https://github.com/dask/dask/pull/9538#issuecomment-1267461887 175 | result = make[dt](len(index), state, **kws) 176 | if k in columns: 177 | data[k] = result 178 | df = pd.DataFrame(data, index=index, columns=columns) 179 | if df.index[-1] == end: 180 | df = df.iloc[:-1] 181 | return df 182 | 183 | 184 | def timeseries( 185 | start="2000-01-01", 186 | end="2000-01-31", 187 | freq="1s", 188 | partition_freq="1d", 189 | dtypes=None, 190 | seed=None, 191 | **kwargs, 192 | ): 193 | """Create timeseries dataframe with random data 194 | 195 | Parameters 196 | ---------- 197 | start: datetime (or datetime-like string) 198 | Start of time series 199 | end: datetime (or datetime-like string) 200 | End of time series 201 | dtypes: dict (optional) 202 | Mapping of column names to types. 203 | Valid types include {float, int, str, 'category'} 204 | freq: string 205 | String like '2s' or '1H' or '12W' for the time series frequency 206 | partition_freq: string 207 | String like '1M' or '2Y' to divide the dataframe into partitions 208 | seed: int (optional) 209 | Randomstate seed 210 | kwargs: 211 | Keywords to pass down to individual column creation functions. 212 | Keywords should be prefixed by the column name and then an underscore. 213 | 214 | Examples 215 | -------- 216 | >>> import dask_expr.datasets import timeseries 217 | >>> df = timeseries( 218 | ... start='2000', end='2010', 219 | ... dtypes={'value': float, 'name': str, 'id': int}, 220 | ... freq='2H', partition_freq='1D', seed=1 221 | ... ) 222 | >>> df.head() # doctest: +SKIP 223 | id name value 224 | 2000-01-01 00:00:00 969 Jerry -0.309014 225 | 2000-01-01 02:00:00 1010 Ray -0.760675 226 | 2000-01-01 04:00:00 1016 Patricia -0.063261 227 | 2000-01-01 06:00:00 960 Charlie 0.788245 228 | 2000-01-01 08:00:00 1031 Kevin 0.466002 229 | """ 230 | if dtypes is None: 231 | dtypes = {"name": "string", "id": int, "x": float, "y": float} 232 | 233 | if seed is None: 234 | seed = np.random.randint(2e9) 235 | 236 | expr = Timeseries( 237 | start, 238 | end, 239 | dtypes, 240 | freq, 241 | partition_freq, 242 | seed, 243 | kwargs, 244 | columns=list(dtypes.keys()), 245 | ) 246 | if pyarrow_strings_enabled(): 247 | return new_collection(ArrowStringConversion(expr)) 248 | return new_collection(expr) 249 | -------------------------------------------------------------------------------- /dask_expr/diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | from dask_expr.diagnostics._analyze import analyze 2 | from dask_expr.diagnostics._explain import explain 3 | 4 | __all__ = ["analyze", "explain"] 5 | -------------------------------------------------------------------------------- /dask_expr/diagnostics/_analyze.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from typing import TYPE_CHECKING, Any 5 | 6 | import pandas as pd 7 | from dask.base import DaskMethodsMixin 8 | from dask.sizeof import sizeof 9 | from dask.utils import format_bytes, import_required 10 | 11 | from dask_expr._expr import Blockwise, Expr 12 | from dask_expr._util import _tokenize_deterministic, is_scalar 13 | from dask_expr.diagnostics._explain import _add_graphviz_edges, _explain_info 14 | from dask_expr.io.io import FusedIO 15 | 16 | if TYPE_CHECKING: 17 | from dask_expr.diagnostics._analyze_plugin import ExpressionStatistics, Statistics 18 | 19 | 20 | def inject_analyze(expr: Expr, id: str, injected: dict) -> Expr: 21 | if expr._name in injected: 22 | return injected[expr._name] 23 | 24 | new_operands = [] 25 | for operand in expr.operands: 26 | if isinstance(operand, Expr) and not isinstance(expr, FusedIO): 27 | new = inject_analyze(operand, id, injected) 28 | injected[operand._name] = new 29 | else: 30 | new = operand 31 | new_operands.append(new) 32 | return Analyze(type(expr)(*new_operands), id, expr._name) 33 | 34 | 35 | def analyze( 36 | expr: Expr, filename: str | None = None, format: str | None = None, **kwargs: Any 37 | ): 38 | import_required( 39 | "distributed", 40 | "distributed is a required dependency for using the analyze method.", 41 | ) 42 | import_required( 43 | "crick", "crick is a required dependency for using the analyze method." 44 | ) 45 | graphviz = import_required( 46 | "graphviz", "graphviz is a required dependency for using the analyze method." 47 | ) 48 | from dask.dot import graphviz_to_file 49 | from distributed import get_client, wait 50 | 51 | from dask_expr import new_collection 52 | from dask_expr.diagnostics._analyze_plugin import AnalyzePlugin 53 | 54 | try: 55 | client = get_client() 56 | except ValueError: 57 | raise RuntimeError("analyze must be run in a distributed context.") 58 | client.register_plugin(AnalyzePlugin()) 59 | 60 | # TODO: Make this work with fuse=True 61 | expr = expr.optimize(fuse=False) 62 | 63 | analysis_id = expr._name 64 | 65 | # Inject analyze nodes 66 | injected = inject_analyze(expr, analysis_id, {}) 67 | out = new_collection(injected) 68 | _ = DaskMethodsMixin.compute(out, **kwargs) 69 | wait(_) 70 | 71 | # Collect data 72 | statistics: Statistics = client.sync( 73 | client.scheduler.analyze_get_statistics, id=analysis_id 74 | ) # type: noqa 75 | 76 | # Plot statistics in graph 77 | seen = set(expr._name) 78 | stack = [expr] 79 | 80 | if filename is None: 81 | filename = f"analyze-{expr._name}" 82 | 83 | if format is None: 84 | format = "svg" 85 | 86 | g = graphviz.Digraph(expr._name) 87 | g.node_attr.update(shape="record") 88 | while stack: 89 | node = stack.pop() 90 | info = _analyze_info(node, statistics._expr_statistics[node._name]) 91 | _add_graphviz_node(info, g) 92 | _add_graphviz_edges(info, g) 93 | 94 | if isinstance(node, FusedIO): 95 | continue 96 | for dep in node.operands: 97 | if not isinstance(dep, Expr) or dep._name in seen: 98 | continue 99 | seen.add(dep._name) 100 | stack.append(dep) 101 | graphviz_to_file(g, filename, format) 102 | return g 103 | 104 | 105 | def _add_graphviz_node(info, graph): 106 | label = "".join( 107 | [ 108 | "<{", 109 | info["label"], 110 | " | ", 111 | "
".join( 112 | [f"{key}: {value}" for key, value in info["details"].items()] 113 | ), 114 | " | ", 115 | _statistics_to_graphviz(info["statistics"]), 116 | "}>", 117 | ] 118 | ) 119 | 120 | graph.node(info["name"], label) 121 | 122 | 123 | def _statistics_to_graphviz(statistics: dict[str, dict[str, Any]]) -> str: 124 | return "

".join( 125 | [ 126 | _metric_to_graphviz(metric, statistics) 127 | for metric, statistics in statistics.items() 128 | ] 129 | ) 130 | 131 | 132 | _FORMAT_FNS = {"nbytes": format_bytes, "nrows": "{:,.0f}".format} 133 | 134 | 135 | def _metric_to_graphviz(metric: str, statistics: dict[str, Any]): 136 | format_fn = _FORMAT_FNS[metric] 137 | quantiles = ( 138 | "[" + ", ".join([format_fn(pctl) for pctl in statistics.pop("quantiles")]) + "]" 139 | ) 140 | count = statistics["count"] 141 | total = statistics["total"] 142 | 143 | return "
".join( 144 | [ 145 | f"{metric}:", 146 | f"{format_fn(total / count)} ({format_fn(total)} / {count:,})", 147 | f"{quantiles}", 148 | ] 149 | ) 150 | 151 | 152 | def _analyze_info(expr: Expr, statistics: ExpressionStatistics): 153 | info = _explain_info(expr) 154 | info["statistics"] = _statistics_info(statistics) 155 | return info 156 | 157 | 158 | def _statistics_info(statistics: ExpressionStatistics): 159 | info = {} 160 | for metric, digest in statistics._metric_digests.items(): 161 | info[metric] = { 162 | "total": digest.total, 163 | "count": digest.count, 164 | "quantiles": [digest.sketch.quantile(q) for q in (0, 0.25, 0.5, 0.75, 1)], 165 | } 166 | return info 167 | 168 | 169 | def collect_statistics(frame, analysis_id, expr_name): 170 | from dask_expr.diagnostics._analyze_plugin import get_worker_plugin 171 | 172 | worker_plugin = get_worker_plugin() 173 | if isinstance(frame, pd.DataFrame): 174 | size = frame.memory_usage(deep=True).sum() 175 | elif isinstance(frame, pd.Series): 176 | size = frame.memory_usage(deep=True) 177 | else: 178 | size = sizeof(frame) 179 | 180 | len_frame = len(frame) if not is_scalar(frame) else 1 181 | worker_plugin.add(analysis_id, expr_name, "nrows", len_frame) 182 | worker_plugin.add(analysis_id, expr_name, "nbytes", size) 183 | return frame 184 | 185 | 186 | class Analyze(Blockwise): 187 | _parameters = ["frame", "analysis_id", "expr_name"] 188 | 189 | operation = staticmethod(collect_statistics) 190 | 191 | @functools.cached_property 192 | def _meta(self): 193 | return self.frame._meta 194 | 195 | @functools.cached_property 196 | def _name(self): 197 | return "analyze-" + _tokenize_deterministic(*self.operands) 198 | -------------------------------------------------------------------------------- /dask_expr/diagnostics/_analyze_plugin.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import defaultdict 4 | from typing import TYPE_CHECKING, ClassVar 5 | 6 | from distributed import Scheduler, SchedulerPlugin, Worker, WorkerPlugin 7 | from distributed.protocol.pickle import dumps 8 | 9 | if TYPE_CHECKING: 10 | from crick import TDigest 11 | 12 | 13 | class Digest: 14 | count: int 15 | total: float 16 | sketch: TDigest 17 | 18 | def __init__(self) -> None: 19 | from crick import TDigest 20 | 21 | self.count = 0 22 | self.total = 0.0 23 | self.sketch = TDigest() 24 | 25 | def add(self, sample: float) -> None: 26 | self.count = self.count + 1 27 | self.total = self.total + sample 28 | self.sketch.add(sample) 29 | 30 | def merge(self, other: Digest) -> None: 31 | self.count = self.count + other.count 32 | self.total = self.total + other.total 33 | self.sketch.merge(other.sketch) 34 | 35 | @property 36 | def mean(self): 37 | return self.total / self.count 38 | 39 | 40 | class AnalyzePlugin(SchedulerPlugin): 41 | idempotent: ClassVar[bool] = True 42 | name: ClassVar[str] = "analyze" 43 | _scheduler: Scheduler | None 44 | 45 | def __init__(self) -> None: 46 | self._scheduler = None 47 | 48 | async def start(self, scheduler: Scheduler) -> None: 49 | self._scheduler = scheduler 50 | scheduler.handlers["analyze_get_statistics"] = self.get_statistics 51 | worker_plugin = _AnalyzeWorkerPlugin() 52 | await self._scheduler.register_worker_plugin( 53 | None, 54 | dumps(worker_plugin), 55 | name=worker_plugin.name, 56 | idempotent=True, 57 | ) 58 | 59 | async def get_statistics(self, id: str): 60 | assert self._scheduler is not None 61 | worker_statistics = await self._scheduler.broadcast( 62 | msg={"op": "analyze_get_statistics", "id": id} 63 | ) 64 | cluster_statistics = Statistics() 65 | for statistics in worker_statistics.values(): 66 | cluster_statistics.merge(statistics) 67 | return cluster_statistics 68 | 69 | 70 | class ExpressionStatistics: 71 | _metric_digests: defaultdict[str, Digest] 72 | 73 | def __init__(self) -> None: 74 | self._metric_digests = defaultdict(Digest) 75 | 76 | def add(self, metric: str, value: float) -> None: 77 | self._metric_digests[metric].add(value) 78 | 79 | def merge(self, other: ExpressionStatistics) -> None: 80 | for metric, digest in other._metric_digests.items(): 81 | self._metric_digests[metric].merge(digest) 82 | 83 | 84 | class Statistics: 85 | _expr_statistics: defaultdict[str, ExpressionStatistics] 86 | 87 | def __init__(self) -> None: 88 | self._expr_statistics = defaultdict(ExpressionStatistics) 89 | 90 | def add(self, expr: str, metric: str, value: float): 91 | self._expr_statistics[expr].add(metric, value) 92 | 93 | def merge(self, other: Statistics): 94 | for expr, statistics in other._expr_statistics.items(): 95 | self._expr_statistics[expr].merge(statistics) 96 | 97 | 98 | class _AnalyzeWorkerPlugin(WorkerPlugin): 99 | idempotent: ClassVar[bool] = True 100 | name: ClassVar[str] = "analyze" 101 | _statistics: defaultdict[str, Statistics] 102 | _worker: Worker | None 103 | 104 | def __init__(self) -> None: 105 | self._worker = None 106 | self._statistics = defaultdict(Statistics) 107 | 108 | def setup(self, worker: Worker) -> None: 109 | self._digests = defaultdict(lambda: defaultdict(lambda: defaultdict(Digest))) 110 | self._worker = worker 111 | self._worker.handlers["analyze_get_statistics"] = self.get_statistics 112 | 113 | def add(self, id: str, expr: str, metric: str, value: float): 114 | self._statistics[id].add(expr, metric, value) 115 | 116 | def get_statistics(self, id: str) -> Statistics: 117 | return self._statistics.pop(id) 118 | 119 | 120 | def get_worker_plugin() -> _AnalyzeWorkerPlugin: 121 | from distributed import get_worker 122 | 123 | try: 124 | worker = get_worker() 125 | except ValueError as e: 126 | raise RuntimeError( 127 | "``.analyze()`` requires Dask's distributed scheduler" 128 | ) from e 129 | 130 | try: 131 | return worker.plugins["analyze"] # type: ignore 132 | except KeyError as e: 133 | raise RuntimeError( 134 | f"The worker {worker.address} does not have an Analyze plugin." 135 | ) from e 136 | -------------------------------------------------------------------------------- /dask_expr/diagnostics/_explain.py: -------------------------------------------------------------------------------- 1 | from dask.utils import funcname, import_required 2 | 3 | from dask_expr._core import OptimizerStage 4 | from dask_expr._expr import Expr, Projection, optimize_until 5 | from dask_expr._merge import Merge 6 | from dask_expr.io.parquet import ReadParquet 7 | 8 | STAGE_LABELS: dict[OptimizerStage, str] = { 9 | "logical": "Logical Plan", 10 | "simplified-logical": "Simplified Logical Plan", 11 | "tuned-logical": "Tuned Logical Plan", 12 | "physical": "Physical Plan", 13 | "simplified-physical": "Simplified Physical Plan", 14 | "fused": "Fused Physical Plan", 15 | } 16 | 17 | 18 | def explain(expr: Expr, stage: OptimizerStage = "fused", format: str | None = None): 19 | graphviz = import_required( 20 | "graphviz", "graphviz is a required dependency for using the explain method." 21 | ) 22 | 23 | if format is None: 24 | format = "png" 25 | 26 | g = graphviz.Digraph( 27 | STAGE_LABELS[stage], filename=f"explain-{stage}-{expr._name}", format=format 28 | ) 29 | g.node_attr.update(shape="record") 30 | 31 | expr = optimize_until(expr, stage) 32 | 33 | seen = set(expr._name) 34 | stack = [expr] 35 | 36 | while stack: 37 | node = stack.pop() 38 | explain_info = _explain_info(node) 39 | _add_graphviz_node(explain_info, g) 40 | _add_graphviz_edges(explain_info, g) 41 | 42 | for dep in node.operands: 43 | if not isinstance(dep, Expr) or dep._name in seen: 44 | continue 45 | seen.add(dep._name) 46 | stack.append(dep) 47 | 48 | g.view() 49 | 50 | 51 | def _add_graphviz_node(explain_info, graph): 52 | label = "".join( 53 | [ 54 | "<{", 55 | explain_info["label"], 56 | " | ", 57 | "
".join( 58 | [f"{key}: {value}" for key, value in explain_info["details"].items()] 59 | ), 60 | "}>", 61 | ] 62 | ) 63 | 64 | graph.node(explain_info["name"], label) 65 | 66 | 67 | def _add_graphviz_edges(explain_info, graph): 68 | name = explain_info["name"] 69 | for _, dep in explain_info["dependencies"]: 70 | graph.edge(dep, name) 71 | 72 | 73 | def _explain_info(expr: Expr): 74 | return { 75 | "name": expr._name, 76 | "label": funcname(type(expr)), 77 | "details": _explain_details(expr), 78 | "dependencies": _explain_dependencies(expr), 79 | } 80 | 81 | 82 | def _explain_details(expr: Expr): 83 | details = {"npartitions": expr.npartitions} 84 | 85 | if isinstance(expr, Merge): 86 | details["how"] = expr.how 87 | elif isinstance(expr, ReadParquet): 88 | details["path"] = expr.path 89 | elif isinstance(expr, Projection): 90 | columns = expr.operand("columns") 91 | details["ncolumns"] = len(columns) if isinstance(columns, list) else "Series" 92 | 93 | return details 94 | 95 | 96 | def _explain_dependencies(expr: Expr) -> list[tuple[str, str]]: 97 | dependencies = [] 98 | for i, operand in enumerate(expr.operands): 99 | if not isinstance(operand, Expr): 100 | continue 101 | param = expr._parameters[i] if i < len(expr._parameters) else "" 102 | dependencies.append((str(param), operand._name)) 103 | return dependencies 104 | -------------------------------------------------------------------------------- /dask_expr/io/__init__.py: -------------------------------------------------------------------------------- 1 | from dask_expr.io.csv import * 2 | from dask_expr.io.io import * 3 | from dask_expr.io.parquet import * 4 | -------------------------------------------------------------------------------- /dask_expr/io/_delayed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from collections.abc import Iterable 5 | from typing import TYPE_CHECKING 6 | 7 | import pandas as pd 8 | from dask._task_spec import Alias, Task, TaskRef 9 | from dask.dataframe.dispatch import make_meta 10 | from dask.dataframe.utils import check_meta, pyarrow_strings_enabled 11 | from dask.delayed import Delayed, delayed 12 | from dask.typing import Key 13 | 14 | from dask_expr._expr import ArrowStringConversion, DelayedsExpr, PartitionsFiltered 15 | from dask_expr._util import _tokenize_deterministic 16 | from dask_expr.io import BlockwiseIO 17 | 18 | if TYPE_CHECKING: 19 | import distributed 20 | 21 | 22 | class FromDelayed(PartitionsFiltered, BlockwiseIO): 23 | _parameters = [ 24 | "delayed_container", 25 | "meta", 26 | "user_divisions", 27 | "verify_meta", 28 | "_partitions", 29 | "prefix", 30 | ] 31 | _defaults = { 32 | "meta": None, 33 | "_partitions": None, 34 | "user_divisions": None, 35 | "verify_meta": True, 36 | "prefix": None, 37 | } 38 | 39 | @functools.cached_property 40 | def _name(self): 41 | if self.prefix is None: 42 | return super()._name 43 | return self.prefix + "-" + _tokenize_deterministic(*self.operands) 44 | 45 | @functools.cached_property 46 | def _meta(self): 47 | if self.operand("meta") is not None: 48 | return self.operand("meta") 49 | 50 | return delayed(make_meta)(self.delayed_container.operands[0]).compute() 51 | 52 | def _divisions(self): 53 | if self.operand("user_divisions") is not None: 54 | return self.operand("user_divisions") 55 | else: 56 | return self.delayed_container.divisions 57 | 58 | def _filtered_task(self, name: Key, index: int) -> Task: 59 | if self.verify_meta: 60 | return Task( 61 | name, 62 | functools.partial(check_meta, meta=self._meta, funcname="from_delayed"), 63 | TaskRef((self.delayed_container._name, index)), 64 | ) 65 | else: 66 | return Alias((self.delayed_container._name, index)) 67 | 68 | 69 | def identity(x): 70 | return x 71 | 72 | 73 | def from_delayed( 74 | dfs: Delayed | distributed.Future | Iterable[Delayed | distributed.Future], 75 | meta=None, 76 | divisions: tuple | None = None, 77 | prefix: str | None = None, 78 | verify_meta: bool = True, 79 | ): 80 | """Create Dask DataFrame from many Dask Delayed objects 81 | 82 | .. warning:: 83 | ``from_delayed`` should only be used if the objects that create 84 | the data are complex and cannot be easily represented as a single 85 | function in an embarassingly parallel fashion. 86 | 87 | ``from_map`` is recommended if the query can be expressed as a single 88 | function like: 89 | 90 | def read_xml(path): 91 | return pd.read_xml(path) 92 | 93 | ddf = dd.from_map(read_xml, paths) 94 | 95 | ``from_delayed`` might be depreacted in the future. 96 | 97 | Parameters 98 | ---------- 99 | dfs : 100 | A ``dask.delayed.Delayed``, a ``distributed.Future``, or an iterable of either 101 | of these objects, e.g. returned by ``client.submit``. These comprise the 102 | individual partitions of the resulting dataframe. 103 | If a single object is provided (not an iterable), then the resulting dataframe 104 | will have only one partition. 105 | $META 106 | divisions : 107 | Partition boundaries along the index. 108 | For tuple, see https://docs.dask.org/en/latest/dataframe-design.html#partitions 109 | If None, then won't use index information 110 | prefix : 111 | Prefix to prepend to the keys. 112 | verify_meta : 113 | If True check that the partitions have consistent metadata, defaults to True. 114 | """ 115 | if isinstance(dfs, Delayed) or hasattr(dfs, "key"): 116 | dfs = [dfs] 117 | 118 | if len(dfs) == 0: 119 | raise TypeError("Must supply at least one delayed object") 120 | 121 | if meta is None: 122 | meta = delayed(make_meta)(dfs[0]).compute() 123 | 124 | if divisions == "sorted": 125 | raise NotImplementedError( 126 | "divisions='sorted' not supported, please calculate the divisions " 127 | "yourself." 128 | ) 129 | elif divisions is not None: 130 | divs = list(divisions) 131 | if len(divs) != len(dfs) + 1: 132 | raise ValueError("divisions should be a tuple of len(dfs) + 1") 133 | 134 | dfs = [ 135 | delayed(df) if not isinstance(df, Delayed) and hasattr(df, "key") else df 136 | for df in dfs 137 | ] 138 | 139 | for item in dfs: 140 | if not isinstance(item, Delayed): 141 | raise TypeError("Expected Delayed object, got %s" % type(item).__name__) 142 | 143 | from dask_expr._collection import new_collection 144 | 145 | result = FromDelayed( 146 | DelayedsExpr(*dfs), make_meta(meta), divisions, verify_meta, None, prefix 147 | ) 148 | if pyarrow_strings_enabled() and any( 149 | pd.api.types.is_object_dtype(dtype) 150 | for dtype in (result.dtypes.values if result.ndim == 2 else [result.dtypes]) 151 | ): 152 | return new_collection(ArrowStringConversion(result)) 153 | return new_collection(result) 154 | -------------------------------------------------------------------------------- /dask_expr/io/bag.py: -------------------------------------------------------------------------------- 1 | from dask.dataframe.io.io import _df_to_bag 2 | from dask.tokenize import tokenize 3 | 4 | from dask_expr import FrameBase 5 | 6 | 7 | def to_bag(df, index=False, format="tuple"): 8 | """Create Dask Bag from a Dask DataFrame 9 | 10 | Parameters 11 | ---------- 12 | index : bool, optional 13 | If True, the elements are tuples of ``(index, value)``, otherwise 14 | they're just the ``value``. Default is False. 15 | format : {"tuple", "dict", "frame"}, optional 16 | Whether to return a bag of tuples, dictionaries, or 17 | dataframe-like objects. Default is "tuple". If "frame", 18 | the original partitions of ``df`` will not be transformed 19 | in any way. 20 | 21 | 22 | Examples 23 | -------- 24 | >>> bag = df.to_bag() # doctest: +SKIP 25 | """ 26 | from dask.bag.core import Bag 27 | 28 | df = df.optimize() 29 | 30 | if not isinstance(df, FrameBase): 31 | raise TypeError("df must be either DataFrame or Series") 32 | name = "to_bag-" + tokenize(df._name, index, format) 33 | if format == "frame": 34 | dsk = df.dask 35 | name = df._name 36 | else: 37 | dsk = { 38 | (name, i): (_df_to_bag, block, index, format) 39 | for (i, block) in enumerate(df.__dask_keys__()) 40 | } 41 | dsk.update(df.__dask_graph__()) 42 | return Bag(dsk, name, df.npartitions) 43 | -------------------------------------------------------------------------------- /dask_expr/io/csv.py: -------------------------------------------------------------------------------- 1 | def to_csv( 2 | df, 3 | filename, 4 | single_file=False, 5 | encoding="utf-8", 6 | mode="wt", 7 | name_function=None, 8 | compression=None, 9 | compute=True, 10 | scheduler=None, 11 | storage_options=None, 12 | header_first_partition_only=None, 13 | compute_kwargs=None, 14 | **kwargs, 15 | ): 16 | """ 17 | Store Dask DataFrame to CSV files 18 | 19 | One filename per partition will be created. You can specify the 20 | filenames in a variety of ways. 21 | 22 | Use a globstring:: 23 | 24 | >>> df.to_csv('/path/to/data/export-*.csv') # doctest: +SKIP 25 | 26 | The * will be replaced by the increasing sequence 0, 1, 2, ... 27 | 28 | :: 29 | 30 | /path/to/data/export-0.csv 31 | /path/to/data/export-1.csv 32 | 33 | Use a globstring and a ``name_function=`` keyword argument. The 34 | name_function function should expect an integer and produce a string. 35 | Strings produced by name_function must preserve the order of their 36 | respective partition indices. 37 | 38 | >>> from datetime import date, timedelta 39 | >>> def name(i): 40 | ... return str(date(2015, 1, 1) + i * timedelta(days=1)) 41 | 42 | >>> name(0) 43 | '2015-01-01' 44 | >>> name(15) 45 | '2015-01-16' 46 | 47 | >>> df.to_csv('/path/to/data/export-*.csv', name_function=name) # doctest: +SKIP 48 | 49 | :: 50 | 51 | /path/to/data/export-2015-01-01.csv 52 | /path/to/data/export-2015-01-02.csv 53 | ... 54 | 55 | You can also provide an explicit list of paths:: 56 | 57 | >>> paths = ['/path/to/data/alice.csv', '/path/to/data/bob.csv', ...] # doctest: +SKIP 58 | >>> df.to_csv(paths) # doctest: +SKIP 59 | 60 | You can also provide a directory name: 61 | 62 | >>> df.to_csv('/path/to/data') # doctest: +SKIP 63 | 64 | The files will be numbered 0, 1, 2, (and so on) suffixed with '.part': 65 | 66 | :: 67 | 68 | /path/to/data/0.part 69 | /path/to/data/1.part 70 | 71 | Parameters 72 | ---------- 73 | df : dask.DataFrame 74 | Data to save 75 | filename : string or list 76 | Absolute or relative filepath(s). Prefix with a protocol like ``s3://`` 77 | to save to remote filesystems. 78 | single_file : bool, default False 79 | Whether to save everything into a single CSV file. Under the 80 | single file mode, each partition is appended at the end of the 81 | specified CSV file. 82 | encoding : string, default 'utf-8' 83 | A string representing the encoding to use in the output file. 84 | mode : str, default 'w' 85 | Python file mode. The default is 'w' (or 'wt'), for writing 86 | a new file or overwriting an existing file in text mode. 'a' 87 | (or 'at') will append to an existing file in text mode or 88 | create a new file if it does not already exist. See :py:func:`open`. 89 | name_function : callable, default None 90 | Function accepting an integer (partition index) and producing a 91 | string to replace the asterisk in the given filename globstring. 92 | Should preserve the lexicographic order of partitions. Not 93 | supported when ``single_file`` is True. 94 | compression : string, optional 95 | A string representing the compression to use in the output file, 96 | allowed values are 'gzip', 'bz2', 'xz', 97 | only used when the first argument is a filename. 98 | compute : bool, default True 99 | If True, immediately executes. If False, returns a set of delayed 100 | objects, which can be computed at a later time. 101 | storage_options : dict 102 | Parameters passed on to the backend filesystem class. 103 | header_first_partition_only : bool, default None 104 | If set to True, only write the header row in the first output 105 | file. By default, headers are written to all partitions under 106 | the multiple file mode (``single_file`` is False) and written 107 | only once under the single file mode (``single_file`` is True). 108 | It must be True under the single file mode. 109 | compute_kwargs : dict, optional 110 | Options to be passed in to the compute method 111 | kwargs : dict, optional 112 | Additional parameters to pass to :meth:`pandas.DataFrame.to_csv`. 113 | 114 | Returns 115 | ------- 116 | The names of the file written if they were computed right away. 117 | If not, the delayed tasks associated with writing the files. 118 | 119 | Raises 120 | ------ 121 | ValueError 122 | If ``header_first_partition_only`` is set to False or 123 | ``name_function`` is specified when ``single_file`` is True. 124 | 125 | See Also 126 | -------- 127 | fsspec.open_files 128 | """ 129 | from dask.dataframe.io.csv import to_csv as _to_csv 130 | 131 | return _to_csv( 132 | df.optimize(), 133 | filename, 134 | single_file=single_file, 135 | encoding=encoding, 136 | mode=mode, 137 | name_function=name_function, 138 | compression=compression, 139 | compute=compute, 140 | scheduler=scheduler, 141 | storage_options=storage_options, 142 | header_first_partition_only=header_first_partition_only, 143 | compute_kwargs=compute_kwargs, 144 | **kwargs, 145 | ) 146 | -------------------------------------------------------------------------------- /dask_expr/io/hdf.py: -------------------------------------------------------------------------------- 1 | def read_hdf( 2 | pattern, 3 | key, 4 | start=0, 5 | stop=None, 6 | columns=None, 7 | chunksize=1000000, 8 | sorted_index=False, 9 | lock=True, 10 | mode="r", 11 | ): 12 | from dask.dataframe.io import read_hdf as _read_hdf 13 | 14 | return _read_hdf( 15 | pattern, 16 | key, 17 | start=start, 18 | stop=stop, 19 | columns=columns, 20 | chunksize=chunksize, 21 | sorted_index=sorted_index, 22 | lock=lock, 23 | mode=mode, 24 | ) 25 | 26 | 27 | def to_hdf( 28 | df, 29 | path, 30 | key, 31 | mode="a", 32 | append=False, 33 | scheduler=None, 34 | name_function=None, 35 | compute=True, 36 | lock=None, 37 | dask_kwargs=None, 38 | **kwargs, 39 | ): 40 | """Store Dask Dataframe to Hierarchical Data Format (HDF) files 41 | 42 | This is a parallel version of the Pandas function of the same name. Please 43 | see the Pandas docstring for more detailed information about shared keyword 44 | arguments. 45 | 46 | This function differs from the Pandas version by saving the many partitions 47 | of a Dask DataFrame in parallel, either to many files, or to many datasets 48 | within the same file. You may specify this parallelism with an asterix 49 | ``*`` within the filename or datapath, and an optional ``name_function``. 50 | The asterix will be replaced with an increasing sequence of integers 51 | starting from ``0`` or with the result of calling ``name_function`` on each 52 | of those integers. 53 | 54 | This function only supports the Pandas ``'table'`` format, not the more 55 | specialized ``'fixed'`` format. 56 | 57 | Parameters 58 | ---------- 59 | path : string, pathlib.Path 60 | Path to a target filename. Supports strings, ``pathlib.Path``, or any 61 | object implementing the ``__fspath__`` protocol. May contain a ``*`` to 62 | denote many filenames. 63 | key : string 64 | Datapath within the files. May contain a ``*`` to denote many locations 65 | name_function : function 66 | A function to convert the ``*`` in the above options to a string. 67 | Should take in a number from 0 to the number of partitions and return a 68 | string. (see examples below) 69 | compute : bool 70 | Whether or not to execute immediately. If False then this returns a 71 | ``dask.Delayed`` value. 72 | lock : bool, Lock, optional 73 | Lock to use to prevent concurrency issues. By default a 74 | ``threading.Lock``, ``multiprocessing.Lock`` or ``SerializableLock`` 75 | will be used depending on your scheduler if a lock is required. See 76 | dask.utils.get_scheduler_lock for more information about lock 77 | selection. 78 | scheduler : string 79 | The scheduler to use, like "threads" or "processes" 80 | **other: 81 | See pandas.to_hdf for more information 82 | 83 | Examples 84 | -------- 85 | Save Data to a single file 86 | 87 | >>> df.to_hdf('output.hdf', '/data') # doctest: +SKIP 88 | 89 | Save data to multiple datapaths within the same file: 90 | 91 | >>> df.to_hdf('output.hdf', '/data-*') # doctest: +SKIP 92 | 93 | Save data to multiple files: 94 | 95 | >>> df.to_hdf('output-*.hdf', '/data') # doctest: +SKIP 96 | 97 | Save data to multiple files, using the multiprocessing scheduler: 98 | 99 | >>> df.to_hdf('output-*.hdf', '/data', scheduler='processes') # doctest: +SKIP 100 | 101 | Specify custom naming scheme. This writes files as 102 | '2000-01-01.hdf', '2000-01-02.hdf', '2000-01-03.hdf', etc.. 103 | 104 | >>> from datetime import date, timedelta 105 | >>> base = date(year=2000, month=1, day=1) 106 | >>> def name_function(i): 107 | ... ''' Convert integer 0 to n to a string ''' 108 | ... return base + timedelta(days=i) 109 | 110 | >>> df.to_hdf('*.hdf', '/data', name_function=name_function) # doctest: +SKIP 111 | 112 | Returns 113 | ------- 114 | filenames : list 115 | Returned if ``compute`` is True. List of file names that each partition 116 | is saved to. 117 | delayed : dask.Delayed 118 | Returned if ``compute`` is False. Delayed object to execute ``to_hdf`` 119 | when computed. 120 | 121 | See Also 122 | -------- 123 | read_hdf: 124 | to_parquet: 125 | """ 126 | from dask.dataframe.io import to_hdf as _to_hdf 127 | 128 | return _to_hdf( 129 | df.optimize(), 130 | path, 131 | key, 132 | mode=mode, 133 | append=append, 134 | scheduler=scheduler, 135 | name_function=name_function, 136 | compute=compute, 137 | lock=lock, 138 | dask_kwargs=dask_kwargs, 139 | **kwargs, 140 | ) 141 | -------------------------------------------------------------------------------- /dask_expr/io/json.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dask.dataframe.utils import insert_meta_param_description 3 | 4 | from dask_expr._backends import dataframe_creation_dispatch 5 | 6 | 7 | @dataframe_creation_dispatch.register_inplace("pandas") 8 | @insert_meta_param_description 9 | def read_json( 10 | url_path, 11 | orient="records", 12 | lines=None, 13 | storage_options=None, 14 | blocksize=None, 15 | sample=2**20, 16 | encoding="utf-8", 17 | errors="strict", 18 | compression="infer", 19 | meta=None, 20 | engine=pd.read_json, 21 | include_path_column=False, 22 | path_converter=None, 23 | **kwargs, 24 | ): 25 | """Create a dataframe from a set of JSON files 26 | 27 | This utilises ``pandas.read_json()``, and most parameters are 28 | passed through - see its docstring. 29 | 30 | Differences: orient is 'records' by default, with lines=True; this 31 | is appropriate for line-delimited "JSON-lines" data, the kind of JSON output 32 | that is most common in big-data scenarios, and which can be chunked when 33 | reading (see ``read_json()``). All other options require blocksize=None, 34 | i.e., one partition per input file. 35 | 36 | Parameters 37 | ---------- 38 | url_path: str, list of str 39 | Location to read from. If a string, can include a glob character to 40 | find a set of file names. 41 | Supports protocol specifications such as ``"s3://"``. 42 | encoding, errors: 43 | The text encoding to implement, e.g., "utf-8" and how to respond 44 | to errors in the conversion (see ``str.encode()``). 45 | orient, lines, kwargs 46 | passed to pandas; if not specified, lines=True when orient='records', 47 | False otherwise. 48 | storage_options: dict 49 | Passed to backend file-system implementation 50 | blocksize: None or int 51 | If None, files are not blocked, and you get one partition per input 52 | file. If int, which can only be used for line-delimited JSON files, 53 | each partition will be approximately this size in bytes, to the nearest 54 | newline character. 55 | sample: int 56 | Number of bytes to pre-load, to provide an empty dataframe structure 57 | to any blocks without data. Only relevant when using blocksize. 58 | encoding, errors: 59 | Text conversion, ``see bytes.decode()`` 60 | compression : string or None 61 | String like 'gzip' or 'xz'. 62 | engine : callable or str, default ``pd.read_json`` 63 | The underlying function that dask will use to read JSON files. By 64 | default, this will be the pandas JSON reader (``pd.read_json``). 65 | If a string is specified, this value will be passed under the ``engine`` 66 | key-word argument to ``pd.read_json`` (only supported for pandas>=2.0). 67 | include_path_column : bool or str, optional 68 | Include a column with the file path where each row in the dataframe 69 | originated. If ``True``, a new column is added to the dataframe called 70 | ``path``. If ``str``, sets new column name. Default is ``False``. 71 | path_converter : function or None, optional 72 | A function that takes one argument and returns a string. Used to convert 73 | paths in the ``path`` column, for instance, to strip a common prefix from 74 | all the paths. 75 | $META 76 | 77 | Returns 78 | ------- 79 | dask.DataFrame 80 | 81 | Examples 82 | -------- 83 | Load single file 84 | 85 | >>> dd.read_json('myfile.1.json') # doctest: +SKIP 86 | 87 | Load multiple files 88 | 89 | >>> dd.read_json('myfile.*.json') # doctest: +SKIP 90 | 91 | >>> dd.read_json(['myfile.1.json', 'myfile.2.json']) # doctest: +SKIP 92 | 93 | Load large line-delimited JSON files using partitions of approx 94 | 256MB size 95 | 96 | >> dd.read_json('data/file*.csv', blocksize=2**28) 97 | """ 98 | from dask.dataframe.io.json import read_json 99 | 100 | return read_json( 101 | url_path, 102 | orient=orient, 103 | lines=lines, 104 | storage_options=storage_options, 105 | blocksize=blocksize, 106 | sample=sample, 107 | encoding=encoding, 108 | errors=errors, 109 | compression=compression, 110 | meta=meta, 111 | engine=engine, 112 | include_path_column=include_path_column, 113 | path_converter=path_converter, 114 | **kwargs, 115 | ) 116 | 117 | 118 | def to_json( 119 | df, 120 | url_path, 121 | orient="records", 122 | lines=None, 123 | storage_options=None, 124 | compute=True, 125 | encoding="utf-8", 126 | errors="strict", 127 | compression=None, 128 | compute_kwargs=None, 129 | name_function=None, 130 | **kwargs, 131 | ): 132 | """Write dataframe into JSON text files 133 | 134 | This utilises ``pandas.DataFrame.to_json()``, and most parameters are 135 | passed through - see its docstring. 136 | 137 | Differences: orient is 'records' by default, with lines=True; this 138 | produces the kind of JSON output that is most common in big-data 139 | applications, and which can be chunked when reading (see ``read_json()``). 140 | 141 | Parameters 142 | ---------- 143 | df: dask.DataFrame 144 | Data to save 145 | url_path: str, list of str 146 | Location to write to. If a string, and there are more than one 147 | partitions in df, should include a glob character to expand into a 148 | set of file names, or provide a ``name_function=`` parameter. 149 | Supports protocol specifications such as ``"s3://"``. 150 | encoding, errors: 151 | The text encoding to implement, e.g., "utf-8" and how to respond 152 | to errors in the conversion (see ``str.encode()``). 153 | orient, lines, kwargs 154 | passed to pandas; if not specified, lines=True when orient='records', 155 | False otherwise. 156 | storage_options: dict 157 | Passed to backend file-system implementation 158 | compute: bool 159 | If true, immediately executes. If False, returns a set of delayed 160 | objects, which can be computed at a later time. 161 | compute_kwargs : dict, optional 162 | Options to be passed in to the compute method 163 | compression : string or None 164 | String like 'gzip' or 'xz'. 165 | name_function : callable, default None 166 | Function accepting an integer (partition index) and producing a 167 | string to replace the asterisk in the given filename globstring. 168 | Should preserve the lexicographic order of partitions. 169 | """ 170 | from dask.dataframe.io.json import to_json 171 | 172 | return to_json( 173 | df, 174 | url_path, 175 | orient=orient, 176 | lines=lines, 177 | storage_options=storage_options, 178 | compute=compute, 179 | encoding=encoding, 180 | errors=errors, 181 | compression=compression, 182 | compute_kwargs=compute_kwargs, 183 | name_function=name_function, 184 | **kwargs, 185 | ) 186 | -------------------------------------------------------------------------------- /dask_expr/io/orc.py: -------------------------------------------------------------------------------- 1 | from dask_expr._backends import dataframe_creation_dispatch 2 | 3 | 4 | @dataframe_creation_dispatch.register_inplace("pandas") 5 | def read_orc( 6 | path, 7 | engine="pyarrow", 8 | columns=None, 9 | index=None, 10 | split_stripes=1, 11 | aggregate_files=None, 12 | storage_options=None, 13 | ): 14 | """Read dataframe from ORC file(s) 15 | 16 | Parameters 17 | ---------- 18 | path: str or list(str) 19 | Location of file(s), which can be a full URL with protocol 20 | specifier, and may include glob character if a single string. 21 | engine: 'pyarrow' or ORCEngine 22 | Backend ORC engine to use for I/O. Default is "pyarrow". 23 | columns: None or list(str) 24 | Columns to load. If None, loads all. 25 | index: str 26 | Column name to set as index. 27 | split_stripes: int or False 28 | Maximum number of ORC stripes to include in each output-DataFrame 29 | partition. Use False to specify a 1-to-1 mapping between files 30 | and partitions. Default is 1. 31 | aggregate_files : bool, default False 32 | Whether distinct file paths may be aggregated into the same output 33 | partition. A setting of True means that any two file paths may be 34 | aggregated into the same output partition, while False means that 35 | inter-file aggregation is prohibited. 36 | storage_options: None or dict 37 | Further parameters to pass to the bytes backend. 38 | 39 | Returns 40 | ------- 41 | Dask.DataFrame (even if there is only one column) 42 | 43 | Examples 44 | -------- 45 | >>> df = dd.read_orc('https://github.com/apache/orc/raw/' 46 | ... 'master/examples/demo-11-zlib.orc') # doctest: +SKIP 47 | """ 48 | from dask.dataframe.io import read_orc as _read_orc 49 | 50 | return _read_orc( 51 | path, 52 | engine=engine, 53 | columns=columns, 54 | index=index, 55 | split_stripes=split_stripes, 56 | aggregate_files=aggregate_files, 57 | storage_options=storage_options, 58 | ) 59 | 60 | 61 | def to_orc( 62 | df, 63 | path, 64 | engine="pyarrow", 65 | write_index=True, 66 | storage_options=None, 67 | compute=True, 68 | compute_kwargs=None, 69 | ): 70 | from dask.dataframe.io import to_orc as _to_orc 71 | 72 | return _to_orc( 73 | df.optimize(), 74 | path, 75 | engine=engine, 76 | write_index=write_index, 77 | storage_options=storage_options, 78 | compute=compute, 79 | compute_kwargs=compute_kwargs, 80 | ) 81 | -------------------------------------------------------------------------------- /dask_expr/io/records.py: -------------------------------------------------------------------------------- 1 | from dask.utils import M 2 | 3 | 4 | def to_records(df): 5 | """Create Dask Array from a Dask Dataframe 6 | 7 | Warning: This creates a dask.array without precise shape information. 8 | Operations that depend on shape information, like slicing or reshaping, 9 | will not work. 10 | 11 | Examples 12 | -------- 13 | >>> df.to_records() # doctest: +SKIP 14 | 15 | See Also 16 | -------- 17 | dask.dataframe.DataFrame.values 18 | dask.dataframe.from_dask_array 19 | """ 20 | return df.map_partitions(M.to_records) 21 | -------------------------------------------------------------------------------- /dask_expr/io/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/io/tests/__init__.py -------------------------------------------------------------------------------- /dask_expr/io/tests/test_delayed.py: -------------------------------------------------------------------------------- 1 | import dask 2 | import numpy as np 3 | import pytest 4 | from dask import delayed 5 | 6 | from dask_expr import from_delayed, from_dict 7 | from dask_expr.tests._util import _backend_library, assert_eq 8 | 9 | pd = _backend_library() 10 | 11 | 12 | def test_from_delayed_optimizing(): 13 | parts = from_dict({"a": np.arange(300)}, npartitions=30).to_delayed() 14 | result = from_delayed(parts[0], meta=pd.DataFrame({"a": pd.Series(dtype=np.int64)})) 15 | assert len(result.optimize().dask) == 2 16 | assert_eq(result, pd.DataFrame({"a": pd.Series(np.arange(10))})) 17 | 18 | 19 | @pytest.mark.parametrize("prefix", [None, "foo"]) 20 | def test_from_delayed(prefix): 21 | pdf = pd.DataFrame( 22 | data=np.random.normal(size=(10, 4)), columns=["a", "b", "c", "d"] 23 | ) 24 | parts = [pdf.iloc[:1], pdf.iloc[1:3], pdf.iloc[3:6], pdf.iloc[6:10]] 25 | dfs = [delayed(parts.__getitem__)(i) for i in range(4)] 26 | 27 | df = from_delayed(dfs, meta=pdf.head(0), divisions=None, prefix=prefix) 28 | assert_eq(df, pdf) 29 | assert len({k[0] for k in df.optimize().dask}) == 2 30 | if prefix: 31 | assert df._name.startswith(prefix) 32 | 33 | divisions = tuple([p.index[0] for p in parts] + [parts[-1].index[-1]]) 34 | df = from_delayed(dfs, meta=pdf.head(0), divisions=divisions, prefix=prefix) 35 | assert_eq(df, pdf) 36 | if prefix: 37 | assert df._name.startswith(prefix) 38 | 39 | 40 | def test_from_delayed_dask(): 41 | df = pd.DataFrame(data=np.random.normal(size=(10, 4)), columns=list("abcd")) 42 | parts = [df.iloc[:1], df.iloc[1:3], df.iloc[3:6], df.iloc[6:10]] 43 | dfs = [delayed(parts.__getitem__)(i) for i in range(4)] 44 | meta = dfs[0].compute() 45 | 46 | my_len = lambda x: pd.Series([len(x)]) 47 | 48 | for divisions in [None, [0, 1, 3, 6, 10]]: 49 | ddf = from_delayed(dfs, meta=meta, divisions=divisions) 50 | assert_eq(ddf, df) 51 | assert list(ddf.map_partitions(my_len).compute()) == [1, 2, 3, 4] 52 | assert ddf.known_divisions == (divisions is not None) 53 | 54 | s = from_delayed([d.a for d in dfs], meta=meta.a, divisions=divisions) 55 | assert_eq(s, df.a) 56 | assert list(s.map_partitions(my_len).compute()) == [1, 2, 3, 4] 57 | assert ddf.known_divisions == (divisions is not None) 58 | 59 | 60 | @dask.delayed 61 | def _load(x): 62 | return pd.DataFrame({"x": x, "y": [1, 2, 3]}) 63 | 64 | 65 | def test_from_delayed_fusion(): 66 | func = lambda x: None 67 | df = from_delayed([_load(x) for x in range(10)], meta={"x": "int64", "y": "int64"}) 68 | result = df.map_partitions(func, meta={}).optimize().dask 69 | expected = df.map_partitions(func, meta={}).optimize(fuse=False).dask 70 | assert result.keys() == expected.keys() 71 | 72 | expected = df.map_partitions(func, meta={}).lower_completely().dask 73 | assert result.keys() == expected.keys() 74 | -------------------------------------------------------------------------------- /dask_expr/io/tests/test_distributed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | from dask_expr import read_parquet 8 | from dask_expr.tests._util import _backend_library, assert_eq 9 | 10 | distributed = pytest.importorskip("distributed") 11 | 12 | from distributed import Client, LocalCluster 13 | from distributed.utils_test import client as c # noqa F401 14 | from distributed.utils_test import gen_cluster 15 | 16 | import dask_expr as dx 17 | 18 | pd = _backend_library() 19 | 20 | 21 | @pytest.fixture(params=["arrow"]) 22 | def filesystem(request): 23 | return request.param 24 | 25 | 26 | def _make_file(dir, df=None, filename="myfile.parquet", **kwargs): 27 | fn = os.path.join(str(dir), filename) 28 | if df is None: 29 | df = pd.DataFrame({c: range(10) for c in "abcde"}) 30 | df.to_parquet(fn, **kwargs) 31 | return fn 32 | 33 | 34 | def test_io_fusion_merge(tmpdir): 35 | pdf = pd.DataFrame({c: range(100) for c in "abcdefghij"}) 36 | with LocalCluster(processes=False, n_workers=2) as cluster: 37 | with Client(cluster) as client: # noqa: F841 38 | dx.from_pandas(pdf, 2).to_parquet(tmpdir) 39 | 40 | df = dx.read_parquet(tmpdir).merge( 41 | dx.read_parquet(tmpdir).add_suffix("_x"), left_on="a", right_on="a_x" 42 | )[["a_x", "b_x", "b"]] 43 | out = df.compute() 44 | pd.testing.assert_frame_equal( 45 | out.sort_values(by="a_x", ignore_index=True), 46 | pdf.merge(pdf.add_suffix("_x"), left_on="a", right_on="a_x")[ 47 | ["a_x", "b_x", "b"] 48 | ], 49 | ) 50 | 51 | 52 | @pytest.mark.filterwarnings("error") 53 | @gen_cluster(client=True) 54 | async def test_parquet_distriuted(c, s, a, b, tmpdir, filesystem): 55 | pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) 56 | df = read_parquet(_make_file(tmpdir, df=pdf), filesystem=filesystem) 57 | assert_eq(await c.gather(c.compute(df.optimize())), pdf) 58 | 59 | 60 | def test_pickle_size(tmpdir, filesystem): 61 | pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) 62 | [_make_file(tmpdir, df=pdf, filename=f"{x}.parquet") for x in range(10)] 63 | df = read_parquet(tmpdir, filesystem=filesystem) 64 | from distributed.protocol import dumps 65 | 66 | assert len(b"".join(dumps(df.optimize().dask))) <= 9100 67 | -------------------------------------------------------------------------------- /dask_expr/io/tests/test_from_pandas.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import dask 4 | import pytest 5 | from dask.dataframe.utils import assert_eq 6 | 7 | from dask_expr import from_pandas, repartition 8 | from dask_expr.tests._util import _backend_library 9 | 10 | pd = _backend_library() 11 | 12 | 13 | @pytest.fixture(params=["Series", "DataFrame"]) 14 | def pdf(request): 15 | out = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]}) 16 | return out["x"] if request.param == "Series" else out 17 | 18 | 19 | @pytest.mark.parametrize("sort", [True, False]) 20 | def test_from_pandas(pdf, sort): 21 | df = from_pandas(pdf, npartitions=2, sort=sort) 22 | 23 | assert df.npartitions == 2 24 | assert df.divisions == (0, 3, 5) 25 | assert_eq(df, pdf, sort_results=sort) 26 | 27 | 28 | def test_from_pandas_noargs(pdf): 29 | df = from_pandas(pdf) 30 | 31 | assert df.npartitions == 1 32 | assert df.divisions == (0, 5) 33 | assert_eq(df, pdf) 34 | 35 | 36 | def test_from_pandas_empty(pdf): 37 | pdf = pdf.iloc[:0] 38 | df = from_pandas(pdf, npartitions=2) 39 | assert_eq(pdf, df) 40 | 41 | 42 | def test_from_pandas_immutable(pdf): 43 | expected = pdf.copy() 44 | df = from_pandas(pdf) 45 | pdf.iloc[0] = 100 46 | assert_eq(df, expected) 47 | 48 | 49 | def test_from_pandas_sort_and_different_partitions(): 50 | pdf = pd.DataFrame({"a": [1, 2, 3] * 3, "b": 1}).set_index("a") 51 | df = from_pandas(pdf, npartitions=4, sort=True) 52 | assert_eq(pdf.sort_index(), df, sort_results=False) 53 | 54 | pdf = pd.DataFrame({"a": [1, 2, 3] * 3, "b": 1}).set_index("a") 55 | df = from_pandas(pdf, npartitions=4, sort=False) 56 | assert_eq(pdf, df, sort_results=False) 57 | 58 | 59 | def test_from_pandas_sort(): 60 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[6, 5, 4, 3, 2, 1]) 61 | df = from_pandas(pdf, npartitions=2) 62 | assert_eq(df, pdf.sort_index(), sort_results=False) 63 | 64 | 65 | def test_from_pandas_divisions(): 66 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[7, 6, 4, 3, 2, 1]) 67 | df = repartition(pdf, (1, 5, 8)) 68 | assert_eq(df, pdf.sort_index()) 69 | 70 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[7, 6, 4, 3, 2, 1]) 71 | df = repartition(pdf, (1, 4, 8)) 72 | assert_eq(df.partitions[1], pd.DataFrame({"a": [3, 2, 1]}, index=[4, 6, 7])) 73 | 74 | df = repartition(df, divisions=(1, 3, 8), force=True) 75 | assert_eq(df, pdf.sort_index()) 76 | 77 | 78 | def test_from_pandas_empty_projection(): 79 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1}) 80 | df = from_pandas(pdf) 81 | assert_eq(df[[]], pdf[[]]) 82 | 83 | 84 | def test_from_pandas_divisions_duplicated(): 85 | pdf = pd.DataFrame({"a": 1}, index=[1, 2, 3, 4, 5, 5, 5, 6, 8]) 86 | df = repartition(pdf, (1, 5, 7, 10)) 87 | assert_eq(df, pdf) 88 | assert_eq(df.partitions[0], pdf.loc[1:4]) 89 | assert_eq(df.partitions[1], pdf.loc[5:6]) 90 | assert_eq(df.partitions[2], pdf.loc[8:]) 91 | 92 | 93 | @pytest.mark.parametrize("npartitions", [1, 3, 6, 7]) 94 | @pytest.mark.parametrize("sort", [True, False]) 95 | def test_from_pandas_npartitions(pdf, npartitions, sort): 96 | df = from_pandas(pdf, sort=sort, npartitions=npartitions) 97 | assert df.npartitions == min(pdf.shape[0], npartitions) 98 | assert "pandas" in df._name 99 | assert_eq(df, pdf, sort_results=sort) 100 | 101 | 102 | @pytest.mark.parametrize("chunksize,npartitions", [(1, 6), (2, 3), (6, 1), (7, 1)]) 103 | @pytest.mark.parametrize("sort", [True, False]) 104 | def test_from_pandas_chunksize(pdf, chunksize, npartitions, sort): 105 | df = from_pandas(pdf, sort=sort, chunksize=chunksize) 106 | assert df.npartitions == npartitions 107 | assert "pandas" in df._name 108 | assert_eq(df, pdf, sort_results=sort) 109 | 110 | 111 | def test_from_pandas_npartitions_and_chunksize(pdf): 112 | with pytest.raises(ValueError, match="npartitions and chunksize"): 113 | from_pandas(pdf, npartitions=2, chunksize=3) 114 | 115 | 116 | def test_from_pandas_string_option(): 117 | pdf = pd.DataFrame({"x": [1, 2, 3], "y": "a"}, index=["a", "b", "c"]) 118 | df = from_pandas(pdf, npartitions=2) 119 | assert df.dtypes["y"] == "string" 120 | assert df.index.dtype == "string" 121 | assert df.compute().dtypes["y"] == "string" 122 | assert df.compute().index.dtype == "string" 123 | assert_eq(df, pdf) 124 | 125 | with dask.config.set({"dataframe.convert-string": False}): 126 | df = from_pandas(pdf, npartitions=2) 127 | assert df.dtypes["y"] == "object" 128 | assert df.index.dtype == "object" 129 | assert df.compute().dtypes["y"] == "object" 130 | assert df.compute().index.dtype == "object" 131 | assert_eq(df, pdf) 132 | 133 | 134 | def test_from_pandas_deepcopy(): 135 | pdf = pd.DataFrame({"col1": [1, 2, 3, 4, 5, 6]}) 136 | df = from_pandas(pdf, npartitions=3) 137 | df_dict = {"dataset": df} 138 | result = copy.deepcopy(df_dict) 139 | assert_eq(result["dataset"], pdf) 140 | 141 | 142 | def test_from_pandas_empty_chunksize(): 143 | pdf = pd.DataFrame() 144 | df = from_pandas(pdf, chunksize=10_000) 145 | assert_eq(pdf, df) 146 | -------------------------------------------------------------------------------- /dask_expr/io/tests/test_sql.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from dask.utils import tmpfile 3 | 4 | from dask_expr import from_pandas, read_sql_table 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | pd = _backend_library() 8 | 9 | pytest.importorskip("sqlalchemy") 10 | 11 | 12 | def test_shuffle_after_read_sql(): 13 | with tmpfile() as f: 14 | uri = "sqlite:///%s" % f 15 | 16 | df = pd.DataFrame( 17 | { 18 | "id": [1, 2, 3, 4, 5, 6, 7, 8], 19 | "value": [ 20 | "value1", 21 | "value2", 22 | "value3", 23 | "value3", 24 | "value4", 25 | "value4", 26 | "value4", 27 | "value5", 28 | ], 29 | } 30 | ).set_index("id") 31 | ddf = from_pandas(df, npartitions=1) 32 | 33 | ddf.to_sql("test_table", uri, if_exists="append") 34 | result = read_sql_table("test_table", con=uri, index_col="id") 35 | assert_eq( 36 | result["value"].unique(), pd.Series(df["value"].unique(), name="value") 37 | ) 38 | assert_eq(result.shuffle(on_index=True), df) 39 | -------------------------------------------------------------------------------- /dask_expr/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/tests/__init__.py -------------------------------------------------------------------------------- /dask_expr/tests/_util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import pickle 3 | 4 | import pytest 5 | from dask import config 6 | from dask.dataframe.utils import assert_eq as dd_assert_eq 7 | 8 | 9 | def _backend_name() -> str: 10 | return config.get("dataframe.backend", "pandas") 11 | 12 | 13 | def _backend_library(): 14 | return importlib.import_module(_backend_name()) 15 | 16 | 17 | def xfail_gpu(reason=None): 18 | condition = _backend_name() == "cudf" 19 | reason = reason or "Failure expected for cudf backend." 20 | return pytest.mark.xfail(condition, reason=reason) 21 | 22 | 23 | def assert_eq(a, b, *args, serialize_graph=True, **kwargs): 24 | if serialize_graph: 25 | # Check that no `Expr` instances are found in 26 | # the graph generated by `Expr.dask` 27 | with config.set({"dask-expr-no-serialize": True}): 28 | for obj in [a, b]: 29 | if hasattr(obj, "dask"): 30 | try: 31 | pickle.dumps(obj.dask) 32 | except AttributeError: 33 | try: 34 | import cloudpickle as cp 35 | 36 | cp.dumps(obj.dask) 37 | except ImportError: 38 | pass 39 | 40 | # Use `dask.dataframe.assert_eq` 41 | return dd_assert_eq(a, b, *args, **kwargs) 42 | -------------------------------------------------------------------------------- /dask_expr/tests/test_align_partitions.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from itertools import product 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from dask_expr import from_pandas 8 | from dask_expr._expr import OpAlignPartitions 9 | from dask_expr._repartition import RepartitionDivisions 10 | from dask_expr._shuffle import Shuffle, divisions_lru 11 | from dask_expr.tests._util import _backend_library, assert_eq 12 | 13 | # Set DataFrame backend for this module 14 | pd = _backend_library() 15 | 16 | 17 | @pytest.fixture 18 | def pdf(): 19 | pdf = pd.DataFrame({"x": range(100)}) 20 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions 21 | yield pdf 22 | 23 | 24 | @pytest.fixture 25 | def df(pdf): 26 | yield from_pandas(pdf, npartitions=10) 27 | 28 | 29 | @pytest.mark.parametrize("op", ["__add__", "add"]) 30 | def test_broadcasting_scalar(pdf, df, op): 31 | df2 = from_pandas(pdf, npartitions=2) 32 | result = getattr(df, op)(df2.x.sum()) 33 | assert_eq(result, pdf + pdf.x.sum()) 34 | assert len(list(result.expr.find_operations(OpAlignPartitions))) == 0 35 | 36 | divisions_lru.data = OrderedDict() 37 | result = getattr(df.set_index("x"), op)(df2.x.sum()) 38 | # Make sure that we don't touch divisions 39 | assert len(divisions_lru.data) == 0 40 | assert_eq(result, pdf.set_index("x") + pdf.x.sum()) 41 | assert len(list(result.expr.find_operations(OpAlignPartitions))) == 0 42 | 43 | if op == "__add__": 44 | # Can't avoid expensive alignment check, but don't touch divisions while figuring it out 45 | divisions_lru.data = OrderedDict() 46 | result = getattr(df.set_index("x"), op)(df2.set_index("x").sum()) 47 | # Make sure that we don't touch divisions 48 | assert len(divisions_lru.data) == 0 49 | assert_eq(result, pdf.set_index("x") + pdf.set_index("x").sum()) 50 | assert len(list(result.expr.find_operations(OpAlignPartitions))) > 0 51 | 52 | assert ( 53 | len( 54 | list( 55 | result.optimize(fuse=False).expr.find_operations( 56 | RepartitionDivisions 57 | ) 58 | ) 59 | ) 60 | == 0 61 | ) 62 | 63 | # Can't avoid alignment, but don't touch divisions while figuring it out 64 | divisions_lru.data = OrderedDict() 65 | result = getattr(df.set_index("x"), op)(df2.set_index("x")) 66 | # Make sure that we don't touch divisions 67 | assert len(divisions_lru.data) == 0 68 | assert_eq(result, pdf.set_index("x") + pdf.set_index("x")) 69 | assert len(list(result.expr.find_operations(OpAlignPartitions))) > 0 70 | 71 | assert ( 72 | len( 73 | list(result.optimize(fuse=False).expr.find_operations(RepartitionDivisions)) 74 | ) 75 | > 0 76 | ) 77 | 78 | 79 | @pytest.mark.parametrize("sorted_index", [False, True]) 80 | @pytest.mark.parametrize("sorted_map_index", [False, True]) 81 | def test_series_map(sorted_index, sorted_map_index): 82 | base = pd.Series( 83 | ["".join(np.random.choice(["a", "b", "c"], size=3)) for x in range(100)] 84 | ) 85 | if not sorted_index: 86 | index = np.arange(100) 87 | np.random.shuffle(index) 88 | base.index = index 89 | map_index = ["".join(x) for x in product("abc", repeat=3)] 90 | mapper = pd.Series(np.random.randint(50, size=len(map_index)), index=map_index) 91 | if not sorted_map_index: 92 | map_index = np.array(map_index) 93 | np.random.shuffle(map_index) 94 | mapper.index = map_index 95 | expected = base.map(mapper) 96 | dask_base = from_pandas(base, npartitions=1, sort=False) 97 | dask_map = from_pandas(mapper, npartitions=1, sort=False) 98 | result = dask_base.map(dask_map) 99 | assert_eq(expected, result) 100 | 101 | 102 | def test_assign_align_partitions(): 103 | pdf = pd.DataFrame({"x": [0] * 20, "y": range(20)}) 104 | df = from_pandas(pdf, npartitions=2) 105 | s = pd.Series(range(10, 30)) 106 | ds = from_pandas(s, npartitions=df.npartitions) 107 | result = df.assign(z=ds)[["y", "z"]] 108 | expected = pdf.assign(z=s)[["y", "z"]] 109 | assert_eq(result, expected) 110 | 111 | 112 | def test_assign_unknown_partitions(pdf): 113 | pdf2 = pdf.sort_index(ascending=False) 114 | df2 = from_pandas(pdf2, npartitions=3, sort=False) 115 | df1 = from_pandas(pdf, npartitions=3).clear_divisions() 116 | df1["new"] = df2.x 117 | expected = pdf.copy() 118 | expected["new"] = pdf2.x 119 | assert_eq(df1, expected) 120 | assert len(list(df1.optimize(fuse=False).expr.find_operations(Shuffle))) == 2 121 | 122 | pdf["c"] = "a" 123 | pdf = pdf.set_index("c") 124 | df = from_pandas(pdf, npartitions=3) 125 | df["new"] = df2.x 126 | with pytest.raises(TypeError, match="have differing dtypes"): 127 | df.optimize() 128 | -------------------------------------------------------------------------------- /dask_expr/tests/test_categorical.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr import from_pandas 4 | from dask_expr._categorical import GetCategories 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | # Set DataFrame backend for this module 8 | pd = _backend_library() 9 | 10 | 11 | @pytest.fixture 12 | def pdf(): 13 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 1, 2], "y": "bcbbbc"}) 14 | return pdf 15 | 16 | 17 | @pytest.fixture 18 | def df(pdf): 19 | yield from_pandas(pdf, npartitions=2) 20 | 21 | 22 | def test_set_categories(pdf): 23 | pdf = pdf.astype("category") 24 | df = from_pandas(pdf, npartitions=2) 25 | assert df.x.cat.known 26 | assert_eq(df.x.cat.codes, pdf.x.cat.codes) 27 | ser = df.x.cat.as_unknown() 28 | assert not ser.cat.known 29 | ser = ser.cat.as_known() 30 | assert_eq(ser.cat.categories, pd.Index([1, 2, 3, 4])) 31 | ser = ser.cat.set_categories([1, 2, 3, 5, 4]) 32 | assert_eq(ser.cat.categories, pd.Index([1, 2, 3, 5, 4])) 33 | assert not ser.cat.ordered 34 | 35 | 36 | def test_categorize(df, pdf): 37 | df = df.categorize() 38 | 39 | assert df.y.cat.known 40 | assert_eq(df, pdf.astype({"y": "category"}), check_categorical=False) 41 | 42 | 43 | def test_get_categories_simplify_adds_projection(df): 44 | optimized = GetCategories( 45 | df, columns=["y"], index=False, split_every=None 46 | ).simplify() 47 | expected = GetCategories( 48 | df[["y"]].simplify(), columns=["y"], index=False, split_every=None 49 | ) 50 | assert optimized._name == expected._name 51 | 52 | 53 | def test_categorical_set_index(): 54 | df = pd.DataFrame({"x": [1, 2, 3, 4], "y": ["a", "b", "b", "c"]}) 55 | df["y"] = pd.Categorical(df["y"], categories=["a", "b", "c"], ordered=True) 56 | a = from_pandas(df, npartitions=2) 57 | 58 | b = a.set_index("y", divisions=["a", "b", "c"], npartitions=a.npartitions) 59 | d1, d2 = b.get_partition(0), b.get_partition(1) 60 | assert list(d1.index.compute(fuse=False)) == ["a"] 61 | assert list(sorted(d2.index.compute())) == ["b", "b", "c"] 62 | 63 | 64 | def test_categorize_drops_category_columns(): 65 | pdf = pd.DataFrame({"a": [1, 2, 1, 2, 3], "b": 1}) 66 | df = from_pandas(pdf) 67 | df = df.categorize(columns=["a"]) 68 | result = df["b"].to_frame() 69 | expected = pdf["b"].to_frame() 70 | assert_eq(result, expected) 71 | -------------------------------------------------------------------------------- /dask_expr/tests/test_core.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr._core import Expr 4 | 5 | 6 | class ExprB(Expr): 7 | def _simplify_down(self): 8 | return ExprA() 9 | 10 | 11 | class ExprA(Expr): 12 | def _simplify_down(self): 13 | return ExprB() 14 | 15 | 16 | def test_endless_simplify(): 17 | expr = ExprA() 18 | with pytest.raises(RuntimeError, match="converge"): 19 | expr.simplify() 20 | -------------------------------------------------------------------------------- /dask_expr/tests/test_cumulative.py: -------------------------------------------------------------------------------- 1 | from dask_expr import from_pandas 2 | from dask_expr.tests._util import _backend_library, assert_eq 3 | 4 | # Set DataFrame backend for this module 5 | pd = _backend_library() 6 | 7 | 8 | def test_cumulative_empty_partitions(): 9 | pdf = pd.DataFrame( 10 | {"x": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]}, 11 | index=pd.date_range("1995-02-26", periods=8, freq="5min"), 12 | dtype=float, 13 | ) 14 | pdf2 = pdf.drop(pdf.between_time("00:10", "00:20").index) 15 | 16 | df = from_pandas(pdf, npartitions=8) 17 | df2 = from_pandas(pdf2, npartitions=1).repartition(df.divisions) 18 | 19 | assert_eq(df2.cumprod(), pdf2.cumprod()) 20 | assert_eq(df2.cumsum(), pdf2.cumsum()) 21 | -------------------------------------------------------------------------------- /dask_expr/tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import sys 3 | 4 | import pytest 5 | 6 | from dask_expr import new_collection 7 | from dask_expr._expr import Lengths 8 | from dask_expr.datasets import Timeseries, timeseries 9 | from dask_expr.tests._util import assert_eq 10 | 11 | 12 | def test_timeseries(): 13 | df = timeseries(freq="360 s", start="2000-01-01", end="2000-01-02") 14 | assert_eq(df, df) 15 | 16 | 17 | def test_optimization(): 18 | df = timeseries(dtypes={"x": int, "y": float}, seed=123) 19 | expected = timeseries(dtypes={"x": int}, seed=123) 20 | result = df[["x"]].optimize(fuse=False) 21 | assert result.expr.frame.operand("columns") == expected.expr.frame.operand( 22 | "columns" 23 | ) 24 | 25 | expected = timeseries(dtypes={"x": int}, seed=123)["x"].simplify() 26 | result = df["x"].optimize(fuse=False) 27 | assert expected.expr.frame.operand("columns") == result.expr.frame.operand( 28 | "columns" 29 | ) 30 | 31 | 32 | def test_arrow_string_option(): 33 | df = timeseries(dtypes={"x": object, "y": float}, seed=123) 34 | result = df.optimize(fuse=False) 35 | assert result.x.dtype == "string" 36 | assert result.x.compute().dtype == "string" 37 | 38 | 39 | def test_column_projection_deterministic(): 40 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-02", seed=123) 41 | result_id = df[["id"]].optimize() 42 | result_id_x = df[["id", "x"]].optimize() 43 | assert_eq(result_id["id"], result_id_x["id"]) 44 | 45 | 46 | def test_timeseries_culling(): 47 | df = timeseries(dtypes={"x": int, "y": float}, seed=123) 48 | pdf = df.compute() 49 | offset = len(df.partitions[0].compute()) 50 | df = (df[["x"]] + 1).partitions[1] 51 | df2 = df.optimize() 52 | 53 | # All tasks should be fused for the single output partition 54 | assert df2.npartitions == 1 55 | assert len(df2.dask) == df2.npartitions 56 | expected = pdf.iloc[offset : 2 * offset][["x"]] + 1 57 | assert_eq(df2, expected) 58 | 59 | 60 | def test_persist(): 61 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-02", seed=123) 62 | a = df["x"] 63 | b = a.persist() 64 | 65 | assert_eq(a, b) 66 | assert len(b.dask) == 2 * b.npartitions 67 | 68 | 69 | def test_lengths(): 70 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-03", seed=123) 71 | assert len(df) == sum(new_collection(Lengths(df.expr).optimize()).compute()) 72 | 73 | 74 | def test_timeseries_empty_projection(): 75 | ts = timeseries(end="2000-01-02", dtypes={}) 76 | expected = timeseries(end="2000-01-02") 77 | assert len(ts) == len(expected) 78 | 79 | 80 | def test_combine_similar(tmpdir): 81 | df = timeseries(end="2000-01-02") 82 | pdf = df.compute() 83 | got = df[df["name"] == "a"][["id"]] 84 | 85 | expected = pdf[pdf["name"] == "a"][["id"]] 86 | assert_eq(got, expected) 87 | assert_eq(got.optimize(fuse=False), expected) 88 | assert_eq(got.optimize(fuse=True), expected) 89 | 90 | # We should only have one Timeseries node, and 91 | # it should not include "z" in the dtypes 92 | timeseries_nodes = list(got.optimize(fuse=False).find_operations(Timeseries)) 93 | assert len(timeseries_nodes) == 1 94 | assert set(timeseries_nodes[0].dtypes.keys()) == {"id", "name"} 95 | 96 | df = timeseries(end="2000-01-02") 97 | df2 = timeseries(end="2000-01-02") 98 | 99 | got = df + df2 100 | timeseries_nodes = list(got.optimize(fuse=False).find_operations(Timeseries)) 101 | assert len(timeseries_nodes) == 2 102 | with pytest.raises(AssertionError): 103 | assert_eq(df + df2, 2 * df) 104 | 105 | 106 | @pytest.mark.parametrize("seed", [42, None]) 107 | def test_timeseries_deterministic_head(seed): 108 | # Make sure our `random_state` code gives 109 | # us deterministic results 110 | df = timeseries(end="2000-01-02", seed=seed) 111 | assert_eq(df.head(), df.head()) 112 | assert_eq(df["x"].head(), df.head()["x"]) 113 | assert_eq(df.head()["x"], df["x"].partitions[0].compute().head()) 114 | 115 | 116 | @pytest.mark.parametrize("seed", [42, None]) 117 | def test_timeseries_gaph_size(seed): 118 | from dask.datasets import timeseries as dd_timeseries 119 | 120 | # Check that our graph size is reasonable 121 | df = timeseries(seed=seed) 122 | ddf = dd_timeseries(seed=seed) 123 | graph_size = sys.getsizeof(pickle.dumps(df.dask)) 124 | graph_size_dd = sys.getsizeof(pickle.dumps(dict(ddf.dask))) 125 | # Make sure we are close to the dask.dataframe graph size 126 | threshold = 1.10 127 | assert graph_size < threshold * graph_size_dd 128 | 129 | 130 | def test_dataset_head(): 131 | ddf = timeseries(freq="1d") 132 | expected = ddf.compute() 133 | assert_eq(ddf.head(30, npartitions=-1), expected) 134 | assert_eq(ddf.head(30, npartitions=-1, compute=False), expected) 135 | -------------------------------------------------------------------------------- /dask_expr/tests/test_datetime.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr._collection import from_pandas 4 | from dask_expr.tests._util import _backend_library, assert_eq 5 | 6 | pd = _backend_library() 7 | 8 | 9 | @pytest.fixture() 10 | def ser(): 11 | return pd.Series(pd.date_range(start="2020-01-01", end="2020-03-03")) 12 | 13 | 14 | @pytest.fixture() 15 | def dser(ser): 16 | return from_pandas(ser, npartitions=3) 17 | 18 | 19 | @pytest.fixture() 20 | def ser_td(): 21 | return pd.Series(pd.timedelta_range(start="0s", end="1000s", freq="s")) 22 | 23 | 24 | @pytest.fixture() 25 | def dser_td(ser_td): 26 | return from_pandas(ser_td, npartitions=3) 27 | 28 | 29 | @pytest.fixture() 30 | def ser_pr(): 31 | return pd.Series(pd.period_range(start="2020-01-01", end="2020-12-31", freq="D")) 32 | 33 | 34 | @pytest.fixture() 35 | def dser_pr(ser_pr): 36 | return from_pandas(ser_pr, npartitions=3) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "func, args", 41 | [ 42 | ("ceil", ("D",)), 43 | ("day_name", ()), 44 | ("floor", ("D",)), 45 | ("isocalendar", ()), 46 | ("month_name", ()), 47 | ("normalize", ()), 48 | ("round", ("D",)), 49 | ("strftime", ("%B %d, %Y, %r",)), 50 | ("to_period", ("D",)), 51 | ], 52 | ) 53 | def test_datetime_accessor_methods(ser, dser, func, args): 54 | assert_eq(getattr(ser.dt, func)(*args), getattr(dser.dt, func)(*args)) 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "func", 59 | [ 60 | "date", 61 | "day", 62 | "day_of_week", 63 | "day_of_year", 64 | "dayofweek", 65 | "dayofyear", 66 | "days_in_month", 67 | "daysinmonth", 68 | "hour", 69 | "is_leap_year", 70 | "is_month_end", 71 | "is_month_start", 72 | "is_quarter_end", 73 | "is_quarter_start", 74 | "is_year_end", 75 | "is_year_start", 76 | "microsecond", 77 | "minute", 78 | "month", 79 | "nanosecond", 80 | "quarter", 81 | "second", 82 | "time", 83 | "timetz", 84 | "weekday", 85 | "year", 86 | ], 87 | ) 88 | def test_datetime_accessor_properties(ser, dser, func): 89 | assert_eq(getattr(ser.dt, func), getattr(dser.dt, func)) 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "func", 94 | [ 95 | "components", 96 | "days", 97 | "microseconds", 98 | "nanoseconds", 99 | "seconds", 100 | ], 101 | ) 102 | def test_timedelta_accessor_properties(ser_td, dser_td, func): 103 | assert_eq(getattr(ser_td.dt, func), getattr(dser_td.dt, func)) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | "func", 108 | [ 109 | "end_time", 110 | "start_time", 111 | "qyear", 112 | "week", 113 | "weekofyear", 114 | ], 115 | ) 116 | def test_period_accessor_properties(ser_pr, dser_pr, func): 117 | assert_eq(getattr(ser_pr.dt, func), getattr(dser_pr.dt, func)) 118 | -------------------------------------------------------------------------------- /dask_expr/tests/test_describe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from dask_expr import from_pandas 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | # Set DataFrame backend for this module 8 | pd = _backend_library() 9 | 10 | 11 | @pytest.fixture 12 | def pdf(): 13 | pdf = pd.DataFrame( 14 | { 15 | "x": [None, 0, 1, 2, 3, 4] * 2, 16 | "ts": [ 17 | pd.Timestamp("2017-05-09 00:00:00.006000"), 18 | pd.Timestamp("2017-05-09 00:00:00.006000"), 19 | pd.Timestamp("2017-05-09 07:56:23.858694"), 20 | pd.Timestamp("2017-05-09 05:59:58.938999"), 21 | None, 22 | None, 23 | ] 24 | * 2, 25 | "td": [ 26 | np.timedelta64(3, "D"), 27 | np.timedelta64(1, "D"), 28 | None, 29 | None, 30 | np.timedelta64(3, "D"), 31 | np.timedelta64(1, "D"), 32 | ] 33 | * 2, 34 | "y": "a", 35 | } 36 | ) 37 | yield pdf 38 | 39 | 40 | @pytest.fixture 41 | def df(pdf): 42 | yield from_pandas(pdf, npartitions=2) 43 | 44 | 45 | def _drop_mean(df, col=None): 46 | """TODO: In pandas 2.0, mean is implemented for datetimes, but Dask returns None.""" 47 | if isinstance(df, pd.DataFrame): 48 | df.at["mean", col] = np.nan 49 | df.dropna(how="all", inplace=True) 50 | elif isinstance(df, pd.Series): 51 | df.drop(labels=["mean"], inplace=True, errors="ignore") 52 | else: 53 | raise NotImplementedError("Expected Series or DataFrame with mean") 54 | return df 55 | 56 | 57 | def test_describe_series(df, pdf): 58 | assert_eq(df.x.describe(), pdf.x.describe()) 59 | assert_eq(df.y.describe(), pdf.y.describe()) 60 | assert_eq(df.ts.describe(), _drop_mean(pdf.ts.describe())) 61 | assert_eq(df.td.describe(), pdf.td.describe()) 62 | 63 | 64 | def test_describe_df(df, pdf): 65 | assert_eq(df.describe(), _drop_mean(pdf.describe(), "ts")) 66 | -------------------------------------------------------------------------------- /dask_expr/tests/test_diagnostics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | 5 | import pytest 6 | 7 | pytest.importorskip("distributed") 8 | 9 | from distributed.utils_test import * # noqa 10 | 11 | from dask_expr import from_pandas 12 | from dask_expr.tests._util import _backend_library 13 | 14 | # Set DataFrame backend for this module 15 | pd = _backend_library() 16 | 17 | 18 | @pytest.fixture 19 | def pdf(): 20 | pdf = pd.DataFrame({"x": range(100)}) 21 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions 22 | yield pdf 23 | 24 | 25 | @pytest.fixture 26 | def df(pdf): 27 | yield from_pandas(pdf, npartitions=10) 28 | 29 | 30 | def test_analyze(df, client, tmpdir): 31 | pytest.importorskip("crick") 32 | filename = str(tmpdir / "analyze") 33 | expr = df.groupby(df.columns[1]).apply(lambda x: x) 34 | digraph = expr.analyze(filename=filename) 35 | assert os.path.exists(filename + ".svg") 36 | for exp in expr.optimize().walk(): 37 | assert any(exp._name in el for el in digraph.body) 38 | -------------------------------------------------------------------------------- /dask_expr/tests/test_dummies.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import pytest 3 | 4 | from dask_expr import from_pandas, get_dummies 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | pd = _backend_library() 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "data", 12 | [ 13 | pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype="category"), 14 | pd.Series( 15 | pandas.Categorical([1, 1, 1, 2, 2, 1, 3, 4], categories=[4, 3, 2, 1]) 16 | ), 17 | pd.DataFrame( 18 | {"a": [1, 2, 3, 4, 4, 3, 2, 1], "b": pandas.Categorical(list("abcdabcd"))} 19 | ), 20 | ], 21 | ) 22 | def test_get_dummies(data): 23 | exp = pd.get_dummies(data) 24 | 25 | ddata = from_pandas(data, 2) 26 | res = get_dummies(ddata) 27 | assert_eq(res, exp) 28 | pandas.testing.assert_index_equal(res.columns, exp.columns) 29 | -------------------------------------------------------------------------------- /dask_expr/tests/test_format.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: W291 2 | from textwrap import dedent 3 | 4 | import pytest 5 | from dask.utils import maybe_pluralize 6 | 7 | from dask_expr import from_pandas 8 | from dask_expr.tests._util import _backend_library 9 | 10 | # Set DataFrame backend for this module 11 | pd = _backend_library() 12 | 13 | 14 | def test_to_string(): 15 | pytest.importorskip("jinja2") 16 | df = pd.DataFrame( 17 | { 18 | "A": [1, 2, 3, 4, 5, 6, 7, 8], 19 | "B": list("ABCDEFGH"), 20 | "C": pd.Categorical(list("AAABBBCC")), 21 | } 22 | ) 23 | ddf = from_pandas(df, 3) 24 | 25 | exp = dedent( 26 | """\ 27 | A B C 28 | npartitions=3 29 | 0 int64 string category[known] 30 | 3 ... ... ... 31 | 6 ... ... ... 32 | 7 ... ... ...""" 33 | ) 34 | assert ddf.to_string() == exp 35 | 36 | exp_table = """ 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 |
ABC
npartitions=3
0int64stringcategory[known]
3.........
6.........
7.........
""" # noqa E222, E702 78 | footer = f"Dask Name: frompandas, {maybe_pluralize(1, 'expression')}" 79 | exp = f"""
Dask DataFrame Structure:
80 | {exp_table} 81 |
{footer}
""" 82 | assert ddf.to_html() == exp 83 | 84 | 85 | def test_series_format(): 86 | s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], index=list("ABCDEFGH")) 87 | ds = from_pandas(s, 3) 88 | 89 | exp = dedent( 90 | """\ 91 | npartitions=3 92 | A int64 93 | D ... 94 | G ... 95 | H ...""" 96 | ) 97 | assert ds.to_string() == exp 98 | 99 | 100 | def test_series_repr(): 101 | s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], index=list("ABCDEFGH")) 102 | ds = from_pandas(s, 3) 103 | 104 | exp = dedent( 105 | """\ 106 | Dask Series Structure: 107 | npartitions=3 108 | A int64 109 | D ... 110 | G ... 111 | H ... 112 | Dask Name: frompandas, 1 expression 113 | Expr=df""" 114 | ) 115 | assert repr(ds) == exp 116 | 117 | # Not a cheap way to determine if series is empty 118 | # so does not prefix with "Empty" as we do w/ empty DataFrame 119 | s = pd.Series([]) 120 | ds = from_pandas(s, 3) 121 | 122 | exp = dedent( 123 | """\ 124 | Dask Series Structure: 125 | npartitions=3 126 | string 127 | ... 128 | ... 129 | ... 130 | Dask Name: frompandas, 1 expression 131 | Expr=df""" 132 | ) 133 | assert repr(ds) == exp 134 | 135 | 136 | def test_df_repr(): 137 | df = pd.DataFrame({"col1": range(10), "col2": map(float, range(10))}) 138 | ddf = from_pandas(df, 3) 139 | 140 | exp = dedent( 141 | """\ 142 | Dask DataFrame Structure: 143 | col1 col2 144 | npartitions=3 145 | 0 int64 float64 146 | 4 ... ... 147 | 7 ... ... 148 | 9 ... ... 149 | Dask Name: frompandas, 1 expression 150 | Expr=df""" 151 | ) 152 | assert repr(ddf) == exp 153 | 154 | df = pd.DataFrame() 155 | ddf = from_pandas(df, 3) 156 | 157 | exp = dedent( 158 | """\ 159 | Empty Dask DataFrame Structure: 160 | npartitions=3 161 | 0 int64 float64 162 | 4 ... ... 163 | 7 ... ... 164 | 9 ... ... 165 | Dask Name: frompandas, 1 expression 166 | Expr=df""" 167 | ) 168 | 169 | 170 | def test_df_to_html(): 171 | df = pd.DataFrame({"col1": range(10), "col2": map(float, range(10))}) 172 | ddf = from_pandas(df, 3) 173 | 174 | exp = dedent( 175 | """\ 176 |
Dask DataFrame Structure:
177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 |
col1col2
npartitions=3
0int64float64
4......
7......
9......
213 |
Dask Name: frompandas, 1 expression
""" 214 | ) 215 | assert ddf.to_html() == exp 216 | assert ddf._repr_html_() == exp # for jupyter 217 | -------------------------------------------------------------------------------- /dask_expr/tests/test_fusion.py: -------------------------------------------------------------------------------- 1 | import dask.dataframe as dd 2 | import pytest 3 | 4 | from dask_expr import from_pandas, optimize 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | # Set DataFrame backend for this module 8 | pd = _backend_library() 9 | 10 | 11 | @pytest.fixture 12 | def pdf(): 13 | pdf = pd.DataFrame({"x": range(100)}) 14 | pdf["y"] = pdf.x * 10.0 15 | yield pdf 16 | 17 | 18 | @pytest.fixture 19 | def df(pdf): 20 | yield from_pandas(pdf, npartitions=10) 21 | 22 | 23 | def test_simple(df): 24 | out = (df["x"] + df["y"]) - 1 25 | unfused = optimize(out, fuse=False) 26 | fused = optimize(out, fuse=True) 27 | 28 | # Should only get one task per partition 29 | # from_pandas is not fused together 30 | assert len(fused.dask) == df.npartitions + 10 31 | assert_eq(fused, unfused) 32 | 33 | 34 | def test_with_non_fusable_on_top(df): 35 | out = (df["x"] + df["y"] - 1).sum() 36 | unfused = optimize(out, fuse=False) 37 | fused = optimize(out, fuse=True) 38 | 39 | assert len(fused.dask) < len(unfused.dask) 40 | assert_eq(fused, unfused) 41 | 42 | # Check that we still get fusion 43 | # after a non-blockwise operation as well 44 | fused_2 = optimize((out + 10) - 5, fuse=True) 45 | assert len(fused_2.dask) == len(fused.dask) + 1 # only one more task 46 | 47 | 48 | def test_optimize_fusion_many(): 49 | # Test that many `Blockwise`` operations, 50 | # originating from various IO operations, 51 | # can all be fused together 52 | a = from_pandas(pd.DataFrame({"x": range(100), "y": range(100)}), 10) 53 | b = from_pandas(pd.DataFrame({"a": range(100)}), 10) 54 | 55 | # some generic elemwise operations 56 | aa = a[["x"]] + 1 57 | aa["a"] = a["y"] + a["x"] 58 | aa["b"] = aa["x"] + 2 59 | series_a = aa[a["x"] > 1]["b"] 60 | 61 | bb = b[["a"]] + 1 62 | bb["b"] = b["a"] + b["a"] 63 | series_b = bb[b["a"] > 1]["b"] 64 | 65 | result = (series_a + series_b) + 1 66 | fused = optimize(result, fuse=True) 67 | unfused = optimize(result, fuse=False) 68 | assert fused.npartitions == a.npartitions 69 | # from_pandas is not fused together 70 | assert len(fused.dask) == fused.npartitions + 20 71 | assert_eq(fused, unfused) 72 | 73 | 74 | def test_optimize_fusion_repeat(df): 75 | # Test that we can optimize a collection 76 | # more than once, and fusion still works 77 | 78 | original = df.copy() 79 | 80 | # some generic elemwise operations 81 | df["x"] += 1 82 | df["z"] = df.y 83 | df += 2 84 | 85 | # repeatedly call optimize after doing new fusable things 86 | fused = optimize(optimize(optimize(df) + 2).x) 87 | 88 | # from_pandas is not fused together 89 | assert len(fused.dask) - 10 == fused.npartitions == original.npartitions 90 | assert_eq(fused, df.x + 2) 91 | 92 | 93 | def test_optimize_fusion_broadcast(df): 94 | # Check fusion with broadcated reduction 95 | result = ((df["x"] + 1) + df["y"].sum()) + 1 96 | fused = optimize(result) 97 | 98 | assert_eq(fused, result) 99 | assert len(fused.dask) < len(result.dask) 100 | 101 | 102 | def test_persist_with_fusion(df): 103 | # Check that fusion works after persisting 104 | df = (df + 2).persist() 105 | out = (df.y + 1).sum() 106 | fused = optimize(out) 107 | 108 | assert_eq(out, fused) 109 | assert len(fused.dask) < len(out.dask) 110 | 111 | 112 | def test_fuse_broadcast_deps(): 113 | pdf = pd.DataFrame({"a": [1, 2, 3]}) 114 | pdf2 = pd.DataFrame({"a": [2, 3, 4]}) 115 | pdf3 = pd.DataFrame({"a": [3, 4, 5]}) 116 | df = from_pandas(pdf, npartitions=1) 117 | df2 = from_pandas(pdf2, npartitions=1) 118 | df3 = from_pandas(pdf3, npartitions=2) 119 | 120 | query = df.merge(df2).merge(df3) 121 | # from_pandas is not fused together 122 | assert len(query.optimize().__dask_graph__()) == 2 + 4 123 | assert_eq(query, pdf.merge(pdf2).merge(pdf3)) 124 | 125 | 126 | def test_name(df): 127 | out = (df["x"] + df["y"]) - 1 128 | fused = optimize(out, fuse=True) 129 | assert "getitem" in str(fused.expr) 130 | assert "sub" in str(fused.expr) 131 | assert str(fused.expr) == str(fused.expr).lower() 132 | 133 | 134 | def test_fusion_executes_only_once(): 135 | times_called = [] 136 | import pandas as pd 137 | 138 | def test(i): 139 | times_called.append(i) 140 | return pd.DataFrame({"a": [1, 2, 3], "b": 1}) 141 | 142 | df = dd.from_map(test, [1], meta=[("a", "i8"), ("b", "i8")]) 143 | df = df[df.a > 1] 144 | df.sum().compute() 145 | assert len(times_called) == 1 146 | -------------------------------------------------------------------------------- /dask_expr/tests/test_indexing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from dask_expr import from_pandas 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | pd = _backend_library() 8 | 9 | 10 | @pytest.fixture 11 | def pdf(): 12 | pdf = pd.DataFrame({"x": range(20)}) 13 | pdf["y"] = pdf.x * 10.0 14 | yield pdf 15 | 16 | 17 | @pytest.fixture 18 | def df(pdf): 19 | yield from_pandas(pdf, npartitions=4) 20 | 21 | 22 | def test_iloc(df, pdf): 23 | assert_eq(df.iloc[:, 1], pdf.iloc[:, 1]) 24 | assert_eq(df.iloc[:, [1]], pdf.iloc[:, [1]]) 25 | assert_eq(df.iloc[:, [0, 1]], pdf.iloc[:, [0, 1]]) 26 | assert_eq(df.iloc[:, []], pdf.iloc[:, []]) 27 | 28 | 29 | def test_iloc_errors(df): 30 | with pytest.raises(NotImplementedError): 31 | df.iloc[1] 32 | with pytest.raises(NotImplementedError): 33 | df.iloc[1, 1] 34 | with pytest.raises(ValueError, match="Too many"): 35 | df.iloc[(1, 2, 3)] 36 | 37 | 38 | def test_loc_slice(pdf, df): 39 | pdf.columns = [10, 20] 40 | df.columns = [10, 20] 41 | assert_eq(df.loc[:, :15], pdf.loc[:, :15]) 42 | assert_eq(df.loc[:, 15:], pdf.loc[:, 15:]) 43 | assert_eq(df.loc[:, 25:], pdf.loc[:, 25:]) # no columns 44 | assert_eq(df.loc[:, ::-1], pdf.loc[:, ::-1]) 45 | 46 | 47 | def test_iloc_slice(df, pdf): 48 | assert_eq(df.iloc[:, :1], pdf.iloc[:, :1]) 49 | assert_eq(df.iloc[:, 1:], pdf.iloc[:, 1:]) 50 | assert_eq(df.iloc[:, 99:], pdf.iloc[:, 99:]) # no columns 51 | assert_eq(df.iloc[:, ::-1], pdf.iloc[:, ::-1]) 52 | 53 | 54 | @pytest.mark.parametrize("loc", [False, True]) 55 | @pytest.mark.parametrize("update", [False, True]) 56 | def test_columns_dtype_on_empty_slice(df, pdf, loc, update): 57 | pdf.columns = [10, 20] 58 | if update: 59 | df.columns = [10, 20] 60 | else: 61 | df = from_pandas(pdf, npartitions=10) 62 | 63 | assert df.columns.dtype == pdf.columns.dtype 64 | assert df.compute().columns.dtype == pdf.columns.dtype 65 | assert_eq(df, pdf) 66 | 67 | if loc: 68 | df = df.loc[:, []] 69 | pdf = pdf.loc[:, []] 70 | else: 71 | df = df[[]] 72 | pdf = pdf[[]] 73 | 74 | assert df.columns.dtype == pdf.columns.dtype 75 | assert df.compute().columns.dtype == pdf.columns.dtype 76 | assert_eq(df, pdf) 77 | 78 | 79 | def test_loc(df, pdf): 80 | assert_eq(df.loc[:, "x"], pdf.loc[:, "x"]) 81 | assert_eq(df.loc[:, ["x"]], pdf.loc[:, ["x"]]) 82 | assert_eq(df.loc[:, []], pdf.loc[:, []]) 83 | 84 | assert_eq(df.loc[df.y == 20, "x"], pdf.loc[pdf.y == 20, "x"]) 85 | assert_eq(df.loc[df.y == 20, ["x"]], pdf.loc[pdf.y == 20, ["x"]]) 86 | assert df.loc[3:8].divisions[0] == 3 87 | assert df.loc[3:8].divisions[-1] == 8 88 | 89 | assert df.loc[5].divisions == (5, 5) 90 | 91 | assert_eq(df.loc[5], pdf.loc[5:5]) 92 | assert_eq(df.loc[3:8], pdf.loc[3:8]) 93 | assert_eq(df.loc[:8], pdf.loc[:8]) 94 | assert_eq(df.loc[3:], pdf.loc[3:]) 95 | assert_eq(df.loc[[5]], pdf.loc[[5]]) 96 | 97 | assert_eq(df.x.loc[5], pdf.x.loc[5:5]) 98 | assert_eq(df.x.loc[3:8], pdf.x.loc[3:8]) 99 | assert_eq(df.x.loc[:8], pdf.x.loc[:8]) 100 | assert_eq(df.x.loc[3:], pdf.x.loc[3:]) 101 | assert_eq(df.x.loc[[5]], pdf.x.loc[[5]]) 102 | assert_eq(df.x.loc[[]], pdf.x.loc[[]]) 103 | assert_eq(df.x.loc[np.array([])], pdf.x.loc[np.array([])]) 104 | 105 | pytest.raises(KeyError, lambda: df.loc[1000]) 106 | assert_eq(df.loc[1000:], pdf.loc[1000:]) 107 | assert_eq(df.loc[1000:2000], pdf.loc[1000:2000]) 108 | assert_eq(df.loc[:-1000], pdf.loc[:-1000]) 109 | assert_eq(df.loc[-2000:-1000], pdf.loc[-2000:-1000]) 110 | 111 | 112 | def test_loc_non_informative_index(): 113 | df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40]) 114 | ddf = from_pandas(df, npartitions=2, sort=True).clear_divisions() 115 | assert not ddf.known_divisions 116 | 117 | ddf.loc[20:30].compute(scheduler="sync") 118 | 119 | assert_eq(ddf.loc[20:30], df.loc[20:30]) 120 | 121 | df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 20, 40]) 122 | ddf = from_pandas(df, npartitions=2, sort=True) 123 | assert_eq(ddf.loc[20], df.loc[20:20]) 124 | 125 | 126 | def test_loc_with_series(df, pdf): 127 | assert_eq(df.loc[df.x % 2 == 0], pdf.loc[pdf.x % 2 == 0]) 128 | 129 | 130 | def test_loc_with_array(df, pdf): 131 | assert_eq(df.loc[(df.x % 2 == 0).values], pdf.loc[(pdf.x % 2 == 0).values]) 132 | 133 | 134 | def test_loc_with_function(df, pdf): 135 | assert_eq(df.loc[lambda df: df["x"] > 3, :], pdf.loc[lambda df: df["x"] > 3, :]) 136 | 137 | def _col_loc_fun(_df): 138 | return _df.columns.str.contains("y") 139 | 140 | assert_eq(df.loc[:, _col_loc_fun], pdf.loc[:, _col_loc_fun]) 141 | 142 | 143 | def test_getitem_align(): 144 | df = pd.DataFrame( 145 | { 146 | "A": [1, 2, 3, 4, 5, 6, 7, 8, 9], 147 | "B": [9, 8, 7, 6, 5, 4, 3, 2, 1], 148 | "C": [True, False, True] * 3, 149 | }, 150 | columns=list("ABC"), 151 | ) 152 | ddf = from_pandas(df, 2) 153 | assert_eq(ddf[ddf.C.repartition([0, 2, 5, 8])], df[df.C]) 154 | 155 | 156 | def test_loc_bool_cindex(): 157 | # https://github.com/dask/dask/issues/11015 158 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) 159 | ddf = from_pandas(pdf, npartitions=1) 160 | indexer = [True, False] 161 | assert_eq(pdf.loc[:, indexer], ddf.loc[:, indexer]) 162 | 163 | 164 | def test_loc_slicing(): 165 | npartitions = 10 166 | pdf = pd.DataFrame( 167 | { 168 | "A": np.random.randn(npartitions * 10), 169 | }, 170 | index=pd.date_range("2024-01-01", "2024-12-31", npartitions * 10), 171 | ) 172 | df = from_pandas(pdf, npartitions=npartitions) 173 | result = df["2024-03-01":"2024-09-30"]["A"] 174 | assert_eq(result, pdf["2024-03-01":"2024-09-30"]["A"]) 175 | 176 | 177 | def test_indexing_element_index(): 178 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) 179 | result = from_pandas(pdf, 2).loc[2].index 180 | pd.testing.assert_index_equal(result.compute(), pdf.loc[[2]].index) 181 | 182 | result = from_pandas(pdf, 2).loc[[2]].index 183 | pd.testing.assert_index_equal(result.compute(), pdf.loc[[2]].index) 184 | -------------------------------------------------------------------------------- /dask_expr/tests/test_interchange.py: -------------------------------------------------------------------------------- 1 | from pandas.core.interchange.dataframe_protocol import DtypeKind 2 | 3 | from dask_expr import from_pandas 4 | from dask_expr.tests._util import _backend_library 5 | 6 | # Set DataFrame backend for this module 7 | pd = _backend_library() 8 | 9 | 10 | def test_interchange_protocol(): 11 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1}) 12 | df = from_pandas(pdf, npartitions=2) 13 | df_int = df.__dataframe__() 14 | pd.testing.assert_index_equal(pdf.columns, df_int.column_names()) 15 | assert df_int.num_columns() == 2 16 | assert df_int.num_rows() == 3 17 | column = df_int.get_columns()[0] 18 | assert column.dtype()[0] == DtypeKind.INT 19 | -------------------------------------------------------------------------------- /dask_expr/tests/test_merge_asof.py: -------------------------------------------------------------------------------- 1 | from dask_expr import from_pandas, merge_asof 2 | from dask_expr.tests._util import _backend_library, assert_eq 3 | 4 | pd = _backend_library() 5 | 6 | 7 | def test_merge_asof_indexed(): 8 | A = pd.DataFrame( 9 | {"left_val": list("abcd" * 3)}, 10 | index=[1, 3, 7, 9, 10, 13, 14, 17, 20, 24, 25, 28], 11 | ) 12 | a = from_pandas(A, npartitions=4) 13 | B = pd.DataFrame( 14 | {"right_val": list("xyz" * 4)}, 15 | index=[1, 2, 3, 6, 7, 10, 12, 14, 16, 19, 23, 26], 16 | ) 17 | b = from_pandas(B, npartitions=3) 18 | 19 | C = pd.merge_asof(A, B, left_index=True, right_index=True) 20 | c = merge_asof(a, b, left_index=True, right_index=True) 21 | 22 | assert_eq(c, C) 23 | 24 | 25 | def test_merge_asof_on_basic(): 26 | A = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) 27 | a = from_pandas(A, npartitions=2) 28 | B = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) 29 | b = from_pandas(B, npartitions=2) 30 | 31 | C = pd.merge_asof(A, B, on="a") 32 | c = merge_asof(a, b, on="a") 33 | # merge_asof does not preserve index 34 | assert_eq(c, C, check_index=False) 35 | 36 | 37 | def test_merge_asof_one_partition(): 38 | left = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) 39 | right = pd.DataFrame({"a": [1, 2, 3], "c": [4, 5, 6]}) 40 | 41 | ddf_left = from_pandas(left, npartitions=1) 42 | ddf_left = ddf_left.set_index("a", sort=True) 43 | ddf_right = from_pandas(right, npartitions=1) 44 | ddf_right = ddf_right.set_index("a", sort=True) 45 | 46 | result = merge_asof( 47 | ddf_left, ddf_right, left_index=True, right_index=True, direction="nearest" 48 | ) 49 | expected = pd.merge_asof( 50 | left.set_index("a"), 51 | right.set_index("a"), 52 | left_index=True, 53 | right_index=True, 54 | direction="nearest", 55 | ) 56 | assert_eq(result, expected) 57 | -------------------------------------------------------------------------------- /dask_expr/tests/test_predicate_pushdown.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr import from_pandas 4 | from dask_expr._expr import rewrite_filters 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | pd = _backend_library() 8 | 9 | 10 | @pytest.fixture 11 | def pdf(): 12 | pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 2}) 13 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions 14 | yield pdf 15 | 16 | 17 | @pytest.fixture 18 | def df(pdf): 19 | yield from_pandas(pdf, npartitions=10) 20 | 21 | 22 | def test_rewrite_filters(df): 23 | predicate = (df.x == 1) | (df.x == 1) 24 | expected = df.x == 1 25 | assert rewrite_filters(predicate.expr)._name == expected._name 26 | 27 | predicate = (df.x == 1) | ((df.x == 1) & (df.y == 2)) 28 | expected = df.x == 1 29 | assert rewrite_filters(predicate.expr)._name == expected._name 30 | 31 | predicate = ((df.x == 1) & (df.y == 3)) | ((df.x == 1) & (df.y == 2)) 32 | expected = (df.x == 1) & ((df.y == 3) | (df.y == 2)) 33 | assert rewrite_filters(predicate.expr)._name == expected._name 34 | 35 | predicate = ((df.x == 1) & (df.y == 3) & (df.a == 1)) | ( 36 | (df.x == 1) & (df.y == 2) & (df.a == 1) 37 | ) 38 | expected = ((df.x == 1) & (df.a == 1)) & ((df.y == 3) | (df.y == 2)) 39 | assert rewrite_filters(predicate.expr)._name == expected._name 40 | 41 | predicate = (df.x == 1) | (df.y == 1) 42 | assert rewrite_filters(predicate.expr)._name == predicate._name 43 | 44 | predicate = df.x == 1 45 | assert rewrite_filters(predicate.expr)._name == predicate._name 46 | 47 | predicate = (df.x.isin([1, 2, 3]) & (df.y == 3)) | ( 48 | df.x.isin([1, 2, 3]) & (df.y == 2) 49 | ) 50 | expected = df.x.isin([1, 2, 3]) & ((df.y == 3) | (df.y == 2)) 51 | assert rewrite_filters(predicate.expr)._name == expected._name 52 | 53 | 54 | def test_rewrite_filters_query(df, pdf): 55 | result = df[((df.x == 1) & (df.y > 1)) | ((df.x == 1) & (df.y > 2))] 56 | result = result[["x"]] 57 | expected = pdf[((pdf.x == 1) & (pdf.y > 1)) | ((pdf.x == 1) & (pdf.y > 2))] 58 | expected = expected[["x"]] 59 | assert_eq(result, expected) 60 | -------------------------------------------------------------------------------- /dask_expr/tests/test_quantiles.py: -------------------------------------------------------------------------------- 1 | from dask_expr import from_pandas 2 | from dask_expr.tests._util import _backend_library, assert_eq 3 | 4 | # Set DataFrame backend for this module 5 | pd = _backend_library() 6 | 7 | 8 | def test_repartition_quantiles(): 9 | pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 15, 7, 8, 9, 10, 11], "d": 3}) 10 | df = from_pandas(pdf, npartitions=5) 11 | result = df.a._repartition_quantiles(npartitions=5) 12 | expected = pd.Series( 13 | [1, 1, 3, 7, 9, 15], index=[0, 0.2, 0.4, 0.6, 0.8, 1], name="a" 14 | ) 15 | assert_eq(result, expected, check_exact=False) 16 | 17 | result = df.a._repartition_quantiles(npartitions=4) 18 | expected = pd.Series([1, 2, 5, 8, 15], index=[0, 0.25, 0.5, 0.75, 1], name="a") 19 | assert_eq(result, expected, check_exact=False) 20 | -------------------------------------------------------------------------------- /dask_expr/tests/test_repartition.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from dask_expr import Repartition, from_pandas, repartition 5 | from dask_expr.tests._util import _backend_library, assert_eq 6 | 7 | pd = _backend_library() 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "kwargs", 12 | [ 13 | {"npartitions": 2}, 14 | {"npartitions": 4}, 15 | {"divisions": (0, 1, 79)}, 16 | {"partition_size": "1kb"}, 17 | ], 18 | ) 19 | def test_repartition_combine_similar(kwargs): 20 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2}) 21 | df = from_pandas(pdf, npartitions=3) 22 | query = df.repartition(**kwargs) 23 | query["new"] = query.x + query.y 24 | result = query.optimize(fuse=False) 25 | 26 | expected = df.repartition(**kwargs).optimize(fuse=False) 27 | arg1 = expected.x 28 | arg2 = expected.y 29 | expected["new"] = arg1 + arg2 30 | assert result._name == expected._name 31 | 32 | expected_pdf = pdf.copy() 33 | expected_pdf["new"] = expected_pdf.x + expected_pdf.y 34 | assert_eq(result, expected_pdf) 35 | 36 | 37 | @pytest.mark.parametrize("type_ctor", [lambda o: o, tuple, list]) 38 | def test_repartition_noop(type_ctor): 39 | pdf = pd.DataFrame({"x": [1, 2, 4, 5], "y": [6, 7, 8, 9]}, index=[-1, 0, 2, 7]) 40 | df = from_pandas(pdf, npartitions=2) 41 | ds = df.x 42 | 43 | def assert_not_repartitions(expr, fuse=False): 44 | repartitions = [ 45 | x for x in expr.optimize(fuse=fuse).walk() if isinstance(x, Repartition) 46 | ] 47 | assert len(repartitions) == 0 48 | 49 | # DataFrame method 50 | df2 = df.repartition(divisions=type_ctor(df.divisions)) 51 | assert_not_repartitions(df2.expr) 52 | 53 | # Top-level dask.dataframe method 54 | df3 = repartition(df, divisions=type_ctor(df.divisions)) 55 | assert_not_repartitions(df3.expr) 56 | 57 | # Series method 58 | ds2 = ds.repartition(divisions=type_ctor(ds.divisions)) 59 | assert_not_repartitions(ds2.expr) 60 | 61 | # Top-level dask.dataframe method applied to a Series 62 | ds3 = repartition(ds, divisions=type_ctor(ds.divisions)) 63 | assert_not_repartitions(ds3.expr) 64 | 65 | 66 | def test_repartition_freq(): 67 | ts = pd.date_range("2015-01-01 00:00", "2015-05-01 23:50", freq="10min") 68 | pdf = pd.DataFrame( 69 | np.random.randint(0, 100, size=(len(ts), 4)), columns=list("ABCD"), index=ts 70 | ) 71 | df = from_pandas(pdf, npartitions=1).repartition(freq="MS") 72 | 73 | assert_eq(df, pdf) 74 | 75 | assert df.divisions == ( 76 | pd.Timestamp("2015-1-1 00:00:00"), 77 | pd.Timestamp("2015-2-1 00:00:00"), 78 | pd.Timestamp("2015-3-1 00:00:00"), 79 | pd.Timestamp("2015-4-1 00:00:00"), 80 | pd.Timestamp("2015-5-1 00:00:00"), 81 | pd.Timestamp("2015-5-1 23:50:00"), 82 | ) 83 | 84 | assert df.npartitions == 5 85 | 86 | 87 | def test_repartition_freq_errors(): 88 | pdf = pd.DataFrame({"x": [1, 2, 3]}) 89 | df = from_pandas(pdf, npartitions=1) 90 | with pytest.raises(TypeError, match="for timeseries"): 91 | df.repartition(freq="1s") 92 | 93 | 94 | def test_repartition_npartitions_numeric_edge_case(): 95 | """ 96 | Test that we cover numeric edge cases when 97 | int(ddf.npartitions / npartitions) * npartitions) != ddf.npartitions 98 | """ 99 | df = pd.DataFrame({"x": range(100)}) 100 | a = from_pandas(df, npartitions=15) 101 | assert a.npartitions == 15 102 | b = a.repartition(npartitions=11) 103 | assert_eq(a, b) 104 | 105 | 106 | def test_repartition_empty_partitions_dtype(): 107 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8]}) 108 | df = from_pandas(pdf, npartitions=4) 109 | assert_eq( 110 | df[df.x < 5].repartition(npartitions=1), 111 | pdf[pdf.x < 5], 112 | ) 113 | 114 | 115 | def test_repartition_filter_pushdown(): 116 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2}) 117 | df = from_pandas(pdf, npartitions=10) 118 | result = df.repartition(npartitions=5) 119 | result = result[result.x > 5.0] 120 | expected = df[df.x > 5.0].repartition(npartitions=5) 121 | assert result.simplify()._name == expected._name 122 | 123 | result = df.repartition(npartitions=5) 124 | result = result[result.x > 5.0][["x", "y"]] 125 | expected = df[["x", "y"]] 126 | expected = expected[expected.x > 5.0].repartition(npartitions=5) 127 | assert result.simplify()._name == expected.simplify()._name 128 | 129 | result = df.repartition(npartitions=5)[["x", "y"]] 130 | result = result[result.x > 5.0] 131 | expected = df[["x", "y"]] 132 | expected = expected[expected.x > 5.0].repartition(npartitions=5) 133 | assert result.simplify()._name == expected.simplify()._name 134 | 135 | 136 | def test_repartition_unknown_divisions(): 137 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2}) 138 | df = from_pandas(pdf, npartitions=5).clear_divisions() 139 | with pytest.raises( 140 | ValueError, match="Cannot repartition on divisions with unknown divisions" 141 | ): 142 | df.repartition(divisions=(0, 100)).compute() 143 | -------------------------------------------------------------------------------- /dask_expr/tests/test_resample.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | 5 | from dask_expr import from_pandas 6 | from dask_expr.tests._util import _backend_library, assert_eq 7 | 8 | # Set DataFrame backend for this module 9 | pd = _backend_library() 10 | 11 | 12 | def resample(df, freq, how="mean", **kwargs): 13 | return getattr(df.resample(freq, **kwargs), how)() 14 | 15 | 16 | @pytest.fixture 17 | def pdf(): 18 | idx = pd.date_range("2000-01-01", periods=12, freq="min") 19 | pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx) 20 | pdf["bar"] = 1 21 | yield pdf 22 | 23 | 24 | @pytest.fixture 25 | def df(pdf): 26 | yield from_pandas(pdf, npartitions=4) 27 | 28 | 29 | @pytest.mark.parametrize("kwargs", [{}, {"closed": "left"}]) 30 | @pytest.mark.parametrize( 31 | "api", 32 | [ 33 | "count", 34 | "prod", 35 | "mean", 36 | "sum", 37 | "min", 38 | "max", 39 | "first", 40 | "last", 41 | "var", 42 | "std", 43 | "size", 44 | "nunique", 45 | "median", 46 | "quantile", 47 | "ohlc", 48 | "sem", 49 | ], 50 | ) 51 | def test_resample_apis(df, pdf, api, kwargs): 52 | result = getattr(df.resample("2min", **kwargs), api)() 53 | expected = getattr(pdf.resample("2min", **kwargs), api)() 54 | assert_eq(result, expected) 55 | 56 | # No column output 57 | if api not in ("size",): 58 | result = getattr(df.resample("2min"), api)()["foo"] 59 | expected = getattr(pdf.resample("2min"), api)()["foo"] 60 | assert_eq(result, expected) 61 | 62 | if api != "ohlc": 63 | # ohlc actually gives back a DataFrame, so this doesn't work 64 | q = result.simplify() 65 | eq = getattr(df["foo"].resample("2min"), api)().simplify() 66 | assert q._name == eq._name 67 | 68 | 69 | @pytest.mark.parametrize( 70 | ["obj", "method", "npartitions", "freq", "closed", "label"], 71 | list( 72 | product( 73 | ["series", "frame"], 74 | ["count", "mean", "ohlc"], 75 | [2, 5], 76 | ["30min", "h", "d", "W"], 77 | ["right", "left"], 78 | ["right", "left"], 79 | ) 80 | ), 81 | ) 82 | def test_series_resample(obj, method, npartitions, freq, closed, label): 83 | index = pd.date_range("1-1-2000", "2-15-2000", freq="h") 84 | index = index.union(pd.date_range("4-15-2000", "5-15-2000", freq="h")) 85 | if obj == "series": 86 | ps = pd.Series(range(len(index)), index=index) 87 | elif obj == "frame": 88 | ps = pd.DataFrame({"a": range(len(index))}, index=index) 89 | ds = from_pandas(ps, npartitions=npartitions) 90 | # Series output 91 | 92 | result = resample(ds, freq, how=method, closed=closed, label=label) 93 | expected = resample(ps, freq, how=method, closed=closed, label=label) 94 | 95 | assert_eq(result, expected, check_dtype=False) 96 | 97 | divisions = result.divisions 98 | 99 | assert expected.index[0] == divisions[0] 100 | assert expected.index[-1] == divisions[-1] 101 | 102 | 103 | def test_resample_agg(df, pdf): 104 | def my_sum(vals, foo=None, *, bar=None): 105 | return vals.sum() 106 | 107 | result = df.resample("2min").agg(my_sum, "foo", bar="bar") 108 | expected = pdf.resample("2min").agg(my_sum, "foo", bar="bar") 109 | assert_eq(result, expected) 110 | 111 | result = df.resample("2min").agg(my_sum)["foo"] 112 | expected = pdf.resample("2min").agg(my_sum)["foo"] 113 | assert_eq(result, expected) 114 | 115 | # simplify up disabled for `agg`, function may access other columns 116 | q = df.resample("2min").agg(my_sum)["foo"].simplify() 117 | eq = df["foo"].resample("2min").agg(my_sum).simplify() 118 | assert q._name != eq._name 119 | 120 | 121 | @pytest.mark.parametrize("method", ["count", "nunique", "size", "sum"]) 122 | def test_resample_has_correct_fill_value(method): 123 | index = pd.date_range("2000-01-01", "2000-02-15", freq="h") 124 | index = index.union(pd.date_range("4-15-2000", "5-15-2000", freq="h")) 125 | ps = pd.Series(range(len(index)), index=index) 126 | ds = from_pandas(ps, npartitions=2) 127 | 128 | assert_eq( 129 | getattr(ds.resample("30min"), method)(), getattr(ps.resample("30min"), method)() 130 | ) 131 | 132 | 133 | def test_resample_divisions_propagation(): 134 | idx = pd.date_range(start="10:00:00.873821", end="10:05:00", freq="0.002s") 135 | pdf = pd.DataFrame({"data": 1}, index=idx) 136 | df = from_pandas(pdf, npartitions=10) 137 | result = df.resample("0.03s").mean() 138 | result = result.repartition(freq="1d") 139 | expected = pdf.resample("0.03s").mean() 140 | assert_eq(result, expected) 141 | 142 | result = df.resample("0.03s").mean().partitions[1] 143 | expected = pdf.resample("0.03s").mean()[997 : 2 * 997] 144 | assert_eq(result, expected) 145 | -------------------------------------------------------------------------------- /dask_expr/tests/test_reshape.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr import from_pandas, pivot_table 4 | from dask_expr.tests._util import _backend_library, assert_eq 5 | 6 | # Set DataFrame backend for this module 7 | pd = _backend_library() 8 | 9 | 10 | @pytest.fixture 11 | def pdf(): 12 | pdf = pd.DataFrame( 13 | { 14 | "x": [1, 2, 3, 4, 5, 6], 15 | "y": pd.Series([4, 5, 8, 6, 1, 4], dtype="category"), 16 | "z": [4, 15, 8, 16, 1, 14], 17 | "a": 1, 18 | } 19 | ) 20 | yield pdf 21 | 22 | 23 | @pytest.fixture 24 | def df(pdf): 25 | yield from_pandas(pdf, npartitions=3) 26 | 27 | 28 | @pytest.mark.parametrize("aggfunc", ["first", "last", "sum", "mean", "count"]) 29 | def test_pivot_table(df, pdf, aggfunc): 30 | assert_eq( 31 | df.pivot_table(index="x", columns="y", values="z", aggfunc=aggfunc), 32 | pdf.pivot_table( 33 | index="x", columns="y", values="z", aggfunc=aggfunc, observed=False 34 | ), 35 | check_dtype=aggfunc != "count", 36 | ) 37 | 38 | assert_eq( 39 | df.pivot_table(index="x", columns="y", values=["z", "a"], aggfunc=aggfunc), 40 | pdf.pivot_table( 41 | index="x", columns="y", values=["z", "a"], aggfunc=aggfunc, observed=False 42 | ), 43 | check_dtype=aggfunc != "count", 44 | ) 45 | 46 | assert_eq( 47 | pivot_table(df, index="x", columns="y", values=["z", "a"], aggfunc=aggfunc), 48 | pdf.pivot_table( 49 | index="x", columns="y", values=["z", "a"], aggfunc=aggfunc, observed=False 50 | ), 51 | check_dtype=aggfunc != "count", 52 | ) 53 | 54 | 55 | def test_pivot_table_fails(df): 56 | with pytest.raises(ValueError, match="must be the name of an existing column"): 57 | df.pivot_table(index="aaa", columns="y", values="z") 58 | with pytest.raises(ValueError, match="must be the name of an existing column"): 59 | df.pivot_table(index=["a"], columns="y", values="z") 60 | 61 | with pytest.raises(ValueError, match="must be the name of an existing column"): 62 | df.pivot_table(index="a", columns="xxx", values="z") 63 | with pytest.raises(ValueError, match="must be the name of an existing column"): 64 | df.pivot_table(index="a", columns=["x"], values="z") 65 | 66 | with pytest.raises(ValueError, match="'columns' must be category dtype"): 67 | df.pivot_table(index="a", columns="x", values="z") 68 | 69 | df2 = df.copy() 70 | df2["y"] = df2.y.cat.as_unknown() 71 | with pytest.raises(ValueError, match="'columns' must have known categories"): 72 | df2.pivot_table(index="a", columns="y", values="z") 73 | 74 | with pytest.raises( 75 | ValueError, match="'values' must refer to an existing column or columns" 76 | ): 77 | df.pivot_table(index="x", columns="y", values="aaa") 78 | 79 | msg = "aggfunc must be either 'mean', 'sum', 'count', 'first', 'last'" 80 | with pytest.raises(ValueError, match=msg): 81 | pivot_table(df, index="x", columns="y", values="z", aggfunc=["sum"]) 82 | -------------------------------------------------------------------------------- /dask_expr/tests/test_rolling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from dask_expr import from_pandas 6 | from dask_expr.tests._util import _backend_library, assert_eq 7 | 8 | # Set DataFrame backend for this module 9 | pd = _backend_library() 10 | 11 | 12 | @pytest.fixture 13 | def pdf(): 14 | idx = pd.date_range("2000-01-01", periods=12, freq="min") 15 | pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx) 16 | pdf["bar"] = 1 17 | yield pdf 18 | 19 | 20 | @pytest.fixture 21 | def df(pdf, request): 22 | npartitions = getattr(request, "param", 2) 23 | yield from_pandas(pdf, npartitions=npartitions) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "api,how_args", 28 | [ 29 | ("count", ()), 30 | ("mean", ()), 31 | ("sum", ()), 32 | ("min", ()), 33 | ("max", ()), 34 | ("var", ()), 35 | ("std", ()), 36 | ("median", ()), 37 | ("skew", ()), 38 | ("quantile", (0.5,)), 39 | ("kurt", ()), 40 | ], 41 | ) 42 | @pytest.mark.parametrize("window,min_periods", ((1, None), (3, 2), (3, 3))) 43 | @pytest.mark.parametrize("center", (True, False)) 44 | @pytest.mark.parametrize("df", (1, 2), indirect=True) 45 | def test_rolling_apis(df, pdf, window, api, how_args, min_periods, center): 46 | args = (window,) 47 | kwargs = dict(min_periods=min_periods, center=center) 48 | 49 | result = getattr(df.rolling(*args, **kwargs), api)(*how_args) 50 | expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args) 51 | assert_eq(result, expected) 52 | 53 | result = getattr(df.rolling(*args, **kwargs), api)(*how_args)["foo"] 54 | expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args)["foo"] 55 | assert_eq(result, expected) 56 | 57 | q = result.simplify() 58 | eq = getattr(df["foo"].rolling(*args, **kwargs), api)(*how_args).simplify() 59 | assert q._name == eq._name 60 | 61 | 62 | @pytest.mark.parametrize("window", (1, 2)) 63 | @pytest.mark.parametrize("df", (1, 2), indirect=True) 64 | def test_rolling_agg(df, pdf, window): 65 | def my_sum(vals, foo=None, *, bar=None): 66 | return vals.sum() 67 | 68 | result = df.rolling(window).agg(my_sum, "foo", bar="bar") 69 | expected = pdf.rolling(window).agg(my_sum, "foo", bar="bar") 70 | assert_eq(result, expected) 71 | 72 | result = df.rolling(window).agg(my_sum)["foo"] 73 | expected = pdf.rolling(window).agg(my_sum)["foo"] 74 | assert_eq(result, expected) 75 | 76 | # simplify up disabled for `agg`, function may access other columns 77 | q = df.rolling(window).agg(my_sum)["foo"].simplify() 78 | eq = df["foo"].rolling(window).agg(my_sum).simplify() 79 | assert q._name != eq._name 80 | 81 | 82 | @pytest.mark.parametrize("window", (1, 2)) 83 | @pytest.mark.parametrize("df", (1, 2), indirect=True) 84 | @pytest.mark.parametrize("raw", (True, False)) 85 | @pytest.mark.parametrize("foo", (1, None)) 86 | @pytest.mark.parametrize("bar", (2, None)) 87 | def test_rolling_apply(df, pdf, window, raw, foo, bar): 88 | def my_sum(vals, foo_=None, *, bar_=None): 89 | assert foo_ == foo 90 | assert bar_ == bar 91 | if raw: 92 | assert isinstance(vals, np.ndarray) 93 | else: 94 | assert isinstance(vals, pd.Series) 95 | return vals.sum() 96 | 97 | kwargs = dict(raw=raw, args=(foo,), kwargs=dict(bar_=bar)) 98 | 99 | result = df.rolling(window).apply(my_sum, **kwargs) 100 | expected = pdf.rolling(window).apply(my_sum, **kwargs) 101 | assert_eq(result, expected) 102 | 103 | result = df.rolling(window).apply(my_sum, **kwargs)["foo"] 104 | expected = pdf.rolling(window).apply(my_sum, **kwargs)["foo"] 105 | assert_eq(result, expected) 106 | 107 | # simplify up disabled for `apply`, function may access other columns 108 | q = df.rolling(window).apply(my_sum, **kwargs)["foo"].simplify() 109 | eq = df["foo"].rolling(window).apply(my_sum, **kwargs).simplify() 110 | assert q._name == eq._name 111 | 112 | 113 | def test_rolling_one_element_window(df, pdf): 114 | pdf.index = pd.date_range("2000-01-01", periods=12, freq="2s") 115 | df = from_pandas(pdf, npartitions=3) 116 | result = pdf.foo.rolling("1s").count() 117 | expected = df.foo.rolling("1s").count() 118 | assert_eq(result, expected) 119 | 120 | 121 | @pytest.mark.parametrize("window", ["2s", "5s", "20s", "10h"]) 122 | def test_time_rolling_large_window_variable_chunks(window): 123 | df = pd.DataFrame( 124 | { 125 | "a": pd.date_range("2016-01-01 00:00:00", periods=100, freq="1s"), 126 | "b": np.random.randint(100, size=(100,)), 127 | } 128 | ) 129 | ddf = from_pandas(df, 5) 130 | ddf = ddf.repartition(divisions=[0, 5, 20, 28, 33, 54, 79, 80, 82, 99]) 131 | df = df.set_index("a") 132 | ddf = ddf.set_index("a") 133 | assert_eq(ddf.rolling(window).sum(), df.rolling(window).sum()) 134 | assert_eq(ddf.rolling(window).count(), df.rolling(window).count()) 135 | assert_eq(ddf.rolling(window).mean(), df.rolling(window).mean()) 136 | 137 | 138 | def test_rolling_one_element_window_empty_after(df, pdf): 139 | pdf.index = pd.date_range("2000-01-01", periods=12, freq="2s") 140 | df = from_pandas(pdf, npartitions=3) 141 | result = df.map_overlap(lambda x: x.rolling("1s").count(), before="1s", after="1s") 142 | expected = pdf.rolling("1s").count() 143 | assert_eq(result, expected) 144 | 145 | 146 | @pytest.mark.parametrize("window", [1, 2, 4, 5]) 147 | @pytest.mark.parametrize("center", [True, False]) 148 | def test_rolling_cov(df, pdf, window, center): 149 | # DataFrame 150 | prolling = pdf.drop("foo", axis=1).rolling(window, center=center) 151 | drolling = df.drop("foo", axis=1).rolling(window, center=center) 152 | assert_eq(prolling.cov(), drolling.cov()) 153 | 154 | # Series 155 | prolling = pdf.bar.rolling(window, center=center) 156 | drolling = df.bar.rolling(window, center=center) 157 | assert_eq(prolling.cov(), drolling.cov()) 158 | 159 | # Projection 160 | actual = df.rolling(window, center=center).cov()[["foo", "bar"]].simplify() 161 | expected = df[["foo", "bar"]].rolling(window, center=center).cov().simplify() 162 | assert actual._name == expected._name 163 | 164 | 165 | def test_rolling_raises(): 166 | df = pd.DataFrame( 167 | {"a": np.random.randn(25).cumsum(), "b": np.random.randint(100, size=(25,))} 168 | ) 169 | ddf = from_pandas(df, npartitions=2) 170 | 171 | pytest.raises(ValueError, lambda: ddf.rolling(1.5)) 172 | pytest.raises(ValueError, lambda: ddf.rolling(-1)) 173 | pytest.raises(ValueError, lambda: ddf.rolling(3, min_periods=1.2)) 174 | pytest.raises(ValueError, lambda: ddf.rolling(3, min_periods=-2)) 175 | pytest.raises(NotImplementedError, lambda: ddf.rolling(100).mean().compute()) 176 | 177 | 178 | def test_time_rolling_constructor(df): 179 | result = df.rolling("4s") 180 | assert result.window == "4s" 181 | assert result.min_periods is None 182 | assert result.win_type is None 183 | -------------------------------------------------------------------------------- /dask_expr/tests/test_string_accessor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dask_expr._collection import DataFrame, from_pandas 4 | from dask_expr.tests._util import _backend_library, assert_eq 5 | 6 | pd = _backend_library() 7 | 8 | 9 | @pytest.fixture() 10 | def ser(): 11 | return pd.Series(["a", "b", "1", "aaa", "bbb", "ccc", "ddd", "abcd"]) 12 | 13 | 14 | @pytest.fixture() 15 | def dser(ser): 16 | import dask.dataframe as dd 17 | 18 | return dd.from_pandas(ser, npartitions=3) 19 | 20 | 21 | @pytest.mark.parametrize( 22 | "func, kwargs", 23 | [ 24 | ("len", {}), 25 | ("capitalize", {}), 26 | ("casefold", {}), 27 | ("contains", {"pat": "a"}), 28 | ("count", {"pat": "a"}), 29 | ("endswith", {"pat": "a"}), 30 | ("extract", {"pat": r"[ab](\d)"}), 31 | ("extractall", {"pat": r"[ab](\d)"}), 32 | ("find", {"sub": "a"}), 33 | ("findall", {"pat": "a"}), 34 | ("fullmatch", {"pat": "a"}), 35 | ("get", {"i": 0}), 36 | ("isalnum", {}), 37 | ("isalpha", {}), 38 | ("isdecimal", {}), 39 | ("isdigit", {}), 40 | ("islower", {}), 41 | ("isspace", {}), 42 | ("istitle", {}), 43 | ("isupper", {}), 44 | ("join", {"sep": "-"}), 45 | ("len", {}), 46 | ("ljust", {"width": 3}), 47 | ("lower", {}), 48 | ("lstrip", {}), 49 | ("match", {"pat": r"[ab](\d)"}), 50 | ("normalize", {"form": "NFC"}), 51 | ("pad", {"width": 3}), 52 | ("removeprefix", {"prefix": "a"}), 53 | ("removesuffix", {"suffix": "a"}), 54 | ("repeat", {"repeats": 2}), 55 | ("replace", {"pat": "a", "repl": "b"}), 56 | ("rfind", {"sub": "a"}), 57 | ("rjust", {"width": 3}), 58 | ("rstrip", {}), 59 | ("slice", {"start": 0, "stop": 1}), 60 | ("slice_replace", {"start": 0, "stop": 1, "repl": "a"}), 61 | ("startswith", {"pat": "a"}), 62 | ("strip", {}), 63 | ("swapcase", {}), 64 | ("title", {}), 65 | ("upper", {}), 66 | ("wrap", {"width": 2}), 67 | ("zfill", {"width": 2}), 68 | ("split", {"pat": "a"}), 69 | ("rsplit", {"pat": "a"}), 70 | ("cat", {}), 71 | ("cat", {"others": pd.Series(["a"])}), 72 | ], 73 | ) 74 | def test_string_accessor(ser, dser, func, kwargs): 75 | ser = ser.astype("string[pyarrow]") 76 | 77 | assert_eq(getattr(ser.str, func)(**kwargs), getattr(dser.str, func)(**kwargs)) 78 | 79 | if func in ( 80 | "contains", 81 | "endswith", 82 | "fullmatch", 83 | "isalnum", 84 | "isalpha", 85 | "isdecimal", 86 | "isdigit", 87 | "islower", 88 | "isspace", 89 | "istitle", 90 | "isupper", 91 | "startswith", 92 | "match", 93 | ): 94 | # This returns arrays and doesn't work in dask/dask either 95 | return 96 | 97 | ser.index = ser.values 98 | ser = ser.sort_index() 99 | dser = from_pandas(ser, npartitions=3) 100 | pdf_result = getattr(ser.index.str, func)(**kwargs) 101 | 102 | if func == "cat" and len(kwargs) > 0: 103 | # Doesn't work with others on Index 104 | return 105 | if isinstance(pdf_result, pd.DataFrame): 106 | assert_eq( 107 | getattr(dser.index.str, func)(**kwargs), pdf_result, check_index=False 108 | ) 109 | else: 110 | assert_eq(getattr(dser.index.str, func)(**kwargs), pdf_result) 111 | 112 | 113 | def test_str_accessor_cat(ser, dser): 114 | sol = ser.str.cat(ser.str.upper(), sep=":") 115 | assert_eq(dser.str.cat(dser.str.upper(), sep=":"), sol) 116 | assert_eq(dser.str.cat(ser.str.upper(), sep=":"), sol) 117 | assert_eq( 118 | dser.str.cat([dser.str.upper(), ser.str.lower()], sep=":"), 119 | ser.str.cat([ser.str.upper(), ser.str.lower()], sep=":"), 120 | ) 121 | assert_eq(dser.str.cat(sep=":"), ser.str.cat(sep=":")) 122 | 123 | for o in ["foo", ["foo"]]: 124 | with pytest.raises(TypeError): 125 | dser.str.cat(o) 126 | 127 | 128 | @pytest.mark.parametrize("index", [None, [0]], ids=["range_index", "other index"]) 129 | def test_str_split_(index): 130 | df = pd.DataFrame({"a": ["a\nb"]}, index=index) 131 | ddf = from_pandas(df, npartitions=1) 132 | 133 | pd_a = df["a"].str.split("\n", n=1, expand=True) 134 | dd_a = ddf["a"].str.split("\n", n=1, expand=True) 135 | 136 | assert_eq(dd_a, pd_a) 137 | 138 | 139 | def test_str_accessor_not_available(): 140 | pdf = pd.DataFrame({"a": [1, 2, 3]}) 141 | df = from_pandas(pdf, npartitions=2) 142 | # Not available on invalid dtypes 143 | with pytest.raises(AttributeError, match=".str accessor"): 144 | df.a.str 145 | 146 | assert "str" not in dir(df.a) 147 | 148 | 149 | def test_partition(): 150 | df = DataFrame.from_dict({"A": ["A|B", "C|D"]}, npartitions=2)["A"].str.partition( 151 | "|" 152 | ) 153 | result = df[1] 154 | expected = pd.DataFrame.from_dict({"A": ["A|B", "C|D"]})["A"].str.partition("|")[1] 155 | assert_eq(result, expected) 156 | -------------------------------------------------------------------------------- /dask_expr/tests/test_ufunc.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import numpy as np 3 | import pytest 4 | 5 | from dask_expr import Index, from_pandas 6 | from dask_expr.tests._util import _backend_library, assert_eq 7 | 8 | pd = _backend_library() 9 | 10 | 11 | @pytest.fixture 12 | def pdf(): 13 | pdf = pd.DataFrame({"x": range(100)}) 14 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions 15 | yield pdf 16 | 17 | 18 | @pytest.fixture 19 | def df(pdf): 20 | yield from_pandas(pdf, npartitions=10) 21 | 22 | 23 | def test_ufunc(df, pdf): 24 | ufunc = "conj" 25 | dafunc = getattr(da, ufunc) 26 | npfunc = getattr(np, ufunc) 27 | 28 | pandas_type = pdf.__class__ 29 | dask_type = df.__class__ 30 | 31 | assert isinstance(dafunc(df), dask_type) 32 | assert_eq(dafunc(df), npfunc(pdf)) 33 | 34 | if isinstance(npfunc, np.ufunc): 35 | assert isinstance(npfunc(df), dask_type) 36 | else: 37 | assert isinstance(npfunc(df), pandas_type) 38 | assert_eq(npfunc(df), npfunc(pdf)) 39 | 40 | # applying Dask ufunc to normal Series triggers computation 41 | assert isinstance(dafunc(pdf), pandas_type) 42 | assert_eq(dafunc(df), npfunc(pdf)) 43 | 44 | assert isinstance(dafunc(df.index), Index) 45 | assert_eq(dafunc(df.index), npfunc(pdf.index)) 46 | 47 | if isinstance(npfunc, np.ufunc): 48 | assert isinstance(npfunc(df.index), Index) 49 | else: 50 | assert isinstance(npfunc(df.index), pd.Index) 51 | 52 | assert_eq(npfunc(df.index), npfunc(pdf.index)) 53 | 54 | 55 | def test_ufunc_with_2args(pdf, df): 56 | ufunc = "logaddexp" 57 | dafunc = getattr(da, ufunc) 58 | npfunc = getattr(np, ufunc) 59 | 60 | pandas_type = pdf.__class__ 61 | pdf2 = pdf.sort_index(ascending=False) 62 | dask_type = df.__class__ 63 | df2 = from_pandas(pdf2, npartitions=8) 64 | # applying Dask ufunc doesn't trigger computation 65 | assert isinstance(dafunc(df, df2), dask_type) 66 | assert_eq(dafunc(df, df2), npfunc(pdf, pdf2)) 67 | 68 | # should be fine with pandas as a second arg, too 69 | assert isinstance(dafunc(df, pdf2), dask_type) 70 | assert_eq(dafunc(df, pdf2), npfunc(pdf, pdf2)) 71 | 72 | # applying NumPy ufunc is lazy 73 | if isinstance(npfunc, np.ufunc): 74 | assert isinstance(npfunc(df, df2), dask_type) 75 | assert isinstance(npfunc(df, pdf2), dask_type) 76 | else: 77 | assert isinstance(npfunc(df, df2), pandas_type) 78 | assert isinstance(npfunc(df, pdf2), pandas_type) 79 | 80 | assert_eq(npfunc(df, df2), npfunc(pdf, pdf2)) 81 | assert_eq(npfunc(df, pdf), npfunc(pdf, pdf2)) 82 | 83 | # applying Dask ufunc to normal Series triggers computation 84 | assert isinstance(dafunc(pdf, pdf2), pandas_type) 85 | assert_eq(dafunc(pdf, pdf2), npfunc(pdf, pdf2)) 86 | 87 | 88 | @pytest.mark.parametrize( 89 | "ufunc", 90 | [ 91 | np.mean, 92 | np.std, 93 | np.sum, 94 | np.cumsum, 95 | np.cumprod, 96 | np.var, 97 | np.min, 98 | np.max, 99 | np.all, 100 | np.any, 101 | np.prod, 102 | ], 103 | ) 104 | def test_reducers(pdf, df, ufunc): 105 | assert_eq(ufunc(pdf, axis=0), ufunc(df, axis=0)) 106 | 107 | 108 | def test_clip(pdf, df): 109 | assert_eq(np.clip(pdf, 10, 20), np.clip(df, 10, 20)) 110 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=62.6", "versioneer[toml]==0.28"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dask-expr" 7 | description = "High Level Expressions for Dask " 8 | maintainers = [{name = "Matthew Rocklin", email = "mrocklin@gmail.com"}] 9 | license = {text = "BSD"} 10 | keywords = ["dask pandas"] 11 | classifiers = [ 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "License :: OSI Approved :: BSD License", 15 | "Operating System :: OS Independent", 16 | "Programming Language :: Python", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3 :: Only", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Programming Language :: Python :: 3.13", 23 | "Topic :: Scientific/Engineering", 24 | "Topic :: System :: Distributed Computing", 25 | ] 26 | readme = "README.md" 27 | requires-python = ">=3.10" 28 | dependencies = [ 29 | "dask == 2024.12.1", 30 | "pyarrow>=14.0.1", 31 | "pandas >= 2", 32 | ] 33 | 34 | dynamic = ["version"] 35 | 36 | [project.optional-dependencies] 37 | analyze = ["crick", "distributed", "graphviz"] 38 | 39 | [project.urls] 40 | "Source code" = "https://github.com/dask-contrib/dask-expr/" 41 | 42 | [tool.setuptools.packages.find] 43 | exclude = ["*tests*"] 44 | namespaces = false 45 | 46 | [tool.coverage.run] 47 | omit = [ 48 | "*/test_*.py", 49 | ] 50 | source = ["dask_expr"] 51 | 52 | [tool.coverage.report] 53 | # Regexes for lines to exclude from consideration 54 | exclude_lines = [ 55 | "pragma: no cover", 56 | "raise AssertionError", 57 | "raise NotImplementedError", 58 | ] 59 | ignore_errors = true 60 | 61 | [tool.versioneer] 62 | VCS = "git" 63 | style = "pep440" 64 | versionfile_source = "dask_expr/_version.py" 65 | versionfile_build = "dask_expr/_version.py" 66 | tag_prefix = "v" 67 | parentdir_prefix = "dask_expr-" 68 | 69 | 70 | [tool.pytest.ini_options] 71 | addopts = "-v -rsxfE --durations=10 --color=yes" 72 | filterwarnings = [ 73 | 'ignore:Passing a BlockManager to DataFrame is deprecated and will raise in a future version. Use public APIs instead:DeprecationWarning', # https://github.com/apache/arrow/issues/35081 74 | 'ignore:The previous implementation of stack is deprecated and will be removed in a future version of pandas\.:FutureWarning', 75 | 'error:\nA value is trying to be set on a copy of a slice from a DataFrame', 76 | ] 77 | xfail_strict = true 78 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | sections = FUTURE,STDLIB,THIRDPARTY,DISTRIBUTED,FIRSTPARTY,LOCALFOLDER 3 | profile = black 4 | skip_gitignore = true 5 | force_to_top = true 6 | default_section = THIRDPARTY 7 | 8 | 9 | [aliases] 10 | test = pytest 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import annotations 4 | 5 | import versioneer 6 | from setuptools import setup 7 | 8 | setup( 9 | version=versioneer.get_version(), 10 | cmdclass=versioneer.get_cmdclass(), 11 | ) 12 | --------------------------------------------------------------------------------