├── .copier-answers.yml
├── .envrc
├── .gitattributes
├── .github
├── CODEOWNERS
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
├── release-drafter.yml
└── workflows
│ ├── build.yml
│ ├── chore.yml
│ ├── ci.yml
│ └── scorecard.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .prettierignore
├── .prettierrc
├── .readthedocs.yml
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── README.md
├── SECURITY.md
├── dataframely
├── __init__.py
├── _base_collection.py
├── _base_schema.py
├── _compat.py
├── _deprecation.py
├── _extre.pyi
├── _filter.py
├── _polars.py
├── _rule.py
├── _typing.py
├── _validation.py
├── collection.py
├── columns
│ ├── __init__.py
│ ├── _base.py
│ ├── _mixins.py
│ ├── _utils.py
│ ├── any.py
│ ├── array.py
│ ├── bool.py
│ ├── datetime.py
│ ├── decimal.py
│ ├── enum.py
│ ├── float.py
│ ├── integer.py
│ ├── list.py
│ ├── object.py
│ ├── string.py
│ └── struct.py
├── config.py
├── exc.py
├── failure.py
├── functional.py
├── mypy.py
├── py.typed
├── random.py
├── schema.py
└── testing
│ ├── __init__.py
│ ├── const.py
│ ├── factory.py
│ ├── mask.py
│ ├── rules.py
│ └── typing.py
├── docker-compose.yml
├── docs
├── Makefile
├── _api
│ ├── dataframely.collection.rst
│ ├── dataframely.columns.any.rst
│ ├── dataframely.columns.bool.rst
│ ├── dataframely.columns.datetime.rst
│ ├── dataframely.columns.decimal.rst
│ ├── dataframely.columns.enum.rst
│ ├── dataframely.columns.float.rst
│ ├── dataframely.columns.integer.rst
│ ├── dataframely.columns.list.rst
│ ├── dataframely.columns.rst
│ ├── dataframely.columns.string.rst
│ ├── dataframely.columns.struct.rst
│ ├── dataframely.config.rst
│ ├── dataframely.exc.rst
│ ├── dataframely.failure.rst
│ ├── dataframely.functional.rst
│ ├── dataframely.mypy.rst
│ ├── dataframely.random.rst
│ ├── dataframely.rst
│ ├── dataframely.schema.rst
│ ├── dataframely.testing.const.rst
│ ├── dataframely.testing.factory.rst
│ ├── dataframely.testing.mask.rst
│ ├── dataframely.testing.rst
│ ├── dataframely.testing.rules.rst
│ ├── dataframely.testing.typing.rst
│ └── modules.rst
├── _static
│ ├── custom.css
│ └── favicon.ico
├── conf.py
├── index.rst
├── make.bat
└── sites
│ ├── development.rst
│ ├── examples
│ └── real-world.ipynb
│ ├── faq.rst
│ ├── installation.rst
│ ├── quickstart.rst
│ └── versioning.rst
├── pixi.lock
├── pixi.toml
├── pyproject.toml
├── src
├── errdefs.rs
├── lib.rs
└── regex_repr.rs
└── tests
├── collection
├── test_base.py
├── test_cast.py
├── test_create_empty.py
├── test_filter_one_to_n.py
├── test_filter_validate.py
├── test_ignore_in_filter.py
├── test_implementation.py
├── test_optional_members.py
├── test_sample.py
└── test_validate_input.py
├── column_types
├── __init__.py
├── test_any.py
├── test_array.py
├── test_datetime.py
├── test_decimal.py
├── test_enum.py
├── test_float.py
├── test_integer.py
├── test_list.py
├── test_object.py
├── test_string.py
└── test_struct.py
├── columns
├── __init__.py
├── test_alias.py
├── test_check.py
├── test_default_dtypes.py
├── test_metadata.py
├── test_polars_schema.py
├── test_pyarrow.py
├── test_rules.py
├── test_sample.py
├── test_sql_schema.py
├── test_str.py
└── test_utils.py
├── core_validation
├── __init__.py
├── test_column_validation.py
├── test_dtype_validation.py
└── test_rule_evaluation.py
├── functional
├── test_concat.py
└── test_relationships.py
├── schema
├── test_base.py
├── test_cast.py
├── test_create_empty.py
├── test_create_empty_if_none.py
├── test_filter.py
├── test_inheritance.py
├── test_rule_implementation.py
├── test_sample.py
└── test_validate.py
├── test_compat.py
├── test_config.py
├── test_deprecation.py
├── test_exc.py
├── test_extre.py
├── test_failure_info.py
├── test_random.py
└── test_typing.py
/.copier-answers.yml:
--------------------------------------------------------------------------------
1 | # This file is managed by Copier; DO NOT EDIT OR REMOVE.
2 | _commit: v0.4.1
3 | _src_path: https://github.com/quantco/copier-template-python-open-source
4 | add_autobump_workflow: false
5 | author_email: oliver.borchert@quantco.com
6 | author_name: Oliver Borchert
7 | github_url: https://github.com/quantco/dataframely
8 | github_user: borchero
9 | minimal_python_version: py311
10 | project_short_description: A declarative, polars-native data frame validation library
11 | project_slug: dataframely
12 | use_devcontainer: false
13 |
--------------------------------------------------------------------------------
/.envrc:
--------------------------------------------------------------------------------
1 | watch_file pixi.toml pixi.lock
2 | eval "$(pixi shell-hook)"
3 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | pixi.lock merge=binary linguist-language=YAML linguist-generated=true
2 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @borchero @AndreasAlbertQC @delsner
2 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Motivation
2 |
3 |
4 |
5 | # Changes
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: github-actions
4 | directory: /
5 | schedule:
6 | interval: monthly
7 | groups:
8 | gh-actions:
9 | patterns:
10 | - "*"
11 | commit-message:
12 | prefix: ci
13 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | # ------------------------------------- PULL REQUEST LABELS ------------------------------------- #
2 | autolabeler:
3 | # Conventional Commit Types (https://github.com/commitizen/conventional-commit-types)
4 | - label: build
5 | title:
6 | - '/^build(\(.*\))?(\!)?\:/'
7 | - label: chore
8 | title:
9 | - '/^chore(\(.*\))?(\!)?\:/'
10 | - label: ci
11 | title:
12 | - '/^ci(\(.*\))?(\!)?\:/'
13 | - label: documentation
14 | title:
15 | - '/^docs(\(.*\))?(\!)?\:/'
16 | - label: enhancement
17 | title:
18 | - '/^feat(\(.*\))?(\!)?\:/'
19 | - label: fix
20 | title:
21 | - '/^fix(\(.*\))?(\!)?\:/'
22 | - label: performance
23 | title:
24 | - '/^perf(\(.*\))?(\!)?\:/'
25 | - label: refactor
26 | title:
27 | - '/^refactor(\(.*\))?(\!)?\:/'
28 | - label: revert
29 | title:
30 | - '/^revert(\(.*\))?(\!)?\:/'
31 | - label: style
32 | title:
33 | - '/^style(\(.*\))?(\!)?\:/'
34 | - label: test
35 | title:
36 | - '/^test(\(.*\))?(\!)?\:/'
37 | # Custom Types
38 | - label: breaking
39 | title:
40 | - '/^[a-z]+(\(.*\))?\!\:/'
41 | # ------------------------------------- AUTOMATIC VERSIONING ------------------------------------ #
42 | version-resolver:
43 | major:
44 | labels:
45 | - breaking
46 | minor:
47 | labels:
48 | - enhancement
49 | default: patch
50 | # ------------------------------------ RELEASE CONFIGURATION ------------------------------------ #
51 | name-template: "v$RESOLVED_VERSION"
52 | tag-template: "v$RESOLVED_VERSION"
53 | category-template: "### $TITLE"
54 | change-template: "- $TITLE by @$AUTHOR in [#$NUMBER]($URL)"
55 | replacers:
56 | # remove conventional commit tag & scope from change list
57 | - search: '/- [a-z]+(\(.*\))?(\!)?\: /g'
58 | replace: "- "
59 | template: |
60 | ## What's Changed
61 |
62 | $CHANGES
63 |
64 | **Full Changelog:** [`$PREVIOUS_TAG...v$RESOLVED_VERSION`](https://github.com/$OWNER/$REPOSITORY/compare/$PREVIOUS_TAG...v$RESOLVED_VERSION)
65 | categories:
66 | - title: ⚠️ Breaking Changes
67 | labels:
68 | - breaking
69 | - title: ✨ New Features
70 | labels:
71 | - enhancement
72 | - title: 🐞 Bug Fixes
73 | labels:
74 | - fix
75 | - title: 🏎️ Performance Improvements
76 | labels:
77 | - performance
78 | - title: 📚 Documentation
79 | labels:
80 | - documentation
81 | - title: 🏗️ Testing
82 | labels:
83 | - test
84 | - title: ⚙️ Automation
85 | labels:
86 | - ci
87 | - title: 🛠 Builds
88 | labels:
89 | - build
90 | - title: 💎 Code Style
91 | labels:
92 | - style
93 | - title: 📦 Refactorings
94 | labels:
95 | - refactor
96 | - title: ♻️ Chores
97 | labels:
98 | - chore
99 | - title: 🗑 Reverts
100 | labels:
101 | - revert
102 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Build
2 | on:
3 | pull_request:
4 | push:
5 | branches: [main]
6 | release:
7 | types: [published]
8 |
9 | jobs:
10 | build-sdist:
11 | name: Build Sdist
12 | runs-on: ubuntu-latest
13 | permissions:
14 | contents: read
15 | steps:
16 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
17 | with:
18 | fetch-depth: 0
19 | - name: Set up pixi
20 | uses: prefix-dev/setup-pixi@19eac09b398e3d0c747adc7921926a6d802df4da # v0.8.8
21 | with:
22 | environments: build
23 | - name: Set version
24 | run: pixi run -e build set-version
25 | - name: Build project
26 | run: pixi run -e build build-sdist
27 | - name: Upload package
28 | uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
29 | with:
30 | name: sdist
31 | path: dist/*
32 |
33 | build-wheel:
34 | name: Build Wheel (${{ matrix.target-platform }})
35 | runs-on: ${{ matrix.os }}
36 | strategy:
37 | fail-fast: false
38 | matrix:
39 | include:
40 | - target-platform: linux-64
41 | os: ubuntu-latest
42 | - target-platform: linux-aarch64
43 | os: ubuntu-24.04-arm
44 | - target-platform: osx-64
45 | os: macos-13
46 | - target-platform: osx-arm64
47 | os: macos-latest
48 | - target-platform: win-64
49 | os: windows-latest
50 | steps:
51 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
52 | with:
53 | fetch-depth: 0
54 | - name: Set up pixi
55 | uses: prefix-dev/setup-pixi@19eac09b398e3d0c747adc7921926a6d802df4da # v0.8.8
56 | with:
57 | environments: build
58 | - name: Set version
59 | run: pixi run -e build set-version
60 | - name: Build wheel
61 | uses: PyO3/maturin-action@aef21716ff3dcae8a1c301d23ec3e4446972a6e3 # v1.49.1
62 | with:
63 | command: build
64 | args: --out dist -i python3.11
65 | manylinux: auto
66 | - name: Check package
67 | run: pixi run -e build check-wheel
68 | - name: Upload package
69 | uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
70 | with:
71 | name: wheel-${{ matrix.target-platform }}
72 | path: dist/*
73 |
74 | release:
75 | name: Publish package
76 | if: github.event_name == 'release'
77 | needs: build-wheel
78 | runs-on: ubuntu-latest
79 | permissions:
80 | id-token: write
81 | environment: pypi
82 | steps:
83 | - uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
84 | with:
85 | path: dist
86 | merge-multiple: true
87 | - name: Publish package on PyPi
88 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
89 |
--------------------------------------------------------------------------------
/.github/workflows/chore.yml:
--------------------------------------------------------------------------------
1 | name: Chore
2 | on:
3 | pull_request_target:
4 | branches: [main]
5 | types: [opened, reopened, edited, synchronize]
6 | push:
7 | branches: [main]
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.ref }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | check-pr-title:
15 | name: Check PR Title
16 | if: github.event_name == 'pull_request_target'
17 | runs-on: ubuntu-latest
18 | permissions:
19 | contents: read
20 | pull-requests: write
21 | steps:
22 | - name: Check valid conventional commit message
23 | id: lint
24 | uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3
25 | with:
26 | subjectPattern: ^[A-Z].+[^. ]$ # subject must start with uppercase letter and may not end with a dot/space
27 | env:
28 | GITHUB_TOKEN: ${{ github.token }}
29 | - name: Post comment about invalid PR title
30 | if: failure()
31 | uses: marocchino/sticky-pull-request-comment@52423e01640425a022ef5fd42c6fb5f633a02728 # v2.9.1
32 | with:
33 | header: conventional-commit-pr-title
34 | message: |
35 | Thank you for opening this pull request! 👋🏼
36 |
37 | This repository requires pull request titles to follow the [Conventional Commits specification](https://www.conventionalcommits.org/en/v1.0.0/) and it looks like your proposed title needs to be adjusted.
38 |
39 | Details
40 |
41 | ```
42 | ${{ steps.lint.outputs.error_message }}
43 | ```
44 |
45 |
46 | - name: Delete comment about invalid PR title
47 | if: success()
48 | uses: marocchino/sticky-pull-request-comment@52423e01640425a022ef5fd42c6fb5f633a02728 # v2.9.1
49 | with:
50 | header: conventional-commit-pr-title
51 | delete: true
52 |
53 | release-drafter:
54 | name: ${{ github.event_name == 'pull_request_target' && 'Assign Labels' || 'Draft Release' }}
55 | runs-on: ubuntu-latest
56 | permissions:
57 | contents: write
58 | pull-requests: write
59 | steps:
60 | - name: ${{ github.event_name == 'pull_request_target' && 'Assign labels' || 'Update release draft' }}
61 | uses: release-drafter/release-drafter@b1476f6e6eb133afa41ed8589daba6dc69b4d3f5 # v6.1.0
62 | with:
63 | disable-releaser: ${{ github.event_name == 'pull_request_target' }}
64 | disable-autolabeler: ${{ github.event_name == 'push' }}
65 | env:
66 | GITHUB_TOKEN: ${{ github.token }}
67 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on:
3 | pull_request:
4 | push:
5 | branches: [main]
6 |
7 | # Automatically stop old builds on the same branch/PR
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.ref }}
10 | cancel-in-progress: true
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | pre-commit-checks:
17 | name: Pre-commit Checks
18 | timeout-minutes: 30
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: Checkout branch
22 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
23 | with:
24 | # needed for 'pre-commit-mirrors-insert-license'
25 | fetch-depth: 0
26 | - name: Set up pixi
27 | uses: prefix-dev/setup-pixi@19eac09b398e3d0c747adc7921926a6d802df4da # v0.8.8
28 | with:
29 | environments: default lint
30 | - name: Install repository
31 | run: pixi run -e default postinstall
32 | - name: pre-commit
33 | run: pixi run pre-commit-run --color=always --show-diff-on-failure
34 |
35 | unit-tests:
36 | name: Unit Tests (${{ matrix.os == 'ubuntu-latest' && 'Linux' || 'Windows' }}) - ${{ matrix.environment }}
37 | timeout-minutes: 30
38 | runs-on: ${{ matrix.os }}
39 | strategy:
40 | fail-fast: true
41 | matrix:
42 | os: [ubuntu-latest, windows-latest]
43 | environment: [py311, py312, py313]
44 | steps:
45 | - name: Checkout branch
46 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
47 | - name: Set up pixi
48 | uses: prefix-dev/setup-pixi@19eac09b398e3d0c747adc7921926a6d802df4da # v0.8.8
49 | with:
50 | environments: ${{ matrix.environment }}
51 | - name: Install repository
52 | run: pixi run -e ${{ matrix.environment }} postinstall
53 | - name: Run pytest
54 | run: pixi run -e ${{ matrix.environment }} test-coverage --color=yes
55 | - name: Upload codecov
56 | uses: codecov/codecov-action@18283e04ce6e62d37312384ff67231eb8fd56d24 # v5.4.3
57 | with:
58 | files: ./coverage.xml
59 | token: ${{ secrets.CODECOV_TOKEN }}
60 |
--------------------------------------------------------------------------------
/.github/workflows/scorecard.yml:
--------------------------------------------------------------------------------
1 | # This workflow uses actions that are not certified by GitHub. They are provided
2 | # by a third-party and are governed by separate terms of service, privacy
3 | # policy, and support documentation.
4 |
5 | name: Scorecard supply-chain security
6 | on:
7 | # For Branch-Protection check. Only the default branch is supported. See
8 | # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
9 | branch_protection_rule:
10 | # To guarantee Maintained check is occasionally updated. See
11 | # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
12 | schedule:
13 | - cron: "34 5 * * 0"
14 | workflow_dispatch:
15 | push:
16 | branches: ["main"]
17 |
18 | # Declare default permissions as read only.
19 | permissions: read-all
20 |
21 | jobs:
22 | analysis:
23 | name: Scorecard analysis
24 | runs-on: ubuntu-latest
25 | # `publish_results: true` only works when run from the default branch. conditional can be removed if disabled.
26 | if: (github.event.repository.default_branch == github.ref_name || github.event_name == 'pull_request') && github.repository == 'quantco/dataframely'
27 | permissions:
28 | # Needed to upload the results to code-scanning dashboard.
29 | security-events: write
30 | # Needed to publish results and get a badge (see publish_results below).
31 | id-token: write
32 | # Uncomment the permissions below if installing in a private repository.
33 | # contents: read
34 | # actions: read
35 |
36 | steps:
37 | - name: "Checkout code"
38 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
39 | with:
40 | persist-credentials: false
41 |
42 | - name: "Run analysis"
43 | uses: ossf/scorecard-action@05b42c624433fc40578a4040d5cf5e36ddca8cde # v2.4.2
44 | with:
45 | results_file: results.sarif
46 | results_format: sarif
47 | # (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
48 | # - you want to enable the Branch-Protection check on a *public* repository, or
49 | # - you are installing Scorecard on a *private* repository
50 | # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action?tab=readme-ov-file#authentication-with-fine-grained-pat-optional.
51 | # repo_token: ${{ secrets.SCORECARD_TOKEN }}
52 |
53 | # Public repositories:
54 | # - Publish results to OpenSSF REST API for easy access by consumers
55 | # - Allows the repository to include the Scorecard badge.
56 | # - See https://github.com/ossf/scorecard-action#publishing-results.
57 | # For private repositories:
58 | # - `publish_results` will always be set to `false`, regardless
59 | # of the value entered here.
60 | publish_results: true
61 |
62 | # (Optional) Uncomment file_mode if you have a .gitattributes with files marked export-ignore
63 | # file_mode: git
64 |
65 | # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
66 | # format to the repository Actions tab.
67 | - name: "Upload artifact"
68 | uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
69 | with:
70 | name: SARIF file
71 | path: results.sarif
72 | retention-days: 5
73 |
74 | # Upload the results to GitHub's code scanning dashboard (optional).
75 | # Commenting out will disable upload of results to your repo's Code Scanning dashboard
76 | - name: "Upload to code-scanning"
77 | uses: github/codeql-action/upload-sarif@ff0a06e83cb2de871e5a09832bc6a81e7276941f # v3.28.18
78 | with:
79 | sarif_file: results.sarif
80 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: ^(\.copier-answers\.yml)|.pixi$
2 | repos:
3 | - repo: local
4 | hooks:
5 | # ensure pixi environments are up to date
6 | # workaround for https://github.com/prefix-dev/pixi/issues/1482
7 | - id: pixi-install
8 | name: pixi-install
9 | entry: pixi install -e default -e lint
10 | language: system
11 | always_run: true
12 | require_serial: true
13 | pass_filenames: false
14 | # insert-license
15 | - id: insert-license
16 | name: insert-license
17 | entry: >-
18 | pixi run -e lint
19 | insert-license
20 | --license-base64 Q29weXJpZ2h0IChjKSBRdWFudENvIDIwMjUtMjAyNQpTUERYLUxpY2Vuc2UtSWRlbnRpZmllcjogQlNELTMtQ2xhdXNl
21 | --dynamic-years
22 | --comment-style "#"
23 | language: system
24 | types: [python]
25 | # docformatter
26 | - id: docformatter
27 | name: docformatter
28 | entry: pixi run -e lint docformatter -i
29 | language: system
30 | types: [python]
31 | # ruff
32 | - id: ruff
33 | name: ruff
34 | entry: pixi run -e lint ruff check --fix --exit-non-zero-on-fix --force-exclude
35 | language: system
36 | types_or: [python, pyi]
37 | require_serial: true
38 | - id: ruff-format
39 | name: ruff-format
40 | entry: pixi run -e lint ruff format --force-exclude
41 | language: system
42 | types_or: [python, pyi]
43 | require_serial: true
44 | # mypy
45 | - id: mypy
46 | name: mypy
47 | entry: pixi run -e default mypy
48 | language: system
49 | types: [python]
50 | require_serial: true
51 | # prettier
52 | - id: prettier
53 | name: prettier
54 | entry: pixi run -e lint prettier --write --list-different --ignore-unknown
55 | language: system
56 | types: [text]
57 | files: \.(md|yml|yaml)$
58 | # taplo
59 | - id: taplo
60 | name: taplo
61 | entry: pixi run -e lint taplo format
62 | language: system
63 | types: [toml]
64 | # pre-commit-hooks
65 | - id: trailing-whitespace-fixer
66 | name: trailing-whitespace-fixer
67 | entry: pixi run -e lint trailing-whitespace-fixer
68 | language: system
69 | types: [text]
70 | - id: end-of-file-fixer
71 | name: end-of-file-fixer
72 | entry: pixi run -e lint end-of-file-fixer
73 | language: system
74 | types: [text]
75 | - id: check-merge-conflict
76 | name: check-merge-conflict
77 | entry: pixi run -e lint check-merge-conflict --assume-in-merge
78 | language: system
79 | types: [text]
80 | # typos
81 | - id: typos
82 | name: typos
83 | entry: pixi run -e lint typos --force-exclude
84 | language: system
85 | types: [text]
86 | require_serial: true
87 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | build
2 | conda.recipe
3 | .copier-answers.yml
4 |
5 | *.html
6 | *.properties
7 |
--------------------------------------------------------------------------------
/.prettierrc:
--------------------------------------------------------------------------------
1 | {
2 | "singleQuote": false,
3 | "bracketSpacing": true,
4 | "printWidth": 200,
5 | "endOfLine": "auto",
6 | "tabWidth": 2
7 | }
8 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-20.04
4 | tools:
5 | python: mambaforge-latest
6 | commands:
7 | - mamba install -c conda-forge -c nodefaults pixi
8 | - pixi run -e docs postinstall
9 | - pixi run -e docs docs
10 | - pixi run -e docs readthedocs
11 | sphinx:
12 | configuration: docs/conf.py
13 | formats:
14 | - pdf
15 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | edition = "2021"
3 | name = "dataframely"
4 | version = "0.1.0"
5 |
6 | [lib]
7 | crate-type = ["cdylib"]
8 | name = "dataframely"
9 |
10 | [dependencies]
11 | pyo3 = { version = "0.24", features = ["abi3-py311", "extension-module"] }
12 | rand = { version = "0.9", features = ["std_rng"] }
13 | regex-syntax = "0.8"
14 | thiserror = "2.0"
15 |
16 | [profile.release]
17 | codegen-units = 1
18 | lto = true
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2025, QuantCo
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | dataframely
— A declarative, 🐻❄️-native data frame validation library
8 |
9 |
10 | [](https://github.com/quantco/dataframely/actions/workflows/ci.yml)
11 | [](https://prefix.dev/channels/conda-forge/packages/dataframely)
12 | [](https://pypi.org/project/dataframely)
13 | [](https://pypi.org/project/dataframely)
14 | [](https://codecov.io/gh/Quantco/dataframely)
15 |
16 |
17 |
18 | ## 🗂 Table of Contents
19 |
20 | - [Introduction](#-introduction)
21 | - [Installation](#-installation)
22 | - [Usage](#-usage)
23 |
24 | ## 📖 Introduction
25 |
26 | Dataframely is a Python package to validate the schema and content of [`polars`](https://pola.rs/) data frames. Its
27 | purpose is to make data pipelines more robust by ensuring that data meets expectations and more readable by adding
28 | schema information to data frame type hints.
29 |
30 | ## 💿 Installation
31 |
32 | You can install `dataframely` using your favorite package manager, e.g., `pixi` or `pip`:
33 |
34 | ```bash
35 | pixi add dataframely
36 | pip install dataframely
37 | ```
38 |
39 | ## 🎯 Usage
40 |
41 | ### Defining a data frame schema
42 |
43 | ```python
44 | import dataframely as dy
45 | import polars as pl
46 |
47 | class HouseSchema(dy.Schema):
48 | zip_code = dy.String(nullable=False, min_length=3)
49 | num_bedrooms = dy.UInt8(nullable=False)
50 | num_bathrooms = dy.UInt8(nullable=False)
51 | price = dy.Float64(nullable=False)
52 |
53 | @dy.rule()
54 | def reasonable_bathroom_to_bedrooom_ratio() -> pl.Expr:
55 | ratio = pl.col("num_bathrooms") / pl.col("num_bedrooms")
56 | return (ratio >= 1 / 3) & (ratio <= 3)
57 |
58 | @dy.rule(group_by=["zip_code"])
59 | def minimum_zip_code_count() -> pl.Expr:
60 | return pl.len() >= 2
61 | ```
62 |
63 | ### Validating data against schema
64 |
65 | ```python
66 |
67 | import polars as pl
68 |
69 | df = pl.DataFrame({
70 | "zip_code": ["01234", "01234", "1", "213", "123", "213"],
71 | "num_bedrooms": [2, 2, 1, None, None, 2],
72 | "num_bathrooms": [1, 2, 1, 1, 0, 8],
73 | "price": [100_000, 110_000, 50_000, 80_000, 60_000, 160_000]
74 | })
75 |
76 | # Validate the data and cast columns to expected types
77 | validated_df: dy.DataFrame[HouseSchema] = HouseSchema.validate(df, cast=True)
78 | ```
79 |
80 | See more advanced usage examples in the [documentation](https://dataframely.readthedocs.io/en/latest/).
81 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Reporting Security Issues
2 |
3 | We take security bugs in our projects seriously. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
4 |
5 | To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/quantco/dataframely/security/advisories/new) tab.
6 |
7 | We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
8 |
--------------------------------------------------------------------------------
/dataframely/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import importlib.metadata
5 | import warnings
6 |
7 | try:
8 | __version__ = importlib.metadata.version(__name__)
9 | except importlib.metadata.PackageNotFoundError as e: # pragma: no cover
10 | warnings.warn(f"Could not determine version of {__name__}\n{e!s}", stacklevel=2)
11 | __version__ = "unknown"
12 |
13 | from . import random
14 | from ._base_collection import CollectionMember
15 | from ._filter import filter
16 | from ._rule import rule
17 | from ._typing import DataFrame, LazyFrame
18 | from .collection import Collection
19 | from .columns import (
20 | Any,
21 | Array,
22 | Bool,
23 | Column,
24 | Date,
25 | Datetime,
26 | Decimal,
27 | Duration,
28 | Enum,
29 | Float,
30 | Float32,
31 | Float64,
32 | Int8,
33 | Int16,
34 | Int32,
35 | Int64,
36 | Integer,
37 | List,
38 | Object,
39 | String,
40 | Struct,
41 | Time,
42 | UInt8,
43 | UInt16,
44 | UInt32,
45 | UInt64,
46 | )
47 | from .config import Config
48 | from .failure import FailureInfo
49 | from .functional import (
50 | concat_collection_members,
51 | filter_relationship_one_to_at_least_one,
52 | filter_relationship_one_to_one,
53 | )
54 | from .schema import Schema
55 |
56 | __all__ = [
57 | "random",
58 | "filter",
59 | "rule",
60 | "DataFrame",
61 | "LazyFrame",
62 | "Collection",
63 | "CollectionMember",
64 | "Config",
65 | "FailureInfo",
66 | "concat_collection_members",
67 | "filter_relationship_one_to_at_least_one",
68 | "filter_relationship_one_to_one",
69 | "Schema",
70 | "Any",
71 | "Bool",
72 | "Column",
73 | "Date",
74 | "Datetime",
75 | "Decimal",
76 | "Duration",
77 | "Time",
78 | "Enum",
79 | "Float",
80 | "Float32",
81 | "Float64",
82 | "Int8",
83 | "Int16",
84 | "Int32",
85 | "Int64",
86 | "Integer",
87 | "UInt8",
88 | "UInt16",
89 | "UInt32",
90 | "UInt64",
91 | "String",
92 | "Struct",
93 | "List",
94 | "Array",
95 | "Object",
96 | ]
97 |
--------------------------------------------------------------------------------
/dataframely/_compat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | from typing import Any
6 |
7 |
8 | class _DummyModule: # pragma: no cover
9 | def __init__(self, module: str) -> None:
10 | self.module = module
11 |
12 | def __getattr__(self, name: str) -> Any:
13 | raise ValueError(f"Module '{self.module}' is not installed.")
14 |
15 |
16 | # ------------------------------------ SQLALCHEMY ------------------------------------ #
17 |
18 | try:
19 | import sqlalchemy as sa
20 | import sqlalchemy.dialects.mssql as sa_mssql
21 | from sqlalchemy.sql.type_api import TypeEngine as sa_TypeEngine
22 | except ImportError: # pragma: no cover
23 | sa = _DummyModule("sqlalchemy") # type: ignore
24 | sa_mssql = _DummyModule("sqlalchemy") # type: ignore
25 |
26 | class sa_TypeEngine: # type: ignore # noqa: N801
27 | pass
28 |
29 |
30 | # -------------------------------------- PYARROW ------------------------------------- #
31 |
32 | try:
33 | import pyarrow as pa
34 | except ImportError: # pragma: no cover
35 | pa = _DummyModule("pyarrow")
36 |
37 | # ------------------------------------------------------------------------------------ #
38 |
39 | __all__ = ["sa", "sa_mssql", "sa_TypeEngine", "pa"]
40 |
--------------------------------------------------------------------------------
/dataframely/_deprecation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import os
5 | import warnings
6 | from collections.abc import Callable
7 | from functools import wraps
8 |
9 | TRUTHY_VALUES = ["1", "true"]
10 |
11 |
12 | def skip_if(env: str) -> Callable:
13 | """Decorator to skip warnings based on environment variable.
14 |
15 | If the environment variable is equivalent to any of TRUTHY_VALUES, the wrapped
16 | function is skipped.
17 | """
18 |
19 | def decorator(fun: Callable) -> Callable:
20 | @wraps(fun)
21 | def wrapper() -> None:
22 | if os.getenv(env, "").lower() in TRUTHY_VALUES:
23 | return
24 | fun()
25 |
26 | return wrapper
27 |
28 | return decorator
29 |
30 |
31 | @skip_if(env="DATAFRAMELY_NO_FUTURE_WARNINGS")
32 | def warn_nullable_default_change() -> None:
33 | warnings.warn(
34 | "The 'nullable' argument was not explicitly set. In a future release, "
35 | "'nullable=False' will be the default if 'nullable' is not specified. "
36 | "Explicitly set 'nullable=True' if you want your column to be nullable.",
37 | FutureWarning,
38 | stacklevel=4,
39 | )
40 |
--------------------------------------------------------------------------------
/dataframely/_extre.pyi:
--------------------------------------------------------------------------------
1 | from typing import Literal, overload
2 |
3 | def matching_string_length(regex: str) -> tuple[int, int | None]:
4 | """
5 | Compute the minimum and maximum length (if available) of strings matching a regular expression.
6 |
7 | Args:
8 | regex: The regular expression to analyze. The regular expression must not
9 | contain any lookaround operators.
10 |
11 | Returns:
12 | A tuple of the minimum of maximum length of the matching strings. While the minimum
13 | length is guaranteed to be available, the maximum length may be ``None`` if ``regex``
14 | matches strings of potentially infinite length (e.g. due to the use of ``+`` or ``*``).
15 |
16 | Raises:
17 | ValueError: If the regex cannot be parsed or analyzed.
18 | """
19 |
20 | @overload
21 | def sample(
22 | regex: str, n: int, max_repetitions: int = 16, seed: int | None = None
23 | ) -> list[str]:
24 | """
25 | Sample a random (set of) string(s) matching the provided regular expression.
26 |
27 | Args:
28 | regex: The regular expression generated strings must match. The regular
29 | expression must not contain any lookaround operators.
30 | n: The number of random strings to generate or ``None`` if a single one should
31 | be generated.
32 | max_repetitions: The maximum number of repetitions for ``+`` and ``*``
33 | quantifiers.
34 | seed: The seed to use for the random sampling procedure.
35 |
36 | Returns:
37 | A single randomly generated string if ``n is None`` or a list of randomly
38 | generated strings if ``n`` is an integer.
39 |
40 | Raises:
41 | ValueError: If the regex cannot be parsed.
42 |
43 | Attention:
44 | Using wildcards (i.e. ``.``) really means _any_ valid Unicode character.
45 | Consider using more precise regular expressions if this is undesired.
46 | """
47 |
48 | @overload
49 | def sample(
50 | regex: str,
51 | n: Literal[None] = None,
52 | max_repetitions: int = 16,
53 | seed: int | None = None,
54 | ) -> str: ...
55 |
--------------------------------------------------------------------------------
/dataframely/_filter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Callable
5 | from typing import Generic, TypeVar
6 |
7 | import polars as pl
8 |
9 | C = TypeVar("C")
10 |
11 |
12 | class Filter(Generic[C]):
13 | """Internal class representing logic for filtering members of a collection."""
14 |
15 | def __init__(self, logic: Callable[[C], pl.LazyFrame]) -> None:
16 | self.logic = logic
17 |
18 |
19 | def filter() -> Callable[[Callable[[C], pl.LazyFrame]], Filter[C]]:
20 | """Mark a function as filters for rows in the members of a collection.
21 |
22 | The name of the function will be used as the name of the filter. The name must not
23 | clash with the name of any column in the member schemas or rules defined on the
24 | member schemas.
25 |
26 | A filter receives a collection as input and must return a data frame like the
27 | following:
28 |
29 | - The columns must be a superset of the common primary keys across all members.
30 | - The rows must provide the primary keys which ought to be *kept* across the
31 | members. The filter results in the removal of rows which are lost as the result
32 | of inner-joining members onto the return value of this function.
33 |
34 | Attention:
35 | Make sure to provide unique combinations of the primary keys or the filters
36 | might introduce duplicate rows.
37 | """
38 |
39 | def decorator(validation_fn: Callable[[C], pl.LazyFrame]) -> Filter[C]:
40 | return Filter(logic=validation_fn)
41 |
42 | return decorator
43 |
--------------------------------------------------------------------------------
/dataframely/_polars.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import datetime as dt
5 | from collections.abc import Iterable
6 | from typing import TypeVar
7 |
8 | import polars as pl
9 | from polars.datatypes import DataTypeClass
10 |
11 | PolarsDataType = pl.DataType | DataTypeClass
12 | FrameType = TypeVar("FrameType", pl.DataFrame, pl.LazyFrame)
13 |
14 | EPOCH_DATETIME = dt.datetime(1970, 1, 1)
15 | SECONDS_PER_DAY = 86400
16 |
17 | # --------------------------------------- JOINS -------------------------------------- #
18 |
19 |
20 | def join_all_inner(dfs: Iterable[FrameType], on: str | list[str]) -> FrameType:
21 | it = iter(dfs)
22 | result = next(it)
23 | while (df := next(it, None)) is not None:
24 | result = result.join(df, on=on)
25 | return result
26 |
27 |
28 | def join_all_outer(dfs: Iterable[FrameType], on: str | list[str]) -> FrameType:
29 | it = iter(dfs)
30 | result = next(it)
31 | while (df := next(it, None)) is not None:
32 | result = result.join(df, on=on, how="full", coalesce=True)
33 | return result
34 |
35 |
36 | # ------------------------------------- DATETIMES ------------------------------------ #
37 |
38 |
39 | def date_matches_resolution(t: dt.date, resolution: str) -> bool:
40 | return pl.Series([t], dtype=pl.Date).dt.truncate(resolution).item() == t
41 |
42 |
43 | def datetime_matches_resolution(t: dt.datetime, resolution: str) -> bool:
44 | return pl.Series([t], dtype=pl.Datetime).dt.truncate(resolution).item() == t
45 |
46 |
47 | def time_matches_resolution(t: dt.time, resolution: str) -> bool:
48 | return (
49 | pl.Series([t], dtype=pl.Time)
50 | .to_frame("t")
51 | .select(
52 | pl.lit(EPOCH_DATETIME.date())
53 | .dt.combine(pl.col("t"))
54 | .dt.truncate(resolution)
55 | .dt.time()
56 | )
57 | .item()
58 | == t
59 | )
60 |
61 |
62 | def timedelta_matches_resolution(d: dt.timedelta, resolution: str) -> bool:
63 | return datetime_matches_resolution(EPOCH_DATETIME + d, resolution)
64 |
--------------------------------------------------------------------------------
/dataframely/_typing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | from collections.abc import Callable
7 | from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeVar
8 |
9 | import polars as pl
10 |
11 | from ._base_schema import BaseSchema
12 |
13 | S = TypeVar("S", bound=BaseSchema, covariant=True)
14 |
15 | P = ParamSpec("P")
16 | R = TypeVar("R")
17 |
18 |
19 | def inherit_signature( # pragma: no cover
20 | target_fn: Callable[P, Any],
21 | ) -> Callable[[Callable[..., R]], Callable[P, R]]:
22 | # NOTE: This code is executed during parsing but has no effect at runtime.
23 | if TYPE_CHECKING:
24 | return lambda _: target_fn
25 |
26 | return lambda _: None
27 |
28 |
29 | class DataFrame(pl.DataFrame, Generic[S]):
30 | """Generic wrapper around a :class:`polars.DataFrame` to attach schema information.
31 |
32 | This class is merely used for the type system and never actually instantiated. This
33 | means that it won't exist at runtime and ``isinstance(PoalrsDataFrame, )`` will
34 | always fail. Accordingly, users should not try to create instances of this class.
35 | """
36 |
37 | # NOTE: Code in this class will never be executed.
38 |
39 | @inherit_signature(pl.DataFrame.clear)
40 | def clear(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
41 | raise NotImplementedError # pragma: no cover
42 |
43 | @inherit_signature(pl.DataFrame.clone)
44 | def clone(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
45 | raise NotImplementedError # pragma: no cover
46 |
47 | @inherit_signature(pl.DataFrame.lazy)
48 | def lazy(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
49 | raise NotImplementedError # pragma: no cover
50 |
51 | def pipe(
52 | self,
53 | function: Callable[Concatenate[DataFrame[S], P], R],
54 | *args: P.args,
55 | **kwargs: P.kwargs,
56 | ) -> R:
57 | raise NotImplementedError # pragma: no cover
58 |
59 | @inherit_signature(pl.DataFrame.rechunk)
60 | def rechunk(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
61 | raise NotImplementedError # pragma: no cover
62 |
63 | @inherit_signature(pl.DataFrame.set_sorted)
64 | def set_sorted(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
65 | raise NotImplementedError # pragma: no cover
66 |
67 | @inherit_signature(pl.DataFrame.shrink_to_fit)
68 | def shrink_to_fit(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
69 | raise NotImplementedError # pragma: no cover
70 |
71 |
72 | class LazyFrame(pl.LazyFrame, Generic[S]):
73 | """Generic wrapper around a :class:`polars.LazyFrame` to attach schema information.
74 |
75 | This class is merely used for the type system and never actually instantiated. This
76 | means that it won't exist at runtime and ``isinstance(LazyFrame, )`` will
77 | always fail. Accordingly, users should not try to create instances of this class.
78 | """
79 |
80 | # NOTE: Code in this class will never be executed.
81 |
82 | @inherit_signature(pl.LazyFrame.cache)
83 | def cache(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
84 | raise NotImplementedError # pragma: no cover
85 |
86 | @inherit_signature(pl.LazyFrame.clear)
87 | def clear(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
88 | raise NotImplementedError # pragma: no cover
89 |
90 | @inherit_signature(pl.LazyFrame.clone)
91 | def clone(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
92 | raise NotImplementedError # pragma: no cover
93 |
94 | # NOTE: inheriting the signature does not work since `mypy` doesn't correctly
95 | # propagate overloads
96 | def collect(self, *args: Any, **kwargs: Any) -> DataFrame[S]: # type: ignore
97 | raise NotImplementedError # pragma: no cover
98 |
99 | @inherit_signature(pl.LazyFrame.lazy)
100 | def lazy(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
101 | raise NotImplementedError # pragma: no cover
102 |
103 | def pipe(
104 | self,
105 | function: Callable[Concatenate[LazyFrame[S], P], R],
106 | *args: P.args,
107 | **kwargs: P.kwargs,
108 | ) -> R:
109 | raise NotImplementedError # pragma: no cover
110 |
111 | @inherit_signature(pl.LazyFrame.set_sorted)
112 | def set_sorted(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
113 | raise NotImplementedError # pragma: no cover
114 |
--------------------------------------------------------------------------------
/dataframely/_validation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Iterable
5 | from typing import Literal
6 |
7 | import polars as pl
8 |
9 | from dataframely.exc import DtypeValidationError, ValidationError
10 |
11 | from ._polars import PolarsDataType
12 | from .columns import Column
13 |
14 | DtypeCasting = Literal["none", "lenient", "strict"]
15 |
16 |
17 | def validate_columns(
18 | lf: pl.LazyFrame,
19 | actual: Iterable[str],
20 | expected: Iterable[str],
21 | ) -> pl.LazyFrame:
22 | """Validate the existence of expected columns in a data frame.
23 |
24 | Args:
25 | lf: The data frame whose list of columns to validate.
26 | actual: The list of columns that _are_ observed. Passed as a separate argument as a
27 | performance improvement as it minimizes the number of schema collections.
28 | expected: The list of columns that _should_ be observed.
29 |
30 | Raises:
31 | ValidationError: If any expected column is not part of the actual columns.
32 |
33 | Returns:
34 | The input data frame, either as-is or with extra columns stripped.
35 | """
36 | actual_set = set(actual)
37 | expected_set = set(expected)
38 |
39 | missing_columns = expected_set - actual_set
40 | if len(missing_columns) > 0:
41 | raise ValidationError(
42 | f"{len(missing_columns)} columns in the schema are missing in the "
43 | f"data frame: {sorted(missing_columns)}."
44 | )
45 |
46 | return lf.select(expected)
47 |
48 |
49 | def validate_dtypes(
50 | lf: pl.LazyFrame,
51 | actual: pl.Schema,
52 | expected: dict[str, Column],
53 | casting: DtypeCasting,
54 | ) -> pl.LazyFrame:
55 | """Validate the dtypes of all expected columns in a data frame.
56 |
57 | Args:
58 | lf: The data frame whose column dtypes to validate.
59 | actual: The actual schema of the data frame. Passed as a separate argument as a
60 | performance improvement as it minimizes the number of schema collections.
61 | expected: The column definitions carrying the expected dtypes.
62 | casting: The strategy for casting dtypes.
63 |
64 | Raises:
65 | DtypeValidationError: If the expected column dtypes do not match the input's and
66 | ``casting`` set to ``none``.
67 |
68 | Returns:
69 | The input data frame with all column dtypes ensured to have the expected dtype.
70 | """
71 | dtype_errors: dict[str, tuple[PolarsDataType, PolarsDataType]] = {}
72 | for name, col in expected.items():
73 | if not col.validate_dtype(actual[name]):
74 | dtype_errors[name] = (actual[name], col.dtype)
75 |
76 | if len(dtype_errors) > 0:
77 | if casting == "none":
78 | raise DtypeValidationError(dtype_errors)
79 | else:
80 | return lf.with_columns(
81 | pl.col(name).cast(expected[name].dtype, strict=(casting == "strict"))
82 | for name in dtype_errors.keys()
83 | )
84 |
85 | return lf
86 |
--------------------------------------------------------------------------------
/dataframely/columns/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from ._base import Column
5 | from .any import Any
6 | from .array import Array
7 | from .bool import Bool
8 | from .datetime import Date, Datetime, Duration, Time
9 | from .decimal import Decimal
10 | from .enum import Enum
11 | from .float import Float, Float32, Float64
12 | from .integer import Int8, Int16, Int32, Int64, Integer, UInt8, UInt16, UInt32, UInt64
13 | from .list import List
14 | from .object import Object
15 | from .string import String
16 | from .struct import Struct
17 |
18 | __all__ = [
19 | "Column",
20 | "Any",
21 | "Array",
22 | "Bool",
23 | "Date",
24 | "Datetime",
25 | "Decimal",
26 | "Duration",
27 | "Enum",
28 | "Time",
29 | "Float",
30 | "Float32",
31 | "Float64",
32 | "Int8",
33 | "Int16",
34 | "Int32",
35 | "Int64",
36 | "Integer",
37 | "Object",
38 | "UInt8",
39 | "UInt16",
40 | "UInt32",
41 | "UInt64",
42 | "String",
43 | "List",
44 | "Struct",
45 | ]
46 |
--------------------------------------------------------------------------------
/dataframely/columns/_mixins.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Sequence
5 | from typing import TYPE_CHECKING, Any, Generic, Protocol, Self, TypeVar
6 |
7 | import polars as pl
8 |
9 | if TYPE_CHECKING: # pragma: no cover
10 | from ._base import Column
11 |
12 | Base = Column
13 | else:
14 | Base = object
15 |
16 | # ----------------------------------- ORDINAL MIXIN ---------------------------------- #
17 |
18 |
19 | class Comparable(Protocol):
20 | def __gt__(self, other: Self, /) -> bool: ...
21 | def __ge__(self, other: Self, /) -> bool: ...
22 |
23 |
24 | T = TypeVar("T", bound=Comparable)
25 |
26 |
27 | class OrdinalMixin(Generic[T], Base):
28 | """Mixin to use for ordinal types."""
29 |
30 | def __init__(
31 | self,
32 | *,
33 | min: T | None = None,
34 | min_exclusive: T | None = None,
35 | max: T | None = None,
36 | max_exclusive: T | None = None,
37 | **kwargs: Any,
38 | ):
39 | if min is not None and min_exclusive is not None:
40 | raise ValueError("At most one of `min` and `min_exclusive` must be set.")
41 | if max is not None and max_exclusive is not None:
42 | raise ValueError("At most one of `max` and `max_exclusive` must be set.")
43 |
44 | if min is not None and max is not None and min > max:
45 | raise ValueError("`min` must not be greater than `max`.")
46 | if min_exclusive is not None and max is not None and min_exclusive >= max:
47 | raise ValueError("`min_exclusive` must not be greater or equal to `max`.")
48 | if min is not None and max_exclusive is not None and min >= max_exclusive:
49 | raise ValueError("`min` must not be greater or equal to `max_exclusive`.")
50 | if (
51 | min_exclusive is not None
52 | and max_exclusive is not None
53 | and min_exclusive >= max_exclusive
54 | ):
55 | raise ValueError(
56 | "`min_exclusive` must not be greater or equal to `max_exclusive`."
57 | )
58 |
59 | super().__init__(**kwargs)
60 | self.min = min
61 | self.min_exclusive = min_exclusive
62 | self.max = max
63 | self.max_exclusive = max_exclusive
64 |
65 | def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
66 | result = super().validation_rules(expr)
67 | if self.min is not None:
68 | result["min"] = expr >= self.min # type: ignore
69 | if self.min_exclusive is not None:
70 | result["min_exclusive"] = expr > self.min_exclusive # type: ignore
71 | if self.max is not None:
72 | result["max"] = expr <= self.max # type: ignore
73 | if self.max_exclusive:
74 | result["max_exclusive"] = expr < self.max_exclusive # type: ignore
75 | return result
76 |
77 |
78 | # ------------------------------------ IS IN MIXIN ----------------------------------- #
79 |
80 | U = TypeVar("U")
81 |
82 |
83 | class IsInMixin(Generic[U], Base):
84 | """Mixin to use for types implementing "is in"."""
85 |
86 | def __init__(self, *, is_in: Sequence[U] | None = None, **kwargs: Any) -> None:
87 | super().__init__(**kwargs)
88 | self.is_in = is_in
89 |
90 | def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
91 | result = super().validation_rules(expr)
92 | if self.is_in is not None:
93 | result["is_in"] = expr.is_in(self.is_in)
94 | return result
95 |
--------------------------------------------------------------------------------
/dataframely/columns/_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Callable
5 | from typing import Any, Concatenate, Literal, ParamSpec, TypeVar, overload
6 |
7 |
8 | class classproperty(property): # noqa: N801
9 | """Replacement for the deprecated @classmethod @property decorator combination.
10 |
11 | Usage:
12 | ```
13 | @classproperty
14 | def num_bytes(self) -> int:
15 | ...
16 | ```
17 | """
18 |
19 | def __get__(self, instance: Any, owner: type | None = None, /) -> Any:
20 | return self.fget(owner) if self.fget is not None else None
21 |
22 |
23 | T = TypeVar("T")
24 | R = TypeVar("R")
25 | P = ParamSpec("P")
26 |
27 |
28 | def map_optional(
29 | fn: Callable[Concatenate[T, P], R],
30 | value: T | None,
31 | *args: P.args,
32 | **kwargs: P.kwargs,
33 | ) -> R | None:
34 | if value is None:
35 | return None
36 | return fn(value, *args, **kwargs)
37 |
38 |
39 | @overload
40 | def first_non_null(
41 | *values: T | None, allow_null_response: Literal[True]
42 | ) -> T | None: ...
43 |
44 |
45 | @overload
46 | def first_non_null(*values: T | None, default: T) -> T: ...
47 |
48 |
49 | def first_non_null(
50 | *values: T | None,
51 | default: T | None = None,
52 | allow_null_response: Literal[True] | None = None,
53 | ) -> T | None:
54 | """Returns the first element in a sequence that is not None."""
55 | for value in values:
56 | if value is not None:
57 | return value
58 | if allow_null_response:
59 | return None
60 | return default
61 |
--------------------------------------------------------------------------------
/dataframely/columns/any.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | from collections.abc import Callable
7 |
8 | import polars as pl
9 |
10 | from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine
11 | from dataframely._polars import PolarsDataType
12 | from dataframely.random import Generator
13 |
14 | from ._base import Column
15 |
16 |
17 | class Any(Column):
18 | """A column with arbitrary type.
19 |
20 | As a column with arbitrary type is commonly mapped to the ``Null`` type (this is the
21 | default in :mod:`polars` and :mod:`pyarrow` for empty columns), dataframely also
22 | requires this column to be nullable. Hence, it cannot be used as a primary key.
23 | """
24 |
25 | def __init__(
26 | self,
27 | *,
28 | check: (
29 | Callable[[pl.Expr], pl.Expr]
30 | | list[Callable[[pl.Expr], pl.Expr]]
31 | | dict[str, Callable[[pl.Expr], pl.Expr]]
32 | | None
33 | ) = None,
34 | alias: str | None = None,
35 | metadata: dict[str, Any] | None = None,
36 | ):
37 | """
38 | Args:
39 | check: A custom rule or multiple rules to run for this column. This can be:
40 | - A single callable that returns a non-aggregated boolean expression.
41 | The name of the rule is derived from the callable name, or defaults to
42 | "check" for lambdas.
43 | - A list of callables, where each callable returns a non-aggregated
44 | boolean expression. The name of the rule is derived from the callable
45 | name, or defaults to "check" for lambdas. Where multiple rules result
46 | in the same name, the suffix __i is appended to the name.
47 | - A dictionary mapping rule names to callables, where each callable
48 | returns a non-aggregated boolean expression.
49 | All rule names provided here are given the prefix "check_".
50 | alias: An overwrite for this column's name which allows for using a column
51 | name that is not a valid Python identifier. Especially note that setting
52 | this option does _not_ allow to refer to the column with two different
53 | names, the specified alias is the only valid name.
54 | metadata: A dictionary of metadata to attach to the column.
55 | """
56 | super().__init__(
57 | nullable=True,
58 | primary_key=False,
59 | check=check,
60 | alias=alias,
61 | metadata=metadata,
62 | )
63 |
64 | @property
65 | def dtype(self) -> pl.DataType:
66 | return pl.Null() # default polars dtype
67 |
68 | def validate_dtype(self, dtype: PolarsDataType) -> bool:
69 | return True
70 |
71 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
72 | match dialect.name:
73 | case "mssql":
74 | return sa_mssql.SQL_VARIANT()
75 | case _: # pragma: no cover
76 | raise NotImplementedError("SQL column cannot have 'Any' type.")
77 |
78 | def pyarrow_field(self, name: str) -> pa.Field:
79 | return pa.field(name, self.pyarrow_dtype, nullable=self.nullable)
80 |
81 | @property
82 | def pyarrow_dtype(self) -> pa.DataType:
83 | return pa.null()
84 |
85 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
86 | return pl.repeat(None, n, dtype=pl.Null, eager=True)
87 |
--------------------------------------------------------------------------------
/dataframely/columns/array.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | import math
7 | from collections.abc import Callable, Sequence
8 | from typing import Any, Literal
9 |
10 | import polars as pl
11 |
12 | from dataframely._compat import pa, sa, sa_TypeEngine
13 | from dataframely.random import Generator
14 |
15 | from ._base import Column
16 | from .struct import Struct
17 |
18 |
19 | class Array(Column):
20 | """A fixed-shape array column."""
21 |
22 | def __init__(
23 | self,
24 | inner: Column,
25 | shape: int | tuple[int, ...],
26 | *,
27 | nullable: bool = True,
28 | # polars doesn't yet support grouping by arrays,
29 | # see https://github.com/pola-rs/polars/issues/22574
30 | primary_key: Literal[False] = False,
31 | check: (
32 | Callable[[pl.Expr], pl.Expr]
33 | | list[Callable[[pl.Expr], pl.Expr]]
34 | | dict[str, Callable[[pl.Expr], pl.Expr]]
35 | | None
36 | ) = None,
37 | alias: str | None = None,
38 | metadata: dict[str, Any] | None = None,
39 | ):
40 | """
41 | Args:
42 | inner: The inner column type. No validation rules on the inner type are supported yet.
43 | shape: The shape of the array.
44 | nullable: Whether this column may contain null values.
45 | primary_key: Whether this column is part of the primary key of the schema.
46 | Not yet supported for the Array type.
47 | check: A custom rule or multiple rules to run for this column. This can be:
48 | - A single callable that returns a non-aggregated boolean expression.
49 | The name of the rule is derived from the callable name, or defaults to
50 | "check" for lambdas.
51 | - A list of callables, where each callable returns a non-aggregated
52 | boolean expression. The name of the rule is derived from the callable
53 | name, or defaults to "check" for lambdas. Where multiple rules result
54 | in the same name, the suffix __i is appended to the name.
55 | - A dictionary mapping rule names to callables, where each callable
56 | returns a non-aggregated boolean expression.
57 | All rule names provided here are given the prefix "check_".
58 | alias: An overwrite for this column's name which allows for using a column
59 | name that is not a valid Python identifier. Especially note that setting
60 | this option does _not_ allow to refer to the column with two different
61 | names, the specified alias is the only valid name.
62 | metadata: A dictionary of metadata to attach to the column.
63 | """
64 | if inner.primary_key or (
65 | isinstance(inner, Struct)
66 | and any(col.primary_key for col in inner.inner.values())
67 | ):
68 | raise ValueError(
69 | "`primary_key=True` is not yet supported for inner types of the Array type."
70 | )
71 |
72 | # We disallow validation rules on the inner type since Polars arrays currently don't support .eval(). Converting
73 | # to a list and calling .list.eval() is possible, however, since the shape can have multiple axes, the recursive
74 | # conversion could have significant performance impact. Hence, we simply disallow inner validation rules.
75 | # Another option would be to allow validation rules only for sampling, but not enforce them.
76 | if inner.validation_rules(pl.lit(None)):
77 | raise ValueError(
78 | "Validation rules on the inner type of Array are not yet supported."
79 | )
80 |
81 | super().__init__(
82 | nullable=nullable,
83 | primary_key=False,
84 | check=check,
85 | alias=alias,
86 | metadata=metadata,
87 | )
88 | self.inner = inner
89 | self.shape = shape if isinstance(shape, tuple) else (shape,)
90 |
91 | @property
92 | def dtype(self) -> pl.DataType:
93 | return pl.Array(self.inner.dtype, self.shape)
94 |
95 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
96 | # NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
97 | raise NotImplementedError("SQL column cannot have 'Array' type.")
98 |
99 | def _pyarrow_dtype_of_shape(self, shape: Sequence[int]) -> pa.DataType:
100 | if shape:
101 | size, *rest = shape
102 | return pa.list_(self._pyarrow_dtype_of_shape(rest), size)
103 | else:
104 | return self.inner.pyarrow_dtype
105 |
106 | @property
107 | def pyarrow_dtype(self) -> pa.DataType:
108 | return self._pyarrow_dtype_of_shape(self.shape)
109 |
110 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
111 | # Sample the inner elements in a flat series
112 | n_elements = n * math.prod(self.shape)
113 | all_elements = self.inner.sample(generator, n_elements)
114 |
115 | # Finally, apply a null mask
116 | return generator._apply_null_mask(
117 | all_elements.reshape((n, *self.shape)),
118 | null_probability=self._null_probability,
119 | )
120 |
--------------------------------------------------------------------------------
/dataframely/columns/bool.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | import polars as pl
7 |
8 | from dataframely._compat import pa, sa, sa_TypeEngine
9 | from dataframely.random import Generator
10 |
11 | from ._base import Column
12 |
13 | # ------------------------------------------------------------------------------------ #
14 |
15 |
16 | class Bool(Column):
17 | """A column of booleans."""
18 |
19 | @property
20 | def dtype(self) -> pl.DataType:
21 | return pl.Boolean()
22 |
23 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
24 | return sa.Boolean()
25 |
26 | @property
27 | def pyarrow_dtype(self) -> pa.DataType:
28 | return pa.bool_()
29 |
30 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
31 | return generator.sample_bool(n, null_probability=self._null_probability)
32 |
--------------------------------------------------------------------------------
/dataframely/columns/enum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | from collections.abc import Callable, Sequence
7 | from typing import Any
8 |
9 | import polars as pl
10 |
11 | from dataframely._compat import pa, sa, sa_TypeEngine
12 | from dataframely._polars import PolarsDataType
13 | from dataframely.random import Generator
14 |
15 | from ._base import Column
16 |
17 |
18 | class Enum(Column):
19 | """A column of enum (string) values."""
20 |
21 | def __init__(
22 | self,
23 | categories: Sequence[str],
24 | *,
25 | nullable: bool | None = None,
26 | primary_key: bool = False,
27 | check: (
28 | Callable[[pl.Expr], pl.Expr]
29 | | list[Callable[[pl.Expr], pl.Expr]]
30 | | dict[str, Callable[[pl.Expr], pl.Expr]]
31 | | None
32 | ) = None,
33 | alias: str | None = None,
34 | metadata: dict[str, Any] | None = None,
35 | ):
36 | """
37 | Args:
38 | categories: The list of valid categories for the enum.
39 | nullable: Whether this column may contain null values.
40 | Explicitly set `nullable=True` if you want your column to be nullable.
41 | In a future release, `nullable=False` will be the default if `nullable`
42 | is not specified.
43 | primary_key: Whether this column is part of the primary key of the schema.
44 | If ``True``, ``nullable`` is automatically set to ``False``.
45 | check: A custom rule or multiple rules to run for this column. This can be:
46 | - A single callable that returns a non-aggregated boolean expression.
47 | The name of the rule is derived from the callable name, or defaults to
48 | "check" for lambdas.
49 | - A list of callables, where each callable returns a non-aggregated
50 | boolean expression. The name of the rule is derived from the callable
51 | name, or defaults to "check" for lambdas. Where multiple rules result
52 | in the same name, the suffix __i is appended to the name.
53 | - A dictionary mapping rule names to callables, where each callable
54 | returns a non-aggregated boolean expression.
55 | All rule names provided here are given the prefix "check_".
56 | alias: An overwrite for this column's name which allows for using a column
57 | name that is not a valid Python identifier. Especially note that setting
58 | this option does _not_ allow to refer to the column with two different
59 | names, the specified alias is the only valid name.
60 | metadata: A dictionary of metadata to attach to the column.
61 | """
62 | super().__init__(
63 | nullable=nullable,
64 | primary_key=primary_key,
65 | check=check,
66 | alias=alias,
67 | metadata=metadata,
68 | )
69 | self.categories = list(categories)
70 |
71 | @property
72 | def dtype(self) -> pl.DataType:
73 | return pl.Enum(self.categories)
74 |
75 | def validate_dtype(self, dtype: PolarsDataType) -> bool:
76 | if not isinstance(dtype, pl.Enum):
77 | return False
78 | return self.categories == dtype.categories.to_list()
79 |
80 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
81 | category_lengths = [len(c) for c in self.categories]
82 | if all(length == category_lengths[0] for length in category_lengths):
83 | return sa.CHAR(category_lengths[0])
84 | return sa.String(max(category_lengths))
85 |
86 | @property
87 | def pyarrow_dtype(self) -> pa.DataType:
88 | return pa.dictionary(pa.uint32(), pa.large_string())
89 |
90 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
91 | return generator.sample_choice(
92 | n, choices=self.categories, null_probability=self._null_probability
93 | ).cast(self.dtype)
94 |
--------------------------------------------------------------------------------
/dataframely/columns/object.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | from collections.abc import Callable
7 | from typing import Any
8 |
9 | import polars as pl
10 |
11 | from dataframely._compat import pa, sa, sa_TypeEngine
12 | from dataframely.random import Generator
13 |
14 | from ._base import Column
15 |
16 |
17 | class Object(Column):
18 | """A Python Object column."""
19 |
20 | def __init__(
21 | self,
22 | *,
23 | nullable: bool = True,
24 | primary_key: bool = False,
25 | check: (
26 | Callable[[pl.Expr], pl.Expr]
27 | | list[Callable[[pl.Expr], pl.Expr]]
28 | | dict[str, Callable[[pl.Expr], pl.Expr]]
29 | | None
30 | ) = None,
31 | alias: str | None = None,
32 | metadata: dict[str, Any] | None = None,
33 | ):
34 | """
35 | Args:
36 | nullable: Whether this column may contain null values.
37 | primary_key: Whether this column is part of the primary key of the schema.
38 | check: A custom rule or multiple rules to run for this column. This can be:
39 | - A single callable that returns a non-aggregated boolean expression.
40 | The name of the rule is derived from the callable name, or defaults to
41 | "check" for lambdas.
42 | - A list of callables, where each callable returns a non-aggregated
43 | boolean expression. The name of the rule is derived from the callable
44 | name, or defaults to "check" for lambdas. Where multiple rules result
45 | in the same name, the suffix __i is appended to the name.
46 | - A dictionary mapping rule names to callables, where each callable
47 | returns a non-aggregated boolean expression.
48 | All rule names provided here are given the prefix "check_".
49 | alias: An overwrite for this column's name which allows for using a column
50 | name that is not a valid Python identifier. Especially note that setting
51 | this option does _not_ allow to refer to the column with two different
52 | names, the specified alias is the only valid name.
53 | metadata: A dictionary of metadata to attach to the column.
54 | """
55 | super().__init__(
56 | nullable=nullable,
57 | primary_key=primary_key,
58 | check=check,
59 | alias=alias,
60 | metadata=metadata,
61 | )
62 |
63 | @property
64 | def dtype(self) -> pl.DataType:
65 | return pl.Object()
66 |
67 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
68 | raise NotImplementedError("SQL column cannot have 'Object' type.")
69 |
70 | @property
71 | def pyarrow_dtype(self) -> pa.DataType:
72 | raise NotImplementedError("PyArrow column cannot have 'Object' type.")
73 |
74 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
75 | raise NotImplementedError(
76 | "Random data sampling not implemented for 'Object' type."
77 | )
78 |
--------------------------------------------------------------------------------
/dataframely/columns/struct.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from __future__ import annotations
5 |
6 | from collections.abc import Callable
7 | from typing import Any
8 |
9 | import polars as pl
10 |
11 | from dataframely._compat import pa, sa, sa_TypeEngine
12 | from dataframely._polars import PolarsDataType
13 | from dataframely.random import Generator
14 |
15 | from ._base import Column
16 |
17 |
18 | class Struct(Column):
19 | """A struct column."""
20 |
21 | def __init__(
22 | self,
23 | inner: dict[str, Column],
24 | *,
25 | nullable: bool | None = None,
26 | primary_key: bool = False,
27 | check: (
28 | Callable[[pl.Expr], pl.Expr]
29 | | list[Callable[[pl.Expr], pl.Expr]]
30 | | dict[str, Callable[[pl.Expr], pl.Expr]]
31 | | None
32 | ) = None,
33 | alias: str | None = None,
34 | metadata: dict[str, Any] | None = None,
35 | ):
36 | """
37 | Args:
38 | inner: The dictionary of struct fields. Struct fields may have
39 | ``primary_key=True`` set but this setting only takes effect if the
40 | struct is nested inside a list. In this case, the list items must be
41 | unique wrt. the struct fields that have ``primary_key=True`` set.
42 | nullable: Whether this column may contain null values.
43 | Explicitly set `nullable=True` if you want your column to be nullable.
44 | In a future release, `nullable=False` will be the default if `nullable`
45 | is not specified.
46 | primary_key: Whether this column is part of the primary key of the schema.
47 | check: A custom rule or multiple rules to run for this column. This can be:
48 | - A single callable that returns a non-aggregated boolean expression.
49 | The name of the rule is derived from the callable name, or defaults to
50 | "check" for lambdas.
51 | - A list of callables, where each callable returns a non-aggregated
52 | boolean expression. The name of the rule is derived from the callable
53 | name, or defaults to "check" for lambdas. Where multiple rules result
54 | in the same name, the suffix __i is appended to the name.
55 | - A dictionary mapping rule names to callables, where each callable
56 | returns a non-aggregated boolean expression.
57 | All rule names provided here are given the prefix "check_".
58 | alias: An overwrite for this column's name which allows for using a column
59 | name that is not a valid Python identifier. Especially note that setting
60 | this option does _not_ allow to refer to the column with two different
61 | names, the specified alias is the only valid name.
62 | metadata: A dictionary of metadata to attach to the column.
63 | """
64 | super().__init__(
65 | nullable=nullable,
66 | primary_key=primary_key,
67 | check=check,
68 | alias=alias,
69 | metadata=metadata,
70 | )
71 | self.inner = inner
72 |
73 | @property
74 | def dtype(self) -> pl.DataType:
75 | return pl.Struct({name: col.dtype for name, col in self.inner.items()})
76 |
77 | def validate_dtype(self, dtype: PolarsDataType) -> bool:
78 | if not isinstance(dtype, pl.Struct):
79 | return False
80 | if len(dtype.fields) != len(self.inner):
81 | return False
82 |
83 | fields = {field.name: field.dtype for field in dtype.fields}
84 | for name, col in self.inner.items():
85 | field_dtype = fields.get(name)
86 | if field_dtype is None:
87 | return False
88 | if not col.validate_dtype(field_dtype):
89 | return False
90 | return True
91 |
92 | def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
93 | inner_rules = {
94 | f"inner_{name}_{rule_name}": (
95 | pl.when(expr.is_null()).then(pl.lit(True)).otherwise(inner_expr)
96 | )
97 | for name, col in self.inner.items()
98 | for rule_name, inner_expr in col.validation_rules(
99 | expr.struct.field(name)
100 | ).items()
101 | }
102 | return {
103 | **super().validation_rules(expr),
104 | **inner_rules,
105 | }
106 |
107 | def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
108 | # NOTE: We might want to add support for PostgreSQL's JSON in the future.
109 | raise NotImplementedError("SQL column cannot have 'Struct' type.")
110 |
111 | @property
112 | def pyarrow_dtype(self) -> pa.DataType:
113 | return pa.struct({name: col.pyarrow_dtype for name, col in self.inner.items()})
114 |
115 | def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
116 | series = (
117 | pl.DataFrame(
118 | {name: col.sample(generator, n) for name, col in self.inner.items()}
119 | )
120 | .select(pl.struct(pl.all()))
121 | .to_series()
122 | )
123 | # Apply a null mask.
124 | return generator._apply_null_mask(
125 | series, null_probability=self._null_probability
126 | )
127 |
--------------------------------------------------------------------------------
/dataframely/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import contextlib
5 | from types import TracebackType
6 | from typing import TypedDict, Unpack
7 |
8 |
9 | class Options(TypedDict):
10 | #: The maximum number of iterations to use for "fuzzy" sampling.
11 | max_sampling_iterations: int
12 |
13 |
14 | def default_options() -> Options:
15 | return {
16 | "max_sampling_iterations": 10_000,
17 | }
18 |
19 |
20 | class Config(contextlib.ContextDecorator):
21 | """An object to track global configuration for operations in dataframely."""
22 |
23 | #: The currently valid config options.
24 | options: Options = default_options()
25 | #: Singleton stack to track where to go back after exiting a context.
26 | _stack: list[Options] = []
27 |
28 | def __init__(self, **options: Unpack[Options]) -> None:
29 | self._local_options: Options = {**default_options(), **options}
30 |
31 | @staticmethod
32 | def set_max_sampling_iterations(iterations: int) -> None:
33 | """Set the maximum number of sampling iterations to use on
34 | :meth:`Schema.sample`."""
35 | Config.options["max_sampling_iterations"] = iterations
36 |
37 | @staticmethod
38 | def restore_defaults() -> None:
39 | """Restore the defaults of the configuration."""
40 | Config.options = default_options()
41 |
42 | # ------------------------------------ CONTEXT ----------------------------------- #
43 |
44 | def __enter__(self) -> None:
45 | Config._stack.append(Config.options)
46 | Config.options = self._local_options
47 |
48 | def __exit__(
49 | self,
50 | exc_type: type[BaseException] | None,
51 | exc_val: BaseException | None,
52 | exc_tb: TracebackType | None,
53 | ) -> None:
54 | Config.options = Config._stack.pop()
55 |
--------------------------------------------------------------------------------
/dataframely/exc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections import defaultdict
5 |
6 | import polars as pl
7 |
8 | from ._polars import PolarsDataType
9 |
10 |
11 | class ValidationError(Exception):
12 | """Error raised when :mod:`dataframely` validation encounters an issue."""
13 |
14 | def __init__(self, message: str) -> None:
15 | super().__init__()
16 | self.message = message
17 |
18 | def __str__(self) -> str:
19 | return self.message
20 |
21 |
22 | class DtypeValidationError(ValidationError):
23 | """Validation error raised when column dtypes are wrong."""
24 |
25 | def __init__(
26 | self, errors: dict[str, tuple[PolarsDataType, PolarsDataType]]
27 | ) -> None:
28 | super().__init__(f"{len(errors)} columns have an invalid dtype")
29 | self.errors = errors
30 |
31 | def __str__(self) -> str:
32 | details = [
33 | f" - '{col}': got dtype '{actual}' but expected '{expected}'"
34 | for col, (actual, expected) in self.errors.items()
35 | ]
36 | return "\n".join([f"{self.message}:"] + details)
37 |
38 |
39 | class RuleValidationError(ValidationError):
40 | """Complex validation error raised when rule validation fails."""
41 |
42 | def __init__(self, errors: dict[str, int]) -> None:
43 | super().__init__(f"{len(errors)} rules failed validation")
44 |
45 | # Split into schema errors and column errors
46 | schema_errors: dict[str, int] = {}
47 | column_errors: dict[str, dict[str, int]] = defaultdict(dict)
48 | for name, count in sorted(errors.items()):
49 | if "|" in name:
50 | column, rule = name.split("|", maxsplit=1)
51 | column_errors[column][rule] = count
52 | else:
53 | schema_errors[name] = count
54 |
55 | self.schema_errors = schema_errors
56 | self.column_errors = column_errors
57 |
58 | def __str__(self) -> str:
59 | schema_details = [
60 | f" - '{name}' failed validation for {count:,} rows"
61 | for name, count in self.schema_errors.items()
62 | ]
63 | column_details = [
64 | msg
65 | for column, errors in self.column_errors.items()
66 | for msg in (
67 | [f" * Column '{column}' failed validation for {len(errors)} rules:"]
68 | + [
69 | f" - '{name}' failed for {count:,} rows"
70 | for name, count in errors.items()
71 | ]
72 | )
73 | ]
74 | return "\n".join([f"{self.message}:"] + schema_details + column_details)
75 |
76 |
77 | class MemberValidationError(ValidationError):
78 | """Validation error raised when multiple members of a collection fail validation."""
79 |
80 | def __init__(self, errors: dict[str, ValidationError]) -> None:
81 | super().__init__(f"{len(errors)} members failed validation")
82 | self.errors = errors
83 |
84 | def __str__(self) -> str:
85 | details = [
86 | f" > Member '{name}' failed validation:\n"
87 | + "\n".join(" " + line for line in str(error).split("\n"))
88 | for name, error in self.errors.items()
89 | ]
90 | return "\n".join([f"{self.message}:"] + details)
91 |
92 |
93 | class ImplementationError(Exception):
94 | """Error raised when a schema is implemented incorrectly."""
95 |
96 |
97 | class AnnotationImplementationError(ImplementationError):
98 | """Error raised when the annotations of a collection are invalid."""
99 |
100 | def __init__(self, attr: str, kls: type) -> None:
101 | message = (
102 | "Annotations of a 'dy.Collection' may only be an (optional) "
103 | f"'dy.LazyFrame', but \"{attr}\" has type '{kls}'."
104 | )
105 | super().__init__(message)
106 |
107 |
108 | class RuleImplementationError(ImplementationError):
109 | """Error raised when a rule is implemented incorrectly."""
110 |
111 | def __init__(
112 | self, name: str, return_dtype: pl.DataType, is_group_rule: bool
113 | ) -> None:
114 | if is_group_rule:
115 | details = (
116 | " When implementing a group rule (i.e. when using the `group_by` "
117 | "parameter), make sure to use an aggregation function such as `.any()`, "
118 | "`.all()`, and others to reduce an expression evaluated on multiple "
119 | "rows in the same group to a single boolean value for the group."
120 | )
121 | else:
122 | details = ""
123 |
124 | message = (
125 | f"Validation rule '{name}' has not been implemented correctly. It "
126 | f"returns dtype '{return_dtype}' but it must return a boolean value."
127 | + details
128 | )
129 | super().__init__(message)
130 |
--------------------------------------------------------------------------------
/dataframely/functional.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Sequence
5 | from typing import TypeVar
6 |
7 | import polars as pl
8 |
9 | from ._base_collection import BaseCollection
10 | from ._typing import LazyFrame
11 | from .schema import Schema
12 |
13 | S = TypeVar("S", bound=Schema)
14 | T = TypeVar("T", bound=Schema)
15 |
16 | # NOTE: Binding to `BaseCollection` is required here as the TypeVar default for the
17 | # sampling type otherwise causes issues for Python 3.13.
18 | C = TypeVar("C", bound=BaseCollection)
19 |
20 | # ------------------------------------------------------------------------------------ #
21 | # FILTER #
22 | # ------------------------------------------------------------------------------------ #
23 |
24 | # --------------------------------- RELATIONSHIP 1:1 --------------------------------- #
25 |
26 |
27 | def filter_relationship_one_to_one(
28 | lhs: LazyFrame[S] | pl.LazyFrame,
29 | rhs: LazyFrame[T] | pl.LazyFrame,
30 | /,
31 | on: str | list[str],
32 | ) -> pl.LazyFrame:
33 | """Express a 1:1 mapping between data frames for a collection filter.
34 |
35 | Args:
36 | lhs: The first data frame in the 1:1 mapping.
37 | rhs: The second data frame in the 1:1 mapping.
38 | on: The columns to join the data frames on. If not provided, the join columns
39 | are inferred from the mutual primary keys of the provided data frames.
40 | """
41 | return lhs.join(rhs, on=on)
42 |
43 |
44 | # ------------------------------- RELATIONSHIP 1:{1,N} ------------------------------- #
45 |
46 |
47 | def filter_relationship_one_to_at_least_one(
48 | lhs: LazyFrame[S] | pl.LazyFrame,
49 | rhs: LazyFrame[T] | pl.LazyFrame,
50 | /,
51 | on: str | list[str],
52 | ) -> pl.LazyFrame:
53 | """Express a 1:{1,N} mapping between data frames for a collection filter.
54 |
55 | Args:
56 | lhs: The data frame with exactly one occurrence for a set of key columns.
57 | rhs: The data frame with at least one occurrence for a set of key columns.
58 | on: The columns to join the data frames on. If not provided, the join columns
59 | are inferred from the joint primary keys of the provided data frames.
60 | """
61 | return lhs.join(rhs.unique(on), on=on)
62 |
63 |
64 | # ------------------------------------------------------------------------------------ #
65 | # CONCAT #
66 | # ------------------------------------------------------------------------------------ #
67 |
68 |
69 | def concat_collection_members(collections: Sequence[C], /) -> dict[str, pl.LazyFrame]:
70 | """Concatenate the members of collections with the same type.
71 |
72 | Args:
73 | collections: The collections whose members to concatenate. Optional members
74 | are concatenated only from the collections that provide them.
75 |
76 | Returns:
77 | A mapping from member names to a lazy concatenation of data frames. All keys
78 | are guaranteed to be valid members of the collection.
79 | """
80 | if len(collections) == 0:
81 | raise ValueError("Cannot concatenate less than one collection.")
82 | members = [c.to_dict() for c in collections]
83 | key_union = set(members[0]).union(*members[1:])
84 | return {
85 | key: pl.concat(
86 | [member_dict[key] for member_dict in members if key in member_dict]
87 | )
88 | for key in key_union
89 | }
90 |
--------------------------------------------------------------------------------
/dataframely/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Quantco/dataframely/1a8f64e031586416677954f12ed41276b4325945/dataframely/py.typed
--------------------------------------------------------------------------------
/dataframely/testing/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from .const import (
5 | ALL_COLUMN_TYPES,
6 | COLUMN_TYPES,
7 | FLOAT_COLUMN_TYPES,
8 | INTEGER_COLUMN_TYPES,
9 | NO_VALIDATION_COLUMN_TYPES,
10 | SUPERTYPE_COLUMN_TYPES,
11 | )
12 | from .factory import create_collection, create_collection_raw, create_schema
13 | from .mask import validation_mask
14 | from .rules import evaluate_rules, rules_from_exprs
15 |
16 | __all__ = [
17 | "ALL_COLUMN_TYPES",
18 | "COLUMN_TYPES",
19 | "FLOAT_COLUMN_TYPES",
20 | "INTEGER_COLUMN_TYPES",
21 | "SUPERTYPE_COLUMN_TYPES",
22 | "NO_VALIDATION_COLUMN_TYPES",
23 | "create_collection",
24 | "create_collection_raw",
25 | "create_schema",
26 | "validation_mask",
27 | "evaluate_rules",
28 | "rules_from_exprs",
29 | ]
30 |
--------------------------------------------------------------------------------
/dataframely/testing/const.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import dataframely.columns as dc
5 |
6 | COLUMN_TYPES: list[type[dc.Column]] = [
7 | dc.Bool,
8 | dc.Date,
9 | dc.Datetime,
10 | dc.Time,
11 | dc.Decimal,
12 | dc.Duration,
13 | dc.Float32,
14 | dc.Float64,
15 | dc.Int8,
16 | dc.Int16,
17 | dc.Int32,
18 | dc.Int64,
19 | dc.UInt8,
20 | dc.UInt16,
21 | dc.UInt32,
22 | dc.UInt64,
23 | dc.String,
24 | ]
25 | INTEGER_COLUMN_TYPES: list[type[dc.Column]] = [
26 | dc.Integer,
27 | dc.Int8,
28 | dc.Int16,
29 | dc.Int32,
30 | dc.Int64,
31 | dc.UInt8,
32 | dc.UInt16,
33 | dc.UInt32,
34 | dc.UInt64,
35 | ]
36 | FLOAT_COLUMN_TYPES: list[type[dc.Column]] = [
37 | dc.Float,
38 | dc.Float32,
39 | dc.Float64,
40 | ]
41 |
42 | SUPERTYPE_COLUMN_TYPES: list[type[dc.Column]] = [
43 | dc.Float,
44 | dc.Integer,
45 | ]
46 |
47 | ALL_COLUMN_TYPES: list[type[dc.Column]] = (
48 | [dc.Any] + COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES
49 | )
50 |
51 | # The following is a list of column types that, when created with default parameter values, add no validation rules.
52 | NO_VALIDATION_COLUMN_TYPES: list[type[dc.Column]] = [
53 | t for t in ALL_COLUMN_TYPES if t not in FLOAT_COLUMN_TYPES
54 | ]
55 |
--------------------------------------------------------------------------------
/dataframely/testing/factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Any
5 |
6 | from dataframely._filter import Filter
7 | from dataframely._rule import Rule
8 | from dataframely._typing import LazyFrame
9 | from dataframely.collection import Collection
10 | from dataframely.columns import Column
11 | from dataframely.schema import Schema
12 |
13 |
14 | def create_schema(
15 | name: str,
16 | columns: dict[str, Column],
17 | rules: dict[str, Rule] | None = None,
18 | ) -> type[Schema]:
19 | """Dynamically create a new schema with the provided name.
20 |
21 | Args:
22 | name: The name of the schema.
23 | columns: The columns to set on the schema. When properly defining the schema,
24 | this would be the annotations that define the column types.
25 | rules: The custom non-column-specific validation rules. When properly defining
26 | the schema, this would be the functions annotated with ``@dy.rule``.
27 |
28 | Returns:
29 | The dynamically created schema.
30 | """
31 | return type(name, (Schema,), {**columns, **(rules or {})})
32 |
33 |
34 | def create_collection(
35 | name: str,
36 | schemas: dict[str, type[Schema]],
37 | filters: dict[str, Filter] | None = None,
38 | *,
39 | annotation_base_class: type = LazyFrame,
40 | ) -> type[Collection]:
41 | return create_collection_raw(
42 | name,
43 | annotations={
44 | name: annotation_base_class[schema] # type: ignore
45 | for name, schema in schemas.items()
46 | },
47 | filters=filters,
48 | )
49 |
50 |
51 | def create_collection_raw(
52 | name: str,
53 | annotations: dict[str, Any],
54 | filters: dict[str, Filter] | None = None,
55 | ) -> type[Collection]:
56 | return type(
57 | name,
58 | (Collection,),
59 | {
60 | "__annotations__": annotations,
61 | **(filters or {}),
62 | },
63 | )
64 |
--------------------------------------------------------------------------------
/dataframely/testing/mask.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | from dataframely.failure import FailureInfo
7 |
8 |
9 | def validation_mask(df: pl.DataFrame | pl.LazyFrame, failure: FailureInfo) -> pl.Series:
10 | """Build a validation mask for the left data frame based on the failure info.
11 |
12 | Args:
13 | df: The data frame for whose rows to generate the validation mask.
14 | failure: The failure object whose information should be used to determine
15 | which rows of the input data frame are invalid.
16 |
17 | Returns:
18 | A series where with the same length as the input data frame where a value of
19 | ``True`` indicates validity and ``False`` the opposite.
20 |
21 | Raises:
22 | ValueError: If columns with a dtype of struct or nested list is contained in
23 | the input. In polars v1.1.0, both of these do not work reliably.
24 | """
25 | if any(
26 | isinstance(dtype, pl.List) and isinstance(dtype.inner, pl.List)
27 | for dtype in df.collect_schema().dtypes()
28 | ): # pragma: no cover
29 | raise ValueError("`validation_mask` currently does not allow for nested lists.")
30 | if any(
31 | isinstance(dtype, pl.Struct) for dtype in df.collect_schema().dtypes()
32 | ): # pragma: no cover
33 | raise ValueError("`validation_mask` currently does not allow for structs.")
34 |
35 | return (
36 | df.lazy()
37 | .collect()
38 | .join(
39 | failure.invalid().unique().with_columns(__marker__=pl.lit(True)),
40 | on=list(df.collect_schema()),
41 | how="left",
42 | nulls_equal=True,
43 | )
44 | .select(pl.col("__marker__").is_null())
45 | .to_series()
46 | )
47 |
--------------------------------------------------------------------------------
/dataframely/testing/rules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | from dataframely._rule import Rule, with_evaluation_rules
7 |
8 |
9 | def rules_from_exprs(exprs: dict[str, pl.Expr]) -> dict[str, Rule]:
10 | """Turn a set of expressions into simple rules.
11 |
12 | Args:
13 | exprs: The expressions, mapping from names to :class:`polars.Expr`.
14 |
15 | Returns:
16 | The rules corresponding to the expressions.
17 | """
18 | return {name: Rule(expr) for name, expr in exprs.items()}
19 |
20 |
21 | def evaluate_rules(lf: pl.LazyFrame, rules: dict[str, Rule]) -> pl.LazyFrame:
22 | """Evaluate the provided rules and return the rules' evaluation.
23 |
24 | Args:
25 | lf: The data frame on which to evaluate the rules.
26 | rules: The rules to evaluate where the key of the dictionary provides the name
27 | of the rule.
28 |
29 | Returns:
30 | The same return value as :meth:`with_evaluation_rules` only that the columns
31 | of the input data frame are dropped.
32 | """
33 | return with_evaluation_rules(lf, rules).drop(lf.collect_schema())
34 |
--------------------------------------------------------------------------------
/dataframely/testing/typing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from dataframely import (
5 | Any,
6 | Date,
7 | Datetime,
8 | Decimal,
9 | Enum,
10 | Float32,
11 | Int64,
12 | List,
13 | Schema,
14 | Struct,
15 | )
16 |
17 |
18 | class MyImportedBaseSchema(Schema):
19 | a = Int64()
20 |
21 |
22 | class MyImportedSchema(MyImportedBaseSchema):
23 | b = Float32()
24 | c = Enum(["a", "b", "c"])
25 | d = Struct({"a": Int64(), "b": Struct({"c": Enum(["a", "b"])})})
26 | e = List(Struct({"a": Int64()}))
27 | f = Datetime()
28 | g = Date()
29 | h = Any()
30 | some_decimal = Decimal(12, 8)
31 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | mssql:
4 | image: mcr.microsoft.com/azure-sql-edge:latest
5 | environment:
6 | ACCEPT_EULA: Y
7 | MSSQL_USER: sa
8 | SA_PASSWORD: P@ssword1
9 | ports:
10 | - "1455:1433"
11 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.collection.rst:
--------------------------------------------------------------------------------
1 | dataframely.collection module
2 | =============================
3 |
4 | .. automodule:: dataframely.collection
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.any.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.any module
2 | ==============================
3 |
4 | .. automodule:: dataframely.columns.any
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.bool.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.bool module
2 | ===============================
3 |
4 | .. automodule:: dataframely.columns.bool
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.datetime.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.datetime module
2 | ===================================
3 |
4 | .. automodule:: dataframely.columns.datetime
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.decimal.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.decimal module
2 | ==================================
3 |
4 | .. automodule:: dataframely.columns.decimal
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.enum.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.enum module
2 | ===============================
3 |
4 | .. automodule:: dataframely.columns.enum
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.float.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.float module
2 | ================================
3 |
4 | .. automodule:: dataframely.columns.float
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.integer.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.integer module
2 | ==================================
3 |
4 | .. automodule:: dataframely.columns.integer
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.list.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.list module
2 | ===============================
3 |
4 | .. automodule:: dataframely.columns.list
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns package
2 | ===========================
3 |
4 | .. automodule:: dataframely.columns
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
9 | Submodules
10 | ----------
11 |
12 | dataframely.columns.any module
13 | ------------------------------
14 |
15 | .. automodule:: dataframely.columns.any
16 | :members:
17 | :show-inheritance:
18 | :undoc-members:
19 |
20 | dataframely.columns.array module
21 | --------------------------------
22 |
23 | .. automodule:: dataframely.columns.array
24 | :members:
25 | :show-inheritance:
26 | :undoc-members:
27 |
28 | dataframely.columns.bool module
29 | -------------------------------
30 |
31 | .. automodule:: dataframely.columns.bool
32 | :members:
33 | :show-inheritance:
34 | :undoc-members:
35 |
36 | dataframely.columns.datetime module
37 | -----------------------------------
38 |
39 | .. automodule:: dataframely.columns.datetime
40 | :members:
41 | :show-inheritance:
42 | :undoc-members:
43 |
44 | dataframely.columns.decimal module
45 | ----------------------------------
46 |
47 | .. automodule:: dataframely.columns.decimal
48 | :members:
49 | :show-inheritance:
50 | :undoc-members:
51 |
52 | dataframely.columns.enum module
53 | -------------------------------
54 |
55 | .. automodule:: dataframely.columns.enum
56 | :members:
57 | :show-inheritance:
58 | :undoc-members:
59 |
60 | dataframely.columns.float module
61 | --------------------------------
62 |
63 | .. automodule:: dataframely.columns.float
64 | :members:
65 | :show-inheritance:
66 | :undoc-members:
67 |
68 | dataframely.columns.integer module
69 | ----------------------------------
70 |
71 | .. automodule:: dataframely.columns.integer
72 | :members:
73 | :show-inheritance:
74 | :undoc-members:
75 |
76 | dataframely.columns.list module
77 | -------------------------------
78 |
79 | .. automodule:: dataframely.columns.list
80 | :members:
81 | :show-inheritance:
82 | :undoc-members:
83 |
84 | dataframely.columns.object module
85 | ---------------------------------
86 |
87 | .. automodule:: dataframely.columns.object
88 | :members:
89 | :show-inheritance:
90 | :undoc-members:
91 |
92 | dataframely.columns.string module
93 | ---------------------------------
94 |
95 | .. automodule:: dataframely.columns.string
96 | :members:
97 | :show-inheritance:
98 | :undoc-members:
99 |
100 | dataframely.columns.struct module
101 | ---------------------------------
102 |
103 | .. automodule:: dataframely.columns.struct
104 | :members:
105 | :show-inheritance:
106 | :undoc-members:
107 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.string.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.string module
2 | =================================
3 |
4 | .. automodule:: dataframely.columns.string
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.columns.struct.rst:
--------------------------------------------------------------------------------
1 | dataframely.columns.struct module
2 | =================================
3 |
4 | .. automodule:: dataframely.columns.struct
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.config.rst:
--------------------------------------------------------------------------------
1 | dataframely.config module
2 | =========================
3 |
4 | .. automodule:: dataframely.config
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.exc.rst:
--------------------------------------------------------------------------------
1 | dataframely.exc module
2 | ======================
3 |
4 | .. automodule:: dataframely.exc
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.failure.rst:
--------------------------------------------------------------------------------
1 | dataframely.failure module
2 | ==========================
3 |
4 | .. automodule:: dataframely.failure
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.functional.rst:
--------------------------------------------------------------------------------
1 | dataframely.functional module
2 | =============================
3 |
4 | .. automodule:: dataframely.functional
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.mypy.rst:
--------------------------------------------------------------------------------
1 | dataframely.mypy module
2 | =======================
3 |
4 | .. automodule:: dataframely.mypy
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.random.rst:
--------------------------------------------------------------------------------
1 | dataframely.random module
2 | =========================
3 |
4 | .. automodule:: dataframely.random
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.rst:
--------------------------------------------------------------------------------
1 | dataframely package
2 | ===================
3 |
4 | .. automodule:: dataframely
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
9 | Subpackages
10 | -----------
11 |
12 | .. toctree::
13 | :maxdepth: 4
14 |
15 | dataframely.columns
16 | dataframely.testing
17 |
18 | Submodules
19 | ----------
20 |
21 | dataframely.collection module
22 | -----------------------------
23 |
24 | .. automodule:: dataframely.collection
25 | :members:
26 | :show-inheritance:
27 | :undoc-members:
28 |
29 | dataframely.config module
30 | -------------------------
31 |
32 | .. automodule:: dataframely.config
33 | :members:
34 | :show-inheritance:
35 | :undoc-members:
36 |
37 | dataframely.exc module
38 | ----------------------
39 |
40 | .. automodule:: dataframely.exc
41 | :members:
42 | :show-inheritance:
43 | :undoc-members:
44 |
45 | dataframely.failure module
46 | --------------------------
47 |
48 | .. automodule:: dataframely.failure
49 | :members:
50 | :show-inheritance:
51 | :undoc-members:
52 |
53 | dataframely.functional module
54 | -----------------------------
55 |
56 | .. automodule:: dataframely.functional
57 | :members:
58 | :show-inheritance:
59 | :undoc-members:
60 |
61 | dataframely.mypy module
62 | -----------------------
63 |
64 | .. automodule:: dataframely.mypy
65 | :members:
66 | :show-inheritance:
67 | :undoc-members:
68 |
69 | dataframely.random module
70 | -------------------------
71 |
72 | .. automodule:: dataframely.random
73 | :members:
74 | :show-inheritance:
75 | :undoc-members:
76 |
77 | dataframely.schema module
78 | -------------------------
79 |
80 | .. automodule:: dataframely.schema
81 | :members:
82 | :show-inheritance:
83 | :undoc-members:
84 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.schema.rst:
--------------------------------------------------------------------------------
1 | dataframely.schema module
2 | =========================
3 |
4 | .. automodule:: dataframely.schema
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.const.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing.const module
2 | ================================
3 |
4 | .. automodule:: dataframely.testing.const
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.factory.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing.factory module
2 | ==================================
3 |
4 | .. automodule:: dataframely.testing.factory
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.mask.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing.mask module
2 | ===============================
3 |
4 | .. automodule:: dataframely.testing.mask
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing package
2 | ===========================
3 |
4 | .. automodule:: dataframely.testing
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
9 | Submodules
10 | ----------
11 |
12 | dataframely.testing.const module
13 | --------------------------------
14 |
15 | .. automodule:: dataframely.testing.const
16 | :members:
17 | :show-inheritance:
18 | :undoc-members:
19 |
20 | dataframely.testing.factory module
21 | ----------------------------------
22 |
23 | .. automodule:: dataframely.testing.factory
24 | :members:
25 | :show-inheritance:
26 | :undoc-members:
27 |
28 | dataframely.testing.mask module
29 | -------------------------------
30 |
31 | .. automodule:: dataframely.testing.mask
32 | :members:
33 | :show-inheritance:
34 | :undoc-members:
35 |
36 | dataframely.testing.rules module
37 | --------------------------------
38 |
39 | .. automodule:: dataframely.testing.rules
40 | :members:
41 | :show-inheritance:
42 | :undoc-members:
43 |
44 | dataframely.testing.typing module
45 | ---------------------------------
46 |
47 | .. automodule:: dataframely.testing.typing
48 | :members:
49 | :show-inheritance:
50 | :undoc-members:
51 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.rules.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing.rules module
2 | ================================
3 |
4 | .. automodule:: dataframely.testing.rules
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/dataframely.testing.typing.rst:
--------------------------------------------------------------------------------
1 | dataframely.testing.typing module
2 | =================================
3 |
4 | .. automodule:: dataframely.testing.typing
5 | :members:
6 | :show-inheritance:
7 | :undoc-members:
8 |
--------------------------------------------------------------------------------
/docs/_api/modules.rst:
--------------------------------------------------------------------------------
1 | dataframely
2 | ===========
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | dataframely
8 |
--------------------------------------------------------------------------------
/docs/_static/custom.css:
--------------------------------------------------------------------------------
1 | #furo-main-content .caption-text {
2 | display: none;
3 | }
4 |
--------------------------------------------------------------------------------
/docs/_static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Quantco/dataframely/1a8f64e031586416677954f12ed41276b4325945/docs/_static/favicon.ico
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | # Configuration file for the Sphinx documentation builder.
5 | #
6 | # This file only contains a selection of the most common options. For a full
7 | # list see the documentation:
8 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
9 |
10 | # -- Path setup --------------------------------------------------------------
11 |
12 | # If extensions (or modules to document with autodoc) are in another directory,
13 | # add these directories to sys.path here. If the directory is relative to the
14 | # documentation root, use os.path.abspath to make it absolute, like shown here.
15 | #
16 | # import os
17 | # import sys
18 | # sys.path.insert(0, os.path.abspath('.'))
19 |
20 |
21 | # -- Project information -----------------------------------------------------
22 |
23 | import datetime
24 | import importlib
25 | import inspect
26 | import os
27 | import subprocess
28 | import sys
29 | from subprocess import CalledProcessError
30 | from typing import cast
31 |
32 | _mod = importlib.import_module("dataframely")
33 |
34 |
35 | project = "dataframely"
36 | copyright = f"{datetime.date.today().year}, QuantCo, Inc"
37 | author = "QuantCo, Inc."
38 |
39 | extensions = [
40 | "nbsphinx",
41 | "numpydoc",
42 | "sphinx_copybutton",
43 | "sphinx.ext.autodoc",
44 | "sphinx.ext.linkcode",
45 | "sphinxcontrib.apidoc",
46 | ]
47 |
48 | numpydoc_class_members_toctree = False
49 |
50 | apidoc_module_dir = "../dataframely"
51 | apidoc_output_dir = "_api"
52 | apidoc_module_first = True
53 | apidoc_extra_args = ["--implicit-namespaces"]
54 |
55 | autodoc_default_options = {
56 | "inherited-members": True,
57 | }
58 |
59 | templates_path = ["_templates"]
60 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
61 | html_theme = "furo"
62 | html_title = "Dataframely"
63 | html_static_path = ["_static"]
64 | html_css_files = ["custom.css"]
65 | html_favicon = "_static/favicon.ico"
66 |
67 |
68 | # Copied and adapted from
69 | # https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659
70 | # Required configuration function to use sphinx.ext.linkcode
71 | def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
72 | """Determine the URL corresponding to a given Python object."""
73 | if domain != "py":
74 | return None
75 |
76 | module_name = info["module"]
77 | full_name = info["fullname"]
78 |
79 | _submodule = sys.modules.get(module_name)
80 | if _submodule is None:
81 | return None
82 |
83 | _object = _submodule
84 | for _part in full_name.split("."):
85 | try:
86 | _object = getattr(_object, _part)
87 | except AttributeError:
88 | return None
89 |
90 | try:
91 | fn = inspect.getsourcefile(inspect.unwrap(_object)) # type: ignore
92 | except TypeError:
93 | fn = None
94 | if not fn:
95 | return None
96 |
97 | try:
98 | source, line_number = inspect.getsourcelines(_object)
99 | except OSError:
100 | line_number = None
101 |
102 | if line_number:
103 | linespec = f"#L{line_number}-L{line_number + len(source) - 1}"
104 | else:
105 | linespec = ""
106 |
107 | fn = os.path.relpath(fn, start=os.path.dirname(cast(str, _mod.__file__)))
108 |
109 | try:
110 | # See https://stackoverflow.com/a/21901260
111 | commit = (
112 | subprocess.check_output(["git", "rev-parse", "HEAD"])
113 | .decode("ascii")
114 | .strip()
115 | )
116 | except CalledProcessError:
117 | # If subprocess returns non-zero exit status
118 | commit = "main"
119 |
120 | return (
121 | "https://github.com/quantco/dataframely"
122 | f"/blob/{commit}/{_mod.__name__.replace('.', '/')}/{fn}{linespec}"
123 | )
124 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Dataframely
2 | ============
3 |
4 | Dataframely is a Python package to validate the schema and content of `polars `_ data frames.
5 | Its purpose is to make data pipelines more robust by ensuring that data meet expectations and more readable by adding schema information to data frame type hints.
6 |
7 | Features
8 | --------
9 |
10 | - Declaratively define schemas as classes with arbitrary inheritance structure
11 | - Specify column-specific validation rules (e.g. nullability, minimum string length, ...)
12 | - Specify cross-column and group validation rules with built-in support for checking the primary key property of a column set
13 | - Specify validation constraints across collections of interdependent data frames
14 | - Validate data frames softly by simply filtering out rows violating rules instead of failing hard
15 | - Introspect validation failure information for run-time failures
16 | - Enhanced type hints for validated data frames allowing users to clearly express expectations about inputs and outputs (i.e., contracts) in data pipelines
17 | - Integrate schemas with external tools (e.g., ``sqlalchemy`` or ``pyarrow``)
18 | - Generate test data that comply with a schema or collection of schemas and its validation rules
19 |
20 | Contents
21 | ========
22 |
23 | .. toctree::
24 | :caption: Contents
25 | :maxdepth: 1
26 |
27 | Installation
28 | Quickstart
29 | Real-world Example
30 | FAQ
31 | Development Guide
32 | Versioning
33 |
34 | API Documentation
35 | =================
36 |
37 | .. toctree::
38 | :caption: API Documentation
39 | :maxdepth: 1
40 |
41 | Collection <_api/dataframely.collection>
42 | Column Types <_api/dataframely.columns>
43 | Config <_api/dataframely.config>
44 | Random Data Generation <_api/dataframely.random>
45 | Failure Information <_api/dataframely.failure>
46 | Schema <_api/dataframely.schema>
47 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/sites/development.rst:
--------------------------------------------------------------------------------
1 | Development
2 | ===========
3 |
4 |
5 | Thanks for deciding to work on ``dataframely``!
6 | You can create a development environment with the following steps:
7 |
8 | Environment Installation
9 | ------------------------
10 |
11 | .. code-block:: bash
12 |
13 | git clone https://github.com/Quantco/dataframely
14 | cd dataframely
15 | pixi install
16 |
17 | Next make sure to install the package locally and set up pre-commit hooks:
18 |
19 | .. code-block:: bash
20 |
21 | pixi run postinstall
22 | pixi run pre-commit-install
23 |
24 |
25 | Running the tests
26 | -----------------
27 |
28 | .. code-block:: bash
29 |
30 | pixi run test
31 |
32 |
33 | You can adjust the ``tests/`` path to run tests in a specific directory or module.
34 |
35 |
36 | Building the Documentation
37 | --------------------------
38 |
39 | When updating the documentation, you can compile a localized build of the
40 | documentation and then open it in your web browser using the commands below:
41 |
42 | .. code-block:: bash
43 |
44 | # Run build
45 | pixi run -e docs postinstall
46 | pixi run docs
47 |
48 | # Open documentation
49 | open docs/_build/html/index.html
50 |
--------------------------------------------------------------------------------
/docs/sites/faq.rst:
--------------------------------------------------------------------------------
1 | FAQ
2 | ===
3 |
4 | Whenever you find out something that you were surprised by or needed some non-trivial
5 | thinking, please add it here.
6 |
--------------------------------------------------------------------------------
/docs/sites/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | To install ``dataframely``, use your favorite package manager, e.g., using ``pixi`` or ``pip``:
5 |
6 | .. code:: bash
7 |
8 | pixi add dataframely
9 | pip install dataframely
10 |
--------------------------------------------------------------------------------
/docs/sites/versioning.rst:
--------------------------------------------------------------------------------
1 | Versioning policy and breaking changes
2 | ======================================
3 |
4 | Dataframely uses `semantic versioning `_.
5 | This versioning scheme is designed to make it easy for users to anticipate what types of change they can expect from a given version update in their dependencies.
6 | We generally recommend that users take measures to control dependency versions. Personally, we like to use ``pixi`` as a package manager, which comes with builtin
7 | support for lockfiles. Many other package managers support similar functionality. When updating the lockfiles, we recommend to use automated testing
8 | to ensure that user code still works with newer versions of dependencies such as ``dataframely``.
9 |
10 | Most importantly, semantic versioning implies that breaking changes of user-facing functionality are only introduced in **major releases**.
11 | We therefore recommend that users are particularly vigilant when updating their environments to a newer major release of `dataframely`.
12 | As always, automated testing is useful here, but we also recommend checking the release notes `published on GitHub `_.
13 |
14 | In order to give users a heads-up before breaking changes are released, we introduce `FutureWarnings `_ .
15 | Warnings are the most direct and effective tool at our disposal for reaching users directly.
16 | We therefore generally recommend that users do not silence such warnings explicitly, but instead migrate their code proactively, whenever possible.
17 | However, we also understand that the need for migration may catch users at an inconvenient time, and a temporary band aid solution might be required.
18 | Users can disable ``FutureWarnings`` either through `python builtins `_,
19 | builtins from tools like `pytest `_ ,
20 | or by setting the ``DATAFRAMELY_NO_FUTURE_WARNINGS`` environment variable to ``true`` or ``1``.
21 |
--------------------------------------------------------------------------------
/pixi.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | channels = ["conda-forge"]
3 | description = "A declarative, polars-native data frame validation library."
4 | name = "dataframely"
5 | platforms = ["linux-64", "linux-aarch64", "osx-64", "osx-arm64", "win-64"]
6 |
7 | [tasks]
8 | postinstall = "pip install --no-build-isolation --no-deps --disable-pip-version-check -e ."
9 |
10 | [dependencies]
11 | python = ">=3.11"
12 | rust = "=1.85"
13 |
14 | numpy = "*"
15 | polars = ">=1.12"
16 |
17 | [host-dependencies]
18 | maturin = ">=1.7,<2"
19 | pip = "*"
20 |
21 | [feature.dev.dependencies]
22 | jupyter = "*"
23 | pandas = "*"
24 | scikit-learn = "*"
25 |
26 | [feature.docs.dependencies]
27 | furo = "*"
28 | ipython = "*"
29 | make = "*"
30 | # Needed for generating docs for dataframely.mypy
31 | mypy = "*"
32 | nbsphinx = "*"
33 | numpydoc = "*"
34 | sphinx = "*"
35 | sphinx-copybutton = "*"
36 | sphinx_rtd_theme = "*"
37 | sphinxcontrib-apidoc = "*"
38 |
39 | [feature.docs.tasks]
40 | docs = "cd docs && make html"
41 | readthedocs = "rm -rf $READTHEDOCS_OUTPUT/html && cp -r docs/_build/html $READTHEDOCS_OUTPUT/html"
42 |
43 | [feature.test.dependencies]
44 | mypy = ">=1.13"
45 | pyarrow = "*"
46 | pyodbc = "*"
47 | pytest = ">=6"
48 | pytest-cov = "*"
49 | pytest-md = "*"
50 | sqlalchemy = ">=2"
51 |
52 | [feature.test.tasks]
53 | test = "pytest"
54 | test-coverage = "pytest --cov=dataframely --cov-report=xml --cov-report=term-missing"
55 |
56 | [feature.build.dependencies]
57 | python-build = "*"
58 | setuptools-scm = "*"
59 | twine = "*"
60 | wheel = "*"
61 | [feature.build.target.unix.dependencies]
62 | sed = "*"
63 |
64 | [feature.build.tasks]
65 | build-sdist = "python -m build --sdist --no-isolation ."
66 | build-wheel = "python -m build --wheel --no-isolation ."
67 | check-wheel = "twine check dist/*"
68 | set-version = "sed -i \"s/0.0.0/$(python -m setuptools_scm)/\" pyproject.toml"
69 |
70 | [feature.lint.dependencies]
71 | docformatter = "*"
72 | insert-license-header = "*"
73 | pre-commit = "*"
74 | pre-commit-hooks = "*"
75 | prettier = "*"
76 | ruff = "*"
77 | taplo = "*"
78 | typos = "*"
79 | [feature.lint.tasks]
80 | pre-commit-install = "pre-commit install"
81 | pre-commit-run = "pre-commit run -a"
82 |
83 | [feature.py311.dependencies]
84 | python = "3.11.*"
85 |
86 | [feature.py312.dependencies]
87 | python = "3.12.*"
88 |
89 | [feature.py313.dependencies]
90 | python = "3.13.*"
91 |
92 | [environments]
93 | build = ["build"]
94 | default = ["dev", "lint", "test"]
95 | docs = ["docs"]
96 | lint = { features = ["lint"], no-default-feature = true }
97 | py311 = ["py311", "test"]
98 | py312 = ["py312", "test"]
99 | py313 = ["py313", "test"]
100 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "maturin"
3 | requires = ["maturin>=1.7,<2.0"]
4 |
5 | [project]
6 | authors = [
7 | { name = "Andreas Albert", email = "andreas.albert@quantco.com" },
8 | { name = "Daniel Elsner", email = "daniel.elsner@quantco.com" },
9 | { name = "Oliver Borchert", email = "oliver.borchert@quantco.com" },
10 | ]
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "Programming Language :: Python :: 3.11",
14 | "Programming Language :: Python :: 3.12",
15 | "Programming Language :: Python :: 3.13",
16 | ]
17 | dependencies = ["numpy", "polars>=1.12"]
18 | description = "A declarative, polars-native data frame validation library"
19 | name = "dataframely"
20 | readme = "README.md"
21 | requires-python = ">=3.11"
22 | version = "0.0.0"
23 |
24 | [project.urls]
25 | Documentation = "https://dataframely.readthedocs.io/"
26 | Repository = "https://github.com/quantco/dataframely"
27 |
28 | [tool.maturin]
29 | module-name = "dataframely._extre"
30 | profile = "release"
31 |
32 | [tool.setuptools.packages.find]
33 | include = ["dataframely"]
34 | namespaces = false
35 |
36 | [project.scripts]
37 |
38 | [tool.docformatter]
39 | black = true # only sets the style options to the default values of black
40 |
41 | [tool.ruff]
42 | line-length = 88
43 |
44 | [tool.ruff.lint]
45 | ignore = [
46 | "E501", # https://docs.astral.sh/ruff/faq/#is-the-ruff-linter-compatible-with-black
47 | "N803", # https://docs.astral.sh/ruff/rules/invalid-argument-name
48 | "N806", # https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
49 | ]
50 | select = [
51 | # pyflakes
52 | "F",
53 | # pycodestyle
54 | "E",
55 | "W",
56 | # isort
57 | "I",
58 | # pep8-naming
59 | "N",
60 | # pyupgrade
61 | "UP",
62 | ]
63 |
64 | [tool.ruff.format]
65 | indent-style = "space"
66 | quote-style = "double"
67 |
68 | [tool.mypy]
69 | check_untyped_defs = true
70 | disallow_untyped_defs = true
71 | exclude = ["docs/"]
72 | explicit_package_bases = true
73 | no_implicit_optional = true
74 | plugins = ["dataframely.mypy"]
75 | python_version = '3.11'
76 | warn_unused_ignores = true
77 |
78 | [[tool.mypy.overrides]]
79 | ignore_missing_imports = true
80 | module = ["pyarrow.*"]
81 |
82 | [tool.pytest.ini_options]
83 | addopts = "--import-mode=importlib"
84 | filterwarnings = [
85 | # Almost all tests are oblivious to the value of `nullable`. Let's ignore the warning as long as it exists.
86 | "ignore:The 'nullable' argument was not explicitly set:FutureWarning",
87 | ]
88 | testpaths = ["tests"]
89 |
90 | [tool.coverage.run]
91 | omit = ["dataframely/mypy.py", "dataframely/testing/generate.py"]
92 |
--------------------------------------------------------------------------------
/src/errdefs.rs:
--------------------------------------------------------------------------------
1 | use pyo3::exceptions;
2 | use pyo3::PyErr;
3 |
4 | pub type Result = std::result::Result;
5 |
6 | #[derive(thiserror::Error, Debug)]
7 | pub enum Error {
8 | #[error("failed to parse regex: {0}")]
9 | Parsing(Box),
10 | #[error("failed to interpret bytes as UTF-8: {0}")]
11 | Utf8(#[from] std::str::Utf8Error),
12 | }
13 |
14 | impl From for Error {
15 | fn from(value: regex_syntax::Error) -> Self {
16 | Self::Parsing(Box::new(value))
17 | }
18 | }
19 |
20 | impl From for PyErr {
21 | fn from(value: Error) -> Self {
22 | exceptions::PyValueError::new_err(value.to_string())
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | mod errdefs;
2 | mod regex_repr;
3 |
4 | use pyo3::prelude::*;
5 | use rand::prelude::*;
6 | use rand::rngs::StdRng;
7 | use regex_repr::Regex;
8 |
9 | #[derive(IntoPyObject)]
10 | enum SampleResult {
11 | #[pyo3(transparent)]
12 | One(String),
13 | #[pyo3(transparent)]
14 | Many(Vec),
15 | }
16 |
17 | /// Obtain the minimum and maximum length (if available) of strings matching a regular expression.
18 | #[pyfunction]
19 | fn matching_string_length(regex: &str) -> PyResult<(usize, Option)> {
20 | let compiled = Regex::new(regex)?;
21 | let result = compiled.matching_string_length()?;
22 | Ok(result)
23 | }
24 |
25 | #[pyfunction]
26 | #[pyo3(signature = (regex, n = None, max_repetitions = 16, seed = None))]
27 | fn sample(
28 | regex: &str,
29 | n: Option,
30 | max_repetitions: u32,
31 | seed: Option,
32 | ) -> PyResult {
33 | let compiled = Regex::new(regex)?;
34 | let mut rng = match seed {
35 | None => StdRng::from_os_rng(),
36 | Some(seed) => StdRng::seed_from_u64(seed),
37 | };
38 | let result = match n {
39 | None => {
40 | let result = compiled.sample(&mut rng, max_repetitions)?;
41 | SampleResult::One(result)
42 | }
43 | Some(n) => {
44 | let results = (0..n)
45 | .map(|_| compiled.sample(&mut rng, max_repetitions))
46 | .collect::, _>>()?;
47 | SampleResult::Many(results)
48 | }
49 | };
50 | Ok(result)
51 | }
52 |
53 | #[pymodule]
54 | #[pyo3(name = "_extre")]
55 | fn extre(m: &Bound<'_, PyModule>) -> PyResult<()> {
56 | m.add_function(wrap_pyfunction!(matching_string_length, m)?)?;
57 | m.add_function(wrap_pyfunction!(sample, m)?)?;
58 | Ok(())
59 | }
60 |
--------------------------------------------------------------------------------
/tests/collection/test_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from collections.abc import Callable
5 | from pathlib import Path
6 |
7 | import polars as pl
8 | import pytest
9 | from polars.testing import assert_frame_equal
10 |
11 | import dataframely as dy
12 |
13 |
14 | class MyFirstSchema(dy.Schema):
15 | a = dy.UInt8(primary_key=True)
16 |
17 |
18 | class MySecondSchema(dy.Schema):
19 | a = dy.UInt16(primary_key=True)
20 | b = dy.Integer()
21 |
22 |
23 | class MyCollection(dy.Collection):
24 | first: dy.LazyFrame[MyFirstSchema]
25 | second: dy.LazyFrame[MySecondSchema] | None
26 |
27 |
28 | def test_common_primary_keys() -> None:
29 | assert MyCollection.common_primary_keys() == ["a"]
30 |
31 |
32 | def test_members() -> None:
33 | members = MyCollection.members()
34 | assert not members["first"].is_optional
35 | assert members["second"].is_optional
36 |
37 |
38 | def test_member_schemas() -> None:
39 | schemas = MyCollection.member_schemas()
40 | assert schemas == {"first": MyFirstSchema, "second": MySecondSchema}
41 |
42 |
43 | def test_required_members() -> None:
44 | required_members = MyCollection.required_members()
45 | assert required_members == {"first"}
46 |
47 |
48 | def test_optional_members() -> None:
49 | optional_members = MyCollection.optional_members()
50 | assert optional_members == {"second"}
51 |
52 |
53 | def test_cast() -> None:
54 | collection = MyCollection.cast(
55 | {
56 | "first": pl.LazyFrame({"a": [1, 2, 3]}),
57 | "second": pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}),
58 | },
59 | )
60 | assert collection.first.collect_schema() == MyFirstSchema.polars_schema()
61 | assert collection.second is not None
62 | assert collection.second.collect_schema() == MySecondSchema.polars_schema()
63 |
64 |
65 | @pytest.mark.parametrize(
66 | "expected",
67 | [
68 | {
69 | "first": pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}),
70 | "second": pl.LazyFrame(
71 | {"a": [1, 2, 3], "b": [4, 5, 6]}, schema={"a": pl.UInt16, "b": pl.Int64}
72 | ),
73 | },
74 | {"first": pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8})},
75 | ],
76 | )
77 | def test_to_dict(expected: dict[str, pl.LazyFrame]) -> None:
78 | collection = MyCollection.validate(expected)
79 |
80 | # Check that export looks as expected
81 | observed = collection.to_dict()
82 | assert set(expected.keys()) == set(observed.keys())
83 | for key in expected.keys():
84 | pl.testing.assert_frame_equal(expected[key], observed[key])
85 |
86 | # Make sure that "roundtrip" validation works
87 | assert MyCollection.is_valid(observed)
88 |
89 |
90 | def test_collect_all() -> None:
91 | collection = MyCollection.cast(
92 | {
93 | "first": pl.LazyFrame({"a": [1, 2, 3]}).filter(pl.col("a") < 3),
94 | "second": pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).filter(
95 | pl.col("b") <= 5
96 | ),
97 | }
98 | )
99 | out = collection.collect_all()
100 |
101 | assert isinstance(out, MyCollection)
102 | assert out.first.explain() == 'DF ["a"]; PROJECT */1 COLUMNS'
103 | assert len(out.first.collect()) == 2
104 | assert out.second is not None
105 | assert out.second.explain() == 'DF ["a", "b"]; PROJECT */2 COLUMNS'
106 | assert len(out.second.collect()) == 2
107 |
108 |
109 | def test_collect_all_optional() -> None:
110 | collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
111 | out = collection.collect_all()
112 |
113 | assert isinstance(out, MyCollection)
114 | assert len(out.first.collect()) == 3
115 | assert out.second is None
116 |
117 |
118 | @pytest.mark.parametrize(
119 | "read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet]
120 | )
121 | def test_read_write_parquet(
122 | tmp_path: Path, read_fn: Callable[[Path], MyCollection]
123 | ) -> None:
124 | collection = MyCollection.cast(
125 | {
126 | "first": pl.LazyFrame({"a": [1, 2, 3]}),
127 | "second": pl.LazyFrame({"a": [1, 2], "b": [10, 15]}),
128 | }
129 | )
130 | collection.write_parquet(tmp_path)
131 |
132 | read = read_fn(tmp_path)
133 | assert_frame_equal(collection.first, read.first)
134 | assert collection.second is not None
135 | assert read.second is not None
136 | assert_frame_equal(collection.second, read.second)
137 |
138 |
139 | @pytest.mark.parametrize(
140 | "read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet]
141 | )
142 | def test_read_write_parquet_optional(
143 | tmp_path: Path, read_fn: Callable[[Path], MyCollection]
144 | ) -> None:
145 | collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
146 | collection.write_parquet(tmp_path)
147 |
148 | read = read_fn(tmp_path)
149 | assert_frame_equal(collection.first, read.first)
150 | assert collection.second is None
151 | assert read.second is None
152 |
--------------------------------------------------------------------------------
/tests/collection/test_cast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import polars.exceptions as plexc
6 | import pytest
7 |
8 | import dataframely as dy
9 |
10 |
11 | class FirstSchema(dy.Schema):
12 | a = dy.Float64()
13 |
14 |
15 | class SecondSchema(dy.Schema):
16 | a = dy.String()
17 |
18 |
19 | class Collection(dy.Collection):
20 | first: dy.LazyFrame[FirstSchema]
21 | second: dy.LazyFrame[SecondSchema] | None
22 |
23 |
24 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
25 | def test_cast_valid(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
26 | first = df_type({"a": [3]})
27 | second = df_type({"a": [1]})
28 | out = Collection.cast({"first": first, "second": second}) # type: ignore
29 | assert out.first.collect_schema() == FirstSchema.polars_schema()
30 | assert out.second is not None
31 | assert out.second.collect_schema() == SecondSchema.polars_schema()
32 |
33 |
34 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
35 | def test_cast_valid_optional(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
36 | first = df_type({"a": [3]})
37 | out = Collection.cast({"first": first}) # type: ignore
38 | assert out.first.collect_schema() == FirstSchema.polars_schema()
39 | assert out.second is None
40 |
41 |
42 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
43 | def test_cast_invalid_members(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
44 | first = df_type({"a": [3]})
45 | with pytest.raises(ValueError):
46 | Collection.cast({"third": first}) # type: ignore
47 |
48 |
49 | def test_cast_invalid_member_schema_eager() -> None:
50 | first = pl.DataFrame({"b": [3]})
51 | with pytest.raises(plexc.ColumnNotFoundError):
52 | Collection.cast({"first": first})
53 |
54 |
55 | def test_cast_invalid_member_schema_lazy() -> None:
56 | first = pl.LazyFrame({"b": [3]})
57 | collection = Collection.cast({"first": first})
58 | with pytest.raises(plexc.ColumnNotFoundError):
59 | collection.collect_all()
60 |
--------------------------------------------------------------------------------
/tests/collection/test_create_empty.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | import dataframely as dy
6 |
7 |
8 | class MyFirstSchema(dy.Schema):
9 | a = dy.Integer(primary_key=True)
10 | b = dy.Integer()
11 |
12 |
13 | class MySecondSchema(dy.Schema):
14 | a = dy.Integer(primary_key=True)
15 | b = dy.Integer(min=1)
16 |
17 |
18 | class MyCollection(dy.Collection):
19 | first: dy.LazyFrame[MyFirstSchema]
20 | second: dy.LazyFrame[MySecondSchema] | None
21 |
22 |
23 | def test_create_empty() -> None:
24 | collection = MyCollection.create_empty()
25 | assert collection.first.collect().height == 0
26 | assert collection.first.collect_schema() == MyFirstSchema.polars_schema()
27 | assert collection.second is not None
28 | assert collection.second.collect().height == 0
29 | assert collection.second.collect_schema() == MySecondSchema.polars_schema()
30 |
--------------------------------------------------------------------------------
/tests/collection/test_filter_one_to_n.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Any
5 |
6 | import polars as pl
7 |
8 | import dataframely as dy
9 |
10 |
11 | class CarSchema(dy.Schema):
12 | vin = dy.String(primary_key=True)
13 | manufacturer = dy.String(nullable=False)
14 |
15 |
16 | class CarPartSchema(dy.Schema):
17 | vin = dy.String(primary_key=True)
18 | part = dy.String(primary_key=True)
19 | price = dy.Float64(primary_key=True)
20 |
21 |
22 | class CarFleet(dy.Collection):
23 | cars: dy.LazyFrame[CarSchema]
24 | car_parts: dy.LazyFrame[CarPartSchema]
25 |
26 | @dy.filter()
27 | def not_car_with_vin_123(self) -> pl.LazyFrame:
28 | return self.cars.filter(pl.col("vin") != pl.lit("123"))
29 |
30 |
31 | def test_valid_failure_infos() -> None:
32 | cars = {"vin": ["123", "456"], "manufacturer": ["BMW", "Mercedes"]}
33 | car_parts: dict[str, list[Any]] = {
34 | "vin": ["123", "123", "456"],
35 | "part": ["Motor", "Wheel", "Motor"],
36 | "price": [1000, 100, 1000],
37 | }
38 | car_fleet, failures = CarFleet.filter(
39 | {"cars": pl.DataFrame(cars), "car_parts": pl.DataFrame(car_parts)},
40 | cast=True,
41 | )
42 |
43 | assert len(car_fleet.cars.collect()) + len(failures["cars"].invalid()) == len(
44 | cars["vin"]
45 | )
46 | assert len(car_fleet.car_parts.collect()) + len(
47 | failures["car_parts"].invalid()
48 | ) == len(car_parts["vin"])
49 |
--------------------------------------------------------------------------------
/tests/collection/test_ignore_in_filter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Annotated
5 |
6 | import polars as pl
7 |
8 | import dataframely as dy
9 |
10 |
11 | class MyTestSchema(dy.Schema):
12 | a_id = dy.UInt8(primary_key=True)
13 |
14 |
15 | class MyTestSchema2(dy.Schema):
16 | a_id = dy.UInt8(primary_key=True)
17 | b_id = dy.UInt8()
18 |
19 |
20 | class IgnoredSchema(dy.Schema):
21 | b_id = dy.UInt8(primary_key=True)
22 |
23 |
24 | class MyTestCollection(dy.Collection):
25 | a: dy.LazyFrame[MyTestSchema]
26 | b: dy.LazyFrame[MyTestSchema2]
27 |
28 | ignored: Annotated[
29 | dy.LazyFrame[IgnoredSchema],
30 | dy.CollectionMember(ignored_in_filters=True),
31 | ]
32 |
33 | @dy.filter()
34 | def filter_a_id(self) -> pl.LazyFrame:
35 | return self.a.join(self.b, on="a_id")
36 |
37 | @dy.filter()
38 | def custom_filter_on_ignored(self) -> pl.LazyFrame:
39 | # we still need to return shared instances
40 | used_a_ids = (
41 | self.ignored.unique()
42 | .join(self.b, on="b_id")
43 | .join(self.a, on="a_id")
44 | .select("a_id")
45 | )
46 | return used_a_ids
47 |
48 |
49 | def test_collection_ignore_in_filter_meta() -> None:
50 | assert MyTestCollection.non_ignored_members() == {"a", "b"}
51 | assert MyTestCollection.ignored_members() == {"ignored"}
52 |
53 |
54 | def test_collection_ignore_in_filter() -> None:
55 | success, failure = MyTestCollection.filter(
56 | {
57 | "a": pl.LazyFrame({"a_id": [1, 2, 3]}),
58 | "b": pl.LazyFrame({"a_id": [1, 2, 3], "b_id": [4, 5, 6]}),
59 | "ignored": pl.LazyFrame({"b_id": [4, 5, 6]}),
60 | },
61 | cast=True,
62 | )
63 | assert failure["a"].invalid().height == 0
64 | assert failure["b"].invalid().height == 0
65 | assert failure["ignored"].invalid().height == 0
66 |
67 |
68 | def test_collection_ignore_in_filter_failure() -> None:
69 | success, failure = MyTestCollection.filter(
70 | {
71 | "a": pl.LazyFrame({"a_id": [1, 2, 3]}),
72 | "b": pl.LazyFrame({"a_id": [1, 2, 3], "b_id": [4, 5, 6]}),
73 | "ignored": pl.LazyFrame(
74 | {"b_id": [9999, 5, 6]}
75 | ), # a_id=1 not used by any ignored
76 | },
77 | cast=True,
78 | )
79 | assert failure["a"].invalid().height == 1
80 | assert failure["b"].invalid().height == 1
81 | assert failure["ignored"].invalid().height == 1
82 |
83 | assert failure["a"].counts() == {"custom_filter_on_ignored": 1}
84 |
--------------------------------------------------------------------------------
/tests/collection/test_optional_members.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | import dataframely as dy
7 |
8 |
9 | class TestSchema(dy.Schema):
10 | a = dy.Integer()
11 |
12 |
13 | class MyCollection(dy.Collection):
14 | first: dy.LazyFrame[TestSchema]
15 | second: dy.LazyFrame[TestSchema] | None
16 |
17 |
18 | def test_collection_optional_member() -> None:
19 | MyCollection.validate({"first": pl.LazyFrame({"a": [1, 2, 3]})})
20 |
21 |
22 | def test_filter_failure_info_keys_only_required() -> None:
23 | out, failure = MyCollection.filter({"first": pl.LazyFrame({"a": [1, 2, 3]})})
24 | assert out.second is None
25 | assert set(failure.keys()) == {"first"}
26 |
27 |
28 | def test_filter_failure_info_keys_required_and_optional() -> None:
29 | out, failure = MyCollection.filter(
30 | {
31 | "first": pl.LazyFrame({"a": [1, 2, 3]}),
32 | "second": pl.LazyFrame({"a": [1, 2, 3]}),
33 | },
34 | )
35 | assert out.second is not None
36 | assert set(failure.keys()) == {"first", "second"}
37 |
--------------------------------------------------------------------------------
/tests/collection/test_validate_input.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 |
9 |
10 | class TestSchema(dy.Schema):
11 | a = dy.Integer()
12 |
13 |
14 | class MyCollection(dy.Collection):
15 | first: dy.LazyFrame[TestSchema]
16 | second: dy.LazyFrame[TestSchema] | None
17 |
18 |
19 | def test_collection_missing_required_member() -> None:
20 | with pytest.raises(ValueError):
21 | MyCollection.validate({"second": pl.LazyFrame({"a": [1, 2, 3]})})
22 |
23 |
24 | def test_collection_superfluous_member() -> None:
25 | with pytest.warns(Warning):
26 | MyCollection.validate(
27 | {
28 | "first": pl.LazyFrame({"a": [1, 2, 3]}),
29 | "third": pl.LazyFrame({"a": [1, 2, 3]}),
30 | },
31 | )
32 |
--------------------------------------------------------------------------------
/tests/column_types/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
--------------------------------------------------------------------------------
/tests/column_types/test_any.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Any
5 |
6 | import polars as pl
7 | import pytest
8 |
9 | import dataframely as dy
10 |
11 |
12 | class AnySchema(dy.Schema):
13 | a = dy.Any()
14 |
15 |
16 | @pytest.mark.parametrize(
17 | "data",
18 | [{"a": [None]}, {"a": [True, None]}, {"a": ["foo"]}, {"a": [3.5]}],
19 | )
20 | def test_any_dtype_passes(data: dict[str, Any]) -> None:
21 | df = pl.DataFrame(data)
22 | assert AnySchema.is_valid(df)
23 |
--------------------------------------------------------------------------------
/tests/column_types/test_array.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 | from dataframely.columns._base import Column
9 | from dataframely.testing import create_schema
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "inner",
14 | [
15 | (dy.Int64()),
16 | (dy.Integer()),
17 | ],
18 | )
19 | def test_integer_array(inner: Column) -> None:
20 | schema = create_schema("test", {"a": dy.Array(inner, 1)})
21 | assert schema.is_valid(
22 | pl.DataFrame(
23 | {"a": [[1], [2], [3]]},
24 | schema={
25 | "a": pl.Array(pl.Int64, 1),
26 | },
27 | )
28 | )
29 |
30 |
31 | def test_invalid_inner_type() -> None:
32 | schema = create_schema("test", {"a": dy.Array(dy.Int64(), 1)})
33 | assert not schema.is_valid(pl.DataFrame({"a": [["1"], ["2"], ["3"]]}))
34 |
35 |
36 | def test_invalid_shape() -> None:
37 | schema = create_schema("test", {"a": dy.Array(dy.Int64(), 2)})
38 | assert not schema.is_valid(
39 | pl.DataFrame(
40 | {"a": [[1], [2], [3]]},
41 | schema={
42 | "a": pl.Array(pl.Int64, 1),
43 | },
44 | )
45 | )
46 |
47 |
48 | @pytest.mark.parametrize(
49 | ("column", "dtype", "is_valid"),
50 | [
51 | (
52 | dy.Array(dy.Int64(), 1),
53 | pl.Array(pl.Int64(), 1),
54 | True,
55 | ),
56 | (
57 | dy.Array(dy.String(), 1),
58 | pl.Array(pl.Int64(), 1),
59 | False,
60 | ),
61 | (
62 | dy.Array(dy.String(), 1),
63 | pl.Array(pl.Int64(), 2),
64 | False,
65 | ),
66 | (
67 | dy.Array(dy.Int64(), (1,)),
68 | pl.Array(pl.Int64(), (1,)),
69 | True,
70 | ),
71 | (
72 | dy.Array(dy.Int64(), (1,)),
73 | pl.Array(pl.Int64(), (2,)),
74 | False,
75 | ),
76 | (
77 | dy.Array(dy.String(), 1),
78 | dy.Array(dy.String(), 1),
79 | False,
80 | ),
81 | (
82 | dy.Array(dy.String(), 1),
83 | dy.String(),
84 | False,
85 | ),
86 | (
87 | dy.Array(dy.String(), 1),
88 | pl.String(),
89 | False,
90 | ),
91 | (
92 | dy.Array(dy.Array(dy.String(), 1), 1),
93 | pl.Array(pl.String(), (1, 1)),
94 | True,
95 | ),
96 | (
97 | dy.Array(dy.String(), (1, 1)),
98 | pl.Array(pl.Array(pl.String(), 1), 1),
99 | True,
100 | ),
101 | ],
102 | )
103 | def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None:
104 | assert column.validate_dtype(dtype) == is_valid
105 |
106 |
107 | def test_nested_arrays() -> None:
108 | schema = create_schema("test", {"a": dy.Array(dy.Array(dy.Int64(), 1), 1)})
109 | assert schema.is_valid(
110 | pl.DataFrame(
111 | {"a": [[[1]], [[2]], [[3]]]},
112 | schema={
113 | "a": pl.Array(pl.Int64, (1, 1)),
114 | },
115 | )
116 | )
117 |
118 |
119 | def test_nested_array() -> None:
120 | schema = create_schema("test", {"a": dy.Array(dy.Array(dy.Int64(), 1), 1)})
121 | assert schema.is_valid(
122 | pl.DataFrame(
123 | {"a": [[[1]], [[2]], [[3]]]},
124 | schema={
125 | "a": pl.Array(pl.Int64, (1, 1)),
126 | },
127 | )
128 | )
129 |
130 |
131 | def test_array_with_inner_pk() -> None:
132 | with pytest.raises(ValueError):
133 | column = dy.Array(dy.String(primary_key=True), 2)
134 | create_schema(
135 | "test",
136 | {"a": column},
137 | )
138 |
139 |
140 | def test_array_with_rules() -> None:
141 | with pytest.raises(ValueError):
142 | create_schema(
143 | "test", {"a": dy.Array(dy.String(min_length=2, nullable=False), 1)}
144 | )
145 |
146 |
147 | def test_outer_nullability() -> None:
148 | schema = create_schema(
149 | "test",
150 | {
151 | "nullable": dy.Array(
152 | inner=dy.Integer(),
153 | shape=1,
154 | nullable=True,
155 | )
156 | },
157 | )
158 | df = pl.DataFrame({"nullable": [None, None]})
159 | schema.validate(df, cast=True)
160 |
--------------------------------------------------------------------------------
/tests/column_types/test_decimal.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import decimal
5 | from typing import Any
6 |
7 | import polars as pl
8 | import pytest
9 | from polars.datatypes import DataTypeClass
10 | from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
11 | from polars.testing import assert_frame_equal
12 |
13 | import dataframely as dy
14 | from dataframely.testing import evaluate_rules, rules_from_exprs
15 |
16 |
17 | class DecimalSchema(dy.Schema):
18 | a = dy.Decimal()
19 |
20 |
21 | @pytest.mark.parametrize(
22 | "kwargs",
23 | [
24 | {"min": decimal.Decimal(2), "max": decimal.Decimal(1)},
25 | {"min_exclusive": decimal.Decimal(2), "max": decimal.Decimal(2)},
26 | {"min": decimal.Decimal(2), "max_exclusive": decimal.Decimal(2)},
27 | {"min_exclusive": decimal.Decimal(2), "max_exclusive": decimal.Decimal(2)},
28 | {"min": decimal.Decimal(2), "min_exclusive": decimal.Decimal(2)},
29 | {"max": decimal.Decimal(2), "max_exclusive": decimal.Decimal(2)},
30 | ],
31 | )
32 | def test_args_consistency_min_max(kwargs: dict[str, Any]) -> None:
33 | with pytest.raises(ValueError):
34 | dy.Decimal(**kwargs)
35 |
36 |
37 | @pytest.mark.parametrize(
38 | "kwargs",
39 | [
40 | dict(scale=1, min=decimal.Decimal("3.14")),
41 | dict(scale=1, min_exclusive=decimal.Decimal("3.14")),
42 | dict(scale=1, max=decimal.Decimal("3.14")),
43 | dict(scale=1, max_exclusive=decimal.Decimal("3.14")),
44 | dict(min=decimal.Decimal(float("inf"))),
45 | dict(max=decimal.Decimal(float("inf"))),
46 | dict(precision=2, min=decimal.Decimal("100")),
47 | dict(precision=2, max=decimal.Decimal("100")),
48 | ],
49 | )
50 | def test_invalid_args(kwargs: dict[str, Any]) -> None:
51 | with pytest.raises(ValueError):
52 | dy.Decimal(**kwargs)
53 |
54 |
55 | @pytest.mark.parametrize(
56 | "dtype", [pl.Decimal, pl.Decimal(12), pl.Decimal(None, 8), pl.Decimal(6, 2)]
57 | )
58 | def test_any_decimal_dtype_passes(dtype: DataTypeClass) -> None:
59 | df = pl.DataFrame(schema={"a": dtype})
60 | assert DecimalSchema.is_valid(df)
61 |
62 |
63 | @pytest.mark.parametrize(
64 | "dtype", [pl.Boolean, pl.String] + list(INTEGER_DTYPES) + list(FLOAT_DTYPES)
65 | )
66 | def test_non_decimal_dtype_fails(dtype: DataTypeClass) -> None:
67 | df = pl.DataFrame(schema={"a": dtype})
68 | assert not DecimalSchema.is_valid(df)
69 |
70 |
71 | @pytest.mark.parametrize(
72 | ("inclusive", "valid"),
73 | [
74 | (True, {"min": [False, False, True, True, True]}),
75 | (False, {"min_exclusive": [False, False, False, True, True]}),
76 | ],
77 | )
78 | def test_validate_min(inclusive: bool, valid: dict[str, list[bool]]) -> None:
79 | kwargs = {("min" if inclusive else "min_exclusive"): decimal.Decimal(3)}
80 | column = dy.Decimal(**kwargs) # type: ignore
81 | lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
82 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
83 | expected = pl.LazyFrame(valid)
84 | assert_frame_equal(actual, expected)
85 |
86 |
87 | @pytest.mark.parametrize(
88 | ("inclusive", "valid"),
89 | [
90 | (True, {"max": [True, True, True, False, False]}),
91 | (False, {"max_exclusive": [True, True, False, False, False]}),
92 | ],
93 | )
94 | def test_validate_max(inclusive: bool, valid: dict[str, list[bool]]) -> None:
95 | kwargs = {("max" if inclusive else "max_exclusive"): decimal.Decimal(3)}
96 | column = dy.Decimal(**kwargs) # type: ignore
97 | lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
98 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
99 | expected = pl.LazyFrame(valid)
100 | assert_frame_equal(actual, expected)
101 |
102 |
103 | @pytest.mark.parametrize(
104 | ("min_inclusive", "max_inclusive", "valid"),
105 | [
106 | (
107 | True,
108 | True,
109 | {
110 | "min": [False, True, True, True, True],
111 | "max": [True, True, True, True, False],
112 | },
113 | ),
114 | (
115 | True,
116 | False,
117 | {
118 | "min": [False, True, True, True, True],
119 | "max_exclusive": [True, True, True, False, False],
120 | },
121 | ),
122 | (
123 | False,
124 | True,
125 | {
126 | "min_exclusive": [False, False, True, True, True],
127 | "max": [True, True, True, True, False],
128 | },
129 | ),
130 | (
131 | False,
132 | False,
133 | {
134 | "min_exclusive": [False, False, True, True, True],
135 | "max_exclusive": [True, True, True, False, False],
136 | },
137 | ),
138 | ],
139 | )
140 | def test_validate_range(
141 | min_inclusive: bool,
142 | max_inclusive: bool,
143 | valid: dict[str, list[bool]],
144 | ) -> None:
145 | kwargs = {
146 | ("min" if min_inclusive else "min_exclusive"): decimal.Decimal(2),
147 | ("max" if max_inclusive else "max_exclusive"): decimal.Decimal(4),
148 | }
149 | column = dy.Decimal(**kwargs) # type: ignore
150 | lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
151 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
152 | expected = pl.LazyFrame(valid)
153 | assert_frame_equal(actual, expected)
154 |
--------------------------------------------------------------------------------
/tests/column_types/test_enum.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Any
5 |
6 | import polars as pl
7 | import pytest
8 |
9 | import dataframely as dy
10 | from dataframely.testing.factory import create_schema
11 |
12 |
13 | @pytest.mark.parametrize(
14 | ("dy_enum", "pl_dtype", "valid"),
15 | [
16 | (dy.Enum(["x", "y"]), pl.Enum(["x", "y"]), True),
17 | (dy.Enum(["y", "x"]), pl.Enum(["x", "y"]), False),
18 | (dy.Enum(["x"]), pl.Enum(["x", "y"]), False),
19 | (dy.Enum(["x", "y", "z"]), pl.Enum(["x", "y"]), False),
20 | (dy.Enum(["x", "y"]), pl.String(), False),
21 | ],
22 | )
23 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
24 | def test_valid(
25 | df_type: type[pl.DataFrame] | type[pl.LazyFrame],
26 | dy_enum: dy.Enum,
27 | pl_dtype: pl.Enum,
28 | valid: bool,
29 | ) -> None:
30 | schema = create_schema("test", {"a": dy_enum})
31 | df = df_type({"a": ["x", "y", "x", "x"]}).cast(pl_dtype)
32 | assert schema.is_valid(df) == valid
33 |
34 |
35 | @pytest.mark.parametrize("enum", [dy.Enum(["x", "y"]), dy.Enum(["y", "x"])])
36 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
37 | @pytest.mark.parametrize(
38 | ("data", "valid"),
39 | [
40 | ({"a": ["x", "y", "x", "x"]}, True),
41 | ({"a": ["x", "y", "x", "x"]}, True),
42 | ({"a": ["x", "y", "z"]}, False),
43 | ({"a": ["x", "y", "z"]}, False),
44 | ],
45 | )
46 | def test_valid_cast(
47 | enum: dy.Enum,
48 | data: Any,
49 | valid: bool,
50 | df_type: type[pl.DataFrame] | type[pl.LazyFrame],
51 | ) -> None:
52 | schema = create_schema("test", {"a": enum})
53 | df = df_type(data)
54 | assert schema.is_valid(df, cast=True) == valid
55 |
56 |
57 | @pytest.mark.parametrize("type1", [list, tuple])
58 | @pytest.mark.parametrize("type2", [list, tuple])
59 | def test_different_sequences(type1: type, type2: type) -> None:
60 | allowed = ["a", "b"]
61 | S = create_schema("test", {"x": dy.Enum(type1(allowed))})
62 | df = pl.DataFrame({"x": pl.Series(["a", "b"], dtype=pl.Enum(type2(allowed)))})
63 | S.validate(df)
64 |
--------------------------------------------------------------------------------
/tests/column_types/test_object.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | import polars as pl
6 | import pytest
7 |
8 | import dataframely as dy
9 | from dataframely.columns._base import Column
10 | from dataframely.random import Generator
11 | from dataframely.testing import create_schema
12 |
13 |
14 | class CustomObject:
15 | def __init__(self, a: int, b: str) -> None:
16 | self.a = a
17 | self.b = b
18 |
19 |
20 | def test_simple_object() -> None:
21 | schema = create_schema("test", {"o": dy.Object()})
22 | assert schema.is_valid(
23 | pl.DataFrame({"o": [CustomObject(a=1, b="foo"), CustomObject(a=2, b="bar")]})
24 | )
25 |
26 |
27 | @pytest.mark.parametrize(
28 | ("column", "dtype", "is_valid"),
29 | [
30 | (
31 | dy.Object(),
32 | pl.Object(),
33 | True,
34 | ),
35 | (
36 | dy.Object(),
37 | object(),
38 | False,
39 | ),
40 | ],
41 | )
42 | def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None:
43 | assert column.validate_dtype(dtype) == is_valid
44 |
45 |
46 | def test_pyarrow_dtype_raises() -> None:
47 | column = dy.Object()
48 | with pytest.raises(
49 | NotImplementedError, match="PyArrow column cannot have 'Object' type."
50 | ):
51 | column.pyarrow_dtype
52 |
53 |
54 | def test_sampling_raises() -> None:
55 | column = dy.Object()
56 | with pytest.raises(
57 | NotImplementedError,
58 | match="Random data sampling not implemented for 'Object' type.",
59 | ):
60 | column.sample(generator=Generator(), n=10)
61 |
--------------------------------------------------------------------------------
/tests/column_types/test_string.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | from polars.testing import assert_frame_equal
6 |
7 | import dataframely as dy
8 | from dataframely.testing import evaluate_rules, rules_from_exprs
9 |
10 |
11 | def test_validate_min_length() -> None:
12 | column = dy.String(min_length=2)
13 | lf = pl.LazyFrame({"a": ["foo", "x"]})
14 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
15 | expected = pl.LazyFrame({"min_length": [True, False]})
16 | assert_frame_equal(actual, expected)
17 |
18 |
19 | def test_validate_max_length() -> None:
20 | column = dy.String(max_length=2)
21 | lf = pl.LazyFrame({"a": ["foo", "x"]})
22 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
23 | expected = pl.LazyFrame({"max_length": [False, True]})
24 | assert_frame_equal(actual, expected)
25 |
26 |
27 | def test_validate_regex() -> None:
28 | column = dy.String(regex="[0-9][a-z]$")
29 | lf = pl.LazyFrame({"a": ["33x", "3x", "44"]})
30 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
31 | expected = pl.LazyFrame({"regex": [True, True, False]})
32 | assert_frame_equal(actual, expected)
33 |
34 |
35 | def test_validate_all_rules() -> None:
36 | column = dy.String(nullable=False, min_length=2, max_length=4)
37 | lf = pl.LazyFrame({"a": ["foo", "x", "foobar", None]})
38 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
39 | expected = pl.LazyFrame(
40 | {
41 | "min_length": [True, False, True, True],
42 | "max_length": [True, True, False, True],
43 | "nullability": [True, True, True, False],
44 | }
45 | )
46 | assert_frame_equal(actual, expected, check_column_order=False)
47 |
--------------------------------------------------------------------------------
/tests/column_types/test_struct.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | import polars as pl
6 | import pytest
7 |
8 | import dataframely as dy
9 | from dataframely.columns._base import Column
10 | from dataframely.testing import create_schema
11 |
12 |
13 | def test_simple_struct() -> None:
14 | schema = create_schema(
15 | "test", {"s": dy.Struct({"a": dy.Integer(), "b": dy.String()})}
16 | )
17 | assert schema.is_valid(
18 | pl.DataFrame({"s": [{"a": 1, "b": "foo"}, {"a": 2, "b": "foo"}]})
19 | )
20 |
21 |
22 | @pytest.mark.parametrize(
23 | ("column", "dtype", "is_valid"),
24 | [
25 | (
26 | dy.Struct({"a": dy.Int64(), "b": dy.String()}),
27 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
28 | True,
29 | ),
30 | (
31 | dy.Struct({"b": dy.String(), "a": dy.Int64()}),
32 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
33 | True,
34 | ),
35 | (
36 | dy.Struct({"a": dy.Int64(), "b": dy.String(), "c": dy.String()}),
37 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
38 | False,
39 | ),
40 | (
41 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
42 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
43 | False,
44 | ),
45 | (
46 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
47 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
48 | False,
49 | ),
50 | (
51 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
52 | pl.Struct({"a": pl.Int64(), "b": pl.String()}),
53 | False,
54 | ),
55 | (
56 | dy.Struct({"a": dy.Int64(), "b": dy.String()}),
57 | pl.Struct({"a": pl.Int64(), "c": pl.String()}),
58 | False,
59 | ),
60 | (
61 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
62 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
63 | False,
64 | ),
65 | (
66 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
67 | dy.String(),
68 | False,
69 | ),
70 | (
71 | dy.Struct({"a": dy.String(), "b": dy.Int64()}),
72 | pl.String(),
73 | False,
74 | ),
75 | ],
76 | )
77 | def test_validate_dtype(column: Column, dtype: pl.DataType, is_valid: bool) -> None:
78 | assert column.validate_dtype(dtype) == is_valid
79 |
80 |
81 | def test_invalid_inner_type() -> None:
82 | schema = create_schema("test", {"a": dy.Struct({"a": dy.Int64()})})
83 | assert not schema.is_valid(pl.DataFrame({"a": [{"a": "1"}, {"a": "2"}]}))
84 |
85 |
86 | def test_nested_structs() -> None:
87 | schema = create_schema(
88 | "test",
89 | {
90 | "s1": dy.Struct(
91 | {
92 | "s2": dy.Struct({"a": dy.Integer(), "b": dy.String()}),
93 | "c": dy.String(),
94 | }
95 | )
96 | },
97 | )
98 | assert schema.is_valid(
99 | pl.DataFrame({"s1": [{"s2": {"a": 1, "b": "foo"}, "c": "bar"}]})
100 | )
101 |
102 |
103 | def test_struct_with_pk() -> None:
104 | schema = create_schema(
105 | "test",
106 | {"s": dy.Struct({"a": dy.String(), "b": dy.Integer()}, primary_key=True)},
107 | )
108 | df = pl.DataFrame(
109 | {"s": [{"a": "foo", "b": 1}, {"a": "bar", "b": 1}, {"a": "bar", "b": 1}]}
110 | )
111 | _, failures = schema.filter(df)
112 | assert failures.invalid().to_dict(as_series=False) == {
113 | "s": [{"a": "bar", "b": 1}, {"a": "bar", "b": 1}]
114 | }
115 | assert failures.counts() == {"primary_key": 2}
116 |
117 |
118 | def test_struct_with_rules() -> None:
119 | schema = create_schema(
120 | "test", {"s": dy.Struct({"a": dy.String(min_length=2, nullable=False)})}
121 | )
122 | df = pl.DataFrame({"s": [{"a": "ab"}, {"a": "a"}, {"a": None}]})
123 | _, failures = schema.filter(df)
124 | assert failures.invalid().to_dict(as_series=False) == {
125 | "s": [{"a": "a"}, {"a": None}]
126 | }
127 | assert failures.counts() == {"s|inner_a_nullability": 1, "s|inner_a_min_length": 1}
128 |
129 |
130 | def test_nested_struct_with_rules() -> None:
131 | schema = create_schema(
132 | "test",
133 | {
134 | "s1": dy.Struct(
135 | {"s2": dy.Struct({"a": dy.String(min_length=2, nullable=False)})}
136 | )
137 | },
138 | )
139 | df = pl.DataFrame(
140 | {"s1": [{"s2": {"a": "ab"}}, {"s2": {"a": "a"}}, {"s2": {"a": None}}]}
141 | )
142 | _, failures = schema.filter(df)
143 | assert failures.invalid().to_dict(as_series=False) == {
144 | "s1": [{"s2": {"a": "a"}}, {"s2": {"a": None}}]
145 | }
146 | assert failures.counts() == {
147 | "s1|inner_s2_inner_a_nullability": 1,
148 | "s1|inner_s2_inner_a_min_length": 1,
149 | }
150 |
151 |
152 | def test_outer_inner_nullability() -> None:
153 | schema = create_schema(
154 | "test",
155 | {
156 | "nullable": dy.Struct(
157 | inner={
158 | "not_nullable1": dy.Integer(nullable=False),
159 | "not_nullable2": dy.Integer(nullable=False),
160 | },
161 | nullable=True,
162 | )
163 | },
164 | )
165 | df = pl.DataFrame({"nullable": [None, None]})
166 |
167 | schema.validate(df, cast=True)
168 |
--------------------------------------------------------------------------------
/tests/columns/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
--------------------------------------------------------------------------------
/tests/columns/test_alias.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | import dataframely as dy
7 |
8 |
9 | class AliasSchema(dy.Schema):
10 | a = dy.Int64(alias="hello world: col with space!")
11 |
12 |
13 | def test_column_names() -> None:
14 | assert AliasSchema.column_names() == ["hello world: col with space!"]
15 |
16 |
17 | def test_validation() -> None:
18 | df = pl.DataFrame({"hello world: col with space!": [1, 2]})
19 | assert AliasSchema.is_valid(df)
20 |
21 |
22 | def test_create_empty() -> None:
23 | df = AliasSchema.create_empty()
24 | assert AliasSchema.is_valid(df)
25 |
--------------------------------------------------------------------------------
/tests/columns/test_check.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | import dataframely as dy
7 | from dataframely.testing import validation_mask
8 |
9 |
10 | class CheckSchema(dy.Schema):
11 | a = dy.Int64(check=lambda col: (col < 5) | (col > 10))
12 | b = dy.String(min_length=3, check=lambda col: col.str.contains("x"))
13 |
14 |
15 | def test_check() -> None:
16 | df = pl.DataFrame({"a": [7, 3, 15], "b": ["abc", "xyz", "x"]})
17 | _, failures = CheckSchema.filter(df)
18 | assert validation_mask(df, failures).to_list() == [False, True, False]
19 | assert failures.counts() == {"a|check": 1, "b|min_length": 1, "b|check": 1}
20 |
21 |
22 | def test_check_names() -> None:
23 | def str_starts_with_a(col: pl.Expr) -> pl.Expr:
24 | return col.str.starts_with("a")
25 |
26 | def str_end_with_z(col: pl.Expr) -> pl.Expr:
27 | return col.str.ends_with("z")
28 |
29 | class MultiCheckSchema(dy.Schema):
30 | name_from_dict = dy.Int64(
31 | check={
32 | "min_max_check": lambda col: (col < 5) | (col > 10),
33 | "summation_check": lambda col: col.sum() < 3,
34 | }
35 | )
36 | name_from_callable = dy.String(check=str_starts_with_a)
37 | name_from_list_of_callables = dy.String(
38 | check=[
39 | str_starts_with_a,
40 | str_end_with_z,
41 | str_end_with_z,
42 | lambda x: x.str.contains("x"),
43 | lambda x: x.str.contains("y"),
44 | ]
45 | )
46 | name_from_lambda = dy.Int64(check=lambda x: x < 2)
47 |
48 | df = pl.DataFrame(
49 | {
50 | "name_from_dict": [2, 4, 6],
51 | "name_from_callable": ["abc", "acd", "dca"],
52 | "name_from_list_of_callables": ["xyz", "xac", "aqq"],
53 | "name_from_lambda": [1, 2, 3],
54 | }
55 | )
56 | _, failures = MultiCheckSchema.filter(df)
57 |
58 | assert failures.counts() == {
59 | "name_from_dict|check__min_max_check": 1,
60 | "name_from_dict|check__summation_check": 3,
61 | "name_from_callable|check__str_starts_with_a": 1,
62 | "name_from_list_of_callables|check__str_starts_with_a": 2,
63 | "name_from_list_of_callables|check__str_end_with_z__0": 2,
64 | "name_from_list_of_callables|check__str_end_with_z__1": 2,
65 | "name_from_list_of_callables|check__0": 1,
66 | "name_from_list_of_callables|check__1": 2,
67 | "name_from_lambda|check": 2,
68 | }
69 |
--------------------------------------------------------------------------------
/tests/columns/test_default_dtypes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 | from dataframely.columns import Column
9 | from dataframely.testing import create_schema
10 |
11 |
12 | @pytest.mark.parametrize(
13 | ("column", "dtype"),
14 | [
15 | (dy.Any(), pl.Null()),
16 | (dy.Bool(), pl.Boolean()),
17 | (dy.Date(), pl.Date()),
18 | (dy.Datetime(), pl.Datetime()),
19 | (dy.Time(), pl.Time()),
20 | (dy.Duration(), pl.Duration()),
21 | (dy.Decimal(), pl.Decimal()),
22 | (dy.Decimal(12), pl.Decimal(12)),
23 | (dy.Decimal(None, 8), pl.Decimal(None, 8)),
24 | (dy.Decimal(6, 2), pl.Decimal(6, 2)),
25 | (dy.Float(), pl.Float64()),
26 | (dy.Float32(), pl.Float32()),
27 | (dy.Float64(), pl.Float64()),
28 | (dy.Integer(), pl.Int64()),
29 | (dy.Int8(), pl.Int8()),
30 | (dy.Int16(), pl.Int16()),
31 | (dy.Int32(), pl.Int32()),
32 | (dy.Int64(), pl.Int64()),
33 | (dy.UInt8(), pl.UInt8()),
34 | (dy.UInt16(), pl.UInt16()),
35 | (dy.UInt32(), pl.UInt32()),
36 | (dy.UInt64(), pl.UInt64()),
37 | (dy.String(), pl.String()),
38 | (dy.List(dy.String()), pl.List(pl.String())),
39 | (dy.Array(dy.String(), 1), pl.Array(pl.String(), 1)),
40 | (dy.Struct({"a": dy.String()}), pl.Struct({"a": pl.String()})),
41 | (dy.Enum(["a", "b"]), pl.Enum(["a", "b"])),
42 | ],
43 | )
44 | def test_default_dtype(column: Column, dtype: pl.DataType) -> None:
45 | schema = create_schema("test", {"a": column})
46 | df = schema.create_empty()
47 | assert df.schema["a"] == dtype
48 | schema.validate(df)
49 | assert schema.is_valid(df)
50 |
--------------------------------------------------------------------------------
/tests/columns/test_metadata.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2024-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import dataframely as dy
5 |
6 |
7 | class SchemaWithMetadata(dy.Schema):
8 | a = dy.Int64(metadata={"masked": True, "comment": "foo", "order": 1})
9 | b = dy.String()
10 |
11 |
12 | def test_metadata() -> None:
13 | assert SchemaWithMetadata.a.metadata == {
14 | "masked": True,
15 | "comment": "foo",
16 | "order": 1,
17 | }
18 | assert SchemaWithMetadata.b.metadata is None
19 |
--------------------------------------------------------------------------------
/tests/columns/test_polars_schema.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | import dataframely as dy
7 | from dataframely.testing.factory import create_schema
8 |
9 |
10 | def test_polars_schema() -> None:
11 | schema = create_schema("test", {"a": dy.Int32(nullable=False), "b": dy.Float32()})
12 | pl_schema = schema.polars_schema()
13 | assert pl_schema == {"a": pl.Int32, "b": pl.Float32}
14 |
--------------------------------------------------------------------------------
/tests/columns/test_pyarrow.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import pytest
5 | from polars._typing import TimeUnit
6 |
7 | import dataframely as dy
8 | from dataframely.columns import Column
9 | from dataframely.testing import (
10 | ALL_COLUMN_TYPES,
11 | COLUMN_TYPES,
12 | NO_VALIDATION_COLUMN_TYPES,
13 | SUPERTYPE_COLUMN_TYPES,
14 | create_schema,
15 | )
16 |
17 |
18 | @pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES)
19 | def test_equal_to_polars_schema(column_type: type[Column]) -> None:
20 | schema = create_schema("test", {"a": column_type()})
21 | actual = schema.pyarrow_schema()
22 | expected = schema.create_empty().to_arrow().schema
23 | assert actual == expected
24 |
25 |
26 | def test_equal_polars_schema_enum() -> None:
27 | schema = create_schema("test", {"a": dy.Enum(["a", "b"])})
28 | actual = schema.pyarrow_schema()
29 | expected = schema.create_empty().to_arrow().schema
30 | assert actual == expected
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "inner",
35 | [c() for c in ALL_COLUMN_TYPES]
36 | + [dy.List(t()) for t in ALL_COLUMN_TYPES]
37 | + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES]
38 | + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES],
39 | )
40 | def test_equal_polars_schema_list(inner: Column) -> None:
41 | schema = create_schema("test", {"a": dy.List(inner)})
42 | actual = schema.pyarrow_schema()
43 | expected = schema.create_empty().to_arrow().schema
44 | assert actual == expected
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "inner",
49 | [c() for c in NO_VALIDATION_COLUMN_TYPES]
50 | + [dy.List(t()) for t in NO_VALIDATION_COLUMN_TYPES]
51 | + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES]
52 | + [dy.Struct({"a": t()}) for t in NO_VALIDATION_COLUMN_TYPES],
53 | )
54 | @pytest.mark.parametrize(
55 | "shape",
56 | [
57 | 1,
58 | 0,
59 | (0, 0),
60 | ],
61 | )
62 | def test_equal_polars_schema_array(inner: Column, shape: int | tuple[int, ...]) -> None:
63 | schema = create_schema("test", {"a": dy.Array(inner, shape)})
64 | actual = schema.pyarrow_schema()
65 | expected = schema.create_empty().to_arrow().schema
66 | assert actual == expected
67 |
68 |
69 | @pytest.mark.parametrize(
70 | "inner",
71 | [c() for c in ALL_COLUMN_TYPES]
72 | + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES]
73 | + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES]
74 | + [dy.List(t()) for t in ALL_COLUMN_TYPES],
75 | )
76 | def test_equal_polars_schema_struct(inner: Column) -> None:
77 | schema = create_schema("test", {"a": dy.Struct({"a": inner})})
78 | actual = schema.pyarrow_schema()
79 | expected = schema.create_empty().to_arrow().schema
80 | assert actual == expected
81 |
82 |
83 | @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES)
84 | @pytest.mark.parametrize("nullable", [True, False])
85 | def test_nullability_information(column_type: type[Column], nullable: bool) -> None:
86 | schema = create_schema("test", {"a": column_type(nullable=nullable)})
87 | assert ("not null" in str(schema.pyarrow_schema())) != nullable
88 |
89 |
90 | @pytest.mark.parametrize("nullable", [True, False])
91 | def test_nullability_information_enum(nullable: bool) -> None:
92 | schema = create_schema("test", {"a": dy.Enum(["a", "b"], nullable=nullable)})
93 | assert ("not null" in str(schema.pyarrow_schema())) != nullable
94 |
95 |
96 | @pytest.mark.parametrize(
97 | "inner",
98 | [c() for c in ALL_COLUMN_TYPES]
99 | + [dy.List(t()) for t in ALL_COLUMN_TYPES]
100 | + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES]
101 | + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES],
102 | )
103 | @pytest.mark.parametrize("nullable", [True, False])
104 | def test_nullability_information_list(inner: Column, nullable: bool) -> None:
105 | schema = create_schema("test", {"a": dy.List(inner, nullable=nullable)})
106 | assert ("not null" in str(schema.pyarrow_schema())) != nullable
107 |
108 |
109 | @pytest.mark.parametrize(
110 | "inner",
111 | [c() for c in ALL_COLUMN_TYPES]
112 | + [dy.Struct({"a": t()}) for t in ALL_COLUMN_TYPES]
113 | + [dy.Array(t(), 1) for t in NO_VALIDATION_COLUMN_TYPES]
114 | + [dy.List(t()) for t in ALL_COLUMN_TYPES],
115 | )
116 | @pytest.mark.parametrize("nullable", [True, False])
117 | def test_nullability_information_struct(inner: Column, nullable: bool) -> None:
118 | schema = create_schema("test", {"a": dy.Struct({"a": inner}, nullable=nullable)})
119 | assert ("not null" in str(schema.pyarrow_schema())) != nullable
120 |
121 |
122 | def test_multiple_columns() -> None:
123 | schema = create_schema("test", {"a": dy.Int32(nullable=False), "b": dy.Integer()})
124 | assert str(schema.pyarrow_schema()).split("\n") == ["a: int32 not null", "b: int64"]
125 |
126 |
127 | @pytest.mark.parametrize("time_unit", ["ns", "us", "ms"])
128 | def test_datetime_time_unit(time_unit: TimeUnit) -> None:
129 | schema = create_schema("test", {"a": dy.Datetime(time_unit=time_unit)})
130 | assert str(schema.pyarrow_schema()) == f"a: timestamp[{time_unit}]"
131 |
--------------------------------------------------------------------------------
/tests/columns/test_rules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 | from polars.testing import assert_frame_equal
7 |
8 | from dataframely.columns import Column
9 | from dataframely.columns.float import _BaseFloat
10 | from dataframely.testing import (
11 | COLUMN_TYPES,
12 | SUPERTYPE_COLUMN_TYPES,
13 | evaluate_rules,
14 | rules_from_exprs,
15 | )
16 |
17 |
18 | @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES)
19 | @pytest.mark.parametrize("nullable", [True, False])
20 | def test_rule_count_nullability(column_type: type[Column], nullable: bool) -> None:
21 | column = column_type(nullable=nullable)
22 | assert len(column.validation_rules(pl.col("a"))) == int(not nullable) + (
23 | 1 if isinstance(column, _BaseFloat) else 0
24 | )
25 |
26 |
27 | @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES)
28 | def test_nullability_rule_for_primary_key(column_type: type[Column]) -> None:
29 | column = column_type(primary_key=True)
30 | assert len(column.validation_rules(pl.col("a"))) == (
31 | 2
32 | if isinstance(column, _BaseFloat)
33 | else 1 # floats additionally have nan/inf rules
34 | )
35 |
36 |
37 | @pytest.mark.parametrize("column_type", COLUMN_TYPES + SUPERTYPE_COLUMN_TYPES)
38 | def test_nullability_rule(column_type: type[Column]) -> None:
39 | column = column_type(nullable=False)
40 | lf = pl.LazyFrame({"a": [None]}, schema={"a": column.dtype})
41 | actual = evaluate_rules(lf, rules_from_exprs(column.validation_rules(pl.col("a"))))
42 | expected = pl.LazyFrame({"nullability": [False]})
43 | assert_frame_equal(actual.select(expected.collect_schema().names()), expected)
44 |
--------------------------------------------------------------------------------
/tests/columns/test_str.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import pytest
5 |
6 | import dataframely as dy
7 | from dataframely.columns import Column
8 | from dataframely.testing import ALL_COLUMN_TYPES
9 |
10 |
11 | @pytest.mark.parametrize("column_type", ALL_COLUMN_TYPES)
12 | def test_string_representation(column_type: type[Column]) -> None:
13 | column = column_type()
14 | assert str(column) == column_type.__name__.lower()
15 |
16 |
17 | def test_string_representation_enum() -> None:
18 | column = dy.Enum(["a", "b"])
19 | assert str(column) == dy.Enum.__name__.lower()
20 |
21 |
22 | def test_string_representation_list() -> None:
23 | column = dy.List(dy.String())
24 | assert str(column) == dy.List.__name__.lower()
25 |
26 |
27 | def test_string_representation_array() -> None:
28 | column = dy.Array(dy.String(), 1)
29 | assert str(column) == dy.Array.__name__.lower()
30 |
31 |
32 | def test_string_representation_struct() -> None:
33 | column = dy.Struct({"a": dy.String()})
34 | assert str(column) == dy.Struct.__name__.lower()
35 |
--------------------------------------------------------------------------------
/tests/columns/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | from dataframely.columns._utils import first_non_null
6 |
7 |
8 | def test_first_non_null_basic() -> None:
9 | assert first_non_null(1, 2, default=3) == 1
10 | assert first_non_null(None, 2, default=3) == 2
11 | assert first_non_null(None, None, default=3) == 3
12 |
13 |
14 | def test_first_non_null_allow_null_response() -> None:
15 | assert first_non_null(None, None, None, allow_null_response=True) is None
16 |
17 |
18 | def test_first_non_null_with_terminal() -> None:
19 | assert first_non_null(None, None, None, default=42) == 42
20 | assert first_non_null(None, 3, None, default=42) == 3
21 |
22 |
23 | def test_first_non_null_mixed_types() -> None:
24 | assert first_non_null(None, "a", default=3) == "a"
25 | assert first_non_null(None, 0, default="b") == 0 # 0 is a valid non-null value
26 | assert (
27 | first_non_null(None, False, default=1) is False
28 | ) # False is a valid non-null value
29 |
30 |
31 | def test_first_non_null_with_kwargs() -> None:
32 | assert first_non_null(None, None, allow_null_response=True) is None
33 | assert first_non_null(None, None, default="fallback") == "fallback"
34 |
--------------------------------------------------------------------------------
/tests/core_validation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
--------------------------------------------------------------------------------
/tests/core_validation/test_column_validation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | from dataframely._validation import validate_columns
8 | from dataframely.exc import ValidationError
9 |
10 |
11 | def test_success() -> None:
12 | df = pl.DataFrame(schema={k: pl.Int64() for k in ["a", "b"]})
13 | lf = validate_columns(df.lazy(), actual=df.schema.keys(), expected=["a"])
14 | assert set(lf.collect_schema().names()) == {"a"}
15 |
16 |
17 | @pytest.mark.parametrize(
18 | ("actual", "expected", "error"),
19 | [
20 | (["a"], ["a", "b"], r"1 columns in the schema are missing.*'b'"),
21 | (["c"], ["a", "b"], r"2 columns in the schema are missing.*'a'.*'b'"),
22 | ],
23 | )
24 | def test_failure(actual: list[str], expected: list[str], error: str) -> None:
25 | df = pl.DataFrame(schema={k: pl.Int64() for k in actual})
26 | with pytest.raises(ValidationError, match=error):
27 | validate_columns(df.lazy(), actual=df.schema.keys(), expected=expected)
28 |
--------------------------------------------------------------------------------
/tests/core_validation/test_dtype_validation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import re
5 |
6 | import polars as pl
7 | import polars.exceptions as plexc
8 | import pytest
9 | from polars.testing import assert_frame_equal
10 |
11 | import dataframely as dy
12 | from dataframely._validation import DtypeCasting, validate_dtypes
13 | from dataframely.columns import Column
14 | from dataframely.exc import DtypeValidationError
15 |
16 |
17 | @pytest.mark.parametrize(
18 | ("actual", "expected", "casting"),
19 | [
20 | ({"a": pl.Int64()}, {"a": dy.Int64()}, "none"),
21 | ({"a": pl.Int32()}, {"a": dy.Int64()}, "lenient"),
22 | ({"a": pl.Int32()}, {"a": dy.Int64()}, "strict"),
23 | (
24 | {"a": pl.Int32(), "b": pl.String()},
25 | {"a": dy.Int64(), "b": dy.UInt8()},
26 | "strict",
27 | ),
28 | ],
29 | )
30 | def test_success(
31 | actual: dict[str, pl.DataType],
32 | expected: dict[str, Column],
33 | casting: DtypeCasting,
34 | ) -> None:
35 | df = pl.DataFrame(schema=actual)
36 | lf = validate_dtypes(
37 | df.lazy(), actual=df.schema, expected=expected, casting=casting
38 | )
39 | schema = lf.collect_schema()
40 | for key, col in expected.items():
41 | assert col.validate_dtype(schema[key])
42 |
43 |
44 | @pytest.mark.parametrize(
45 | ("actual", "expected", "error", "fail_columns"),
46 | [
47 | (
48 | {"a": pl.Int32()},
49 | {"a": dy.Int64()},
50 | r"1 columns have an invalid dtype.*\n.*got dtype 'Int32'",
51 | {"a"},
52 | ),
53 | (
54 | {"a": pl.Int32(), "b": pl.String()},
55 | {"a": dy.Int64(), "b": dy.UInt8()},
56 | r"2 columns have an invalid dtype",
57 | {"a", "b"},
58 | ),
59 | ],
60 | )
61 | def test_failure(
62 | actual: dict[str, pl.DataType],
63 | expected: dict[str, Column],
64 | error: str,
65 | fail_columns: set[str],
66 | ) -> None:
67 | df = pl.DataFrame(schema=actual)
68 | try:
69 | validate_dtypes(df.lazy(), actual=df.schema, expected=expected, casting="none")
70 | assert False # above should raise
71 | except DtypeValidationError as exc:
72 | assert set(exc.errors.keys()) == fail_columns
73 | assert re.match(error, str(exc))
74 |
75 |
76 | def test_lenient_casting() -> None:
77 | lf = pl.LazyFrame(
78 | {"a": [1, 2, 3], "b": ["foo", "12", "1313"]},
79 | schema={"a": pl.Int64(), "b": pl.String()},
80 | )
81 | actual = validate_dtypes(
82 | lf,
83 | actual=lf.collect_schema(),
84 | expected={"a": dy.UInt8(), "b": dy.UInt8()},
85 | casting="lenient",
86 | )
87 | expected = pl.LazyFrame(
88 | {"a": [1, 2, 3], "b": [None, 12, None]},
89 | schema={"a": pl.UInt8(), "b": pl.UInt8()},
90 | )
91 | assert_frame_equal(actual, expected)
92 |
93 |
94 | def test_strict_casting() -> None:
95 | lf = pl.LazyFrame(
96 | {"a": [1, 2, 3], "b": ["foo", "12", "1313"]},
97 | schema={"a": pl.Int64(), "b": pl.String()},
98 | )
99 | lf_valid = validate_dtypes(
100 | lf,
101 | actual=lf.collect_schema(),
102 | expected={"a": dy.UInt8(), "b": dy.UInt8()},
103 | casting="strict",
104 | )
105 | with pytest.raises(plexc.InvalidOperationError):
106 | lf_valid.collect()
107 |
--------------------------------------------------------------------------------
/tests/core_validation/test_rule_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | from polars.testing import assert_frame_equal
6 |
7 | from dataframely._rule import GroupRule, Rule
8 | from dataframely.testing import evaluate_rules
9 |
10 |
11 | def test_single_column_single_rule() -> None:
12 | lf = pl.LazyFrame({"a": [1, 2]})
13 | rules = {
14 | "a|min": Rule(pl.col("a") >= 2),
15 | }
16 | actual = evaluate_rules(lf, rules)
17 |
18 | expected = pl.LazyFrame({"a|min": [False, True]})
19 | assert_frame_equal(actual, expected)
20 |
21 |
22 | def test_single_column_multi_rule() -> None:
23 | lf = pl.LazyFrame({"a": [1, 2, 3]})
24 | rules = {
25 | "a|min": Rule(pl.col("a") >= 2),
26 | "a|max": Rule(pl.col("a") <= 2),
27 | }
28 | actual = evaluate_rules(lf, rules)
29 |
30 | expected = pl.LazyFrame(
31 | {"a|min": [False, True, True], "a|max": [True, True, False]}
32 | )
33 | assert_frame_equal(actual, expected)
34 |
35 |
36 | def test_multi_column_multi_rule() -> None:
37 | lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
38 | rules = {
39 | "a|min": Rule(pl.col("a") >= 2),
40 | "a|max": Rule(pl.col("a") <= 2),
41 | "b|even": Rule(pl.col("b") % 2 == 0),
42 | }
43 | actual = evaluate_rules(lf, rules)
44 |
45 | expected = pl.LazyFrame(
46 | {
47 | "a|min": [False, True, True],
48 | "a|max": [True, True, False],
49 | "b|even": [True, False, True],
50 | }
51 | )
52 | assert_frame_equal(actual, expected)
53 |
54 |
55 | def test_cross_column_rule() -> None:
56 | lf = pl.LazyFrame({"a": [1, 1, 2, 2], "b": [1, 1, 1, 2]})
57 | rules = {"primary_key": Rule(~pl.struct("a", "b").is_duplicated())}
58 | actual = evaluate_rules(lf, rules)
59 |
60 | expected = pl.LazyFrame({"primary_key": [False, False, True, True]})
61 | assert_frame_equal(actual, expected)
62 |
63 |
64 | def test_group_rule() -> None:
65 | lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]})
66 | rules: dict[str, Rule] = {
67 | "unique_b": GroupRule(pl.col("b").n_unique() == 1, group_columns=["a"])
68 | }
69 | actual = evaluate_rules(lf, rules)
70 |
71 | expected = pl.LazyFrame({"unique_b": [True, True, False, False, True]})
72 | assert_frame_equal(actual, expected)
73 |
74 |
75 | def test_simple_rule_and_group_rule() -> None:
76 | lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]})
77 | rules: dict[str, Rule] = {
78 | "b|max": Rule(pl.col("b") <= 1),
79 | "unique_b": GroupRule(pl.col("b").n_unique() == 1, group_columns=["a"]),
80 | }
81 | actual = evaluate_rules(lf, rules)
82 |
83 | expected = pl.LazyFrame(
84 | {
85 | "b|max": [True, True, True, False, True],
86 | "unique_b": [True, True, False, False, True],
87 | }
88 | )
89 | assert_frame_equal(actual, expected, check_column_order=False)
90 |
91 |
92 | def test_multiple_group_rules() -> None:
93 | lf = pl.LazyFrame({"a": [1, 1, 2, 2, 3], "b": [1, 1, 1, 2, 1]})
94 | rules: dict[str, Rule] = {
95 | "unique_b": GroupRule(pl.col("b").n_unique() == 1, group_columns=["a"]),
96 | "sum_b": GroupRule(pl.col("b").sum() >= 2, group_columns=["a"]),
97 | "group_count": GroupRule(pl.len() >= 2, group_columns=["a", "b"]),
98 | }
99 | actual = evaluate_rules(lf, rules)
100 |
101 | expected = pl.LazyFrame(
102 | {
103 | "unique_b": [True, True, False, False, True],
104 | "sum_b": [True, True, True, True, False],
105 | "group_count": [True, True, False, False, False],
106 | }
107 | )
108 | assert_frame_equal(actual, expected)
109 |
--------------------------------------------------------------------------------
/tests/functional/test_concat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 |
9 |
10 | class MySchema(dy.Schema):
11 | a = dy.Int64()
12 |
13 |
14 | class SimpleCollection(dy.Collection):
15 | first: dy.LazyFrame[MySchema]
16 | second: dy.LazyFrame[MySchema] | None
17 | third: dy.LazyFrame[MySchema] | None
18 |
19 |
20 | def test_concat() -> None:
21 | col1 = SimpleCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
22 | col2 = SimpleCollection.cast(
23 | {
24 | "first": pl.LazyFrame({"a": [4, 5, 6]}),
25 | "second": pl.LazyFrame({"a": [4, 5, 6]}),
26 | }
27 | )
28 | col3 = SimpleCollection.cast(
29 | {
30 | "first": pl.LazyFrame({"a": [7, 8, 9]}),
31 | "second": pl.LazyFrame({"a": [7, 8, 9]}),
32 | "third": pl.LazyFrame({"a": [7, 8, 9]}),
33 | }
34 | )
35 | concat = dy.concat_collection_members([col1, col2, col3])
36 | assert concat["first"].collect().get_column("a").to_list() == list(range(1, 10))
37 | assert concat["second"].collect().get_column("a").to_list() == list(range(4, 10))
38 | assert concat["third"].collect().get_column("a").to_list() == list(range(7, 10))
39 |
40 |
41 | def test_concat_empty() -> None:
42 | with pytest.raises(ValueError):
43 | dy.concat_collection_members([])
44 |
--------------------------------------------------------------------------------
/tests/functional/test_relationships.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 |
9 | # -------------------------------------- SCHEMA -------------------------------------- #
10 |
11 |
12 | class DepartmentSchema(dy.Schema):
13 | department_id = dy.Int64(primary_key=True)
14 |
15 |
16 | class ManagerSchema(dy.Schema):
17 | department_id = dy.Int64(primary_key=True)
18 | name = dy.String(nullable=False)
19 |
20 |
21 | class EmployeeSchema(dy.Schema):
22 | department_id = dy.Int64(primary_key=True)
23 | employee_number = dy.Int64(primary_key=True)
24 | name = dy.String(nullable=False)
25 |
26 |
27 | # ------------------------------------- FIXTURES ------------------------------------- #
28 |
29 |
30 | @pytest.fixture()
31 | def departments() -> dy.LazyFrame[DepartmentSchema]:
32 | return DepartmentSchema.cast(pl.LazyFrame({"department_id": [1, 2]}))
33 |
34 |
35 | @pytest.fixture()
36 | def managers() -> dy.LazyFrame[ManagerSchema]:
37 | return ManagerSchema.cast(
38 | pl.LazyFrame({"department_id": [1], "name": ["Donald Duck"]})
39 | )
40 |
41 |
42 | @pytest.fixture()
43 | def employees() -> dy.LazyFrame[EmployeeSchema]:
44 | return EmployeeSchema.cast(
45 | pl.LazyFrame(
46 | {
47 | "department_id": [2, 2, 2],
48 | "employee_number": [101, 102, 103],
49 | "name": ["Huey", "Dewey", "Louie"],
50 | }
51 | )
52 | )
53 |
54 |
55 | # ------------------------------------------------------------------------------------ #
56 | # TESTS #
57 | # ------------------------------------------------------------------------------------ #
58 |
59 |
60 | def test_one_to_one(
61 | departments: dy.LazyFrame[DepartmentSchema],
62 | managers: dy.LazyFrame[ManagerSchema],
63 | ) -> None:
64 | actual = dy.filter_relationship_one_to_one(
65 | departments, managers, on="department_id"
66 | )
67 | assert actual.select("department_id").collect().to_series().to_list() == [1]
68 |
69 |
70 | def test_one_to_at_least_one(
71 | departments: dy.LazyFrame[DepartmentSchema],
72 | employees: dy.LazyFrame[EmployeeSchema],
73 | ) -> None:
74 | actual = dy.filter_relationship_one_to_at_least_one(
75 | departments, employees, on="department_id"
76 | )
77 | assert actual.select("department_id").collect().to_series().to_list() == [2]
78 |
--------------------------------------------------------------------------------
/tests/schema/test_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 |
5 | import polars as pl
6 | import pytest
7 |
8 | import dataframely as dy
9 | from dataframely._rule import Rule
10 | from dataframely.exc import ImplementationError
11 | from dataframely.testing import create_schema
12 |
13 |
14 | class MySchema(dy.Schema):
15 | a = dy.Integer(primary_key=True)
16 | b = dy.String(primary_key=True)
17 | c = dy.Float64()
18 | d = dy.Any(alias="e")
19 |
20 |
21 | class MySchemaWithRule(MySchema):
22 | @dy.rule()
23 | def a_greater_than_c() -> pl.Expr:
24 | return pl.col("a") > pl.col("c")
25 |
26 |
27 | def test_column_names() -> None:
28 | assert MySchema.column_names() == ["a", "b", "c", "e"]
29 |
30 |
31 | def test_columns() -> None:
32 | columns = MySchema.columns()
33 | assert isinstance(columns["a"], dy.Integer)
34 | assert isinstance(columns["b"], dy.String)
35 | assert isinstance(columns["c"], dy.Float64)
36 | assert isinstance(columns["e"], dy.Any)
37 |
38 |
39 | def test_nullability() -> None:
40 | columns = MySchema.columns()
41 | assert not columns["a"].nullable
42 | assert not columns["b"].nullable
43 | assert columns["c"].nullable
44 | assert columns["e"].nullable
45 |
46 |
47 | def test_primary_keys() -> None:
48 | assert MySchema.primary_keys() == ["a", "b"]
49 |
50 |
51 | def test_no_rule_named_primary_key() -> None:
52 | with pytest.raises(ImplementationError):
53 | create_schema(
54 | "test",
55 | {"a": dy.String()},
56 | {"primary_key": Rule(pl.col("a").str.len_bytes() > 1)},
57 | )
58 |
59 |
60 | def test_col() -> None:
61 | assert MySchema.a.col.__dict__ == pl.col("a").__dict__
62 | assert MySchema.b.col.__dict__ == pl.col("b").__dict__
63 | assert MySchema.c.col.__dict__ == pl.col("c").__dict__
64 | assert MySchema.d.col.__dict__ == pl.col("e").__dict__
65 |
66 |
67 | def test_col_raise_if_none() -> None:
68 | class InvalidSchema(dy.Schema):
69 | a = dy.Integer()
70 |
71 | # Manually override alias to be ``None``.
72 | InvalidSchema.a.alias = None
73 | with pytest.raises(ValueError):
74 | InvalidSchema.a.col
75 |
76 |
77 | def test_col_in_polars_expression() -> None:
78 | df = (
79 | pl.DataFrame({"a": [1, 2], "b": ["a", "b"], "c": [1.0, 2.0], "e": [None, None]})
80 | .filter((MySchema.b.col == "a") & (MySchema.a.col > 0))
81 | .select(MySchema.a.col)
82 | )
83 | assert df.row(0) == (1,)
84 |
85 |
86 | def test_dunder_name() -> None:
87 | assert MySchema.__name__ == "MySchema"
88 |
89 |
90 | def test_dunder_name_with_rule() -> None:
91 | assert MySchemaWithRule.__name__ == "MySchemaWithRule"
92 |
--------------------------------------------------------------------------------
/tests/schema/test_cast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from typing import Any
5 |
6 | import polars as pl
7 | import polars.exceptions as plexc
8 | import pytest
9 |
10 | import dataframely as dy
11 |
12 |
13 | class MySchema(dy.Schema):
14 | a = dy.Float64()
15 | b = dy.String()
16 |
17 |
18 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
19 | @pytest.mark.parametrize(
20 | "data",
21 | [
22 | {"a": [3], "b": [1]},
23 | {"a": [1], "b": [2], "c": [3]},
24 | ],
25 | )
26 | def test_cast_valid(
27 | df_type: type[pl.DataFrame] | type[pl.LazyFrame], data: dict[str, Any]
28 | ) -> None:
29 | df = df_type(data)
30 | out = MySchema.cast(df)
31 | assert isinstance(out, df_type)
32 | assert out.lazy().collect_schema() == MySchema.polars_schema()
33 |
34 |
35 | def test_cast_invalid_schema_eager() -> None:
36 | df = pl.DataFrame({"a": [1]})
37 | with pytest.raises(plexc.ColumnNotFoundError):
38 | MySchema.cast(df)
39 |
40 |
41 | def test_cast_invalid_schema_lazy() -> None:
42 | lf = pl.LazyFrame({"a": [1]})
43 | lf = MySchema.cast(lf)
44 | with pytest.raises(plexc.ColumnNotFoundError):
45 | lf.collect()
46 |
--------------------------------------------------------------------------------
/tests/schema/test_create_empty.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 |
9 |
10 | class MySchema(dy.Schema):
11 | a = dy.Int64()
12 | b = dy.String()
13 |
14 |
15 | @pytest.mark.parametrize("with_arg", [True, False])
16 | def test_create_empty_eager(with_arg: bool) -> None:
17 | if with_arg:
18 | df = MySchema.create_empty(lazy=False)
19 | else:
20 | df = MySchema.create_empty()
21 |
22 | assert isinstance(df, pl.DataFrame)
23 | assert df.columns == ["a", "b"]
24 | assert df.dtypes == [pl.Int64, pl.String]
25 | assert len(df) == 0
26 |
27 |
28 | def test_create_empty_lazy() -> None:
29 | df = MySchema.create_empty(lazy=True)
30 | assert isinstance(df, pl.LazyFrame)
31 | assert df.collect_schema().names() == ["a", "b"]
32 | assert df.collect_schema().dtypes() == [pl.Int64, pl.String]
33 | assert len(df.collect()) == 0
34 |
--------------------------------------------------------------------------------
/tests/schema/test_create_empty_if_none.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2024-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 | from polars.testing import assert_frame_equal
7 |
8 | import dataframely as dy
9 |
10 |
11 | class MySchema(dy.Schema):
12 | a = dy.Int64()
13 | b = dy.String()
14 |
15 |
16 | @pytest.mark.parametrize("lazy_in", [True, False])
17 | @pytest.mark.parametrize("lazy_out", [True, False])
18 | def test_create_empty_if_none_non_none(lazy_in: bool, lazy_out: bool) -> None:
19 | # Arrange
20 | df_raw = MySchema.validate(pl.DataFrame({"a": [1], "b": ["foo"]}))
21 | df = df_raw.lazy() if lazy_in else df_raw
22 |
23 | # Act
24 | result = MySchema.create_empty_if_none(df, lazy=lazy_out)
25 |
26 | # Assert
27 | if lazy_out:
28 | assert isinstance(result, pl.LazyFrame)
29 | else:
30 | assert isinstance(result, pl.DataFrame)
31 | assert_frame_equal(result.lazy().collect(), df.lazy().collect())
32 |
33 |
34 | @pytest.mark.parametrize("lazy", [True, False])
35 | def test_create_empty_if_none_none(lazy: bool) -> None:
36 | # Act
37 | result = MySchema.create_empty_if_none(None, lazy=lazy)
38 |
39 | # Assert
40 | if lazy:
41 | assert isinstance(result, pl.LazyFrame)
42 | else:
43 | assert isinstance(result, pl.DataFrame)
44 | assert_frame_equal(result.lazy().collect(), MySchema.create_empty())
45 |
--------------------------------------------------------------------------------
/tests/schema/test_inheritance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import dataframely as dy
5 |
6 |
7 | class ParentSchema(dy.Schema):
8 | a = dy.Integer()
9 |
10 |
11 | class ChildSchema(ParentSchema):
12 | b = dy.Integer()
13 |
14 |
15 | class GrandchildSchema(ChildSchema):
16 | c = dy.Integer()
17 |
18 |
19 | def test_columns() -> None:
20 | assert ParentSchema.column_names() == ["a"]
21 | assert ChildSchema.column_names() == ["a", "b"]
22 | assert GrandchildSchema.column_names() == ["a", "b", "c"]
23 |
--------------------------------------------------------------------------------
/tests/schema/test_rule_implementation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 |
7 | import dataframely as dy
8 | from dataframely._rule import GroupRule, Rule
9 | from dataframely.exc import ImplementationError, RuleImplementationError
10 | from dataframely.testing import create_schema
11 |
12 |
13 | def test_group_rule_group_by_error() -> None:
14 | with pytest.raises(
15 | ImplementationError,
16 | match=(
17 | r"Group validation rule 'b_greater_zero' has been implemented "
18 | r"incorrectly\. It references 1 columns which are not in the schema"
19 | ),
20 | ):
21 | create_schema(
22 | "test",
23 | columns={"a": dy.Integer(), "b": dy.Integer()},
24 | rules={
25 | "b_greater_zero": GroupRule(
26 | (pl.col("b") > 0).all(), group_columns=["c"]
27 | )
28 | },
29 | )
30 |
31 |
32 | def test_rule_implementation_error() -> None:
33 | with pytest.raises(
34 | RuleImplementationError, match=r"rule 'integer_rule'.*returns dtype 'Int64'"
35 | ):
36 | create_schema(
37 | "test",
38 | columns={"a": dy.Integer()},
39 | rules={"integer_rule": Rule(pl.col("a") + 1)},
40 | )
41 |
42 |
43 | def test_group_rule_implementation_error() -> None:
44 | with pytest.raises(
45 | RuleImplementationError,
46 | match=(
47 | r"rule 'b_greater_zero'.*returns dtype 'List\(Boolean\)'.*"
48 | r"make sure to use an aggregation function"
49 | ),
50 | ):
51 | create_schema(
52 | "test",
53 | columns={"a": dy.Integer(), "b": dy.Integer()},
54 | rules={"b_greater_zero": GroupRule(pl.col("b") > 0, group_columns=["a"])},
55 | )
56 |
57 |
58 | def test_rule_column_overlap_error() -> None:
59 | with pytest.raises(
60 | ImplementationError,
61 | match=r"Rules and columns must not be named equally but found 1 overlaps",
62 | ):
63 | create_schema(
64 | "test",
65 | columns={"test": dy.Integer(alias="a")},
66 | rules={"a": Rule(pl.col("a") > 0)},
67 | )
68 |
--------------------------------------------------------------------------------
/tests/schema/test_sample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import numpy as np
5 | import polars as pl
6 | import pytest
7 | from polars.testing import assert_frame_equal
8 |
9 | import dataframely as dy
10 | from dataframely.random import Generator
11 |
12 |
13 | class MySimpleSchema(dy.Schema):
14 | a = dy.Int64()
15 | b = dy.String()
16 |
17 |
18 | class PrimaryKeySchema(dy.Schema):
19 | a = dy.Int64(primary_key=True)
20 | b = dy.String()
21 |
22 |
23 | class CheckSchema(dy.Schema):
24 | a = dy.UInt64()
25 | b = dy.UInt64()
26 |
27 | @dy.rule()
28 | def a_ge_b() -> pl.Expr:
29 | return pl.col("a") >= pl.col("b")
30 |
31 |
32 | class ComplexSchema(dy.Schema):
33 | a = dy.UInt8(primary_key=True)
34 | b = dy.UInt8(primary_key=True)
35 |
36 | @dy.rule()
37 | def a_greater_b() -> pl.Expr:
38 | return pl.col("a") > pl.col("b")
39 |
40 | @dy.rule(group_by=["a"])
41 | def minimum_two_per_a() -> pl.Expr:
42 | return pl.len() >= 2
43 |
44 |
45 | class LimitedComplexSchema(dy.Schema):
46 | a = dy.UInt8(primary_key=True)
47 | b = dy.UInt8(primary_key=True)
48 |
49 | @dy.rule()
50 | def a_greater_b() -> pl.Expr:
51 | return pl.col("a") > pl.col("b")
52 |
53 | @dy.rule(group_by=["a"])
54 | def minimum_two_per_a() -> pl.Expr:
55 | # We cannot generate more than 768 rows with this rule
56 | return pl.len() <= 3
57 |
58 |
59 | # --------------------------------------- TESTS -------------------------------------- #
60 |
61 |
62 | @pytest.mark.parametrize("n", [0, 1000])
63 | def test_sample_deterministic(n: int) -> None:
64 | with dy.Config(max_sampling_iterations=1):
65 | df = MySimpleSchema.sample(n)
66 | MySimpleSchema.validate(df)
67 |
68 |
69 | @pytest.mark.parametrize("schema", [PrimaryKeySchema, CheckSchema, ComplexSchema])
70 | @pytest.mark.parametrize("n", [0, 1000])
71 | def test_sample_fuzzy(schema: type[dy.Schema], n: int) -> None:
72 | df = schema.sample(n, generator=Generator(seed=42))
73 | assert len(df) == n
74 | schema.validate(df)
75 |
76 |
77 | def test_sample_fuzzy_failure() -> None:
78 | with pytest.raises(ValueError):
79 | with dy.Config(max_sampling_iterations=5):
80 | ComplexSchema.sample(1000, generator=Generator(seed=42))
81 |
82 |
83 | @pytest.mark.parametrize("n", [1, 1000])
84 | def test_sample_overrides(n: int) -> None:
85 | df = CheckSchema.sample(overrides={"b": range(n)})
86 | CheckSchema.validate(df)
87 | assert len(df) == n
88 | assert df.get_column("b").to_list() == list(range(n))
89 |
90 |
91 | def test_sample_overrides_with_removing_groups() -> None:
92 | generator = Generator()
93 | n = 333 # we cannot use something too large here or we'll never return
94 | overrides = np.random.randint(100, size=n)
95 | df = LimitedComplexSchema.sample(overrides={"b": overrides}, generator=generator)
96 | LimitedComplexSchema.validate(df)
97 | assert len(df) == n
98 | assert df.get_column("b").to_list() == list(overrides)
99 |
100 |
101 | @pytest.mark.parametrize("n", [1, 1000])
102 | def test_sample_overrides_allow_no_fuzzy(n: int) -> None:
103 | with dy.Config(max_sampling_iterations=1):
104 | df = CheckSchema.sample(n, overrides={"b": [0] * n})
105 | CheckSchema.validate(df)
106 | assert len(df) == n
107 | assert df.get_column("b").to_list() == [0] * n
108 |
109 |
110 | @pytest.mark.parametrize("n", [1, 1000])
111 | def test_sample_overrides_full(n: int) -> None:
112 | df = CheckSchema.sample(n)
113 | df_override = CheckSchema.sample(n, overrides=df.to_dict())
114 | assert_frame_equal(df, df_override)
115 |
116 |
117 | def test_sample_overrides_row_layout() -> None:
118 | df = MySimpleSchema.sample(overrides=[{"a": 1}, {"a": 2}, {"a": 3}])
119 | assert len(df) == 3
120 | assert df.get_column("a").to_list() == [1, 2, 3]
121 |
122 |
123 | def test_sample_overrides_invalid_column() -> None:
124 | with pytest.raises(ValueError, match=r"not in the schema"):
125 | MySimpleSchema.sample(overrides={"foo": []})
126 |
127 |
128 | def test_sample_overrides_invalid_length() -> None:
129 | with pytest.raises(ValueError, match=r"`num_rows` is different"):
130 | MySimpleSchema.sample(3, overrides={"a": [1, 2]})
131 |
132 |
133 | def test_sample_no_overrides_no_num_rows() -> None:
134 | # This case infers `num_rows == 1`
135 | df = MySimpleSchema.sample()
136 | MySimpleSchema.validate(df)
137 | assert len(df) == 1
138 |
--------------------------------------------------------------------------------
/tests/schema/test_validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 | import pytest
6 | from polars.testing import assert_frame_equal
7 |
8 | import dataframely as dy
9 | from dataframely.exc import DtypeValidationError, RuleValidationError, ValidationError
10 |
11 |
12 | class MySchema(dy.Schema):
13 | a = dy.Int64(primary_key=True)
14 | b = dy.String(nullable=False, max_length=5)
15 | c = dy.String()
16 |
17 |
18 | class MyComplexSchema(dy.Schema):
19 | a = dy.Int64()
20 | b = dy.Int64()
21 |
22 | @dy.rule()
23 | def b_greater_a() -> pl.Expr:
24 | return pl.col("b") > pl.col("a")
25 |
26 | @dy.rule(group_by=["a"])
27 | def b_unique_within_a() -> pl.Expr:
28 | return pl.col("b").n_unique() == 1
29 |
30 |
31 | # -------------------------------------- COLUMNS ------------------------------------- #
32 |
33 |
34 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
35 | def test_missing_columns(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
36 | df = df_type({"a": [1], "b": [""]})
37 | with pytest.raises(ValidationError):
38 | MySchema.validate(df)
39 | assert not MySchema.is_valid(df)
40 |
41 |
42 | # -------------------------------------- DTYPES -------------------------------------- #
43 |
44 |
45 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
46 | def test_invalid_dtype(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
47 | df = df_type({"a": [1], "b": [1], "c": [1]})
48 | try:
49 | MySchema.validate(df)
50 | assert False # above should raise
51 | except DtypeValidationError as exc:
52 | assert len(exc.errors) == 2
53 | assert not MySchema.is_valid(df)
54 |
55 |
56 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
57 | def test_invalid_dtype_cast(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
58 | df = df_type({"a": [1], "b": [1], "c": [1]})
59 | actual = MySchema.validate(df, cast=True)
60 | expected = pl.DataFrame({"a": [1], "b": ["1"], "c": ["1"]})
61 | assert_frame_equal(actual, expected)
62 | assert MySchema.is_valid(df, cast=True)
63 |
64 |
65 | # --------------------------------------- RULES -------------------------------------- #
66 |
67 |
68 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
69 | def test_invalid_column_contents(
70 | df_type: type[pl.DataFrame] | type[pl.LazyFrame],
71 | ) -> None:
72 | df = df_type({"a": [1, 2, 3], "b": ["x", "longtext", None], "c": ["1", None, "3"]})
73 | try:
74 | MySchema.validate(df)
75 | assert False # above should raise
76 | except RuleValidationError as exc:
77 | assert len(exc.schema_errors) == 0
78 | assert exc.column_errors == {"b": {"nullability": 1, "max_length": 1}}
79 | assert not MySchema.is_valid(df)
80 |
81 |
82 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
83 | def test_invalid_primary_key(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
84 | df = df_type({"a": [1, 1], "b": ["x", "y"], "c": ["1", "2"]})
85 | try:
86 | MySchema.validate(df)
87 | assert False # above should raise
88 | except RuleValidationError as exc:
89 | assert exc.schema_errors == {"primary_key": 2}
90 | assert len(exc.column_errors) == 0
91 | assert not MySchema.is_valid(df)
92 |
93 |
94 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
95 | def test_violated_custom_rule(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
96 | df = df_type({"a": [1, 1, 2, 3, 3], "b": [2, 2, 2, 4, 5]})
97 | try:
98 | MyComplexSchema.validate(df)
99 | assert False # above should raise
100 | except RuleValidationError as exc:
101 | assert exc.schema_errors == {"b_greater_a": 1, "b_unique_within_a": 2}
102 | assert len(exc.column_errors) == 0
103 | assert not MyComplexSchema.is_valid(df)
104 |
105 |
106 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
107 | def test_success_multi_row_strip_cast(
108 | df_type: type[pl.DataFrame] | type[pl.LazyFrame],
109 | ) -> None:
110 | df = df_type(
111 | {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": [1, None, None], "d": [1, 2, 3]}
112 | )
113 | actual = MySchema.validate(df, cast=True)
114 | expected = pl.DataFrame(
115 | {"a": [1, 2, 3], "b": ["x", "y", "z"], "c": ["1", None, None]}
116 | )
117 | assert_frame_equal(actual, expected)
118 | assert MySchema.is_valid(df, cast=True)
119 |
120 |
121 | @pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
122 | def test_group_rule_on_nulls(df_type: type[pl.DataFrame] | type[pl.LazyFrame]) -> None:
123 | # The schema is violated because we have multiple "b" values for the same "a" value
124 | df = df_type({"a": [None, None], "b": [1, 2]})
125 | with pytest.raises(RuleValidationError):
126 | MyComplexSchema.validate(df, cast=True)
127 | assert not MyComplexSchema.is_valid(df, cast=True)
128 |
--------------------------------------------------------------------------------
/tests/test_compat.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import pytest
5 |
6 | from dataframely._compat import _DummyModule
7 |
8 |
9 | def test_dummy_module() -> None:
10 | module = "sqlalchemy"
11 | dm = _DummyModule(module=module)
12 | assert dm.module == module
13 | with pytest.raises(ValueError):
14 | getattr(dm, "foo")
15 |
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import dataframely as dy
5 |
6 |
7 | def test_config_global() -> None:
8 | dy.Config.set_max_sampling_iterations(50)
9 | assert dy.Config.options["max_sampling_iterations"] == 50
10 | dy.Config.restore_defaults()
11 |
12 |
13 | def test_config_local() -> None:
14 | try:
15 | with dy.Config(max_sampling_iterations=35):
16 | assert dy.Config.options["max_sampling_iterations"] == 35
17 | assert dy.Config.options["max_sampling_iterations"] == 10_000
18 | finally:
19 | dy.Config.restore_defaults()
20 |
21 |
22 | def test_config_local_nested() -> None:
23 | try:
24 | with dy.Config(max_sampling_iterations=35):
25 | assert dy.Config.options["max_sampling_iterations"] == 35
26 | with dy.Config(max_sampling_iterations=20):
27 | assert dy.Config.options["max_sampling_iterations"] == 20
28 | assert dy.Config.options["max_sampling_iterations"] == 35
29 | assert dy.Config.options["max_sampling_iterations"] == 10_000
30 | finally:
31 | dy.Config.restore_defaults()
32 |
33 |
34 | def test_config_global_local() -> None:
35 | try:
36 | dy.Config.set_max_sampling_iterations(50)
37 | assert dy.Config.options["max_sampling_iterations"] == 50
38 | with dy.Config(max_sampling_iterations=35):
39 | assert dy.Config.options["max_sampling_iterations"] == 35
40 | assert dy.Config.options["max_sampling_iterations"] == 50
41 | finally:
42 | dy.Config.restore_defaults()
43 |
--------------------------------------------------------------------------------
/tests/test_deprecation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import warnings
5 |
6 | import pytest
7 |
8 | import dataframely as dy
9 |
10 |
11 | def test_column_constructor_warns_about_nullable(
12 | monkeypatch: pytest.MonkeyPatch,
13 | ) -> None:
14 | monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", "")
15 | with pytest.warns(
16 | FutureWarning, match="The 'nullable' argument was not explicitly set"
17 | ):
18 | dy.Integer()
19 |
20 |
21 | @pytest.mark.parametrize("env_var", ["1", "True", "true"])
22 | def test_future_warning_skip(monkeypatch: pytest.MonkeyPatch, env_var: str) -> None:
23 | monkeypatch.setenv("DATAFRAMELY_NO_FUTURE_WARNINGS", env_var)
24 |
25 | # Elevates FutureWarning to an exception
26 | with warnings.catch_warnings():
27 | warnings.simplefilter("error", FutureWarning)
28 | dy.Integer()
29 |
--------------------------------------------------------------------------------
/tests/test_exc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import polars as pl
5 |
6 | from dataframely.exc import DtypeValidationError, RuleValidationError, ValidationError
7 |
8 |
9 | def test_validation_error_str() -> None:
10 | message = "validation failed"
11 | exc = ValidationError(message)
12 | assert str(exc) == message
13 |
14 |
15 | def test_dtype_validation_error_str() -> None:
16 | exc = DtypeValidationError(
17 | errors={"a": (pl.Int64, pl.String), "b": (pl.Boolean, pl.String)}
18 | )
19 | assert str(exc).split("\n") == [
20 | "2 columns have an invalid dtype:",
21 | " - 'a': got dtype 'Int64' but expected 'String'",
22 | " - 'b': got dtype 'Boolean' but expected 'String'",
23 | ]
24 |
25 |
26 | def test_rule_validation_error_str() -> None:
27 | exc = RuleValidationError(
28 | {
29 | "b|max_length": 1500,
30 | "a|nullability": 2,
31 | "primary_key": 2000,
32 | "a|min_length": 5,
33 | },
34 | )
35 | assert str(exc).split("\n") == [
36 | "4 rules failed validation:",
37 | " - 'primary_key' failed validation for 2,000 rows",
38 | " * Column 'a' failed validation for 2 rules:",
39 | " - 'min_length' failed for 5 rows",
40 | " - 'nullability' failed for 2 rows",
41 | " * Column 'b' failed validation for 1 rules:",
42 | " - 'max_length' failed for 1,500 rows",
43 | ]
44 |
--------------------------------------------------------------------------------
/tests/test_extre.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | import math
5 | import re
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | import dataframely._extre as extre
11 |
12 | # ------------------------------------ MATCHING STRING LENGTH ----------------------------------- #
13 |
14 |
15 | @pytest.mark.parametrize(
16 | ("regex", "expected_lower", "expected_upper"),
17 | [
18 | (r"abc", 3, 3),
19 | (r".*", 0, None),
20 | (r"[a-z]{3,5}", 3, 5),
21 | (r"[0-9]{2}[0-9a-zA-Z]{2,4}", 4, 6),
22 | (r"^[0-9]{2}[0-9a-zA-Z]{2,4}$", 4, 6),
23 | (r"^[0-9]{2}[0-9a-zA-Z]{2,4}.+$", 5, None),
24 | ],
25 | )
26 | def test_matching_string_length(
27 | regex: str, expected_lower: int, expected_upper: int | None
28 | ) -> None:
29 | actual_lower, actual_upper = extre.matching_string_length(regex)
30 | assert actual_lower == expected_lower
31 | assert actual_upper == expected_upper
32 |
33 |
34 | @pytest.mark.parametrize("regex", [r"(?=[A-Za-z\d])"])
35 | def test_failing_matching_string_length(regex: str) -> None:
36 | with pytest.raises(ValueError):
37 | extre.matching_string_length(regex)
38 |
39 |
40 | # ------------------------------------------- SAMPLING ------------------------------------------ #
41 |
42 | TEST_REGEXES = [
43 | "",
44 | "a",
45 | "ab",
46 | "a|b",
47 | "[A-Z]+",
48 | "[A-Za-z0-9]?",
49 | "([a-z]+:)?[0-9]*" r"[^@]+@[^@]+\.[^@]+",
50 | r"[a-z0-9\._%+!$&*=^|~#%'`?{}/\-]+@([a-z0-9\-]+\.){1,}([a-z]{2,16})",
51 | ]
52 |
53 |
54 | @pytest.mark.parametrize("regex", TEST_REGEXES)
55 | def test_sample_one(regex: str) -> None:
56 | sample = extre.sample(regex, max_repetitions=10)
57 | assert re.fullmatch(regex, sample) is not None
58 |
59 |
60 | @pytest.mark.parametrize("regex", TEST_REGEXES)
61 | def test_sample_many(regex: str) -> None:
62 | samples = extre.sample(regex, n=100, max_repetitions=10)
63 | assert all(re.fullmatch(regex, s) is not None for s in samples)
64 |
65 |
66 | def test_sample_equal_alternation_probabilities() -> None:
67 | n = 100_000
68 | samples = extre.sample("a|b|c", n=n)
69 | np.allclose(np.unique_counts(samples).counts / n, np.ones(3) / 3, atol=0.01)
70 |
71 |
72 | def test_sample_max_repetitions() -> None:
73 | samples = extre.sample(".*", n=100_000, max_repetitions=10)
74 | assert max(len(s) for s in samples) == 10
75 | assert math.isclose(np.mean([len(s) for s in samples]), 5, abs_tol=0.05)
76 |
77 |
78 | def test_sample_equal_class_probabilities() -> None:
79 | n = 1_000_000
80 | samples = extre.sample("[a-z0-9]", n=n)
81 | np.allclose(np.unique_counts(samples).counts / n, np.ones(36) / 36, atol=0.001)
82 |
83 |
84 | def test_sample_one_seed() -> None:
85 | choices = [extre.sample("a|b", seed=42) for _ in range(10_000)]
86 | assert len(set(choices)) == 1
87 |
88 |
89 | def test_sample_many_seed() -> None:
90 | choices = extre.sample("a|b", n=10_000, seed=42)
91 | assert len(set(choices)) == 2
92 |
--------------------------------------------------------------------------------
/tests/test_failure_info.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) QuantCo 2025-2025
2 | # SPDX-License-Identifier: BSD-3-Clause
3 |
4 | from pathlib import Path
5 |
6 | import polars as pl
7 | from polars.testing import assert_frame_equal
8 |
9 | import dataframely as dy
10 |
11 |
12 | class MySchema(dy.Schema):
13 | a = dy.Integer(primary_key=True, min=5, max=10)
14 | b = dy.Integer(nullable=False, is_in=[1, 2, 3, 5, 7, 11])
15 |
16 |
17 | def test_read_write_parquet(tmp_path: Path) -> None:
18 | df = pl.DataFrame(
19 | {
20 | "a": [4, 5, 6, 6, 7, 8],
21 | "b": [1, 2, 3, 4, 5, 6],
22 | }
23 | )
24 | _, failure = MySchema.filter(df)
25 | assert failure._df.height == 4
26 | failure.write_parquet(tmp_path / "failure.parquet")
27 |
28 | read: dy.FailureInfo[MySchema] = dy.FailureInfo.scan_parquet(
29 | tmp_path / "failure.parquet"
30 | )
31 | assert_frame_equal(failure._lf, read._lf)
32 | assert failure._rule_columns == read._rule_columns
33 | assert failure.schema == read.schema == MySchema
34 |
--------------------------------------------------------------------------------