├── .flake8
├── .gitattributes
├── .github
└── workflows
│ ├── conda.yml
│ ├── dask_test.yaml
│ ├── pre-commit.yml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE.txt
├── README.md
├── changes.md
├── ci
├── environment.yml
├── environment_released.yml
└── recipe
│ └── meta.yaml
├── codecov.yml
├── dask_expr
├── __init__.py
├── _accessor.py
├── _backends.py
├── _categorical.py
├── _collection.py
├── _concat.py
├── _core.py
├── _cumulative.py
├── _datetime.py
├── _describe.py
├── _dispatch.py
├── _dummies.py
├── _expr.py
├── _groupby.py
├── _indexing.py
├── _interchange.py
├── _merge.py
├── _merge_asof.py
├── _quantile.py
├── _quantiles.py
├── _reductions.py
├── _repartition.py
├── _resample.py
├── _rolling.py
├── _shuffle.py
├── _str_accessor.py
├── _util.py
├── _version.py
├── array
│ ├── __init__.py
│ ├── _creation.py
│ ├── blockwise.py
│ ├── core.py
│ ├── random.py
│ ├── rechunk.py
│ ├── reductions.py
│ ├── slicing.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_array.py
│ │ └── test_creation.py
├── datasets.py
├── diagnostics
│ ├── __init__.py
│ ├── _analyze.py
│ ├── _analyze_plugin.py
│ └── _explain.py
├── io
│ ├── __init__.py
│ ├── _delayed.py
│ ├── bag.py
│ ├── csv.py
│ ├── hdf.py
│ ├── io.py
│ ├── json.py
│ ├── orc.py
│ ├── parquet.py
│ ├── records.py
│ ├── sql.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_delayed.py
│ │ ├── test_distributed.py
│ │ ├── test_from_pandas.py
│ │ ├── test_io.py
│ │ ├── test_parquet.py
│ │ └── test_sql.py
└── tests
│ ├── __init__.py
│ ├── _util.py
│ ├── test_align_partitions.py
│ ├── test_categorical.py
│ ├── test_collection.py
│ ├── test_concat.py
│ ├── test_core.py
│ ├── test_cumulative.py
│ ├── test_datasets.py
│ ├── test_datetime.py
│ ├── test_describe.py
│ ├── test_diagnostics.py
│ ├── test_distributed.py
│ ├── test_dummies.py
│ ├── test_format.py
│ ├── test_fusion.py
│ ├── test_groupby.py
│ ├── test_indexing.py
│ ├── test_interchange.py
│ ├── test_map_partitions_overlap.py
│ ├── test_merge.py
│ ├── test_merge_asof.py
│ ├── test_partitioning_knowledge.py
│ ├── test_predicate_pushdown.py
│ ├── test_quantiles.py
│ ├── test_reductions.py
│ ├── test_repartition.py
│ ├── test_resample.py
│ ├── test_reshape.py
│ ├── test_rolling.py
│ ├── test_shuffle.py
│ ├── test_string_accessor.py
│ └── test_ufunc.py
├── demo.ipynb
├── pyproject.toml
├── setup.cfg
└── setup.py
/.flake8:
--------------------------------------------------------------------------------
1 | # flake8 doesn't support pyproject.toml yet https://github.com/PyCQA/flake8/issues/234
2 | [flake8]
3 | # References:
4 | # https://flake8.readthedocs.io/en/latest/user/configuration.html
5 | # https://flake8.readthedocs.io/en/latest/user/error-codes.html
6 | # https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes
7 | exclude = __init__.py
8 | ignore =
9 | # Extra space in brackets
10 | E20
11 | # Multiple spaces around ","
12 | E231,E241
13 | # Comments
14 | E26
15 | # Import formatting
16 | E4
17 | # Comparing types instead of isinstance
18 | E721
19 | # Assigning lambda expression
20 | E731
21 | # Ambiguous variable names
22 | E741
23 | # Line break before binary operator
24 | W503
25 | # Line break after binary operator
26 | W504
27 | # Redefinition of unused 'loop' from line 10
28 | F811
29 | # No explicit stacklevel in warnings.warn. FIXME we should correct this in the code
30 | B028
31 |
32 | max-line-length = 120
33 | per-file-ignores =
34 | *_test.py:
35 | # Do not call assert False since python -O removes these calls
36 | B011,
37 | **/tests/*:
38 | # Do not call assert False since python -O removes these calls
39 | B011,
40 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | dask_expr/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/.github/workflows/conda.yml:
--------------------------------------------------------------------------------
1 | name: Conda build
2 | on:
3 | push:
4 | branches:
5 | - main
6 | tags:
7 | - "*"
8 | pull_request:
9 | paths:
10 | - setup.py
11 | - ci/recipe/**
12 | - .github/workflows/conda.yml
13 | - pyproject.toml
14 |
15 | # When this workflow is queued, automatically cancel any previous running
16 | # or pending jobs from the same branch
17 | concurrency:
18 | group: conda-${{ github.ref }}
19 | cancel-in-progress: true
20 |
21 | # Required shell entrypoint to have properly activated conda environments
22 | defaults:
23 | run:
24 | shell: bash -l {0}
25 |
26 | jobs:
27 | conda:
28 | name: Build (and upload)
29 | runs-on: ubuntu-latest
30 | steps:
31 | - uses: actions/checkout@v4.1.1
32 | with:
33 | fetch-depth: 0
34 | - name: Set up Python
35 | uses: conda-incubator/setup-miniconda@v3.0.4
36 | with:
37 | miniforge-version: latest
38 | use-mamba: true
39 | python-version: 3.9
40 | channel-priority: strict
41 | - name: Install dependencies
42 | run: |
43 | mamba install -c conda-forge boa conda-verify
44 |
45 | which python
46 | pip list
47 | mamba list
48 | - name: Build conda packages
49 | run: |
50 | # suffix for pre-release package versions
51 | export VERSION_SUFFIX=a`date +%y%m%d`
52 |
53 | # conda search for the latest dask-core pre-release
54 | arr=($(conda search --override-channels -c dask/label/dev dask-core | tail -n 1))
55 |
56 | # extract dask-core pre-release version / build
57 | export DASK_CORE_VERSION=${arr[1]}
58 |
59 | # distributed pre-release build
60 | conda mambabuild ci/recipe \
61 | --channel dask/label/dev \
62 | --no-anaconda-upload \
63 | --output-folder .
64 | - name: Upload conda packages
65 | if: github.event_name == 'push' && github.repository == 'dask/dask-expr'
66 | env:
67 | ANACONDA_API_TOKEN: ${{ secrets.DASK_CONDA_TOKEN }}
68 | run: |
69 | # install anaconda for upload
70 | mamba install -c conda-forge anaconda-client
71 |
72 | anaconda upload --label dev noarch/*.tar.bz2
73 |
--------------------------------------------------------------------------------
/.github/workflows/dask_test.yaml:
--------------------------------------------------------------------------------
1 | name: Tests / dask/dask
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | workflow_dispatch:
9 |
10 | # When this workflow is queued, automatically cancel any previous running
11 | # or pending jobs from the same branch
12 | concurrency:
13 | group: dask_test-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | # Required shell entrypoint to have properly activated conda environments
17 | defaults:
18 | run:
19 | shell: bash -l {0}
20 |
21 | jobs:
22 | test:
23 | runs-on: ubuntu-latest
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | python-version: ["3.10", "3.12"]
28 | environment-file: [ci/environment.yml]
29 |
30 | steps:
31 | - uses: actions/checkout@v4
32 | with:
33 | fetch-depth: 0 # Needed by codecov.io
34 |
35 | - name: Get current date
36 | id: date
37 | run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}"
38 |
39 | - name: Install Environment
40 | uses: mamba-org/setup-micromamba@v1
41 | with:
42 | environment-file: ${{ matrix.environment-file }}
43 | create-args: python=${{ matrix.python-version }}
44 | # Wipe cache every 24 hours or whenever environment.yml changes. This means it
45 | # may take up to a day before changes to unpinned packages are picked up.
46 | # To force a cache refresh, change the hardcoded numerical suffix below.
47 | cache-environment-key: environment-${{ steps.date.outputs.date }}-0
48 |
49 | - name: Install current main versions of dask
50 | run: python -m pip install git+https://github.com/dask/dask
51 |
52 | - name: Install current main versions of distributed
53 | run: python -m pip install git+https://github.com/dask/distributed
54 |
55 | - name: Install dask-expr
56 | run: python -m pip install -e . --no-deps
57 |
58 | - name: Print dask versions
59 | # Output of `micromamba list` is buggy for pip-installed packages
60 | run: pip list | grep -E 'dask|distributed'
61 |
62 | - name: Run Dask DataFrame tests
63 | run: python -c "import dask.dataframe as dd; dd.test_dataframe()"
64 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: Linting
2 |
3 | on:
4 | push:
5 | branches: main
6 | pull_request:
7 | branches: main
8 |
9 | jobs:
10 | checks:
11 | name: pre-commit hooks
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: actions/setup-python@v5
16 | with:
17 | python-version: '3.10'
18 | - uses: pre-commit/action@v3.0.0
19 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | workflow_dispatch:
9 |
10 | # When this workflow is queued, automatically cancel any previous running
11 | # or pending jobs from the same branch
12 | concurrency:
13 | group: test-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | # Required shell entrypoint to have properly activated conda environments
17 | defaults:
18 | run:
19 | shell: bash -l {0}
20 |
21 | jobs:
22 | test:
23 | runs-on: ubuntu-latest
24 | strategy:
25 | fail-fast: false
26 | matrix:
27 | python-version: ["3.10", "3.11", "3.12", "3.13"]
28 | environment-file: [ci/environment.yml]
29 |
30 | steps:
31 | - uses: actions/checkout@v4
32 | with:
33 | fetch-depth: 0 # Needed by codecov.io
34 |
35 | - name: Get current date
36 | id: date
37 | run: echo "date=$(date +%Y-%m-%d)" >> "${GITHUB_OUTPUT}"
38 |
39 | - name: Install Environment
40 | uses: mamba-org/setup-micromamba@v1
41 | with:
42 | environment-file: ${{ matrix.environment-file }}
43 | create-args: python=${{ matrix.python-version }}
44 | # Wipe cache every 24 hours or whenever environment.yml changes. This means it
45 | # may take up to a day before changes to unpinned packages are picked up.
46 | # To force a cache refresh, change the hardcoded numerical suffix below.
47 | cache-environment-key: environment-${{ steps.date.outputs.date }}-1
48 |
49 | - name: Install current main versions of dask
50 | run: python -m pip install git+https://github.com/dask/dask
51 | if: ${{ matrix.environment-file == 'ci/environment.yml' }}
52 |
53 | - name: Install current main versions of distributed
54 | run: python -m pip install git+https://github.com/dask/distributed
55 | if: ${{ matrix.environment-file == 'ci/environment.yml' }}
56 |
57 | - name: Install dask-expr
58 | run: python -m pip install -e . --no-deps
59 |
60 | - name: Print dask versions
61 | # Output of `micromamba list` is buggy for pip-installed packages
62 | run: pip list | grep -E 'dask|distributed'
63 |
64 | - name: Run tests
65 | run: py.test -n auto --verbose --cov=dask_expr --cov-report=xml
66 |
67 | - name: Coverage
68 | uses: codecov/codecov-action@v3
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | build/
3 | dist/
4 | *.egg-info/
5 | bench/shakespeare.txt
6 | .coverage
7 | *.sw?
8 | .DS_STORE
9 | \.tox/
10 | .idea/
11 | .ipynb_checkpoints/
12 | coverage.xml
13 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: end-of-file-fixer
6 | - id: debug-statements
7 | - repo: https://github.com/MarcoGorelli/absolufy-imports
8 | rev: v0.3.1
9 | hooks:
10 | - id: absolufy-imports
11 | name: absolufy-imports
12 | - repo: https://github.com/pycqa/isort
13 | rev: 5.12.0
14 | hooks:
15 | - id: isort
16 | language_version: python3
17 | - repo: https://github.com/asottile/pyupgrade
18 | rev: v3.3.2
19 | hooks:
20 | - id: pyupgrade
21 | args:
22 | - --py310-plus
23 | - repo: https://github.com/psf/black
24 | rev: 23.3.0
25 | hooks:
26 | - id: black
27 | language_version: python3
28 | exclude: versioneer.py
29 | args:
30 | - --target-version=py310
31 | - repo: https://github.com/pycqa/flake8
32 | rev: 6.0.0
33 | hooks:
34 | - id: flake8
35 | language_version: python3
36 | additional_dependencies:
37 | # NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive
38 | # dependency. Make sure to update flake8-bugbear manually on a regular basis.
39 | - flake8-bugbear==22.8.23
40 | - repo: https://github.com/codespell-project/codespell
41 | rev: v2.2.4
42 | hooks:
43 | - id: codespell
44 | types_or: [rst, markdown]
45 | files: docs
46 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright Dask-expr development team
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Dask Expressions
2 | ================
3 |
4 | The Implementation is now the default and only backend for Dask DataFrames and was
5 | moved to https://github.com/dask/dask.
6 |
7 | This repository is no longer maintained.
8 |
--------------------------------------------------------------------------------
/ci/environment.yml:
--------------------------------------------------------------------------------
1 | name: dask-expr
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - pytest
6 | - pytest-cov
7 | - pytest-xdist
8 | - dask # overridden by git tip below
9 | - pyarrow>=14.0.1
10 | - pandas>=2
11 | - pre-commit
12 | - sqlalchemy
13 | - xarray
14 | - pip:
15 | - git+https://github.com/dask/distributed
16 | - git+https://github.com/dask/dask
17 |
--------------------------------------------------------------------------------
/ci/environment_released.yml:
--------------------------------------------------------------------------------
1 | name: dask-expr
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - pytest
6 | - pytest-cov
7 | - pytest-xdist
8 | - dask
9 | - pyarrow>=14.0.1
10 | - pandas>=2
11 | - pre-commit
12 |
--------------------------------------------------------------------------------
/ci/recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | {% set major_minor_patch = environ.get('GIT_DESCRIBE_TAG', '0.0.0.dev').lstrip('v').split('.') %}
2 | {% set new_patch = major_minor_patch[2] | int + 1 %}
3 | {% set version = (major_minor_patch[:2] + [new_patch]) | join('.') + environ.get('VERSION_SUFFIX', '') %}
4 | {% set dask_version = environ.get('DASK_CORE_VERSION', '0.0.0.dev') %}
5 |
6 |
7 | package:
8 | name: dask-expr
9 | version: {{ version }}
10 |
11 | source:
12 | git_url: ../..
13 |
14 | build:
15 | number: {{ GIT_DESCRIBE_NUMBER }}
16 | noarch: python
17 | string: py_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }}
18 | script: {{ PYTHON }} -m pip install . -vv
19 |
20 | requirements:
21 | host:
22 | - python >=3.10
23 | - pip
24 | - dask-core {{ dask_version }}
25 | - versioneer =0.28
26 | - setuptools >=62.6
27 | - tomli # [py<311]
28 | run:
29 | - python >=3.10
30 | - {{ pin_compatible('dask-core', max_pin='x.x.x.x') }}
31 | - pyarrow
32 | - pandas >=2
33 |
34 | test:
35 | imports:
36 | - dask_expr
37 | requires:
38 | - pip
39 | commands:
40 | - pip check
41 |
42 | about:
43 | home: https://github.com/dask/dask-expr
44 | summary: 'High Level Expressions for Dask'
45 | description: |
46 | High Level Expressions for Dask
47 | license: BSD-3-Clause
48 | license_family: BSD
49 | license_file: LICENSE.txt
50 | doc_url: https://github.com/dask/dask-expr/blob/main/README.md
51 | dev_url: https://github.com/dask/dask-expr
52 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: yes
3 |
4 | coverage:
5 | precision: 2
6 | round: down
7 | range: "87...100"
8 |
9 | status:
10 | project:
11 | default:
12 | target: 87%
13 | threshold: 1%
14 | patch: no
15 | changes: no
16 |
17 | comment: off
18 |
--------------------------------------------------------------------------------
/dask_expr/__init__.py:
--------------------------------------------------------------------------------
1 | import dask.dataframe
2 |
3 | from dask_expr import _version, datasets
4 | from dask_expr._collection import *
5 | from dask_expr._dispatch import get_collection_type
6 | from dask_expr._dummies import get_dummies
7 | from dask_expr._groupby import Aggregation
8 | from dask_expr.io._delayed import from_delayed
9 | from dask_expr.io.bag import to_bag
10 | from dask_expr.io.csv import to_csv
11 | from dask_expr.io.hdf import read_hdf, to_hdf
12 | from dask_expr.io.json import read_json, to_json
13 | from dask_expr.io.orc import read_orc, to_orc
14 | from dask_expr.io.parquet import to_parquet
15 | from dask_expr.io.records import to_records
16 | from dask_expr.io.sql import read_sql, read_sql_query, read_sql_table, to_sql
17 |
18 | __version__ = _version.get_versions()["version"]
19 |
--------------------------------------------------------------------------------
/dask_expr/_accessor.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from dask.dataframe.accessor import _bind_method, _bind_property, maybe_wrap_pandas
4 | from dask.dataframe.dispatch import make_meta, meta_nonempty
5 |
6 | from dask_expr._expr import Elemwise, Expr
7 |
8 |
9 | class Accessor:
10 | """
11 | Base class for pandas Accessor objects cat, dt, and str.
12 |
13 | Notes
14 | -----
15 | Subclasses should define ``_accessor_name``, ``_accessor_methods``, and
16 | ``_accessor_properties``.
17 | """
18 |
19 | def __init__(self, series):
20 | from dask_expr import Series
21 |
22 | if not isinstance(series, Series):
23 | raise ValueError("Accessor cannot be initialized")
24 |
25 | series_meta = series._meta
26 | if hasattr(series_meta, "to_series"): # is index-like
27 | series_meta = series_meta.to_series()
28 | meta = getattr(series_meta, self._accessor_name)
29 |
30 | self._meta = meta
31 | self._series = series
32 |
33 | def __init_subclass__(cls, **kwargs):
34 | """Bind all auto-generated methods & properties"""
35 | import pandas as pd
36 |
37 | super().__init_subclass__(**kwargs)
38 | pd_cls = getattr(pd.Series, cls._accessor_name)
39 | for item in cls._accessor_methods:
40 | attr, min_version = item if isinstance(item, tuple) else (item, None)
41 | if not hasattr(cls, attr):
42 | _bind_method(cls, pd_cls, attr, min_version)
43 | for item in cls._accessor_properties:
44 | attr, min_version = item if isinstance(item, tuple) else (item, None)
45 | if not hasattr(cls, attr):
46 | _bind_property(cls, pd_cls, attr, min_version)
47 |
48 | @staticmethod
49 | def _delegate_property(obj, accessor, attr):
50 | out = getattr(getattr(obj, accessor, obj), attr)
51 | return maybe_wrap_pandas(obj, out)
52 |
53 | @staticmethod
54 | def _delegate_method(obj, accessor, attr, args, kwargs):
55 | out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs)
56 | return maybe_wrap_pandas(obj, out)
57 |
58 | def _function_map(self, attr, *args, **kwargs):
59 | from dask_expr._collection import Index, new_collection
60 |
61 | if isinstance(self._series, Index):
62 | return new_collection(
63 | FunctionMapIndex(self._series, self._accessor_name, attr, args, kwargs)
64 | )
65 |
66 | return new_collection(
67 | FunctionMap(self._series, self._accessor_name, attr, args, kwargs)
68 | )
69 |
70 | def _property_map(self, attr, *args, **kwargs):
71 | from dask_expr._collection import Index, new_collection
72 |
73 | if isinstance(self._series, Index):
74 | return new_collection(
75 | PropertyMapIndex(self._series, self._accessor_name, attr)
76 | )
77 |
78 | return new_collection(PropertyMap(self._series, self._accessor_name, attr))
79 |
80 |
81 | class PropertyMap(Elemwise):
82 | _parameters = [
83 | "frame",
84 | "accessor",
85 | "attr",
86 | ]
87 |
88 | @staticmethod
89 | def operation(obj, accessor, attr):
90 | out = getattr(getattr(obj, accessor, obj), attr)
91 | return maybe_wrap_pandas(obj, out)
92 |
93 |
94 | class PropertyMapIndex(PropertyMap):
95 | def _divisions(self):
96 | # TODO: We can do better here
97 | return (None,) * (self.frame.npartitions + 1)
98 |
99 |
100 | class FunctionMap(Elemwise):
101 | _parameters = ["frame", "accessor", "attr", "args", "kwargs"]
102 |
103 | @functools.cached_property
104 | def _meta(self):
105 | args = [
106 | meta_nonempty(op._meta) if isinstance(op, Expr) else op for op in self._args
107 | ]
108 | return make_meta(self.operation(*args, **self._kwargs))
109 |
110 | @staticmethod
111 | def operation(obj, accessor, attr, args, kwargs):
112 | out = getattr(getattr(obj, accessor, obj), attr)(*args, **kwargs)
113 | return maybe_wrap_pandas(obj, out)
114 |
115 |
116 | class FunctionMapIndex(FunctionMap):
117 | def _divisions(self):
118 | # TODO: We can do better here
119 | return (None,) * (self.frame.npartitions + 1)
120 |
--------------------------------------------------------------------------------
/dask_expr/_backends.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from dask.backends import CreationDispatch
6 | from dask.dataframe.backends import DataFrameBackendEntrypoint
7 | from dask.dataframe.dispatch import to_pandas_dispatch
8 |
9 | from dask_expr._dispatch import get_collection_type
10 | from dask_expr._expr import ToBackend
11 |
12 | try:
13 | import sparse
14 |
15 | sparse_installed = True
16 | except ImportError:
17 | sparse_installed = False
18 |
19 |
20 | try:
21 | import scipy.sparse as sp
22 |
23 | scipy_installed = True
24 | except ImportError:
25 | scipy_installed = False
26 |
27 |
28 | dataframe_creation_dispatch = CreationDispatch(
29 | module_name="dataframe",
30 | default="pandas",
31 | entrypoint_root="dask_expr",
32 | entrypoint_class=DataFrameBackendEntrypoint,
33 | name="dataframe_creation_dispatch",
34 | )
35 |
36 |
37 | class ToPandasBackend(ToBackend):
38 | @staticmethod
39 | def operation(df, options):
40 | return to_pandas_dispatch(df, **options)
41 |
42 | def _simplify_down(self):
43 | if isinstance(self.frame._meta, (pd.DataFrame, pd.Series, pd.Index)):
44 | # We already have pandas data
45 | return self.frame
46 |
47 |
48 | class PandasBackendEntrypoint(DataFrameBackendEntrypoint):
49 | """Pandas-Backend Entrypoint Class for Dask-Expressions
50 |
51 | Note that all DataFrame-creation functions are defined
52 | and registered 'in-place'.
53 | """
54 |
55 | @classmethod
56 | def to_backend(cls, data, **kwargs):
57 | from dask_expr._collection import new_collection
58 |
59 | return new_collection(ToPandasBackend(data, kwargs))
60 |
61 |
62 | dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint())
63 |
64 |
65 | @get_collection_type.register(pd.Series)
66 | def get_collection_type_series(_):
67 | from dask_expr._collection import Series
68 |
69 | return Series
70 |
71 |
72 | @get_collection_type.register(pd.DataFrame)
73 | def get_collection_type_dataframe(_):
74 | from dask_expr._collection import DataFrame
75 |
76 | return DataFrame
77 |
78 |
79 | @get_collection_type.register(pd.Index)
80 | def get_collection_type_index(_):
81 | from dask_expr._collection import Index
82 |
83 | return Index
84 |
85 |
86 | def create_array_collection(expr):
87 | # This is hacky and an abstraction leak, but utilizing get_collection_type
88 | # to infer that we want to create an array is the only way that is guaranteed
89 | # to be a general solution.
90 | # We can get rid of this when we have an Array expression
91 | from dask.highlevelgraph import HighLevelGraph
92 | from dask.layers import Blockwise
93 |
94 | result = expr.optimize()
95 | dsk = result.__dask_graph__()
96 | name = result._name
97 | meta = result._meta
98 | divisions = result.divisions
99 | import dask.array as da
100 |
101 | chunks = ((np.nan,) * (len(divisions) - 1),) + tuple((d,) for d in meta.shape[1:])
102 | if len(chunks) > 1:
103 | if isinstance(dsk, HighLevelGraph):
104 | layer = dsk.layers[name]
105 | else:
106 | # dask-expr provides a dict only
107 | layer = dsk
108 | if isinstance(layer, Blockwise):
109 | layer.new_axes["j"] = chunks[1][0]
110 | layer.output_indices = layer.output_indices + ("j",)
111 | else:
112 | from dask._task_spec import Alias, Task
113 |
114 | suffix = (0,) * (len(chunks) - 1)
115 | for i in range(len(chunks[0])):
116 | task = layer.get((name, i))
117 | new_key = (name, i) + suffix
118 | if isinstance(task, Task):
119 | task = Alias(new_key, task.key)
120 | layer[new_key] = task
121 | return da.Array(dsk, name=name, chunks=chunks, dtype=meta.dtype)
122 |
123 |
124 | @get_collection_type.register(np.ndarray)
125 | def get_collection_type_array(_):
126 | return create_array_collection
127 |
128 |
129 | if sparse_installed:
130 |
131 | @get_collection_type.register(sparse.COO)
132 | def get_collection_type_array(_):
133 | return create_array_collection
134 |
135 |
136 | if scipy_installed:
137 |
138 | @get_collection_type.register(sp.csr_matrix)
139 | def get_collection_type_array(_):
140 | return create_array_collection
141 |
142 |
143 | @get_collection_type.register(object)
144 | def get_collection_type_object(_):
145 | from dask_expr._collection import Scalar
146 |
147 | return Scalar
148 |
149 |
150 | ######################################
151 | # cuDF: Pandas Dataframes on the GPU #
152 | ######################################
153 |
154 |
155 | @get_collection_type.register_lazy("cudf")
156 | def _register_cudf():
157 | import dask_cudf # noqa: F401
158 |
--------------------------------------------------------------------------------
/dask_expr/_categorical.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import pandas as pd
4 | from dask.dataframe.categorical import (
5 | _categorize_block,
6 | _get_categories,
7 | _get_categories_agg,
8 | )
9 | from dask.dataframe.utils import (
10 | AttributeNotImplementedError,
11 | clear_known_categories,
12 | has_known_categories,
13 | )
14 | from dask.utils import M
15 |
16 | from dask_expr._accessor import Accessor, PropertyMap
17 | from dask_expr._expr import Blockwise, Elemwise, Projection
18 | from dask_expr._reductions import ApplyConcatApply
19 |
20 |
21 | class CategoricalAccessor(Accessor):
22 | """
23 | Accessor object for categorical properties of the Series values.
24 |
25 | Examples
26 | --------
27 | >>> s.cat.categories # doctest: +SKIP
28 |
29 | Notes
30 | -----
31 | Attributes that depend only on metadata are eager
32 |
33 | * categories
34 | * ordered
35 |
36 | Attributes depending on the entire dataset are lazy
37 |
38 | * codes
39 | * ...
40 |
41 | So `df.a.cat.categories` <=> `df.a._meta.cat.categories`
42 | So `df.a.cat.codes` <=> `df.a.map_partitions(lambda x: x.cat.codes)`
43 | """
44 |
45 | _accessor_name = "cat"
46 | _accessor_methods = (
47 | "add_categories",
48 | "as_ordered",
49 | "as_unordered",
50 | "remove_categories",
51 | "rename_categories",
52 | "reorder_categories",
53 | "set_categories",
54 | )
55 | _accessor_properties = ()
56 |
57 | @property
58 | def known(self):
59 | """Whether the categories are fully known"""
60 | return has_known_categories(self._series)
61 |
62 | def as_known(self, **kwargs):
63 | """Ensure the categories in this series are known.
64 |
65 | If the categories are known, this is a no-op. If unknown, the
66 | categories are computed, and a new series with known categories is
67 | returned.
68 |
69 | Parameters
70 | ----------
71 | kwargs
72 | Keywords to pass on to the call to `compute`.
73 | """
74 | if self.known:
75 | return self._series
76 | from dask_expr._collection import new_collection
77 |
78 | categories = (
79 | new_collection(PropertyMap(self._series, "cat", "categories"))
80 | .unique()
81 | .compute()
82 | )
83 | return self.set_categories(categories.values)
84 |
85 | def as_unknown(self):
86 | """Ensure the categories in this series are unknown"""
87 | if not self.known:
88 | return self._series
89 |
90 | from dask_expr import new_collection
91 |
92 | return new_collection(AsUnknown(self._series))
93 |
94 | @property
95 | def ordered(self):
96 | """Whether the categories have an ordered relationship"""
97 | return self._delegate_property(self._series._meta, "cat", "ordered")
98 |
99 | @property
100 | def categories(self):
101 | """The categories of this categorical.
102 |
103 | If categories are unknown, an error is raised"""
104 | if not self.known:
105 | msg = (
106 | "`df.column.cat.categories` with unknown categories is not "
107 | "supported. Please use `column.cat.as_known()` or "
108 | "`df.categorize()` beforehand to ensure known categories"
109 | )
110 | raise AttributeNotImplementedError(msg)
111 | return self._delegate_property(self._series._meta, "cat", "categories")
112 |
113 | @property
114 | def codes(self):
115 | """The codes of this categorical.
116 |
117 | If categories are unknown, an error is raised"""
118 | if not self.known:
119 | msg = (
120 | "`df.column.cat.codes` with unknown categories is not "
121 | "supported. Please use `column.cat.as_known()` or "
122 | "`df.categorize()` beforehand to ensure known categories"
123 | )
124 | raise AttributeNotImplementedError(msg)
125 | from dask_expr._collection import new_collection
126 |
127 | return new_collection(PropertyMap(self._series, "cat", "codes"))
128 |
129 | def remove_unused_categories(self):
130 | """
131 | Removes categories which are not used
132 |
133 | Notes
134 | -----
135 | This method requires a full scan of the data to compute the
136 | unique values, which can be expensive.
137 | """
138 | # get the set of used categories
139 | present = self._series.dropna().unique()
140 | present = pd.Index(present.compute())
141 |
142 | if isinstance(self._series._meta, pd.CategoricalIndex):
143 | meta_cat = self._series._meta
144 | else:
145 | meta_cat = self._series._meta.cat
146 |
147 | # Reorder to keep cat:code relationship, filtering unused (-1)
148 | ordered, mask = present.reindex(meta_cat.categories)
149 | if mask is None:
150 | # PANDAS-23963: old and new categories match.
151 | return self._series
152 |
153 | new_categories = ordered[mask != -1]
154 | return self.set_categories(new_categories)
155 |
156 |
157 | class AsUnknown(Elemwise):
158 | _parameters = ["frame"]
159 | operation = M.copy
160 |
161 | @functools.cached_property
162 | def _meta(self):
163 | return clear_known_categories(self.frame._meta)
164 |
165 |
166 | class Categorize(Blockwise):
167 | _parameters = ["frame", "categories", "index"]
168 | operation = staticmethod(_categorize_block)
169 | _projection_passthrough = True
170 |
171 | @functools.cached_property
172 | def _meta(self):
173 | return _categorize_block(
174 | self.frame._meta, self.operand("categories"), self.operand("index")
175 | )
176 |
177 | def _simplify_up(self, parent, dependents):
178 | result = super()._simplify_up(parent, dependents)
179 | if result is None:
180 | return result
181 | # pop potentially dropped columns from categories
182 | cats = self.operand("categories")
183 | cats = {k: v for k, v in cats.items() if k in result.frame.columns}
184 | return Categorize(result.frame, cats, result.operand("index"))
185 |
186 |
187 | class GetCategories(ApplyConcatApply):
188 | _parameters = ["frame", "columns", "index", "split_every"]
189 |
190 | chunk = staticmethod(_get_categories)
191 | aggregate = staticmethod(_get_categories_agg)
192 |
193 | @property
194 | def chunk_kwargs(self):
195 | return {"columns": self.operand("columns"), "index": self.operand("index")}
196 |
197 | @functools.cached_property
198 | def _meta(self):
199 | return ({}, pd.Series())
200 |
201 | def _simplify_down(self):
202 | if set(self.frame.columns) == set(self.operand("columns")):
203 | return None
204 |
205 | return GetCategories(
206 | Projection(self.frame, self.operand("columns")),
207 | columns=self.operand("columns"),
208 | index=self.operand("index"),
209 | split_every=self.operand("split_every"),
210 | )
211 |
--------------------------------------------------------------------------------
/dask_expr/_cumulative.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 |
4 | import pandas as pd
5 | from dask.dataframe import methods
6 | from dask.utils import M
7 |
8 | from dask_expr._expr import Blockwise, Expr, Projection, plain_column_projection
9 |
10 |
11 | class CumulativeAggregations(Expr):
12 | _parameters = ["frame", "axis", "skipna"]
13 | _defaults = {"axis": None}
14 |
15 | chunk_operation = None
16 | aggregate_operation = None
17 | neutral_element = None
18 |
19 | def _divisions(self):
20 | return self.frame._divisions()
21 |
22 | @functools.cached_property
23 | def _meta(self):
24 | return self.frame._meta
25 |
26 | def _lower(self):
27 | chunks = CumulativeBlockwise(
28 | self.frame, self.axis, self.skipna, self.chunk_operation
29 | )
30 | chunks_last = TakeLast(chunks, self.skipna)
31 | return CumulativeFinalize(
32 | chunks, chunks_last, self.aggregate_operation, self.neutral_element
33 | )
34 |
35 | def _simplify_up(self, parent, dependents):
36 | if isinstance(parent, Projection):
37 | return plain_column_projection(self, parent, dependents)
38 |
39 |
40 | class CumulativeBlockwise(Blockwise):
41 | _parameters = ["frame", "axis", "skipna", "operation"]
42 | _defaults = {"skipna": True, "axis": None}
43 | _projection_passthrough = True
44 |
45 | @functools.cached_property
46 | def _meta(self):
47 | return self.frame._meta
48 |
49 | @functools.cached_property
50 | def operation(self):
51 | return self.operand("operation")
52 |
53 | @functools.cached_property
54 | def _args(self) -> list:
55 | return self.operands[:-1]
56 |
57 |
58 | class TakeLast(Blockwise):
59 | _parameters = ["frame", "skipna"]
60 | _projection_passthrough = True
61 |
62 | @staticmethod
63 | def operation(a, skipna=True):
64 | if skipna:
65 | if a.ndim == 1 and (a.empty or a.isna().all()):
66 | return None
67 | a = a.ffill()
68 | return a.tail(n=1).squeeze()
69 |
70 |
71 | class CumulativeFinalize(Expr):
72 | _parameters = ["frame", "previous_partitions", "aggregator", "neutral_element"]
73 |
74 | def _divisions(self):
75 | return self.frame._divisions()
76 |
77 | @functools.cached_property
78 | def _meta(self):
79 | return self.frame._meta
80 |
81 | def _layer(self) -> dict:
82 | dsk = {}
83 | frame, previous_partitions = self.frame, self.previous_partitions
84 | dsk[(self._name, 0)] = (frame._name, 0)
85 |
86 | intermediate_name = self._name + "-intermediate"
87 | for i in range(1, self.frame.npartitions):
88 | if i == 1:
89 | dsk[(intermediate_name, i)] = (previous_partitions._name, i - 1)
90 | else:
91 | # aggregate with previous cumulation results
92 | dsk[(intermediate_name, i)] = (
93 | cumulative_wrapper_intermediate,
94 | self.aggregator,
95 | (intermediate_name, i - 1),
96 | (previous_partitions._name, i - 1),
97 | self.neutral_element,
98 | )
99 | dsk[(self._name, i)] = (
100 | cumulative_wrapper,
101 | self.aggregator,
102 | (self.frame._name, i),
103 | (intermediate_name, i),
104 | self.neutral_element,
105 | )
106 | return dsk
107 |
108 |
109 | def cumulative_wrapper(func, x, y, neutral_element):
110 | if isinstance(y, pd.Series) and len(y) == 0:
111 | y = neutral_element
112 | return func(x, y)
113 |
114 |
115 | def cumulative_wrapper_intermediate(func, x, y, neutral_element):
116 | if isinstance(y, pd.Series) and len(y) == 0:
117 | y = neutral_element
118 | return methods._cum_aggregate_apply(func, x, y)
119 |
120 |
121 | class CumSum(CumulativeAggregations):
122 | chunk_operation = M.cumsum
123 | aggregate_operation = staticmethod(methods.cumsum_aggregate)
124 | neutral_element = 0
125 |
126 |
127 | class CumProd(CumulativeAggregations):
128 | chunk_operation = M.cumprod
129 | aggregate_operation = staticmethod(methods.cumprod_aggregate)
130 | neutral_element = 1
131 |
132 |
133 | class CumMax(CumulativeAggregations):
134 | chunk_operation = M.cummax
135 | aggregate_operation = staticmethod(methods.cummax_aggregate)
136 | neutral_element = -math.inf
137 |
138 |
139 | class CumMin(CumulativeAggregations):
140 | chunk_operation = M.cummin
141 | aggregate_operation = staticmethod(methods.cummin_aggregate)
142 | neutral_element = math.inf
143 |
--------------------------------------------------------------------------------
/dask_expr/_datetime.py:
--------------------------------------------------------------------------------
1 | from dask_expr._accessor import Accessor
2 |
3 |
4 | class DatetimeAccessor(Accessor):
5 | """Accessor object for datetimelike properties of the Series values.
6 |
7 | Examples
8 | --------
9 |
10 | >>> s.dt.microsecond # doctest: +SKIP
11 | """
12 |
13 | _accessor_name = "dt"
14 |
15 | _accessor_methods = (
16 | "ceil",
17 | "day_name",
18 | "floor",
19 | "isocalendar",
20 | "month_name",
21 | "normalize",
22 | "round",
23 | "strftime",
24 | "to_period",
25 | "to_pydatetime",
26 | "to_pytimedelta",
27 | "to_timestamp",
28 | "total_seconds",
29 | "tz_convert",
30 | "tz_localize",
31 | )
32 |
33 | _accessor_properties = (
34 | "components",
35 | "date",
36 | "day",
37 | "day_of_week",
38 | "day_of_year",
39 | "dayofweek",
40 | "dayofyear",
41 | "days",
42 | "days_in_month",
43 | "daysinmonth",
44 | "end_time",
45 | "freq",
46 | "hour",
47 | "is_leap_year",
48 | "is_month_end",
49 | "is_month_start",
50 | "is_quarter_end",
51 | "is_quarter_start",
52 | "is_year_end",
53 | "is_year_start",
54 | "microsecond",
55 | "microseconds",
56 | "minute",
57 | "month",
58 | "nanosecond",
59 | "nanoseconds",
60 | "quarter",
61 | "qyear",
62 | "second",
63 | "seconds",
64 | "start_time",
65 | "time",
66 | "timetz",
67 | "tz",
68 | "week",
69 | "weekday",
70 | "weekofyear",
71 | "year",
72 | )
73 |
--------------------------------------------------------------------------------
/dask_expr/_describe.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | from dask.dataframe.dispatch import make_meta, meta_nonempty
5 | from dask.dataframe.methods import (
6 | describe_nonnumeric_aggregate,
7 | describe_numeric_aggregate,
8 | )
9 | from pandas.core.dtypes.common import is_datetime64_any_dtype, is_timedelta64_dtype
10 |
11 | from dask_expr._expr import Blockwise, DropnaSeries, Filter, Head, Sqrt, ToNumeric
12 | from dask_expr._quantile import SeriesQuantile
13 | from dask_expr._reductions import Reduction, Size, ValueCounts
14 |
15 |
16 | class DescribeNumeric(Reduction):
17 | _parameters = ["frame", "split_every", "percentiles", "percentile_method"]
18 | _defaults = {
19 | "percentiles": None,
20 | "split_every": None,
21 | "percentile_method": "default",
22 | }
23 |
24 | @functools.cached_property
25 | def _meta(self):
26 | return make_meta(meta_nonempty(self.frame._meta).describe())
27 |
28 | def _divisions(self):
29 | return (None, None)
30 |
31 | def _lower(self):
32 | frame = self.frame
33 | if self.percentiles is None:
34 | percentiles = self.percentiles or [0.25, 0.5, 0.75]
35 | else:
36 | percentiles = np.array(self.percentiles)
37 | percentiles = np.append(percentiles, 0.5)
38 | percentiles = np.unique(percentiles)
39 | percentiles = list(percentiles)
40 |
41 | is_td_col = is_timedelta64_dtype(frame._meta.dtype)
42 | is_dt_col = is_datetime64_any_dtype(frame._meta.dtype)
43 | if is_td_col or is_dt_col:
44 | frame = ToNumeric(DropnaSeries(frame))
45 |
46 | stats = [
47 | frame.count(split_every=self.split_every),
48 | frame.mean(split_every=self.split_every),
49 | Sqrt(frame.var(split_every=self.split_every)),
50 | frame.min(split_every=self.split_every),
51 | SeriesQuantile(frame, q=percentiles, method=self.percentile_method),
52 | frame.max(split_every=self.split_every),
53 | ]
54 | try:
55 | unit = getattr(self.frame._meta.array, "unit", None)
56 | except AttributeError:
57 | # cudf Series has no array attribute
58 | unit = None
59 | return DescribeNumericAggregate(
60 | self.frame._meta.name,
61 | is_td_col,
62 | is_dt_col,
63 | unit,
64 | *stats,
65 | )
66 |
67 |
68 | class DescribeNumericAggregate(Blockwise):
69 | _parameters = ["name", "is_timedelta_col", "is_datetime_col", "unit"]
70 | _defaults = {"is_timedelta_col": False, "is_datetime_col": False}
71 |
72 | def _broadcast_dep(self, dep):
73 | return dep.npartitions == 1
74 |
75 | @staticmethod
76 | def operation(name, is_timedelta_col, is_datetime_col, unit, *stats):
77 | return describe_numeric_aggregate(
78 | stats, name, is_timedelta_col, is_datetime_col, unit
79 | )
80 |
81 |
82 | class DescribeNonNumeric(DescribeNumeric):
83 | _parameters = ["frame", "split_every"]
84 |
85 | def _lower(self):
86 | frame = self.frame
87 | vcounts = ValueCounts(frame, split_every=self.split_every, sort=True)
88 | count_unique = Size(Filter(vcounts, vcounts > 0))
89 | stats = [
90 | count_unique,
91 | frame.count(split_every=self.split_every),
92 | Head(vcounts, n=1),
93 | ]
94 | return DescribeNonNumericAggregate(frame._meta.name, *stats)
95 |
96 |
97 | class DescribeNonNumericAggregate(DescribeNumericAggregate):
98 | _parameters = ["name"]
99 | _defaults = {}
100 |
101 | @staticmethod
102 | def operation(name, *stats):
103 | return describe_nonnumeric_aggregate(stats, name)
104 |
--------------------------------------------------------------------------------
/dask_expr/_dispatch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dask.utils import Dispatch
4 |
5 | get_collection_type = Dispatch("get_collection_type")
6 |
--------------------------------------------------------------------------------
/dask_expr/_dummies.py:
--------------------------------------------------------------------------------
1 | import dask.dataframe.methods as methods
2 | import pandas as pd
3 | from dask.dataframe.utils import has_known_categories
4 | from dask.utils import get_meta_library
5 |
6 | from dask_expr._collection import DataFrame, Series, new_collection
7 | from dask_expr._expr import Blockwise
8 |
9 |
10 | def get_dummies(
11 | data,
12 | prefix=None,
13 | prefix_sep="_",
14 | dummy_na=False,
15 | columns=None,
16 | sparse=False,
17 | drop_first=False,
18 | dtype=bool,
19 | **kwargs,
20 | ):
21 | """
22 | Convert categorical variable into dummy/indicator variables.
23 |
24 | Data must have category dtype to infer result's ``columns``.
25 |
26 | Parameters
27 | ----------
28 | data : Series, or DataFrame
29 | For Series, the dtype must be categorical.
30 | For DataFrame, at least one column must be categorical.
31 | prefix : string, list of strings, or dict of strings, default None
32 | String to append DataFrame column names.
33 | Pass a list with length equal to the number of columns
34 | when calling get_dummies on a DataFrame. Alternatively, `prefix`
35 | can be a dictionary mapping column names to prefixes.
36 | prefix_sep : string, default '_'
37 | If appending prefix, separator/delimiter to use. Or pass a
38 | list or dictionary as with `prefix.`
39 | dummy_na : bool, default False
40 | Add a column to indicate NaNs, if False NaNs are ignored.
41 | columns : list-like, default None
42 | Column names in the DataFrame to be encoded.
43 | If `columns` is None then all the columns with
44 | `category` dtype will be converted.
45 | sparse : bool, default False
46 | Whether the dummy columns should be sparse or not. Returns
47 | SparseDataFrame if `data` is a Series or if all columns are included.
48 | Otherwise returns a DataFrame with some SparseBlocks.
49 |
50 | .. versionadded:: 0.18.2
51 |
52 | drop_first : bool, default False
53 | Whether to get k-1 dummies out of k categorical levels by removing the
54 | first level.
55 |
56 | dtype : dtype, default bool
57 | Data type for new columns. Only a single dtype is allowed.
58 |
59 | .. versionadded:: 0.18.2
60 |
61 | Returns
62 | -------
63 | dummies : DataFrame
64 |
65 | Examples
66 | --------
67 | Dask's version only works with Categorical data, as this is the only way to
68 | know the output shape without computing all the data.
69 |
70 | >>> import pandas as pd
71 | >>> import dask.dataframe as dd
72 | >>> s = dd.from_pandas(pd.Series(list('abca')), npartitions=2)
73 | >>> dd.get_dummies(s)
74 | Traceback (most recent call last):
75 | ...
76 | NotImplementedError: `get_dummies` with non-categorical dtypes is not supported...
77 |
78 | With categorical data:
79 |
80 | >>> s = dd.from_pandas(pd.Series(list('abca'), dtype='category'), npartitions=2)
81 | >>> dd.get_dummies(s) # doctest: +NORMALIZE_WHITESPACE
82 | Dask DataFrame Structure:
83 | a b c
84 | npartitions=2
85 | 0 bool bool bool
86 | 2 ... ... ...
87 | 3 ... ... ...
88 | Dask Name: get_dummies, 2 graph layers
89 | >>> dd.get_dummies(s).compute() # doctest: +ELLIPSIS
90 | a b c
91 | 0 True False False
92 | 1 False True False
93 | 2 False False True
94 | 3 True False False
95 |
96 | See Also
97 | --------
98 | pandas.get_dummies
99 | """
100 | if isinstance(data, (pd.Series, pd.DataFrame)):
101 | return pd.get_dummies(
102 | data,
103 | prefix=prefix,
104 | prefix_sep=prefix_sep,
105 | dummy_na=dummy_na,
106 | columns=columns,
107 | sparse=sparse,
108 | drop_first=drop_first,
109 | dtype=dtype,
110 | **kwargs,
111 | )
112 |
113 | not_cat_msg = (
114 | "`get_dummies` with non-categorical dtypes is not "
115 | "supported. Please use `df.categorize()` beforehand to "
116 | "convert to categorical dtype."
117 | )
118 |
119 | unknown_cat_msg = (
120 | "`get_dummies` with unknown categories is not "
121 | "supported. Please use `column.cat.as_known()` or "
122 | "`df.categorize()` beforehand to ensure known "
123 | "categories"
124 | )
125 |
126 | if isinstance(data, Series):
127 | if not methods.is_categorical_dtype(data):
128 | raise NotImplementedError(not_cat_msg)
129 | if not has_known_categories(data):
130 | raise NotImplementedError(unknown_cat_msg)
131 | elif isinstance(data, DataFrame):
132 | if columns is None:
133 | if (data.dtypes == "object").any():
134 | raise NotImplementedError(not_cat_msg)
135 | if (data.dtypes == "string").any():
136 | raise NotImplementedError(not_cat_msg)
137 | columns = data._meta.select_dtypes(include=["category"]).columns
138 | else:
139 | if not all(methods.is_categorical_dtype(data[c]) for c in columns):
140 | raise NotImplementedError(not_cat_msg)
141 |
142 | if not all(has_known_categories(data[c]) for c in columns):
143 | raise NotImplementedError(unknown_cat_msg)
144 |
145 | return new_collection(
146 | GetDummies(
147 | data, prefix, prefix_sep, dummy_na, columns, sparse, drop_first, dtype
148 | )
149 | )
150 |
151 |
152 | class GetDummies(Blockwise):
153 | _parameters = [
154 | "frame",
155 | "prefix",
156 | "prefix_sep",
157 | "dummy_na",
158 | "columns",
159 | "sparse",
160 | "drop_first",
161 | "dtype",
162 | ]
163 | _defaults = {
164 | "prefix": None,
165 | "prefix_sep": "_",
166 | "dummy_na": False,
167 | "columns": None,
168 | "sparse": False,
169 | "drop_first": False,
170 | "dtype": bool,
171 | }
172 | # cudf has extra kwargs after `columns`
173 | _keyword_only = ["sparse", "drop_first", "dtype"]
174 |
175 | @staticmethod
176 | def operation(df, *args, **kwargs):
177 | return get_meta_library(df).get_dummies(df, *args, **kwargs)
178 |
--------------------------------------------------------------------------------
/dask_expr/_interchange.py:
--------------------------------------------------------------------------------
1 | from dask.dataframe._compat import is_string_dtype
2 | from dask.dataframe.dispatch import is_categorical_dtype
3 | from pandas.core.interchange.dataframe_protocol import DtypeKind
4 |
5 | from dask_expr._collection import DataFrame
6 |
7 | _NP_KINDS = {
8 | "i": DtypeKind.INT,
9 | "u": DtypeKind.UINT,
10 | "f": DtypeKind.FLOAT,
11 | "b": DtypeKind.BOOL,
12 | "U": DtypeKind.STRING,
13 | "M": DtypeKind.DATETIME,
14 | "m": DtypeKind.DATETIME,
15 | }
16 |
17 |
18 | class DaskDataFrameInterchange:
19 | def __init__(
20 | self, df: DataFrame, nan_as_null: bool = False, allow_copy: bool = True
21 | ) -> None:
22 | self._df = df
23 | self._nan_as_null = nan_as_null
24 | self._allow_copy = allow_copy
25 |
26 | def get_columns(self):
27 | return [DaskColumn(self._df[name]) for name in self._df.columns]
28 |
29 | def column_names(self):
30 | return self._df.columns
31 |
32 | def num_columns(self) -> int:
33 | return len(self._df.columns)
34 |
35 | def num_rows(self) -> int:
36 | return len(self._df)
37 |
38 |
39 | class DaskColumn:
40 | def __init__(self, column, allow_copy: bool = True) -> None:
41 | self._col = column
42 | self._allow_copy = allow_copy
43 |
44 | def dtype(self) -> tuple[DtypeKind, None, None, None]:
45 | dtype = self._col.dtype
46 |
47 | if is_categorical_dtype(dtype):
48 | return (
49 | DtypeKind.CATEGORICAL,
50 | None,
51 | None,
52 | None,
53 | )
54 | elif is_string_dtype(dtype):
55 | return (
56 | DtypeKind.STRING,
57 | None,
58 | None,
59 | None,
60 | )
61 | else:
62 | return _NP_KINDS.get(dtype.kind, None), None, None, None
63 |
--------------------------------------------------------------------------------
/dask_expr/_quantile.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | from dask.dataframe.dispatch import make_meta, meta_nonempty
5 | from dask.utils import import_required, is_series_like
6 |
7 | from dask_expr._expr import DropnaSeries, Expr
8 |
9 |
10 | def _finalize_scalar_result(cons, *args, **kwargs):
11 | return cons(*args, **kwargs)[0]
12 |
13 |
14 | class SeriesQuantile(Expr):
15 | _parameters = ["frame", "q", "method"]
16 | _defaults = {"method": "default"}
17 |
18 | @functools.cached_property
19 | def q(self):
20 | q = np.array(self.operand("q"))
21 | if q.ndim > 0:
22 | assert len(q) > 0, f"must provide non-empty q={q}"
23 | q.sort(kind="mergesort")
24 | return q
25 | return np.asarray([self.operand("q")])
26 |
27 | @functools.cached_property
28 | def method(self):
29 | if self.operand("method") == "default":
30 | return "dask"
31 | else:
32 | return self.operand("method")
33 |
34 | @functools.cached_property
35 | def _meta(self):
36 | meta = self.frame._meta
37 | if not is_series_like(self.frame._meta):
38 | meta = meta.to_series()
39 | return make_meta(meta_nonempty(meta).quantile(self.operand("q")))
40 |
41 | def _divisions(self):
42 | if is_series_like(self._meta):
43 | return (np.min(self.q), np.max(self.q))
44 | return (None, None)
45 |
46 | @functools.cached_property
47 | def _constructor(self):
48 | meta = self.frame._meta
49 | if not is_series_like(self.frame._meta):
50 | meta = meta.to_series()
51 | return meta._constructor
52 |
53 | @functools.cached_property
54 | def _finalizer(self):
55 | if is_series_like(self._meta):
56 | return lambda tsk: (
57 | self._constructor,
58 | tsk,
59 | self.q,
60 | None,
61 | self.frame._meta.name,
62 | )
63 | else:
64 | return lambda tsk: (_finalize_scalar_result, self._constructor, tsk, [0])
65 |
66 | def _lower(self):
67 | frame = DropnaSeries(self.frame)
68 | if self.method == "tdigest":
69 | return SeriesQuantileTdigest(
70 | frame, self.operand("q"), self.operand("method")
71 | )
72 | else:
73 | return SeriesQuantileDask(frame, self.operand("q"), self.operand("method"))
74 |
75 |
76 | class SeriesQuantileTdigest(SeriesQuantile):
77 | @functools.cached_property
78 | def _meta(self):
79 | import_required(
80 | "crick", "crick is a required dependency for using the tdigest method."
81 | )
82 | return super()._meta
83 |
84 | def _layer(self) -> dict:
85 | from dask.array.percentile import _percentiles_from_tdigest, _tdigest_chunk
86 |
87 | dsk = {}
88 | for i in range(self.frame.npartitions):
89 | dsk[("chunk-" + self._name, i)] = (
90 | _tdigest_chunk,
91 | (getattr, (self.frame._name, i), "values"),
92 | )
93 |
94 | dsk[(self._name, 0)] = self._finalizer(
95 | (_percentiles_from_tdigest, self.q * 100, sorted(dsk))
96 | )
97 | return dsk
98 |
99 | def _lower(self):
100 | return None
101 |
102 |
103 | class SeriesQuantileDask(SeriesQuantile):
104 | def _layer(self) -> dict:
105 | from dask.array.dispatch import percentile_lookup as _percentile
106 | from dask.array.percentile import merge_percentiles
107 |
108 | dsk = {}
109 | # Add 0 and 100 during calculation for more robust behavior (hopefully)
110 | calc_qs = np.pad(self.q * 100, 1, mode="constant")
111 | calc_qs[-1] = 100
112 |
113 | for i in range(self.frame.npartitions):
114 | dsk[("chunk-" + self._name, i)] = (
115 | _percentile,
116 | (self.frame._name, i),
117 | calc_qs,
118 | )
119 | dsk[(self._name, 0)] = self._finalizer(
120 | (
121 | merge_percentiles,
122 | self.q * 100,
123 | [calc_qs] * self.frame.npartitions,
124 | sorted(dsk),
125 | "lower",
126 | None,
127 | False,
128 | )
129 | )
130 | return dsk
131 |
132 | def _lower(self):
133 | return None
134 |
--------------------------------------------------------------------------------
/dask_expr/_quantiles.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | import toolz
5 | from dask.dataframe.partitionquantiles import (
6 | create_merge_tree,
7 | dtype_info,
8 | merge_and_compress_summaries,
9 | percentiles_summary,
10 | process_val_weights,
11 | )
12 | from dask.tokenize import tokenize
13 | from dask.utils import random_state_data
14 |
15 | from dask_expr._expr import Expr
16 |
17 |
18 | class RepartitionQuantiles(Expr):
19 | _parameters = ["frame", "input_npartitions", "upsample", "random_state"]
20 | _defaults = {"upsample": 1.0, "random_state": None}
21 |
22 | @functools.cached_property
23 | def _meta(self):
24 | return self.frame._meta
25 |
26 | @property
27 | def npartitions(self):
28 | return 1
29 |
30 | def _divisions(self):
31 | return 0.0, 1.0
32 |
33 | def __dask_postcompute__(self):
34 | return toolz.first, ()
35 |
36 | def _layer(self):
37 | import pandas as pd
38 |
39 | qs = np.linspace(0, 1, self.input_npartitions + 1)
40 | if self.random_state is None:
41 | random_state = int(tokenize(self.operands), 16) % np.iinfo(np.int32).max
42 | else:
43 | random_state = self.random_state
44 | state_data = random_state_data(self.frame.npartitions, random_state)
45 |
46 | keys = self.frame.__dask_keys__()
47 | dtype_dsk = {(self._name, 0, 0): (dtype_info, keys[0])}
48 |
49 | percentiles_dsk = {
50 | (self._name, 1, i): (
51 | percentiles_summary,
52 | key,
53 | self.frame.npartitions,
54 | self.input_npartitions,
55 | self.upsample,
56 | state,
57 | )
58 | for i, (state, key) in enumerate(zip(state_data, keys))
59 | }
60 |
61 | merge_dsk = create_merge_tree(
62 | merge_and_compress_summaries, sorted(percentiles_dsk), self._name, 2
63 | )
64 | if not merge_dsk:
65 | # Compress the data even if we only have one partition
66 | merge_dsk = {
67 | (self._name, 2, 0): (
68 | merge_and_compress_summaries,
69 | [list(percentiles_dsk)[0]],
70 | )
71 | }
72 |
73 | merged_key = max(merge_dsk)
74 | last_dsk = {
75 | (self._name, 0): (
76 | pd.Series,
77 | (
78 | process_val_weights,
79 | merged_key,
80 | self.input_npartitions,
81 | (self._name, 0, 0),
82 | ),
83 | qs,
84 | None,
85 | self.frame._meta.name,
86 | )
87 | }
88 | return {**dtype_dsk, **percentiles_dsk, **merge_dsk, **last_dsk}
89 |
--------------------------------------------------------------------------------
/dask_expr/_str_accessor.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | from dask.dataframe.dispatch import make_meta, meta_nonempty
4 |
5 | from dask_expr._accessor import Accessor, FunctionMap
6 | from dask_expr._expr import Blockwise
7 | from dask_expr._reductions import Reduction
8 |
9 |
10 | class StringAccessor(Accessor):
11 | """Accessor object for string properties of the Series values.
12 |
13 | Examples
14 | --------
15 |
16 | >>> s.str.lower() # doctest: +SKIP
17 | """
18 |
19 | _accessor_name = "str"
20 |
21 | _accessor_methods = (
22 | "capitalize",
23 | "casefold",
24 | "center",
25 | "contains",
26 | "count",
27 | "decode",
28 | "encode",
29 | "endswith",
30 | "extract",
31 | "extractall",
32 | "find",
33 | "findall",
34 | "fullmatch",
35 | "get",
36 | "index",
37 | "isalnum",
38 | "isalpha",
39 | "isdecimal",
40 | "isdigit",
41 | "islower",
42 | "isnumeric",
43 | "isspace",
44 | "istitle",
45 | "isupper",
46 | "join",
47 | "len",
48 | "ljust",
49 | "lower",
50 | "lstrip",
51 | "match",
52 | "normalize",
53 | "pad",
54 | "partition",
55 | "removeprefix",
56 | "removesuffix",
57 | "repeat",
58 | "replace",
59 | "rfind",
60 | "rindex",
61 | "rjust",
62 | "rpartition",
63 | "rstrip",
64 | "slice",
65 | "slice_replace",
66 | "startswith",
67 | "strip",
68 | "swapcase",
69 | "title",
70 | "translate",
71 | "upper",
72 | "wrap",
73 | "zfill",
74 | )
75 | _accessor_properties = ()
76 |
77 | def _split(self, method, pat=None, n=-1, expand=False):
78 | from dask_expr import new_collection
79 |
80 | if expand:
81 | if n == -1:
82 | raise NotImplementedError(
83 | "To use the expand parameter you must specify the number of "
84 | "expected splits with the n= parameter. Usually n splits "
85 | "result in n+1 output columns."
86 | )
87 | return new_collection(
88 | SplitMap(
89 | self._series,
90 | self._accessor_name,
91 | method,
92 | (),
93 | {"pat": pat, "n": n, "expand": expand},
94 | )
95 | )
96 | return self._function_map(method, pat=pat, n=n, expand=expand)
97 |
98 | def split(self, pat=None, n=-1, expand=False):
99 | """Known inconsistencies: ``expand=True`` with unknown ``n`` will raise a ``NotImplementedError``."""
100 | return self._split("split", pat=pat, n=n, expand=expand)
101 |
102 | def rsplit(self, pat=None, n=-1, expand=False):
103 | return self._split("rsplit", pat=pat, n=n, expand=expand)
104 |
105 | def cat(self, others=None, sep=None, na_rep=None):
106 | import pandas as pd
107 |
108 | from dask_expr._collection import Index, Series, new_collection
109 |
110 | if others is None:
111 | return new_collection(Cat(self._series, sep, na_rep))
112 |
113 | valid_types = (Series, Index, pd.Series, pd.Index)
114 | if isinstance(others, valid_types):
115 | others = [others]
116 | elif not all(isinstance(a, valid_types) for a in others):
117 | raise TypeError("others must be Series/Index")
118 |
119 | return new_collection(CatBlockwise(self._series, sep, na_rep, *others))
120 |
121 | def __getitem__(self, index):
122 | return self._function_map("__getitem__", index)
123 |
124 |
125 | class CatBlockwise(Blockwise):
126 | _parameters = ["frame", "sep", "na_rep"]
127 | _keyword_only = ["sep", "na_rep"]
128 |
129 | @property
130 | def _args(self) -> list:
131 | return [self.frame] + self.operands[len(self._parameters) :]
132 |
133 | @staticmethod
134 | def operation(ser, *args, **kwargs):
135 | return ser.str.cat(list(args), **kwargs)
136 |
137 |
138 | class Cat(Reduction):
139 | _parameters = ["frame", "sep", "na_rep"]
140 |
141 | @property
142 | def chunk_kwargs(self):
143 | return {"sep": self.sep, "na_rep": self.na_rep}
144 |
145 | @property
146 | def combine_kwargs(self):
147 | return self.chunk_kwargs
148 |
149 | @property
150 | def aggregate_kwargs(self):
151 | return self.chunk_kwargs
152 |
153 | @staticmethod
154 | def reduction_chunk(ser, *args, **kwargs):
155 | return ser.str.cat(*args, **kwargs)
156 |
157 | @staticmethod
158 | def reduction_combine(ser, *args, **kwargs):
159 | return Cat.reduction_chunk(ser, *args, **kwargs)
160 |
161 | @staticmethod
162 | def reduction_aggregate(ser, *args, **kwargs):
163 | return Cat.reduction_chunk(ser, *args, **kwargs)
164 |
165 |
166 | class SplitMap(FunctionMap):
167 | _parameters = ["frame", "accessor", "attr", "args", "kwargs"]
168 |
169 | @property
170 | def n(self):
171 | return self.kwargs["n"]
172 |
173 | @property
174 | def pat(self):
175 | return self.kwargs["pat"]
176 |
177 | @property
178 | def expand(self):
179 | return self.kwargs["expand"]
180 |
181 | @functools.cached_property
182 | def _meta(self):
183 | delimiter = " " if self.pat is None else self.pat
184 | meta = meta_nonempty(self.frame._meta)
185 | meta = self.frame._meta._constructor(
186 | [delimiter.join(["a"] * (self.n + 1))],
187 | index=meta.iloc[:1].index,
188 | )
189 | return make_meta(
190 | getattr(meta.str, self.attr)(n=self.n, expand=self.expand, pat=self.pat)
191 | )
192 |
--------------------------------------------------------------------------------
/dask_expr/_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from collections import OrderedDict, UserDict
5 | from collections.abc import Hashable, Iterable, Sequence
6 | from typing import Any, Literal, TypeVar, cast
7 |
8 | import dask
9 | import numpy as np
10 | import pandas as pd
11 | from dask import config
12 | from dask.dataframe._compat import is_string_dtype
13 | from dask.dataframe.core import is_dask_collection, is_dataframe_like, is_series_like
14 | from dask.tokenize import normalize_token, tokenize
15 | from dask.utils import get_default_shuffle_method
16 | from packaging.version import Version
17 |
18 | K = TypeVar("K", bound=Hashable)
19 | V = TypeVar("V")
20 |
21 | DASK_VERSION = Version(dask.__version__)
22 | DASK_GT_20231201 = DASK_VERSION > Version("2023.12.1")
23 | PANDAS_VERSION = Version(pd.__version__)
24 | PANDAS_GE_300 = PANDAS_VERSION.major >= 3
25 |
26 |
27 | def _calc_maybe_new_divisions(df, periods, freq):
28 | """Maybe calculate new divisions by periods of size freq
29 |
30 | Used to shift the divisions for the `shift` method. If freq isn't a fixed
31 | size (not anchored or relative), then the divisions are shifted
32 | appropriately.
33 |
34 | Returning None, indicates divisions ought to be cleared.
35 |
36 | Parameters
37 | ----------
38 | df : dd.DataFrame, dd.Series, or dd.Index
39 | periods : int
40 | The number of periods to shift.
41 | freq : DateOffset, timedelta, or time rule string
42 | The frequency to shift by.
43 | """
44 | if isinstance(freq, str):
45 | freq = pd.tseries.frequencies.to_offset(freq)
46 |
47 | is_offset = isinstance(freq, pd.DateOffset)
48 | if is_offset:
49 | if not isinstance(freq, pd.offsets.Tick):
50 | # Can't infer divisions on relative or anchored offsets, as
51 | # divisions may now split identical index value.
52 | # (e.g. index_partitions = [[1, 2, 3], [3, 4, 5]])
53 | return None # Would need to clear divisions
54 | if df.known_divisions:
55 | divs = pd.Series(range(len(df.divisions)), index=df.divisions)
56 | divisions = divs.shift(periods, freq=freq).index
57 | return tuple(divisions)
58 | return df.divisions
59 |
60 |
61 | def _validate_axis(axis=0, none_is_zero: bool = True) -> None | Literal[0, 1]:
62 | if axis not in (0, 1, "index", "columns", None):
63 | raise ValueError(f"No axis named {axis}")
64 | # convert to numeric axis
65 | numeric_axis: dict[str | None, Literal[0, 1]] = {"index": 0, "columns": 1}
66 | if none_is_zero:
67 | numeric_axis[None] = 0
68 |
69 | return numeric_axis.get(axis, axis)
70 |
71 |
72 | def _convert_to_list(column) -> list | None:
73 | if column is None or isinstance(column, list):
74 | pass
75 | elif isinstance(column, tuple):
76 | column = list(column)
77 | elif hasattr(column, "dtype"):
78 | column = column.tolist()
79 | else:
80 | column = [column]
81 | return column
82 |
83 |
84 | def is_scalar(x):
85 | # np.isscalar does not work for some pandas scalars, for example pd.NA
86 | if isinstance(x, (Sequence, Iterable)) and not isinstance(x, str):
87 | return False
88 | elif hasattr(x, "dtype"):
89 | return isinstance(x, np.ScalarType)
90 | if isinstance(x, dict):
91 | return False
92 | if isinstance(x, (str, int)) or x is None:
93 | return True
94 |
95 | from dask_expr._expr import Expr
96 |
97 | return not isinstance(x, Expr)
98 |
99 |
100 | def _tokenize_deterministic(*args, **kwargs) -> str:
101 | # Utility to be strict about deterministic tokens
102 | return tokenize(*args, ensure_deterministic=True, **kwargs)
103 |
104 |
105 | def _tokenize_partial(expr, ignore: list | None = None) -> str:
106 | # Helper function to "tokenize" the operands
107 | # that are not in the `ignore` list
108 | ignore = ignore or []
109 | return _tokenize_deterministic(
110 | *[
111 | op
112 | for i, op in enumerate(expr.operands)
113 | if i >= len(expr._parameters) or expr._parameters[i] not in ignore
114 | ]
115 | )
116 |
117 |
118 | class LRU(UserDict[K, V]):
119 | """Limited size mapping, evicting the least recently looked-up key when full"""
120 |
121 | def __init__(self, maxsize: float) -> None:
122 | super().__init__()
123 | self.data = OrderedDict()
124 | self.maxsize = maxsize
125 |
126 | def __getitem__(self, key: K) -> V:
127 | value = super().__getitem__(key)
128 | cast(OrderedDict, self.data).move_to_end(key)
129 | return value
130 |
131 | def __setitem__(self, key: K, value: V) -> None:
132 | if len(self) >= self.maxsize:
133 | cast(OrderedDict, self.data).popitem(last=False)
134 | super().__setitem__(key, value)
135 |
136 |
137 | class _BackendData:
138 | """Helper class to wrap backend data
139 |
140 | The primary purpose of this class is to provide
141 | caching outside the ``FromPandas`` class.
142 | """
143 |
144 | def __init__(self, data):
145 | self._data = data
146 | self._division_info = LRU(10)
147 |
148 | @functools.cached_property
149 | def _token(self):
150 | from dask_expr._util import _tokenize_deterministic
151 |
152 | return _tokenize_deterministic(self._data)
153 |
154 | def __len__(self):
155 | return len(self._data)
156 |
157 | def __getattr__(self, key: str) -> Any:
158 | try:
159 | return object.__getattribute__(self, key)
160 | except AttributeError:
161 | # Return the underlying backend attribute
162 | return getattr(self._data, key)
163 |
164 | def __reduce__(self):
165 | return type(self), (self._data,)
166 |
167 | def __deepcopy__(self, memodict=None):
168 | return type(self)(self._data.copy())
169 |
170 |
171 | @normalize_token.register(_BackendData)
172 | def normalize_data_wrapper(data):
173 | return data._token
174 |
175 |
176 | def _maybe_from_pandas(dfs):
177 | from dask_expr import from_pandas
178 |
179 | def _pd_series_or_dataframe(x):
180 | # `x` can be a cudf Series/DataFrame
181 | return not is_dask_collection(x) and (is_series_like(x) or is_dataframe_like(x))
182 |
183 | dfs = [from_pandas(df, 1) if _pd_series_or_dataframe(df) else df for df in dfs]
184 | return dfs
185 |
186 |
187 | def _get_shuffle_preferring_order(shuffle):
188 | if shuffle is not None:
189 | return shuffle
190 |
191 | # Choose tasks over disk since it keeps the order
192 | shuffle = get_default_shuffle_method()
193 | if shuffle == "disk":
194 | return "tasks"
195 |
196 | return shuffle
197 |
198 |
199 | def _raise_if_object_series(x, funcname):
200 | """
201 | Utility function to raise an error if an object column does not support
202 | a certain operation like `mean`.
203 | """
204 | if x.ndim == 1 and hasattr(x, "dtype"):
205 | if x.dtype == object:
206 | raise ValueError("`%s` not supported with object series" % funcname)
207 | elif is_string_dtype(x):
208 | raise ValueError("`%s` not supported with string series" % funcname)
209 |
210 |
211 | def _is_any_real_numeric_dtype(arr_or_dtype):
212 | try:
213 | from pandas.api.types import is_any_real_numeric_dtype
214 |
215 | return is_any_real_numeric_dtype(arr_or_dtype)
216 | except ImportError:
217 | # Temporary/soft pandas<2 support to enable cudf dev
218 | # TODO: Remove `try` block after 4/2024
219 | from pandas.api.types import is_bool_dtype, is_complex_dtype, is_numeric_dtype
220 |
221 | return (
222 | is_numeric_dtype(arr_or_dtype)
223 | and not is_complex_dtype(arr_or_dtype)
224 | and not is_bool_dtype(arr_or_dtype)
225 | )
226 |
227 |
228 | def get_specified_shuffle(shuffle_method):
229 | # Take the config shuffle if given, otherwise defer evaluation until optimize
230 | return shuffle_method or config.get("dataframe.shuffle.method", None)
231 |
--------------------------------------------------------------------------------
/dask_expr/array/__init__.py:
--------------------------------------------------------------------------------
1 | # isort: skip_file
2 |
3 | from dask_expr.array import random
4 | from dask_expr.array.core import Array, asarray, from_array
5 | from dask_expr.array.reductions import (
6 | mean,
7 | moment,
8 | nanmean,
9 | nanstd,
10 | nansum,
11 | nanvar,
12 | prod,
13 | std,
14 | sum,
15 | var,
16 | )
17 | from dask_expr.array._creation import arange, linspace, ones, empty, zeros
18 |
--------------------------------------------------------------------------------
/dask_expr/array/slicing.py:
--------------------------------------------------------------------------------
1 | import toolz
2 | from dask.array.optimization import fuse_slice
3 | from dask.array.slicing import normalize_slice, slice_array
4 | from dask.array.utils import meta_from_array
5 | from dask.utils import cached_property
6 |
7 | from dask_expr.array.core import Array
8 |
9 |
10 | class Slice(Array):
11 | _parameters = ["array", "index"]
12 |
13 | @property
14 | def _meta(self):
15 | return meta_from_array(self.array._meta, ndim=len(self.chunks))
16 |
17 | @cached_property
18 | def _info(self):
19 | return slice_array(
20 | self._name,
21 | self.array._name,
22 | self.array.chunks,
23 | self.index,
24 | self.array.dtype.itemsize,
25 | )
26 |
27 | def _layer(self):
28 | return self._info[0]
29 |
30 | @property
31 | def chunks(self):
32 | return self._info[1]
33 |
34 | def _simplify_down(self):
35 | if all(
36 | isinstance(idx, slice) and idx == slice(None, None, None)
37 | for idx in self.index
38 | ):
39 | return self.array
40 | if isinstance(self.array, Slice):
41 | return Slice(
42 | self.array.array,
43 | normalize_slice(
44 | fuse_slice(self.array.index, self.index), self.array.array.ndim
45 | ),
46 | )
47 |
48 | if isinstance(self.array, Elemwise):
49 | index = self.index + (slice(None),) * (self.ndim - len(self.index))
50 | args = []
51 | for arg, ind in toolz.partition(2, self.array.args):
52 | if ind is None:
53 | args.append(arg)
54 | else:
55 | idx = tuple(index[self.array.out_ind.index(i)] for i in ind)
56 | args.append(arg[idx])
57 | return Elemwise(*self.array.operands[: -len(args)], *args)
58 |
59 | if isinstance(self.array, Transpose):
60 | if any(isinstance(idx, (int)) or idx is None for idx in self.index):
61 | return None # can't handle changes in dimension
62 | else:
63 | index = self.index + (slice(None),) * (self.ndim - len(self.index))
64 | new = tuple(index[i] for i in self.array.axes)
65 | return self.array.substitute(self.array.array, self.array.array[new])
66 |
67 |
68 | from dask_expr.array.blockwise import Elemwise, Transpose
69 |
--------------------------------------------------------------------------------
/dask_expr/array/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/array/tests/__init__.py
--------------------------------------------------------------------------------
/dask_expr/array/tests/test_array.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import numpy as np
4 | import pytest
5 | from dask.array.utils import assert_eq
6 |
7 | import dask_expr.array as da
8 |
9 |
10 | def test_basic():
11 | x = np.random.random((10, 10))
12 | xx = da.from_array(x, chunks=(4, 4))
13 | xx._meta
14 | xx.chunks
15 | repr(xx)
16 |
17 | assert_eq(x, xx)
18 |
19 |
20 | def test_rechunk():
21 | a = np.random.random((10, 10))
22 | b = da.from_array(a, chunks=(4, 4))
23 | c = b.rechunk()
24 | assert c.npartitions == 1
25 | assert_eq(b, c)
26 |
27 | d = b.rechunk((3, 3))
28 | assert d.npartitions == 16
29 | assert_eq(d, a)
30 |
31 |
32 | def test_rechunk_optimize():
33 | a = np.random.random((10, 10))
34 | b = da.from_array(a, chunks=(4, 4))
35 |
36 | c = b.rechunk((2, 5)).rechunk((5, 2))
37 | d = b.rechunk((5, 2))
38 |
39 | assert c.optimize()._name == d.optimize()._name
40 |
41 | assert (
42 | b.T.rechunk((5, 2)).optimize()._name == da.from_array(a, chunks=(2, 5)).T._name
43 | )
44 |
45 |
46 | def test_rechunk_blockwise_optimize():
47 | a = np.random.random((10, 10))
48 | b = da.from_array(a, chunks=(4, 4))
49 |
50 | result = (da.from_array(a, chunks=(4, 4)) + 1).rechunk((5, 5))
51 | expected = da.from_array(a, chunks=(5, 5)) + 1
52 | assert result.optimize()._name == expected.optimize()._name
53 |
54 | a = np.random.random((10,))
55 | aa = da.from_array(a)
56 | b = np.random.random((10, 10))
57 | bb = da.from_array(b)
58 |
59 | c = (aa + bb).rechunk((5, 2))
60 | result = c.optimize()
61 | expected = da.from_array(a, chunks=(2,)) + da.from_array(b, chunks=(5, 2))
62 | assert result._name == expected._name
63 |
64 | a = np.random.random((10, 1))
65 | aa = da.from_array(a)
66 | b = np.random.random((10, 10))
67 | bb = da.from_array(b)
68 |
69 | c = (aa + bb).rechunk((5, 2))
70 | result = c.optimize()
71 |
72 | expected = da.from_array(a, chunks=(5, 1)) + da.from_array(b, chunks=(5, 2))
73 | assert result._name == expected._name
74 |
75 |
76 | def test_elemwise():
77 | a = np.random.random((10, 10))
78 | b = da.from_array(a, chunks=(4, 4))
79 |
80 | (b + 1).compute()
81 | assert_eq(a + 1, b + 1)
82 | assert_eq(a + 2 * a, b + 2 * b)
83 |
84 | x = np.random.random(10)
85 | y = da.from_array(x, chunks=(4,))
86 |
87 | assert_eq(a + x, b + y)
88 |
89 |
90 | def test_transpose():
91 | a = np.random.random((10, 20))
92 | b = da.from_array(a, chunks=(2, 5))
93 |
94 | assert_eq(a.T, b.T)
95 |
96 | a = np.random.random((10, 1))
97 | b = da.from_array(a, chunks=(5, 1))
98 | assert_eq(a.T + a, b.T + b)
99 | assert_eq(a + a.T, b + b.T)
100 |
101 | assert b.T.T.optimize()._name == b.optimize()._name
102 |
103 |
104 | def test_slicing():
105 | a = np.random.random((10, 20))
106 | b = da.from_array(a, chunks=(2, 5))
107 |
108 | assert_eq(a[:], b[:])
109 | assert_eq(a[::2], b[::2])
110 | assert_eq(a[1, :5], b[1, :5])
111 | assert_eq(a[None, ..., ::5], b[None, ..., ::5])
112 | assert_eq(a[3], b[3])
113 |
114 |
115 | def test_slicing_optimization():
116 | a = np.random.random((10, 20))
117 | b = da.from_array(a, chunks=(2, 5))
118 |
119 | assert b[:].optimize()._name == b._name
120 | assert b[5:, 4][::2].optimize()._name == b[5::2, 4].optimize()._name
121 |
122 | assert (b + 1)[:5].optimize()._name == (b[:5] + 1)._name
123 | assert (b + 1)[5].optimize()._name == (b[5] + 1)._name
124 | assert b.T[5:].optimize()._name == b[:, 5:].T._name
125 |
126 |
127 | def test_slicing_optimization_change_dimensionality():
128 | a = np.random.random((10, 20))
129 | b = da.from_array(a, chunks=(2, 5))
130 | assert (b + 1)[5].optimize()._name == (b[5] + 1)._name
131 |
132 |
133 | def test_xarray():
134 | pytest.importorskip("xarray")
135 |
136 | import xarray as xr
137 |
138 | a = np.random.random((10, 20))
139 | b = da.from_array(a)
140 |
141 | x = (xr.DataArray(b, dims=["x", "y"]) + 1).chunk(x=2)
142 |
143 | assert x.data.optimize()._name == (da.from_array(a, chunks={0: 2}) + 1)._name
144 |
145 |
146 | def test_random():
147 | x = da.random.random((100, 100), chunks=(50, 50))
148 | assert_eq(x, x)
149 |
150 |
151 | @pytest.mark.parametrize(
152 | "reduction",
153 | ["sum", "mean", "var", "std", "any", "all", "prod", "min", "max"],
154 | )
155 | def test_reductions(reduction):
156 | a = np.random.random((10, 20))
157 | b = da.from_array(a, chunks=(2, 5))
158 |
159 | def func(x, **kwargs):
160 | return getattr(x, reduction)(**kwargs)
161 |
162 | assert_eq(func(a), func(b))
163 | assert_eq(func(a, axis=1), func(b, axis=1))
164 |
165 |
166 | @pytest.mark.parametrize(
167 | "reduction",
168 | ["nanmean", "nansum"],
169 | )
170 | def test_reduction_functions(reduction):
171 | a = np.random.random((10, 20))
172 | b = da.from_array(a, chunks=(2, 5))
173 |
174 | def func(x, **kwargs):
175 | if isinstance(x, np.ndarray):
176 | return getattr(np, reduction)(x, **kwargs)
177 | else:
178 | return getattr(da, reduction)(x, **kwargs)
179 |
180 | func(b).chunks
181 |
182 | assert_eq(func(a), func(b))
183 | assert_eq(func(a, axis=1), func(b, axis=1))
184 |
185 |
186 | @pytest.mark.parametrize(
187 | "ufunc",
188 | [np.sqrt, np.sin, np.exp],
189 | )
190 | def test_ufunc(ufunc):
191 | a = np.random.random((10, 20))
192 | b = da.from_array(a, chunks=(2, 5))
193 | assert_eq(ufunc(a), ufunc(b))
194 |
195 |
196 | @pytest.mark.parametrize(
197 | "op",
198 | [operator.add, operator.sub, operator.pow, operator.floordiv, operator.truediv],
199 | )
200 | def test_binop(op):
201 | a = np.random.random((10, 20))
202 | b = np.random.random(20)
203 | aa = da.from_array(a, chunks=(2, 5))
204 | bb = da.from_array(b, chunks=5)
205 |
206 | assert_eq(op(a, b), op(aa, bb))
207 |
208 |
209 | def test_asarray():
210 | a = np.random.random((10, 20))
211 | b = da.asarray(a)
212 | assert_eq(a, b)
213 | assert isinstance(b, da.Array) and type(b) == type(da.from_array(a))
214 |
215 |
216 | def test_unify_chunks():
217 | a = np.random.random((10, 20))
218 | aa = da.asarray(a, chunks=(4, 5))
219 | b = np.random.random((20,))
220 | bb = da.asarray(b, chunks=(10,))
221 |
222 | assert_eq(a + b, aa + bb)
223 |
224 |
225 | def test_array_function():
226 | a = np.random.random((10, 20))
227 | aa = da.asarray(a, chunks=(4, 5))
228 |
229 | assert isinstance(np.nanmean(aa), da.Array)
230 |
231 | assert_eq(np.nanmean(aa), np.nanmean(a))
232 |
--------------------------------------------------------------------------------
/dask_expr/array/tests/test_creation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from dask.array import assert_eq
4 |
5 | import dask_expr.array as da
6 |
7 |
8 | def test_arange():
9 | assert_eq(da.arange(1, 100, 7), np.arange(1, 100, 7))
10 | assert_eq(da.arange(100), np.arange(100))
11 | assert_eq(da.arange(100, like=np.arange(100)), np.arange(100))
12 |
13 |
14 | def test_linspace():
15 | assert_eq(da.linspace(1, 100, 30), np.linspace(1, 100, 30))
16 |
17 |
18 | @pytest.mark.parametrize("func", ["ones", "zeros"])
19 | def test_ones(func):
20 | assert_eq(getattr(da, func)((10, 20, 15)), getattr(np, func)((10, 20, 15)))
21 | assert_eq(
22 | getattr(da, func)((10, 20, 15), dtype="i8"),
23 | getattr(np, func)((10, 20, 15), dtype="i8"),
24 | )
25 |
--------------------------------------------------------------------------------
/dask_expr/datasets.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import operator
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from dask._task_spec import Task
7 | from dask.dataframe.utils import pyarrow_strings_enabled
8 | from dask.typing import Key
9 |
10 | from dask_expr._collection import new_collection
11 | from dask_expr._expr import ArrowStringConversion
12 | from dask_expr.io import BlockwiseIO, PartitionsFiltered
13 |
14 | __all__ = ["timeseries"]
15 |
16 |
17 | class Timeseries(PartitionsFiltered, BlockwiseIO):
18 | _absorb_projections = True
19 |
20 | _parameters = [
21 | "start",
22 | "end",
23 | "dtypes",
24 | "freq",
25 | "partition_freq",
26 | "seed",
27 | "kwargs",
28 | "columns",
29 | "_partitions",
30 | "_series",
31 | ]
32 | _defaults = {
33 | "start": "2000-01-01",
34 | "end": "2000-01-31",
35 | "dtypes": {"name": "string", "id": int, "x": float, "y": float},
36 | "freq": "1s",
37 | "partition_freq": "1d",
38 | "seed": None,
39 | "kwargs": {},
40 | "_partitions": None,
41 | "_series": False,
42 | }
43 |
44 | @functools.cached_property
45 | def _meta(self):
46 | result = self._make_timeseries_part("2000", "2000", 0).iloc[:0]
47 | if self._series:
48 | return result[result.columns[0]]
49 | return result
50 |
51 | def _divisions(self):
52 | return pd.date_range(start=self.start, end=self.end, freq=self.partition_freq)
53 |
54 | @property
55 | def _dtypes(self):
56 | dtypes = self.operand("dtypes")
57 | return {col: dtypes[col] for col in self.operand("columns")}
58 |
59 | @functools.cached_property
60 | def random_state(self):
61 | npartitions = len(self._divisions()) - 1
62 | ndtypes = max(len(self.operand("dtypes")), 1)
63 | random_state = np.random.RandomState(self.seed)
64 | n = npartitions * ndtypes
65 | random_data = random_state.bytes(n * 4) # `n` 32-bit integers
66 | l = list(np.frombuffer(random_data, dtype=np.uint32).reshape((n,)))
67 | assert len(l) == n
68 | return l
69 |
70 | @functools.cached_property
71 | def _make_timeseries_part(self):
72 | return MakeTimeseriesPart(
73 | self.operand("dtypes"),
74 | list(self._dtypes.keys()),
75 | self.freq,
76 | self.kwargs,
77 | )
78 |
79 | def _filtered_task(self, name: Key, index: int) -> Task:
80 | full_divisions = self._divisions()
81 | ndtypes = max(len(self.operand("dtypes")), 1)
82 | task = Task(
83 | name,
84 | self._make_timeseries_part,
85 | full_divisions[index],
86 | full_divisions[index + 1],
87 | self.random_state[index * ndtypes],
88 | )
89 | if self._series:
90 | return Task(name, operator.getitem, task, self.operand("columns")[0])
91 | return task
92 |
93 |
94 | names = [
95 | "Alice",
96 | "Bob",
97 | "Charlie",
98 | "Dan",
99 | "Edith",
100 | "Frank",
101 | "George",
102 | "Hannah",
103 | "Ingrid",
104 | "Jerry",
105 | "Kevin",
106 | "Laura",
107 | "Michael",
108 | "Norbert",
109 | "Oliver",
110 | "Patricia",
111 | "Quinn",
112 | "Ray",
113 | "Sarah",
114 | "Tim",
115 | "Ursula",
116 | "Victor",
117 | "Wendy",
118 | "Xavier",
119 | "Yvonne",
120 | "Zelda",
121 | ]
122 |
123 |
124 | def make_string(n, rstate):
125 | return rstate.choice(names, size=n)
126 |
127 |
128 | def make_categorical(n, rstate):
129 | return pd.Categorical.from_codes(rstate.randint(0, len(names), size=n), names)
130 |
131 |
132 | def make_float(n, rstate):
133 | return rstate.rand(n) * 2 - 1
134 |
135 |
136 | def make_int(n, rstate, lam=1000):
137 | return rstate.poisson(lam, size=n)
138 |
139 |
140 | make = {
141 | float: make_float,
142 | int: make_int,
143 | str: make_string,
144 | object: make_string,
145 | "string": make_string,
146 | "category": make_categorical,
147 | }
148 |
149 |
150 | class MakeTimeseriesPart:
151 | def __init__(self, dtypes, columns, freq, kwargs):
152 | self.dtypes = dtypes
153 | self.columns = columns
154 | self.freq = freq
155 | self.kwargs = kwargs
156 |
157 | def __call__(self, start, end, state_data):
158 | dtypes = self.dtypes
159 | columns = self.columns
160 | freq = self.freq
161 | kwargs = self.kwargs
162 | state = np.random.RandomState(state_data)
163 | index = pd.date_range(start=start, end=end, freq=freq, name="timestamp")
164 | data = {}
165 | for k, dt in dtypes.items():
166 | kws = {
167 | kk.rsplit("_", 1)[1]: v
168 | for kk, v in kwargs.items()
169 | if kk.rsplit("_", 1)[0] == k
170 | }
171 | # Note: we compute data for all dtypes in order, not just those in the output
172 | # columns. This ensures the same output given the same state_data, regardless
173 | # of whether there is any column projection.
174 | # cf. https://github.com/dask/dask/pull/9538#issuecomment-1267461887
175 | result = make[dt](len(index), state, **kws)
176 | if k in columns:
177 | data[k] = result
178 | df = pd.DataFrame(data, index=index, columns=columns)
179 | if df.index[-1] == end:
180 | df = df.iloc[:-1]
181 | return df
182 |
183 |
184 | def timeseries(
185 | start="2000-01-01",
186 | end="2000-01-31",
187 | freq="1s",
188 | partition_freq="1d",
189 | dtypes=None,
190 | seed=None,
191 | **kwargs,
192 | ):
193 | """Create timeseries dataframe with random data
194 |
195 | Parameters
196 | ----------
197 | start: datetime (or datetime-like string)
198 | Start of time series
199 | end: datetime (or datetime-like string)
200 | End of time series
201 | dtypes: dict (optional)
202 | Mapping of column names to types.
203 | Valid types include {float, int, str, 'category'}
204 | freq: string
205 | String like '2s' or '1H' or '12W' for the time series frequency
206 | partition_freq: string
207 | String like '1M' or '2Y' to divide the dataframe into partitions
208 | seed: int (optional)
209 | Randomstate seed
210 | kwargs:
211 | Keywords to pass down to individual column creation functions.
212 | Keywords should be prefixed by the column name and then an underscore.
213 |
214 | Examples
215 | --------
216 | >>> import dask_expr.datasets import timeseries
217 | >>> df = timeseries(
218 | ... start='2000', end='2010',
219 | ... dtypes={'value': float, 'name': str, 'id': int},
220 | ... freq='2H', partition_freq='1D', seed=1
221 | ... )
222 | >>> df.head() # doctest: +SKIP
223 | id name value
224 | 2000-01-01 00:00:00 969 Jerry -0.309014
225 | 2000-01-01 02:00:00 1010 Ray -0.760675
226 | 2000-01-01 04:00:00 1016 Patricia -0.063261
227 | 2000-01-01 06:00:00 960 Charlie 0.788245
228 | 2000-01-01 08:00:00 1031 Kevin 0.466002
229 | """
230 | if dtypes is None:
231 | dtypes = {"name": "string", "id": int, "x": float, "y": float}
232 |
233 | if seed is None:
234 | seed = np.random.randint(2e9)
235 |
236 | expr = Timeseries(
237 | start,
238 | end,
239 | dtypes,
240 | freq,
241 | partition_freq,
242 | seed,
243 | kwargs,
244 | columns=list(dtypes.keys()),
245 | )
246 | if pyarrow_strings_enabled():
247 | return new_collection(ArrowStringConversion(expr))
248 | return new_collection(expr)
249 |
--------------------------------------------------------------------------------
/dask_expr/diagnostics/__init__.py:
--------------------------------------------------------------------------------
1 | from dask_expr.diagnostics._analyze import analyze
2 | from dask_expr.diagnostics._explain import explain
3 |
4 | __all__ = ["analyze", "explain"]
5 |
--------------------------------------------------------------------------------
/dask_expr/diagnostics/_analyze.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from typing import TYPE_CHECKING, Any
5 |
6 | import pandas as pd
7 | from dask.base import DaskMethodsMixin
8 | from dask.sizeof import sizeof
9 | from dask.utils import format_bytes, import_required
10 |
11 | from dask_expr._expr import Blockwise, Expr
12 | from dask_expr._util import _tokenize_deterministic, is_scalar
13 | from dask_expr.diagnostics._explain import _add_graphviz_edges, _explain_info
14 | from dask_expr.io.io import FusedIO
15 |
16 | if TYPE_CHECKING:
17 | from dask_expr.diagnostics._analyze_plugin import ExpressionStatistics, Statistics
18 |
19 |
20 | def inject_analyze(expr: Expr, id: str, injected: dict) -> Expr:
21 | if expr._name in injected:
22 | return injected[expr._name]
23 |
24 | new_operands = []
25 | for operand in expr.operands:
26 | if isinstance(operand, Expr) and not isinstance(expr, FusedIO):
27 | new = inject_analyze(operand, id, injected)
28 | injected[operand._name] = new
29 | else:
30 | new = operand
31 | new_operands.append(new)
32 | return Analyze(type(expr)(*new_operands), id, expr._name)
33 |
34 |
35 | def analyze(
36 | expr: Expr, filename: str | None = None, format: str | None = None, **kwargs: Any
37 | ):
38 | import_required(
39 | "distributed",
40 | "distributed is a required dependency for using the analyze method.",
41 | )
42 | import_required(
43 | "crick", "crick is a required dependency for using the analyze method."
44 | )
45 | graphviz = import_required(
46 | "graphviz", "graphviz is a required dependency for using the analyze method."
47 | )
48 | from dask.dot import graphviz_to_file
49 | from distributed import get_client, wait
50 |
51 | from dask_expr import new_collection
52 | from dask_expr.diagnostics._analyze_plugin import AnalyzePlugin
53 |
54 | try:
55 | client = get_client()
56 | except ValueError:
57 | raise RuntimeError("analyze must be run in a distributed context.")
58 | client.register_plugin(AnalyzePlugin())
59 |
60 | # TODO: Make this work with fuse=True
61 | expr = expr.optimize(fuse=False)
62 |
63 | analysis_id = expr._name
64 |
65 | # Inject analyze nodes
66 | injected = inject_analyze(expr, analysis_id, {})
67 | out = new_collection(injected)
68 | _ = DaskMethodsMixin.compute(out, **kwargs)
69 | wait(_)
70 |
71 | # Collect data
72 | statistics: Statistics = client.sync(
73 | client.scheduler.analyze_get_statistics, id=analysis_id
74 | ) # type: noqa
75 |
76 | # Plot statistics in graph
77 | seen = set(expr._name)
78 | stack = [expr]
79 |
80 | if filename is None:
81 | filename = f"analyze-{expr._name}"
82 |
83 | if format is None:
84 | format = "svg"
85 |
86 | g = graphviz.Digraph(expr._name)
87 | g.node_attr.update(shape="record")
88 | while stack:
89 | node = stack.pop()
90 | info = _analyze_info(node, statistics._expr_statistics[node._name])
91 | _add_graphviz_node(info, g)
92 | _add_graphviz_edges(info, g)
93 |
94 | if isinstance(node, FusedIO):
95 | continue
96 | for dep in node.operands:
97 | if not isinstance(dep, Expr) or dep._name in seen:
98 | continue
99 | seen.add(dep._name)
100 | stack.append(dep)
101 | graphviz_to_file(g, filename, format)
102 | return g
103 |
104 |
105 | def _add_graphviz_node(info, graph):
106 | label = "".join(
107 | [
108 | "<{",
109 | info["label"],
110 | " | ",
111 | "
".join(
112 | [f"{key}: {value}" for key, value in info["details"].items()]
113 | ),
114 | " | ",
115 | _statistics_to_graphviz(info["statistics"]),
116 | "}>",
117 | ]
118 | )
119 |
120 | graph.node(info["name"], label)
121 |
122 |
123 | def _statistics_to_graphviz(statistics: dict[str, dict[str, Any]]) -> str:
124 | return "
".join(
125 | [
126 | _metric_to_graphviz(metric, statistics)
127 | for metric, statistics in statistics.items()
128 | ]
129 | )
130 |
131 |
132 | _FORMAT_FNS = {"nbytes": format_bytes, "nrows": "{:,.0f}".format}
133 |
134 |
135 | def _metric_to_graphviz(metric: str, statistics: dict[str, Any]):
136 | format_fn = _FORMAT_FNS[metric]
137 | quantiles = (
138 | "[" + ", ".join([format_fn(pctl) for pctl in statistics.pop("quantiles")]) + "]"
139 | )
140 | count = statistics["count"]
141 | total = statistics["total"]
142 |
143 | return "
".join(
144 | [
145 | f"{metric}:",
146 | f"{format_fn(total / count)} ({format_fn(total)} / {count:,})",
147 | f"{quantiles}",
148 | ]
149 | )
150 |
151 |
152 | def _analyze_info(expr: Expr, statistics: ExpressionStatistics):
153 | info = _explain_info(expr)
154 | info["statistics"] = _statistics_info(statistics)
155 | return info
156 |
157 |
158 | def _statistics_info(statistics: ExpressionStatistics):
159 | info = {}
160 | for metric, digest in statistics._metric_digests.items():
161 | info[metric] = {
162 | "total": digest.total,
163 | "count": digest.count,
164 | "quantiles": [digest.sketch.quantile(q) for q in (0, 0.25, 0.5, 0.75, 1)],
165 | }
166 | return info
167 |
168 |
169 | def collect_statistics(frame, analysis_id, expr_name):
170 | from dask_expr.diagnostics._analyze_plugin import get_worker_plugin
171 |
172 | worker_plugin = get_worker_plugin()
173 | if isinstance(frame, pd.DataFrame):
174 | size = frame.memory_usage(deep=True).sum()
175 | elif isinstance(frame, pd.Series):
176 | size = frame.memory_usage(deep=True)
177 | else:
178 | size = sizeof(frame)
179 |
180 | len_frame = len(frame) if not is_scalar(frame) else 1
181 | worker_plugin.add(analysis_id, expr_name, "nrows", len_frame)
182 | worker_plugin.add(analysis_id, expr_name, "nbytes", size)
183 | return frame
184 |
185 |
186 | class Analyze(Blockwise):
187 | _parameters = ["frame", "analysis_id", "expr_name"]
188 |
189 | operation = staticmethod(collect_statistics)
190 |
191 | @functools.cached_property
192 | def _meta(self):
193 | return self.frame._meta
194 |
195 | @functools.cached_property
196 | def _name(self):
197 | return "analyze-" + _tokenize_deterministic(*self.operands)
198 |
--------------------------------------------------------------------------------
/dask_expr/diagnostics/_analyze_plugin.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import defaultdict
4 | from typing import TYPE_CHECKING, ClassVar
5 |
6 | from distributed import Scheduler, SchedulerPlugin, Worker, WorkerPlugin
7 | from distributed.protocol.pickle import dumps
8 |
9 | if TYPE_CHECKING:
10 | from crick import TDigest
11 |
12 |
13 | class Digest:
14 | count: int
15 | total: float
16 | sketch: TDigest
17 |
18 | def __init__(self) -> None:
19 | from crick import TDigest
20 |
21 | self.count = 0
22 | self.total = 0.0
23 | self.sketch = TDigest()
24 |
25 | def add(self, sample: float) -> None:
26 | self.count = self.count + 1
27 | self.total = self.total + sample
28 | self.sketch.add(sample)
29 |
30 | def merge(self, other: Digest) -> None:
31 | self.count = self.count + other.count
32 | self.total = self.total + other.total
33 | self.sketch.merge(other.sketch)
34 |
35 | @property
36 | def mean(self):
37 | return self.total / self.count
38 |
39 |
40 | class AnalyzePlugin(SchedulerPlugin):
41 | idempotent: ClassVar[bool] = True
42 | name: ClassVar[str] = "analyze"
43 | _scheduler: Scheduler | None
44 |
45 | def __init__(self) -> None:
46 | self._scheduler = None
47 |
48 | async def start(self, scheduler: Scheduler) -> None:
49 | self._scheduler = scheduler
50 | scheduler.handlers["analyze_get_statistics"] = self.get_statistics
51 | worker_plugin = _AnalyzeWorkerPlugin()
52 | await self._scheduler.register_worker_plugin(
53 | None,
54 | dumps(worker_plugin),
55 | name=worker_plugin.name,
56 | idempotent=True,
57 | )
58 |
59 | async def get_statistics(self, id: str):
60 | assert self._scheduler is not None
61 | worker_statistics = await self._scheduler.broadcast(
62 | msg={"op": "analyze_get_statistics", "id": id}
63 | )
64 | cluster_statistics = Statistics()
65 | for statistics in worker_statistics.values():
66 | cluster_statistics.merge(statistics)
67 | return cluster_statistics
68 |
69 |
70 | class ExpressionStatistics:
71 | _metric_digests: defaultdict[str, Digest]
72 |
73 | def __init__(self) -> None:
74 | self._metric_digests = defaultdict(Digest)
75 |
76 | def add(self, metric: str, value: float) -> None:
77 | self._metric_digests[metric].add(value)
78 |
79 | def merge(self, other: ExpressionStatistics) -> None:
80 | for metric, digest in other._metric_digests.items():
81 | self._metric_digests[metric].merge(digest)
82 |
83 |
84 | class Statistics:
85 | _expr_statistics: defaultdict[str, ExpressionStatistics]
86 |
87 | def __init__(self) -> None:
88 | self._expr_statistics = defaultdict(ExpressionStatistics)
89 |
90 | def add(self, expr: str, metric: str, value: float):
91 | self._expr_statistics[expr].add(metric, value)
92 |
93 | def merge(self, other: Statistics):
94 | for expr, statistics in other._expr_statistics.items():
95 | self._expr_statistics[expr].merge(statistics)
96 |
97 |
98 | class _AnalyzeWorkerPlugin(WorkerPlugin):
99 | idempotent: ClassVar[bool] = True
100 | name: ClassVar[str] = "analyze"
101 | _statistics: defaultdict[str, Statistics]
102 | _worker: Worker | None
103 |
104 | def __init__(self) -> None:
105 | self._worker = None
106 | self._statistics = defaultdict(Statistics)
107 |
108 | def setup(self, worker: Worker) -> None:
109 | self._digests = defaultdict(lambda: defaultdict(lambda: defaultdict(Digest)))
110 | self._worker = worker
111 | self._worker.handlers["analyze_get_statistics"] = self.get_statistics
112 |
113 | def add(self, id: str, expr: str, metric: str, value: float):
114 | self._statistics[id].add(expr, metric, value)
115 |
116 | def get_statistics(self, id: str) -> Statistics:
117 | return self._statistics.pop(id)
118 |
119 |
120 | def get_worker_plugin() -> _AnalyzeWorkerPlugin:
121 | from distributed import get_worker
122 |
123 | try:
124 | worker = get_worker()
125 | except ValueError as e:
126 | raise RuntimeError(
127 | "``.analyze()`` requires Dask's distributed scheduler"
128 | ) from e
129 |
130 | try:
131 | return worker.plugins["analyze"] # type: ignore
132 | except KeyError as e:
133 | raise RuntimeError(
134 | f"The worker {worker.address} does not have an Analyze plugin."
135 | ) from e
136 |
--------------------------------------------------------------------------------
/dask_expr/diagnostics/_explain.py:
--------------------------------------------------------------------------------
1 | from dask.utils import funcname, import_required
2 |
3 | from dask_expr._core import OptimizerStage
4 | from dask_expr._expr import Expr, Projection, optimize_until
5 | from dask_expr._merge import Merge
6 | from dask_expr.io.parquet import ReadParquet
7 |
8 | STAGE_LABELS: dict[OptimizerStage, str] = {
9 | "logical": "Logical Plan",
10 | "simplified-logical": "Simplified Logical Plan",
11 | "tuned-logical": "Tuned Logical Plan",
12 | "physical": "Physical Plan",
13 | "simplified-physical": "Simplified Physical Plan",
14 | "fused": "Fused Physical Plan",
15 | }
16 |
17 |
18 | def explain(expr: Expr, stage: OptimizerStage = "fused", format: str | None = None):
19 | graphviz = import_required(
20 | "graphviz", "graphviz is a required dependency for using the explain method."
21 | )
22 |
23 | if format is None:
24 | format = "png"
25 |
26 | g = graphviz.Digraph(
27 | STAGE_LABELS[stage], filename=f"explain-{stage}-{expr._name}", format=format
28 | )
29 | g.node_attr.update(shape="record")
30 |
31 | expr = optimize_until(expr, stage)
32 |
33 | seen = set(expr._name)
34 | stack = [expr]
35 |
36 | while stack:
37 | node = stack.pop()
38 | explain_info = _explain_info(node)
39 | _add_graphviz_node(explain_info, g)
40 | _add_graphviz_edges(explain_info, g)
41 |
42 | for dep in node.operands:
43 | if not isinstance(dep, Expr) or dep._name in seen:
44 | continue
45 | seen.add(dep._name)
46 | stack.append(dep)
47 |
48 | g.view()
49 |
50 |
51 | def _add_graphviz_node(explain_info, graph):
52 | label = "".join(
53 | [
54 | "<{",
55 | explain_info["label"],
56 | " | ",
57 | "
".join(
58 | [f"{key}: {value}" for key, value in explain_info["details"].items()]
59 | ),
60 | "}>",
61 | ]
62 | )
63 |
64 | graph.node(explain_info["name"], label)
65 |
66 |
67 | def _add_graphviz_edges(explain_info, graph):
68 | name = explain_info["name"]
69 | for _, dep in explain_info["dependencies"]:
70 | graph.edge(dep, name)
71 |
72 |
73 | def _explain_info(expr: Expr):
74 | return {
75 | "name": expr._name,
76 | "label": funcname(type(expr)),
77 | "details": _explain_details(expr),
78 | "dependencies": _explain_dependencies(expr),
79 | }
80 |
81 |
82 | def _explain_details(expr: Expr):
83 | details = {"npartitions": expr.npartitions}
84 |
85 | if isinstance(expr, Merge):
86 | details["how"] = expr.how
87 | elif isinstance(expr, ReadParquet):
88 | details["path"] = expr.path
89 | elif isinstance(expr, Projection):
90 | columns = expr.operand("columns")
91 | details["ncolumns"] = len(columns) if isinstance(columns, list) else "Series"
92 |
93 | return details
94 |
95 |
96 | def _explain_dependencies(expr: Expr) -> list[tuple[str, str]]:
97 | dependencies = []
98 | for i, operand in enumerate(expr.operands):
99 | if not isinstance(operand, Expr):
100 | continue
101 | param = expr._parameters[i] if i < len(expr._parameters) else ""
102 | dependencies.append((str(param), operand._name))
103 | return dependencies
104 |
--------------------------------------------------------------------------------
/dask_expr/io/__init__.py:
--------------------------------------------------------------------------------
1 | from dask_expr.io.csv import *
2 | from dask_expr.io.io import *
3 | from dask_expr.io.parquet import *
4 |
--------------------------------------------------------------------------------
/dask_expr/io/_delayed.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from collections.abc import Iterable
5 | from typing import TYPE_CHECKING
6 |
7 | import pandas as pd
8 | from dask._task_spec import Alias, Task, TaskRef
9 | from dask.dataframe.dispatch import make_meta
10 | from dask.dataframe.utils import check_meta, pyarrow_strings_enabled
11 | from dask.delayed import Delayed, delayed
12 | from dask.typing import Key
13 |
14 | from dask_expr._expr import ArrowStringConversion, DelayedsExpr, PartitionsFiltered
15 | from dask_expr._util import _tokenize_deterministic
16 | from dask_expr.io import BlockwiseIO
17 |
18 | if TYPE_CHECKING:
19 | import distributed
20 |
21 |
22 | class FromDelayed(PartitionsFiltered, BlockwiseIO):
23 | _parameters = [
24 | "delayed_container",
25 | "meta",
26 | "user_divisions",
27 | "verify_meta",
28 | "_partitions",
29 | "prefix",
30 | ]
31 | _defaults = {
32 | "meta": None,
33 | "_partitions": None,
34 | "user_divisions": None,
35 | "verify_meta": True,
36 | "prefix": None,
37 | }
38 |
39 | @functools.cached_property
40 | def _name(self):
41 | if self.prefix is None:
42 | return super()._name
43 | return self.prefix + "-" + _tokenize_deterministic(*self.operands)
44 |
45 | @functools.cached_property
46 | def _meta(self):
47 | if self.operand("meta") is not None:
48 | return self.operand("meta")
49 |
50 | return delayed(make_meta)(self.delayed_container.operands[0]).compute()
51 |
52 | def _divisions(self):
53 | if self.operand("user_divisions") is not None:
54 | return self.operand("user_divisions")
55 | else:
56 | return self.delayed_container.divisions
57 |
58 | def _filtered_task(self, name: Key, index: int) -> Task:
59 | if self.verify_meta:
60 | return Task(
61 | name,
62 | functools.partial(check_meta, meta=self._meta, funcname="from_delayed"),
63 | TaskRef((self.delayed_container._name, index)),
64 | )
65 | else:
66 | return Alias((self.delayed_container._name, index))
67 |
68 |
69 | def identity(x):
70 | return x
71 |
72 |
73 | def from_delayed(
74 | dfs: Delayed | distributed.Future | Iterable[Delayed | distributed.Future],
75 | meta=None,
76 | divisions: tuple | None = None,
77 | prefix: str | None = None,
78 | verify_meta: bool = True,
79 | ):
80 | """Create Dask DataFrame from many Dask Delayed objects
81 |
82 | .. warning::
83 | ``from_delayed`` should only be used if the objects that create
84 | the data are complex and cannot be easily represented as a single
85 | function in an embarassingly parallel fashion.
86 |
87 | ``from_map`` is recommended if the query can be expressed as a single
88 | function like:
89 |
90 | def read_xml(path):
91 | return pd.read_xml(path)
92 |
93 | ddf = dd.from_map(read_xml, paths)
94 |
95 | ``from_delayed`` might be depreacted in the future.
96 |
97 | Parameters
98 | ----------
99 | dfs :
100 | A ``dask.delayed.Delayed``, a ``distributed.Future``, or an iterable of either
101 | of these objects, e.g. returned by ``client.submit``. These comprise the
102 | individual partitions of the resulting dataframe.
103 | If a single object is provided (not an iterable), then the resulting dataframe
104 | will have only one partition.
105 | $META
106 | divisions :
107 | Partition boundaries along the index.
108 | For tuple, see https://docs.dask.org/en/latest/dataframe-design.html#partitions
109 | If None, then won't use index information
110 | prefix :
111 | Prefix to prepend to the keys.
112 | verify_meta :
113 | If True check that the partitions have consistent metadata, defaults to True.
114 | """
115 | if isinstance(dfs, Delayed) or hasattr(dfs, "key"):
116 | dfs = [dfs]
117 |
118 | if len(dfs) == 0:
119 | raise TypeError("Must supply at least one delayed object")
120 |
121 | if meta is None:
122 | meta = delayed(make_meta)(dfs[0]).compute()
123 |
124 | if divisions == "sorted":
125 | raise NotImplementedError(
126 | "divisions='sorted' not supported, please calculate the divisions "
127 | "yourself."
128 | )
129 | elif divisions is not None:
130 | divs = list(divisions)
131 | if len(divs) != len(dfs) + 1:
132 | raise ValueError("divisions should be a tuple of len(dfs) + 1")
133 |
134 | dfs = [
135 | delayed(df) if not isinstance(df, Delayed) and hasattr(df, "key") else df
136 | for df in dfs
137 | ]
138 |
139 | for item in dfs:
140 | if not isinstance(item, Delayed):
141 | raise TypeError("Expected Delayed object, got %s" % type(item).__name__)
142 |
143 | from dask_expr._collection import new_collection
144 |
145 | result = FromDelayed(
146 | DelayedsExpr(*dfs), make_meta(meta), divisions, verify_meta, None, prefix
147 | )
148 | if pyarrow_strings_enabled() and any(
149 | pd.api.types.is_object_dtype(dtype)
150 | for dtype in (result.dtypes.values if result.ndim == 2 else [result.dtypes])
151 | ):
152 | return new_collection(ArrowStringConversion(result))
153 | return new_collection(result)
154 |
--------------------------------------------------------------------------------
/dask_expr/io/bag.py:
--------------------------------------------------------------------------------
1 | from dask.dataframe.io.io import _df_to_bag
2 | from dask.tokenize import tokenize
3 |
4 | from dask_expr import FrameBase
5 |
6 |
7 | def to_bag(df, index=False, format="tuple"):
8 | """Create Dask Bag from a Dask DataFrame
9 |
10 | Parameters
11 | ----------
12 | index : bool, optional
13 | If True, the elements are tuples of ``(index, value)``, otherwise
14 | they're just the ``value``. Default is False.
15 | format : {"tuple", "dict", "frame"}, optional
16 | Whether to return a bag of tuples, dictionaries, or
17 | dataframe-like objects. Default is "tuple". If "frame",
18 | the original partitions of ``df`` will not be transformed
19 | in any way.
20 |
21 |
22 | Examples
23 | --------
24 | >>> bag = df.to_bag() # doctest: +SKIP
25 | """
26 | from dask.bag.core import Bag
27 |
28 | df = df.optimize()
29 |
30 | if not isinstance(df, FrameBase):
31 | raise TypeError("df must be either DataFrame or Series")
32 | name = "to_bag-" + tokenize(df._name, index, format)
33 | if format == "frame":
34 | dsk = df.dask
35 | name = df._name
36 | else:
37 | dsk = {
38 | (name, i): (_df_to_bag, block, index, format)
39 | for (i, block) in enumerate(df.__dask_keys__())
40 | }
41 | dsk.update(df.__dask_graph__())
42 | return Bag(dsk, name, df.npartitions)
43 |
--------------------------------------------------------------------------------
/dask_expr/io/csv.py:
--------------------------------------------------------------------------------
1 | def to_csv(
2 | df,
3 | filename,
4 | single_file=False,
5 | encoding="utf-8",
6 | mode="wt",
7 | name_function=None,
8 | compression=None,
9 | compute=True,
10 | scheduler=None,
11 | storage_options=None,
12 | header_first_partition_only=None,
13 | compute_kwargs=None,
14 | **kwargs,
15 | ):
16 | """
17 | Store Dask DataFrame to CSV files
18 |
19 | One filename per partition will be created. You can specify the
20 | filenames in a variety of ways.
21 |
22 | Use a globstring::
23 |
24 | >>> df.to_csv('/path/to/data/export-*.csv') # doctest: +SKIP
25 |
26 | The * will be replaced by the increasing sequence 0, 1, 2, ...
27 |
28 | ::
29 |
30 | /path/to/data/export-0.csv
31 | /path/to/data/export-1.csv
32 |
33 | Use a globstring and a ``name_function=`` keyword argument. The
34 | name_function function should expect an integer and produce a string.
35 | Strings produced by name_function must preserve the order of their
36 | respective partition indices.
37 |
38 | >>> from datetime import date, timedelta
39 | >>> def name(i):
40 | ... return str(date(2015, 1, 1) + i * timedelta(days=1))
41 |
42 | >>> name(0)
43 | '2015-01-01'
44 | >>> name(15)
45 | '2015-01-16'
46 |
47 | >>> df.to_csv('/path/to/data/export-*.csv', name_function=name) # doctest: +SKIP
48 |
49 | ::
50 |
51 | /path/to/data/export-2015-01-01.csv
52 | /path/to/data/export-2015-01-02.csv
53 | ...
54 |
55 | You can also provide an explicit list of paths::
56 |
57 | >>> paths = ['/path/to/data/alice.csv', '/path/to/data/bob.csv', ...] # doctest: +SKIP
58 | >>> df.to_csv(paths) # doctest: +SKIP
59 |
60 | You can also provide a directory name:
61 |
62 | >>> df.to_csv('/path/to/data') # doctest: +SKIP
63 |
64 | The files will be numbered 0, 1, 2, (and so on) suffixed with '.part':
65 |
66 | ::
67 |
68 | /path/to/data/0.part
69 | /path/to/data/1.part
70 |
71 | Parameters
72 | ----------
73 | df : dask.DataFrame
74 | Data to save
75 | filename : string or list
76 | Absolute or relative filepath(s). Prefix with a protocol like ``s3://``
77 | to save to remote filesystems.
78 | single_file : bool, default False
79 | Whether to save everything into a single CSV file. Under the
80 | single file mode, each partition is appended at the end of the
81 | specified CSV file.
82 | encoding : string, default 'utf-8'
83 | A string representing the encoding to use in the output file.
84 | mode : str, default 'w'
85 | Python file mode. The default is 'w' (or 'wt'), for writing
86 | a new file or overwriting an existing file in text mode. 'a'
87 | (or 'at') will append to an existing file in text mode or
88 | create a new file if it does not already exist. See :py:func:`open`.
89 | name_function : callable, default None
90 | Function accepting an integer (partition index) and producing a
91 | string to replace the asterisk in the given filename globstring.
92 | Should preserve the lexicographic order of partitions. Not
93 | supported when ``single_file`` is True.
94 | compression : string, optional
95 | A string representing the compression to use in the output file,
96 | allowed values are 'gzip', 'bz2', 'xz',
97 | only used when the first argument is a filename.
98 | compute : bool, default True
99 | If True, immediately executes. If False, returns a set of delayed
100 | objects, which can be computed at a later time.
101 | storage_options : dict
102 | Parameters passed on to the backend filesystem class.
103 | header_first_partition_only : bool, default None
104 | If set to True, only write the header row in the first output
105 | file. By default, headers are written to all partitions under
106 | the multiple file mode (``single_file`` is False) and written
107 | only once under the single file mode (``single_file`` is True).
108 | It must be True under the single file mode.
109 | compute_kwargs : dict, optional
110 | Options to be passed in to the compute method
111 | kwargs : dict, optional
112 | Additional parameters to pass to :meth:`pandas.DataFrame.to_csv`.
113 |
114 | Returns
115 | -------
116 | The names of the file written if they were computed right away.
117 | If not, the delayed tasks associated with writing the files.
118 |
119 | Raises
120 | ------
121 | ValueError
122 | If ``header_first_partition_only`` is set to False or
123 | ``name_function`` is specified when ``single_file`` is True.
124 |
125 | See Also
126 | --------
127 | fsspec.open_files
128 | """
129 | from dask.dataframe.io.csv import to_csv as _to_csv
130 |
131 | return _to_csv(
132 | df.optimize(),
133 | filename,
134 | single_file=single_file,
135 | encoding=encoding,
136 | mode=mode,
137 | name_function=name_function,
138 | compression=compression,
139 | compute=compute,
140 | scheduler=scheduler,
141 | storage_options=storage_options,
142 | header_first_partition_only=header_first_partition_only,
143 | compute_kwargs=compute_kwargs,
144 | **kwargs,
145 | )
146 |
--------------------------------------------------------------------------------
/dask_expr/io/hdf.py:
--------------------------------------------------------------------------------
1 | def read_hdf(
2 | pattern,
3 | key,
4 | start=0,
5 | stop=None,
6 | columns=None,
7 | chunksize=1000000,
8 | sorted_index=False,
9 | lock=True,
10 | mode="r",
11 | ):
12 | from dask.dataframe.io import read_hdf as _read_hdf
13 |
14 | return _read_hdf(
15 | pattern,
16 | key,
17 | start=start,
18 | stop=stop,
19 | columns=columns,
20 | chunksize=chunksize,
21 | sorted_index=sorted_index,
22 | lock=lock,
23 | mode=mode,
24 | )
25 |
26 |
27 | def to_hdf(
28 | df,
29 | path,
30 | key,
31 | mode="a",
32 | append=False,
33 | scheduler=None,
34 | name_function=None,
35 | compute=True,
36 | lock=None,
37 | dask_kwargs=None,
38 | **kwargs,
39 | ):
40 | """Store Dask Dataframe to Hierarchical Data Format (HDF) files
41 |
42 | This is a parallel version of the Pandas function of the same name. Please
43 | see the Pandas docstring for more detailed information about shared keyword
44 | arguments.
45 |
46 | This function differs from the Pandas version by saving the many partitions
47 | of a Dask DataFrame in parallel, either to many files, or to many datasets
48 | within the same file. You may specify this parallelism with an asterix
49 | ``*`` within the filename or datapath, and an optional ``name_function``.
50 | The asterix will be replaced with an increasing sequence of integers
51 | starting from ``0`` or with the result of calling ``name_function`` on each
52 | of those integers.
53 |
54 | This function only supports the Pandas ``'table'`` format, not the more
55 | specialized ``'fixed'`` format.
56 |
57 | Parameters
58 | ----------
59 | path : string, pathlib.Path
60 | Path to a target filename. Supports strings, ``pathlib.Path``, or any
61 | object implementing the ``__fspath__`` protocol. May contain a ``*`` to
62 | denote many filenames.
63 | key : string
64 | Datapath within the files. May contain a ``*`` to denote many locations
65 | name_function : function
66 | A function to convert the ``*`` in the above options to a string.
67 | Should take in a number from 0 to the number of partitions and return a
68 | string. (see examples below)
69 | compute : bool
70 | Whether or not to execute immediately. If False then this returns a
71 | ``dask.Delayed`` value.
72 | lock : bool, Lock, optional
73 | Lock to use to prevent concurrency issues. By default a
74 | ``threading.Lock``, ``multiprocessing.Lock`` or ``SerializableLock``
75 | will be used depending on your scheduler if a lock is required. See
76 | dask.utils.get_scheduler_lock for more information about lock
77 | selection.
78 | scheduler : string
79 | The scheduler to use, like "threads" or "processes"
80 | **other:
81 | See pandas.to_hdf for more information
82 |
83 | Examples
84 | --------
85 | Save Data to a single file
86 |
87 | >>> df.to_hdf('output.hdf', '/data') # doctest: +SKIP
88 |
89 | Save data to multiple datapaths within the same file:
90 |
91 | >>> df.to_hdf('output.hdf', '/data-*') # doctest: +SKIP
92 |
93 | Save data to multiple files:
94 |
95 | >>> df.to_hdf('output-*.hdf', '/data') # doctest: +SKIP
96 |
97 | Save data to multiple files, using the multiprocessing scheduler:
98 |
99 | >>> df.to_hdf('output-*.hdf', '/data', scheduler='processes') # doctest: +SKIP
100 |
101 | Specify custom naming scheme. This writes files as
102 | '2000-01-01.hdf', '2000-01-02.hdf', '2000-01-03.hdf', etc..
103 |
104 | >>> from datetime import date, timedelta
105 | >>> base = date(year=2000, month=1, day=1)
106 | >>> def name_function(i):
107 | ... ''' Convert integer 0 to n to a string '''
108 | ... return base + timedelta(days=i)
109 |
110 | >>> df.to_hdf('*.hdf', '/data', name_function=name_function) # doctest: +SKIP
111 |
112 | Returns
113 | -------
114 | filenames : list
115 | Returned if ``compute`` is True. List of file names that each partition
116 | is saved to.
117 | delayed : dask.Delayed
118 | Returned if ``compute`` is False. Delayed object to execute ``to_hdf``
119 | when computed.
120 |
121 | See Also
122 | --------
123 | read_hdf:
124 | to_parquet:
125 | """
126 | from dask.dataframe.io import to_hdf as _to_hdf
127 |
128 | return _to_hdf(
129 | df.optimize(),
130 | path,
131 | key,
132 | mode=mode,
133 | append=append,
134 | scheduler=scheduler,
135 | name_function=name_function,
136 | compute=compute,
137 | lock=lock,
138 | dask_kwargs=dask_kwargs,
139 | **kwargs,
140 | )
141 |
--------------------------------------------------------------------------------
/dask_expr/io/json.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from dask.dataframe.utils import insert_meta_param_description
3 |
4 | from dask_expr._backends import dataframe_creation_dispatch
5 |
6 |
7 | @dataframe_creation_dispatch.register_inplace("pandas")
8 | @insert_meta_param_description
9 | def read_json(
10 | url_path,
11 | orient="records",
12 | lines=None,
13 | storage_options=None,
14 | blocksize=None,
15 | sample=2**20,
16 | encoding="utf-8",
17 | errors="strict",
18 | compression="infer",
19 | meta=None,
20 | engine=pd.read_json,
21 | include_path_column=False,
22 | path_converter=None,
23 | **kwargs,
24 | ):
25 | """Create a dataframe from a set of JSON files
26 |
27 | This utilises ``pandas.read_json()``, and most parameters are
28 | passed through - see its docstring.
29 |
30 | Differences: orient is 'records' by default, with lines=True; this
31 | is appropriate for line-delimited "JSON-lines" data, the kind of JSON output
32 | that is most common in big-data scenarios, and which can be chunked when
33 | reading (see ``read_json()``). All other options require blocksize=None,
34 | i.e., one partition per input file.
35 |
36 | Parameters
37 | ----------
38 | url_path: str, list of str
39 | Location to read from. If a string, can include a glob character to
40 | find a set of file names.
41 | Supports protocol specifications such as ``"s3://"``.
42 | encoding, errors:
43 | The text encoding to implement, e.g., "utf-8" and how to respond
44 | to errors in the conversion (see ``str.encode()``).
45 | orient, lines, kwargs
46 | passed to pandas; if not specified, lines=True when orient='records',
47 | False otherwise.
48 | storage_options: dict
49 | Passed to backend file-system implementation
50 | blocksize: None or int
51 | If None, files are not blocked, and you get one partition per input
52 | file. If int, which can only be used for line-delimited JSON files,
53 | each partition will be approximately this size in bytes, to the nearest
54 | newline character.
55 | sample: int
56 | Number of bytes to pre-load, to provide an empty dataframe structure
57 | to any blocks without data. Only relevant when using blocksize.
58 | encoding, errors:
59 | Text conversion, ``see bytes.decode()``
60 | compression : string or None
61 | String like 'gzip' or 'xz'.
62 | engine : callable or str, default ``pd.read_json``
63 | The underlying function that dask will use to read JSON files. By
64 | default, this will be the pandas JSON reader (``pd.read_json``).
65 | If a string is specified, this value will be passed under the ``engine``
66 | key-word argument to ``pd.read_json`` (only supported for pandas>=2.0).
67 | include_path_column : bool or str, optional
68 | Include a column with the file path where each row in the dataframe
69 | originated. If ``True``, a new column is added to the dataframe called
70 | ``path``. If ``str``, sets new column name. Default is ``False``.
71 | path_converter : function or None, optional
72 | A function that takes one argument and returns a string. Used to convert
73 | paths in the ``path`` column, for instance, to strip a common prefix from
74 | all the paths.
75 | $META
76 |
77 | Returns
78 | -------
79 | dask.DataFrame
80 |
81 | Examples
82 | --------
83 | Load single file
84 |
85 | >>> dd.read_json('myfile.1.json') # doctest: +SKIP
86 |
87 | Load multiple files
88 |
89 | >>> dd.read_json('myfile.*.json') # doctest: +SKIP
90 |
91 | >>> dd.read_json(['myfile.1.json', 'myfile.2.json']) # doctest: +SKIP
92 |
93 | Load large line-delimited JSON files using partitions of approx
94 | 256MB size
95 |
96 | >> dd.read_json('data/file*.csv', blocksize=2**28)
97 | """
98 | from dask.dataframe.io.json import read_json
99 |
100 | return read_json(
101 | url_path,
102 | orient=orient,
103 | lines=lines,
104 | storage_options=storage_options,
105 | blocksize=blocksize,
106 | sample=sample,
107 | encoding=encoding,
108 | errors=errors,
109 | compression=compression,
110 | meta=meta,
111 | engine=engine,
112 | include_path_column=include_path_column,
113 | path_converter=path_converter,
114 | **kwargs,
115 | )
116 |
117 |
118 | def to_json(
119 | df,
120 | url_path,
121 | orient="records",
122 | lines=None,
123 | storage_options=None,
124 | compute=True,
125 | encoding="utf-8",
126 | errors="strict",
127 | compression=None,
128 | compute_kwargs=None,
129 | name_function=None,
130 | **kwargs,
131 | ):
132 | """Write dataframe into JSON text files
133 |
134 | This utilises ``pandas.DataFrame.to_json()``, and most parameters are
135 | passed through - see its docstring.
136 |
137 | Differences: orient is 'records' by default, with lines=True; this
138 | produces the kind of JSON output that is most common in big-data
139 | applications, and which can be chunked when reading (see ``read_json()``).
140 |
141 | Parameters
142 | ----------
143 | df: dask.DataFrame
144 | Data to save
145 | url_path: str, list of str
146 | Location to write to. If a string, and there are more than one
147 | partitions in df, should include a glob character to expand into a
148 | set of file names, or provide a ``name_function=`` parameter.
149 | Supports protocol specifications such as ``"s3://"``.
150 | encoding, errors:
151 | The text encoding to implement, e.g., "utf-8" and how to respond
152 | to errors in the conversion (see ``str.encode()``).
153 | orient, lines, kwargs
154 | passed to pandas; if not specified, lines=True when orient='records',
155 | False otherwise.
156 | storage_options: dict
157 | Passed to backend file-system implementation
158 | compute: bool
159 | If true, immediately executes. If False, returns a set of delayed
160 | objects, which can be computed at a later time.
161 | compute_kwargs : dict, optional
162 | Options to be passed in to the compute method
163 | compression : string or None
164 | String like 'gzip' or 'xz'.
165 | name_function : callable, default None
166 | Function accepting an integer (partition index) and producing a
167 | string to replace the asterisk in the given filename globstring.
168 | Should preserve the lexicographic order of partitions.
169 | """
170 | from dask.dataframe.io.json import to_json
171 |
172 | return to_json(
173 | df,
174 | url_path,
175 | orient=orient,
176 | lines=lines,
177 | storage_options=storage_options,
178 | compute=compute,
179 | encoding=encoding,
180 | errors=errors,
181 | compression=compression,
182 | compute_kwargs=compute_kwargs,
183 | name_function=name_function,
184 | **kwargs,
185 | )
186 |
--------------------------------------------------------------------------------
/dask_expr/io/orc.py:
--------------------------------------------------------------------------------
1 | from dask_expr._backends import dataframe_creation_dispatch
2 |
3 |
4 | @dataframe_creation_dispatch.register_inplace("pandas")
5 | def read_orc(
6 | path,
7 | engine="pyarrow",
8 | columns=None,
9 | index=None,
10 | split_stripes=1,
11 | aggregate_files=None,
12 | storage_options=None,
13 | ):
14 | """Read dataframe from ORC file(s)
15 |
16 | Parameters
17 | ----------
18 | path: str or list(str)
19 | Location of file(s), which can be a full URL with protocol
20 | specifier, and may include glob character if a single string.
21 | engine: 'pyarrow' or ORCEngine
22 | Backend ORC engine to use for I/O. Default is "pyarrow".
23 | columns: None or list(str)
24 | Columns to load. If None, loads all.
25 | index: str
26 | Column name to set as index.
27 | split_stripes: int or False
28 | Maximum number of ORC stripes to include in each output-DataFrame
29 | partition. Use False to specify a 1-to-1 mapping between files
30 | and partitions. Default is 1.
31 | aggregate_files : bool, default False
32 | Whether distinct file paths may be aggregated into the same output
33 | partition. A setting of True means that any two file paths may be
34 | aggregated into the same output partition, while False means that
35 | inter-file aggregation is prohibited.
36 | storage_options: None or dict
37 | Further parameters to pass to the bytes backend.
38 |
39 | Returns
40 | -------
41 | Dask.DataFrame (even if there is only one column)
42 |
43 | Examples
44 | --------
45 | >>> df = dd.read_orc('https://github.com/apache/orc/raw/'
46 | ... 'master/examples/demo-11-zlib.orc') # doctest: +SKIP
47 | """
48 | from dask.dataframe.io import read_orc as _read_orc
49 |
50 | return _read_orc(
51 | path,
52 | engine=engine,
53 | columns=columns,
54 | index=index,
55 | split_stripes=split_stripes,
56 | aggregate_files=aggregate_files,
57 | storage_options=storage_options,
58 | )
59 |
60 |
61 | def to_orc(
62 | df,
63 | path,
64 | engine="pyarrow",
65 | write_index=True,
66 | storage_options=None,
67 | compute=True,
68 | compute_kwargs=None,
69 | ):
70 | from dask.dataframe.io import to_orc as _to_orc
71 |
72 | return _to_orc(
73 | df.optimize(),
74 | path,
75 | engine=engine,
76 | write_index=write_index,
77 | storage_options=storage_options,
78 | compute=compute,
79 | compute_kwargs=compute_kwargs,
80 | )
81 |
--------------------------------------------------------------------------------
/dask_expr/io/records.py:
--------------------------------------------------------------------------------
1 | from dask.utils import M
2 |
3 |
4 | def to_records(df):
5 | """Create Dask Array from a Dask Dataframe
6 |
7 | Warning: This creates a dask.array without precise shape information.
8 | Operations that depend on shape information, like slicing or reshaping,
9 | will not work.
10 |
11 | Examples
12 | --------
13 | >>> df.to_records() # doctest: +SKIP
14 |
15 | See Also
16 | --------
17 | dask.dataframe.DataFrame.values
18 | dask.dataframe.from_dask_array
19 | """
20 | return df.map_partitions(M.to_records)
21 |
--------------------------------------------------------------------------------
/dask_expr/io/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/io/tests/__init__.py
--------------------------------------------------------------------------------
/dask_expr/io/tests/test_delayed.py:
--------------------------------------------------------------------------------
1 | import dask
2 | import numpy as np
3 | import pytest
4 | from dask import delayed
5 |
6 | from dask_expr import from_delayed, from_dict
7 | from dask_expr.tests._util import _backend_library, assert_eq
8 |
9 | pd = _backend_library()
10 |
11 |
12 | def test_from_delayed_optimizing():
13 | parts = from_dict({"a": np.arange(300)}, npartitions=30).to_delayed()
14 | result = from_delayed(parts[0], meta=pd.DataFrame({"a": pd.Series(dtype=np.int64)}))
15 | assert len(result.optimize().dask) == 2
16 | assert_eq(result, pd.DataFrame({"a": pd.Series(np.arange(10))}))
17 |
18 |
19 | @pytest.mark.parametrize("prefix", [None, "foo"])
20 | def test_from_delayed(prefix):
21 | pdf = pd.DataFrame(
22 | data=np.random.normal(size=(10, 4)), columns=["a", "b", "c", "d"]
23 | )
24 | parts = [pdf.iloc[:1], pdf.iloc[1:3], pdf.iloc[3:6], pdf.iloc[6:10]]
25 | dfs = [delayed(parts.__getitem__)(i) for i in range(4)]
26 |
27 | df = from_delayed(dfs, meta=pdf.head(0), divisions=None, prefix=prefix)
28 | assert_eq(df, pdf)
29 | assert len({k[0] for k in df.optimize().dask}) == 2
30 | if prefix:
31 | assert df._name.startswith(prefix)
32 |
33 | divisions = tuple([p.index[0] for p in parts] + [parts[-1].index[-1]])
34 | df = from_delayed(dfs, meta=pdf.head(0), divisions=divisions, prefix=prefix)
35 | assert_eq(df, pdf)
36 | if prefix:
37 | assert df._name.startswith(prefix)
38 |
39 |
40 | def test_from_delayed_dask():
41 | df = pd.DataFrame(data=np.random.normal(size=(10, 4)), columns=list("abcd"))
42 | parts = [df.iloc[:1], df.iloc[1:3], df.iloc[3:6], df.iloc[6:10]]
43 | dfs = [delayed(parts.__getitem__)(i) for i in range(4)]
44 | meta = dfs[0].compute()
45 |
46 | my_len = lambda x: pd.Series([len(x)])
47 |
48 | for divisions in [None, [0, 1, 3, 6, 10]]:
49 | ddf = from_delayed(dfs, meta=meta, divisions=divisions)
50 | assert_eq(ddf, df)
51 | assert list(ddf.map_partitions(my_len).compute()) == [1, 2, 3, 4]
52 | assert ddf.known_divisions == (divisions is not None)
53 |
54 | s = from_delayed([d.a for d in dfs], meta=meta.a, divisions=divisions)
55 | assert_eq(s, df.a)
56 | assert list(s.map_partitions(my_len).compute()) == [1, 2, 3, 4]
57 | assert ddf.known_divisions == (divisions is not None)
58 |
59 |
60 | @dask.delayed
61 | def _load(x):
62 | return pd.DataFrame({"x": x, "y": [1, 2, 3]})
63 |
64 |
65 | def test_from_delayed_fusion():
66 | func = lambda x: None
67 | df = from_delayed([_load(x) for x in range(10)], meta={"x": "int64", "y": "int64"})
68 | result = df.map_partitions(func, meta={}).optimize().dask
69 | expected = df.map_partitions(func, meta={}).optimize(fuse=False).dask
70 | assert result.keys() == expected.keys()
71 |
72 | expected = df.map_partitions(func, meta={}).lower_completely().dask
73 | assert result.keys() == expected.keys()
74 |
--------------------------------------------------------------------------------
/dask_expr/io/tests/test_distributed.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 |
5 | import pytest
6 |
7 | from dask_expr import read_parquet
8 | from dask_expr.tests._util import _backend_library, assert_eq
9 |
10 | distributed = pytest.importorskip("distributed")
11 |
12 | from distributed import Client, LocalCluster
13 | from distributed.utils_test import client as c # noqa F401
14 | from distributed.utils_test import gen_cluster
15 |
16 | import dask_expr as dx
17 |
18 | pd = _backend_library()
19 |
20 |
21 | @pytest.fixture(params=["arrow"])
22 | def filesystem(request):
23 | return request.param
24 |
25 |
26 | def _make_file(dir, df=None, filename="myfile.parquet", **kwargs):
27 | fn = os.path.join(str(dir), filename)
28 | if df is None:
29 | df = pd.DataFrame({c: range(10) for c in "abcde"})
30 | df.to_parquet(fn, **kwargs)
31 | return fn
32 |
33 |
34 | def test_io_fusion_merge(tmpdir):
35 | pdf = pd.DataFrame({c: range(100) for c in "abcdefghij"})
36 | with LocalCluster(processes=False, n_workers=2) as cluster:
37 | with Client(cluster) as client: # noqa: F841
38 | dx.from_pandas(pdf, 2).to_parquet(tmpdir)
39 |
40 | df = dx.read_parquet(tmpdir).merge(
41 | dx.read_parquet(tmpdir).add_suffix("_x"), left_on="a", right_on="a_x"
42 | )[["a_x", "b_x", "b"]]
43 | out = df.compute()
44 | pd.testing.assert_frame_equal(
45 | out.sort_values(by="a_x", ignore_index=True),
46 | pdf.merge(pdf.add_suffix("_x"), left_on="a", right_on="a_x")[
47 | ["a_x", "b_x", "b"]
48 | ],
49 | )
50 |
51 |
52 | @pytest.mark.filterwarnings("error")
53 | @gen_cluster(client=True)
54 | async def test_parquet_distriuted(c, s, a, b, tmpdir, filesystem):
55 | pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]})
56 | df = read_parquet(_make_file(tmpdir, df=pdf), filesystem=filesystem)
57 | assert_eq(await c.gather(c.compute(df.optimize())), pdf)
58 |
59 |
60 | def test_pickle_size(tmpdir, filesystem):
61 | pdf = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]})
62 | [_make_file(tmpdir, df=pdf, filename=f"{x}.parquet") for x in range(10)]
63 | df = read_parquet(tmpdir, filesystem=filesystem)
64 | from distributed.protocol import dumps
65 |
66 | assert len(b"".join(dumps(df.optimize().dask))) <= 9100
67 |
--------------------------------------------------------------------------------
/dask_expr/io/tests/test_from_pandas.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import dask
4 | import pytest
5 | from dask.dataframe.utils import assert_eq
6 |
7 | from dask_expr import from_pandas, repartition
8 | from dask_expr.tests._util import _backend_library
9 |
10 | pd = _backend_library()
11 |
12 |
13 | @pytest.fixture(params=["Series", "DataFrame"])
14 | def pdf(request):
15 | out = pd.DataFrame({"x": [1, 4, 3, 2, 0, 5]})
16 | return out["x"] if request.param == "Series" else out
17 |
18 |
19 | @pytest.mark.parametrize("sort", [True, False])
20 | def test_from_pandas(pdf, sort):
21 | df = from_pandas(pdf, npartitions=2, sort=sort)
22 |
23 | assert df.npartitions == 2
24 | assert df.divisions == (0, 3, 5)
25 | assert_eq(df, pdf, sort_results=sort)
26 |
27 |
28 | def test_from_pandas_noargs(pdf):
29 | df = from_pandas(pdf)
30 |
31 | assert df.npartitions == 1
32 | assert df.divisions == (0, 5)
33 | assert_eq(df, pdf)
34 |
35 |
36 | def test_from_pandas_empty(pdf):
37 | pdf = pdf.iloc[:0]
38 | df = from_pandas(pdf, npartitions=2)
39 | assert_eq(pdf, df)
40 |
41 |
42 | def test_from_pandas_immutable(pdf):
43 | expected = pdf.copy()
44 | df = from_pandas(pdf)
45 | pdf.iloc[0] = 100
46 | assert_eq(df, expected)
47 |
48 |
49 | def test_from_pandas_sort_and_different_partitions():
50 | pdf = pd.DataFrame({"a": [1, 2, 3] * 3, "b": 1}).set_index("a")
51 | df = from_pandas(pdf, npartitions=4, sort=True)
52 | assert_eq(pdf.sort_index(), df, sort_results=False)
53 |
54 | pdf = pd.DataFrame({"a": [1, 2, 3] * 3, "b": 1}).set_index("a")
55 | df = from_pandas(pdf, npartitions=4, sort=False)
56 | assert_eq(pdf, df, sort_results=False)
57 |
58 |
59 | def test_from_pandas_sort():
60 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[6, 5, 4, 3, 2, 1])
61 | df = from_pandas(pdf, npartitions=2)
62 | assert_eq(df, pdf.sort_index(), sort_results=False)
63 |
64 |
65 | def test_from_pandas_divisions():
66 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[7, 6, 4, 3, 2, 1])
67 | df = repartition(pdf, (1, 5, 8))
68 | assert_eq(df, pdf.sort_index())
69 |
70 | pdf = pd.DataFrame({"a": [1, 2, 3, 1, 2, 2]}, index=[7, 6, 4, 3, 2, 1])
71 | df = repartition(pdf, (1, 4, 8))
72 | assert_eq(df.partitions[1], pd.DataFrame({"a": [3, 2, 1]}, index=[4, 6, 7]))
73 |
74 | df = repartition(df, divisions=(1, 3, 8), force=True)
75 | assert_eq(df, pdf.sort_index())
76 |
77 |
78 | def test_from_pandas_empty_projection():
79 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1})
80 | df = from_pandas(pdf)
81 | assert_eq(df[[]], pdf[[]])
82 |
83 |
84 | def test_from_pandas_divisions_duplicated():
85 | pdf = pd.DataFrame({"a": 1}, index=[1, 2, 3, 4, 5, 5, 5, 6, 8])
86 | df = repartition(pdf, (1, 5, 7, 10))
87 | assert_eq(df, pdf)
88 | assert_eq(df.partitions[0], pdf.loc[1:4])
89 | assert_eq(df.partitions[1], pdf.loc[5:6])
90 | assert_eq(df.partitions[2], pdf.loc[8:])
91 |
92 |
93 | @pytest.mark.parametrize("npartitions", [1, 3, 6, 7])
94 | @pytest.mark.parametrize("sort", [True, False])
95 | def test_from_pandas_npartitions(pdf, npartitions, sort):
96 | df = from_pandas(pdf, sort=sort, npartitions=npartitions)
97 | assert df.npartitions == min(pdf.shape[0], npartitions)
98 | assert "pandas" in df._name
99 | assert_eq(df, pdf, sort_results=sort)
100 |
101 |
102 | @pytest.mark.parametrize("chunksize,npartitions", [(1, 6), (2, 3), (6, 1), (7, 1)])
103 | @pytest.mark.parametrize("sort", [True, False])
104 | def test_from_pandas_chunksize(pdf, chunksize, npartitions, sort):
105 | df = from_pandas(pdf, sort=sort, chunksize=chunksize)
106 | assert df.npartitions == npartitions
107 | assert "pandas" in df._name
108 | assert_eq(df, pdf, sort_results=sort)
109 |
110 |
111 | def test_from_pandas_npartitions_and_chunksize(pdf):
112 | with pytest.raises(ValueError, match="npartitions and chunksize"):
113 | from_pandas(pdf, npartitions=2, chunksize=3)
114 |
115 |
116 | def test_from_pandas_string_option():
117 | pdf = pd.DataFrame({"x": [1, 2, 3], "y": "a"}, index=["a", "b", "c"])
118 | df = from_pandas(pdf, npartitions=2)
119 | assert df.dtypes["y"] == "string"
120 | assert df.index.dtype == "string"
121 | assert df.compute().dtypes["y"] == "string"
122 | assert df.compute().index.dtype == "string"
123 | assert_eq(df, pdf)
124 |
125 | with dask.config.set({"dataframe.convert-string": False}):
126 | df = from_pandas(pdf, npartitions=2)
127 | assert df.dtypes["y"] == "object"
128 | assert df.index.dtype == "object"
129 | assert df.compute().dtypes["y"] == "object"
130 | assert df.compute().index.dtype == "object"
131 | assert_eq(df, pdf)
132 |
133 |
134 | def test_from_pandas_deepcopy():
135 | pdf = pd.DataFrame({"col1": [1, 2, 3, 4, 5, 6]})
136 | df = from_pandas(pdf, npartitions=3)
137 | df_dict = {"dataset": df}
138 | result = copy.deepcopy(df_dict)
139 | assert_eq(result["dataset"], pdf)
140 |
141 |
142 | def test_from_pandas_empty_chunksize():
143 | pdf = pd.DataFrame()
144 | df = from_pandas(pdf, chunksize=10_000)
145 | assert_eq(pdf, df)
146 |
--------------------------------------------------------------------------------
/dask_expr/io/tests/test_sql.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from dask.utils import tmpfile
3 |
4 | from dask_expr import from_pandas, read_sql_table
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | pd = _backend_library()
8 |
9 | pytest.importorskip("sqlalchemy")
10 |
11 |
12 | def test_shuffle_after_read_sql():
13 | with tmpfile() as f:
14 | uri = "sqlite:///%s" % f
15 |
16 | df = pd.DataFrame(
17 | {
18 | "id": [1, 2, 3, 4, 5, 6, 7, 8],
19 | "value": [
20 | "value1",
21 | "value2",
22 | "value3",
23 | "value3",
24 | "value4",
25 | "value4",
26 | "value4",
27 | "value5",
28 | ],
29 | }
30 | ).set_index("id")
31 | ddf = from_pandas(df, npartitions=1)
32 |
33 | ddf.to_sql("test_table", uri, if_exists="append")
34 | result = read_sql_table("test_table", con=uri, index_col="id")
35 | assert_eq(
36 | result["value"].unique(), pd.Series(df["value"].unique(), name="value")
37 | )
38 | assert_eq(result.shuffle(on_index=True), df)
39 |
--------------------------------------------------------------------------------
/dask_expr/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dask/dask-expr/42e7a8958ba286a4c7b2f199d143a70b0b479489/dask_expr/tests/__init__.py
--------------------------------------------------------------------------------
/dask_expr/tests/_util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import pickle
3 |
4 | import pytest
5 | from dask import config
6 | from dask.dataframe.utils import assert_eq as dd_assert_eq
7 |
8 |
9 | def _backend_name() -> str:
10 | return config.get("dataframe.backend", "pandas")
11 |
12 |
13 | def _backend_library():
14 | return importlib.import_module(_backend_name())
15 |
16 |
17 | def xfail_gpu(reason=None):
18 | condition = _backend_name() == "cudf"
19 | reason = reason or "Failure expected for cudf backend."
20 | return pytest.mark.xfail(condition, reason=reason)
21 |
22 |
23 | def assert_eq(a, b, *args, serialize_graph=True, **kwargs):
24 | if serialize_graph:
25 | # Check that no `Expr` instances are found in
26 | # the graph generated by `Expr.dask`
27 | with config.set({"dask-expr-no-serialize": True}):
28 | for obj in [a, b]:
29 | if hasattr(obj, "dask"):
30 | try:
31 | pickle.dumps(obj.dask)
32 | except AttributeError:
33 | try:
34 | import cloudpickle as cp
35 |
36 | cp.dumps(obj.dask)
37 | except ImportError:
38 | pass
39 |
40 | # Use `dask.dataframe.assert_eq`
41 | return dd_assert_eq(a, b, *args, **kwargs)
42 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_align_partitions.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from itertools import product
3 |
4 | import numpy as np
5 | import pytest
6 |
7 | from dask_expr import from_pandas
8 | from dask_expr._expr import OpAlignPartitions
9 | from dask_expr._repartition import RepartitionDivisions
10 | from dask_expr._shuffle import Shuffle, divisions_lru
11 | from dask_expr.tests._util import _backend_library, assert_eq
12 |
13 | # Set DataFrame backend for this module
14 | pd = _backend_library()
15 |
16 |
17 | @pytest.fixture
18 | def pdf():
19 | pdf = pd.DataFrame({"x": range(100)})
20 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions
21 | yield pdf
22 |
23 |
24 | @pytest.fixture
25 | def df(pdf):
26 | yield from_pandas(pdf, npartitions=10)
27 |
28 |
29 | @pytest.mark.parametrize("op", ["__add__", "add"])
30 | def test_broadcasting_scalar(pdf, df, op):
31 | df2 = from_pandas(pdf, npartitions=2)
32 | result = getattr(df, op)(df2.x.sum())
33 | assert_eq(result, pdf + pdf.x.sum())
34 | assert len(list(result.expr.find_operations(OpAlignPartitions))) == 0
35 |
36 | divisions_lru.data = OrderedDict()
37 | result = getattr(df.set_index("x"), op)(df2.x.sum())
38 | # Make sure that we don't touch divisions
39 | assert len(divisions_lru.data) == 0
40 | assert_eq(result, pdf.set_index("x") + pdf.x.sum())
41 | assert len(list(result.expr.find_operations(OpAlignPartitions))) == 0
42 |
43 | if op == "__add__":
44 | # Can't avoid expensive alignment check, but don't touch divisions while figuring it out
45 | divisions_lru.data = OrderedDict()
46 | result = getattr(df.set_index("x"), op)(df2.set_index("x").sum())
47 | # Make sure that we don't touch divisions
48 | assert len(divisions_lru.data) == 0
49 | assert_eq(result, pdf.set_index("x") + pdf.set_index("x").sum())
50 | assert len(list(result.expr.find_operations(OpAlignPartitions))) > 0
51 |
52 | assert (
53 | len(
54 | list(
55 | result.optimize(fuse=False).expr.find_operations(
56 | RepartitionDivisions
57 | )
58 | )
59 | )
60 | == 0
61 | )
62 |
63 | # Can't avoid alignment, but don't touch divisions while figuring it out
64 | divisions_lru.data = OrderedDict()
65 | result = getattr(df.set_index("x"), op)(df2.set_index("x"))
66 | # Make sure that we don't touch divisions
67 | assert len(divisions_lru.data) == 0
68 | assert_eq(result, pdf.set_index("x") + pdf.set_index("x"))
69 | assert len(list(result.expr.find_operations(OpAlignPartitions))) > 0
70 |
71 | assert (
72 | len(
73 | list(result.optimize(fuse=False).expr.find_operations(RepartitionDivisions))
74 | )
75 | > 0
76 | )
77 |
78 |
79 | @pytest.mark.parametrize("sorted_index", [False, True])
80 | @pytest.mark.parametrize("sorted_map_index", [False, True])
81 | def test_series_map(sorted_index, sorted_map_index):
82 | base = pd.Series(
83 | ["".join(np.random.choice(["a", "b", "c"], size=3)) for x in range(100)]
84 | )
85 | if not sorted_index:
86 | index = np.arange(100)
87 | np.random.shuffle(index)
88 | base.index = index
89 | map_index = ["".join(x) for x in product("abc", repeat=3)]
90 | mapper = pd.Series(np.random.randint(50, size=len(map_index)), index=map_index)
91 | if not sorted_map_index:
92 | map_index = np.array(map_index)
93 | np.random.shuffle(map_index)
94 | mapper.index = map_index
95 | expected = base.map(mapper)
96 | dask_base = from_pandas(base, npartitions=1, sort=False)
97 | dask_map = from_pandas(mapper, npartitions=1, sort=False)
98 | result = dask_base.map(dask_map)
99 | assert_eq(expected, result)
100 |
101 |
102 | def test_assign_align_partitions():
103 | pdf = pd.DataFrame({"x": [0] * 20, "y": range(20)})
104 | df = from_pandas(pdf, npartitions=2)
105 | s = pd.Series(range(10, 30))
106 | ds = from_pandas(s, npartitions=df.npartitions)
107 | result = df.assign(z=ds)[["y", "z"]]
108 | expected = pdf.assign(z=s)[["y", "z"]]
109 | assert_eq(result, expected)
110 |
111 |
112 | def test_assign_unknown_partitions(pdf):
113 | pdf2 = pdf.sort_index(ascending=False)
114 | df2 = from_pandas(pdf2, npartitions=3, sort=False)
115 | df1 = from_pandas(pdf, npartitions=3).clear_divisions()
116 | df1["new"] = df2.x
117 | expected = pdf.copy()
118 | expected["new"] = pdf2.x
119 | assert_eq(df1, expected)
120 | assert len(list(df1.optimize(fuse=False).expr.find_operations(Shuffle))) == 2
121 |
122 | pdf["c"] = "a"
123 | pdf = pdf.set_index("c")
124 | df = from_pandas(pdf, npartitions=3)
125 | df["new"] = df2.x
126 | with pytest.raises(TypeError, match="have differing dtypes"):
127 | df.optimize()
128 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_categorical.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr import from_pandas
4 | from dask_expr._categorical import GetCategories
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | # Set DataFrame backend for this module
8 | pd = _backend_library()
9 |
10 |
11 | @pytest.fixture
12 | def pdf():
13 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 1, 2], "y": "bcbbbc"})
14 | return pdf
15 |
16 |
17 | @pytest.fixture
18 | def df(pdf):
19 | yield from_pandas(pdf, npartitions=2)
20 |
21 |
22 | def test_set_categories(pdf):
23 | pdf = pdf.astype("category")
24 | df = from_pandas(pdf, npartitions=2)
25 | assert df.x.cat.known
26 | assert_eq(df.x.cat.codes, pdf.x.cat.codes)
27 | ser = df.x.cat.as_unknown()
28 | assert not ser.cat.known
29 | ser = ser.cat.as_known()
30 | assert_eq(ser.cat.categories, pd.Index([1, 2, 3, 4]))
31 | ser = ser.cat.set_categories([1, 2, 3, 5, 4])
32 | assert_eq(ser.cat.categories, pd.Index([1, 2, 3, 5, 4]))
33 | assert not ser.cat.ordered
34 |
35 |
36 | def test_categorize(df, pdf):
37 | df = df.categorize()
38 |
39 | assert df.y.cat.known
40 | assert_eq(df, pdf.astype({"y": "category"}), check_categorical=False)
41 |
42 |
43 | def test_get_categories_simplify_adds_projection(df):
44 | optimized = GetCategories(
45 | df, columns=["y"], index=False, split_every=None
46 | ).simplify()
47 | expected = GetCategories(
48 | df[["y"]].simplify(), columns=["y"], index=False, split_every=None
49 | )
50 | assert optimized._name == expected._name
51 |
52 |
53 | def test_categorical_set_index():
54 | df = pd.DataFrame({"x": [1, 2, 3, 4], "y": ["a", "b", "b", "c"]})
55 | df["y"] = pd.Categorical(df["y"], categories=["a", "b", "c"], ordered=True)
56 | a = from_pandas(df, npartitions=2)
57 |
58 | b = a.set_index("y", divisions=["a", "b", "c"], npartitions=a.npartitions)
59 | d1, d2 = b.get_partition(0), b.get_partition(1)
60 | assert list(d1.index.compute(fuse=False)) == ["a"]
61 | assert list(sorted(d2.index.compute())) == ["b", "b", "c"]
62 |
63 |
64 | def test_categorize_drops_category_columns():
65 | pdf = pd.DataFrame({"a": [1, 2, 1, 2, 3], "b": 1})
66 | df = from_pandas(pdf)
67 | df = df.categorize(columns=["a"])
68 | result = df["b"].to_frame()
69 | expected = pdf["b"].to_frame()
70 | assert_eq(result, expected)
71 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_core.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr._core import Expr
4 |
5 |
6 | class ExprB(Expr):
7 | def _simplify_down(self):
8 | return ExprA()
9 |
10 |
11 | class ExprA(Expr):
12 | def _simplify_down(self):
13 | return ExprB()
14 |
15 |
16 | def test_endless_simplify():
17 | expr = ExprA()
18 | with pytest.raises(RuntimeError, match="converge"):
19 | expr.simplify()
20 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_cumulative.py:
--------------------------------------------------------------------------------
1 | from dask_expr import from_pandas
2 | from dask_expr.tests._util import _backend_library, assert_eq
3 |
4 | # Set DataFrame backend for this module
5 | pd = _backend_library()
6 |
7 |
8 | def test_cumulative_empty_partitions():
9 | pdf = pd.DataFrame(
10 | {"x": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]},
11 | index=pd.date_range("1995-02-26", periods=8, freq="5min"),
12 | dtype=float,
13 | )
14 | pdf2 = pdf.drop(pdf.between_time("00:10", "00:20").index)
15 |
16 | df = from_pandas(pdf, npartitions=8)
17 | df2 = from_pandas(pdf2, npartitions=1).repartition(df.divisions)
18 |
19 | assert_eq(df2.cumprod(), pdf2.cumprod())
20 | assert_eq(df2.cumsum(), pdf2.cumsum())
21 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import sys
3 |
4 | import pytest
5 |
6 | from dask_expr import new_collection
7 | from dask_expr._expr import Lengths
8 | from dask_expr.datasets import Timeseries, timeseries
9 | from dask_expr.tests._util import assert_eq
10 |
11 |
12 | def test_timeseries():
13 | df = timeseries(freq="360 s", start="2000-01-01", end="2000-01-02")
14 | assert_eq(df, df)
15 |
16 |
17 | def test_optimization():
18 | df = timeseries(dtypes={"x": int, "y": float}, seed=123)
19 | expected = timeseries(dtypes={"x": int}, seed=123)
20 | result = df[["x"]].optimize(fuse=False)
21 | assert result.expr.frame.operand("columns") == expected.expr.frame.operand(
22 | "columns"
23 | )
24 |
25 | expected = timeseries(dtypes={"x": int}, seed=123)["x"].simplify()
26 | result = df["x"].optimize(fuse=False)
27 | assert expected.expr.frame.operand("columns") == result.expr.frame.operand(
28 | "columns"
29 | )
30 |
31 |
32 | def test_arrow_string_option():
33 | df = timeseries(dtypes={"x": object, "y": float}, seed=123)
34 | result = df.optimize(fuse=False)
35 | assert result.x.dtype == "string"
36 | assert result.x.compute().dtype == "string"
37 |
38 |
39 | def test_column_projection_deterministic():
40 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-02", seed=123)
41 | result_id = df[["id"]].optimize()
42 | result_id_x = df[["id", "x"]].optimize()
43 | assert_eq(result_id["id"], result_id_x["id"])
44 |
45 |
46 | def test_timeseries_culling():
47 | df = timeseries(dtypes={"x": int, "y": float}, seed=123)
48 | pdf = df.compute()
49 | offset = len(df.partitions[0].compute())
50 | df = (df[["x"]] + 1).partitions[1]
51 | df2 = df.optimize()
52 |
53 | # All tasks should be fused for the single output partition
54 | assert df2.npartitions == 1
55 | assert len(df2.dask) == df2.npartitions
56 | expected = pdf.iloc[offset : 2 * offset][["x"]] + 1
57 | assert_eq(df2, expected)
58 |
59 |
60 | def test_persist():
61 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-02", seed=123)
62 | a = df["x"]
63 | b = a.persist()
64 |
65 | assert_eq(a, b)
66 | assert len(b.dask) == 2 * b.npartitions
67 |
68 |
69 | def test_lengths():
70 | df = timeseries(freq="1h", start="2000-01-01", end="2000-01-03", seed=123)
71 | assert len(df) == sum(new_collection(Lengths(df.expr).optimize()).compute())
72 |
73 |
74 | def test_timeseries_empty_projection():
75 | ts = timeseries(end="2000-01-02", dtypes={})
76 | expected = timeseries(end="2000-01-02")
77 | assert len(ts) == len(expected)
78 |
79 |
80 | def test_combine_similar(tmpdir):
81 | df = timeseries(end="2000-01-02")
82 | pdf = df.compute()
83 | got = df[df["name"] == "a"][["id"]]
84 |
85 | expected = pdf[pdf["name"] == "a"][["id"]]
86 | assert_eq(got, expected)
87 | assert_eq(got.optimize(fuse=False), expected)
88 | assert_eq(got.optimize(fuse=True), expected)
89 |
90 | # We should only have one Timeseries node, and
91 | # it should not include "z" in the dtypes
92 | timeseries_nodes = list(got.optimize(fuse=False).find_operations(Timeseries))
93 | assert len(timeseries_nodes) == 1
94 | assert set(timeseries_nodes[0].dtypes.keys()) == {"id", "name"}
95 |
96 | df = timeseries(end="2000-01-02")
97 | df2 = timeseries(end="2000-01-02")
98 |
99 | got = df + df2
100 | timeseries_nodes = list(got.optimize(fuse=False).find_operations(Timeseries))
101 | assert len(timeseries_nodes) == 2
102 | with pytest.raises(AssertionError):
103 | assert_eq(df + df2, 2 * df)
104 |
105 |
106 | @pytest.mark.parametrize("seed", [42, None])
107 | def test_timeseries_deterministic_head(seed):
108 | # Make sure our `random_state` code gives
109 | # us deterministic results
110 | df = timeseries(end="2000-01-02", seed=seed)
111 | assert_eq(df.head(), df.head())
112 | assert_eq(df["x"].head(), df.head()["x"])
113 | assert_eq(df.head()["x"], df["x"].partitions[0].compute().head())
114 |
115 |
116 | @pytest.mark.parametrize("seed", [42, None])
117 | def test_timeseries_gaph_size(seed):
118 | from dask.datasets import timeseries as dd_timeseries
119 |
120 | # Check that our graph size is reasonable
121 | df = timeseries(seed=seed)
122 | ddf = dd_timeseries(seed=seed)
123 | graph_size = sys.getsizeof(pickle.dumps(df.dask))
124 | graph_size_dd = sys.getsizeof(pickle.dumps(dict(ddf.dask)))
125 | # Make sure we are close to the dask.dataframe graph size
126 | threshold = 1.10
127 | assert graph_size < threshold * graph_size_dd
128 |
129 |
130 | def test_dataset_head():
131 | ddf = timeseries(freq="1d")
132 | expected = ddf.compute()
133 | assert_eq(ddf.head(30, npartitions=-1), expected)
134 | assert_eq(ddf.head(30, npartitions=-1, compute=False), expected)
135 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_datetime.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr._collection import from_pandas
4 | from dask_expr.tests._util import _backend_library, assert_eq
5 |
6 | pd = _backend_library()
7 |
8 |
9 | @pytest.fixture()
10 | def ser():
11 | return pd.Series(pd.date_range(start="2020-01-01", end="2020-03-03"))
12 |
13 |
14 | @pytest.fixture()
15 | def dser(ser):
16 | return from_pandas(ser, npartitions=3)
17 |
18 |
19 | @pytest.fixture()
20 | def ser_td():
21 | return pd.Series(pd.timedelta_range(start="0s", end="1000s", freq="s"))
22 |
23 |
24 | @pytest.fixture()
25 | def dser_td(ser_td):
26 | return from_pandas(ser_td, npartitions=3)
27 |
28 |
29 | @pytest.fixture()
30 | def ser_pr():
31 | return pd.Series(pd.period_range(start="2020-01-01", end="2020-12-31", freq="D"))
32 |
33 |
34 | @pytest.fixture()
35 | def dser_pr(ser_pr):
36 | return from_pandas(ser_pr, npartitions=3)
37 |
38 |
39 | @pytest.mark.parametrize(
40 | "func, args",
41 | [
42 | ("ceil", ("D",)),
43 | ("day_name", ()),
44 | ("floor", ("D",)),
45 | ("isocalendar", ()),
46 | ("month_name", ()),
47 | ("normalize", ()),
48 | ("round", ("D",)),
49 | ("strftime", ("%B %d, %Y, %r",)),
50 | ("to_period", ("D",)),
51 | ],
52 | )
53 | def test_datetime_accessor_methods(ser, dser, func, args):
54 | assert_eq(getattr(ser.dt, func)(*args), getattr(dser.dt, func)(*args))
55 |
56 |
57 | @pytest.mark.parametrize(
58 | "func",
59 | [
60 | "date",
61 | "day",
62 | "day_of_week",
63 | "day_of_year",
64 | "dayofweek",
65 | "dayofyear",
66 | "days_in_month",
67 | "daysinmonth",
68 | "hour",
69 | "is_leap_year",
70 | "is_month_end",
71 | "is_month_start",
72 | "is_quarter_end",
73 | "is_quarter_start",
74 | "is_year_end",
75 | "is_year_start",
76 | "microsecond",
77 | "minute",
78 | "month",
79 | "nanosecond",
80 | "quarter",
81 | "second",
82 | "time",
83 | "timetz",
84 | "weekday",
85 | "year",
86 | ],
87 | )
88 | def test_datetime_accessor_properties(ser, dser, func):
89 | assert_eq(getattr(ser.dt, func), getattr(dser.dt, func))
90 |
91 |
92 | @pytest.mark.parametrize(
93 | "func",
94 | [
95 | "components",
96 | "days",
97 | "microseconds",
98 | "nanoseconds",
99 | "seconds",
100 | ],
101 | )
102 | def test_timedelta_accessor_properties(ser_td, dser_td, func):
103 | assert_eq(getattr(ser_td.dt, func), getattr(dser_td.dt, func))
104 |
105 |
106 | @pytest.mark.parametrize(
107 | "func",
108 | [
109 | "end_time",
110 | "start_time",
111 | "qyear",
112 | "week",
113 | "weekofyear",
114 | ],
115 | )
116 | def test_period_accessor_properties(ser_pr, dser_pr, func):
117 | assert_eq(getattr(ser_pr.dt, func), getattr(dser_pr.dt, func))
118 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_describe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from dask_expr import from_pandas
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | # Set DataFrame backend for this module
8 | pd = _backend_library()
9 |
10 |
11 | @pytest.fixture
12 | def pdf():
13 | pdf = pd.DataFrame(
14 | {
15 | "x": [None, 0, 1, 2, 3, 4] * 2,
16 | "ts": [
17 | pd.Timestamp("2017-05-09 00:00:00.006000"),
18 | pd.Timestamp("2017-05-09 00:00:00.006000"),
19 | pd.Timestamp("2017-05-09 07:56:23.858694"),
20 | pd.Timestamp("2017-05-09 05:59:58.938999"),
21 | None,
22 | None,
23 | ]
24 | * 2,
25 | "td": [
26 | np.timedelta64(3, "D"),
27 | np.timedelta64(1, "D"),
28 | None,
29 | None,
30 | np.timedelta64(3, "D"),
31 | np.timedelta64(1, "D"),
32 | ]
33 | * 2,
34 | "y": "a",
35 | }
36 | )
37 | yield pdf
38 |
39 |
40 | @pytest.fixture
41 | def df(pdf):
42 | yield from_pandas(pdf, npartitions=2)
43 |
44 |
45 | def _drop_mean(df, col=None):
46 | """TODO: In pandas 2.0, mean is implemented for datetimes, but Dask returns None."""
47 | if isinstance(df, pd.DataFrame):
48 | df.at["mean", col] = np.nan
49 | df.dropna(how="all", inplace=True)
50 | elif isinstance(df, pd.Series):
51 | df.drop(labels=["mean"], inplace=True, errors="ignore")
52 | else:
53 | raise NotImplementedError("Expected Series or DataFrame with mean")
54 | return df
55 |
56 |
57 | def test_describe_series(df, pdf):
58 | assert_eq(df.x.describe(), pdf.x.describe())
59 | assert_eq(df.y.describe(), pdf.y.describe())
60 | assert_eq(df.ts.describe(), _drop_mean(pdf.ts.describe()))
61 | assert_eq(df.td.describe(), pdf.td.describe())
62 |
63 |
64 | def test_describe_df(df, pdf):
65 | assert_eq(df.describe(), _drop_mean(pdf.describe(), "ts"))
66 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_diagnostics.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 |
5 | import pytest
6 |
7 | pytest.importorskip("distributed")
8 |
9 | from distributed.utils_test import * # noqa
10 |
11 | from dask_expr import from_pandas
12 | from dask_expr.tests._util import _backend_library
13 |
14 | # Set DataFrame backend for this module
15 | pd = _backend_library()
16 |
17 |
18 | @pytest.fixture
19 | def pdf():
20 | pdf = pd.DataFrame({"x": range(100)})
21 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions
22 | yield pdf
23 |
24 |
25 | @pytest.fixture
26 | def df(pdf):
27 | yield from_pandas(pdf, npartitions=10)
28 |
29 |
30 | def test_analyze(df, client, tmpdir):
31 | pytest.importorskip("crick")
32 | filename = str(tmpdir / "analyze")
33 | expr = df.groupby(df.columns[1]).apply(lambda x: x)
34 | digraph = expr.analyze(filename=filename)
35 | assert os.path.exists(filename + ".svg")
36 | for exp in expr.optimize().walk():
37 | assert any(exp._name in el for el in digraph.body)
38 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_dummies.py:
--------------------------------------------------------------------------------
1 | import pandas
2 | import pytest
3 |
4 | from dask_expr import from_pandas, get_dummies
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | pd = _backend_library()
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "data",
12 | [
13 | pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype="category"),
14 | pd.Series(
15 | pandas.Categorical([1, 1, 1, 2, 2, 1, 3, 4], categories=[4, 3, 2, 1])
16 | ),
17 | pd.DataFrame(
18 | {"a": [1, 2, 3, 4, 4, 3, 2, 1], "b": pandas.Categorical(list("abcdabcd"))}
19 | ),
20 | ],
21 | )
22 | def test_get_dummies(data):
23 | exp = pd.get_dummies(data)
24 |
25 | ddata = from_pandas(data, 2)
26 | res = get_dummies(ddata)
27 | assert_eq(res, exp)
28 | pandas.testing.assert_index_equal(res.columns, exp.columns)
29 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_format.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: W291
2 | from textwrap import dedent
3 |
4 | import pytest
5 | from dask.utils import maybe_pluralize
6 |
7 | from dask_expr import from_pandas
8 | from dask_expr.tests._util import _backend_library
9 |
10 | # Set DataFrame backend for this module
11 | pd = _backend_library()
12 |
13 |
14 | def test_to_string():
15 | pytest.importorskip("jinja2")
16 | df = pd.DataFrame(
17 | {
18 | "A": [1, 2, 3, 4, 5, 6, 7, 8],
19 | "B": list("ABCDEFGH"),
20 | "C": pd.Categorical(list("AAABBBCC")),
21 | }
22 | )
23 | ddf = from_pandas(df, 3)
24 |
25 | exp = dedent(
26 | """\
27 | A B C
28 | npartitions=3
29 | 0 int64 string category[known]
30 | 3 ... ... ...
31 | 6 ... ... ...
32 | 7 ... ... ..."""
33 | )
34 | assert ddf.to_string() == exp
35 |
36 | exp_table = """
37 |
38 |
39 | |
40 | A |
41 | B |
42 | C |
43 |
44 |
45 | npartitions=3 |
46 | |
47 | |
48 | |
49 |
50 |
51 |
52 |
53 | 0 |
54 | int64 |
55 | string |
56 | category[known] |
57 |
58 |
59 | 3 |
60 | ... |
61 | ... |
62 | ... |
63 |
64 |
65 | 6 |
66 | ... |
67 | ... |
68 | ... |
69 |
70 |
71 | 7 |
72 | ... |
73 | ... |
74 | ... |
75 |
76 |
77 |
""" # noqa E222, E702
78 | footer = f"Dask Name: frompandas, {maybe_pluralize(1, 'expression')}"
79 | exp = f"""Dask DataFrame Structure:
80 | {exp_table}
81 | {footer}
"""
82 | assert ddf.to_html() == exp
83 |
84 |
85 | def test_series_format():
86 | s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], index=list("ABCDEFGH"))
87 | ds = from_pandas(s, 3)
88 |
89 | exp = dedent(
90 | """\
91 | npartitions=3
92 | A int64
93 | D ...
94 | G ...
95 | H ..."""
96 | )
97 | assert ds.to_string() == exp
98 |
99 |
100 | def test_series_repr():
101 | s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], index=list("ABCDEFGH"))
102 | ds = from_pandas(s, 3)
103 |
104 | exp = dedent(
105 | """\
106 | Dask Series Structure:
107 | npartitions=3
108 | A int64
109 | D ...
110 | G ...
111 | H ...
112 | Dask Name: frompandas, 1 expression
113 | Expr=df"""
114 | )
115 | assert repr(ds) == exp
116 |
117 | # Not a cheap way to determine if series is empty
118 | # so does not prefix with "Empty" as we do w/ empty DataFrame
119 | s = pd.Series([])
120 | ds = from_pandas(s, 3)
121 |
122 | exp = dedent(
123 | """\
124 | Dask Series Structure:
125 | npartitions=3
126 | string
127 | ...
128 | ...
129 | ...
130 | Dask Name: frompandas, 1 expression
131 | Expr=df"""
132 | )
133 | assert repr(ds) == exp
134 |
135 |
136 | def test_df_repr():
137 | df = pd.DataFrame({"col1": range(10), "col2": map(float, range(10))})
138 | ddf = from_pandas(df, 3)
139 |
140 | exp = dedent(
141 | """\
142 | Dask DataFrame Structure:
143 | col1 col2
144 | npartitions=3
145 | 0 int64 float64
146 | 4 ... ...
147 | 7 ... ...
148 | 9 ... ...
149 | Dask Name: frompandas, 1 expression
150 | Expr=df"""
151 | )
152 | assert repr(ddf) == exp
153 |
154 | df = pd.DataFrame()
155 | ddf = from_pandas(df, 3)
156 |
157 | exp = dedent(
158 | """\
159 | Empty Dask DataFrame Structure:
160 | npartitions=3
161 | 0 int64 float64
162 | 4 ... ...
163 | 7 ... ...
164 | 9 ... ...
165 | Dask Name: frompandas, 1 expression
166 | Expr=df"""
167 | )
168 |
169 |
170 | def test_df_to_html():
171 | df = pd.DataFrame({"col1": range(10), "col2": map(float, range(10))})
172 | ddf = from_pandas(df, 3)
173 |
174 | exp = dedent(
175 | """\
176 | Dask DataFrame Structure:
177 |
178 |
179 |
180 | |
181 | col1 |
182 | col2 |
183 |
184 |
185 | npartitions=3 |
186 | |
187 | |
188 |
189 |
190 |
191 |
192 | 0 |
193 | int64 |
194 | float64 |
195 |
196 |
197 | 4 |
198 | ... |
199 | ... |
200 |
201 |
202 | 7 |
203 | ... |
204 | ... |
205 |
206 |
207 | 9 |
208 | ... |
209 | ... |
210 |
211 |
212 |
213 | Dask Name: frompandas, 1 expression
"""
214 | )
215 | assert ddf.to_html() == exp
216 | assert ddf._repr_html_() == exp # for jupyter
217 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_fusion.py:
--------------------------------------------------------------------------------
1 | import dask.dataframe as dd
2 | import pytest
3 |
4 | from dask_expr import from_pandas, optimize
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | # Set DataFrame backend for this module
8 | pd = _backend_library()
9 |
10 |
11 | @pytest.fixture
12 | def pdf():
13 | pdf = pd.DataFrame({"x": range(100)})
14 | pdf["y"] = pdf.x * 10.0
15 | yield pdf
16 |
17 |
18 | @pytest.fixture
19 | def df(pdf):
20 | yield from_pandas(pdf, npartitions=10)
21 |
22 |
23 | def test_simple(df):
24 | out = (df["x"] + df["y"]) - 1
25 | unfused = optimize(out, fuse=False)
26 | fused = optimize(out, fuse=True)
27 |
28 | # Should only get one task per partition
29 | # from_pandas is not fused together
30 | assert len(fused.dask) == df.npartitions + 10
31 | assert_eq(fused, unfused)
32 |
33 |
34 | def test_with_non_fusable_on_top(df):
35 | out = (df["x"] + df["y"] - 1).sum()
36 | unfused = optimize(out, fuse=False)
37 | fused = optimize(out, fuse=True)
38 |
39 | assert len(fused.dask) < len(unfused.dask)
40 | assert_eq(fused, unfused)
41 |
42 | # Check that we still get fusion
43 | # after a non-blockwise operation as well
44 | fused_2 = optimize((out + 10) - 5, fuse=True)
45 | assert len(fused_2.dask) == len(fused.dask) + 1 # only one more task
46 |
47 |
48 | def test_optimize_fusion_many():
49 | # Test that many `Blockwise`` operations,
50 | # originating from various IO operations,
51 | # can all be fused together
52 | a = from_pandas(pd.DataFrame({"x": range(100), "y": range(100)}), 10)
53 | b = from_pandas(pd.DataFrame({"a": range(100)}), 10)
54 |
55 | # some generic elemwise operations
56 | aa = a[["x"]] + 1
57 | aa["a"] = a["y"] + a["x"]
58 | aa["b"] = aa["x"] + 2
59 | series_a = aa[a["x"] > 1]["b"]
60 |
61 | bb = b[["a"]] + 1
62 | bb["b"] = b["a"] + b["a"]
63 | series_b = bb[b["a"] > 1]["b"]
64 |
65 | result = (series_a + series_b) + 1
66 | fused = optimize(result, fuse=True)
67 | unfused = optimize(result, fuse=False)
68 | assert fused.npartitions == a.npartitions
69 | # from_pandas is not fused together
70 | assert len(fused.dask) == fused.npartitions + 20
71 | assert_eq(fused, unfused)
72 |
73 |
74 | def test_optimize_fusion_repeat(df):
75 | # Test that we can optimize a collection
76 | # more than once, and fusion still works
77 |
78 | original = df.copy()
79 |
80 | # some generic elemwise operations
81 | df["x"] += 1
82 | df["z"] = df.y
83 | df += 2
84 |
85 | # repeatedly call optimize after doing new fusable things
86 | fused = optimize(optimize(optimize(df) + 2).x)
87 |
88 | # from_pandas is not fused together
89 | assert len(fused.dask) - 10 == fused.npartitions == original.npartitions
90 | assert_eq(fused, df.x + 2)
91 |
92 |
93 | def test_optimize_fusion_broadcast(df):
94 | # Check fusion with broadcated reduction
95 | result = ((df["x"] + 1) + df["y"].sum()) + 1
96 | fused = optimize(result)
97 |
98 | assert_eq(fused, result)
99 | assert len(fused.dask) < len(result.dask)
100 |
101 |
102 | def test_persist_with_fusion(df):
103 | # Check that fusion works after persisting
104 | df = (df + 2).persist()
105 | out = (df.y + 1).sum()
106 | fused = optimize(out)
107 |
108 | assert_eq(out, fused)
109 | assert len(fused.dask) < len(out.dask)
110 |
111 |
112 | def test_fuse_broadcast_deps():
113 | pdf = pd.DataFrame({"a": [1, 2, 3]})
114 | pdf2 = pd.DataFrame({"a": [2, 3, 4]})
115 | pdf3 = pd.DataFrame({"a": [3, 4, 5]})
116 | df = from_pandas(pdf, npartitions=1)
117 | df2 = from_pandas(pdf2, npartitions=1)
118 | df3 = from_pandas(pdf3, npartitions=2)
119 |
120 | query = df.merge(df2).merge(df3)
121 | # from_pandas is not fused together
122 | assert len(query.optimize().__dask_graph__()) == 2 + 4
123 | assert_eq(query, pdf.merge(pdf2).merge(pdf3))
124 |
125 |
126 | def test_name(df):
127 | out = (df["x"] + df["y"]) - 1
128 | fused = optimize(out, fuse=True)
129 | assert "getitem" in str(fused.expr)
130 | assert "sub" in str(fused.expr)
131 | assert str(fused.expr) == str(fused.expr).lower()
132 |
133 |
134 | def test_fusion_executes_only_once():
135 | times_called = []
136 | import pandas as pd
137 |
138 | def test(i):
139 | times_called.append(i)
140 | return pd.DataFrame({"a": [1, 2, 3], "b": 1})
141 |
142 | df = dd.from_map(test, [1], meta=[("a", "i8"), ("b", "i8")])
143 | df = df[df.a > 1]
144 | df.sum().compute()
145 | assert len(times_called) == 1
146 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_indexing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from dask_expr import from_pandas
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | pd = _backend_library()
8 |
9 |
10 | @pytest.fixture
11 | def pdf():
12 | pdf = pd.DataFrame({"x": range(20)})
13 | pdf["y"] = pdf.x * 10.0
14 | yield pdf
15 |
16 |
17 | @pytest.fixture
18 | def df(pdf):
19 | yield from_pandas(pdf, npartitions=4)
20 |
21 |
22 | def test_iloc(df, pdf):
23 | assert_eq(df.iloc[:, 1], pdf.iloc[:, 1])
24 | assert_eq(df.iloc[:, [1]], pdf.iloc[:, [1]])
25 | assert_eq(df.iloc[:, [0, 1]], pdf.iloc[:, [0, 1]])
26 | assert_eq(df.iloc[:, []], pdf.iloc[:, []])
27 |
28 |
29 | def test_iloc_errors(df):
30 | with pytest.raises(NotImplementedError):
31 | df.iloc[1]
32 | with pytest.raises(NotImplementedError):
33 | df.iloc[1, 1]
34 | with pytest.raises(ValueError, match="Too many"):
35 | df.iloc[(1, 2, 3)]
36 |
37 |
38 | def test_loc_slice(pdf, df):
39 | pdf.columns = [10, 20]
40 | df.columns = [10, 20]
41 | assert_eq(df.loc[:, :15], pdf.loc[:, :15])
42 | assert_eq(df.loc[:, 15:], pdf.loc[:, 15:])
43 | assert_eq(df.loc[:, 25:], pdf.loc[:, 25:]) # no columns
44 | assert_eq(df.loc[:, ::-1], pdf.loc[:, ::-1])
45 |
46 |
47 | def test_iloc_slice(df, pdf):
48 | assert_eq(df.iloc[:, :1], pdf.iloc[:, :1])
49 | assert_eq(df.iloc[:, 1:], pdf.iloc[:, 1:])
50 | assert_eq(df.iloc[:, 99:], pdf.iloc[:, 99:]) # no columns
51 | assert_eq(df.iloc[:, ::-1], pdf.iloc[:, ::-1])
52 |
53 |
54 | @pytest.mark.parametrize("loc", [False, True])
55 | @pytest.mark.parametrize("update", [False, True])
56 | def test_columns_dtype_on_empty_slice(df, pdf, loc, update):
57 | pdf.columns = [10, 20]
58 | if update:
59 | df.columns = [10, 20]
60 | else:
61 | df = from_pandas(pdf, npartitions=10)
62 |
63 | assert df.columns.dtype == pdf.columns.dtype
64 | assert df.compute().columns.dtype == pdf.columns.dtype
65 | assert_eq(df, pdf)
66 |
67 | if loc:
68 | df = df.loc[:, []]
69 | pdf = pdf.loc[:, []]
70 | else:
71 | df = df[[]]
72 | pdf = pdf[[]]
73 |
74 | assert df.columns.dtype == pdf.columns.dtype
75 | assert df.compute().columns.dtype == pdf.columns.dtype
76 | assert_eq(df, pdf)
77 |
78 |
79 | def test_loc(df, pdf):
80 | assert_eq(df.loc[:, "x"], pdf.loc[:, "x"])
81 | assert_eq(df.loc[:, ["x"]], pdf.loc[:, ["x"]])
82 | assert_eq(df.loc[:, []], pdf.loc[:, []])
83 |
84 | assert_eq(df.loc[df.y == 20, "x"], pdf.loc[pdf.y == 20, "x"])
85 | assert_eq(df.loc[df.y == 20, ["x"]], pdf.loc[pdf.y == 20, ["x"]])
86 | assert df.loc[3:8].divisions[0] == 3
87 | assert df.loc[3:8].divisions[-1] == 8
88 |
89 | assert df.loc[5].divisions == (5, 5)
90 |
91 | assert_eq(df.loc[5], pdf.loc[5:5])
92 | assert_eq(df.loc[3:8], pdf.loc[3:8])
93 | assert_eq(df.loc[:8], pdf.loc[:8])
94 | assert_eq(df.loc[3:], pdf.loc[3:])
95 | assert_eq(df.loc[[5]], pdf.loc[[5]])
96 |
97 | assert_eq(df.x.loc[5], pdf.x.loc[5:5])
98 | assert_eq(df.x.loc[3:8], pdf.x.loc[3:8])
99 | assert_eq(df.x.loc[:8], pdf.x.loc[:8])
100 | assert_eq(df.x.loc[3:], pdf.x.loc[3:])
101 | assert_eq(df.x.loc[[5]], pdf.x.loc[[5]])
102 | assert_eq(df.x.loc[[]], pdf.x.loc[[]])
103 | assert_eq(df.x.loc[np.array([])], pdf.x.loc[np.array([])])
104 |
105 | pytest.raises(KeyError, lambda: df.loc[1000])
106 | assert_eq(df.loc[1000:], pdf.loc[1000:])
107 | assert_eq(df.loc[1000:2000], pdf.loc[1000:2000])
108 | assert_eq(df.loc[:-1000], pdf.loc[:-1000])
109 | assert_eq(df.loc[-2000:-1000], pdf.loc[-2000:-1000])
110 |
111 |
112 | def test_loc_non_informative_index():
113 | df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 30, 40])
114 | ddf = from_pandas(df, npartitions=2, sort=True).clear_divisions()
115 | assert not ddf.known_divisions
116 |
117 | ddf.loc[20:30].compute(scheduler="sync")
118 |
119 | assert_eq(ddf.loc[20:30], df.loc[20:30])
120 |
121 | df = pd.DataFrame({"x": [1, 2, 3, 4]}, index=[10, 20, 20, 40])
122 | ddf = from_pandas(df, npartitions=2, sort=True)
123 | assert_eq(ddf.loc[20], df.loc[20:20])
124 |
125 |
126 | def test_loc_with_series(df, pdf):
127 | assert_eq(df.loc[df.x % 2 == 0], pdf.loc[pdf.x % 2 == 0])
128 |
129 |
130 | def test_loc_with_array(df, pdf):
131 | assert_eq(df.loc[(df.x % 2 == 0).values], pdf.loc[(pdf.x % 2 == 0).values])
132 |
133 |
134 | def test_loc_with_function(df, pdf):
135 | assert_eq(df.loc[lambda df: df["x"] > 3, :], pdf.loc[lambda df: df["x"] > 3, :])
136 |
137 | def _col_loc_fun(_df):
138 | return _df.columns.str.contains("y")
139 |
140 | assert_eq(df.loc[:, _col_loc_fun], pdf.loc[:, _col_loc_fun])
141 |
142 |
143 | def test_getitem_align():
144 | df = pd.DataFrame(
145 | {
146 | "A": [1, 2, 3, 4, 5, 6, 7, 8, 9],
147 | "B": [9, 8, 7, 6, 5, 4, 3, 2, 1],
148 | "C": [True, False, True] * 3,
149 | },
150 | columns=list("ABC"),
151 | )
152 | ddf = from_pandas(df, 2)
153 | assert_eq(ddf[ddf.C.repartition([0, 2, 5, 8])], df[df.C])
154 |
155 |
156 | def test_loc_bool_cindex():
157 | # https://github.com/dask/dask/issues/11015
158 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
159 | ddf = from_pandas(pdf, npartitions=1)
160 | indexer = [True, False]
161 | assert_eq(pdf.loc[:, indexer], ddf.loc[:, indexer])
162 |
163 |
164 | def test_loc_slicing():
165 | npartitions = 10
166 | pdf = pd.DataFrame(
167 | {
168 | "A": np.random.randn(npartitions * 10),
169 | },
170 | index=pd.date_range("2024-01-01", "2024-12-31", npartitions * 10),
171 | )
172 | df = from_pandas(pdf, npartitions=npartitions)
173 | result = df["2024-03-01":"2024-09-30"]["A"]
174 | assert_eq(result, pdf["2024-03-01":"2024-09-30"]["A"])
175 |
176 |
177 | def test_indexing_element_index():
178 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
179 | result = from_pandas(pdf, 2).loc[2].index
180 | pd.testing.assert_index_equal(result.compute(), pdf.loc[[2]].index)
181 |
182 | result = from_pandas(pdf, 2).loc[[2]].index
183 | pd.testing.assert_index_equal(result.compute(), pdf.loc[[2]].index)
184 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_interchange.py:
--------------------------------------------------------------------------------
1 | from pandas.core.interchange.dataframe_protocol import DtypeKind
2 |
3 | from dask_expr import from_pandas
4 | from dask_expr.tests._util import _backend_library
5 |
6 | # Set DataFrame backend for this module
7 | pd = _backend_library()
8 |
9 |
10 | def test_interchange_protocol():
11 | pdf = pd.DataFrame({"a": [1, 2, 3], "b": 1})
12 | df = from_pandas(pdf, npartitions=2)
13 | df_int = df.__dataframe__()
14 | pd.testing.assert_index_equal(pdf.columns, df_int.column_names())
15 | assert df_int.num_columns() == 2
16 | assert df_int.num_rows() == 3
17 | column = df_int.get_columns()[0]
18 | assert column.dtype()[0] == DtypeKind.INT
19 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_merge_asof.py:
--------------------------------------------------------------------------------
1 | from dask_expr import from_pandas, merge_asof
2 | from dask_expr.tests._util import _backend_library, assert_eq
3 |
4 | pd = _backend_library()
5 |
6 |
7 | def test_merge_asof_indexed():
8 | A = pd.DataFrame(
9 | {"left_val": list("abcd" * 3)},
10 | index=[1, 3, 7, 9, 10, 13, 14, 17, 20, 24, 25, 28],
11 | )
12 | a = from_pandas(A, npartitions=4)
13 | B = pd.DataFrame(
14 | {"right_val": list("xyz" * 4)},
15 | index=[1, 2, 3, 6, 7, 10, 12, 14, 16, 19, 23, 26],
16 | )
17 | b = from_pandas(B, npartitions=3)
18 |
19 | C = pd.merge_asof(A, B, left_index=True, right_index=True)
20 | c = merge_asof(a, b, left_index=True, right_index=True)
21 |
22 | assert_eq(c, C)
23 |
24 |
25 | def test_merge_asof_on_basic():
26 | A = pd.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]})
27 | a = from_pandas(A, npartitions=2)
28 | B = pd.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]})
29 | b = from_pandas(B, npartitions=2)
30 |
31 | C = pd.merge_asof(A, B, on="a")
32 | c = merge_asof(a, b, on="a")
33 | # merge_asof does not preserve index
34 | assert_eq(c, C, check_index=False)
35 |
36 |
37 | def test_merge_asof_one_partition():
38 | left = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
39 | right = pd.DataFrame({"a": [1, 2, 3], "c": [4, 5, 6]})
40 |
41 | ddf_left = from_pandas(left, npartitions=1)
42 | ddf_left = ddf_left.set_index("a", sort=True)
43 | ddf_right = from_pandas(right, npartitions=1)
44 | ddf_right = ddf_right.set_index("a", sort=True)
45 |
46 | result = merge_asof(
47 | ddf_left, ddf_right, left_index=True, right_index=True, direction="nearest"
48 | )
49 | expected = pd.merge_asof(
50 | left.set_index("a"),
51 | right.set_index("a"),
52 | left_index=True,
53 | right_index=True,
54 | direction="nearest",
55 | )
56 | assert_eq(result, expected)
57 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_predicate_pushdown.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr import from_pandas
4 | from dask_expr._expr import rewrite_filters
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | pd = _backend_library()
8 |
9 |
10 | @pytest.fixture
11 | def pdf():
12 | pdf = pd.DataFrame({"x": range(100), "a": 1, "b": 2})
13 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions
14 | yield pdf
15 |
16 |
17 | @pytest.fixture
18 | def df(pdf):
19 | yield from_pandas(pdf, npartitions=10)
20 |
21 |
22 | def test_rewrite_filters(df):
23 | predicate = (df.x == 1) | (df.x == 1)
24 | expected = df.x == 1
25 | assert rewrite_filters(predicate.expr)._name == expected._name
26 |
27 | predicate = (df.x == 1) | ((df.x == 1) & (df.y == 2))
28 | expected = df.x == 1
29 | assert rewrite_filters(predicate.expr)._name == expected._name
30 |
31 | predicate = ((df.x == 1) & (df.y == 3)) | ((df.x == 1) & (df.y == 2))
32 | expected = (df.x == 1) & ((df.y == 3) | (df.y == 2))
33 | assert rewrite_filters(predicate.expr)._name == expected._name
34 |
35 | predicate = ((df.x == 1) & (df.y == 3) & (df.a == 1)) | (
36 | (df.x == 1) & (df.y == 2) & (df.a == 1)
37 | )
38 | expected = ((df.x == 1) & (df.a == 1)) & ((df.y == 3) | (df.y == 2))
39 | assert rewrite_filters(predicate.expr)._name == expected._name
40 |
41 | predicate = (df.x == 1) | (df.y == 1)
42 | assert rewrite_filters(predicate.expr)._name == predicate._name
43 |
44 | predicate = df.x == 1
45 | assert rewrite_filters(predicate.expr)._name == predicate._name
46 |
47 | predicate = (df.x.isin([1, 2, 3]) & (df.y == 3)) | (
48 | df.x.isin([1, 2, 3]) & (df.y == 2)
49 | )
50 | expected = df.x.isin([1, 2, 3]) & ((df.y == 3) | (df.y == 2))
51 | assert rewrite_filters(predicate.expr)._name == expected._name
52 |
53 |
54 | def test_rewrite_filters_query(df, pdf):
55 | result = df[((df.x == 1) & (df.y > 1)) | ((df.x == 1) & (df.y > 2))]
56 | result = result[["x"]]
57 | expected = pdf[((pdf.x == 1) & (pdf.y > 1)) | ((pdf.x == 1) & (pdf.y > 2))]
58 | expected = expected[["x"]]
59 | assert_eq(result, expected)
60 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_quantiles.py:
--------------------------------------------------------------------------------
1 | from dask_expr import from_pandas
2 | from dask_expr.tests._util import _backend_library, assert_eq
3 |
4 | # Set DataFrame backend for this module
5 | pd = _backend_library()
6 |
7 |
8 | def test_repartition_quantiles():
9 | pdf = pd.DataFrame({"a": [1, 2, 3, 4, 5, 15, 7, 8, 9, 10, 11], "d": 3})
10 | df = from_pandas(pdf, npartitions=5)
11 | result = df.a._repartition_quantiles(npartitions=5)
12 | expected = pd.Series(
13 | [1, 1, 3, 7, 9, 15], index=[0, 0.2, 0.4, 0.6, 0.8, 1], name="a"
14 | )
15 | assert_eq(result, expected, check_exact=False)
16 |
17 | result = df.a._repartition_quantiles(npartitions=4)
18 | expected = pd.Series([1, 2, 5, 8, 15], index=[0, 0.25, 0.5, 0.75, 1], name="a")
19 | assert_eq(result, expected, check_exact=False)
20 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_repartition.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from dask_expr import Repartition, from_pandas, repartition
5 | from dask_expr.tests._util import _backend_library, assert_eq
6 |
7 | pd = _backend_library()
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "kwargs",
12 | [
13 | {"npartitions": 2},
14 | {"npartitions": 4},
15 | {"divisions": (0, 1, 79)},
16 | {"partition_size": "1kb"},
17 | ],
18 | )
19 | def test_repartition_combine_similar(kwargs):
20 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2})
21 | df = from_pandas(pdf, npartitions=3)
22 | query = df.repartition(**kwargs)
23 | query["new"] = query.x + query.y
24 | result = query.optimize(fuse=False)
25 |
26 | expected = df.repartition(**kwargs).optimize(fuse=False)
27 | arg1 = expected.x
28 | arg2 = expected.y
29 | expected["new"] = arg1 + arg2
30 | assert result._name == expected._name
31 |
32 | expected_pdf = pdf.copy()
33 | expected_pdf["new"] = expected_pdf.x + expected_pdf.y
34 | assert_eq(result, expected_pdf)
35 |
36 |
37 | @pytest.mark.parametrize("type_ctor", [lambda o: o, tuple, list])
38 | def test_repartition_noop(type_ctor):
39 | pdf = pd.DataFrame({"x": [1, 2, 4, 5], "y": [6, 7, 8, 9]}, index=[-1, 0, 2, 7])
40 | df = from_pandas(pdf, npartitions=2)
41 | ds = df.x
42 |
43 | def assert_not_repartitions(expr, fuse=False):
44 | repartitions = [
45 | x for x in expr.optimize(fuse=fuse).walk() if isinstance(x, Repartition)
46 | ]
47 | assert len(repartitions) == 0
48 |
49 | # DataFrame method
50 | df2 = df.repartition(divisions=type_ctor(df.divisions))
51 | assert_not_repartitions(df2.expr)
52 |
53 | # Top-level dask.dataframe method
54 | df3 = repartition(df, divisions=type_ctor(df.divisions))
55 | assert_not_repartitions(df3.expr)
56 |
57 | # Series method
58 | ds2 = ds.repartition(divisions=type_ctor(ds.divisions))
59 | assert_not_repartitions(ds2.expr)
60 |
61 | # Top-level dask.dataframe method applied to a Series
62 | ds3 = repartition(ds, divisions=type_ctor(ds.divisions))
63 | assert_not_repartitions(ds3.expr)
64 |
65 |
66 | def test_repartition_freq():
67 | ts = pd.date_range("2015-01-01 00:00", "2015-05-01 23:50", freq="10min")
68 | pdf = pd.DataFrame(
69 | np.random.randint(0, 100, size=(len(ts), 4)), columns=list("ABCD"), index=ts
70 | )
71 | df = from_pandas(pdf, npartitions=1).repartition(freq="MS")
72 |
73 | assert_eq(df, pdf)
74 |
75 | assert df.divisions == (
76 | pd.Timestamp("2015-1-1 00:00:00"),
77 | pd.Timestamp("2015-2-1 00:00:00"),
78 | pd.Timestamp("2015-3-1 00:00:00"),
79 | pd.Timestamp("2015-4-1 00:00:00"),
80 | pd.Timestamp("2015-5-1 00:00:00"),
81 | pd.Timestamp("2015-5-1 23:50:00"),
82 | )
83 |
84 | assert df.npartitions == 5
85 |
86 |
87 | def test_repartition_freq_errors():
88 | pdf = pd.DataFrame({"x": [1, 2, 3]})
89 | df = from_pandas(pdf, npartitions=1)
90 | with pytest.raises(TypeError, match="for timeseries"):
91 | df.repartition(freq="1s")
92 |
93 |
94 | def test_repartition_npartitions_numeric_edge_case():
95 | """
96 | Test that we cover numeric edge cases when
97 | int(ddf.npartitions / npartitions) * npartitions) != ddf.npartitions
98 | """
99 | df = pd.DataFrame({"x": range(100)})
100 | a = from_pandas(df, npartitions=15)
101 | assert a.npartitions == 15
102 | b = a.repartition(npartitions=11)
103 | assert_eq(a, b)
104 |
105 |
106 | def test_repartition_empty_partitions_dtype():
107 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8]})
108 | df = from_pandas(pdf, npartitions=4)
109 | assert_eq(
110 | df[df.x < 5].repartition(npartitions=1),
111 | pdf[pdf.x < 5],
112 | )
113 |
114 |
115 | def test_repartition_filter_pushdown():
116 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2})
117 | df = from_pandas(pdf, npartitions=10)
118 | result = df.repartition(npartitions=5)
119 | result = result[result.x > 5.0]
120 | expected = df[df.x > 5.0].repartition(npartitions=5)
121 | assert result.simplify()._name == expected._name
122 |
123 | result = df.repartition(npartitions=5)
124 | result = result[result.x > 5.0][["x", "y"]]
125 | expected = df[["x", "y"]]
126 | expected = expected[expected.x > 5.0].repartition(npartitions=5)
127 | assert result.simplify()._name == expected.simplify()._name
128 |
129 | result = df.repartition(npartitions=5)[["x", "y"]]
130 | result = result[result.x > 5.0]
131 | expected = df[["x", "y"]]
132 | expected = expected[expected.x > 5.0].repartition(npartitions=5)
133 | assert result.simplify()._name == expected.simplify()._name
134 |
135 |
136 | def test_repartition_unknown_divisions():
137 | pdf = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8] * 10, "y": 1, "z": 2})
138 | df = from_pandas(pdf, npartitions=5).clear_divisions()
139 | with pytest.raises(
140 | ValueError, match="Cannot repartition on divisions with unknown divisions"
141 | ):
142 | df.repartition(divisions=(0, 100)).compute()
143 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_resample.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import pytest
4 |
5 | from dask_expr import from_pandas
6 | from dask_expr.tests._util import _backend_library, assert_eq
7 |
8 | # Set DataFrame backend for this module
9 | pd = _backend_library()
10 |
11 |
12 | def resample(df, freq, how="mean", **kwargs):
13 | return getattr(df.resample(freq, **kwargs), how)()
14 |
15 |
16 | @pytest.fixture
17 | def pdf():
18 | idx = pd.date_range("2000-01-01", periods=12, freq="min")
19 | pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx)
20 | pdf["bar"] = 1
21 | yield pdf
22 |
23 |
24 | @pytest.fixture
25 | def df(pdf):
26 | yield from_pandas(pdf, npartitions=4)
27 |
28 |
29 | @pytest.mark.parametrize("kwargs", [{}, {"closed": "left"}])
30 | @pytest.mark.parametrize(
31 | "api",
32 | [
33 | "count",
34 | "prod",
35 | "mean",
36 | "sum",
37 | "min",
38 | "max",
39 | "first",
40 | "last",
41 | "var",
42 | "std",
43 | "size",
44 | "nunique",
45 | "median",
46 | "quantile",
47 | "ohlc",
48 | "sem",
49 | ],
50 | )
51 | def test_resample_apis(df, pdf, api, kwargs):
52 | result = getattr(df.resample("2min", **kwargs), api)()
53 | expected = getattr(pdf.resample("2min", **kwargs), api)()
54 | assert_eq(result, expected)
55 |
56 | # No column output
57 | if api not in ("size",):
58 | result = getattr(df.resample("2min"), api)()["foo"]
59 | expected = getattr(pdf.resample("2min"), api)()["foo"]
60 | assert_eq(result, expected)
61 |
62 | if api != "ohlc":
63 | # ohlc actually gives back a DataFrame, so this doesn't work
64 | q = result.simplify()
65 | eq = getattr(df["foo"].resample("2min"), api)().simplify()
66 | assert q._name == eq._name
67 |
68 |
69 | @pytest.mark.parametrize(
70 | ["obj", "method", "npartitions", "freq", "closed", "label"],
71 | list(
72 | product(
73 | ["series", "frame"],
74 | ["count", "mean", "ohlc"],
75 | [2, 5],
76 | ["30min", "h", "d", "W"],
77 | ["right", "left"],
78 | ["right", "left"],
79 | )
80 | ),
81 | )
82 | def test_series_resample(obj, method, npartitions, freq, closed, label):
83 | index = pd.date_range("1-1-2000", "2-15-2000", freq="h")
84 | index = index.union(pd.date_range("4-15-2000", "5-15-2000", freq="h"))
85 | if obj == "series":
86 | ps = pd.Series(range(len(index)), index=index)
87 | elif obj == "frame":
88 | ps = pd.DataFrame({"a": range(len(index))}, index=index)
89 | ds = from_pandas(ps, npartitions=npartitions)
90 | # Series output
91 |
92 | result = resample(ds, freq, how=method, closed=closed, label=label)
93 | expected = resample(ps, freq, how=method, closed=closed, label=label)
94 |
95 | assert_eq(result, expected, check_dtype=False)
96 |
97 | divisions = result.divisions
98 |
99 | assert expected.index[0] == divisions[0]
100 | assert expected.index[-1] == divisions[-1]
101 |
102 |
103 | def test_resample_agg(df, pdf):
104 | def my_sum(vals, foo=None, *, bar=None):
105 | return vals.sum()
106 |
107 | result = df.resample("2min").agg(my_sum, "foo", bar="bar")
108 | expected = pdf.resample("2min").agg(my_sum, "foo", bar="bar")
109 | assert_eq(result, expected)
110 |
111 | result = df.resample("2min").agg(my_sum)["foo"]
112 | expected = pdf.resample("2min").agg(my_sum)["foo"]
113 | assert_eq(result, expected)
114 |
115 | # simplify up disabled for `agg`, function may access other columns
116 | q = df.resample("2min").agg(my_sum)["foo"].simplify()
117 | eq = df["foo"].resample("2min").agg(my_sum).simplify()
118 | assert q._name != eq._name
119 |
120 |
121 | @pytest.mark.parametrize("method", ["count", "nunique", "size", "sum"])
122 | def test_resample_has_correct_fill_value(method):
123 | index = pd.date_range("2000-01-01", "2000-02-15", freq="h")
124 | index = index.union(pd.date_range("4-15-2000", "5-15-2000", freq="h"))
125 | ps = pd.Series(range(len(index)), index=index)
126 | ds = from_pandas(ps, npartitions=2)
127 |
128 | assert_eq(
129 | getattr(ds.resample("30min"), method)(), getattr(ps.resample("30min"), method)()
130 | )
131 |
132 |
133 | def test_resample_divisions_propagation():
134 | idx = pd.date_range(start="10:00:00.873821", end="10:05:00", freq="0.002s")
135 | pdf = pd.DataFrame({"data": 1}, index=idx)
136 | df = from_pandas(pdf, npartitions=10)
137 | result = df.resample("0.03s").mean()
138 | result = result.repartition(freq="1d")
139 | expected = pdf.resample("0.03s").mean()
140 | assert_eq(result, expected)
141 |
142 | result = df.resample("0.03s").mean().partitions[1]
143 | expected = pdf.resample("0.03s").mean()[997 : 2 * 997]
144 | assert_eq(result, expected)
145 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_reshape.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr import from_pandas, pivot_table
4 | from dask_expr.tests._util import _backend_library, assert_eq
5 |
6 | # Set DataFrame backend for this module
7 | pd = _backend_library()
8 |
9 |
10 | @pytest.fixture
11 | def pdf():
12 | pdf = pd.DataFrame(
13 | {
14 | "x": [1, 2, 3, 4, 5, 6],
15 | "y": pd.Series([4, 5, 8, 6, 1, 4], dtype="category"),
16 | "z": [4, 15, 8, 16, 1, 14],
17 | "a": 1,
18 | }
19 | )
20 | yield pdf
21 |
22 |
23 | @pytest.fixture
24 | def df(pdf):
25 | yield from_pandas(pdf, npartitions=3)
26 |
27 |
28 | @pytest.mark.parametrize("aggfunc", ["first", "last", "sum", "mean", "count"])
29 | def test_pivot_table(df, pdf, aggfunc):
30 | assert_eq(
31 | df.pivot_table(index="x", columns="y", values="z", aggfunc=aggfunc),
32 | pdf.pivot_table(
33 | index="x", columns="y", values="z", aggfunc=aggfunc, observed=False
34 | ),
35 | check_dtype=aggfunc != "count",
36 | )
37 |
38 | assert_eq(
39 | df.pivot_table(index="x", columns="y", values=["z", "a"], aggfunc=aggfunc),
40 | pdf.pivot_table(
41 | index="x", columns="y", values=["z", "a"], aggfunc=aggfunc, observed=False
42 | ),
43 | check_dtype=aggfunc != "count",
44 | )
45 |
46 | assert_eq(
47 | pivot_table(df, index="x", columns="y", values=["z", "a"], aggfunc=aggfunc),
48 | pdf.pivot_table(
49 | index="x", columns="y", values=["z", "a"], aggfunc=aggfunc, observed=False
50 | ),
51 | check_dtype=aggfunc != "count",
52 | )
53 |
54 |
55 | def test_pivot_table_fails(df):
56 | with pytest.raises(ValueError, match="must be the name of an existing column"):
57 | df.pivot_table(index="aaa", columns="y", values="z")
58 | with pytest.raises(ValueError, match="must be the name of an existing column"):
59 | df.pivot_table(index=["a"], columns="y", values="z")
60 |
61 | with pytest.raises(ValueError, match="must be the name of an existing column"):
62 | df.pivot_table(index="a", columns="xxx", values="z")
63 | with pytest.raises(ValueError, match="must be the name of an existing column"):
64 | df.pivot_table(index="a", columns=["x"], values="z")
65 |
66 | with pytest.raises(ValueError, match="'columns' must be category dtype"):
67 | df.pivot_table(index="a", columns="x", values="z")
68 |
69 | df2 = df.copy()
70 | df2["y"] = df2.y.cat.as_unknown()
71 | with pytest.raises(ValueError, match="'columns' must have known categories"):
72 | df2.pivot_table(index="a", columns="y", values="z")
73 |
74 | with pytest.raises(
75 | ValueError, match="'values' must refer to an existing column or columns"
76 | ):
77 | df.pivot_table(index="x", columns="y", values="aaa")
78 |
79 | msg = "aggfunc must be either 'mean', 'sum', 'count', 'first', 'last'"
80 | with pytest.raises(ValueError, match=msg):
81 | pivot_table(df, index="x", columns="y", values="z", aggfunc=["sum"])
82 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_rolling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from dask_expr import from_pandas
6 | from dask_expr.tests._util import _backend_library, assert_eq
7 |
8 | # Set DataFrame backend for this module
9 | pd = _backend_library()
10 |
11 |
12 | @pytest.fixture
13 | def pdf():
14 | idx = pd.date_range("2000-01-01", periods=12, freq="min")
15 | pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx)
16 | pdf["bar"] = 1
17 | yield pdf
18 |
19 |
20 | @pytest.fixture
21 | def df(pdf, request):
22 | npartitions = getattr(request, "param", 2)
23 | yield from_pandas(pdf, npartitions=npartitions)
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "api,how_args",
28 | [
29 | ("count", ()),
30 | ("mean", ()),
31 | ("sum", ()),
32 | ("min", ()),
33 | ("max", ()),
34 | ("var", ()),
35 | ("std", ()),
36 | ("median", ()),
37 | ("skew", ()),
38 | ("quantile", (0.5,)),
39 | ("kurt", ()),
40 | ],
41 | )
42 | @pytest.mark.parametrize("window,min_periods", ((1, None), (3, 2), (3, 3)))
43 | @pytest.mark.parametrize("center", (True, False))
44 | @pytest.mark.parametrize("df", (1, 2), indirect=True)
45 | def test_rolling_apis(df, pdf, window, api, how_args, min_periods, center):
46 | args = (window,)
47 | kwargs = dict(min_periods=min_periods, center=center)
48 |
49 | result = getattr(df.rolling(*args, **kwargs), api)(*how_args)
50 | expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args)
51 | assert_eq(result, expected)
52 |
53 | result = getattr(df.rolling(*args, **kwargs), api)(*how_args)["foo"]
54 | expected = getattr(pdf.rolling(*args, **kwargs), api)(*how_args)["foo"]
55 | assert_eq(result, expected)
56 |
57 | q = result.simplify()
58 | eq = getattr(df["foo"].rolling(*args, **kwargs), api)(*how_args).simplify()
59 | assert q._name == eq._name
60 |
61 |
62 | @pytest.mark.parametrize("window", (1, 2))
63 | @pytest.mark.parametrize("df", (1, 2), indirect=True)
64 | def test_rolling_agg(df, pdf, window):
65 | def my_sum(vals, foo=None, *, bar=None):
66 | return vals.sum()
67 |
68 | result = df.rolling(window).agg(my_sum, "foo", bar="bar")
69 | expected = pdf.rolling(window).agg(my_sum, "foo", bar="bar")
70 | assert_eq(result, expected)
71 |
72 | result = df.rolling(window).agg(my_sum)["foo"]
73 | expected = pdf.rolling(window).agg(my_sum)["foo"]
74 | assert_eq(result, expected)
75 |
76 | # simplify up disabled for `agg`, function may access other columns
77 | q = df.rolling(window).agg(my_sum)["foo"].simplify()
78 | eq = df["foo"].rolling(window).agg(my_sum).simplify()
79 | assert q._name != eq._name
80 |
81 |
82 | @pytest.mark.parametrize("window", (1, 2))
83 | @pytest.mark.parametrize("df", (1, 2), indirect=True)
84 | @pytest.mark.parametrize("raw", (True, False))
85 | @pytest.mark.parametrize("foo", (1, None))
86 | @pytest.mark.parametrize("bar", (2, None))
87 | def test_rolling_apply(df, pdf, window, raw, foo, bar):
88 | def my_sum(vals, foo_=None, *, bar_=None):
89 | assert foo_ == foo
90 | assert bar_ == bar
91 | if raw:
92 | assert isinstance(vals, np.ndarray)
93 | else:
94 | assert isinstance(vals, pd.Series)
95 | return vals.sum()
96 |
97 | kwargs = dict(raw=raw, args=(foo,), kwargs=dict(bar_=bar))
98 |
99 | result = df.rolling(window).apply(my_sum, **kwargs)
100 | expected = pdf.rolling(window).apply(my_sum, **kwargs)
101 | assert_eq(result, expected)
102 |
103 | result = df.rolling(window).apply(my_sum, **kwargs)["foo"]
104 | expected = pdf.rolling(window).apply(my_sum, **kwargs)["foo"]
105 | assert_eq(result, expected)
106 |
107 | # simplify up disabled for `apply`, function may access other columns
108 | q = df.rolling(window).apply(my_sum, **kwargs)["foo"].simplify()
109 | eq = df["foo"].rolling(window).apply(my_sum, **kwargs).simplify()
110 | assert q._name == eq._name
111 |
112 |
113 | def test_rolling_one_element_window(df, pdf):
114 | pdf.index = pd.date_range("2000-01-01", periods=12, freq="2s")
115 | df = from_pandas(pdf, npartitions=3)
116 | result = pdf.foo.rolling("1s").count()
117 | expected = df.foo.rolling("1s").count()
118 | assert_eq(result, expected)
119 |
120 |
121 | @pytest.mark.parametrize("window", ["2s", "5s", "20s", "10h"])
122 | def test_time_rolling_large_window_variable_chunks(window):
123 | df = pd.DataFrame(
124 | {
125 | "a": pd.date_range("2016-01-01 00:00:00", periods=100, freq="1s"),
126 | "b": np.random.randint(100, size=(100,)),
127 | }
128 | )
129 | ddf = from_pandas(df, 5)
130 | ddf = ddf.repartition(divisions=[0, 5, 20, 28, 33, 54, 79, 80, 82, 99])
131 | df = df.set_index("a")
132 | ddf = ddf.set_index("a")
133 | assert_eq(ddf.rolling(window).sum(), df.rolling(window).sum())
134 | assert_eq(ddf.rolling(window).count(), df.rolling(window).count())
135 | assert_eq(ddf.rolling(window).mean(), df.rolling(window).mean())
136 |
137 |
138 | def test_rolling_one_element_window_empty_after(df, pdf):
139 | pdf.index = pd.date_range("2000-01-01", periods=12, freq="2s")
140 | df = from_pandas(pdf, npartitions=3)
141 | result = df.map_overlap(lambda x: x.rolling("1s").count(), before="1s", after="1s")
142 | expected = pdf.rolling("1s").count()
143 | assert_eq(result, expected)
144 |
145 |
146 | @pytest.mark.parametrize("window", [1, 2, 4, 5])
147 | @pytest.mark.parametrize("center", [True, False])
148 | def test_rolling_cov(df, pdf, window, center):
149 | # DataFrame
150 | prolling = pdf.drop("foo", axis=1).rolling(window, center=center)
151 | drolling = df.drop("foo", axis=1).rolling(window, center=center)
152 | assert_eq(prolling.cov(), drolling.cov())
153 |
154 | # Series
155 | prolling = pdf.bar.rolling(window, center=center)
156 | drolling = df.bar.rolling(window, center=center)
157 | assert_eq(prolling.cov(), drolling.cov())
158 |
159 | # Projection
160 | actual = df.rolling(window, center=center).cov()[["foo", "bar"]].simplify()
161 | expected = df[["foo", "bar"]].rolling(window, center=center).cov().simplify()
162 | assert actual._name == expected._name
163 |
164 |
165 | def test_rolling_raises():
166 | df = pd.DataFrame(
167 | {"a": np.random.randn(25).cumsum(), "b": np.random.randint(100, size=(25,))}
168 | )
169 | ddf = from_pandas(df, npartitions=2)
170 |
171 | pytest.raises(ValueError, lambda: ddf.rolling(1.5))
172 | pytest.raises(ValueError, lambda: ddf.rolling(-1))
173 | pytest.raises(ValueError, lambda: ddf.rolling(3, min_periods=1.2))
174 | pytest.raises(ValueError, lambda: ddf.rolling(3, min_periods=-2))
175 | pytest.raises(NotImplementedError, lambda: ddf.rolling(100).mean().compute())
176 |
177 |
178 | def test_time_rolling_constructor(df):
179 | result = df.rolling("4s")
180 | assert result.window == "4s"
181 | assert result.min_periods is None
182 | assert result.win_type is None
183 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_string_accessor.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dask_expr._collection import DataFrame, from_pandas
4 | from dask_expr.tests._util import _backend_library, assert_eq
5 |
6 | pd = _backend_library()
7 |
8 |
9 | @pytest.fixture()
10 | def ser():
11 | return pd.Series(["a", "b", "1", "aaa", "bbb", "ccc", "ddd", "abcd"])
12 |
13 |
14 | @pytest.fixture()
15 | def dser(ser):
16 | import dask.dataframe as dd
17 |
18 | return dd.from_pandas(ser, npartitions=3)
19 |
20 |
21 | @pytest.mark.parametrize(
22 | "func, kwargs",
23 | [
24 | ("len", {}),
25 | ("capitalize", {}),
26 | ("casefold", {}),
27 | ("contains", {"pat": "a"}),
28 | ("count", {"pat": "a"}),
29 | ("endswith", {"pat": "a"}),
30 | ("extract", {"pat": r"[ab](\d)"}),
31 | ("extractall", {"pat": r"[ab](\d)"}),
32 | ("find", {"sub": "a"}),
33 | ("findall", {"pat": "a"}),
34 | ("fullmatch", {"pat": "a"}),
35 | ("get", {"i": 0}),
36 | ("isalnum", {}),
37 | ("isalpha", {}),
38 | ("isdecimal", {}),
39 | ("isdigit", {}),
40 | ("islower", {}),
41 | ("isspace", {}),
42 | ("istitle", {}),
43 | ("isupper", {}),
44 | ("join", {"sep": "-"}),
45 | ("len", {}),
46 | ("ljust", {"width": 3}),
47 | ("lower", {}),
48 | ("lstrip", {}),
49 | ("match", {"pat": r"[ab](\d)"}),
50 | ("normalize", {"form": "NFC"}),
51 | ("pad", {"width": 3}),
52 | ("removeprefix", {"prefix": "a"}),
53 | ("removesuffix", {"suffix": "a"}),
54 | ("repeat", {"repeats": 2}),
55 | ("replace", {"pat": "a", "repl": "b"}),
56 | ("rfind", {"sub": "a"}),
57 | ("rjust", {"width": 3}),
58 | ("rstrip", {}),
59 | ("slice", {"start": 0, "stop": 1}),
60 | ("slice_replace", {"start": 0, "stop": 1, "repl": "a"}),
61 | ("startswith", {"pat": "a"}),
62 | ("strip", {}),
63 | ("swapcase", {}),
64 | ("title", {}),
65 | ("upper", {}),
66 | ("wrap", {"width": 2}),
67 | ("zfill", {"width": 2}),
68 | ("split", {"pat": "a"}),
69 | ("rsplit", {"pat": "a"}),
70 | ("cat", {}),
71 | ("cat", {"others": pd.Series(["a"])}),
72 | ],
73 | )
74 | def test_string_accessor(ser, dser, func, kwargs):
75 | ser = ser.astype("string[pyarrow]")
76 |
77 | assert_eq(getattr(ser.str, func)(**kwargs), getattr(dser.str, func)(**kwargs))
78 |
79 | if func in (
80 | "contains",
81 | "endswith",
82 | "fullmatch",
83 | "isalnum",
84 | "isalpha",
85 | "isdecimal",
86 | "isdigit",
87 | "islower",
88 | "isspace",
89 | "istitle",
90 | "isupper",
91 | "startswith",
92 | "match",
93 | ):
94 | # This returns arrays and doesn't work in dask/dask either
95 | return
96 |
97 | ser.index = ser.values
98 | ser = ser.sort_index()
99 | dser = from_pandas(ser, npartitions=3)
100 | pdf_result = getattr(ser.index.str, func)(**kwargs)
101 |
102 | if func == "cat" and len(kwargs) > 0:
103 | # Doesn't work with others on Index
104 | return
105 | if isinstance(pdf_result, pd.DataFrame):
106 | assert_eq(
107 | getattr(dser.index.str, func)(**kwargs), pdf_result, check_index=False
108 | )
109 | else:
110 | assert_eq(getattr(dser.index.str, func)(**kwargs), pdf_result)
111 |
112 |
113 | def test_str_accessor_cat(ser, dser):
114 | sol = ser.str.cat(ser.str.upper(), sep=":")
115 | assert_eq(dser.str.cat(dser.str.upper(), sep=":"), sol)
116 | assert_eq(dser.str.cat(ser.str.upper(), sep=":"), sol)
117 | assert_eq(
118 | dser.str.cat([dser.str.upper(), ser.str.lower()], sep=":"),
119 | ser.str.cat([ser.str.upper(), ser.str.lower()], sep=":"),
120 | )
121 | assert_eq(dser.str.cat(sep=":"), ser.str.cat(sep=":"))
122 |
123 | for o in ["foo", ["foo"]]:
124 | with pytest.raises(TypeError):
125 | dser.str.cat(o)
126 |
127 |
128 | @pytest.mark.parametrize("index", [None, [0]], ids=["range_index", "other index"])
129 | def test_str_split_(index):
130 | df = pd.DataFrame({"a": ["a\nb"]}, index=index)
131 | ddf = from_pandas(df, npartitions=1)
132 |
133 | pd_a = df["a"].str.split("\n", n=1, expand=True)
134 | dd_a = ddf["a"].str.split("\n", n=1, expand=True)
135 |
136 | assert_eq(dd_a, pd_a)
137 |
138 |
139 | def test_str_accessor_not_available():
140 | pdf = pd.DataFrame({"a": [1, 2, 3]})
141 | df = from_pandas(pdf, npartitions=2)
142 | # Not available on invalid dtypes
143 | with pytest.raises(AttributeError, match=".str accessor"):
144 | df.a.str
145 |
146 | assert "str" not in dir(df.a)
147 |
148 |
149 | def test_partition():
150 | df = DataFrame.from_dict({"A": ["A|B", "C|D"]}, npartitions=2)["A"].str.partition(
151 | "|"
152 | )
153 | result = df[1]
154 | expected = pd.DataFrame.from_dict({"A": ["A|B", "C|D"]})["A"].str.partition("|")[1]
155 | assert_eq(result, expected)
156 |
--------------------------------------------------------------------------------
/dask_expr/tests/test_ufunc.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import numpy as np
3 | import pytest
4 |
5 | from dask_expr import Index, from_pandas
6 | from dask_expr.tests._util import _backend_library, assert_eq
7 |
8 | pd = _backend_library()
9 |
10 |
11 | @pytest.fixture
12 | def pdf():
13 | pdf = pd.DataFrame({"x": range(100)})
14 | pdf["y"] = pdf.x // 7 # Not unique; duplicates span different partitions
15 | yield pdf
16 |
17 |
18 | @pytest.fixture
19 | def df(pdf):
20 | yield from_pandas(pdf, npartitions=10)
21 |
22 |
23 | def test_ufunc(df, pdf):
24 | ufunc = "conj"
25 | dafunc = getattr(da, ufunc)
26 | npfunc = getattr(np, ufunc)
27 |
28 | pandas_type = pdf.__class__
29 | dask_type = df.__class__
30 |
31 | assert isinstance(dafunc(df), dask_type)
32 | assert_eq(dafunc(df), npfunc(pdf))
33 |
34 | if isinstance(npfunc, np.ufunc):
35 | assert isinstance(npfunc(df), dask_type)
36 | else:
37 | assert isinstance(npfunc(df), pandas_type)
38 | assert_eq(npfunc(df), npfunc(pdf))
39 |
40 | # applying Dask ufunc to normal Series triggers computation
41 | assert isinstance(dafunc(pdf), pandas_type)
42 | assert_eq(dafunc(df), npfunc(pdf))
43 |
44 | assert isinstance(dafunc(df.index), Index)
45 | assert_eq(dafunc(df.index), npfunc(pdf.index))
46 |
47 | if isinstance(npfunc, np.ufunc):
48 | assert isinstance(npfunc(df.index), Index)
49 | else:
50 | assert isinstance(npfunc(df.index), pd.Index)
51 |
52 | assert_eq(npfunc(df.index), npfunc(pdf.index))
53 |
54 |
55 | def test_ufunc_with_2args(pdf, df):
56 | ufunc = "logaddexp"
57 | dafunc = getattr(da, ufunc)
58 | npfunc = getattr(np, ufunc)
59 |
60 | pandas_type = pdf.__class__
61 | pdf2 = pdf.sort_index(ascending=False)
62 | dask_type = df.__class__
63 | df2 = from_pandas(pdf2, npartitions=8)
64 | # applying Dask ufunc doesn't trigger computation
65 | assert isinstance(dafunc(df, df2), dask_type)
66 | assert_eq(dafunc(df, df2), npfunc(pdf, pdf2))
67 |
68 | # should be fine with pandas as a second arg, too
69 | assert isinstance(dafunc(df, pdf2), dask_type)
70 | assert_eq(dafunc(df, pdf2), npfunc(pdf, pdf2))
71 |
72 | # applying NumPy ufunc is lazy
73 | if isinstance(npfunc, np.ufunc):
74 | assert isinstance(npfunc(df, df2), dask_type)
75 | assert isinstance(npfunc(df, pdf2), dask_type)
76 | else:
77 | assert isinstance(npfunc(df, df2), pandas_type)
78 | assert isinstance(npfunc(df, pdf2), pandas_type)
79 |
80 | assert_eq(npfunc(df, df2), npfunc(pdf, pdf2))
81 | assert_eq(npfunc(df, pdf), npfunc(pdf, pdf2))
82 |
83 | # applying Dask ufunc to normal Series triggers computation
84 | assert isinstance(dafunc(pdf, pdf2), pandas_type)
85 | assert_eq(dafunc(pdf, pdf2), npfunc(pdf, pdf2))
86 |
87 |
88 | @pytest.mark.parametrize(
89 | "ufunc",
90 | [
91 | np.mean,
92 | np.std,
93 | np.sum,
94 | np.cumsum,
95 | np.cumprod,
96 | np.var,
97 | np.min,
98 | np.max,
99 | np.all,
100 | np.any,
101 | np.prod,
102 | ],
103 | )
104 | def test_reducers(pdf, df, ufunc):
105 | assert_eq(ufunc(pdf, axis=0), ufunc(df, axis=0))
106 |
107 |
108 | def test_clip(pdf, df):
109 | assert_eq(np.clip(pdf, 10, 20), np.clip(df, 10, 20))
110 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=62.6", "versioneer[toml]==0.28"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "dask-expr"
7 | description = "High Level Expressions for Dask "
8 | maintainers = [{name = "Matthew Rocklin", email = "mrocklin@gmail.com"}]
9 | license = {text = "BSD"}
10 | keywords = ["dask pandas"]
11 | classifiers = [
12 | "Intended Audience :: Developers",
13 | "Intended Audience :: Science/Research",
14 | "License :: OSI Approved :: BSD License",
15 | "Operating System :: OS Independent",
16 | "Programming Language :: Python",
17 | "Programming Language :: Python :: 3",
18 | "Programming Language :: Python :: 3 :: Only",
19 | "Programming Language :: Python :: 3.10",
20 | "Programming Language :: Python :: 3.11",
21 | "Programming Language :: Python :: 3.12",
22 | "Programming Language :: Python :: 3.13",
23 | "Topic :: Scientific/Engineering",
24 | "Topic :: System :: Distributed Computing",
25 | ]
26 | readme = "README.md"
27 | requires-python = ">=3.10"
28 | dependencies = [
29 | "dask == 2024.12.1",
30 | "pyarrow>=14.0.1",
31 | "pandas >= 2",
32 | ]
33 |
34 | dynamic = ["version"]
35 |
36 | [project.optional-dependencies]
37 | analyze = ["crick", "distributed", "graphviz"]
38 |
39 | [project.urls]
40 | "Source code" = "https://github.com/dask-contrib/dask-expr/"
41 |
42 | [tool.setuptools.packages.find]
43 | exclude = ["*tests*"]
44 | namespaces = false
45 |
46 | [tool.coverage.run]
47 | omit = [
48 | "*/test_*.py",
49 | ]
50 | source = ["dask_expr"]
51 |
52 | [tool.coverage.report]
53 | # Regexes for lines to exclude from consideration
54 | exclude_lines = [
55 | "pragma: no cover",
56 | "raise AssertionError",
57 | "raise NotImplementedError",
58 | ]
59 | ignore_errors = true
60 |
61 | [tool.versioneer]
62 | VCS = "git"
63 | style = "pep440"
64 | versionfile_source = "dask_expr/_version.py"
65 | versionfile_build = "dask_expr/_version.py"
66 | tag_prefix = "v"
67 | parentdir_prefix = "dask_expr-"
68 |
69 |
70 | [tool.pytest.ini_options]
71 | addopts = "-v -rsxfE --durations=10 --color=yes"
72 | filterwarnings = [
73 | 'ignore:Passing a BlockManager to DataFrame is deprecated and will raise in a future version. Use public APIs instead:DeprecationWarning', # https://github.com/apache/arrow/issues/35081
74 | 'ignore:The previous implementation of stack is deprecated and will be removed in a future version of pandas\.:FutureWarning',
75 | 'error:\nA value is trying to be set on a copy of a slice from a DataFrame',
76 | ]
77 | xfail_strict = true
78 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | sections = FUTURE,STDLIB,THIRDPARTY,DISTRIBUTED,FIRSTPARTY,LOCALFOLDER
3 | profile = black
4 | skip_gitignore = true
5 | force_to_top = true
6 | default_section = THIRDPARTY
7 |
8 |
9 | [aliases]
10 | test = pytest
11 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from __future__ import annotations
4 |
5 | import versioneer
6 | from setuptools import setup
7 |
8 | setup(
9 | version=versioneer.get_version(),
10 | cmdclass=versioneer.get_cmdclass(),
11 | )
12 |
--------------------------------------------------------------------------------