├── .github
└── workflows
│ ├── build.yml
│ ├── publish.yml
│ └── semgrep.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CODE_OF_CONDUCT.md
├── LICENSE.txt
├── MANIFEST.in
├── README.md
├── docs
├── Makefile
├── make.bat
├── remove_metadata.py
├── run_notebooks.sh
├── source
│ ├── README.rst
│ ├── advanced.rst
│ ├── advanced_linting_support.ipynb
│ ├── api.rst
│ ├── autocomplete_in_notebooks.ipynb
│ ├── column_ambiguity.ipynb
│ ├── complex_datatypes.ipynb
│ ├── conf.py
│ ├── contributing.rst
│ ├── create_empty_datasets.ipynb
│ ├── create_schema_in_notebook.ipynb
│ ├── dataset_implements.ipynb
│ ├── documentation.ipynb
│ ├── ide.rst
│ ├── index.rst
│ ├── loading_datasets_in_notebooks.ipynb
│ ├── notebook.rst
│ ├── robots.txt
│ ├── schema_attributes.ipynb
│ ├── structtype_columns.ipynb
│ ├── structtypes_in_notebooks.ipynb
│ ├── subclass_column_meta.ipynb
│ ├── subclassing_schemas.ipynb
│ ├── transforming_datasets.ipynb
│ └── type_checking.ipynb
└── videos
│ ├── ide.ipynb
│ └── notebook.ipynb
├── pyproject.toml
├── renovate.json
├── requirements-dev.txt
├── requirements.txt
├── setup.py
├── tests
├── _core
│ ├── test_column.py
│ ├── test_dataset.py
│ ├── test_datatypes.py
│ └── test_metadata.py
├── _schema
│ ├── test_create_spark_schema.py
│ ├── test_get_schema_definition.py
│ ├── test_offending_schemas.py
│ ├── test_schema.py
│ └── test_structfield.py
├── _transforms
│ ├── test_structtype_column.py
│ └── test_transform_to_schema.py
├── _utils
│ ├── test_create_dataset.py
│ ├── test_load_table.py
│ └── test_register_schema_to_dataset.py
└── conftest.py
├── tox.ini
└── typedspark
├── __init__.py
├── _core
├── __init__.py
├── column.py
├── column_meta.py
├── dataset.py
├── datatypes.py
├── literaltype.py
└── validate_schema.py
├── _schema
├── __init__.py
├── dlt_kwargs.py
├── get_schema_definition.py
├── get_schema_imports.py
├── schema.py
└── structfield.py
├── _transforms
├── __init__.py
├── rename_duplicate_columns.py
├── structtype_column.py
├── transform_to_schema.py
└── utils.py
├── _utils
├── __init__.py
├── camelcase.py
├── create_dataset.py
├── create_dataset_from_structtype.py
├── databases.py
├── load_table.py
└── register_schema_to_dataset.py
└── py.typed
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | timeout-minutes: 20
9 | strategy:
10 | matrix:
11 | python-version: ["3.9", "3.10", "3.11", "3.12"]
12 |
13 | steps:
14 | - uses: actions/checkout@v4
15 | - uses: actions/setup-java@v4
16 | with:
17 | distribution: 'temurin'
18 | java-version: 11
19 | - uses: vemonet/setup-spark@v1
20 | with:
21 | spark-version: '3.5.3'
22 | hadoop-version: '3'
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | cache: 'pip'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install -r requirements.txt
32 | pip install -r requirements-dev.txt
33 | - name: Linting
34 | run: |
35 | flake8
36 | pylint typedspark
37 | mypy .
38 | pyright .
39 | bandit typedspark/**/*.py
40 | black --check .
41 | isort --check .
42 | docformatter --black -c **/*.py
43 | - name: Testing
44 | run: |
45 | # we run this test seperately, to ensure that it is run without an active Spark session
46 | python -m pytest -m no_spark_session
47 |
48 | coverage run -m pytest
49 | coverage report -m --fail-under 100
50 | - name: Run notebooks
51 | run: |
52 | for FILE in docs/*/*.ipynb; do
53 | BASE=$(basename $FILE)
54 | cp $FILE .
55 | jupyter nbconvert --to notebook $BASE --execute
56 | done
57 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python
13 | uses: actions/setup-python@v5
14 | with:
15 | python-version: '3.x'
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Build and publish
21 | env:
22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
24 | run: |
25 | python setup.py sdist bdist_wheel
26 | twine upload dist/*
27 |
--------------------------------------------------------------------------------
/.github/workflows/semgrep.yml:
--------------------------------------------------------------------------------
1 | # This workflow uses actions that are not certified by GitHub.
2 | # They are provided by a third-party and are governed by
3 | # separate terms of service, privacy policy, and support
4 | # documentation.
5 |
6 | # This workflow file requires a free account on Semgrep.dev to
7 | # manage rules, file ignores, notifications, and more.
8 | #
9 | # See https://semgrep.dev/docs
10 |
11 | name: Semgrep
12 |
13 | on:
14 | push:
15 | branches: [ "main" ]
16 | pull_request:
17 | # The branches below must be a subset of the branches above
18 | branches: [ "main" ]
19 | schedule:
20 | - cron: '42 3 * * 1'
21 |
22 | permissions:
23 | contents: read
24 |
25 | jobs:
26 | semgrep:
27 | permissions:
28 | contents: read # for actions/checkout to fetch code
29 | security-events: write # for github/codeql-action/upload-sarif to upload SARIF results
30 | actions: read # only required for a private repository by github/codeql-action/upload-sarif to get the Action run status
31 | name: Scan
32 | runs-on: ubuntu-latest
33 | steps:
34 | # Checkout project source
35 | - uses: actions/checkout@v4
36 |
37 | # Scan code using project's configuration on https://semgrep.dev/manage
38 | - uses: returntocorp/semgrep-action@713efdd345f3035192eaa63f56867b88e63e4e5d
39 | with:
40 | publishToken: ${{ secrets.SEMGREP_APP_TOKEN }}
41 | publishDeployment: ${{ secrets.SEMGREP_DEPLOYMENT_ID }}
42 | generateSarif: "1"
43 |
44 | # Upload SARIF file generated in previous step
45 | - name: Upload SARIF file
46 | uses: github/codeql-action/upload-sarif@v3
47 | with:
48 | sarif_file: semgrep.sarif
49 | if: always()
50 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .mypy_cache
2 | .pytest_cache
3 | **/__pycache__
4 | .vscode
5 | .venv
6 | *.pyc
7 | *.ipynb_checkpoints*
8 | .DS_Store
9 | .coverage
10 | .cache
11 | *.egg*
12 | docs/build
13 | spark-warehouse
14 | dist/
15 | build/
16 | output.json
17 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: flake8
5 | name: flake8
6 | entry: flake8
7 | language: system
8 | types: [python]
9 | - id: pylint
10 | name: pylint
11 | entry: pylint typedspark
12 | language: system
13 | types: [python]
14 | args:
15 | [
16 | '-rn', # Only display messages
17 | '-sn', # Don't display the score
18 | ]
19 | - id: mypy
20 | name: mypy
21 | entry: mypy
22 | language: system
23 | types: [python]
24 | - id: pyright
25 | name: pyright
26 | entry: pyright
27 | language: system
28 | types: [python]
29 | - id: black
30 | name: black
31 | entry: black --check
32 | language: system
33 | files: \.(py|ipynb)$
34 | - id: isort
35 | name: isort
36 | entry: isort --check
37 | language: system
38 | types: [python]
39 | - id: docformatter
40 | name: docformatter
41 | entry: docformatter --black -c
42 | language: system
43 | types: [python]
44 | - id: pytest-no-spark
45 | name: pytest-no-spark
46 | entry: python -m pytest -m no_spark_session
47 | language: system
48 | types: [python]
49 | pass_filenames: false
50 | - id: pytest-spark
51 | name: pytest-spark
52 | entry: coverage run -m pytest
53 | language: system
54 | types: [python]
55 | pass_filenames: false
56 | - id: pytest-spark-coverage
57 | name: pytest-spark-coverage
58 | entry: coverage report -m --fail-under 100
59 | language: system
60 | types: [python]
61 | pass_filenames: false
62 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.11"
12 |
13 | sphinx:
14 | configuration: docs/source/conf.py
15 |
16 | python:
17 | install:
18 | - method: pip
19 | path: .
20 | - requirements: requirements-dev.txt
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | typedspark@kaiko.ai.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Typedspark: column-wise type annotations for pyspark DataFrames
2 |
3 | We love Spark! But in production code we're wary when we see:
4 |
5 | ```python
6 | from pyspark.sql import DataFrame
7 |
8 | def foo(df: DataFrame) -> DataFrame:
9 | # do stuff
10 | return df
11 | ```
12 |
13 | Because… How do we know which columns are supposed to be in ``df``?
14 |
15 | Using ``typedspark``, we can be more explicit about what these data should look like.
16 |
17 | ```python
18 | from typedspark import Column, DataSet, Schema
19 | from pyspark.sql.types import LongType, StringType
20 |
21 | class Person(Schema):
22 | id: Column[LongType]
23 | name: Column[StringType]
24 | age: Column[LongType]
25 |
26 | def foo(df: DataSet[Person]) -> DataSet[Person]:
27 | # do stuff
28 | return df
29 | ```
30 | The advantages include:
31 |
32 | * Improved readability of the code
33 | * Typechecking, both during runtime and linting
34 | * Auto-complete of column names
35 | * Easy refactoring of column names
36 | * Easier unit testing through the generation of empty ``DataSets`` based on their schemas
37 | * Improved documentation of tables
38 |
39 | ## Documentation
40 | Please see our documentation on [readthedocs](https://typedspark.readthedocs.io/en/latest/index.html).
41 |
42 | ## Installation
43 |
44 | You can install ``typedspark`` from [pypi](https://pypi.org/project/typedspark/) by running:
45 |
46 | ```bash
47 | pip install typedspark
48 | ```
49 | By default, ``typedspark`` does not list ``pyspark`` as a dependency, since many platforms (e.g. Databricks) come with ``pyspark`` preinstalled. If you want to install ``typedspark`` with ``pyspark``, you can run:
50 |
51 | ```bash
52 | pip install "typedspark[pyspark]"
53 | ```
54 |
55 | ## Demo videos
56 |
57 | ### IDE demo
58 |
59 | https://github.com/kaiko-ai/typedspark/assets/47976799/e6f7fa9c-6d14-4f68-baba-fe3c22f75b67
60 |
61 | You can find the corresponding code [here](docs/videos/ide.ipynb).
62 |
63 | ### Jupyter / Databricks notebooks demo
64 |
65 | https://github.com/kaiko-ai/typedspark/assets/47976799/39e157c3-6db0-436a-9e72-44b2062df808
66 |
67 | You can find the corresponding code [here](docs/videos/notebook.ipynb).
68 |
69 | ## FAQ
70 |
71 | **I found a bug! What should I do?**
72 | Great! Please make an issue and we'll look into it.
73 |
74 | **I have a great idea to improve typedspark! How can we make this work?**
75 | Awesome, please make an issue and let us know!
76 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/remove_metadata.py:
--------------------------------------------------------------------------------
1 | """Removes the metadata from a notebook.
2 |
3 | Also removes the spark warnings from cells where the sparksession is initialized.
4 | """
5 |
6 | import sys
7 |
8 | import nbformat
9 |
10 |
11 | def clear_metadata(cell):
12 | """Clears the metadata of a notebook cell."""
13 | cell.metadata = {}
14 |
15 |
16 | def remove_spark_warnings(cell):
17 | """Removes the spark warnings from a notebook cell."""
18 | if "outputs" in cell.keys():
19 | outputs = []
20 | for output in cell.outputs:
21 | if "text" in output.keys():
22 | if 'Setting default log level to "WARN"' in output.text:
23 | continue
24 | if (
25 | "WARN NativeCodeLoader: Unable to load native-hadoop library for your platform."
26 | in output.text
27 | ):
28 | continue
29 | if "WARN Utils: Service 'SparkUI' could not bind on port" in output.text:
30 | continue
31 | if (
32 | "FutureWarning: is_datetime64tz_dtype is deprecated and will be removed in a future version." # noqa: E501
33 | in output.text
34 | ):
35 | continue
36 | outputs.append(output)
37 |
38 | cell.outputs = outputs
39 |
40 |
41 | if __name__ == "__main__":
42 | FILENAME = sys.argv[1]
43 | nb = nbformat.read(FILENAME, as_version=4)
44 |
45 | for nb_cell in nb["cells"]:
46 | clear_metadata(nb_cell)
47 | remove_spark_warnings(nb_cell)
48 |
49 | nbformat.write(nb, FILENAME)
50 |
--------------------------------------------------------------------------------
/docs/run_notebooks.sh:
--------------------------------------------------------------------------------
1 | for FILE in docs/*/*.ipynb; do
2 | echo "Running $FILE"
3 | DIR=$(dirname $FILE)
4 | BASE=$(basename $FILE)
5 | mv $FILE .
6 |
7 | jupyter nbconvert --to notebook $BASE --execute --inplace
8 | python docs/remove_metadata.py $BASE;
9 |
10 | mv $BASE $DIR
11 | done
12 |
--------------------------------------------------------------------------------
/docs/source/README.rst:
--------------------------------------------------------------------------------
1 | ===============================================================
2 | Typedspark: column-wise type annotations for pyspark DataFrames
3 | ===============================================================
4 |
5 | We love Spark! But in production code we're wary when we see:
6 |
7 | .. code-block:: python
8 |
9 | from pyspark.sql import DataFrame
10 |
11 | def foo(df: DataFrame) -> DataFrame:
12 | # do stuff
13 | return df
14 |
15 | Because… How do we know which columns are supposed to be in ``df``?
16 |
17 | Using ``typedspark``, we can be more explicit about what these data should look like.
18 |
19 | .. code-block:: python
20 |
21 | from typedspark import Column, DataSet, Schema
22 | from pyspark.sql.types import LongType, StringType
23 |
24 | class Person(Schema):
25 | id: Column[LongType]
26 | name: Column[StringType]
27 | age: Column[LongType]
28 |
29 | def foo(df: DataSet[Person]) -> DataSet[Person]:
30 | # do stuff
31 | return df
32 |
33 | The advantages include:
34 |
35 | * Improved readability of the code
36 | * Typechecking, both during runtime and linting
37 | * Auto-complete of column names
38 | * Easy refactoring of column names
39 | * Easier unit testing through the generation of empty ``DataSets`` based on their schemas
40 | * Improved documentation of tables
41 |
42 | Installation
43 | ============
44 |
45 | You can install ``typedspark`` from `pypi `_ by running:
46 |
47 | .. code-block:: bash
48 |
49 | pip install typedspark
50 |
51 | By default, ``typedspark`` does not list ``pyspark`` as a dependency, since many platforms (e.g. Databricks) come with ``pyspark`` preinstalled. If you want to install ``typedspark`` with ``pyspark``, you can run:
52 |
53 | .. code-block:: bash
54 |
55 | pip install "typedspark[pyspark]"
56 |
57 | Demo videos
58 | ===========
59 |
60 | * IDE demo: `video `_ and `code `_.
61 | * Jupyter / Databricks Notebook demo: `video `_ and `code `_.
62 |
63 | FAQ
64 | ===
65 |
66 | | **I found a bug! What should I do?**
67 | | Great! Please make an issue and we'll look into it.
68 | |
69 | | **I have a great idea to improve typedspark! How can we make this work?**
70 | | Awesome, please make an issue and let us know!
71 |
--------------------------------------------------------------------------------
/docs/source/advanced.rst:
--------------------------------------------------------------------------------
1 | Advanced Topics
2 | ===============
3 |
4 | .. toctree::
5 |
6 | subclassing_schemas
7 | subclass_column_meta
8 | column_ambiguity
9 | advanced_linting_support
10 | dataset_implements
11 |
--------------------------------------------------------------------------------
/docs/source/advanced_linting_support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "b83251c2",
7 | "metadata": {},
8 | "source": [
9 | "# Advanced usage of type hints\n",
10 | "## Functions that do not affect the schema\n",
11 | "\n",
12 | "There are a number of functions in `DataSet` which do not affect the schema. For example:"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "bdf75d36",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "from pyspark.sql import SparkSession\n",
23 | "\n",
24 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
25 | "spark.sparkContext.setLogLevel(\"ERROR\")"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "id": "c84b26e9",
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "from typedspark import Column, Schema, DataSet, create_partially_filled_dataset\n",
36 | "from pyspark.sql.types import StringType\n",
37 | "\n",
38 | "\n",
39 | "class A(Schema):\n",
40 | " a: Column[StringType]\n",
41 | "\n",
42 | "\n",
43 | "df = create_partially_filled_dataset(\n",
44 | " spark,\n",
45 | " A,\n",
46 | " {\n",
47 | " A.a: [\"a\", \"b\", \"c\"],\n",
48 | " },\n",
49 | ")\n",
50 | "res = df.filter(A.a == \"a\")"
51 | ]
52 | },
53 | {
54 | "attachments": {},
55 | "cell_type": "markdown",
56 | "id": "01675118",
57 | "metadata": {},
58 | "source": [
59 | "In the above example, `filter()` will not actually make any changes to the schema, hence we have implemented the return type of `DataSet.filter()` to be a `DataSet` of the same `Schema` that you started with. In other words, a linter will see that `res` is of the type `DataSet[A]`.\n",
60 | "\n",
61 | "This allows you to skip casting steps in many cases and instead define functions as:"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "id": "b0fe9344",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "def foo(df: DataSet[A]) -> DataSet[A]:\n",
72 | " return df.filter(A.a == \"a\")"
73 | ]
74 | },
75 | {
76 | "attachments": {},
77 | "cell_type": "markdown",
78 | "id": "4ed27a7d",
79 | "metadata": {},
80 | "source": [
81 | "The functions for which this is currently implemented include:\n",
82 | "\n",
83 | "* `filter()`\n",
84 | "* `distinct()`\n",
85 | "* `orderBy()`\n",
86 | "* `where()`\n",
87 | "* `alias()`\n",
88 | "* `persist()`\n",
89 | "* `unpersist()`\n",
90 | "\n",
91 | "## Functions applied to two DataSets of the same schema\n",
92 | "\n",
93 | "Similarly, some functions return a `DataSet[A]` when they take two `DataSet[A]` as an input. For example, here a linter will see that `res` is of the type `DataSet[A]`."
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 4,
99 | "id": "3abfb0d8",
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "df_a = create_partially_filled_dataset(spark, A, {A.a: [\"a\", \"b\", \"c\"]})\n",
104 | "df_b = create_partially_filled_dataset(spark, A, {A.a: [\"d\", \"e\", \"f\"]})\n",
105 | "\n",
106 | "res = df_a.unionByName(df_b)"
107 | ]
108 | },
109 | {
110 | "attachments": {},
111 | "cell_type": "markdown",
112 | "id": "dfa33e61",
113 | "metadata": {},
114 | "source": [
115 | "The functions in this category include:\n",
116 | "\n",
117 | "* `unionByName()`\n",
118 | "* `join(..., how=\"semi\")`\n",
119 | "\n",
120 | "## Transformations\n",
121 | "\n",
122 | "Finally, the `transform()` function can also be typed. In the following example, a linter will see that `res` is of the type `DataSet[B]`."
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": 5,
128 | "id": "5b0df362",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "from typedspark import transform_to_schema\n",
133 | "from pyspark.sql.functions import lit\n",
134 | "\n",
135 | "\n",
136 | "class B(A):\n",
137 | " b: Column[StringType]\n",
138 | "\n",
139 | "\n",
140 | "def foo(df: DataSet[A]) -> DataSet[A]:\n",
141 | " return transform_to_schema(\n",
142 | " df,\n",
143 | " B,\n",
144 | " {\n",
145 | " B.b: lit(\"hi\"),\n",
146 | " },\n",
147 | " )\n",
148 | "\n",
149 | "\n",
150 | "res = create_partially_filled_dataset(\n",
151 | " spark,\n",
152 | " A,\n",
153 | " {\n",
154 | " A.a: [\"a\", \"b\", \"c\"],\n",
155 | " },\n",
156 | ").transform(foo)"
157 | ]
158 | },
159 | {
160 | "attachments": {},
161 | "cell_type": "markdown",
162 | "id": "4221be22",
163 | "metadata": {},
164 | "source": [
165 | "## Did we miss anything?\n",
166 | "\n",
167 | "There are likely more functions that we did not yet cover. Feel free to make an issue and reach out when there is one that you'd like typedspark to support!"
168 | ]
169 | },
170 | {
171 | "attachments": {},
172 | "cell_type": "markdown",
173 | "id": "d21ad841",
174 | "metadata": {},
175 | "source": []
176 | }
177 | ],
178 | "metadata": {
179 | "kernelspec": {
180 | "display_name": "typedspark",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.11.9"
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 5
199 | }
200 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API Documentation
2 | =================
3 |
4 | .. automodule:: typedspark
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/autocomplete_in_notebooks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "c3322756",
7 | "metadata": {},
8 | "source": [
9 | "# Autocomplete in Databricks & Jupyter notebooks\n",
10 | "When we use `Catalogs`, `Databases`, `Database`, `load_table()` or `create_schema()` in a Databricks or Jupyter notebook, we also get autocomplete on the column names. No more looking at `df.columns` every minute to remember the column names!\n",
11 | "\n",
12 | "## The basics\n",
13 | "\n",
14 | "To illustrate this, let us first generate a table that we'll write to the table `person_table`."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "id": "87752202",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "from pyspark.sql import SparkSession\n",
25 | "\n",
26 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
27 | "spark.sparkContext.setLogLevel(\"ERROR\")"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "id": "6c1e5acc",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "import pandas as pd\n",
38 | "\n",
39 | "(\n",
40 | " spark.createDataFrame(\n",
41 | " pd.DataFrame(\n",
42 | " dict(\n",
43 | " name=[\"Jack\", \"John\", \"Jane\"],\n",
44 | " age=[20, 30, 40],\n",
45 | " )\n",
46 | " )\n",
47 | " ).createOrReplaceTempView(\"person_table\")\n",
48 | ")"
49 | ]
50 | },
51 | {
52 | "attachments": {},
53 | "cell_type": "markdown",
54 | "id": "4bd96763",
55 | "metadata": {},
56 | "source": [
57 | "We can now load these data using `load_table()`. Note that the `Schema` is inferred: it doesn't need to have been serialized using `typedspark`."
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 3,
63 | "id": "3003dea9",
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "from typedspark import load_table\n",
68 | "\n",
69 | "df, Person = load_table(spark, \"person_table\")"
70 | ]
71 | },
72 | {
73 | "attachments": {},
74 | "cell_type": "markdown",
75 | "id": "65c183c1",
76 | "metadata": {},
77 | "source": [
78 | "You can now use `df` and `Person` just like you would in your IDE, including autocomplete!"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 4,
84 | "id": "f38e0e20",
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "+----+---+\n",
92 | "|name|age|\n",
93 | "+----+---+\n",
94 | "|John| 30|\n",
95 | "|Jane| 40|\n",
96 | "+----+---+\n",
97 | "\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "df.filter(Person.age > 25).show()"
103 | ]
104 | },
105 | {
106 | "attachments": {},
107 | "cell_type": "markdown",
108 | "id": "eb3f8c9c",
109 | "metadata": {},
110 | "source": [
111 | "## Other notebook types\n",
112 | "\n",
113 | "Auto-complete of dynamically loaded schemas (e.g. through `load_table()` or `create_schema()`) has been verified to work on Databricks, JupyterLab and Jupyter Notebook. At the time of writing, it doesn't work in VSCode and PyCharm notebooks."
114 | ]
115 | },
116 | {
117 | "attachments": {},
118 | "cell_type": "markdown",
119 | "id": "c342aec1",
120 | "metadata": {},
121 | "source": []
122 | }
123 | ],
124 | "metadata": {
125 | "kernelspec": {
126 | "display_name": "Python 3 (ipykernel)",
127 | "language": "python",
128 | "name": "python3"
129 | },
130 | "language_info": {
131 | "codemirror_mode": {
132 | "name": "ipython",
133 | "version": 3
134 | },
135 | "file_extension": ".py",
136 | "mimetype": "text/x-python",
137 | "name": "python",
138 | "nbconvert_exporter": "python",
139 | "pygments_lexer": "ipython3",
140 | "version": "3.11.9"
141 | }
142 | },
143 | "nbformat": 4,
144 | "nbformat_minor": 5
145 | }
146 |
--------------------------------------------------------------------------------
/docs/source/complex_datatypes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "72077ffb",
7 | "metadata": {},
8 | "source": [
9 | "# Other complex datatypes\n",
10 | "Spark contains several other complex data types. \n",
11 | "\n",
12 | "## MapType, ArrayType, DecimalType and DayTimeIntervalType\n",
13 | "These can be used in `typedspark` as follows:"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "id": "934c2a2d",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from typing import Literal\n",
24 | "from pyspark.sql.types import StringType\n",
25 | "from typedspark import (\n",
26 | " ArrayType,\n",
27 | " DayTimeIntervalType,\n",
28 | " DecimalType,\n",
29 | " IntervalType,\n",
30 | " MapType,\n",
31 | " Schema,\n",
32 | " Column,\n",
33 | ")\n",
34 | "\n",
35 | "\n",
36 | "class Values(Schema):\n",
37 | " array: Column[ArrayType[StringType]]\n",
38 | " map: Column[MapType[StringType, StringType]]\n",
39 | " decimal: Column[DecimalType[Literal[38], Literal[18]]]\n",
40 | " interval: Column[DayTimeIntervalType[IntervalType.HOUR, IntervalType.SECOND]]"
41 | ]
42 | },
43 | {
44 | "attachments": {},
45 | "cell_type": "markdown",
46 | "id": "271da325",
47 | "metadata": {},
48 | "source": [
49 | "## Generating DataSets\n",
50 | "\n",
51 | "You can generate `DataSets` using complex data types in the following way:"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 2,
57 | "id": "f82a8b4f",
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "from pyspark.sql import SparkSession\n",
62 | "\n",
63 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
64 | "spark.sparkContext.setLogLevel(\"ERROR\")"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "3ccb56ff",
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "+---------+--------+--------------------+--------------------+----------+-------------------+\n",
78 | "| array| map| decimal| interval| date| timestamp|\n",
79 | "+---------+--------+--------------------+--------------------+----------+-------------------+\n",
80 | "|[a, b, c]|{a -> b}|32.00000000000000...|INTERVAL '26:03:0...|2020-01-01|2020-01-01 10:15:00|\n",
81 | "+---------+--------+--------------------+--------------------+----------+-------------------+\n",
82 | "\n"
83 | ]
84 | }
85 | ],
86 | "source": [
87 | "from datetime import date, datetime, timedelta\n",
88 | "from decimal import Decimal\n",
89 | "from pyspark.sql.types import DateType, TimestampType\n",
90 | "from typedspark._utils.create_dataset import create_partially_filled_dataset\n",
91 | "\n",
92 | "\n",
93 | "class MoreValues(Values):\n",
94 | " date: Column[DateType]\n",
95 | " timestamp: Column[TimestampType]\n",
96 | "\n",
97 | "\n",
98 | "create_partially_filled_dataset(\n",
99 | " spark,\n",
100 | " MoreValues,\n",
101 | " {\n",
102 | " MoreValues.array: [[\"a\", \"b\", \"c\"]],\n",
103 | " MoreValues.map: [{\"a\": \"b\"}],\n",
104 | " MoreValues.decimal: [Decimal(32)],\n",
105 | " MoreValues.interval: [timedelta(days=1, hours=2, minutes=3, seconds=4)],\n",
106 | " MoreValues.date: [date(2020, 1, 1)],\n",
107 | " MoreValues.timestamp: [datetime(2020, 1, 1, 10, 15)],\n",
108 | " },\n",
109 | ").show()"
110 | ]
111 | },
112 | {
113 | "attachments": {},
114 | "cell_type": "markdown",
115 | "id": "8a2232b3",
116 | "metadata": {},
117 | "source": [
118 | "## Did we miss a data type?\n",
119 | "\n",
120 | "Feel free to make an issue! We can extend the list of supported data types."
121 | ]
122 | },
123 | {
124 | "attachments": {},
125 | "cell_type": "markdown",
126 | "id": "33d83e5f",
127 | "metadata": {},
128 | "source": []
129 | }
130 | ],
131 | "metadata": {
132 | "kernelspec": {
133 | "display_name": "typedspark",
134 | "language": "python",
135 | "name": "python3"
136 | },
137 | "language_info": {
138 | "codemirror_mode": {
139 | "name": "ipython",
140 | "version": 3
141 | },
142 | "file_extension": ".py",
143 | "mimetype": "text/x-python",
144 | "name": "python",
145 | "nbconvert_exporter": "python",
146 | "pygments_lexer": "ipython3",
147 | "version": "3.11.9"
148 | }
149 | },
150 | "nbformat": 4,
151 | "nbformat_minor": 5
152 | }
153 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 |
10 | from typing import Any, List
11 |
12 | project = "typedspark"
13 | copyright = "2023, Nanne Aben, Marijn Valk"
14 | author = "Nanne Aben, Marijn Valk"
15 |
16 | # -- General configuration ---------------------------------------------------
17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
18 |
19 | extensions = ["sphinx.ext.autodoc", "sphinx_rtd_theme", "nbsphinx"]
20 |
21 | templates_path = ["_templates"]
22 | exclude_patterns: List[Any] = []
23 |
24 |
25 | # -- Options for HTML output -------------------------------------------------
26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
27 |
28 | html_theme = "sphinx_rtd_theme"
29 | html_static_path = ["_static"]
30 |
31 | autodoc_inherit_docstrings = False
32 |
33 | html_extra_path = ["robots.txt"]
34 |
--------------------------------------------------------------------------------
/docs/source/contributing.rst:
--------------------------------------------------------------------------------
1 | ============
2 | Contributing
3 | ============
4 |
5 | We welcome contributions! To set up your development environment, we recommend using pyenv. You can find more on how to install ``pyenv`` and ``pyenv-virtualen`` here:
6 |
7 | * https://github.com/pyenv/pyenv
8 | * https://github.com/pyenv/pyenv-virtualenv
9 |
10 | To set up the environment, run:
11 |
12 | .. code-block:: bash
13 |
14 | pyenv install 3.11
15 | pyenv virtualenv 3.11 typedspark
16 | pyenv activate typedspark
17 | pip install -r requirements.txt
18 | pip install -r requirements-dev.txt
19 |
20 | For a list of currently supported Python versions, we refer to ``.github/workflows/build.yml``.
21 |
22 | Note that in order to run the unit tests, you will need to set up Spark on your machine.
23 |
24 | ---------------
25 | Pre-commit hook
26 | ---------------
27 | We use ``pre-commit`` to run a number of checks on the code before it is committed. To install the pre-commit hook, run:
28 |
29 | .. code-block:: bash
30 |
31 | pre-commit install
32 |
33 | Note that this will require you to set up Spark on your machine.
34 |
35 | There are currently two steps from the CI/CD that we do not check using the pre-commit hook:
36 |
37 | * bandit
38 | * notebooks
39 |
40 | Since they rarely fail, this shouldn't be a problem. We recommend that you test these using the CI/CD pipeline.
41 |
42 | ---------
43 | Notebooks
44 | ---------
45 | If you make changes that affect the documentation, please rerun the documentation notebooks in ``docs/``. You can do so by running the following command in the root of the repository:
46 |
47 | .. code-block:: bash
48 |
49 | sh docs/run_notebooks.sh
50 |
51 | This will run all notebooks and strip the metadata afterwards, such that the diffs in the PR remain manageable.
52 |
53 | ----------------------
54 | Building documentation
55 | ----------------------
56 |
57 | You can build the documentation locally by running:
58 |
59 | .. code-block:: bash
60 |
61 | cd docs/; make clean; make html; cd ..
62 |
63 | You can find the resulting documentation in ``docs/build/html/index.html``.
64 |
--------------------------------------------------------------------------------
/docs/source/create_empty_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "f25abb69",
7 | "metadata": {},
8 | "source": [
9 | "# Easier unit testing through the creation of empty DataSets from schemas\n",
10 | "\n",
11 | "We provide helper functions to generate (partially) empty DataSets from existing schemas. This can be helpful in certain situations, such as unit testing.\n",
12 | "\n",
13 | "## Column-wise definition of your DataSets\n",
14 | "\n",
15 | "First, let us consider the column-wise definition of new `DataSets`. This is useful when you have few rows, but many columns."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": 1,
21 | "id": "25b679fc",
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from pyspark.sql import SparkSession\n",
26 | "\n",
27 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
28 | "spark.sparkContext.setLogLevel(\"ERROR\")"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "id": "7a2690c7",
35 | "metadata": {},
36 | "outputs": [
37 | {
38 | "name": "stdout",
39 | "output_type": "stream",
40 | "text": [
41 | "+----+----+----+\n",
42 | "| id|name| age|\n",
43 | "+----+----+----+\n",
44 | "|NULL|NULL|NULL|\n",
45 | "|NULL|NULL|NULL|\n",
46 | "|NULL|NULL|NULL|\n",
47 | "+----+----+----+\n",
48 | "\n"
49 | ]
50 | }
51 | ],
52 | "source": [
53 | "from typedspark import Column, Schema, create_empty_dataset, create_partially_filled_dataset\n",
54 | "from pyspark.sql.types import LongType, StringType\n",
55 | "\n",
56 | "\n",
57 | "class Person(Schema):\n",
58 | " id: Column[LongType]\n",
59 | " name: Column[StringType]\n",
60 | " age: Column[LongType]\n",
61 | "\n",
62 | "\n",
63 | "df_empty = create_empty_dataset(spark, Person)\n",
64 | "df_empty.show()"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "d5ce76e1",
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "+---+----+----+\n",
78 | "| id|name| age|\n",
79 | "+---+----+----+\n",
80 | "| 1|John|NULL|\n",
81 | "| 2|Jane|NULL|\n",
82 | "| 3|Jack|NULL|\n",
83 | "+---+----+----+\n",
84 | "\n"
85 | ]
86 | }
87 | ],
88 | "source": [
89 | "df_partially_filled = create_partially_filled_dataset(\n",
90 | " spark,\n",
91 | " Person,\n",
92 | " {\n",
93 | " Person.id: [1, 2, 3],\n",
94 | " Person.name: [\"John\", \"Jane\", \"Jack\"],\n",
95 | " },\n",
96 | ")\n",
97 | "df_partially_filled.show()"
98 | ]
99 | },
100 | {
101 | "attachments": {},
102 | "cell_type": "markdown",
103 | "id": "571cf714",
104 | "metadata": {},
105 | "source": [
106 | "## Row-wise definition of your DataSets\n",
107 | "\n",
108 | "It is also possible to define your DataSets in a row-wise fashion. This is useful for cases where you have few columns, but many rows!"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 4,
114 | "id": "a68e52bf",
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "name": "stdout",
119 | "output_type": "stream",
120 | "text": [
121 | "+----+-------+---+\n",
122 | "| id| name|age|\n",
123 | "+----+-------+---+\n",
124 | "|NULL| Alice| 20|\n",
125 | "|NULL| Bob| 30|\n",
126 | "|NULL|Charlie| 40|\n",
127 | "|NULL| Dave| 50|\n",
128 | "|NULL| Eve| 60|\n",
129 | "|NULL| Frank| 70|\n",
130 | "|NULL| Grace| 80|\n",
131 | "+----+-------+---+\n",
132 | "\n"
133 | ]
134 | }
135 | ],
136 | "source": [
137 | "create_partially_filled_dataset(\n",
138 | " spark,\n",
139 | " Person,\n",
140 | " [\n",
141 | " {Person.name: \"Alice\", Person.age: 20},\n",
142 | " {Person.name: \"Bob\", Person.age: 30},\n",
143 | " {Person.name: \"Charlie\", Person.age: 40},\n",
144 | " {Person.name: \"Dave\", Person.age: 50},\n",
145 | " {Person.name: \"Eve\", Person.age: 60},\n",
146 | " {Person.name: \"Frank\", Person.age: 70},\n",
147 | " {Person.name: \"Grace\", Person.age: 80},\n",
148 | " ],\n",
149 | ").show()"
150 | ]
151 | },
152 | {
153 | "attachments": {},
154 | "cell_type": "markdown",
155 | "id": "bd5c7c1c",
156 | "metadata": {},
157 | "source": [
158 | "## Example unit test\n",
159 | "\n",
160 | "The following code snippet shows what a full unit test using typedspark can look like."
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 5,
166 | "id": "f80ddebf",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "from pyspark.sql import SparkSession\n",
171 | "from pyspark.sql.types import LongType, StringType\n",
172 | "from typedspark import Column, DataSet, Schema, create_partially_filled_dataset, transform_to_schema\n",
173 | "from chispa.dataframe_comparer import assert_df_equality\n",
174 | "\n",
175 | "\n",
176 | "class Person(Schema):\n",
177 | " name: Column[StringType]\n",
178 | " age: Column[LongType]\n",
179 | "\n",
180 | "\n",
181 | "def birthday(df: DataSet[Person]) -> DataSet[Person]:\n",
182 | " return transform_to_schema(df, Person, {Person.age: Person.age + 1})\n",
183 | "\n",
184 | "\n",
185 | "def test_birthday(spark: SparkSession):\n",
186 | " df = create_partially_filled_dataset(\n",
187 | " spark,\n",
188 | " Person,\n",
189 | " {\n",
190 | " Person.name: [\"Alice\", \"Bob\"],\n",
191 | " Person.age: [20, 30],\n",
192 | " },\n",
193 | " )\n",
194 | "\n",
195 | " observed = birthday(df)\n",
196 | " expected = create_partially_filled_dataset(\n",
197 | " spark,\n",
198 | " Person,\n",
199 | " {\n",
200 | " Person.name: [\"Alice\", \"Bob\"],\n",
201 | " Person.age: [21, 31],\n",
202 | " },\n",
203 | " )\n",
204 | "\n",
205 | " assert_df_equality(observed, expected, ignore_row_order=True, ignore_nullable=True)"
206 | ]
207 | },
208 | {
209 | "attachments": {},
210 | "cell_type": "markdown",
211 | "id": "f5e716a8",
212 | "metadata": {},
213 | "source": []
214 | }
215 | ],
216 | "metadata": {
217 | "kernelspec": {
218 | "display_name": "typedspark",
219 | "language": "python",
220 | "name": "python3"
221 | },
222 | "language_info": {
223 | "codemirror_mode": {
224 | "name": "ipython",
225 | "version": 3
226 | },
227 | "file_extension": ".py",
228 | "mimetype": "text/x-python",
229 | "name": "python",
230 | "nbconvert_exporter": "python",
231 | "pygments_lexer": "ipython3",
232 | "version": "3.11.9"
233 | }
234 | },
235 | "nbformat": 4,
236 | "nbformat_minor": 5
237 | }
238 |
--------------------------------------------------------------------------------
/docs/source/dataset_implements.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "d3c03896",
6 | "metadata": {},
7 | "source": [
8 | "# Transformations for all schemas with a given column using DataSetImplements\n",
9 | "\n",
10 | "Let's illustrate this with an example! First, we'll define some data."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "91c48423",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from pyspark.sql import SparkSession\n",
21 | "\n",
22 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
23 | "spark.sparkContext.setLogLevel(\"ERROR\")"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "99c453be",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "from pyspark.sql.types import LongType, StringType\n",
34 | "from typedspark import (\n",
35 | " Schema,\n",
36 | " Column,\n",
37 | " create_empty_dataset,\n",
38 | ")\n",
39 | "\n",
40 | "\n",
41 | "class Person(Schema):\n",
42 | " name: Column[StringType]\n",
43 | " age: Column[LongType]\n",
44 | " job: Column[StringType]\n",
45 | "\n",
46 | "\n",
47 | "class Pet(Schema):\n",
48 | " name: Column[StringType]\n",
49 | " age: Column[LongType]\n",
50 | " type: Column[StringType]\n",
51 | "\n",
52 | "\n",
53 | "class Fruit(Schema):\n",
54 | " type: Column[StringType]\n",
55 | "\n",
56 | "\n",
57 | "person = create_empty_dataset(spark, Person)\n",
58 | "pet = create_empty_dataset(spark, Pet)\n",
59 | "fruit = create_empty_dataset(spark, Fruit)"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "ca634c83",
65 | "metadata": {},
66 | "source": [
67 | "Now, suppose we want to define a function `birthday()` that works on all schemas that contain the column `age`. With `DataSet`, we'd have to specifically indicate which schemas contain the `age` column. We could do this with for example:"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 3,
73 | "id": "c948c8d3",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "from typing import TypeVar, Union\n",
78 | "\n",
79 | "from typedspark import DataSet, transform_to_schema\n",
80 | "\n",
81 | "T = TypeVar(\"T\", bound=Union[Person, Pet])\n",
82 | "\n",
83 | "\n",
84 | "def birthday(df: DataSet[T]) -> DataSet[T]:\n",
85 | " return transform_to_schema(\n",
86 | " df,\n",
87 | " df.typedspark_schema,\n",
88 | " {Person.age: Person.age + 1},\n",
89 | " )"
90 | ]
91 | },
92 | {
93 | "cell_type": "markdown",
94 | "id": "9784804d",
95 | "metadata": {},
96 | "source": [
97 | "This can get tedious if the list of schemas with the column `age` changes, for example because new schemas are added, or because the `age` column is removed from a schema! It's also not great that we're using `Person.age` here to define the `age` column...\n",
98 | "\n",
99 | "Fortunately, we can do better! Consider the following example:"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 4,
105 | "id": "b2436106",
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "from typing import Protocol\n",
110 | "\n",
111 | "from typedspark import DataSetImplements\n",
112 | "\n",
113 | "\n",
114 | "class Age(Schema, Protocol):\n",
115 | " age: Column[LongType]\n",
116 | "\n",
117 | "\n",
118 | "T = TypeVar(\"T\", bound=Schema)\n",
119 | "\n",
120 | "\n",
121 | "def birthday(df: DataSetImplements[Age, T]) -> DataSet[T]:\n",
122 | " return transform_to_schema(\n",
123 | " df,\n",
124 | " df.typedspark_schema,\n",
125 | " {Age.age: Age.age + 1},\n",
126 | " )"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "id": "2088742c",
132 | "metadata": {},
133 | "source": [
134 | "Here, we define `Age` to be both a `Schema` and a `Protocol` ([PEP-0544](https://peps.python.org/pep-0544/)). \n",
135 | "\n",
136 | "We then define `birthday()` to:\n",
137 | "\n",
138 | "1. Take as an input `DataSetImplements[Age, T]`: a `DataSet` that implements the protocol `Age` as `T`. \n",
139 | "2. Return a `DataSet[T]`: a `DataSet` of the same type as the one that was provided.\n",
140 | "\n",
141 | "Let's see this in action!"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 5,
147 | "id": "5658210f",
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "# returns a DataSet[Person]\n",
152 | "happy_person = birthday(person)\n",
153 | "\n",
154 | "# returns a DataSet[Pet]\n",
155 | "happy_pet = birthday(pet)\n",
156 | "\n",
157 | "try:\n",
158 | " # Raises a linting error:\n",
159 | " # Argument of type \"DataSet[Fruit]\" cannot be assigned to\n",
160 | " # parameter \"df\" of type \"DataSetImplements[Age, T@birthday]\"\n",
161 | " birthday(fruit)\n",
162 | "except Exception as e:\n",
163 | " pass"
164 | ]
165 | },
166 | {
167 | "cell_type": "markdown",
168 | "id": "5bfb99ed",
169 | "metadata": {},
170 | "source": []
171 | }
172 | ],
173 | "metadata": {
174 | "kernelspec": {
175 | "display_name": "typedspark",
176 | "language": "python",
177 | "name": "python3"
178 | },
179 | "language_info": {
180 | "codemirror_mode": {
181 | "name": "ipython",
182 | "version": 3
183 | },
184 | "file_extension": ".py",
185 | "mimetype": "text/x-python",
186 | "name": "python",
187 | "nbconvert_exporter": "python",
188 | "pygments_lexer": "ipython3",
189 | "version": "3.11.9"
190 | }
191 | },
192 | "nbformat": 4,
193 | "nbformat_minor": 5
194 | }
195 |
--------------------------------------------------------------------------------
/docs/source/documentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "4b8cb87a",
7 | "metadata": {},
8 | "source": [
9 | "# Documentation of tables\n",
10 | "You can add documentation to schemas as follows:"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "d2e1e850",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from typing import Annotated\n",
21 | "from pyspark.sql.types import DateType, StringType\n",
22 | "from typedspark import Column, ColumnMeta, Schema\n",
23 | "\n",
24 | "\n",
25 | "class Person(Schema):\n",
26 | " \"\"\"Dimension table that contains information about a person.\"\"\"\n",
27 | "\n",
28 | " person_id: Annotated[\n",
29 | " Column[StringType],\n",
30 | " ColumnMeta(comment=\"Unique person id\"),\n",
31 | " ]\n",
32 | " gender: Annotated[Column[StringType], ColumnMeta(comment=\"Gender of the person\")]\n",
33 | " birthdate: Annotated[Column[DateType], ColumnMeta(comment=\"Date of birth of the person\")]\n",
34 | " job_id: Annotated[Column[StringType], ColumnMeta(comment=\"Id of the job\")]"
35 | ]
36 | },
37 | {
38 | "attachments": {},
39 | "cell_type": "markdown",
40 | "id": "302ad462",
41 | "metadata": {},
42 | "source": [
43 | "If you use Databricks and Delta Live Tables, you can make the documentation appear in the Databricks UI by using the following Delta Live Table definition:\n",
44 | "\n",
45 | "```python\n",
46 | "@dlt.table(**Person.get_dlt_kwargs())\n",
47 | "def table_definition() -> DataSet[Person]:\n",
48 | " # your table definition here\n",
49 | "```\n"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "id": "17c2b1a6",
55 | "metadata": {},
56 | "source": []
57 | }
58 | ],
59 | "metadata": {
60 | "kernelspec": {
61 | "display_name": "typedspark",
62 | "language": "python",
63 | "name": "python3"
64 | },
65 | "language_info": {
66 | "codemirror_mode": {
67 | "name": "ipython",
68 | "version": 3
69 | },
70 | "file_extension": ".py",
71 | "mimetype": "text/x-python",
72 | "name": "python",
73 | "nbconvert_exporter": "python",
74 | "pygments_lexer": "ipython3",
75 | "version": "3.11.9"
76 | }
77 | },
78 | "nbformat": 4,
79 | "nbformat_minor": 5
80 | }
81 |
--------------------------------------------------------------------------------
/docs/source/ide.rst:
--------------------------------------------------------------------------------
1 | In your IDE
2 | ===========
3 |
4 | .. toctree::
5 |
6 | schema_attributes
7 | type_checking
8 | transforming_datasets
9 | create_empty_datasets
10 | structtype_columns
11 | complex_datatypes
12 | documentation
13 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. include:: README.rst
2 |
3 | .. toctree::
4 |
5 | README
6 | ide
7 | notebook
8 | advanced
9 | contributing
10 | api
11 |
--------------------------------------------------------------------------------
/docs/source/notebook.rst:
--------------------------------------------------------------------------------
1 | In your Notebooks
2 | =================
3 |
4 | .. toctree::
5 |
6 | loading_datasets_in_notebooks
7 | create_schema_in_notebook
8 | autocomplete_in_notebooks
9 | structtypes_in_notebooks
10 |
--------------------------------------------------------------------------------
/docs/source/robots.txt:
--------------------------------------------------------------------------------
1 | User-agent: *
2 | Disallow: /
3 | Allow: /en/stable/
4 | Allow: /en/latest/
5 | Sitemap: https://typedspark.readthedocs.io/sitemap.xml
--------------------------------------------------------------------------------
/docs/source/schema_attributes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "011002f2",
7 | "metadata": {},
8 | "source": [
9 | "# Auto-complete & easier refactoring using schema attributes \n",
10 | "Schemas allow us to replaces the numerous strings throughout the code by schema attributes. Consider the following example:"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "6379b38b",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from typedspark import Column, DataSet, Schema\n",
21 | "from pyspark.sql.types import LongType, StringType\n",
22 | "from pyspark.sql.functions import col\n",
23 | "\n",
24 | "\n",
25 | "class Person(Schema):\n",
26 | " id: Column[LongType]\n",
27 | " name: Column[StringType]\n",
28 | " age: Column[LongType]\n",
29 | "\n",
30 | "\n",
31 | "def birthday(df: DataSet[Person]) -> DataSet[Person]:\n",
32 | " return DataSet[Person](\n",
33 | " df.withColumn(\"age\", col(\"age\") + 1),\n",
34 | " )"
35 | ]
36 | },
37 | {
38 | "attachments": {},
39 | "cell_type": "markdown",
40 | "id": "a03a164a",
41 | "metadata": {},
42 | "source": [
43 | "We can replace this with:"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "id": "1e6fa606",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "def birthday(df: DataSet[Person]) -> DataSet[Person]:\n",
54 | " return DataSet[Person](\n",
55 | " df.withColumn(Person.age.str, Person.age + 1),\n",
56 | " )"
57 | ]
58 | },
59 | {
60 | "attachments": {},
61 | "cell_type": "markdown",
62 | "id": "4232e8ce",
63 | "metadata": {},
64 | "source": [
65 | "Which allows:\n",
66 | "\n",
67 | "* Autocomplete of column names during coding\n",
68 | "* Easy refactoring of column names\n",
69 | "\n",
70 | "Note that we have two options when using schema attributes:\n",
71 | "\n",
72 | "* `Person.age`, which is similar to a Spark `Column` object (i.e. `col(\"age\")`)\n",
73 | "* `Person.age.str`, which is just the column name (i.e. `\"age\"`)\n",
74 | "\n",
75 | "It is usually fairly obvious which one to use. For instance, in the above example, `withColumn()` expects a string as the first argument and a `Column` object as the second argument."
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "id": "4a56cffb",
81 | "metadata": {},
82 | "source": []
83 | }
84 | ],
85 | "metadata": {
86 | "kernelspec": {
87 | "display_name": "typedspark",
88 | "language": "python",
89 | "name": "python3"
90 | },
91 | "language_info": {
92 | "codemirror_mode": {
93 | "name": "ipython",
94 | "version": 3
95 | },
96 | "file_extension": ".py",
97 | "mimetype": "text/x-python",
98 | "name": "python",
99 | "nbconvert_exporter": "python",
100 | "pygments_lexer": "ipython3",
101 | "version": "3.11.9"
102 | }
103 | },
104 | "nbformat": 4,
105 | "nbformat_minor": 5
106 | }
107 |
--------------------------------------------------------------------------------
/docs/source/structtype_columns.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "72077ffb",
7 | "metadata": {},
8 | "source": [
9 | "# StructType Columns\n",
10 | "\n",
11 | "## The basics\n",
12 | "\n",
13 | "We can define `StructType` columns in `typedspark` as follows:"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "id": "934c2a2d",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from pyspark.sql.types import IntegerType, StringType\n",
24 | "from typedspark import DataSet, StructType, Schema, Column\n",
25 | "\n",
26 | "\n",
27 | "class Values(Schema):\n",
28 | " name: Column[StringType]\n",
29 | " severity: Column[IntegerType]\n",
30 | "\n",
31 | "\n",
32 | "class Actions(Schema):\n",
33 | " consequeces: Column[StructType[Values]]"
34 | ]
35 | },
36 | {
37 | "attachments": {},
38 | "cell_type": "markdown",
39 | "id": "a6cd9cb4",
40 | "metadata": {},
41 | "source": [
42 | "We can get auto-complete (and refactoring) of the sub-columns by using:"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "id": "89b3f661",
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "def get_high_severity_actions(df: DataSet[Actions]) -> DataSet[Actions]:\n",
53 | " return df.filter(Actions.consequeces.dtype.schema.severity > 5)"
54 | ]
55 | },
56 | {
57 | "attachments": {},
58 | "cell_type": "markdown",
59 | "id": "276c6561",
60 | "metadata": {},
61 | "source": [
62 | "## Transform to schema\n",
63 | "\n",
64 | "You can use the following syntax to add `StructType` columns in `transform_to_schema()`."
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "e0f4da52",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "from pyspark.sql import SparkSession\n",
75 | "\n",
76 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
77 | "spark.sparkContext.setLogLevel(\"ERROR\")"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 4,
83 | "id": "cf40dcde",
84 | "metadata": {},
85 | "outputs": [
86 | {
87 | "name": "stdout",
88 | "output_type": "stream",
89 | "text": [
90 | "+-----------+\n",
91 | "|consequeces|\n",
92 | "+-----------+\n",
93 | "| {a, 1}|\n",
94 | "| {b, 2}|\n",
95 | "| {c, 3}|\n",
96 | "+-----------+\n",
97 | "\n"
98 | ]
99 | }
100 | ],
101 | "source": [
102 | "from typedspark import create_partially_filled_dataset, transform_to_schema, structtype_column\n",
103 | "\n",
104 | "\n",
105 | "class Input(Schema):\n",
106 | " a: Column[StringType]\n",
107 | " b: Column[IntegerType]\n",
108 | "\n",
109 | "\n",
110 | "df = create_partially_filled_dataset(\n",
111 | " spark,\n",
112 | " Input,\n",
113 | " {\n",
114 | " Input.a: [\"a\", \"b\", \"c\"],\n",
115 | " Input.b: [1, 2, 3],\n",
116 | " },\n",
117 | ")\n",
118 | "\n",
119 | "transform_to_schema(\n",
120 | " df,\n",
121 | " Actions,\n",
122 | " {\n",
123 | " Actions.consequeces: structtype_column(\n",
124 | " Actions.consequeces.dtype.schema,\n",
125 | " {\n",
126 | " Actions.consequeces.dtype.schema.name: Input.a,\n",
127 | " Actions.consequeces.dtype.schema.severity: Input.b,\n",
128 | " },\n",
129 | " )\n",
130 | " },\n",
131 | ").show()"
132 | ]
133 | },
134 | {
135 | "attachments": {},
136 | "cell_type": "markdown",
137 | "id": "271da325",
138 | "metadata": {},
139 | "source": [
140 | "Note that just like in `transform_to_schema()`, the `transformations` dictionary in `structtype_column(..., transformations)` requires columns with unique names as keys.\n",
141 | "\n",
142 | "## Generating DataSets\n",
143 | "\n",
144 | "We can generate `DataSets` with `StructType` columns as follows:"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "id": "3ccb56ff",
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "name": "stdout",
155 | "output_type": "stream",
156 | "text": [
157 | "+-----------+\n",
158 | "|consequeces|\n",
159 | "+-----------+\n",
160 | "| {NULL, 1}|\n",
161 | "| {NULL, 2}|\n",
162 | "| {NULL, 3}|\n",
163 | "+-----------+\n",
164 | "\n"
165 | ]
166 | }
167 | ],
168 | "source": [
169 | "from typedspark import create_partially_filled_dataset\n",
170 | "\n",
171 | "values = create_partially_filled_dataset(\n",
172 | " spark,\n",
173 | " Values,\n",
174 | " {\n",
175 | " Values.severity: [1, 2, 3],\n",
176 | " },\n",
177 | ")\n",
178 | "\n",
179 | "actions = create_partially_filled_dataset(\n",
180 | " spark,\n",
181 | " Actions,\n",
182 | " {\n",
183 | " Actions.consequeces: values.collect(),\n",
184 | " },\n",
185 | ")\n",
186 | "actions.show()"
187 | ]
188 | },
189 | {
190 | "attachments": {},
191 | "cell_type": "markdown",
192 | "id": "8a2232b3",
193 | "metadata": {},
194 | "source": [
195 | "Or in row-wise format:"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": 6,
201 | "id": "04a8e9d7",
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "+-----------+\n",
209 | "|consequeces|\n",
210 | "+-----------+\n",
211 | "| {a, 1}|\n",
212 | "| {b, 2}|\n",
213 | "| {c, 3}|\n",
214 | "+-----------+\n",
215 | "\n"
216 | ]
217 | }
218 | ],
219 | "source": [
220 | "from typedspark import create_structtype_row\n",
221 | "\n",
222 | "create_partially_filled_dataset(\n",
223 | " spark,\n",
224 | " Actions,\n",
225 | " [\n",
226 | " {\n",
227 | " Actions.consequeces: create_structtype_row(\n",
228 | " Values, {Values.name: \"a\", Values.severity: 1}\n",
229 | " ),\n",
230 | " },\n",
231 | " {\n",
232 | " Actions.consequeces: create_structtype_row(\n",
233 | " Values, {Values.name: \"b\", Values.severity: 2}\n",
234 | " ),\n",
235 | " },\n",
236 | " {\n",
237 | " Actions.consequeces: create_structtype_row(\n",
238 | " Values, {Values.name: \"c\", Values.severity: 3}\n",
239 | " ),\n",
240 | " },\n",
241 | " ],\n",
242 | ").show()"
243 | ]
244 | },
245 | {
246 | "attachments": {},
247 | "cell_type": "markdown",
248 | "id": "33d83e5f",
249 | "metadata": {},
250 | "source": []
251 | },
252 | {
253 | "attachments": {},
254 | "cell_type": "markdown",
255 | "id": "0f3ab181",
256 | "metadata": {},
257 | "source": []
258 | }
259 | ],
260 | "metadata": {
261 | "kernelspec": {
262 | "display_name": "typedspark",
263 | "language": "python",
264 | "name": "python3"
265 | },
266 | "language_info": {
267 | "codemirror_mode": {
268 | "name": "ipython",
269 | "version": 3
270 | },
271 | "file_extension": ".py",
272 | "mimetype": "text/x-python",
273 | "name": "python",
274 | "nbconvert_exporter": "python",
275 | "pygments_lexer": "ipython3",
276 | "version": "3.11.9"
277 | }
278 | },
279 | "nbformat": 4,
280 | "nbformat_minor": 5
281 | }
282 |
--------------------------------------------------------------------------------
/docs/source/structtypes_in_notebooks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "a006aaea",
7 | "metadata": {},
8 | "source": [
9 | "# Handling StructType columns in notebooks\n",
10 | "\n",
11 | "First, let us make some example data again."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "id": "5442cec0",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "from pyspark.sql import SparkSession\n",
22 | "\n",
23 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
24 | "spark.sparkContext.setLogLevel(\"ERROR\")"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "id": "420da6e1",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "from typedspark import Schema, Column, StructType, create_partially_filled_dataset, load_table\n",
35 | "from pyspark.sql.types import IntegerType\n",
36 | "\n",
37 | "\n",
38 | "class Values(Schema):\n",
39 | " a: Column[IntegerType]\n",
40 | " b: Column[IntegerType]\n",
41 | "\n",
42 | "\n",
43 | "class Container(Schema):\n",
44 | " values: Column[StructType[Values]]\n",
45 | "\n",
46 | "\n",
47 | "create_partially_filled_dataset(\n",
48 | " spark,\n",
49 | " Container,\n",
50 | " {\n",
51 | " Container.values: create_partially_filled_dataset(\n",
52 | " spark,\n",
53 | " Values,\n",
54 | " {Values.a: [1, 2, 3]},\n",
55 | " ).collect(),\n",
56 | " },\n",
57 | ").createOrReplaceTempView(\"structtype_table\")\n",
58 | "\n",
59 | "container, ContainerSchema = load_table(spark, \"structtype_table\", \"Container\")"
60 | ]
61 | },
62 | {
63 | "attachments": {},
64 | "cell_type": "markdown",
65 | "id": "67cdf490",
66 | "metadata": {},
67 | "source": [
68 | "Like before, we can show the schema simply by running:"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "id": "34205e88",
75 | "metadata": {},
76 | "outputs": [
77 | {
78 | "data": {
79 | "text/plain": [
80 | "\n",
81 | "from pyspark.sql.types import IntegerType\n",
82 | "\n",
83 | "from typedspark import Column, Schema, StructType\n",
84 | "\n",
85 | "\n",
86 | "class Container(Schema):\n",
87 | " values: Column[StructType[Values]]"
88 | ]
89 | },
90 | "execution_count": 3,
91 | "metadata": {},
92 | "output_type": "execute_result"
93 | }
94 | ],
95 | "source": [
96 | "ContainerSchema"
97 | ]
98 | },
99 | {
100 | "attachments": {},
101 | "cell_type": "markdown",
102 | "id": "7a6fac57",
103 | "metadata": {},
104 | "source": [
105 | "We can show the `StructType` schema using:"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 4,
111 | "id": "e88e06a9",
112 | "metadata": {},
113 | "outputs": [
114 | {
115 | "data": {
116 | "text/plain": [
117 | "\n",
118 | "from pyspark.sql.types import IntegerType\n",
119 | "\n",
120 | "from typedspark import Column, Schema\n",
121 | "\n",
122 | "\n",
123 | "class Values(Schema):\n",
124 | " a: Column[IntegerType]\n",
125 | " b: Column[IntegerType]"
126 | ]
127 | },
128 | "execution_count": 4,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | }
132 | ],
133 | "source": [
134 | "ContainerSchema.values.dtype.schema"
135 | ]
136 | },
137 | {
138 | "attachments": {},
139 | "cell_type": "markdown",
140 | "id": "f51beb47",
141 | "metadata": {},
142 | "source": [
143 | "We can also use this in queries, for example:"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 5,
149 | "id": "397c575b",
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "name": "stdout",
154 | "output_type": "stream",
155 | "text": [
156 | "+---------+\n",
157 | "| values|\n",
158 | "+---------+\n",
159 | "|{2, NULL}|\n",
160 | "|{3, NULL}|\n",
161 | "+---------+\n",
162 | "\n"
163 | ]
164 | }
165 | ],
166 | "source": [
167 | "container.filter(ContainerSchema.values.dtype.schema.a > 1).show()"
168 | ]
169 | },
170 | {
171 | "attachments": {},
172 | "cell_type": "markdown",
173 | "id": "18ea295c",
174 | "metadata": {},
175 | "source": []
176 | }
177 | ],
178 | "metadata": {
179 | "kernelspec": {
180 | "display_name": "typedspark",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.11.9"
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 5
199 | }
200 |
--------------------------------------------------------------------------------
/docs/source/subclass_column_meta.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Defining your own ColumnMeta attributes\n",
8 | "\n",
9 | "In this notebook, we will see how to define your own `ColumnMeta` attributes. This is useful when you want to add some metadata to your columns that are not already defined in the `ColumnMeta` class."
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "text/plain": [
20 | "{'id': {'comment': 'Identifies the person', 'primary_key': True},\n",
21 | " 'name': {},\n",
22 | " 'age': {}}"
23 | ]
24 | },
25 | "execution_count": 1,
26 | "metadata": {},
27 | "output_type": "execute_result"
28 | }
29 | ],
30 | "source": [
31 | "from dataclasses import dataclass\n",
32 | "from typing import Annotated\n",
33 | "from pyspark.sql.types import LongType, StringType\n",
34 | "from typedspark import ColumnMeta, Schema\n",
35 | "from typedspark._core.column import Column\n",
36 | "\n",
37 | "\n",
38 | "@dataclass\n",
39 | "class MyColumnMeta(ColumnMeta):\n",
40 | " primary_key: bool = False\n",
41 | "\n",
42 | "\n",
43 | "class Persons(Schema):\n",
44 | " id: Annotated[\n",
45 | " Column[LongType],\n",
46 | " MyColumnMeta(\n",
47 | " comment=\"Identifies the person\",\n",
48 | " primary_key=True,\n",
49 | " ),\n",
50 | " ]\n",
51 | " name: Column[StringType]\n",
52 | " age: Column[LongType]\n",
53 | "\n",
54 | "\n",
55 | "Persons.get_metadata()"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": []
62 | }
63 | ],
64 | "metadata": {
65 | "kernelspec": {
66 | "display_name": "typedspark",
67 | "language": "python",
68 | "name": "python3"
69 | },
70 | "language_info": {
71 | "codemirror_mode": {
72 | "name": "ipython",
73 | "version": 3
74 | },
75 | "file_extension": ".py",
76 | "mimetype": "text/x-python",
77 | "name": "python",
78 | "nbconvert_exporter": "python",
79 | "pygments_lexer": "ipython3",
80 | "version": "3.11.2"
81 | }
82 | },
83 | "nbformat": 4,
84 | "nbformat_minor": 2
85 | }
86 |
--------------------------------------------------------------------------------
/docs/source/subclassing_schemas.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "76ffac97",
7 | "metadata": {},
8 | "source": [
9 | "# Subclassing schemas\n",
10 | "Subclassing schemas is a useful pattern for pipelines where every next function adds a few columns."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "7725965f",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from typedspark import Column, Schema, DataSet\n",
21 | "from pyspark.sql.types import LongType, StringType\n",
22 | "from pyspark.sql.functions import lit\n",
23 | "\n",
24 | "\n",
25 | "class Person(Schema):\n",
26 | " id: Column[LongType]\n",
27 | " name: Column[StringType]\n",
28 | "\n",
29 | "\n",
30 | "class PersonWithAge(Person):\n",
31 | " age: Column[LongType]\n",
32 | "\n",
33 | "\n",
34 | "def foo(df: DataSet[Person]) -> DataSet[PersonWithAge]:\n",
35 | " return DataSet[PersonWithAge](\n",
36 | " df.withColumn(PersonWithAge.age, lit(42)),\n",
37 | " )"
38 | ]
39 | },
40 | {
41 | "attachments": {},
42 | "cell_type": "markdown",
43 | "id": "cc265385",
44 | "metadata": {},
45 | "source": [
46 | "Similarly, you can use this pattern when merging (or joining or concatenating) two datasets together."
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 2,
52 | "id": "c0d6b6a6",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "class PersonA(Schema):\n",
57 | " id: Column[LongType]\n",
58 | " name: Column[StringType]\n",
59 | "\n",
60 | "\n",
61 | "class PersonB(Schema):\n",
62 | " id: Column[LongType]\n",
63 | " age: Column\n",
64 | "\n",
65 | "\n",
66 | "class PersonAB(PersonA, PersonB):\n",
67 | " pass\n",
68 | "\n",
69 | "\n",
70 | "def foo(df_a: DataSet[PersonA], df_b: DataSet[PersonB]) -> DataSet[PersonAB]:\n",
71 | " return DataSet[PersonAB](\n",
72 | " df_a.join(df_b, PersonAB.id),\n",
73 | " )"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "id": "b24f0e09",
79 | "metadata": {},
80 | "source": []
81 | }
82 | ],
83 | "metadata": {
84 | "kernelspec": {
85 | "display_name": "typedspark",
86 | "language": "python",
87 | "name": "python3"
88 | },
89 | "language_info": {
90 | "codemirror_mode": {
91 | "name": "ipython",
92 | "version": 3
93 | },
94 | "file_extension": ".py",
95 | "mimetype": "text/x-python",
96 | "name": "python",
97 | "nbconvert_exporter": "python",
98 | "pygments_lexer": "ipython3",
99 | "version": "3.11.9"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 5
104 | }
105 |
--------------------------------------------------------------------------------
/docs/source/transforming_datasets.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "ab6cce16",
7 | "metadata": {},
8 | "source": [
9 | "# Transforming a DataSet to another schema\n",
10 | "\n",
11 | "## The basics\n",
12 | "\n",
13 | "We often come across the following pattern:"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 1,
19 | "id": "44b1c3bb",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "from pyspark.sql.types import IntegerType, StringType\n",
24 | "from typedspark import Column, Schema, DataSet\n",
25 | "\n",
26 | "\n",
27 | "class Person(Schema):\n",
28 | " name: Column[StringType]\n",
29 | " job_id: Column[IntegerType]\n",
30 | "\n",
31 | "\n",
32 | "class Job(Schema):\n",
33 | " id: Column[IntegerType]\n",
34 | " function: Column[StringType]\n",
35 | " hourly_rate: Column[IntegerType]\n",
36 | "\n",
37 | "\n",
38 | "class PersonWithJob(Person, Job):\n",
39 | " id: Column[IntegerType]\n",
40 | " name: Column[StringType]\n",
41 | " job_name: Column[StringType]\n",
42 | " rate: Column[IntegerType]\n",
43 | "\n",
44 | "\n",
45 | "def get_plumbers(persons: DataSet[Person], jobs: DataSet[Job]) -> DataSet[PersonWithJob]:\n",
46 | " return DataSet[PersonWithJob](\n",
47 | " jobs.filter(Job.function == \"plumber\")\n",
48 | " .join(persons, Job.id == Person.job_id)\n",
49 | " .withColumn(PersonWithJob.job_name.str, Job.function)\n",
50 | " .withColumn(PersonWithJob.rate.str, Job.hourly_rate)\n",
51 | " .select(*PersonWithJob.all_column_names())\n",
52 | " )"
53 | ]
54 | },
55 | {
56 | "attachments": {},
57 | "cell_type": "markdown",
58 | "id": "2ba1ded0",
59 | "metadata": {},
60 | "source": [
61 | "We can make that quite a bit more condensed:"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 2,
67 | "id": "a08d2a1f",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "from typedspark import transform_to_schema\n",
72 | "\n",
73 | "\n",
74 | "def get_plumbers(persons: DataSet[Person], jobs: DataSet[Job]) -> DataSet[PersonWithJob]:\n",
75 | " return transform_to_schema(\n",
76 | " jobs.filter(\n",
77 | " Job.function == \"plumber\",\n",
78 | " ).join(\n",
79 | " persons,\n",
80 | " Job.id == Person.job_id,\n",
81 | " ),\n",
82 | " PersonWithJob,\n",
83 | " {\n",
84 | " PersonWithJob.job_name: Job.function,\n",
85 | " PersonWithJob.rate: Job.hourly_rate,\n",
86 | " },\n",
87 | " )"
88 | ]
89 | },
90 | {
91 | "attachments": {},
92 | "cell_type": "markdown",
93 | "id": "98d00a0f",
94 | "metadata": {},
95 | "source": [
96 | "Specifically, `transform_to_schema()` has the following benefits:\n",
97 | "\n",
98 | "* No more need to cast every return statement using `DataSet[Schema](...)`\n",
99 | "* No more need to drop the columns that are not in the schema using `select(*Schema.all_column_names())`\n",
100 | "* Less verbose syntax compared to `.withColumn(...)`\n",
101 | "\n",
102 | "## Unique keys required\n",
103 | "\n",
104 | "The `transformations` dictionary in `transform_to_schema(..., transformations)` requires columns with unique names as keys. The following pattern will throw an exception."
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 3,
110 | "id": "916f8122",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "from pyspark.sql import SparkSession\n",
115 | "\n",
116 | "spark = SparkSession.Builder().config(\"spark.ui.showConsoleProgress\", \"false\").getOrCreate()\n",
117 | "spark.sparkContext.setLogLevel(\"ERROR\")"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 4,
123 | "id": "756995ae",
124 | "metadata": {},
125 | "outputs": [
126 | {
127 | "name": "stdout",
128 | "output_type": "stream",
129 | "text": [
130 | "[CANNOT_CONVERT_COLUMN_INTO_BOOL] Cannot convert column into bool: please use '&' for 'and', '|' for 'or', '~' for 'not' when building DataFrame boolean expressions.\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "from typedspark import create_partially_filled_dataset\n",
136 | "\n",
137 | "df = create_partially_filled_dataset(spark, Job, {Job.hourly_rate: [10, 20, 30]})\n",
138 | "\n",
139 | "try:\n",
140 | " transform_to_schema(\n",
141 | " df,\n",
142 | " Job,\n",
143 | " {\n",
144 | " Job.hourly_rate: Job.hourly_rate + 3,\n",
145 | " Job.hourly_rate: Job.hourly_rate * 2,\n",
146 | " },\n",
147 | " )\n",
148 | "except ValueError as e:\n",
149 | " print(e)"
150 | ]
151 | },
152 | {
153 | "attachments": {},
154 | "cell_type": "markdown",
155 | "id": "67b9285c",
156 | "metadata": {},
157 | "source": [
158 | "Instead, use one line per column"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 5,
164 | "id": "46c8833f",
165 | "metadata": {},
166 | "outputs": [
167 | {
168 | "name": "stdout",
169 | "output_type": "stream",
170 | "text": [
171 | "+----+--------+-----------+\n",
172 | "| id|function|hourly_rate|\n",
173 | "+----+--------+-----------+\n",
174 | "|NULL| NULL| 26|\n",
175 | "|NULL| NULL| 46|\n",
176 | "|NULL| NULL| 66|\n",
177 | "+----+--------+-----------+\n",
178 | "\n"
179 | ]
180 | }
181 | ],
182 | "source": [
183 | "transform_to_schema(\n",
184 | " df,\n",
185 | " Job,\n",
186 | " {\n",
187 | " Job.hourly_rate: (Job.hourly_rate + 3) * 2,\n",
188 | " },\n",
189 | ").show()"
190 | ]
191 | },
192 | {
193 | "attachments": {},
194 | "cell_type": "markdown",
195 | "id": "6546f023",
196 | "metadata": {},
197 | "source": []
198 | }
199 | ],
200 | "metadata": {
201 | "kernelspec": {
202 | "display_name": "typedspark",
203 | "language": "python",
204 | "name": "python3"
205 | },
206 | "language_info": {
207 | "codemirror_mode": {
208 | "name": "ipython",
209 | "version": 3
210 | },
211 | "file_extension": ".py",
212 | "mimetype": "text/x-python",
213 | "name": "python",
214 | "nbconvert_exporter": "python",
215 | "pygments_lexer": "ipython3",
216 | "version": "3.11.9"
217 | }
218 | },
219 | "nbformat": 4,
220 | "nbformat_minor": 5
221 | }
222 |
--------------------------------------------------------------------------------
/docs/videos/ide.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# IDE demo\n",
8 | "\n",
9 | "This notebook contains the code accompanying the IDE demo. You can find the video [here](\n",
10 | "https://github.com/kaiko-ai/typedspark/assets/47976799/e6f7fa9c-6d14-4f68-baba-fe3c22f75b67)."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from pyspark.sql.types import DateType, LongType, StringType\n",
20 | "from typedspark import DataSet, Column, Schema\n",
21 | "\n",
22 | "\n",
23 | "class Pets(Schema):\n",
24 | " pet_id: Column[LongType]\n",
25 | " owner_id: Column[LongType]\n",
26 | " pet_name: Column[StringType]\n",
27 | " species: Column[StringType]\n",
28 | " breed: Column[StringType]\n",
29 | " age: Column[LongType]\n",
30 | " birthdate: Column[DateType]\n",
31 | " gender: Column[StringType]\n",
32 | "\n",
33 | "\n",
34 | "class Vaccinations(Schema):\n",
35 | " vaccination_id: Column[LongType]\n",
36 | " pet_id: Column[LongType]\n",
37 | " vaccine_name: Column[StringType]\n",
38 | " vaccine_date: Column[DateType]\n",
39 | " next_due_date: Column[DateType]\n",
40 | "\n",
41 | "\n",
42 | "class Owners(Schema):\n",
43 | " owner_id: Column[LongType]\n",
44 | " first_name: Column[StringType]\n",
45 | " last_name: Column[StringType]\n",
46 | " email: Column[StringType]\n",
47 | " phone_number: Column[StringType]\n",
48 | " address: Column[StringType]"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 2,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "def get_dogs(pets: DataSet[Pets]) -> DataSet[Pets]:\n",
58 | " return pets.filter(Pets.species == \"dog\")"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 3,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "from chispa.dataframe_comparer import assert_df_equality\n",
68 | "from pyspark.sql import SparkSession\n",
69 | "from typedspark import create_partially_filled_dataset\n",
70 | "\n",
71 | "\n",
72 | "def test_get_dogs(spark: SparkSession):\n",
73 | " pets = create_partially_filled_dataset(\n",
74 | " spark,\n",
75 | " Pets,\n",
76 | " {\n",
77 | " Pets.pet_id: [1, 2, 3],\n",
78 | " Pets.species: [\"dog\", \"cat\", \"dog\"],\n",
79 | " },\n",
80 | " )\n",
81 | "\n",
82 | " observed = get_dogs(pets)\n",
83 | " expected = create_partially_filled_dataset(\n",
84 | " spark,\n",
85 | " Pets,\n",
86 | " {\n",
87 | " Pets.pet_id: [1, 3],\n",
88 | " Pets.species: [\"dog\", \"dog\"],\n",
89 | " },\n",
90 | " )\n",
91 | "\n",
92 | " assert_df_equality(\n",
93 | " observed,\n",
94 | " expected,\n",
95 | " ignore_row_order=True,\n",
96 | " ignore_nullable=True,\n",
97 | " )"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 4,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "from pyspark.sql.functions import concat_ws\n",
107 | "from typedspark import (\n",
108 | " register_schema_to_dataset,\n",
109 | " transform_to_schema,\n",
110 | ")\n",
111 | "\n",
112 | "\n",
113 | "class Reminder(Schema):\n",
114 | " owner_id: Column[LongType]\n",
115 | " pet_id: Column[LongType]\n",
116 | " vaccination_id: Column[LongType]\n",
117 | " full_name: Column[StringType]\n",
118 | " email_address: Column[StringType]\n",
119 | " pet_name: Column[StringType]\n",
120 | " vaccine: Column[StringType]\n",
121 | " due: Column[DateType]\n",
122 | "\n",
123 | "\n",
124 | "def find_owners_who_need_to_renew_their_pets_vaccinations(\n",
125 | " owners: DataSet[Owners],\n",
126 | " pets: DataSet[Pets],\n",
127 | " vaccinations: DataSet[Vaccinations],\n",
128 | ") -> DataSet[Reminder]:\n",
129 | " _owners = register_schema_to_dataset(owners, Owners)\n",
130 | " _pets = register_schema_to_dataset(pets, Pets)\n",
131 | " _vaccinations = register_schema_to_dataset(vaccinations, Vaccinations)\n",
132 | "\n",
133 | " return transform_to_schema(\n",
134 | " owners.join(\n",
135 | " pets,\n",
136 | " _owners.owner_id == _pets.owner_id,\n",
137 | " \"inner\",\n",
138 | " ).join(\n",
139 | " vaccinations,\n",
140 | " _pets.pet_id == _vaccinations.pet_id,\n",
141 | " \"inner\",\n",
142 | " ),\n",
143 | " Reminder,\n",
144 | " {\n",
145 | " Reminder.owner_id: _owners.owner_id,\n",
146 | " Reminder.pet_id: _pets.pet_id,\n",
147 | " Reminder.vaccination_id: _vaccinations.vaccination_id,\n",
148 | " Reminder.full_name: concat_ws(\n",
149 | " \" \",\n",
150 | " _owners.first_name,\n",
151 | " _owners.last_name,\n",
152 | " ),\n",
153 | " Reminder.email_address: _owners.email,\n",
154 | " Reminder.pet_name: _pets.pet_name,\n",
155 | " Reminder.vaccine: _vaccinations.vaccine_name,\n",
156 | " Reminder.due: _vaccinations.next_due_date,\n",
157 | " },\n",
158 | " )"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": []
167 | }
168 | ],
169 | "metadata": {
170 | "kernelspec": {
171 | "display_name": "typedspark",
172 | "language": "python",
173 | "name": "python3"
174 | },
175 | "language_info": {
176 | "codemirror_mode": {
177 | "name": "ipython",
178 | "version": 3
179 | },
180 | "file_extension": ".py",
181 | "mimetype": "text/x-python",
182 | "name": "python",
183 | "nbconvert_exporter": "python",
184 | "pygments_lexer": "ipython3",
185 | "version": "3.11.9"
186 | }
187 | },
188 | "nbformat": 4,
189 | "nbformat_minor": 2
190 | }
191 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 100
3 |
4 | [tool.isort]
5 | profile = "black"
6 | line_length = 100
7 |
8 | [tool.pylint]
9 | ignore-paths = ['^tests/.+.py$', 'setup.py', '^docs/.+.py$']
10 |
11 | [tool.pylint."messages control"]
12 | disable = "all"
13 | enable = [
14 | "empty-docstring",
15 | "missing-module-docstring",
16 | "missing-class-docstring",
17 | "missing-function-docstring",
18 | "unidiomatic-typecheck",
19 | "no-else-return",
20 | "consider-using-dict-comprehension",
21 | "dangerous-default-value",
22 | "unspecified-encoding",
23 | "unnecessary-pass",
24 | "redefined-outer-name",
25 | "invalid-name",
26 | "unused-argument",
27 | "redefined-builtin",
28 | "simplifiable-if-expression",
29 | "logging-fstring-interpolation",
30 | "inconsistent-return-statements",
31 | "consider-using-set-comprehension"
32 | ]
33 |
34 | [tool.coverage.run]
35 | source = ["typedspark/"]
36 |
37 | [tool.mypy]
38 | exclude = ['^docs\/.+\.py$']
39 |
40 | [tool.bandit]
41 | exclude_dirs = ["tests"]
--------------------------------------------------------------------------------
/renovate.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json",
3 | "extends": [
4 | "config:base"
5 | ],
6 | "autoApprove": true,
7 | "automerge": true
8 | }
9 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # pyspark
2 | pyspark==4.0.0
3 | # linters
4 | flake8==7.2.0
5 | pylint==3.3.7
6 | black[jupyter]<=25.1.0
7 | bandit==1.8.3
8 | isort==6.0.1
9 | docformatter==1.7.7
10 | mypy==1.16.0
11 | pyright<=1.1.402
12 | autoflake==2.3.1
13 | # stubs
14 | pandas-stubs<=2.2.3.250527
15 | types-setuptools==80.9.0.20250529
16 | # testing
17 | pytest==8.4.0
18 | coverage==7.8.2
19 | pandas==2.3.0
20 | setuptools==80.9.0
21 | chispa==0.11.1
22 | # notebooks
23 | nbconvert==7.16.6
24 | jupyter==1.1.1
25 | nbformat==5.10.4
26 | # readthedocs
27 | sphinx<=8.2.3
28 | sphinx-rtd-theme==3.0.2
29 | nbsphinx==0.9.7
30 | # precommit hook
31 | pre-commit==4.2.0
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | typing-extensions<=4.14.0
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | def get_requirements():
5 | with open("requirements.txt") as f:
6 | return f.read().splitlines()
7 |
8 |
9 | def get_long_description():
10 | with open("README.md", encoding="utf-8") as f:
11 | return f.read()
12 |
13 |
14 | setup(
15 | name="typedspark",
16 | url="https://github.com/kaiko-ai/typedspark",
17 | license="Apache-2.0",
18 | author="Nanne Aben",
19 | author_email="nanne@kaiko.ai",
20 | description="Column-wise type annotations for pyspark DataFrames",
21 | keywords="pyspark spark typing type checking annotations",
22 | long_description=get_long_description(),
23 | long_description_content_type="text/markdown",
24 | packages=find_packages(include=["typedspark", "typedspark.*"]),
25 | install_requires=get_requirements(),
26 | python_requires=">=3.9.0",
27 | classifiers=["Programming Language :: Python", "Typing :: Typed"],
28 | setuptools_git_versioning={"enabled": True},
29 | setup_requires=["setuptools-git-versioning>=2.0,<3"],
30 | package_data={"typedspark": ["py.typed"]},
31 | extras_require={
32 | "pyspark": ["pyspark"],
33 | },
34 | )
35 |
--------------------------------------------------------------------------------
/tests/_core/test_column.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Annotated
3 |
4 | import pandas as pd
5 | import pytest
6 | from pyspark.sql import SparkSession
7 | from pyspark.sql.functions import lit
8 | from pyspark.sql.types import LongType, StringType
9 |
10 | from typedspark import Column, ColumnMeta, Schema
11 | from typedspark._utils.create_dataset import create_partially_filled_dataset
12 |
13 |
14 | class A(Schema):
15 | a: Column[LongType]
16 | b: Column[StringType]
17 |
18 |
19 | def test_column(spark: SparkSession):
20 | (
21 | spark.createDataFrame(
22 | pd.DataFrame(
23 | dict(
24 | a=[1, 2, 3],
25 | )
26 | )
27 | )
28 | .filter(A.a == 1)
29 | .withColumn(A.b.str, lit("a"))
30 | )
31 |
32 |
33 | def test_column_doesnt_exist():
34 | with pytest.raises(TypeError):
35 | A.z
36 |
37 |
38 | @pytest.mark.no_spark_session
39 | def test_column_reference_without_spark_session():
40 | a = A.a
41 | assert a.str == "a"
42 |
43 |
44 | def test_column_with_deprecated_dataframe_param(spark: SparkSession):
45 | df = create_partially_filled_dataset(spark, A, {A.a: [1, 2, 3]})
46 | Column("a", dataframe=df)
47 |
48 |
49 | @dataclass
50 | class MyColumnMeta(ColumnMeta):
51 | primary_key: bool = False
52 |
53 |
54 | class Persons(Schema):
55 | id: Annotated[
56 | Column[LongType],
57 | MyColumnMeta(
58 | comment="Identifies the person",
59 | primary_key=True,
60 | ),
61 | ]
62 | name: Column[StringType]
63 | age: Column[LongType]
64 |
65 |
66 | def test_get_metadata():
67 | assert Persons.get_metadata()["id"] == {
68 | "comment": "Identifies the person",
69 | "primary_key": True,
70 | }
71 |
--------------------------------------------------------------------------------
/tests/_core/test_dataset.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import pandas as pd
4 | import pytest
5 | from pyspark import StorageLevel
6 | from pyspark.sql import DataFrame, SparkSession
7 | from pyspark.sql.types import LongType, StringType
8 |
9 | from typedspark import Column, DataSet, Schema
10 | from typedspark._core.dataset import DataSetImplements
11 | from typedspark._utils.create_dataset import create_empty_dataset
12 |
13 |
14 | class A(Schema):
15 | a: Column[LongType]
16 | b: Column[StringType]
17 |
18 |
19 | class B(Schema):
20 | a: Column[LongType]
21 | b: Column[StringType]
22 |
23 |
24 | def create_dataframe(spark: SparkSession, d):
25 | return spark.createDataFrame(pd.DataFrame(d))
26 |
27 |
28 | def test_dataset(spark: SparkSession):
29 | d = dict(
30 | a=[1, 2, 3],
31 | b=["a", "b", "c"],
32 | )
33 | df = create_dataframe(spark, d)
34 | DataSet[A](df)
35 |
36 |
37 | def test_dataset_allow_underscored_columns_not_in_schema(spark: SparkSession):
38 | d = {"a": [1, 2, 3], "b": ["a", "b", "c"], "__c": [1, 2, 3]}
39 | df = create_dataframe(spark, d)
40 | DataSet[A](df)
41 |
42 |
43 | def test_dataset_single_underscored_column_should_raise(spark: SparkSession):
44 | d = {"a": [1, 2, 3], "b": ["a", "b", "c"], "_c": [1, 2, 3]}
45 | df = create_dataframe(spark, d)
46 | with pytest.raises(TypeError):
47 | DataSet[A](df)
48 |
49 |
50 | def test_dataset_missing_colnames(spark: SparkSession):
51 | d = dict(
52 | a=[1, 2, 3],
53 | )
54 | df = create_dataframe(spark, d)
55 | with pytest.raises(TypeError):
56 | DataSet[A](df)
57 |
58 |
59 | def test_dataset_too_many_colnames(spark: SparkSession):
60 | d = dict(
61 | a=[1, 2, 3],
62 | b=["a", "b", "c"],
63 | c=[1, 2, 3],
64 | )
65 | df = create_dataframe(spark, d)
66 | with pytest.raises(TypeError):
67 | DataSet[A](df)
68 |
69 |
70 | def test_wrong_type(spark: SparkSession):
71 | d = dict(
72 | a=[1, 2, 3],
73 | b=[1, 2, 3],
74 | )
75 | df = create_dataframe(spark, d)
76 | with pytest.raises(TypeError):
77 | DataSet[A](df)
78 |
79 |
80 | def test_inherrited_functions(spark: SparkSession):
81 | df = create_empty_dataset(spark, A)
82 |
83 | df.distinct()
84 | cached1: DataSet[A] = df.cache()
85 | cached2: DataSet[A] = df.persist(StorageLevel.MEMORY_AND_DISK)
86 | assert isinstance(df.filter(A.a == 1), DataSet)
87 | assert isinstance(df.where(A.a == 1), DataSet)
88 | df.orderBy(A.a)
89 | df.transform(lambda df: df)
90 |
91 | cached1.unpersist(True)
92 | cached2.unpersist(True)
93 |
94 |
95 | def test_inherrited_functions_with_other_dataset(spark: SparkSession):
96 | df_a = create_empty_dataset(spark, A)
97 | df_b = create_empty_dataset(spark, A)
98 |
99 | df_a.join(df_b, A.a.str)
100 | df_a.unionByName(df_b)
101 |
102 |
103 | def test_schema_property_of_dataset(spark: SparkSession):
104 | df = create_empty_dataset(spark, A)
105 | assert df.typedspark_schema == A
106 |
107 |
108 | def test_initialize_dataset_implements(spark: SparkSession):
109 | with pytest.raises(NotImplementedError):
110 | DataSetImplements()
111 |
112 |
113 | def test_reduce(spark: SparkSession):
114 | functools.reduce(
115 | DataSet.unionByName,
116 | [create_empty_dataset(spark, A), create_empty_dataset(spark, A)],
117 | )
118 |
119 |
120 | def test_resetting_of_schema_annotations(spark: SparkSession):
121 | df = create_empty_dataset(spark, A)
122 |
123 | a: DataFrame
124 |
125 | # if no schema is specified, the annotation should be None
126 | a = DataSet(df)
127 | assert a._schema_annotations is None
128 |
129 | # when we specify a schema, the class variable will be set to A, but afterwards it should be
130 | # reset to None again when we initialize a new object without specifying a schema
131 | DataSet[A]
132 | a = DataSet(df)
133 | assert a._schema_annotations is None
134 |
135 | # and then to B
136 | a = DataSet[B](df)
137 | assert a._schema_annotations == B
138 |
139 | # and then to None again
140 | a = DataSet(df)
141 | assert a._schema_annotations is None
142 |
--------------------------------------------------------------------------------
/tests/_core/test_datatypes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql import SparkSession
3 | from pyspark.sql.types import ArrayType as SparkArrayType
4 | from pyspark.sql.types import LongType
5 | from pyspark.sql.types import MapType as SparkMapType
6 | from pyspark.sql.types import StringType, StructField
7 | from pyspark.sql.types import StructType as SparkStructType
8 |
9 | from typedspark import ArrayType, Column, DataSet, MapType, Schema, StructType, create_empty_dataset
10 |
11 |
12 | class SubSchema(Schema):
13 | a: Column[StringType]
14 | b: Column[StringType]
15 |
16 |
17 | class Example(Schema):
18 | a: Column[MapType[StringType, StringType]]
19 | b: Column[ArrayType[StringType]]
20 | c: Column[StructType[SubSchema]]
21 |
22 |
23 | def test_complex_datatypes_equals(spark: SparkSession):
24 | df = create_empty_dataset(spark, Example)
25 |
26 | assert df.schema["a"] == StructField("a", SparkMapType(StringType(), StringType()))
27 | assert df.schema["b"] == StructField("b", SparkArrayType(StringType()))
28 | structfields = [
29 | StructField("a", StringType(), True),
30 | StructField("b", StringType(), True),
31 | ]
32 | assert df.schema["c"] == StructField("c", SparkStructType(structfields))
33 |
34 |
35 | class ArrayTypeSchema(Schema):
36 | a: Column[ArrayType[StringType]]
37 |
38 |
39 | class DifferentArrayTypeSchema(Schema):
40 | a: Column[ArrayType[LongType]]
41 |
42 |
43 | class MapTypeSchema(Schema):
44 | a: Column[MapType[StringType, StringType]]
45 |
46 |
47 | class DifferentKeyMapTypeSchema(Schema):
48 | a: Column[MapType[LongType, StringType]]
49 |
50 |
51 | class DifferentValueMapTypeSchema(Schema):
52 | a: Column[MapType[StringType, LongType]]
53 |
54 |
55 | class DifferentSubSchema(Schema):
56 | a: Column[LongType]
57 | b: Column[LongType]
58 |
59 |
60 | class StructTypeSchema(Schema):
61 | a: Column[StructType[SubSchema]]
62 |
63 |
64 | class DifferentStructTypeSchema(Schema):
65 | a: Column[StructType[DifferentSubSchema]]
66 |
67 |
68 | def test_complex_datatypes_not_equals(spark: SparkSession):
69 | with pytest.raises(TypeError):
70 | df1 = create_empty_dataset(spark, ArrayTypeSchema)
71 | DataSet[DifferentArrayTypeSchema](df1)
72 |
73 | df2 = create_empty_dataset(spark, MapTypeSchema)
74 | with pytest.raises(TypeError):
75 | DataSet[DifferentKeyMapTypeSchema](df2)
76 | with pytest.raises(TypeError):
77 | DataSet[DifferentValueMapTypeSchema](df2)
78 |
79 | with pytest.raises(TypeError):
80 | df3 = create_empty_dataset(spark, StructTypeSchema)
81 | DataSet[DifferentStructTypeSchema](df3)
82 |
--------------------------------------------------------------------------------
/tests/_core/test_metadata.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pyspark.sql import SparkSession
4 | from pyspark.sql.types import LongType
5 |
6 | from typedspark import Column, DataSet, Schema, create_empty_dataset
7 | from typedspark._core.column_meta import ColumnMeta
8 |
9 |
10 | class A(Schema):
11 | a: Annotated[Column[LongType], ColumnMeta(comment="test")]
12 | b: Column[LongType]
13 |
14 |
15 | def test_add_schema_metadata(spark: SparkSession):
16 | df: DataSet[A] = create_empty_dataset(spark, A, 1)
17 | assert df.schema["a"].metadata == {"comment": "test"}
18 | assert df.schema["b"].metadata == {}
19 |
20 |
21 | class B(Schema):
22 | a: Column[LongType]
23 | b: Annotated[Column[LongType], ColumnMeta(comment="test")]
24 |
25 |
26 | def test_refresh_metadata(spark: SparkSession):
27 | df_a = create_empty_dataset(spark, A, 1)
28 | df_b = DataSet[B](df_a)
29 | assert df_b.schema["a"].metadata == {}
30 | assert df_b.schema["b"].metadata == {"comment": "test"}
31 |
--------------------------------------------------------------------------------
/tests/_schema/test_create_spark_schema.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql.types import LongType, StringType, StructField, StructType
3 |
4 | from typedspark import Column, Schema
5 |
6 |
7 | class A(Schema):
8 | a: Column[LongType]
9 | b: Column[StringType]
10 |
11 |
12 | class B(Schema):
13 | a: Column
14 |
15 |
16 | def test_create_spark_schema():
17 | result = A.get_structtype()
18 | expected = StructType(
19 | [
20 | StructField("a", LongType(), True),
21 | StructField("b", StringType(), True),
22 | ]
23 | )
24 |
25 | assert result == expected
26 |
27 |
28 | def test_create_spark_schema_with_faulty_schema():
29 | with pytest.raises(TypeError):
30 | B.get_structtype()
31 |
--------------------------------------------------------------------------------
/tests/_schema/test_get_schema_definition.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pyspark.sql.types import IntegerType, StringType
4 |
5 | from typedspark import Column, DayTimeIntervalType, IntervalType, Schema
6 | from typedspark._schema.get_schema_definition import _replace_literal, _replace_literals
7 |
8 |
9 | class A(Schema):
10 | """This is a docstring for A."""
11 |
12 | a: Annotated[Column[IntegerType], "Some column"]
13 | b: Column[StringType]
14 |
15 |
16 | def test_replace_literal():
17 | result = _replace_literal(
18 | "DayTimeIntervalType[Literal[0], Literal[1]]",
19 | replace_literals_in=DayTimeIntervalType,
20 | original="Literal[0]",
21 | replacement="IntervalType.DAY",
22 | )
23 | expected = "DayTimeIntervalType[IntervalType.DAY, Literal[1]]"
24 |
25 | assert result == expected
26 |
27 |
28 | def test_replace_literals():
29 | result = _replace_literals(
30 | "DayTimeIntervalType[Literal[0], Literal[1]]",
31 | replace_literals_in=DayTimeIntervalType,
32 | replace_literals_by=IntervalType,
33 | )
34 | expected = "DayTimeIntervalType[IntervalType.DAY, IntervalType.HOUR]"
35 |
36 | assert result == expected
37 |
38 |
39 | def test_get_schema_definition_as_string():
40 | result = A.get_schema_definition_as_string(include_documentation=True)
41 | expected = '''from typing import Annotated
42 |
43 | from pyspark.sql.types import IntegerType, StringType
44 |
45 | from typedspark import Column, ColumnMeta, Schema
46 |
47 |
48 | class A(Schema):
49 | """This is a docstring for A."""
50 |
51 | a: Annotated[Column[IntegerType], ColumnMeta(comment="Some column")]
52 | b: Annotated[Column[StringType], ColumnMeta(comment="")]
53 | '''
54 | assert result == expected
55 |
--------------------------------------------------------------------------------
/tests/_schema/test_offending_schemas.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, List, Type
2 |
3 | import pytest
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.types import StringType
6 |
7 | from typedspark import ArrayType, Column, ColumnMeta, MapType, Schema, create_empty_dataset
8 | from typedspark._core.datatypes import DecimalType
9 |
10 |
11 | class InvalidColumn(Schema):
12 | a: int
13 |
14 |
15 | class ColumnWithoutType(Schema):
16 | a: Column
17 |
18 |
19 | class AnnotationWithoutColumn(Schema):
20 | a: Annotated # type: ignore
21 |
22 |
23 | class InvalidColumnMeta(Schema):
24 | a: Annotated[StringType, str]
25 |
26 |
27 | class InvalidDataTypeWithinAnnotation(Schema):
28 | a: Annotated[str, ColumnMeta()] # type: ignore
29 |
30 |
31 | class InvalidDataType(Schema):
32 | a: Column[int] # type: ignore
33 |
34 |
35 | class ComplexTypeWithoutSubtype(Schema):
36 | a: Column[ArrayType]
37 |
38 |
39 | class ComplexTypeWithInvalidSubtype(Schema):
40 | a: Column[ArrayType[int]] # type: ignore
41 |
42 |
43 | class InvalidDataTypeWithArguments(Schema):
44 | a: Column[List[str]] # type: ignore
45 |
46 |
47 | class DecimalTypeWithoutArguments(Schema):
48 | a: Column[DecimalType] # type: ignore
49 |
50 |
51 | class DecimalTypeWithIncorrectArguments(Schema):
52 | a: Column[DecimalType[int, int]] # type: ignore
53 |
54 |
55 | offending_schemas: List[Type[Schema]] = [
56 | InvalidColumn,
57 | ColumnWithoutType,
58 | AnnotationWithoutColumn,
59 | InvalidColumnMeta,
60 | InvalidDataTypeWithinAnnotation,
61 | InvalidDataType,
62 | ComplexTypeWithoutSubtype,
63 | ComplexTypeWithInvalidSubtype,
64 | InvalidDataTypeWithArguments,
65 | ]
66 |
67 |
68 | def test_offending_schema_exceptions(spark: SparkSession):
69 | for schema in offending_schemas:
70 | with pytest.raises(TypeError):
71 | create_empty_dataset(spark, schema)
72 |
73 |
74 | def test_offending_schemas_repr_exceptions():
75 | for schema in offending_schemas:
76 | schema.get_schema_definition_as_string(generate_imports=True)
77 |
78 |
79 | def test_offending_schemas_dtype():
80 | with pytest.raises(TypeError):
81 | ColumnWithoutType.a.dtype
82 |
83 |
84 | def test_offending_schemas_runtime_error_on_load():
85 | with pytest.raises(TypeError):
86 |
87 | class WrongNumberOfArguments(Schema):
88 | a: Column[MapType[StringType]] # type: ignore
89 |
--------------------------------------------------------------------------------
/tests/_schema/test_schema.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Literal, Type
2 |
3 | import pytest
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.types import LongType, StringType, StructField, StructType
6 |
7 | import typedspark
8 | from typedspark import Column, ColumnMeta, Schema, create_partially_filled_dataset
9 | from typedspark._core.literaltype import IntervalType
10 | from typedspark._schema.schema import DltKwargs
11 |
12 |
13 | class A(Schema):
14 | a: Column[LongType]
15 | b: Column[StringType]
16 |
17 |
18 | schema_a_string = """
19 | from pyspark.sql.types import LongType, StringType
20 |
21 | from typedspark import Column, Schema
22 |
23 |
24 | class A(Schema):
25 | a: Column[LongType]
26 | b: Column[StringType]
27 | """
28 |
29 | schema_a_string_with_documentation = '''from typing import Annotated
30 |
31 | from pyspark.sql.types import LongType, StringType
32 |
33 | from typedspark import Column, ColumnMeta, Schema
34 |
35 |
36 | class A(Schema):
37 | """Add documentation here."""
38 |
39 | a: Annotated[Column[LongType], ColumnMeta(comment="")]
40 | b: Annotated[Column[StringType], ColumnMeta(comment="")]
41 | '''
42 |
43 |
44 | class B(Schema):
45 | b: Column[LongType]
46 | a: Column[StringType]
47 |
48 |
49 | class Values(Schema):
50 | a: Column[typedspark.DecimalType[Literal[38], Literal[18]]]
51 | b: Column[StringType]
52 |
53 |
54 | class ComplexDatatypes(Schema):
55 | value: Column[typedspark.StructType[Values]]
56 | items: Column[typedspark.ArrayType[StringType]]
57 | consequences: Column[typedspark.MapType[StringType, typedspark.ArrayType[StringType]]]
58 | diff: Column[typedspark.DayTimeIntervalType[IntervalType.DAY, IntervalType.SECOND]]
59 |
60 |
61 | schema_complex_datatypes = '''from typing import Annotated, Literal
62 |
63 | from pyspark.sql.types import StringType
64 |
65 | from typedspark import ArrayType, Column, ColumnMeta, DayTimeIntervalType, DecimalType, IntervalType, MapType, Schema, StructType
66 |
67 |
68 | class ComplexDatatypes(Schema):
69 | """Add documentation here."""
70 |
71 | value: Annotated[Column[StructType[test_schema.Values]], ColumnMeta(comment="")]
72 | items: Annotated[Column[ArrayType[StringType]], ColumnMeta(comment="")]
73 | consequences: Annotated[Column[MapType[StringType, ArrayType[StringType]]], ColumnMeta(comment="")]
74 | diff: Annotated[Column[DayTimeIntervalType[IntervalType.DAY, IntervalType.SECOND]], ColumnMeta(comment="")]
75 |
76 |
77 | class Values(Schema):
78 | """Add documentation here."""
79 |
80 | a: Annotated[Column[DecimalType[Literal[38], Literal[18]]], ColumnMeta(comment="")]
81 | b: Annotated[Column[StringType], ColumnMeta(comment="")]
82 | ''' # noqa: E501
83 |
84 |
85 | class PascalCase(Schema):
86 | """Schema docstring."""
87 |
88 | a: Annotated[Column[StringType], ColumnMeta(comment="some")]
89 | b: Annotated[Column[LongType], ColumnMeta(comment="other")]
90 |
91 |
92 | def test_all_column_names():
93 | assert A.all_column_names() == ["a", "b"]
94 | assert B.all_column_names() == ["b", "a"]
95 |
96 |
97 | def test_all_column_names_except_for():
98 | assert A.all_column_names_except_for(["a"]) == ["b"]
99 | assert B.all_column_names_except_for([]) == ["b", "a"]
100 | assert B.all_column_names_except_for(["b", "a"]) == []
101 |
102 |
103 | def test_get_snake_case():
104 | assert A.get_snake_case() == "a"
105 | assert PascalCase.get_snake_case() == "pascal_case"
106 |
107 |
108 | def test_get_docstring():
109 | assert A.get_docstring() == ""
110 | assert PascalCase.get_docstring() == "Schema docstring."
111 |
112 |
113 | def test_get_structtype():
114 | assert A.get_structtype() == StructType(
115 | [StructField("a", LongType(), True), StructField("b", StringType(), True)]
116 | )
117 | assert PascalCase.get_structtype() == StructType(
118 | [
119 | StructField("a", StringType(), metadata={"comment": "some"}),
120 | StructField("b", LongType(), metadata={"comment": "other"}),
121 | ]
122 | )
123 |
124 |
125 | def test_get_dlt_kwargs():
126 | assert A.get_dlt_kwargs() == DltKwargs(
127 | name="a",
128 | comment="",
129 | schema=StructType(
130 | [StructField("a", LongType(), True), StructField("b", StringType(), True)]
131 | ),
132 | )
133 |
134 | assert PascalCase.get_dlt_kwargs() == DltKwargs(
135 | name="pascal_case",
136 | comment="Schema docstring.",
137 | schema=StructType(
138 | [
139 | StructField("a", StringType(), metadata={"comment": "some"}),
140 | StructField("b", LongType(), metadata={"comment": "other"}),
141 | ]
142 | ),
143 | )
144 |
145 |
146 | def test_repr():
147 | assert repr(A) == schema_a_string
148 |
149 |
150 | @pytest.mark.parametrize(
151 | "schema, expected_schema_definition",
152 | [
153 | (A, schema_a_string_with_documentation),
154 | (ComplexDatatypes, schema_complex_datatypes),
155 | ],
156 | )
157 | def test_get_schema(schema: Type[Schema], expected_schema_definition: str):
158 | schema_definition = schema.get_schema_definition_as_string(include_documentation=True)
159 | assert schema_definition == expected_schema_definition
160 |
161 |
162 | def test_dtype_attributes(spark: SparkSession):
163 | assert isinstance(A.a.dtype, LongType)
164 | assert isinstance(ComplexDatatypes.items.dtype, typedspark.ArrayType)
165 | assert isinstance(ComplexDatatypes.value.dtype, typedspark.StructType)
166 | assert ComplexDatatypes.value.dtype.schema == Values
167 | assert isinstance(ComplexDatatypes.value.dtype.schema.b.dtype, StringType)
168 |
169 | df = create_partially_filled_dataset(
170 | spark,
171 | ComplexDatatypes,
172 | {
173 | ComplexDatatypes.value: create_partially_filled_dataset(
174 | spark,
175 | Values,
176 | {
177 | Values.b: ["a", "b", "c"],
178 | },
179 | ).collect(),
180 | },
181 | )
182 | assert df.filter(ComplexDatatypes.value.dtype.schema.b == "b").count() == 1
183 |
--------------------------------------------------------------------------------
/tests/_schema/test_structfield.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, get_type_hints
2 |
3 | import pytest
4 | from pyspark.sql.types import BooleanType, LongType, StringType, StructField
5 |
6 | from typedspark import Column, ColumnMeta, Schema
7 | from typedspark._schema.structfield import (
8 | _get_structfield_dtype,
9 | get_structfield,
10 | get_structfield_meta,
11 | )
12 |
13 |
14 | class A(Schema):
15 | a: Column[LongType]
16 | b: Column[StringType]
17 | c: Annotated[Column[StringType], ColumnMeta(comment="comment")]
18 | d: Annotated[Column[BooleanType], ColumnMeta(comment="comment2")]
19 |
20 |
21 | @pytest.fixture()
22 | def type_hints():
23 | return get_type_hints(A, include_extras=True)
24 |
25 |
26 | def test_get_structfield_dtype(type_hints):
27 | assert _get_structfield_dtype(Column[LongType], "a") == LongType()
28 | assert _get_structfield_dtype(type_hints["b"], "b") == StringType()
29 | assert (
30 | _get_structfield_dtype(
31 | Annotated[Column[StringType], ColumnMeta(comment="comment")], # type: ignore
32 | "c",
33 | )
34 | == StringType()
35 | )
36 | assert _get_structfield_dtype(type_hints["d"], "d") == BooleanType()
37 |
38 |
39 | def test_get_structfield_metadata(type_hints):
40 | assert get_structfield_meta(Column[LongType]) == ColumnMeta()
41 | assert get_structfield_meta(type_hints["b"]) == ColumnMeta()
42 | assert get_structfield_meta(
43 | Annotated[Column[StringType], ColumnMeta(comment="comment")] # type: ignore
44 | ) == ColumnMeta(comment="comment")
45 | assert get_structfield_meta(type_hints["d"]) == ColumnMeta(comment="comment2")
46 |
47 |
48 | def test_get_structfield(type_hints):
49 | assert get_structfield("a", Column[LongType]) == StructField(name="a", dataType=LongType())
50 | assert get_structfield("b", type_hints["b"]) == StructField(name="b", dataType=StringType())
51 | assert get_structfield( # type: ignore
52 | "c",
53 | Annotated[Column[StringType], ColumnMeta(comment="comment")], # type: ignore
54 | ) == StructField(name="c", dataType=StringType(), metadata={"comment": "comment"})
55 | assert get_structfield("d", type_hints["d"]) == StructField(
56 | name="d", dataType=BooleanType(), metadata={"comment": "comment2"}
57 | )
58 |
--------------------------------------------------------------------------------
/tests/_transforms/test_structtype_column.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from chispa.dataframe_comparer import assert_df_equality # type: ignore
3 | from pyspark.sql import SparkSession
4 | from pyspark.sql.types import IntegerType
5 |
6 | from typedspark import Column, Schema, StructType, structtype_column
7 | from typedspark._transforms.transform_to_schema import transform_to_schema
8 | from typedspark._utils.create_dataset import create_partially_filled_dataset
9 |
10 |
11 | class SubSchema(Schema):
12 | a: Column[IntegerType]
13 | b: Column[IntegerType]
14 |
15 |
16 | class MainSchema(Schema):
17 | a: Column[IntegerType]
18 | b: Column[StructType[SubSchema]]
19 |
20 |
21 | def test_structtype_column(spark: SparkSession):
22 | df = create_partially_filled_dataset(spark, MainSchema, {MainSchema.a: [1, 2, 3]})
23 | observed = transform_to_schema(
24 | df,
25 | MainSchema,
26 | {
27 | MainSchema.b: structtype_column(
28 | SubSchema,
29 | {SubSchema.a: MainSchema.a + 2, SubSchema.b: MainSchema.a + 4},
30 | )
31 | },
32 | )
33 | expected = create_partially_filled_dataset(
34 | spark,
35 | MainSchema,
36 | {
37 | MainSchema.a: [1, 2, 3],
38 | MainSchema.b: create_partially_filled_dataset(
39 | spark, SubSchema, {SubSchema.a: [3, 4, 5], SubSchema.b: [5, 6, 7]}
40 | ).collect(),
41 | },
42 | )
43 | assert_df_equality(observed, expected, ignore_nullable=True)
44 |
45 |
46 | def test_structtype_column_different_column_order(spark: SparkSession):
47 | df = create_partially_filled_dataset(spark, MainSchema, {MainSchema.a: [1, 2, 3]})
48 | observed = transform_to_schema(
49 | df,
50 | MainSchema,
51 | {
52 | MainSchema.b: structtype_column(
53 | SubSchema,
54 | {SubSchema.b: MainSchema.a + 4, SubSchema.a: MainSchema.a + 2},
55 | )
56 | },
57 | )
58 | expected = create_partially_filled_dataset(
59 | spark,
60 | MainSchema,
61 | {
62 | MainSchema.a: [1, 2, 3],
63 | MainSchema.b: create_partially_filled_dataset(
64 | spark, SubSchema, {SubSchema.a: [3, 4, 5], SubSchema.b: [5, 6, 7]}
65 | ).collect(),
66 | },
67 | )
68 | assert_df_equality(observed, expected, ignore_nullable=True)
69 |
70 |
71 | def test_structtype_column_partial(spark: SparkSession):
72 | df = create_partially_filled_dataset(spark, MainSchema, {MainSchema.a: [1, 2, 3]})
73 | observed = transform_to_schema(
74 | df,
75 | MainSchema,
76 | {
77 | MainSchema.b: structtype_column(
78 | SubSchema,
79 | {SubSchema.a: MainSchema.a + 2},
80 | fill_unspecified_columns_with_nulls=True,
81 | )
82 | },
83 | )
84 | expected = create_partially_filled_dataset(
85 | spark,
86 | MainSchema,
87 | {
88 | MainSchema.a: [1, 2, 3],
89 | MainSchema.b: create_partially_filled_dataset(
90 | spark,
91 | SubSchema,
92 | {SubSchema.a: [3, 4, 5], SubSchema.b: [None, None, None]},
93 | ).collect(),
94 | },
95 | )
96 | assert_df_equality(observed, expected, ignore_nullable=True)
97 |
98 |
99 | def test_structtype_column_with_double_column(spark: SparkSession):
100 | df = create_partially_filled_dataset(spark, MainSchema, {MainSchema.a: [1, 2, 3]})
101 | with pytest.raises(ValueError):
102 | transform_to_schema(
103 | df,
104 | MainSchema,
105 | {
106 | MainSchema.b: structtype_column(
107 | SubSchema,
108 | {SubSchema.a: MainSchema.a + 2, SubSchema.a: MainSchema.a + 2},
109 | )
110 | },
111 | )
112 |
--------------------------------------------------------------------------------
/tests/_transforms/test_transform_to_schema.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from chispa.dataframe_comparer import assert_df_equality # type: ignore
3 | from pyspark.sql import SparkSession
4 | from pyspark.sql.types import IntegerType, StringType
5 |
6 | from typedspark import (
7 | Column,
8 | Schema,
9 | create_empty_dataset,
10 | register_schema_to_dataset,
11 | transform_to_schema,
12 | )
13 | from typedspark._utils.create_dataset import create_partially_filled_dataset
14 |
15 |
16 | class Person(Schema):
17 | a: Column[IntegerType]
18 | b: Column[IntegerType]
19 | c: Column[IntegerType]
20 |
21 |
22 | class PersonLessData(Schema):
23 | a: Column[IntegerType]
24 | b: Column[IntegerType]
25 |
26 |
27 | class PersonDifferentData(Schema):
28 | a: Column[IntegerType]
29 | d: Column[IntegerType]
30 |
31 |
32 | def test_transform_to_schema_without_transformations(spark: SparkSession):
33 | df = create_empty_dataset(spark, Person)
34 | observed = transform_to_schema(df, PersonLessData)
35 | expected = create_empty_dataset(spark, PersonLessData)
36 | assert_df_equality(observed, expected)
37 |
38 |
39 | def test_transform_to_schema_with_transformation(spark: SparkSession):
40 | df = create_partially_filled_dataset(spark, Person, {Person.c: [1, 2, 3]})
41 | observed = transform_to_schema(df, PersonDifferentData, {PersonDifferentData.d: Person.c + 3})
42 | expected = create_partially_filled_dataset(
43 | spark, PersonDifferentData, {PersonDifferentData.d: [4, 5, 6]}
44 | )
45 | assert_df_equality(observed, expected)
46 |
47 |
48 | def test_transform_to_schema_with_missing_column(spark):
49 | df = create_partially_filled_dataset(spark, Person, {Person.c: [1, 2, 3]}).drop(Person.a)
50 | with pytest.raises(Exception):
51 | transform_to_schema(df, PersonDifferentData, {PersonDifferentData.d: Person.c + 3})
52 |
53 | observed = transform_to_schema(
54 | df,
55 | PersonDifferentData,
56 | {PersonDifferentData.d: Person.c + 3},
57 | fill_unspecified_columns_with_nulls=True,
58 | )
59 |
60 | expected = create_partially_filled_dataset(
61 | spark,
62 | PersonDifferentData,
63 | {PersonDifferentData.d: [4, 5, 6]},
64 | )
65 | assert_df_equality(observed, expected)
66 |
67 |
68 | def test_transform_to_schema_with_pre_existing_column(spark):
69 | df = create_partially_filled_dataset(spark, Person, {Person.a: [0, 1, 2], Person.c: [1, 2, 3]})
70 |
71 | observed = transform_to_schema(
72 | df,
73 | PersonDifferentData,
74 | {PersonDifferentData.d: Person.c + 3},
75 | fill_unspecified_columns_with_nulls=True,
76 | )
77 |
78 | expected = create_partially_filled_dataset(
79 | spark,
80 | PersonDifferentData,
81 | {PersonDifferentData.a: [0, 1, 2], PersonDifferentData.d: [4, 5, 6]},
82 | )
83 | assert_df_equality(observed, expected)
84 |
85 |
86 | class PersonA(Schema):
87 | name: Column[StringType]
88 | age: Column[StringType]
89 |
90 |
91 | class PersonB(Schema):
92 | name: Column[StringType]
93 | age: Column[StringType]
94 |
95 |
96 | def test_transform_to_schema_with_column_disambiguation(spark: SparkSession):
97 | df_a = create_partially_filled_dataset(
98 | spark,
99 | PersonA,
100 | {PersonA.name: ["John", "Jane", "Bob"], PersonA.age: [30, 40, 50]},
101 | )
102 | df_b = create_partially_filled_dataset(
103 | spark,
104 | PersonB,
105 | {PersonB.name: ["John", "Jane", "Bob"], PersonB.age: [31, 41, 51]},
106 | )
107 |
108 | person_a = register_schema_to_dataset(df_a, PersonA)
109 | person_b = register_schema_to_dataset(df_b, PersonB)
110 |
111 | with pytest.raises(ValueError):
112 | transform_to_schema(
113 | df_a.join(df_b, person_a.name == person_b.name),
114 | PersonA,
115 | {
116 | person_a.age: person_a.age + 3,
117 | person_b.age: person_b.age + 5,
118 | },
119 | )
120 |
121 | with pytest.raises(ValueError):
122 | transform_to_schema(
123 | df_a.join(df_b, person_a.name == person_b.name),
124 | PersonA,
125 | {
126 | PersonA.age: person_a.age,
127 | },
128 | )
129 |
130 | res = transform_to_schema(
131 | df_a.join(df_b, person_a.name == person_b.name),
132 | PersonA,
133 | {
134 | PersonA.name: person_a.name,
135 | PersonB.age: person_b.age,
136 | },
137 | )
138 | expected = create_partially_filled_dataset(
139 | spark,
140 | PersonA,
141 | {PersonA.name: ["John", "Jane", "Bob"], PersonA.age: [31, 41, 51]},
142 | )
143 | assert_df_equality(res, expected, ignore_row_order=True)
144 |
145 |
146 | def test_transform_to_schema_with_double_column(spark: SparkSession):
147 | df = create_partially_filled_dataset(spark, Person, {Person.a: [1, 2, 3], Person.b: [1, 2, 3]})
148 |
149 | with pytest.raises(ValueError):
150 | transform_to_schema(
151 | df,
152 | Person,
153 | {
154 | Person.a: Person.a + 3,
155 | Person.a: Person.a + 5,
156 | },
157 | )
158 |
159 |
160 | def test_transform_to_schema_sequential(spark: SparkSession):
161 | df = create_partially_filled_dataset(spark, Person, {Person.a: [1, 2, 3], Person.b: [1, 2, 3]})
162 |
163 | observed = transform_to_schema(
164 | df,
165 | Person,
166 | {
167 | Person.a: Person.a + 3,
168 | Person.b: Person.a + 5,
169 | },
170 | )
171 |
172 | expected = create_partially_filled_dataset(
173 | spark, Person, {Person.a: [4, 5, 6], Person.b: [9, 10, 11]}
174 | )
175 |
176 | assert_df_equality(observed, expected)
177 |
178 |
179 | def test_transform_to_schema_parallel(spark: SparkSession):
180 | df = create_partially_filled_dataset(spark, Person, {Person.a: [1, 2, 3], Person.b: [1, 2, 3]})
181 |
182 | observed = transform_to_schema(
183 | df,
184 | Person,
185 | {
186 | Person.a: Person.a + 3,
187 | Person.b: Person.a + 5,
188 | },
189 | run_sequentially=False,
190 | )
191 |
192 | expected = create_partially_filled_dataset(
193 | spark, Person, {Person.a: [4, 5, 6], Person.b: [6, 7, 8]}
194 | )
195 |
196 | assert_df_equality(observed, expected)
197 |
--------------------------------------------------------------------------------
/tests/_utils/test_create_dataset.py:
--------------------------------------------------------------------------------
1 | from decimal import Decimal
2 | from typing import Literal
3 |
4 | import pytest
5 | from chispa.dataframe_comparer import assert_df_equality # type: ignore
6 | from pyspark.sql import SparkSession
7 | from pyspark.sql.types import StringType
8 |
9 | from typedspark import (
10 | ArrayType,
11 | Column,
12 | DataSet,
13 | MapType,
14 | Schema,
15 | StructType,
16 | create_empty_dataset,
17 | create_partially_filled_dataset,
18 | create_structtype_row,
19 | )
20 | from typedspark._core.datatypes import DecimalType
21 |
22 |
23 | class A(Schema):
24 | a: Column[DecimalType[Literal[38], Literal[18]]]
25 | b: Column[StringType]
26 |
27 |
28 | def test_create_empty_dataset(spark: SparkSession):
29 | n_rows = 2
30 | result: DataSet[A] = create_empty_dataset(spark, A, n_rows)
31 |
32 | spark_schema = A.get_structtype()
33 | data = [(None, None), (None, None)]
34 | expected = spark.createDataFrame(data, spark_schema)
35 |
36 | assert_df_equality(result, expected)
37 |
38 |
39 | def test_create_partially_filled_dataset(spark: SparkSession):
40 | data = {A.a: [Decimal(x) for x in [1, 2, 3]]}
41 | result: DataSet[A] = create_partially_filled_dataset(spark, A, data)
42 |
43 | spark_schema = A.get_structtype()
44 | row_data = [(Decimal(1), None), (Decimal(2), None), (Decimal(3), None)]
45 | expected = spark.createDataFrame(row_data, spark_schema)
46 |
47 | assert_df_equality(result, expected)
48 |
49 |
50 | def test_create_partially_filled_dataset_with_different_number_of_rows(
51 | spark: SparkSession,
52 | ):
53 | with pytest.raises(ValueError):
54 | create_partially_filled_dataset(spark, A, {A.a: [1], A.b: ["a", "b"]})
55 |
56 |
57 | class B(Schema):
58 | a: Column[ArrayType[StringType]]
59 | b: Column[MapType[StringType, StringType]]
60 | c: Column[StructType[A]]
61 |
62 |
63 | def test_create_empty_dataset_with_complex_data(spark: SparkSession):
64 | df_a = create_partially_filled_dataset(spark, A, {A.a: [Decimal(x) for x in [1, 2, 3]]})
65 |
66 | result = create_partially_filled_dataset(
67 | spark,
68 | B,
69 | {
70 | B.a: [["a"], ["b", "c"], ["d"]],
71 | B.b: [{"a": "1"}, {"b": "2", "c": "3"}, {"d": "4"}],
72 | B.c: df_a.collect(),
73 | },
74 | )
75 |
76 | spark_schema = B.get_structtype()
77 | row_data = [
78 | (["a"], {"a": "1"}, (Decimal(1), None)),
79 | (["b", "c"], {"b": "2", "c": "3"}, (Decimal(2), None)),
80 | (["d"], {"d": "4"}, (Decimal(3), None)),
81 | ]
82 | expected = spark.createDataFrame(row_data, spark_schema)
83 |
84 | assert_df_equality(result, expected)
85 |
86 |
87 | def test_create_partially_filled_dataset_from_list(spark: SparkSession):
88 | result = create_partially_filled_dataset(
89 | spark,
90 | A,
91 | [
92 | {A.a: Decimal(1), A.b: "a"},
93 | {A.a: Decimal(2)},
94 | {A.b: "c", A.a: Decimal(3)},
95 | ],
96 | )
97 |
98 | spark_schema = A.get_structtype()
99 | row_data = [(Decimal(1), "a"), (Decimal(2), None), (Decimal(3), "c")]
100 | expected = spark.createDataFrame(row_data, spark_schema)
101 |
102 | assert_df_equality(result, expected)
103 |
104 |
105 | def test_create_partially_filled_dataset_from_list_with_complex_data(spark: SparkSession):
106 | result = create_partially_filled_dataset(
107 | spark,
108 | B,
109 | [
110 | {
111 | B.a: ["a"],
112 | B.b: {"a": "1"},
113 | B.c: create_structtype_row(A, {A.a: Decimal(1), A.b: "a"}),
114 | },
115 | {
116 | B.a: ["b", "c"],
117 | B.b: {"b": "2", "c": "3"},
118 | B.c: create_structtype_row(A, {A.a: Decimal(2)}),
119 | },
120 | {
121 | B.a: ["d"],
122 | B.b: {"d": "4"},
123 | B.c: create_structtype_row(A, {A.b: "c", A.a: Decimal(3)}),
124 | },
125 | ],
126 | )
127 |
128 | spark_schema = B.get_structtype()
129 | row_data = [
130 | (["a"], {"a": "1"}, (Decimal(1), "a")),
131 | (["b", "c"], {"b": "2", "c": "3"}, (Decimal(2), None)),
132 | (["d"], {"d": "4"}, (Decimal(3), "c")),
133 | ]
134 | expected = spark.createDataFrame(row_data, spark_schema)
135 |
136 | assert_df_equality(result, expected)
137 |
138 |
139 | def test_create_partially_filled_dataset_with_invalid_argument(spark: SparkSession):
140 | with pytest.raises(ValueError):
141 | create_partially_filled_dataset(spark, A, ()) # type: ignore
142 |
--------------------------------------------------------------------------------
/tests/_utils/test_load_table.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import pytest
4 | from chispa.dataframe_comparer import assert_df_equality # type: ignore
5 | from pyspark.sql import SparkSession
6 | from pyspark.sql.functions import first
7 | from pyspark.sql.types import IntegerType, StringType
8 |
9 | from typedspark import (
10 | ArrayType,
11 | Column,
12 | Databases,
13 | DecimalType,
14 | MapType,
15 | Schema,
16 | StructType,
17 | create_empty_dataset,
18 | load_table,
19 | )
20 | from typedspark._core.datatypes import DayTimeIntervalType
21 | from typedspark._core.literaltype import IntervalType
22 | from typedspark._utils.create_dataset import create_partially_filled_dataset
23 | from typedspark._utils.databases import Catalogs, _get_spark_session
24 | from typedspark._utils.load_table import create_schema
25 |
26 |
27 | class SubSchema(Schema):
28 | a: Column[IntegerType]
29 |
30 |
31 | class A(Schema):
32 | a: Column[IntegerType]
33 | b: Column[ArrayType[IntegerType]]
34 | c: Column[ArrayType[MapType[IntegerType, IntegerType]]]
35 | d: Column[DayTimeIntervalType[IntervalType.HOUR, IntervalType.MINUTE]]
36 | e: Column[DecimalType[Literal[7], Literal[2]]]
37 | value_container: Column[StructType[SubSchema]]
38 |
39 |
40 | def test_load_table(spark: SparkSession) -> None:
41 | df = create_empty_dataset(spark, A)
42 | df.createOrReplaceTempView("temp")
43 |
44 | df_loaded, schema = load_table(spark, "temp")
45 |
46 | assert_df_equality(df, df_loaded)
47 | assert schema.get_structtype() == A.get_structtype()
48 | assert schema.get_schema_name() != "A"
49 |
50 |
51 | def test_load_table_with_schema_name(spark: SparkSession) -> None:
52 | df = create_empty_dataset(spark, A)
53 | df.createOrReplaceTempView("temp")
54 |
55 | df_loaded, schema = load_table(spark, "temp", schema_name="A")
56 |
57 | assert_df_equality(df, df_loaded)
58 | assert schema.get_structtype() == A.get_structtype()
59 | assert schema.get_schema_name() == "A"
60 |
61 |
62 | class B(Schema):
63 | a: Column[StringType]
64 | b: Column[IntegerType]
65 | c: Column[StringType]
66 |
67 |
68 | def test_create_schema(spark: SparkSession) -> None:
69 | df = (
70 | create_partially_filled_dataset(
71 | spark,
72 | B,
73 | {
74 | B.a: ["a", "b!!", "c", "a", "b!!", "c", "a", "b!!", "c"],
75 | B.b: [1, 1, 1, 2, 2, 2, 3, 3, 3],
76 | B.c: ["alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota"],
77 | },
78 | )
79 | .groupby(B.b)
80 | .pivot(B.a.str)
81 | .agg(first(B.c))
82 | )
83 |
84 | df, MySchema = create_schema(df, "B")
85 |
86 | assert MySchema.get_schema_name() == "B"
87 | assert "a" in MySchema.all_column_names()
88 | assert "b__" in MySchema.all_column_names()
89 | assert "c" in MySchema.all_column_names()
90 |
91 |
92 | def test_create_schema_with_duplicated_column_names(spark: SparkSession) -> None:
93 | df = (
94 | create_partially_filled_dataset(
95 | spark,
96 | B,
97 | {
98 | B.a: ["a", "b??", "c", "a", "b!!", "c", "a", "b!!", "c"],
99 | B.b: [1, 1, 1, 2, 2, 2, 3, 3, 3],
100 | B.c: ["alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta", "iota"],
101 | },
102 | )
103 | .groupby(B.b)
104 | .pivot(B.a.str)
105 | .agg(first(B.c))
106 | )
107 |
108 | with pytest.raises(ValueError):
109 | create_schema(df, "B")
110 |
111 |
112 | def test_name_of_structtype_schema(spark):
113 | df = create_empty_dataset(spark, A)
114 | df, MySchema = create_schema(df, "A")
115 |
116 | assert MySchema.value_container.dtype.schema.get_schema_name() == "ValueContainer"
117 |
118 |
119 | def test_databases_with_temp_view(spark):
120 | df = create_empty_dataset(spark, A)
121 | df.createOrReplaceTempView("table_a")
122 |
123 | db = Databases(spark)
124 | for df_loaded, schema in [db.default.table_a(), db.default.table_a.load()]: # type: ignore
125 | assert_df_equality(df, df_loaded)
126 | assert schema.get_structtype() == A.get_structtype()
127 | assert schema.get_schema_name() == "TableA"
128 | assert db.default.table_a.str == "table_a" # type: ignore
129 | assert db.default.str == "default" # type: ignore
130 |
131 |
132 | def _drop_table(spark: SparkSession, table_name: str) -> None:
133 | spark.sql(f"DROP TABLE IF EXISTS {table_name}")
134 |
135 |
136 | def test_databases_with_table(spark: SparkSession):
137 | df = create_empty_dataset(spark, A)
138 | df.write.saveAsTable("default.table_b")
139 |
140 | try:
141 | db = Databases(spark)
142 | df_loaded, schema = db.default.table_b() # type: ignore
143 |
144 | assert_df_equality(df, df_loaded)
145 | assert schema.get_structtype() == A.get_structtype()
146 | assert schema.get_schema_name() == "TableB"
147 | assert db.default.table_b.str == "default.table_b" # type: ignore
148 | assert db.default.str == "default" # type: ignore
149 | except Exception as exception:
150 | _drop_table(spark, "default.table_b")
151 | raise exception
152 |
153 | _drop_table(spark, "default.table_b")
154 |
155 |
156 | def test_databases_with_table_name_starting_with_underscore(spark: SparkSession):
157 | df = create_empty_dataset(spark, A)
158 | df.write.saveAsTable("default._table_b")
159 |
160 | try:
161 | db = Databases(spark)
162 | df_loaded, _ = db.default.u_table_b() # type: ignore
163 | assert_df_equality(df, df_loaded)
164 | assert db.default.u_table_b.str == "default._table_b" # type: ignore
165 | except Exception as exception:
166 | _drop_table(spark, "default._table_b")
167 | raise exception
168 |
169 | _drop_table(spark, "default._table_b")
170 |
171 |
172 | def test_databases_with_table_name_starting_with_underscore_with_naming_conflict(
173 | spark: SparkSession,
174 | ):
175 | df_a = create_empty_dataset(spark, A)
176 | df_b = create_empty_dataset(spark, B)
177 | df_a.write.saveAsTable("default._table_b")
178 | df_b.write.saveAsTable("default.u_table_b")
179 |
180 | try:
181 | db = Databases(spark)
182 | df_loaded, _ = db.default.u__table_b() # type: ignore
183 | assert_df_equality(df_a, df_loaded)
184 | assert db.default.u__table_b.str == "default._table_b" # type: ignore
185 |
186 | df_loaded, _ = db.default.u_table_b() # type: ignore
187 | assert_df_equality(df_b, df_loaded)
188 | assert db.default.u_table_b.str == "default.u_table_b" # type: ignore
189 | except Exception as exception:
190 | _drop_table(spark, "default._table_b")
191 | _drop_table(spark, "default.u_table_b")
192 | raise exception
193 |
194 | _drop_table(spark, "default._table_b")
195 | _drop_table(spark, "default.u_table_b")
196 |
197 |
198 | def test_catalogs(spark: SparkSession):
199 | df = create_empty_dataset(spark, A)
200 | df.write.saveAsTable("spark_catalog.default.table_b")
201 |
202 | try:
203 | db = Catalogs(spark)
204 | df_loaded, schema = db.spark_catalog.default.table_b() # type: ignore
205 |
206 | assert_df_equality(df, df_loaded)
207 | assert schema.get_structtype() == A.get_structtype()
208 | assert schema.get_schema_name() == "TableB"
209 | assert db.spark_catalog.default.table_b.str == "spark_catalog.default.table_b" # type: ignore # noqa: E501
210 | except Exception as exception:
211 | _drop_table(spark, "spark_catalog.default.table_b")
212 | raise exception
213 |
214 | _drop_table(spark, "spark_catalog.default.table_b")
215 |
216 |
217 | def test_get_spark_session(spark: SparkSession):
218 | res = _get_spark_session(None)
219 |
220 | assert res == spark
221 |
222 |
223 | @pytest.mark.no_spark_session
224 | def test_get_spark_session_without_spark_session():
225 | if SparkSession.getActiveSession() is None:
226 | with pytest.raises(ValueError):
227 | _get_spark_session(None)
228 |
--------------------------------------------------------------------------------
/tests/_utils/test_register_schema_to_dataset.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from chispa.dataframe_comparer import assert_df_equality # type: ignore
3 | from pyspark.errors import AnalysisException
4 | from pyspark.sql import SparkSession
5 | from pyspark.sql.types import IntegerType, StringType
6 |
7 | from typedspark import (
8 | Column,
9 | Schema,
10 | StructType,
11 | create_partially_filled_dataset,
12 | register_schema_to_dataset,
13 | )
14 | from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset_with_alias
15 |
16 |
17 | class Pets(Schema):
18 | name: Column[StringType]
19 | age: Column[IntegerType]
20 |
21 |
22 | class Person(Schema):
23 | a: Column[IntegerType]
24 | b: Column[IntegerType]
25 | pets: Column[StructType[Pets]]
26 |
27 |
28 | class Job(Schema):
29 | a: Column[IntegerType]
30 | c: Column[IntegerType]
31 |
32 |
33 | class Age(Schema):
34 | age_1: Column[IntegerType]
35 | age_2: Column[IntegerType]
36 |
37 |
38 | def test_register_schema_to_dataset(spark: SparkSession):
39 | df_a = create_partially_filled_dataset(spark, Person, {Person.a: [1, 2, 3]})
40 | df_b = create_partially_filled_dataset(spark, Job, {Job.a: [1, 2, 3]})
41 |
42 | with pytest.raises(AnalysisException):
43 | df_a.join(df_b, Person.a == Job.a)
44 |
45 | person = register_schema_to_dataset(df_a, Person)
46 | job = register_schema_to_dataset(df_b, Job)
47 |
48 | assert person.get_schema_name() == "Person"
49 | assert hash(person.a) != hash(Person.a)
50 |
51 | df_a.join(df_b, person.a == job.a)
52 |
53 |
54 | def test_register_schema_to_dataset_with_alias(spark: SparkSession):
55 | df = create_partially_filled_dataset(
56 | spark,
57 | Person,
58 | {
59 | Person.a: [1, 2, 3],
60 | Person.b: [1, 2, 3],
61 | Person.pets: create_partially_filled_dataset(
62 | spark,
63 | Pets,
64 | {
65 | Pets.name: ["Bobby", "Bobby", "Bobby"],
66 | Pets.age: [10, 20, 30],
67 | },
68 | ).collect(),
69 | },
70 | )
71 |
72 | with pytest.raises(AnalysisException):
73 | df_a = df.alias("a")
74 | df_b = df.alias("b")
75 | schema_a = register_schema_to_dataset(df_a, Person)
76 | schema_b = register_schema_to_dataset(df_b, Person)
77 | df_a.join(df_b, schema_a.a == schema_b.b)
78 |
79 | df_a, schema_a = register_schema_to_dataset_with_alias(df, Person, "a")
80 | df_b, schema_b = register_schema_to_dataset_with_alias(df, Person, "b")
81 | joined = df_a.join(df_b, schema_a.a == schema_b.b)
82 |
83 | res = joined.select(
84 | schema_a.pets.dtype.schema.age.alias(Age.age_1.str),
85 | schema_b.pets.dtype.schema.age.alias(Age.age_2.str),
86 | )
87 |
88 | expected = create_partially_filled_dataset(
89 | spark, Age, {Age.age_1: [10, 20, 30], Age.age_2: [10, 20, 30]}
90 | )
91 |
92 | assert_df_equality(res, expected)
93 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import pytest
5 | from pyspark.sql import SparkSession
6 |
7 |
8 | @pytest.fixture(scope="session")
9 | def spark():
10 | """Fixture for creating a spark session."""
11 | os.environ["PYSPARK_PYTHON"] = sys.executable
12 | os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
13 |
14 | spark = SparkSession.Builder().getOrCreate()
15 | yield spark
16 | spark.stop()
17 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 100
3 |
--------------------------------------------------------------------------------
/typedspark/__init__.py:
--------------------------------------------------------------------------------
1 | """Typedspark: column-wise type annotations for pyspark DataFrames."""
2 |
3 | from typedspark._core.column import Column
4 | from typedspark._core.column_meta import ColumnMeta
5 | from typedspark._core.dataset import DataSet, DataSetImplements
6 | from typedspark._core.datatypes import (
7 | ArrayType,
8 | DayTimeIntervalType,
9 | DecimalType,
10 | MapType,
11 | StructType,
12 | )
13 | from typedspark._core.literaltype import IntervalType
14 | from typedspark._schema.schema import MetaSchema, Schema
15 | from typedspark._transforms.structtype_column import structtype_column
16 | from typedspark._transforms.transform_to_schema import transform_to_schema
17 | from typedspark._utils.create_dataset import (
18 | create_empty_dataset,
19 | create_partially_filled_dataset,
20 | create_structtype_row,
21 | )
22 | from typedspark._utils.databases import Catalogs, Database, Databases
23 | from typedspark._utils.load_table import create_schema, load_table
24 | from typedspark._utils.register_schema_to_dataset import (
25 | register_schema_to_dataset,
26 | register_schema_to_dataset_with_alias,
27 | )
28 |
29 | __all__ = [
30 | "ArrayType",
31 | "Catalogs",
32 | "Column",
33 | "ColumnMeta",
34 | "Database",
35 | "Databases",
36 | "DataSet",
37 | "DayTimeIntervalType",
38 | "DecimalType",
39 | "IntervalType",
40 | "MapType",
41 | "MetaSchema",
42 | "DataSetImplements",
43 | "Schema",
44 | "StructType",
45 | "create_empty_dataset",
46 | "create_partially_filled_dataset",
47 | "create_structtype_row",
48 | "create_schema",
49 | "load_table",
50 | "register_schema_to_dataset",
51 | "register_schema_to_dataset_with_alias",
52 | "structtype_column",
53 | "transform_to_schema",
54 | ]
55 |
--------------------------------------------------------------------------------
/typedspark/_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaiko-ai/typedspark/0c99ffb7cb9eddee9121254163f74b8de25d9f6b/typedspark/_core/__init__.py
--------------------------------------------------------------------------------
/typedspark/_core/column.py:
--------------------------------------------------------------------------------
1 | """Module containing classes and functions related to TypedSpark Columns."""
2 |
3 | from logging import warn
4 | from typing import Generic, Optional, TypeVar, Union, get_args, get_origin
5 |
6 | from pyspark.sql import Column as SparkColumn
7 | from pyspark.sql import DataFrame, SparkSession
8 | from pyspark.sql.functions import col
9 | from pyspark.sql.types import DataType
10 |
11 | from typedspark._core.datatypes import StructType
12 |
13 | T = TypeVar("T", bound=DataType)
14 |
15 |
16 | class EmptyColumn(SparkColumn):
17 | """Column object to be instantiated when there is no active Spark session."""
18 |
19 | def __init__(self, *args, **kwargs) -> None: # pragma: no cover
20 | pass
21 |
22 |
23 | class Column(SparkColumn, Generic[T]):
24 | """Represents a ``Column`` in a ``Schema``. Can be used as:
25 |
26 | .. code-block:: python
27 |
28 | class A(Schema):
29 | a: Column[IntegerType]
30 | b: Column[StringType]
31 | """
32 |
33 | def __new__(
34 | cls,
35 | name: str,
36 | dataframe: Optional[DataFrame] = None,
37 | curid: Optional[int] = None,
38 | dtype: Optional[T] = None,
39 | parent: Union[DataFrame, "Column", None] = None,
40 | alias: Optional[str] = None,
41 | ):
42 | """``__new__()`` instantiates the object (prior to ``__init__()``).
43 |
44 | Here, we simply take the provided ``name``, create a pyspark
45 | ``Column`` object and cast it to a typedspark ``Column`` object.
46 | This allows us to bypass the pypsark ``Column`` constuctor in
47 | ``__init__()``, which requires parameters that may be difficult
48 | to access.
49 | """
50 | # pylint: disable=unused-argument
51 |
52 | if dataframe is not None and parent is None:
53 | parent = dataframe
54 | warn("The use of Column(dataframe=...) is deprecated, use Column(parent=...) instead.")
55 |
56 | column: SparkColumn
57 | if SparkSession.getActiveSession() is None:
58 | column = EmptyColumn() # pragma: no cover
59 | elif alias is not None:
60 | column = col(f"{alias}.{name}")
61 | elif parent is not None:
62 | column = parent[name]
63 | else:
64 | column = col(name)
65 |
66 | column.__class__ = Column # type: ignore
67 | return column
68 |
69 | def __init__(
70 | self,
71 | name: str,
72 | dataframe: Optional[DataFrame] = None,
73 | curid: Optional[int] = None,
74 | dtype: Optional[T] = None,
75 | parent: Union[DataFrame, "Column", None] = None,
76 | alias: Optional[str] = None,
77 | ):
78 | # pylint: disable=unused-argument
79 | self.str = name
80 | self._dtype = dtype if dtype is not None else DataType
81 | self._curid = curid
82 |
83 | def __hash__(self) -> int:
84 | return hash((self.str, self._curid))
85 |
86 | @property
87 | def dtype(self) -> T:
88 | """Get the datatype of the column, e.g. Column[IntegerType] -> IntegerType."""
89 | dtype = self._dtype
90 |
91 | if get_origin(dtype) == StructType:
92 | return StructType(
93 | schema=get_args(dtype)[0],
94 | parent=self,
95 | ) # type: ignore
96 |
97 | return dtype() # type: ignore
98 |
--------------------------------------------------------------------------------
/typedspark/_core/column_meta.py:
--------------------------------------------------------------------------------
1 | """Metadata for ``Column`` objects that can be accessed during runtime."""
2 |
3 | from dataclasses import asdict, dataclass
4 | from typing import Dict, Optional
5 |
6 |
7 | @dataclass
8 | class ColumnMeta:
9 | """Contains the metadata for a ``Column``. Used as:
10 |
11 | .. code-block:: python
12 |
13 | class A(Schema):
14 | a: Annotated[
15 | Column[IntegerType],
16 | ColumnMeta(
17 | comment="This is a comment",
18 | )
19 | ]
20 | """
21 |
22 | comment: Optional[str] = None
23 |
24 | def get_metadata(self) -> Optional[Dict[str, str]]:
25 | """Returns the metadata of this column."""
26 | res = {k: v for k, v in asdict(self).items() if v is not None}
27 | return res if len(res) > 0 else None
28 |
--------------------------------------------------------------------------------
/typedspark/_core/datatypes.py:
--------------------------------------------------------------------------------
1 | """Here, we make our own definitions of ``MapType``, ``ArrayType`` and ``StructType`` in
2 | order to allow e.g. for ``ArrayType[StringType]``."""
3 |
4 | from __future__ import annotations
5 |
6 | from typing import TYPE_CHECKING, Any, Dict, Generic, Type, TypeVar
7 |
8 | from pyspark.sql.types import DataType
9 |
10 | if TYPE_CHECKING: # pragma: no cover
11 | from typedspark._core.column import Column
12 | from typedspark._schema.schema import Schema
13 |
14 | _Schema = TypeVar("_Schema", bound=Schema)
15 | else:
16 | _Schema = TypeVar("_Schema")
17 |
18 | _KeyType = TypeVar("_KeyType", bound=DataType) # pylint: disable=invalid-name
19 | _ValueType = TypeVar("_ValueType", bound=DataType) # pylint: disable=invalid-name
20 | _Precision = TypeVar("_Precision", bound=int) # pylint: disable=invalid-name
21 | _Scale = TypeVar("_Scale", bound=int) # pylint: disable=invalid-name
22 | _StartField = TypeVar("_StartField", bound=int) # pylint: disable=invalid-name
23 | _EndField = TypeVar("_EndField", bound=int) # pylint: disable=invalid-name
24 |
25 |
26 | class TypedSparkDataType(DataType):
27 | """Base class for typedspark specific ``DataTypes``."""
28 |
29 | @classmethod
30 | def get_name(cls) -> str:
31 | """Return the name of the type."""
32 | return cls.__name__
33 |
34 |
35 | class StructTypeMeta(type):
36 | """Initializes the schema attribute as None.
37 |
38 | This allows for auto-complete in Databricks notebooks (uninitialized variables don't
39 | show up in auto-complete there).
40 | """
41 |
42 | def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]):
43 | dct["schema"] = None
44 | return super().__new__(cls, name, bases, dct)
45 |
46 |
47 | class StructType(Generic[_Schema], TypedSparkDataType, metaclass=StructTypeMeta):
48 | """Allows for type annotations such as:
49 |
50 | .. code-block:: python
51 |
52 | class Job(Schema):
53 | position: Column[StringType]
54 | salary: Column[LongType]
55 |
56 | class Person(Schema):
57 | job: Column[StructType[Job]]
58 | """
59 |
60 | def __init__(
61 | self,
62 | schema: Type[_Schema],
63 | parent: Column,
64 | ) -> None:
65 | self.schema = schema
66 | self.schema._parent = parent
67 |
68 |
69 | class MapType(Generic[_KeyType, _ValueType], TypedSparkDataType):
70 | """Allows for type annotations such as.
71 |
72 | .. code-block:: python
73 |
74 | class Basket(Schema):
75 | items: Column[MapType[StringType, StringType]]
76 | """
77 |
78 |
79 | class ArrayType(Generic[_ValueType], TypedSparkDataType):
80 | """Allows for type annotations such as.
81 |
82 | .. code-block:: python
83 |
84 | class Basket(Schema):
85 | items: Column[ArrayType[StringType]]
86 | """
87 |
88 |
89 | class DecimalType(Generic[_Precision, _Scale], TypedSparkDataType):
90 | """Allows for type annotations such as.
91 |
92 | .. code-block:: python
93 |
94 | class Numbers(Schema):
95 | number: Column[DecimalType[Literal[10], Literal[0]]]
96 | """
97 |
98 |
99 | class DayTimeIntervalType(Generic[_StartField, _EndField], TypedSparkDataType):
100 | """Allows for type annotations such as.
101 |
102 | .. code-block:: python
103 |
104 | class TimeInterval(Schema):
105 | interval: Column[DayTimeIntervalType[IntervalType.HOUR, IntervalType.SECOND]
106 | """
107 |
--------------------------------------------------------------------------------
/typedspark/_core/literaltype.py:
--------------------------------------------------------------------------------
1 | """Defines ``LiteralTypes``, e.g. ``IntervalType.DAY``, that map their class attribute
2 | to a ``Literal`` integer.
3 |
4 | Can be used for example in ``DayTimeIntervalType``.
5 | """
6 |
7 | from typing import Dict, Literal
8 |
9 |
10 | class LiteralType:
11 | """Base class for literal types, that map their class attribute to a Literal
12 | integer."""
13 |
14 | @classmethod
15 | def get_dict(cls) -> Dict[str, str]:
16 | """Returns a dictionary mapping e.g. "IntervalType.DAY" -> "Literal[0]"."""
17 | dictionary = {}
18 | for key, value in cls.__dict__.items():
19 | if key.startswith("_"):
20 | continue
21 |
22 | key = f"{cls.__name__}.{key}"
23 | value = str(value).replace("typing.", "")
24 |
25 | dictionary[key] = value
26 |
27 | return dictionary
28 |
29 | @classmethod
30 | def get_inverse_dict(cls) -> Dict[str, str]:
31 | """Returns a dictionary mapping e.g. "Literal[0]" -> "IntervalType.DAY"."""
32 | return {v: k for k, v in cls.get_dict().items()}
33 |
34 |
35 | class IntervalType(LiteralType):
36 | """Interval types for ``DayTimeIntervalType``."""
37 |
38 | DAY = Literal[0]
39 | HOUR = Literal[1]
40 | MINUTE = Literal[2]
41 | SECOND = Literal[3]
42 |
--------------------------------------------------------------------------------
/typedspark/_core/validate_schema.py:
--------------------------------------------------------------------------------
1 | """Module containing functions that are related to validating schema's at runtime."""
2 |
3 | from typing import Dict
4 |
5 | from pyspark.sql.types import ArrayType, DataType, MapType, StructField, StructType
6 |
7 | from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype
8 |
9 |
10 | def validate_schema(
11 | structtype_expected: StructType, structtype_observed: StructType, schema_name: str
12 | ) -> None:
13 | """Checks whether the expected and the observed StructType match."""
14 | expected = unpack_schema(structtype_expected)
15 | observed = unpack_schema(structtype_observed)
16 |
17 | check_names(expected, observed, schema_name)
18 | check_dtypes(expected, observed, schema_name)
19 |
20 |
21 | def unpack_schema(schema: StructType) -> Dict[str, StructField]:
22 | """Converts the observed schema to a dictionary mapping column name to StructField.
23 |
24 | We ignore columns that start with ``__``.
25 | """
26 | res = {}
27 | for field in schema.fields:
28 | if field.name.startswith("__"):
29 | continue
30 | field.nullable = True
31 | field.metadata = {}
32 | res[field.name] = field
33 |
34 | return res
35 |
36 |
37 | def check_names(
38 | expected: Dict[str, StructField], observed: Dict[str, StructField], schema_name: str
39 | ) -> None:
40 | """Checks whether the observed and expected list of column names overlap.
41 |
42 | Is order insensitive.
43 | """
44 | names_observed = set(observed.keys())
45 | names_expected = set(expected.keys())
46 |
47 | diff = names_observed - names_expected
48 | if diff:
49 | diff_schema = create_schema_from_structtype(
50 | StructType([observed[colname] for colname in diff]), schema_name
51 | )
52 | raise TypeError(
53 | f"Data contains the following columns not present in schema {schema_name}: {diff}.\n\n"
54 | "If you believe these columns should be part of the schema, consider adding the "
55 | "following lines to it.\n\n"
56 | f"{diff_schema.get_schema_definition_as_string(generate_imports=False)}"
57 | )
58 |
59 | diff = names_expected - names_observed
60 | if diff:
61 | raise TypeError(
62 | f"Schema {schema_name} contains the following columns not present in data: {diff}"
63 | )
64 |
65 |
66 | def check_dtypes(
67 | schema_expected: Dict[str, StructField],
68 | schema_observed: Dict[str, StructField],
69 | schema_name: str,
70 | ) -> None:
71 | """Checks for each column whether the observed and expected data type match.
72 |
73 | Is order insensitive.
74 | """
75 | for name, structfield_expected in schema_expected.items():
76 | structfield_observed = schema_observed[name]
77 | check_dtype(
78 | name,
79 | structfield_expected.dataType,
80 | structfield_observed.dataType,
81 | schema_name,
82 | )
83 |
84 |
85 | def check_dtype(
86 | colname: str, dtype_expected: DataType, dtype_observed: DataType, schema_name: str
87 | ) -> None:
88 | """Checks whether the observed and expected data type match."""
89 | if dtype_expected == dtype_observed:
90 | return None
91 |
92 | if isinstance(dtype_expected, ArrayType) and isinstance(dtype_observed, ArrayType):
93 | return check_dtype(
94 | f"{colname}.element_type",
95 | dtype_expected.elementType,
96 | dtype_observed.elementType,
97 | schema_name,
98 | )
99 |
100 | if isinstance(dtype_expected, MapType) and isinstance(dtype_observed, MapType):
101 | check_dtype(
102 | f"{colname}.key",
103 | dtype_expected.keyType,
104 | dtype_observed.keyType,
105 | schema_name,
106 | )
107 | return check_dtype(
108 | f"{colname}.value",
109 | dtype_expected.valueType,
110 | dtype_observed.valueType,
111 | schema_name,
112 | )
113 |
114 | if isinstance(dtype_expected, StructType) and isinstance(dtype_observed, StructType):
115 | return validate_schema(dtype_expected, dtype_observed, f"{schema_name}.{colname}")
116 |
117 | raise TypeError(
118 | f"Column {colname} is of type {dtype_observed}, but {schema_name}.{colname} "
119 | + f"suggests {dtype_expected}."
120 | )
121 |
--------------------------------------------------------------------------------
/typedspark/_schema/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaiko-ai/typedspark/0c99ffb7cb9eddee9121254163f74b8de25d9f6b/typedspark/_schema/__init__.py
--------------------------------------------------------------------------------
/typedspark/_schema/dlt_kwargs.py:
--------------------------------------------------------------------------------
1 | """A representation of the ``Schema`` to be used by Delta Live Tables."""
2 |
3 | from typing import Optional, TypedDict
4 |
5 | from pyspark.sql.types import StructType
6 |
7 |
8 | class DltKwargs(TypedDict):
9 | """A representation of the ``Schema`` to be used by Delta Live Tables.
10 |
11 | .. code-block:: python
12 |
13 | @dlt.table(**Person.get_dlt_kwargs())
14 | def table_definition() -> DataSet[Person]:
15 |
16 | """
17 |
18 | name: str
19 | comment: Optional[str]
20 | schema: StructType
21 |
--------------------------------------------------------------------------------
/typedspark/_schema/get_schema_definition.py:
--------------------------------------------------------------------------------
1 | """Module to output a string with the ``Schema`` definition of a given ``DataFrame``."""
2 |
3 | from __future__ import annotations
4 |
5 | import re
6 | from typing import TYPE_CHECKING, Type, get_args, get_origin, get_type_hints
7 |
8 | from typedspark._core.datatypes import DayTimeIntervalType, StructType, TypedSparkDataType
9 | from typedspark._core.literaltype import IntervalType, LiteralType
10 | from typedspark._schema.get_schema_imports import get_schema_imports
11 |
12 | if TYPE_CHECKING: # pragma: no cover
13 | from typedspark._schema.schema import Schema
14 |
15 |
16 | def get_schema_definition_as_string(
17 | schema: Type[Schema],
18 | include_documentation: bool,
19 | generate_imports: bool,
20 | add_subschemas: bool,
21 | class_name: str = "MyNewSchema",
22 | ) -> str:
23 | """Return the code for a given ``Schema`` as a string.
24 |
25 | Typically used when you load a dataset using
26 | ``load_dataset_from_table()`` in a notebook and you want to save the
27 | schema in your code base. When ``generate_imports`` is True, the
28 | required imports for the schema are included in the string.
29 | """
30 | imports = get_schema_imports(schema, include_documentation) if generate_imports else ""
31 | schema_string = _build_schema_definition_string(
32 | schema, include_documentation, add_subschemas, class_name
33 | )
34 |
35 | return imports + schema_string
36 |
37 |
38 | def _build_schema_definition_string(
39 | schema: Type[Schema],
40 | include_documentation: bool,
41 | add_subschemas: bool,
42 | class_name: str = "MyNewSchema",
43 | ) -> str:
44 | """Return the code for a given ``Schema`` as a string."""
45 | lines = f"class {class_name}(Schema):\n"
46 |
47 | if include_documentation:
48 | lines += _create_docstring(schema)
49 |
50 | lines += _add_lines_with_typehint(include_documentation, schema)
51 |
52 | if add_subschemas:
53 | lines += _add_subschemas(schema, add_subschemas, include_documentation)
54 |
55 | return lines
56 |
57 |
58 | def _create_docstring(schema: Type[Schema]) -> str:
59 | """Create the docstring for a given ``Schema``."""
60 | if schema.get_docstring() != "":
61 | docstring = f' """{schema.get_docstring()}"""\n\n'
62 | else:
63 | docstring = ' """Add documentation here."""\n\n'
64 | return docstring
65 |
66 |
67 | def _add_lines_with_typehint(include_documentation, schema):
68 | """Add a line with the typehint for each column in the ``Schema``."""
69 | lines = ""
70 | for col_name, col_type in get_type_hints(schema, include_extras=True).items():
71 | typehint, comment = _create_typehint_and_comment(col_type)
72 |
73 | if include_documentation:
74 | lines += f' {col_name}: Annotated[{typehint}, ColumnMeta(comment="{comment}")]\n'
75 | else:
76 | lines += f" {col_name}: {typehint}\n"
77 | return lines
78 |
79 |
80 | def _create_typehint_and_comment(col_type) -> list[str]:
81 | """Create a typehint and comment for a given column."""
82 | typehint = (
83 | str(col_type)
84 | .replace("typedspark._core.column.", "")
85 | .replace("typedspark._core.datatypes.", "")
86 | .replace("typedspark._schema.schema.", "")
87 | .replace("pyspark.sql.types.", "")
88 | .replace("typing.", "")
89 | .replace("abc.", "")
90 | )
91 | typehint, comment = _extract_comment(typehint)
92 | typehint = _replace_literals(
93 | typehint, replace_literals_in=DayTimeIntervalType, replace_literals_by=IntervalType
94 | )
95 | return [typehint, comment]
96 |
97 |
98 | def _extract_comment(typehint: str) -> tuple[str, str]:
99 | """Extract the comment from a typehint."""
100 | comment = ""
101 | if "Annotated" in typehint:
102 | match = re.search(r"Annotated\[(.*), '(.*)'\]", typehint)
103 | if match is not None:
104 | typehint, comment = match.groups()
105 | return typehint, comment
106 |
107 |
108 | def _replace_literals(
109 | typehint: str,
110 | replace_literals_in: Type[TypedSparkDataType],
111 | replace_literals_by: Type[LiteralType],
112 | ) -> str:
113 | """Replace all Literals in a LiteralType, e.g.
114 |
115 | "DayTimeIntervalType[Literal[0], Literal[1]]" ->
116 | "DayTimeIntervalType[IntervalType.DAY, IntervalType.HOUR]"
117 | """
118 | mapping = replace_literals_by.get_inverse_dict()
119 | for original, replacement in mapping.items():
120 | typehint = _replace_literal(typehint, replace_literals_in, original, replacement)
121 |
122 | return typehint
123 |
124 |
125 | def _replace_literal(
126 | typehint: str,
127 | replace_literals_in: Type[TypedSparkDataType],
128 | original: str,
129 | replacement: str,
130 | ) -> str:
131 | """Replaces a single Literal in a LiteralType, e.g.
132 |
133 | "DayTimeIntervalType[Literal[0], Literal[1]]" ->
134 | "DayTimeIntervalType[IntervalType.DAY, Literal[1]]"
135 | """
136 | return re.sub(
137 | rf"{replace_literals_in.get_name()}\[[^]]*\]",
138 | lambda x: x.group(0).replace(original, replacement),
139 | typehint,
140 | )
141 |
142 |
143 | def _add_subschemas(schema: Type[Schema], add_subschemas: bool, include_documentation: bool) -> str:
144 | """Identifies whether any ``Column`` are of the ``StructType`` type and generates
145 | their schema recursively."""
146 | lines = ""
147 | for val in get_type_hints(schema).values():
148 | args = get_args(val)
149 | if not args:
150 | continue
151 |
152 | dtype = args[0]
153 | if get_origin(dtype) == StructType:
154 | lines += "\n\n"
155 | subschema: Type[Schema] = get_args(dtype)[0]
156 | lines += _build_schema_definition_string(
157 | subschema, include_documentation, add_subschemas, subschema.get_schema_name()
158 | )
159 |
160 | return lines
161 |
--------------------------------------------------------------------------------
/typedspark/_schema/get_schema_imports.py:
--------------------------------------------------------------------------------
1 | """Builds an import statement for everything imported by a given ``Schema``."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING, Optional, Type, get_args, get_origin, get_type_hints
6 |
7 | from pyspark.sql.types import DataType
8 |
9 | from typedspark._core.datatypes import (
10 | ArrayType,
11 | DayTimeIntervalType,
12 | DecimalType,
13 | MapType,
14 | StructType,
15 | TypedSparkDataType,
16 | )
17 |
18 | if TYPE_CHECKING: # pragma: no cover
19 | from typedspark._schema.schema import Schema
20 |
21 |
22 | def get_schema_imports(schema: Type[Schema], include_documentation: bool) -> str:
23 | """Builds an import statement for everything imported by the ``Schema``."""
24 | dtypes = _get_imported_dtypes(schema)
25 | return _build_import_string(dtypes, include_documentation)
26 |
27 |
28 | def _get_imported_dtypes(schema: Type[Schema]) -> set[Type[DataType]]:
29 | """Returns a set of DataTypes that are imported by the given schema."""
30 | encountered_datatypes: set[Type[DataType]] = set()
31 | for column in get_type_hints(schema).values():
32 | args = get_args(column)
33 | if not args:
34 | continue
35 |
36 | dtype = args[0]
37 | encountered_datatypes |= _process_datatype(dtype)
38 |
39 | return encountered_datatypes
40 |
41 |
42 | def _process_datatype(dtype: Type[DataType]) -> set[Type[DataType]]:
43 | """Returns a set of DataTypes that are imported for a given DataType.
44 |
45 | Handles nested DataTypes recursively.
46 | """
47 | encountered_datatypes: set[Type[DataType]] = set()
48 |
49 | origin: Optional[Type[DataType]] = get_origin(dtype)
50 | if origin:
51 | encountered_datatypes.add(origin)
52 | else:
53 | encountered_datatypes.add(dtype)
54 |
55 | if origin == MapType:
56 | key, value = get_args(dtype)
57 | encountered_datatypes |= _process_datatype(key)
58 | encountered_datatypes |= _process_datatype(value)
59 |
60 | if origin == ArrayType:
61 | element = get_args(dtype)[0]
62 | encountered_datatypes |= _process_datatype(element)
63 |
64 | if get_origin(dtype) == StructType:
65 | subschema = get_args(dtype)[0]
66 | encountered_datatypes |= _get_imported_dtypes(subschema)
67 |
68 | return encountered_datatypes
69 |
70 |
71 | def _build_import_string(
72 | encountered_datatypes: set[Type[DataType]], include_documentation: bool
73 | ) -> str:
74 | """Returns a multiline string with the imports required for the given
75 | encountered_datatypes.
76 |
77 | Import sorting is applied.
78 |
79 | If the schema uses IntegerType, BooleanType, StringType, this functions result would be
80 |
81 | .. code-block:: python
82 |
83 | from pyspark.sql.types import BooleanType, IntegerType, StringType
84 |
85 | from typedspark import Column, Schema
86 | """
87 | return (
88 | _typing_imports(encountered_datatypes, include_documentation)
89 | + _pyspark_imports(encountered_datatypes)
90 | + _typedspark_imports(encountered_datatypes, include_documentation)
91 | )
92 |
93 |
94 | def _typing_imports(encountered_datatypes: set[Type[DataType]], include_documentation: bool) -> str:
95 | """Returns the import statement for the typing library."""
96 | imports = []
97 |
98 | if any([dtype == DecimalType for dtype in encountered_datatypes]):
99 | imports += ["Literal"]
100 |
101 | if include_documentation:
102 | imports += ["Annotated"]
103 |
104 | if len(imports) > 0:
105 | imports = sorted(imports)
106 | imports_string = ", ".join(imports) # type: ignore
107 | return f"from typing import {imports_string}\n\n"
108 |
109 | return ""
110 |
111 |
112 | def _pyspark_imports(encountered_datatypes: set[Type[DataType]]) -> str:
113 | """Returns the import statement for the pyspark library."""
114 | dtypes = sorted(
115 | [
116 | dtype.__name__
117 | for dtype in encountered_datatypes
118 | if not issubclass(dtype, TypedSparkDataType)
119 | ]
120 | )
121 |
122 | if len(dtypes) > 0:
123 | dtypes_string = ", ".join(dtypes)
124 | return f"from pyspark.sql.types import {dtypes_string}\n\n"
125 |
126 | return ""
127 |
128 |
129 | def _typedspark_imports(
130 | encountered_datatypes: set[Type[DataType]], include_documentation: bool
131 | ) -> str:
132 | """Returns the import statement for the typedspark library."""
133 | dtypes = [
134 | dtype.__name__ for dtype in encountered_datatypes if issubclass(dtype, TypedSparkDataType)
135 | ] + ["Column", "Schema"]
136 |
137 | if any([dtype == DayTimeIntervalType for dtype in encountered_datatypes]):
138 | dtypes += ["IntervalType"]
139 |
140 | if include_documentation:
141 | dtypes.append("ColumnMeta")
142 |
143 | dtypes = sorted(dtypes)
144 |
145 | dtypes_string = ", ".join(dtypes)
146 | return f"from typedspark import {dtypes_string}\n\n\n"
147 |
--------------------------------------------------------------------------------
/typedspark/_schema/schema.py:
--------------------------------------------------------------------------------
1 | """Module containing classes and functions related to TypedSpark Schemas."""
2 |
3 | import inspect
4 | import re
5 | from typing import (
6 | Any,
7 | Dict,
8 | List,
9 | Optional,
10 | Protocol,
11 | Type,
12 | Union,
13 | _ProtocolMeta,
14 | get_args,
15 | get_type_hints,
16 | )
17 |
18 | from pyspark.sql import DataFrame
19 | from pyspark.sql.types import DataType, StructType
20 |
21 | from typedspark._core.column import Column
22 | from typedspark._schema.dlt_kwargs import DltKwargs
23 | from typedspark._schema.get_schema_definition import get_schema_definition_as_string
24 | from typedspark._schema.structfield import get_structfield
25 |
26 |
27 | class MetaSchema(_ProtocolMeta): # type: ignore
28 | """``MetaSchema`` is the metaclass of ``Schema``.
29 |
30 | It basically implements all functionality of ``Schema``. But since
31 | classes are typically considered more convenient than metaclasses,
32 | we provide ``Schema`` as the public interface.
33 |
34 | .. code-block:: python
35 |
36 | class A(Schema):
37 | a: Column[IntegerType]
38 | b: Column[StringType]
39 |
40 | DataSet[A](df)
41 |
42 | The class methods of ``Schema`` are described here.
43 | """
44 |
45 | _parent: Optional[Union[DataFrame, Column]] = None
46 | _alias: Optional[str] = None
47 | _current_id: Optional[int] = None
48 | _original_name: Optional[str] = None
49 |
50 | def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]):
51 | cls._attributes = dir(cls)
52 |
53 | # initializes all uninitialied variables with a type annotation as None
54 | # this allows for auto-complete in Databricks notebooks (uninitialized variables
55 | # don't show up in auto-complete there).
56 | if "__annotations__" in dct.keys():
57 | extra = {attr: None for attr in dct["__annotations__"] if attr not in dct}
58 | dct = dict(dct, **extra)
59 |
60 | return super().__new__(cls, name, bases, dct)
61 |
62 | def __repr__(cls) -> str:
63 | return f"\n{str(cls)}"
64 |
65 | def __str__(cls) -> str:
66 | return cls.get_schema_definition_as_string(add_subschemas=False)
67 |
68 | def __getattribute__(cls, name: str) -> Any:
69 | """Python base function that gets attributes.
70 |
71 | We listen here for anyone getting a ``Column`` from the ``Schema``.
72 | Even though they're not explicitely instantiated, we can instantiate
73 | them here whenever someone attempts to get them. This allows us to do the following:
74 |
75 | .. code-block:: python
76 |
77 | class A(Schema):
78 | a: Column[IntegerType]
79 |
80 | (
81 | df.withColumn(A.a.str, lit(1))
82 | .select(A.a)
83 | )
84 | """
85 | if (
86 | name.startswith("__")
87 | or name == "_attributes"
88 | or name in cls._attributes
89 | or name in dir(Protocol)
90 | ):
91 | return object.__getattribute__(cls, name)
92 |
93 | if name in get_type_hints(cls):
94 | return Column(
95 | name,
96 | dtype=cls._get_dtype(name), # type: ignore
97 | parent=cls._parent,
98 | curid=cls._current_id,
99 | alias=cls._alias,
100 | )
101 |
102 | raise TypeError(f"Schema {cls.get_schema_name()} does not have attribute {name}.")
103 |
104 | def _get_dtype(cls, name: str) -> Type[DataType]:
105 | """Returns the datatype of a column, e.g. Column[IntegerType] -> IntegerType."""
106 | column = get_type_hints(cls)[name]
107 | args = get_args(column)
108 |
109 | if not args:
110 | raise TypeError(
111 | f"Column {cls.get_schema_name()}.{name} does not have an annotated type."
112 | )
113 |
114 | dtype = args[0]
115 | return dtype
116 |
117 | def all_column_names(cls) -> List[str]:
118 | """Returns all column names for a given schema."""
119 | return list(get_type_hints(cls).keys())
120 |
121 | def all_column_names_except_for(cls, except_for: List[str]) -> List[str]:
122 | """Returns all column names for a given schema except for the columns specified
123 | in the ``except_for`` parameter."""
124 | return list(name for name in get_type_hints(cls).keys() if name not in except_for)
125 |
126 | def get_snake_case(cls) -> str:
127 | """Return the class name transformed into snakecase."""
128 | word = cls.get_schema_name()
129 | word = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", word)
130 | word = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", word)
131 | word = word.replace("-", "_")
132 | return word.lower()
133 |
134 | def get_schema_definition_as_string(
135 | cls,
136 | schema_name: Optional[str] = None,
137 | include_documentation: bool = False,
138 | generate_imports: bool = True,
139 | add_subschemas: bool = True,
140 | ) -> str:
141 | """Return the code for the ``Schema`` as a string."""
142 | if schema_name is None:
143 | schema_name = cls.get_schema_name()
144 | return get_schema_definition_as_string(
145 | cls, # type: ignore
146 | include_documentation,
147 | generate_imports,
148 | add_subschemas,
149 | schema_name,
150 | )
151 |
152 | def print_schema(
153 | cls,
154 | schema_name: Optional[str] = None,
155 | include_documentation: bool = False,
156 | generate_imports: bool = True,
157 | add_subschemas: bool = False,
158 | ): # pragma: no cover
159 | """Print the code for the ``Schema``."""
160 | print(
161 | cls.get_schema_definition_as_string(
162 | schema_name=schema_name,
163 | include_documentation=include_documentation,
164 | generate_imports=generate_imports,
165 | add_subschemas=add_subschemas,
166 | )
167 | )
168 |
169 | def get_docstring(cls) -> Union[str, None]:
170 | """Returns the docstring of the schema."""
171 | return inspect.getdoc(cls)
172 |
173 | def get_structtype(cls) -> StructType:
174 | """Creates the spark StructType for the schema."""
175 | return StructType(
176 | [
177 | get_structfield(name, column)
178 | for name, column in get_type_hints(cls, include_extras=True).items()
179 | ]
180 | )
181 |
182 | def get_dlt_kwargs(cls, name: Optional[str] = None) -> DltKwargs:
183 | """Creates a representation of the ``Schema`` to be used by Delta Live Tables.
184 |
185 | .. code-block:: python
186 |
187 | @dlt.table(**DimPatient.get_dlt_kwargs())
188 | def table_definition() -> DataSet[DimPatient]:
189 |
190 | """
191 | return {
192 | "name": name if name else cls.get_snake_case(),
193 | "comment": cls.get_docstring(),
194 | "schema": cls.get_structtype(),
195 | }
196 |
197 | def get_schema_name(cls):
198 | """Returns the name with which the schema was initialized."""
199 | return cls._original_name if cls._original_name else cls.__name__
200 |
201 | def get_metadata(cls) -> dict[str, dict[str, Any]]:
202 | """Returns the metadata of each of the columns in the schema."""
203 | return {field.name: field.metadata for field in cls.get_structtype().fields}
204 |
205 |
206 | class Schema(Protocol, metaclass=MetaSchema):
207 | # pylint: disable=empty-docstring
208 | # Since docstrings are inherrited, and since we use docstrings to
209 | # annotate tables (see MetaSchema.get_dlt_kwargs()), we have chosen
210 | # to add an empty docstring to the Schema class (otherwise the Schema
211 | # docstring would be added to any schema without a docstring).
212 | """"""
213 |
--------------------------------------------------------------------------------
/typedspark/_schema/structfield.py:
--------------------------------------------------------------------------------
1 | """Module responsible for generating StructFields from Columns in a Schema."""
2 |
3 | from __future__ import annotations
4 |
5 | import inspect
6 | from typing import TYPE_CHECKING, Annotated, Type, TypeVar, Union, get_args, get_origin
7 |
8 | from pyspark.sql.types import ArrayType as SparkArrayType
9 | from pyspark.sql.types import DataType
10 | from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType
11 | from pyspark.sql.types import DecimalType as SparkDecimalType
12 | from pyspark.sql.types import MapType as SparkMapType
13 | from pyspark.sql.types import StructField
14 | from pyspark.sql.types import StructType as SparkStructType
15 |
16 | from typedspark._core.column import Column
17 | from typedspark._core.column_meta import ColumnMeta
18 | from typedspark._core.datatypes import (
19 | ArrayType,
20 | DayTimeIntervalType,
21 | DecimalType,
22 | MapType,
23 | StructType,
24 | TypedSparkDataType,
25 | )
26 |
27 | if TYPE_CHECKING: # pragma: no cover
28 | from typedspark._schema.schema import Schema
29 |
30 | _DataType = TypeVar("_DataType", bound=DataType) # pylint: disable=invalid-name
31 |
32 |
33 | def get_structfield(
34 | name: str,
35 | column: Union[Type[Column[_DataType]], Annotated[Type[Column[_DataType]], ColumnMeta]],
36 | ) -> StructField:
37 | """Generates a ``StructField`` for a given ``Column`` in a ``Schema``."""
38 | meta = get_structfield_meta(column)
39 |
40 | return StructField(
41 | name=name,
42 | dataType=_get_structfield_dtype(column, name),
43 | nullable=True,
44 | metadata=meta.get_metadata(),
45 | )
46 |
47 |
48 | def get_structfield_meta(
49 | column: Union[Type[Column[_DataType]], Annotated[Type[Column[_DataType]], ColumnMeta]],
50 | ) -> ColumnMeta:
51 | """Get the spark column metadata from the ``ColumnMeta`` data, when available."""
52 | return next((x for x in get_args(column) if isinstance(x, ColumnMeta)), ColumnMeta())
53 |
54 |
55 | def _get_structfield_dtype(
56 | column: Union[Type[Column[_DataType]], Annotated[Type[Column[_DataType]], ColumnMeta]],
57 | colname: str,
58 | ) -> DataType:
59 | """Get the spark ``DataType`` from the ``Column`` type annotation."""
60 | origin = get_origin(column)
61 | if origin not in [Annotated, Column]:
62 | raise TypeError(f"Column {colname} needs to be of type Column or Annotated.")
63 |
64 | if origin == Annotated:
65 | column = _get_column_from_annotation(column, colname)
66 |
67 | args = get_args(column)
68 | dtype = _get_dtype(args[0], colname)
69 | return dtype
70 |
71 |
72 | def _get_column_from_annotation(
73 | column: Annotated[Type[Column[_DataType]], ColumnMeta],
74 | colname: str,
75 | ) -> Type[Column[_DataType]]:
76 | """Takes an ``Annotation[Column[...], ...]`` and returns the ``Column[...]``."""
77 | column = get_args(column)[0]
78 | if get_origin(column) != Column:
79 | raise TypeError(f"Column {colname} needs to have a Column[] within Annotated[].")
80 |
81 | return column
82 |
83 |
84 | def _get_dtype(dtype: Type[DataType], colname: str) -> DataType:
85 | """Takes a ``DataType`` class and returns a DataType object."""
86 | origin = get_origin(dtype)
87 | if origin == ArrayType:
88 | return _extract_arraytype(dtype, colname)
89 | if origin == MapType:
90 | return _extract_maptype(dtype, colname)
91 | if origin == StructType:
92 | return _extract_structtype(dtype)
93 | if origin == DecimalType:
94 | return _extract_decimaltype(dtype)
95 | if origin == DayTimeIntervalType:
96 | return _extract_daytimeintervaltype(dtype)
97 | if (
98 | inspect.isclass(dtype)
99 | and issubclass(dtype, DataType)
100 | and not issubclass(dtype, TypedSparkDataType)
101 | ):
102 | return dtype()
103 |
104 | raise TypeError(
105 | f"Column {colname} does not have a correctly formatted DataType as a parameter."
106 | )
107 |
108 |
109 | def _extract_arraytype(arraytype: Type[DataType], colname: str) -> SparkArrayType:
110 | """Takes e.g. an ``ArrayType[StringType]`` and creates an ``ArrayType(StringType(),
111 | True)``."""
112 | params = get_args(arraytype)
113 | element_type = _get_dtype(params[0], colname)
114 | return SparkArrayType(element_type)
115 |
116 |
117 | def _extract_maptype(maptype: Type[DataType], colname: str) -> SparkMapType:
118 | """Takes e.g. a ``MapType[StringType, StringType]`` and creates a ``
119 | MapType(StringType(), StringType(), True)``."""
120 | params = get_args(maptype)
121 | key_type = _get_dtype(params[0], colname)
122 | value_type = _get_dtype(params[1], colname)
123 | return SparkMapType(key_type, value_type)
124 |
125 |
126 | def _extract_structtype(structtype: Type[DataType]) -> SparkStructType:
127 | """Takes a ``StructType[Schema]`` annotation and creates a
128 | ``StructType(schema_list)``, where ``schema_list`` contains all ``StructField()``
129 | defined in the ``Schema``."""
130 | params = get_args(structtype)
131 | schema: Type[Schema] = params[0]
132 | return schema.get_structtype()
133 |
134 |
135 | def _extract_decimaltype(decimaltype: Type[DataType]) -> SparkDecimalType:
136 | """Takes e.g. a ``DecimalType[Literal[10], Literal[12]]`` and returns
137 | ``DecimalType(10, 12)``."""
138 | params = get_args(decimaltype)
139 | key_type: int = _unpack_literal(params[0])
140 | value_type: int = _unpack_literal(params[1])
141 | return SparkDecimalType(key_type, value_type)
142 |
143 |
144 | def _extract_daytimeintervaltype(daytimeintervaltype: Type[DataType]) -> SparkDayTimeIntervalType:
145 | """Takes e.g. a ``DayTimeIntervalType[Literal[1], Literal[2]]`` and returns
146 | ``DayTimeIntervalType(1, 2)``."""
147 | params = get_args(daytimeintervaltype)
148 | start_field: int = _unpack_literal(params[0])
149 | end_field: int = _unpack_literal(params[1])
150 | return SparkDayTimeIntervalType(start_field, end_field)
151 |
152 |
153 | def _unpack_literal(literal):
154 | """Takes as input e.g. ``Literal[10]`` and returns ``10``."""
155 | return get_args(literal)[0]
156 |
--------------------------------------------------------------------------------
/typedspark/_transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaiko-ai/typedspark/0c99ffb7cb9eddee9121254163f74b8de25d9f6b/typedspark/_transforms/__init__.py
--------------------------------------------------------------------------------
/typedspark/_transforms/rename_duplicate_columns.py:
--------------------------------------------------------------------------------
1 | """Module that handles duplicate columns in the ``DataFrame``, that are also in the
2 | schema (and hence in the resulting ``DataSet[Schema]``), but which are not handled by
3 | the transformations dictionary."""
4 |
5 | from typing import Dict, Final, Type
6 | from uuid import uuid4
7 |
8 | from pyspark.sql import Column as SparkColumn
9 |
10 | from typedspark._schema.schema import Schema
11 |
12 | ERROR_MSG: Final[
13 | str
14 | ] = """Columns {columns} are ambiguous.
15 | Please specify the transformations for these columns explicitly, for example:
16 |
17 | schema_a = register_schema_to_dataset(df_a, A)
18 | schema_b = register_schema_to_dataset(df_b, B)
19 |
20 | transform_to_schema(
21 | df_a.join(
22 | df_b,
23 | schema_a.id == schema_b.id
24 | ),
25 | C,
26 | {{
27 | C.id: schema_a.id,
28 | }}
29 | )
30 | """
31 |
32 |
33 | class RenameDuplicateColumns:
34 | """Class that handles duplicate columns in the DataFrame, that are also in the
35 | schema (and hence in the resulting ``DataSet[Schema]``), but which are not handled
36 | by the transformations dictionary.
37 |
38 | This class renames these duplicate columns to temporary names, such that we avoid
39 | ambiguous columns in the resulting ``DataSet[Schema]``.
40 | """
41 |
42 | def __init__(
43 | self,
44 | transformations: Dict[str, SparkColumn],
45 | schema: Type[Schema],
46 | dataframe_columns: list[str],
47 | ):
48 | self._temporary_key_mapping = self._create_temporary_key_mapping(
49 | transformations, dataframe_columns, schema
50 | )
51 | self._transformations = self._rename_keys_to_temporary_keys(transformations)
52 |
53 | def _create_temporary_key_mapping(
54 | self,
55 | transformations: Dict[str, SparkColumn],
56 | dataframe_columns: list[str],
57 | schema: Type[Schema],
58 | ) -> Dict[str, str]:
59 | """Creates a mapping for duplicate columns in the ``DataFrame`` to temporary
60 | names, such that we avoid ambiguous columns in the resulting
61 | ``DataSet[Schema]``."""
62 | duplicate_columns_in_dataframe = self._duplicates(dataframe_columns)
63 | schema_columns = set(schema.all_column_names())
64 | transformation_keys = set(transformations.keys())
65 |
66 | self._verify_that_all_duplicate_columns_will_be_handled(
67 | duplicate_columns_in_dataframe, schema_columns, transformation_keys
68 | )
69 |
70 | problematic_keys = duplicate_columns_in_dataframe & transformation_keys
71 |
72 | res = {}
73 | for key in problematic_keys:
74 | res[key] = self._find_temporary_name(key, dataframe_columns)
75 |
76 | return {k: f"my_temporary_typedspark_{k}" for k in problematic_keys}
77 |
78 | def _duplicates(self, lst: list) -> set:
79 | """Returns a set of the duplicates in the provided list."""
80 | return {x for x in lst if lst.count(x) > 1}
81 |
82 | def _verify_that_all_duplicate_columns_will_be_handled(
83 | self,
84 | duplicate_columns_in_dataframe: set[str],
85 | schema_columns: set[str],
86 | transformation_keys: set[str],
87 | ):
88 | """Raises an exception if there are duplicate columns in the ``DataFrame``, that
89 | are also in the schema (and hence in the resulting ``DataSet[Schema]``), but
90 | which are not handled by the ``transformation``s`` dictionary."""
91 | unhandled_columns = (duplicate_columns_in_dataframe & schema_columns) - transformation_keys
92 | if unhandled_columns:
93 | raise ValueError(ERROR_MSG.format(columns=unhandled_columns))
94 |
95 | def _find_temporary_name(self, colname: str, dataframe_columns: list[str]) -> str:
96 | """Appends a uuid to the column name to make sure the temporary name doesn't
97 | collide with any other column names."""
98 | name = colname
99 | num = 0
100 | while name in dataframe_columns:
101 | name = f"{colname}_with_temporary_uuid_{uuid4()}"
102 | num += 1
103 | if num > 100:
104 | raise Exception("Failed to find a temporary name.") # pragma: no cover
105 |
106 | return name
107 |
108 | def _rename_keys_to_temporary_keys(
109 | self, transformations: Dict[str, SparkColumn]
110 | ) -> Dict[str, SparkColumn]:
111 | """Renames the keys in the transformations dictionary to temporary keys."""
112 | return {self._temporary_key_mapping.get(k, k): v for k, v in transformations.items()}
113 |
114 | @property
115 | def transformations(self) -> Dict[str, SparkColumn]:
116 | """Returns the transformations dictionary."""
117 | return self._transformations
118 |
119 | @property
120 | def temporary_key_mapping(self) -> Dict[str, str]:
121 | """Returns the temporary key mapping."""
122 | return self._temporary_key_mapping
123 |
--------------------------------------------------------------------------------
/typedspark/_transforms/structtype_column.py:
--------------------------------------------------------------------------------
1 | """Functionality for dealing with StructType columns."""
2 |
3 | from typing import Dict, Optional, Type
4 |
5 | from pyspark.sql import Column as SparkColumn
6 | from pyspark.sql.functions import struct
7 |
8 | from typedspark._core.column import Column
9 | from typedspark._schema.schema import Schema
10 | from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings
11 |
12 |
13 | def structtype_column(
14 | schema: Type[Schema],
15 | transformations: Optional[Dict[Column, SparkColumn]] = None,
16 | fill_unspecified_columns_with_nulls: bool = False,
17 | ) -> SparkColumn:
18 | """Helps with creating new ``StructType`` columns of a certain schema, for
19 | example:
20 |
21 | .. code-block:: python
22 |
23 | transform_to_schema(
24 | df,
25 | Output,
26 | {
27 | Output.values: structtype_column(
28 | Value,
29 | {
30 | Value.a: Input.a + 2,
31 | ...
32 | }
33 | )
34 | }
35 | )
36 | """
37 | _transformations = convert_keys_to_strings(transformations)
38 |
39 | if fill_unspecified_columns_with_nulls:
40 | _transformations = add_nulls_for_unspecified_columns(_transformations, schema)
41 |
42 | _transformations = _order_columns(_transformations, schema)
43 |
44 | return struct([v.alias(k) for k, v in _transformations.items()])
45 |
46 |
47 | def _order_columns(
48 | transformations: Dict[str, SparkColumn], schema: Type[Schema]
49 | ) -> Dict[str, SparkColumn]:
50 | """Chispa's DataFrame comparer doesn't deal nicely with StructTypes whose columns
51 | are ordered differently, hence we order them the same as in the schema here."""
52 | transformations_ordered = {}
53 | for field in schema.get_structtype().fields:
54 | transformations_ordered[field.name] = transformations[field.name]
55 |
56 | return transformations_ordered
57 |
--------------------------------------------------------------------------------
/typedspark/_transforms/transform_to_schema.py:
--------------------------------------------------------------------------------
1 | """Module containing functions that are related to transformations to DataSets."""
2 |
3 | from functools import reduce
4 | from typing import Dict, Optional, Type, TypeVar, Union
5 |
6 | from pyspark.sql import Column as SparkColumn
7 | from pyspark.sql import DataFrame
8 |
9 | from typedspark._core.column import Column
10 | from typedspark._core.dataset import DataSet
11 | from typedspark._schema.schema import Schema
12 | from typedspark._transforms.rename_duplicate_columns import RenameDuplicateColumns
13 | from typedspark._transforms.utils import add_nulls_for_unspecified_columns, convert_keys_to_strings
14 |
15 | T = TypeVar("T", bound=Schema)
16 |
17 |
18 | def _do_transformations(
19 | dataframe: DataFrame, transformations: Dict[str, SparkColumn], run_sequentially: bool = True
20 | ) -> DataFrame:
21 | """Performs the transformations on the provided DataFrame."""
22 | if run_sequentially:
23 | return reduce(
24 | lambda acc, key: DataFrame.withColumn(acc, key, transformations[key]),
25 | transformations.keys(),
26 | dataframe,
27 | )
28 | return DataFrame.withColumns(dataframe, transformations)
29 |
30 |
31 | def _rename_temporary_keys_to_original_keys(
32 | dataframe: DataFrame, problematic_key_mapping: Dict[str, str]
33 | ) -> DataFrame:
34 | """Renames the temporary keys back to the original keys."""
35 | return reduce(
36 | lambda acc, key: DataFrame.withColumnRenamed(acc, problematic_key_mapping[key], key),
37 | problematic_key_mapping.keys(),
38 | dataframe,
39 | )
40 |
41 |
42 | def transform_to_schema(
43 | dataframe: DataFrame,
44 | schema: Type[T],
45 | transformations: Optional[Dict[Column, SparkColumn]] = None,
46 | fill_unspecified_columns_with_nulls: bool = False,
47 | run_sequentially: bool = True,
48 | ) -> DataSet[T]:
49 | """On the provided DataFrame ``df``, it performs the ``transformations`` (if
50 | provided), and subsequently subsets the resulting DataFrame to the columns specified
51 | in ``schema``.
52 |
53 | .. code-block:: python
54 |
55 | transform_to_schema(
56 | df_a.join(df_b, A.a == B.f),
57 | AB,
58 | {
59 | AB.a: A.a + 3,
60 | AB.b: A.b + 7,
61 | AB.i: B.i - 5,
62 | AB.j: B.j + 1,
63 | }
64 | )
65 | """
66 | transform: Union[dict[str, SparkColumn], RenameDuplicateColumns]
67 | transform = convert_keys_to_strings(transformations)
68 |
69 | if fill_unspecified_columns_with_nulls:
70 | transform = add_nulls_for_unspecified_columns(transform, schema, dataframe.columns)
71 |
72 | transform = RenameDuplicateColumns(transform, schema, dataframe.columns)
73 |
74 | return DataSet[schema]( # type: ignore
75 | dataframe.transform(_do_transformations, transform.transformations, run_sequentially)
76 | .drop(*transform.temporary_key_mapping.keys())
77 | .transform(_rename_temporary_keys_to_original_keys, transform.temporary_key_mapping)
78 | .select(*schema.all_column_names())
79 | )
80 |
--------------------------------------------------------------------------------
/typedspark/_transforms/utils.py:
--------------------------------------------------------------------------------
1 | """Util functions for typedspark._transforms."""
2 |
3 | from typing import Dict, List, Optional, Type
4 |
5 | from pyspark.sql import Column as SparkColumn
6 | from pyspark.sql.functions import lit
7 |
8 | from typedspark._core.column import Column
9 | from typedspark._schema.schema import Schema
10 |
11 |
12 | def add_nulls_for_unspecified_columns(
13 | transformations: Dict[str, SparkColumn],
14 | schema: Type[Schema],
15 | previously_existing_columns: Optional[List[str]] = None,
16 | ) -> Dict[str, SparkColumn]:
17 | """Takes the columns from the schema that are not present in the transformation
18 | dictionary and sets their values to Null (casted to the corresponding type defined
19 | in the schema)."""
20 | _previously_existing_columns = (
21 | [] if previously_existing_columns is None else previously_existing_columns
22 | )
23 | for field in schema.get_structtype().fields:
24 | if field.name not in transformations and field.name not in _previously_existing_columns:
25 | transformations[field.name] = lit(None).cast(field.dataType)
26 |
27 | return transformations
28 |
29 |
30 | def convert_keys_to_strings(
31 | transformations: Optional[Dict[Column, SparkColumn]],
32 | ) -> Dict[str, SparkColumn]:
33 | """Takes the Column keys in transformations and converts them to strings."""
34 | if transformations is None:
35 | return {}
36 |
37 | _transformations = {k.str: v for k, v in transformations.items()}
38 |
39 | if len(transformations) != len(_transformations):
40 | raise ValueError(
41 | "The transformations dictionary requires columns with unique names as keys. "
42 | + "It is currently not possible to have ambiguous column names here, "
43 | + "even when used in combination with register_schema_to_dataset()."
44 | )
45 |
46 | return _transformations
47 |
--------------------------------------------------------------------------------
/typedspark/_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaiko-ai/typedspark/0c99ffb7cb9eddee9121254163f74b8de25d9f6b/typedspark/_utils/__init__.py
--------------------------------------------------------------------------------
/typedspark/_utils/camelcase.py:
--------------------------------------------------------------------------------
1 | """Utility function for converting from snake case to camel case."""
2 |
3 |
4 | def to_camel_case(name: str) -> str:
5 | """Converts a string to camel case."""
6 | return "".join([word.capitalize() for word in name.split("_")])
7 |
--------------------------------------------------------------------------------
/typedspark/_utils/create_dataset.py:
--------------------------------------------------------------------------------
1 | """Module containing functions related to creating a DataSet from scratch."""
2 |
3 | from typing import Any, Dict, List, Type, TypeVar, Union, get_type_hints
4 |
5 | from pyspark.sql import Row, SparkSession
6 |
7 | from typedspark._core.column import Column
8 | from typedspark._core.dataset import DataSet
9 | from typedspark._schema.schema import Schema
10 |
11 | T = TypeVar("T", bound=Schema)
12 |
13 |
14 | def create_empty_dataset(spark: SparkSession, schema: Type[T], n_rows: int = 3) -> DataSet[T]:
15 | """Creates a ``DataSet`` with ``Schema`` schema, containing ``n_rows`` rows, filled
16 | with ``None`` values.
17 |
18 | .. code-block:: python
19 |
20 | class Person(Schema):
21 | name: Column[StringType]
22 | age: Column[LongType]
23 |
24 | df = create_empty_dataset(spark, Person)
25 | """
26 | n_cols = len(get_type_hints(schema))
27 | rows = tuple([None] * n_cols)
28 | data = [rows] * n_rows
29 | spark_schema = schema.get_structtype()
30 | dataframe = spark.createDataFrame(data, spark_schema)
31 | return DataSet[schema](dataframe) # type: ignore
32 |
33 |
34 | def create_partially_filled_dataset(
35 | spark: SparkSession,
36 | schema: Type[T],
37 | data: Union[Dict[Column, List[Any]], List[Dict[Column, Any]]],
38 | ) -> DataSet[T]:
39 | """Creates a ``DataSet`` with ``Schema`` schema, where ``data`` can
40 | be defined in either of the following two ways:
41 |
42 | .. code-block:: python
43 |
44 | class Person(Schema):
45 | name: Column[StringType]
46 | age: Column[LongType]
47 | job: Column[StringType]
48 |
49 | df = create_partially_filled_dataset(
50 | spark,
51 | Person,
52 | {
53 | Person.name: ["John", "Jack", "Jane"],
54 | Person.age: [30, 40, 50],
55 | }
56 | )
57 |
58 | Or:
59 |
60 | .. code-block:: python
61 |
62 | df = create_partially_filled_dataset(
63 | spark,
64 | Person,
65 | [
66 | {Person.name: "John", Person.age: 30},
67 | {Person.name: "Jack", Person.age: 40},
68 | {Person.name: "Jane", Person.age: 50},
69 | ]
70 | )
71 |
72 | Any columns in the schema that are not present in the data will be
73 | initialized with ``None`` values.
74 | """
75 | if isinstance(data, list):
76 | col_data = _create_column_wise_data_from_list(schema, data)
77 | elif isinstance(data, dict):
78 | col_data = _create_column_wise_data_from_dict(schema, data)
79 | else:
80 | raise ValueError("The provided data is not a list or a dict.")
81 |
82 | row_data = zip(*col_data)
83 | spark_schema = schema.get_structtype()
84 | dataframe = spark.createDataFrame(row_data, spark_schema)
85 | return DataSet[schema](dataframe) # type: ignore
86 |
87 |
88 | def create_structtype_row(schema: Type[T], data: Dict[Column, Any]) -> Row:
89 | """Creates a ``Row`` with ``StructType`` schema, where ``data`` is a mapping from
90 | column to data in the respective column."""
91 | data_with_string_index = {k.str: v for k, v in data.items()}
92 | data_converted = {
93 | k: data_with_string_index[k] if k in data_with_string_index else None
94 | for k in get_type_hints(schema).keys()
95 | }
96 | return Row(**data_converted)
97 |
98 |
99 | def _create_column_wise_data_from_dict(
100 | schema: Type[T], data: Dict[Column, List[Any]]
101 | ) -> List[List[Any]]:
102 | """Converts a dict of column to data to a list of lists, where each inner list
103 | contains the data for a column."""
104 | data_converted = {k.str: v for k, v in data.items()}
105 | n_rows_unique = {len(v) for _, v in data.items()}
106 | if len(n_rows_unique) > 1:
107 | raise ValueError("The number of rows in the provided data differs per column.")
108 |
109 | n_rows = list(n_rows_unique)[0]
110 | col_data = []
111 | for col in get_type_hints(schema).keys():
112 | if col in data_converted:
113 | col_data += [data_converted[col]]
114 | else:
115 | col_data += [[None] * n_rows]
116 |
117 | return col_data
118 |
119 |
120 | def _create_column_wise_data_from_list(
121 | schema: Type[T], data: List[Dict[Column, Any]]
122 | ) -> List[List[Any]]:
123 | """Converts a list of dicts of column to data to a list of lists, where each inner
124 | list contains the data for a column."""
125 | data_converted = [{k.str: v for k, v in row.items()} for row in data]
126 |
127 | col_data = []
128 | for col in get_type_hints(schema).keys():
129 | col_data += [[row[col] if col in row else None for row in data_converted]]
130 |
131 | return col_data
132 |
--------------------------------------------------------------------------------
/typedspark/_utils/create_dataset_from_structtype.py:
--------------------------------------------------------------------------------
1 | """Utility functions for creating a ``Schema`` from a ``StructType``"""
2 |
3 | from typing import Dict, Literal, Optional, Type
4 |
5 | from pyspark.sql.types import ArrayType as SparkArrayType
6 | from pyspark.sql.types import DataType
7 | from pyspark.sql.types import DayTimeIntervalType as SparkDayTimeIntervalType
8 | from pyspark.sql.types import DecimalType as SparkDecimalType
9 | from pyspark.sql.types import MapType as SparkMapType
10 | from pyspark.sql.types import StructType as SparkStructType
11 |
12 | from typedspark._core.column import Column
13 | from typedspark._core.datatypes import (
14 | ArrayType,
15 | DayTimeIntervalType,
16 | DecimalType,
17 | MapType,
18 | StructType,
19 | )
20 | from typedspark._schema.schema import MetaSchema, Schema
21 | from typedspark._utils.camelcase import to_camel_case
22 |
23 |
24 | def create_schema_from_structtype(
25 | structtype: SparkStructType, schema_name: Optional[str] = None
26 | ) -> Type[Schema]:
27 | """Dynamically builds a ``Schema`` based on a ``DataFrame``'s ``StructType``"""
28 | type_annotations = {}
29 | attributes: Dict[str, None] = {}
30 | for column in structtype:
31 | name = column.name
32 | data_type = _extract_data_type(column.dataType, name)
33 | type_annotations[name] = Column[data_type] # type: ignore
34 | attributes[name] = None
35 |
36 | if not schema_name:
37 | schema_name = "DynamicallyLoadedSchema"
38 |
39 | schema = MetaSchema(schema_name, tuple([Schema]), attributes)
40 | schema.__annotations__ = type_annotations
41 |
42 | return schema # type: ignore
43 |
44 |
45 | def _extract_data_type(dtype: DataType, name: str) -> Type[DataType]:
46 | """Given an instance of a ``DataType``, it extracts the corresponding ``DataType``
47 | class, potentially including annotations (e.g. ``ArrayType[StringType]``)."""
48 | if isinstance(dtype, SparkArrayType):
49 | element_type = _extract_data_type(dtype.elementType, name)
50 | return ArrayType[element_type] # type: ignore
51 |
52 | if isinstance(dtype, SparkMapType):
53 | key_type = _extract_data_type(dtype.keyType, name)
54 | value_type = _extract_data_type(dtype.valueType, name)
55 | return MapType[key_type, value_type] # type: ignore
56 |
57 | if isinstance(dtype, SparkStructType):
58 | subschema = create_schema_from_structtype(dtype, to_camel_case(name))
59 | return StructType[subschema] # type: ignore
60 |
61 | if isinstance(dtype, SparkDayTimeIntervalType):
62 | start_field = dtype.startField
63 | end_field = dtype.endField
64 | return DayTimeIntervalType[Literal[start_field], Literal[end_field]] # type: ignore
65 |
66 | if isinstance(dtype, SparkDecimalType):
67 | precision = dtype.precision
68 | scale = dtype.scale
69 | return DecimalType[Literal[precision], Literal[scale]] # type: ignore
70 |
71 | return type(dtype)
72 |
--------------------------------------------------------------------------------
/typedspark/_utils/databases.py:
--------------------------------------------------------------------------------
1 | """Loads all catalogs, databases and tables in a SparkSession."""
2 |
3 | from abc import ABC
4 | from datetime import datetime
5 | from typing import Any, Optional, Tuple, TypeVar
6 | from warnings import warn
7 |
8 | from pyspark.sql import Row, SparkSession
9 |
10 | from typedspark._core.dataset import DataSet
11 | from typedspark._schema.schema import Schema
12 | from typedspark._utils.camelcase import to_camel_case
13 | from typedspark._utils.load_table import load_table
14 |
15 | T = TypeVar("T", bound=Schema)
16 |
17 |
18 | class Timeout(ABC):
19 | """Warns the user if loading databases or catalogs is taking too long."""
20 |
21 | _TIMEOUT_WARNING: str
22 |
23 | def __init__(self, silent: bool, n: int): # pylint: disable=invalid-name
24 | self._start = datetime.now()
25 | self._silent = silent
26 | self._n = n
27 |
28 | def check_for_warning(self, i: int): # pragma: no cover
29 | """Checks if a warning should be issued."""
30 | if self._silent:
31 | return
32 |
33 | diff = datetime.now() - self._start
34 | if diff.seconds > 10:
35 | warn(self._TIMEOUT_WARNING.format(i, self._n))
36 | self._silent = True
37 |
38 |
39 | class DatabasesTimeout(Timeout):
40 | """Warns the user if Databases() is taking too long."""
41 |
42 | _TIMEOUT_WARNING = """
43 | Databases() is taking longer than 10 seconds. So far, {} out of {} databases have been loaded.
44 | If this is too slow, consider loading a single database using:
45 |
46 | from typedspark import Database
47 |
48 | db = Database(spark, db_name=...)
49 | """
50 |
51 |
52 | class CatalogsTimeout(Timeout):
53 | """Warns the user if Catalogs() is taking too long."""
54 |
55 | _TIMEOUT_WARNING = """
56 | Catalogs() is taking longer than 10 seconds. So far, {} out of {} catalogs have been loaded.
57 | If this is too slow, consider loading a single catalog using:
58 |
59 | from typedspark import Databases
60 |
61 | db = Databases(spark, catalog_name=...)
62 | """
63 |
64 |
65 | def _get_spark_session(spark: Optional[SparkSession]) -> SparkSession:
66 | if spark is not None:
67 | return spark
68 |
69 | spark = SparkSession.getActiveSession()
70 | if spark is not None:
71 | return spark
72 |
73 | raise ValueError("No active SparkSession found.") # pragma: no cover
74 |
75 |
76 | def _resolve_names_starting_with_an_underscore(name: str, names: list[str]) -> str:
77 | """Autocomplete is currently problematic when a name (of a table, database, or
78 | catlog) starts with an underscore.
79 |
80 | In this case, it's considered a private attribute and it doesn't show up in the
81 | autocomplete options in your notebook. To combat this behaviour, we add a u as a
82 | prefix, followed by as many underscores as needed (up to 100) to keep the name
83 | unique.
84 | """
85 | if not name.startswith("_"):
86 | return name
87 |
88 | prefix = "u"
89 | proposed_name = prefix + name
90 | i = 0
91 | while proposed_name in names:
92 | prefix = prefix + "_"
93 | proposed_name = prefix + name
94 | i += 1
95 | if i > 100:
96 | raise Exception(
97 | "Couldn't find a unique name, even when adding 100 underscores. This seems unlikely"
98 | " behaviour, exiting to prevent an infinite loop."
99 | ) # pragma: no cover
100 |
101 | return proposed_name
102 |
103 |
104 | class Table:
105 | """Loads a table in a database."""
106 |
107 | def __init__(self, spark: SparkSession, db_name: str, table_name: str, is_temporary: bool):
108 | self._spark = spark
109 | self._db_name = db_name
110 | self._table_name = table_name
111 | self._is_temporary = is_temporary
112 |
113 | @property
114 | def str(self) -> str:
115 | """Returns the path to the table, e.g. ``default.person``.
116 |
117 | While temporary tables are always stored in the ``default`` db, they are saved and
118 | loaded directly from their table name, e.g. ``person``.
119 |
120 | Non-temporary tables are saved and loaded from their full name, e.g.
121 | ``default.person``.
122 | """
123 | if self._is_temporary:
124 | return self._table_name
125 |
126 | return f"{self._db_name}.{self._table_name}"
127 |
128 | def load(self) -> Tuple[DataSet[T], T]:
129 | """Loads the table as a DataSet[T] and returns the schema."""
130 | return load_table( # type: ignore
131 | self._spark,
132 | self.str,
133 | to_camel_case(self._table_name),
134 | )
135 |
136 | def __call__(self, *args: Any, **kwds: Any) -> Tuple[DataSet[T], T]:
137 | return self.load()
138 |
139 |
140 | class Database:
141 | """Loads all tables in a database."""
142 |
143 | def __init__(
144 | self,
145 | spark: Optional[SparkSession] = None,
146 | db_name: str = "default",
147 | catalog_name: Optional[str] = None,
148 | ):
149 | spark = _get_spark_session(spark)
150 |
151 | if catalog_name is None:
152 | self._db_name = db_name
153 | else:
154 | self._db_name = f"{catalog_name}.{db_name}"
155 |
156 | tables = spark.sql(f"show tables from {self._db_name}").collect()
157 | table_names = [table.tableName for table in tables]
158 |
159 | for table in tables:
160 | escaped_name = _resolve_names_starting_with_an_underscore(table.tableName, table_names)
161 | self.__setattr__(
162 | escaped_name,
163 | Table(spark, self._db_name, table.tableName, table.isTemporary),
164 | )
165 |
166 | @property
167 | def str(self) -> str:
168 | """Returns the database name."""
169 | return self._db_name
170 |
171 |
172 | class Databases:
173 | """Loads all databases and tables in a SparkSession."""
174 |
175 | def __init__(
176 | self,
177 | spark: Optional[SparkSession] = None,
178 | silent: bool = False,
179 | catalog_name: Optional[str] = None,
180 | ):
181 | spark = _get_spark_session(spark)
182 |
183 | if catalog_name is None:
184 | query = "show databases"
185 | else:
186 | query = f"show databases in {catalog_name}"
187 |
188 | databases = spark.sql(query).collect()
189 | database_names = [self._extract_db_name(database) for database in databases]
190 | timeout = DatabasesTimeout(silent, n=len(databases))
191 |
192 | for i, db_name in enumerate(database_names):
193 | timeout.check_for_warning(i)
194 | escaped_name = _resolve_names_starting_with_an_underscore(db_name, database_names)
195 | self.__setattr__(
196 | escaped_name,
197 | Database(spark, db_name, catalog_name),
198 | )
199 |
200 | def _extract_db_name(self, database: Row) -> str:
201 | """Extracts the database name from a Row.
202 |
203 | Old versions of Spark use ``databaseName``, newer versions use ``namespace``.
204 | """
205 | if hasattr(database, "databaseName"): # pragma: no cover
206 | return database.databaseName
207 | if hasattr(database, "namespace"):
208 | return database.namespace
209 |
210 | raise ValueError(f"Could not find database name in {database}.") # pragma: no cover
211 |
212 |
213 | class Catalogs:
214 | """Loads all catalogs, databases and tables in a SparkSession."""
215 |
216 | def __init__(self, spark: Optional[SparkSession] = None, silent: bool = False):
217 | spark = _get_spark_session(spark)
218 |
219 | catalogs = spark.sql("show catalogs").collect()
220 | catalog_names = [catalog.catalog for catalog in catalogs]
221 | timeout = CatalogsTimeout(silent, n=len(catalogs))
222 |
223 | for i, catalog_name in enumerate(catalog_names):
224 | escaped_name = _resolve_names_starting_with_an_underscore(catalog_name, catalog_names)
225 | timeout.check_for_warning(i)
226 | self.__setattr__(
227 | escaped_name,
228 | Databases(spark, silent=True, catalog_name=catalog_name),
229 | )
230 |
--------------------------------------------------------------------------------
/typedspark/_utils/load_table.py:
--------------------------------------------------------------------------------
1 | """Functions for loading `DataSet` and `Schema` in notebooks."""
2 |
3 | import re
4 | from typing import Dict, Optional, Tuple, Type
5 |
6 | from pyspark.sql import DataFrame, SparkSession
7 |
8 | from typedspark._core.dataset import DataSet
9 | from typedspark._schema.schema import Schema
10 | from typedspark._utils.create_dataset_from_structtype import create_schema_from_structtype
11 | from typedspark._utils.register_schema_to_dataset import register_schema_to_dataset
12 |
13 |
14 | def _replace_illegal_column_names(dataframe: DataFrame) -> DataFrame:
15 | """Replaces illegal column names with a legal version."""
16 | mapping = _create_mapping(dataframe)
17 |
18 | for column, column_renamed in mapping.items():
19 | if column != column_renamed:
20 | dataframe = dataframe.withColumnRenamed(column, column_renamed)
21 |
22 | return dataframe
23 |
24 |
25 | def _create_mapping(dataframe: DataFrame) -> Dict[str, str]:
26 | """Checks if there are duplicate columns after replacing illegal characters."""
27 | mapping = {column: _replace_illegal_characters(column) for column in dataframe.columns}
28 | renamed_columns = list(mapping.values())
29 | duplicates = {
30 | column: column_renamed
31 | for column, column_renamed in mapping.items()
32 | if renamed_columns.count(column_renamed) > 1
33 | }
34 |
35 | if len(duplicates) > 0:
36 | raise ValueError(
37 | "You're trying to dynamically generate a Schema from a DataFrame. "
38 | + "However, typedspark has detected that the DataFrame contains duplicate columns "
39 | + "after replacing illegal characters (e.g. whitespaces, dots, etc.).\n"
40 | + "The folowing columns have lead to duplicates:\n"
41 | + f"{duplicates}\n\n"
42 | + "Please rename these columns in your DataFrame."
43 | )
44 |
45 | return mapping
46 |
47 |
48 | def _replace_illegal_characters(column_name: str) -> str:
49 | """Replaces illegal characters in a column name with an underscore."""
50 | return re.sub("[^A-Za-z0-9]", "_", column_name)
51 |
52 |
53 | def create_schema(
54 | dataframe: DataFrame, schema_name: Optional[str] = None
55 | ) -> Tuple[DataSet[Schema], Type[Schema]]:
56 | """This function inferres a ``Schema`` in a notebook based on a the provided
57 | ``DataFrame``.
58 |
59 | This allows for autocompletion on column names, amongst other
60 | things.
61 |
62 | .. code-block:: python
63 |
64 | df, Person = create_schema(df)
65 | """
66 | dataframe = _replace_illegal_column_names(dataframe)
67 | schema = create_schema_from_structtype(dataframe.schema, schema_name)
68 | dataset = DataSet[schema](dataframe) # type: ignore
69 | schema = register_schema_to_dataset(dataset, schema)
70 | return dataset, schema
71 |
72 |
73 | def load_table(
74 | spark: SparkSession, table_name: str, schema_name: Optional[str] = None
75 | ) -> Tuple[DataSet[Schema], Type[Schema]]:
76 | """This function loads a ``DataSet``, along with its inferred ``Schema``, in a
77 | notebook.
78 |
79 | This allows for autocompletion on column names, amongst other
80 | things.
81 |
82 | .. code-block:: python
83 |
84 | df, Person = load_table(spark, "path.to.table")
85 | """
86 | dataframe = spark.table(table_name)
87 | return create_schema(dataframe, schema_name)
88 |
89 |
90 | DataFrame.to_typedspark = create_schema # type: ignore
91 |
--------------------------------------------------------------------------------
/typedspark/_utils/register_schema_to_dataset.py:
--------------------------------------------------------------------------------
1 | """Module containing functions that are related to registering schema's to DataSets."""
2 |
3 | import itertools
4 | from typing import Tuple, Type, TypeVar
5 |
6 | from typedspark._core.dataset import DataSet
7 | from typedspark._schema.schema import Schema
8 |
9 | T = TypeVar("T", bound=Schema)
10 |
11 |
12 | def _counter(count: itertools.count = itertools.count()):
13 | return next(count)
14 |
15 |
16 | def register_schema_to_dataset(dataframe: DataSet[T], schema: Type[T]) -> Type[T]:
17 | """Helps combat column ambiguity. For example:
18 |
19 | .. code-block:: python
20 |
21 | class Person(Schema):
22 | id: Column[IntegerType]
23 | name: Column[StringType]
24 |
25 | class Job(Schema):
26 | id: Column[IntegerType]
27 | salary: Column[IntegerType]
28 |
29 | class PersonWithJob(Person, Job):
30 | pass
31 |
32 | def foo(df_a: DataSet[Person], df_b: DataSet[Job]) -> DataSet[PersonWithJob]:
33 | return DataSet[PersonWithSalary](
34 | df_a.join(
35 | df_b,
36 | Person.id == Job.id
37 | )
38 | )
39 |
40 | Calling ``foo()`` would result in a ``AnalysisException``, because Spark can't figure out
41 | whether ``id`` belongs to ``df_a`` or ``df_b``. To deal with this, you need to register
42 | your ``Schema`` to the ``DataSet``.
43 |
44 | .. code-block:: python
45 |
46 | from typedspark import register_schema_to_dataset
47 |
48 | def foo(df_a: DataSet[Person], df_b: DataSet[Job]) -> DataSet[PersonWithSalary]:
49 | person = register_schema_to_dataset(df_a, Person)
50 | job = register_schema_to_dataset(df_b, Job)
51 | return DataSet[PersonWithSalary](
52 | df_a.join(
53 | df_b,
54 | person.id == job.id
55 | )
56 | )
57 | """
58 |
59 | class LinkedSchema(schema): # type: ignore # pylint: disable=missing-class-docstring
60 | _parent = dataframe
61 | _current_id = _counter()
62 | _original_name = schema.get_schema_name()
63 |
64 | return LinkedSchema # type: ignore
65 |
66 |
67 | def register_schema_to_dataset_with_alias(
68 | dataframe: DataSet[T], schema: Type[T], alias: str
69 | ) -> Tuple[DataSet[T], Type[T]]:
70 | """When dealing with self-joins, running `register_dataset_to_schema()` is not
71 | enough.
72 |
73 | Instead, we'll need `register_dataset_to_schema_with_alias()`, e.g.:
74 |
75 | .. code-block:: python
76 |
77 | class Person(Schema):
78 | id: Column[IntegerType]
79 | name: Column[StringType]
80 |
81 | df_a, person_a = register_schema_to_dataset_with_alias(df, Person, alias="a")
82 | df_b, person_b = register_schema_to_dataset_with_alias(df, Person, alias="b")
83 |
84 | df_a.join(df_b, person_a.id == person_b.id)
85 | """
86 |
87 | class LinkedSchema(schema): # type: ignore # pylint: disable=missing-class-docstring
88 | _current_id = _counter()
89 | _original_name = schema.get_schema_name()
90 | _alias = alias
91 |
92 | return (
93 | dataframe.alias(alias),
94 | LinkedSchema, # type: ignore
95 | )
96 |
--------------------------------------------------------------------------------
/typedspark/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaiko-ai/typedspark/0c99ffb7cb9eddee9121254163f74b8de25d9f6b/typedspark/py.typed
--------------------------------------------------------------------------------