├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── actions │ └── setup-poetry-env │ │ └── action.yml ├── pull_request_template.md └── workflows │ ├── ci.yml │ ├── mkdocs.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── NOTICE-binary.md ├── README.md ├── chispa ├── __init__.py ├── bcolors.py ├── column_comparer.py ├── common_enums.py ├── dataframe_comparer.py ├── default_formats.py ├── formatting │ ├── __init__.py │ ├── format_string.py │ ├── formats.py │ └── formatting_config.py ├── number_helpers.py ├── py.typed ├── row_comparer.py ├── rows_comparer.py ├── schema_comparer.py └── structfield_comparer.py ├── ci └── environment-py39.yml ├── docs ├── gen_ref_pages.py └── index.md ├── images ├── columns_not_approx_equal.png ├── columns_not_equal_error.png ├── custom_formats.png ├── df_not_equal_underlined.png ├── dfs_not_approx_equal.png ├── dfs_not_equal_error.png ├── dfs_not_equal_error_old.png ├── ignore_column_order_false.png ├── ignore_row_order_false.png ├── ignore_row_order_false_old.png ├── nullable_off_error.png └── schemas_not_approx_equal.png ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── data └── tree_string │ ├── it_prints_correctly_for_wide_schemas.txt │ ├── it_prints_correctly_for_wide_schemas_different_lengths.txt │ ├── it_prints_correctly_for_wide_schemas_ignore_metadata.txt │ ├── it_prints_correctly_for_wide_schemas_ignore_nullable.txt │ ├── it_prints_correctly_for_wide_schemas_multiple_nested_structs.txt │ └── it_prints_correctly_for_wide_schemas_with_metadata.txt ├── formatting ├── test_formats.py ├── test_formatting_config.py └── test_terminal_string_formatter.py ├── spark.py ├── test_column_comparer.py ├── test_dataframe_comparer.py ├── test_deprecated.py ├── test_readme_examples.py ├── test_row_comparer.py ├── test_rows_comparer.py ├── test_schema_comparer.py └── test_structfield_comparer.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a bug report to help us improve 4 | labels: "bug" 5 | --- 6 | 7 | **Describe the bug** 8 | 9 | 10 | 11 | **To Reproduce** 12 | 13 | Steps to reproduce the behavior: 14 | 15 | 1. ... 16 | 2. ... 17 | 3. ... 18 | 19 | **Expected behavior** 20 | 21 | 22 | 23 | **System [please complete the following information]:** 24 | 25 | - OS: e.g. [Ubuntu 18.04] 26 | - Python Version: [e.g. Python 3.8] 27 | - PySpark version: [e.g. PySpark 3.5.1] 28 | 29 | **Additional context** 30 | 31 | 32 | 33 | **Are you planning on creating a PR?** 34 | 35 | 36 | 37 | - [ ] I'm planning to make a pull-request 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest a new feature 4 | labels: "enhancement" 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe.** 8 | 9 | 10 | 11 | **Describe the solution you would like** 12 | 13 | 14 | 15 | **Additional context** 16 | 17 | 18 | 19 | **Are you planning on creating a PR?** 20 | 21 | 22 | 23 | - [ ] I'm planning to make a pull-request 24 | -------------------------------------------------------------------------------- /.github/actions/setup-poetry-env/action.yml: -------------------------------------------------------------------------------- 1 | name: "setup-poetry-env" 2 | description: "Composite action to setup the Python and poetry environment." 3 | 4 | inputs: 5 | python-version: 6 | required: false 7 | description: "The python version to use" 8 | default: "3.11" 9 | pyspark-version: 10 | required: false 11 | description: "The pyspark version to use" 12 | default: "3.5.1" 13 | with-docs: 14 | required: false 15 | description: "Install the docs dependency group" 16 | default: 'false' 17 | 18 | runs: 19 | using: "composite" 20 | steps: 21 | - name: Set up python 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: ${{ inputs.python-version }} 25 | 26 | - name: Install Poetry 27 | env: 28 | # renovate: datasource=pypi depName=poetry 29 | POETRY_VERSION: "1.8.3" 30 | run: curl -sSL https://install.python-poetry.org | python - -y 31 | shell: bash 32 | 33 | - name: Add Poetry to Path 34 | run: echo "$HOME/.local/bin" >> $GITHUB_PATH 35 | shell: bash 36 | 37 | - name: Configure Poetry virtual environment in project 38 | run: poetry config virtualenvs.in-project true 39 | shell: bash 40 | 41 | - name: Load cached venv 42 | id: cached-poetry-dependencies 43 | uses: actions/cache@v3 44 | with: 45 | path: .venv 46 | key: venv-${{ runner.os }}-${{ inputs.python-version }}-${{ inputs.pyspark-version }}-${{ inputs.with-docs }}-${{ hashFiles('poetry.lock') }} 47 | 48 | - name: Install dependencies 49 | run: | 50 | if [[ "${{ inputs.with-docs }}" == "true" ]]; then 51 | poetry install --no-interaction --with mkdocs 52 | else 53 | poetry install --no-interaction 54 | fi 55 | poetry run pip install pyspark==${{ inputs.pyspark-version }} 56 | shell: bash 57 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 58 | 59 | - name: Print python and pyspark version 60 | run: | 61 | poetry run python --version 62 | poetry run pyspark --version 63 | shell: bash 64 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | **PR Checklist** 2 | 3 | - [ ] A description of the changes is added to the description of this PR. 4 | - [ ] If there is a related issue, make sure it is linked to this PR. 5 | - [ ] If you've fixed a bug or added code that should be tested, add tests! 6 | - [ ] If you've added or modified a feature, documentation in `docs` is updated 7 | 8 | **Description of changes** 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Unit tests 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize, reopened] 6 | push: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | quality: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Check out 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up the environment 18 | uses: ./.github/actions/setup-poetry-env 19 | 20 | - name: Check lock file 21 | run: poetry lock --check 22 | 23 | - name: Run code quality checks 24 | run: poetry run make check 25 | 26 | test: 27 | runs-on: ubuntu-latest 28 | strategy: 29 | matrix: 30 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 31 | pyspark-version: ["3.4.3", "3.5.1"] 32 | include: 33 | - python-version: "3.8" 34 | pyspark-version: "3.3.4" 35 | - python-version: "3.9" 36 | pyspark-version: "3.3.4" 37 | - python-version: "3.10" 38 | pyspark-version: "3.3.4" 39 | fail-fast: false 40 | defaults: 41 | run: 42 | shell: bash 43 | steps: 44 | - name: Check out 45 | uses: actions/checkout@v3 46 | 47 | - name: Set up the environment 48 | uses: ./.github/actions/setup-poetry-env 49 | with: 50 | python-version: ${{ matrix.python-version }} 51 | pyspark-version: ${{ matrix.pyspark-version }} 52 | 53 | - name: Run tests 54 | run: poetry run pytest tests 55 | 56 | check-docs: 57 | runs-on: ubuntu-latest 58 | steps: 59 | - name: Check out 60 | uses: actions/checkout@v3 61 | 62 | - name: Set up the environment 63 | uses: ./.github/actions/setup-poetry-env 64 | with: 65 | with-docs: true 66 | 67 | - name: Check if documentation can be built 68 | run: poetry run mkdocs build -s 69 | -------------------------------------------------------------------------------- /.github/workflows/mkdocs.yml: -------------------------------------------------------------------------------- 1 | name: MKDocs deploy 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up the environment 17 | uses: ./.github/actions/setup-poetry-env 18 | with: 19 | with-docs: true 20 | 21 | - name: Setup GH 22 | run: | 23 | sudo apt update && sudo apt install -y git 24 | git config user.name 'github-actions[bot]' 25 | git config user.email 'github-actions[bot]@users.noreply.github.com' 26 | - name: Build and Deploy 27 | run: 28 | poetry run mkdocs gh-deploy --force 29 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | branches: [main] 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/chispa 14 | permissions: 15 | id-token: write 16 | steps: 17 | - name: Check out 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up the environment 21 | uses: ./.github/actions/setup-poetry-env 22 | 23 | - name: Export tag 24 | id: vars 25 | run: | 26 | tag=${GITHUB_REF#refs/*/} 27 | version=${tag#v} 28 | echo tag=$tag >> $GITHUB_OUTPUT 29 | echo version=$version >> $GITHUB_OUTPUT 30 | 31 | - name: Build Python package 32 | run: poetry build 33 | 34 | - name: Publish to PyPi 35 | uses: pypa/gh-action-pypi-publish@release/v1 36 | 37 | deploy-docs: 38 | runs-on: ubuntu-latest 39 | needs: publish 40 | steps: 41 | - name: Check out 42 | uses: actions/checkout@v4 43 | 44 | - name: Set up the environment 45 | uses: ./.github/actions/setup-poetry-env 46 | 47 | - name: Deploy documentation 48 | run: poetry run mkdocs gh-deploy --force 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .python_version 3 | 4 | # Emacs 5 | .dir-locals.el 6 | 7 | # VSCode 8 | .vscode 9 | 10 | # Below are sections from https://github.com/github/gitignore/blob/main/Python.gitignore 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | .pybuilder/ 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 83 | __pypackages__/ 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # mkdocs documentation 95 | /site 96 | 97 | # mypy 98 | .mypy_cache/ 99 | .dmypy.json 100 | dmypy.json 101 | 102 | # pytype static type analyzer 103 | .pytype/ 104 | 105 | # Cython debug symbols 106 | cython_debug/ 107 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: "v4.6.0" 5 | hooks: 6 | - id: check-case-conflict 7 | - id: check-merge-conflict 8 | - id: check-toml 9 | - id: check-yaml 10 | - id: end-of-file-fixer 11 | - id: trailing-whitespace 12 | args: [--markdown-linebreak-ext=md] 13 | 14 | - repo: https://github.com/astral-sh/ruff-pre-commit 15 | rev: "v0.5.2" 16 | hooks: 17 | - id: ruff 18 | args: [--exit-non-zero-on-fix] 19 | - id: ruff-format 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to chispa 2 | 3 | ## Contributing 4 | 5 | // TODO 6 | 7 | ## Making a New Release 8 | 9 | ### Step 1: Update the CHANGELOG.md 10 | 11 | 1. Create a Pull Request (PR) to update the `CHANGELOG.md`. 12 | - Go to the "Create new release" section. 13 | - Use "Auto-generate release notes" to help create the changelog. 14 | - Edit the notes if needed. 15 | 16 | ### Step 2: Merge the PR 17 | 18 | 1. After review and approval, merge the PR into the main branch. 19 | 20 | ### Step 3: Create a New Release 21 | 22 | 1. Go to the "Releases" section. 23 | - Click "Draft a new release". 24 | - Set the tag version as `vx.y.z` (e.g., `v0.11.0`). 25 | - Copy the updated `CHANGELOG.md` to the release description. 26 | - Publish the release. 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Matthew Powers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install 2 | install: ## Install the Poetry environment 3 | @echo "Creating virtual environment using Poetry" 4 | @poetry install 5 | 6 | .PHONY: check 7 | check: ## Run code quality checks 8 | @echo "Running pre-commit hooks" 9 | @poetry run pre-commit run -a 10 | @poetry run mypy chispa 11 | 12 | .PHONY: test 13 | test: ## Run unit tests 14 | @echo "Running unit tests" 15 | @poetry run pytest tests --cov=chispa --cov-report=term 16 | 17 | .PHONY: test-cov-html 18 | test-cov-html: ## Run unit tests and create a coverage report 19 | @echo "Running unit tests and generating HTML report" 20 | @poetry run pytest tests --cov=chispa --cov-report=html 21 | 22 | .PHONY: test-cov-xml 23 | test-cov-xml: ## Run unit tests and create a coverage report in xml format 24 | @echo "Running unit tests and generating XML report" 25 | @poetry run pytest tests --cov=chispa --cov-report=xml 26 | 27 | .PHONY: build 28 | build: clean-build ## Build wheel and sdist files using Poetry 29 | @echo "Creating wheel and sdist files" 30 | @poetry build 31 | 32 | .PHONY: clean-build 33 | clean-build: ## clean build artifacts 34 | @rm -rf dist 35 | 36 | .PHONY: publish 37 | publish: ## Publish a release to PyPI 38 | @echo "Publishing: Dry run." 39 | @poetry config pypi-token.pypi $(PYPI_TOKEN) 40 | @poetry publish --dry-run 41 | @echo "Publishing." 42 | @poetry publish 43 | 44 | .PHONY: build-and-publish 45 | build-and-publish: build publish ## Build and publish 46 | 47 | .PHONY: docs-test 48 | docs-test: ## Test if documentation can be built without warnings or errors 49 | @poetry run mkdocs build -s 50 | 51 | .PHONY: docs 52 | docs: ## Build and serve the documentation 53 | @poetry run mkdocs serve 54 | 55 | # Inspired by https://marmelab.com/blog/2016/02/29/auto-documented-makefile.html 56 | .PHONY: help 57 | help: ## Show help for the commands 58 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' 59 | 60 | .DEFAULT_GOAL := help 61 | -------------------------------------------------------------------------------- /NOTICE-binary.md: -------------------------------------------------------------------------------- 1 | # apache spark 2 | Apache Spark 3 | Copyright 2014 and onwards The Apache Software Foundation. 4 | 5 | # findspark 6 | Copyright (c) 2015, Min RK All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without modification, 9 | are permitted provided that the following conditions are met: 10 | 11 | Redistributions of source code must retain the above copyright notice, this list 12 | of conditions and the following disclaimer. 13 | 14 | Redistributions in binary form must reproduce the above copyright notice, this list of 15 | conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 16 | 17 | Neither the name of findspark nor the names of its contributors may be used to endorse 18 | or promote products derived from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, 21 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 25 | STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | # pre-commit 29 | Copyright (c) 2014 pre-commit dev team: Anthony Sottile, Ken Struys 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in 39 | all copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 47 | THE SOFTWARE. 48 | 49 | # prettytable 50 | Copyright (c) 2009-2014 Luke Maurits 51 | All rights reserved. 52 | With contributions from: 53 | * Chris Clark 54 | * Klein Stephane 55 | * John Filleau 56 | * Vladimir Vrzić 57 | 58 | Redistribution and use in source and binary forms, with or without 59 | modification, are permitted provided that the following conditions are met: 60 | 61 | * Redistributions of source code must retain the above copyright notice, 62 | this list of conditions and the following disclaimer. 63 | * Redistributions in binary form must reproduce the above copyright notice, 64 | this list of conditions and the following disclaimer in the documentation 65 | and/or other materials provided with the distribution. 66 | * The name of the author may not be used to endorse or promote products 67 | derived from this software without specific prior written permission. 68 | 69 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 70 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 71 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 72 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 73 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 74 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 75 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 76 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 77 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 78 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 79 | POSSIBILITY OF SUCH DAMAGE. 80 | 81 | # pytest 82 | The MIT License (MIT) 83 | 84 | Copyright (c) 2004 Holger Krekel and others 85 | 86 | Permission is hereby granted, free of charge, to any person obtaining a copy of 87 | this software and associated documentation files (the "Software"), to deal in 88 | the Software without restriction, including without limitation the rights to 89 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 90 | of the Software, and to permit persons to whom the Software is furnished to do 91 | so, subject to the following conditions: 92 | 93 | The above copyright notice and this permission notice shall be included in all 94 | copies or substantial portions of the Software. 95 | 96 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 97 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 98 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 99 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 100 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 101 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 102 | SOFTWARE. 103 | 104 | # pytest-describe 105 | Permission is hereby granted, free of charge, to any person obtaining a copy 106 | of this software and associated documentation files (the "Software"), to deal 107 | in the Software without restriction, including without limitation the rights 108 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 109 | copies of the Software, and to permit persons to whom the Software is 110 | furnished to do so, subject to the following conditions: 111 | 112 | The above copyright notice and this permission notice shall be included in all 113 | copies or substantial portions of the Software. 114 | 115 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 116 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 117 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 118 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 119 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 120 | 121 | # pytest-cov 122 | The MIT License 123 | 124 | Copyright (c) 2010 Meme Dough 125 | 126 | Permission is hereby granted, free of charge, to any person obtaining a copy 127 | of this software and associated documentation files (the "Software"), to deal 128 | in the Software without restriction, including without limitation the rights 129 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 130 | copies of the Software, and to permit persons to whom the Software is 131 | furnished to do so, subject to the following conditions: 132 | 133 | The above copyright notice and this permission notice shall be included in 134 | all copies or substantial portions of the Software. 135 | 136 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 137 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 138 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 139 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 140 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 141 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 142 | THE SOFTWARE. 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chispa 2 | 3 | ![![image](https://github.com/MrPowers/chispa/workflows/build/badge.svg)](https://github.com/MrPowers/chispa/actions/workflows/ci.yml/badge.svg) 4 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/chispa) 5 | [![PyPI version](https://badge.fury.io/py/chispa.svg)](https://badge.fury.io/py/chispa) 6 | 7 | chispa provides fast PySpark test helper methods that output descriptive error messages. 8 | 9 | This library makes it easy to write high quality PySpark code. 10 | 11 | Fun fact: "chispa" means Spark in Spanish ;) 12 | 13 | ## Installation 14 | 15 | Install the latest version with `pip install chispa`. 16 | 17 | If you use Poetry, add this library as a development dependency with `poetry add chispa -G dev`. 18 | 19 | ## Column equality 20 | 21 | Suppose you have a function that removes the non-word characters in a string. 22 | 23 | ```python 24 | def remove_non_word_characters(col): 25 | return F.regexp_replace(col, "[^\\w\\s]+", "") 26 | ``` 27 | 28 | Create a `SparkSession` so you can create DataFrames. 29 | 30 | ```python 31 | from pyspark.sql import SparkSession 32 | 33 | spark = (SparkSession.builder 34 | .master("local") 35 | .appName("chispa") 36 | .getOrCreate()) 37 | ``` 38 | 39 | Create a DataFrame with a column that contains strings with non-word characters, run the `remove_non_word_characters` function, and check that all these characters are removed with the chispa `assert_column_equality` method. 40 | 41 | ```python 42 | import pytest 43 | 44 | from chispa.column_comparer import assert_column_equality 45 | import pyspark.sql.functions as F 46 | 47 | def test_remove_non_word_characters_short(): 48 | data = [ 49 | ("jo&&se", "jose"), 50 | ("**li**", "li"), 51 | ("#::luisa", "luisa"), 52 | (None, None) 53 | ] 54 | df = (spark.createDataFrame(data, ["name", "expected_name"]) 55 | .withColumn("clean_name", remove_non_word_characters(F.col("name")))) 56 | assert_column_equality(df, "clean_name", "expected_name") 57 | ``` 58 | 59 | Let's write another test that'll fail to see how the descriptive error message lets you easily debug the underlying issue. 60 | 61 | Here's the failing test: 62 | 63 | ```python 64 | def test_remove_non_word_characters_nice_error(): 65 | data = [ 66 | ("matt7", "matt"), 67 | ("bill&", "bill"), 68 | ("isabela*", "isabela"), 69 | (None, None) 70 | ] 71 | df = (spark.createDataFrame(data, ["name", "expected_name"]) 72 | .withColumn("clean_name", remove_non_word_characters(F.col("name")))) 73 | assert_column_equality(df, "clean_name", "expected_name") 74 | ``` 75 | 76 | Here's the nicely formatted error message: 77 | 78 | ![ColumnsNotEqualError](https://raw.githubusercontent.com/MrPowers/chispa/main/images/columns_not_equal_error.png) 79 | 80 | You can see the `matt7` / `matt` row of data is what's causing the error (note it's highlighted in red). The other rows are colored blue because they're equal. 81 | 82 | ## DataFrame equality 83 | 84 | We can also test the `remove_non_word_characters` method by creating two DataFrames and verifying that they're equal. 85 | 86 | Creating two DataFrames is slower and requires more code, but comparing entire DataFrames is necessary for some tests. 87 | 88 | ```python 89 | from chispa.dataframe_comparer import * 90 | 91 | def test_remove_non_word_characters_long(): 92 | source_data = [ 93 | ("jo&&se",), 94 | ("**li**",), 95 | ("#::luisa",), 96 | (None,) 97 | ] 98 | source_df = spark.createDataFrame(source_data, ["name"]) 99 | 100 | actual_df = source_df.withColumn( 101 | "clean_name", 102 | remove_non_word_characters(F.col("name")) 103 | ) 104 | 105 | expected_data = [ 106 | ("jo&&se", "jose"), 107 | ("**li**", "li"), 108 | ("#::luisa", "luisa"), 109 | (None, None) 110 | ] 111 | expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) 112 | 113 | assert_df_equality(actual_df, expected_df) 114 | ``` 115 | 116 | Let's write another test that'll return an error, so you can see the descriptive error message. 117 | 118 | ```python 119 | def test_remove_non_word_characters_long_error(): 120 | source_data = [ 121 | ("matt7",), 122 | ("bill&",), 123 | ("isabela*",), 124 | (None,) 125 | ] 126 | source_df = spark.createDataFrame(source_data, ["name"]) 127 | 128 | actual_df = source_df.withColumn( 129 | "clean_name", 130 | remove_non_word_characters(F.col("name")) 131 | ) 132 | 133 | expected_data = [ 134 | ("matt7", "matt"), 135 | ("bill&", "bill"), 136 | ("isabela*", "isabela"), 137 | (None, None) 138 | ] 139 | expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) 140 | 141 | assert_df_equality(actual_df, expected_df) 142 | ``` 143 | 144 | Here's the nicely formatted error message: 145 | 146 | ![DataFramesNotEqualError](https://raw.githubusercontent.com/MrPowers/chispa/main/images/dfs_not_equal_error.png) 147 | 148 | ### Ignore row order 149 | 150 | You can easily compare DataFrames, ignoring the order of the rows. The content of the DataFrames is usually what matters, not the order of the rows. 151 | 152 | Here are the contents of `df1`: 153 | 154 | ``` 155 | +--------+ 156 | |some_num| 157 | +--------+ 158 | | 1| 159 | | 2| 160 | | 3| 161 | +--------+ 162 | ``` 163 | 164 | Here are the contents of `df2`: 165 | 166 | ``` 167 | +--------+ 168 | |some_num| 169 | +--------+ 170 | | 2| 171 | | 1| 172 | | 3| 173 | +--------+ 174 | ``` 175 | 176 | Here's how to confirm `df1` and `df2` are equal when the row order is ignored. 177 | 178 | ```python 179 | assert_df_equality(df1, df2, ignore_row_order=True) 180 | ``` 181 | 182 | If you don't specify to `ignore_row_order` then the test will error out with this message: 183 | 184 | ![ignore_row_order_false](https://raw.githubusercontent.com/MrPowers/chispa/main/images/ignore_row_order_false.png) 185 | 186 | The rows aren't ordered by default because sorting slows down the function. 187 | 188 | ### Ignore column order 189 | 190 | This section explains how to compare DataFrames, ignoring the order of the columns. 191 | 192 | Suppose you have the following `df1`: 193 | 194 | ``` 195 | +----+----+ 196 | |num1|num2| 197 | +----+----+ 198 | | 1| 7| 199 | | 2| 8| 200 | | 3| 9| 201 | +----+----+ 202 | ``` 203 | 204 | Here are the contents of `df2`: 205 | 206 | ``` 207 | +----+----+ 208 | |num2|num1| 209 | +----+----+ 210 | | 7| 1| 211 | | 8| 2| 212 | | 9| 3| 213 | +----+----+ 214 | ``` 215 | 216 | Here's how to compare the equality of `df1` and `df2`, ignoring the column order: 217 | 218 | ```python 219 | assert_df_equality(df1, df2, ignore_column_order=True) 220 | ``` 221 | 222 | Here's the error message you'll see if you run `assert_df_equality(df1, df2)`, without ignoring the column order. 223 | 224 | ![ignore_column_order_false](https://raw.githubusercontent.com/MrPowers/chispa/main/images/ignore_column_order_false.png) 225 | 226 | ### Ignore specific columns 227 | 228 | This section explains how to compare DataFrames, ignoring specific columns. 229 | 230 | Suppose you have the following `df1`: 231 | 232 | ``` 233 | +------------+-------------+ 234 | | name | clean_name | 235 | +------------+-------------+ 236 | | "matt7" | "matt7" | 237 | | "bill&" | "bill" | 238 | | "isabela*" | "isabela" | 239 | | "None" | "None" | 240 | +------------+-------------+ 241 | ``` 242 | 243 | Here are the contents of `df2`: 244 | 245 | ``` 246 | +------------+-------------+ 247 | | name | clean_name | 248 | +------------+-------------+ 249 | | "matt7" | "matt" | 250 | | "bill&" | "bill" | 251 | | "isabela*" | "isabela" | 252 | | "None" | "None" | 253 | +------------+-------------+ 254 | ``` 255 | 256 | Here's how to compare the equality of `df1` and `df2`, ignoring the column `clean_name`: 257 | 258 | ```python 259 | assert_df_equality(df1, df2, ignore_columns=["clean_name"]) 260 | ``` 261 | 262 | Here's the error message you'll see if you run `assert_df_equality(df1, df2)`, without ignoring the column `clean_name`. 263 | 264 | ![ignore_columns_none](https://raw.githubusercontent.com/MrPowers/chispa/main/images/dfs_not_equal_error.png) 265 | 266 | ### Ignore nullability 267 | 268 | Each column in a schema has three properties: a name, data type, and nullable property. The column can accept null values if `nullable` is set to true. 269 | 270 | You'll sometimes want to ignore the nullable property when making DataFrame comparisons. 271 | 272 | Suppose you have the following `df1`: 273 | 274 | ``` 275 | +-----+---+ 276 | | name|age| 277 | +-----+---+ 278 | | juan| 7| 279 | |bruna| 8| 280 | +-----+---+ 281 | ``` 282 | 283 | And this `df2`: 284 | 285 | ``` 286 | +-----+---+ 287 | | name|age| 288 | +-----+---+ 289 | | juan| 7| 290 | |bruna| 8| 291 | +-----+---+ 292 | ``` 293 | 294 | You might be surprised to find that in this example, `df1` and `df2` are not equal and will error out with this message: 295 | 296 | ![nullable_off_error](https://raw.githubusercontent.com/MrPowers/chispa/main/images/nullable_off_error.png) 297 | 298 | Examine the code in this contrived example to better understand the error: 299 | 300 | ```python 301 | def ignore_nullable_property(): 302 | s1 = StructType([ 303 | StructField("name", StringType(), True), 304 | StructField("age", IntegerType(), True)]) 305 | df1 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s1) 306 | s2 = StructType([ 307 | StructField("name", StringType(), True), 308 | StructField("age", IntegerType(), False)]) 309 | df2 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s2) 310 | assert_df_equality(df1, df2) 311 | ``` 312 | 313 | You can ignore the nullable property when assessing equality by adding a flag: 314 | 315 | ```python 316 | assert_df_equality(df1, df2, ignore_nullable=True) 317 | ``` 318 | 319 | Elements contained within an `ArrayType()` also have a nullable property, in addition to the nullable property of the column schema. These are also ignored when passing `ignore_nullable=True`. 320 | 321 | Again, examine the following code to understand the error that `ignore_nullable=True` bypasses: 322 | 323 | ```python 324 | def ignore_nullable_property_array(): 325 | s1 = StructType([ 326 | StructField("name", StringType(), True), 327 | StructField("coords", ArrayType(DoubleType(), True), True),]) 328 | df1 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s1) 329 | s2 = StructType([ 330 | StructField("name", StringType(), True), 331 | StructField("coords", ArrayType(DoubleType(), False), True),]) 332 | df2 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s2) 333 | assert_df_equality(df1, df2) 334 | ``` 335 | 336 | ### Allow NaN equality 337 | 338 | Python has NaN (not a number) values and two NaN values are not considered equal by default. Create two NaN values, compare them, and confirm they're not considered equal by default. 339 | 340 | ```python 341 | nan1 = float('nan') 342 | nan2 = float('nan') 343 | nan1 == nan2 # False 344 | ``` 345 | 346 | pandas considers NaN values to be equal by default, but this library requires you to set a flag to consider two NaN values to be equal. 347 | 348 | ```python 349 | assert_df_equality(df1, df2, allow_nan_equality=True) 350 | ``` 351 | 352 | ## Customize formatting 353 | 354 | You can specify custom formats for the printed error messages as follows: 355 | 356 | ```python 357 | from chispa import FormattingConfig 358 | 359 | formats = FormattingConfig( 360 | mismatched_rows={"color": "light_yellow"}, 361 | matched_rows={"color": "cyan", "style": "bold"}, 362 | mismatched_cells={"color": "purple"}, 363 | matched_cells={"color": "blue"}, 364 | ) 365 | 366 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) 367 | ``` 368 | 369 | or similarly: 370 | 371 | ```python 372 | from chispa import FormattingConfig, Color, Style 373 | 374 | formats = FormattingConfig( 375 | mismatched_rows={"color": Color.LIGHT_YELLOW}, 376 | matched_rows={"color": Color.CYAN, "style": Style.BOLD}, 377 | mismatched_cells={"color": Color.PURPLE}, 378 | matched_cells={"color": Color.BLUE}, 379 | ) 380 | 381 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) 382 | ``` 383 | 384 | You can also define these formats in `conftest.py` and inject them via a fixture: 385 | 386 | ```python 387 | @pytest.fixture() 388 | def chispa_formats(): 389 | return FormattingConfig( 390 | mismatched_rows={"color": "light_yellow"}, 391 | matched_rows={"color": "cyan", "style": "bold"}, 392 | mismatched_cells={"color": "purple"}, 393 | matched_cells={"color": "blue"}, 394 | ) 395 | 396 | def test_shows_assert_basic_rows_equality(chispa_formats): 397 | ... 398 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=chispa_formats) 399 | ``` 400 | 401 | ![custom_formats](https://raw.githubusercontent.com/MrPowers/chispa/main/images/custom_formats.png) 402 | 403 | ## Approximate column equality 404 | 405 | We can check if columns are approximately equal, which is especially useful for floating number comparisons. 406 | 407 | Here's a test that creates a DataFrame with two floating point columns and verifies that the columns are approximately equal. In this example, values are considered approximately equal if the difference is less than 0.1. 408 | 409 | ```python 410 | def test_approx_col_equality_same(): 411 | data = [ 412 | (1.1, 1.1), 413 | (2.2, 2.15), 414 | (3.3, 3.37), 415 | (None, None) 416 | ] 417 | df = spark.createDataFrame(data, ["num1", "num2"]) 418 | assert_approx_column_equality(df, "num1", "num2", 0.1) 419 | ``` 420 | 421 | Here's an example of a test with columns that are not approximately equal. 422 | 423 | ```python 424 | def test_approx_col_equality_different(): 425 | data = [ 426 | (1.1, 1.1), 427 | (2.2, 2.15), 428 | (3.3, 5.0), 429 | (None, None) 430 | ] 431 | df = spark.createDataFrame(data, ["num1", "num2"]) 432 | assert_approx_column_equality(df, "num1", "num2", 0.1) 433 | ``` 434 | 435 | This failing test will output a readable error message so the issue is easy to debug. 436 | 437 | ![ColumnsNotEqualError](https://raw.githubusercontent.com/MrPowers/chispa/main/images/columns_not_approx_equal.png) 438 | 439 | ## Approximate DataFrame equality 440 | 441 | Let's create two DataFrames and confirm they're approximately equal. 442 | 443 | ```python 444 | def test_approx_df_equality_same(): 445 | data1 = [ 446 | (1.1, "a"), 447 | (2.2, "b"), 448 | (3.3, "c"), 449 | (None, None) 450 | ] 451 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 452 | 453 | data2 = [ 454 | (1.05, "a"), 455 | (2.13, "b"), 456 | (3.3, "c"), 457 | (None, None) 458 | ] 459 | df2 = spark.createDataFrame(data2, ["num", "letter"]) 460 | 461 | assert_approx_df_equality(df1, df2, 0.1) 462 | ``` 463 | 464 | The `assert_approx_df_equality` method is smart and will only perform approximate equality operations for floating point numbers in DataFrames. It'll perform regular equality for strings and other types. 465 | 466 | Let's perform an approximate equality comparison for two DataFrames that are not equal. 467 | 468 | ```python 469 | def test_approx_df_equality_different(): 470 | data1 = [ 471 | (1.1, "a"), 472 | (2.2, "b"), 473 | (3.3, "c"), 474 | (None, None) 475 | ] 476 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 477 | 478 | data2 = [ 479 | (1.1, "a"), 480 | (5.0, "b"), 481 | (3.3, "z"), 482 | (None, None) 483 | ] 484 | df2 = spark.createDataFrame(data2, ["num", "letter"]) 485 | 486 | assert_approx_df_equality(df1, df2, 0.1) 487 | ``` 488 | 489 | Here's the pretty error message that's outputted: 490 | 491 | ![DataFramesNotEqualError](https://raw.githubusercontent.com/MrPowers/chispa/main/images/dfs_not_approx_equal.png) 492 | 493 | ## Schema mismatch messages 494 | 495 | DataFrame equality messages peform schema comparisons before analyzing the actual content of the DataFrames. DataFrames that don't have the same schemas should error out as fast as possible. 496 | 497 | Let's compare a DataFrame that has a string column an integer column with a DataFrame that has two integer columns to observe the schema mismatch message. 498 | 499 | ```python 500 | def test_schema_mismatch_message(): 501 | data1 = [ 502 | (1, "a"), 503 | (2, "b"), 504 | (3, "c"), 505 | (None, None) 506 | ] 507 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 508 | 509 | data2 = [ 510 | (1, 6), 511 | (2, 7), 512 | (3, 8), 513 | (None, None) 514 | ] 515 | df2 = spark.createDataFrame(data2, ["num", "num2"]) 516 | 517 | assert_df_equality(df1, df2) 518 | ``` 519 | 520 | Here's the error message: 521 | 522 | ![SchemasNotEqualError](https://raw.githubusercontent.com/MrPowers/chispa/main/images/schemas_not_approx_equal.png) 523 | 524 | ## Supported PySpark / Python versions 525 | 526 | chispa currently supports PySpark 2.4+ and Python 3.5+. 527 | 528 | Use chispa v0.8.2 if you're using an older Python version. 529 | 530 | PySpark 2 support will be dropped when chispa 1.x is released. 531 | 532 | ## Benchmarks 533 | 534 | TODO: Need to benchmark these methods vs. the spark-testing-base ones 535 | 536 | ## Developing chispa on your local machine 537 | 538 | You are encouraged to clone and/or fork this repo. 539 | 540 | This project uses [Poetry](https://python-poetry.org/) for packaging and dependency management. 541 | 542 | * Setup the virtual environment with `poetry install` 543 | * Run the tests with `poetry run pytest tests` 544 | 545 | Studying the codebase is a great way to learn about PySpark! 546 | 547 | ## Contributing 548 | 549 | Anyone is encouraged to submit a pull request, open an issue, or submit a bug report. 550 | 551 | We're happy to promote folks to be library maintainers if they make good contributions. 552 | -------------------------------------------------------------------------------- /chispa/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from glob import glob 6 | from typing import Callable 7 | 8 | from pyspark.sql import DataFrame 9 | 10 | # Add PySpark to the library path based on the value of SPARK_HOME if pyspark is not already in our path 11 | try: 12 | from pyspark import context # noqa: F401 13 | except ImportError: 14 | # We need to add PySpark, try use findspark, or failback to the "manually" find it 15 | try: 16 | import findspark # type: ignore[import-untyped] 17 | 18 | findspark.init() 19 | except ImportError: 20 | try: 21 | spark_home = os.environ["SPARK_HOME"] 22 | sys.path.append(os.path.join(spark_home, "python")) 23 | py4j_src_zip = glob(os.path.join(spark_home, "python", "lib", "py4j-*-src.zip")) 24 | if len(py4j_src_zip) == 0: 25 | raise ValueError( 26 | "py4j source archive not found in {}".format(os.path.join(spark_home, "python", "lib")) 27 | ) 28 | else: 29 | py4j_src_zip = sorted(py4j_src_zip)[::-1] 30 | sys.path.append(py4j_src_zip[0]) 31 | except KeyError: 32 | print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!") 33 | exit(-1) 34 | 35 | from chispa.default_formats import DefaultFormats 36 | from chispa.formatting import Color, Format, FormattingConfig, Style 37 | 38 | from .column_comparer import ( 39 | ColumnsNotEqualError, 40 | assert_approx_column_equality, 41 | assert_column_equality, 42 | ) 43 | from .dataframe_comparer import ( 44 | DataFramesNotEqualError, 45 | assert_approx_df_equality, 46 | assert_df_equality, 47 | ) 48 | from .rows_comparer import assert_basic_rows_equality 49 | 50 | 51 | class Chispa: 52 | def __init__(self, formats: FormattingConfig | None = None) -> None: 53 | if not formats: 54 | self.formats = FormattingConfig() 55 | elif isinstance(formats, FormattingConfig): 56 | self.formats = formats 57 | else: 58 | self.formats = FormattingConfig._from_arbitrary_dataclass(formats) 59 | 60 | def assert_df_equality( 61 | self, 62 | df1: DataFrame, 63 | df2: DataFrame, 64 | ignore_nullable: bool = False, 65 | transforms: list[Callable] | None = None, # type: ignore[type-arg] 66 | allow_nan_equality: bool = False, 67 | ignore_column_order: bool = False, 68 | ignore_row_order: bool = False, 69 | underline_cells: bool = False, 70 | ignore_metadata: bool = False, 71 | ignore_columns: list[str] | None = None, 72 | ) -> None: 73 | return assert_df_equality( 74 | df1, 75 | df2, 76 | ignore_nullable, 77 | transforms, 78 | allow_nan_equality, 79 | ignore_column_order, 80 | ignore_row_order, 81 | underline_cells, 82 | ignore_metadata, 83 | ignore_columns, 84 | self.formats, 85 | ) 86 | 87 | 88 | __all__ = ( 89 | "DataFramesNotEqualError", 90 | "assert_df_equality", 91 | "assert_approx_df_equality", 92 | "ColumnsNotEqualError", 93 | "assert_column_equality", 94 | "assert_approx_column_equality", 95 | "assert_basic_rows_equality", 96 | "Style", 97 | "Color", 98 | "FormattingConfig", 99 | "Format", 100 | "Chispa", 101 | "DefaultFormats", 102 | ) 103 | -------------------------------------------------------------------------------- /chispa/bcolors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | 5 | 6 | class bcolors: 7 | NC = "\033[0m" # No Color, reset all 8 | 9 | Bold = "\033[1m" 10 | Underlined = "\033[4m" 11 | Blink = "\033[5m" 12 | Inverted = "\033[7m" 13 | Hidden = "\033[8m" 14 | 15 | Black = "\033[30m" 16 | Red = "\033[31m" 17 | Green = "\033[32m" 18 | Yellow = "\033[33m" 19 | Blue = "\033[34m" 20 | Purple = "\033[35m" 21 | Cyan = "\033[36m" 22 | LightGray = "\033[37m" 23 | DarkGray = "\033[30m" 24 | LightRed = "\033[31m" 25 | LightGreen = "\033[32m" 26 | LightYellow = "\033[93m" 27 | LightBlue = "\033[34m" 28 | LightPurple = "\033[35m" 29 | LightCyan = "\033[36m" 30 | White = "\033[97m" 31 | 32 | # Style 33 | Bold = "\033[1m" 34 | Underline = "\033[4m" 35 | 36 | def __init__(self) -> None: 37 | warnings.warn("The `bcolors` class is deprecated and will be removed in a future version.", DeprecationWarning) 38 | 39 | 40 | def blue(s: str) -> str: 41 | warnings.warn("The `blue` function is deprecated and will be removed in a future version.", DeprecationWarning) 42 | return bcolors.LightBlue + str(s) + bcolors.LightRed 43 | 44 | 45 | def line_blue(s: str) -> str: 46 | return bcolors.LightBlue + s + bcolors.NC 47 | 48 | 49 | def line_red(s: str) -> str: 50 | return bcolors.LightRed + s + bcolors.NC 51 | 52 | 53 | def underline_text(input_text: str) -> str: 54 | """ 55 | Takes an input string and returns a white, underlined string (based on PrettyTable formatting) 56 | """ 57 | warnings.warn( 58 | "The `underline_text` function is deprecated and will be removed in a future version.", DeprecationWarning 59 | ) 60 | return bcolors.White + bcolors.Underline + input_text + bcolors.NC + bcolors.LightRed 61 | -------------------------------------------------------------------------------- /chispa/column_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from prettytable import PrettyTable 4 | from pyspark.sql import DataFrame 5 | 6 | from chispa.formatting import blue 7 | 8 | 9 | class ColumnsNotEqualError(Exception): 10 | """The columns are not equal""" 11 | 12 | pass 13 | 14 | 15 | def assert_column_equality(df: DataFrame, col_name1: str, col_name2: str) -> None: 16 | rows = df.select(col_name1, col_name2).collect() 17 | col_name_1_elements = [x[0] for x in rows] 18 | col_name_2_elements = [x[1] for x in rows] 19 | if col_name_1_elements != col_name_2_elements: 20 | zipped = list(zip(col_name_1_elements, col_name_2_elements)) 21 | t = PrettyTable([col_name1, col_name2]) 22 | for elements in zipped: 23 | if elements[0] == elements[1]: 24 | t.add_row([blue(str(elements[0])), blue(str(elements[1]))]) 25 | else: 26 | t.add_row([str(elements[0]), str(elements[1])]) 27 | raise ColumnsNotEqualError("\n" + t.get_string()) 28 | 29 | 30 | def assert_approx_column_equality(df: DataFrame, col_name1: str, col_name2: str, precision: float) -> None: 31 | rows = df.select(col_name1, col_name2).collect() 32 | col_name_1_elements = [x[0] for x in rows] 33 | col_name_2_elements = [x[1] for x in rows] 34 | all_rows_equal = True 35 | zipped = list(zip(col_name_1_elements, col_name_2_elements)) 36 | t = PrettyTable([col_name1, col_name2]) 37 | for elements in zipped: 38 | first = blue(str(elements[0])) 39 | second = blue(str(elements[1])) 40 | # when one is None and the other isn't, they're not equal 41 | if (elements[0] is None) != (elements[1] is None): 42 | all_rows_equal = False 43 | t.add_row([str(elements[0]), str(elements[1])]) 44 | # when both are None, they're equal 45 | elif elements[0] is None and elements[1] is None: 46 | t.add_row([first, second]) 47 | # when the diff is less than the threshhold, they're approximately equal 48 | elif abs(elements[0] - elements[1]) < precision: 49 | t.add_row([first, second]) 50 | # otherwise, they're not equal 51 | else: 52 | all_rows_equal = False 53 | t.add_row([str(elements[0]), str(elements[1])]) 54 | if all_rows_equal is False: 55 | raise ColumnsNotEqualError("\n" + t.get_string()) 56 | -------------------------------------------------------------------------------- /chispa/common_enums.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import Enum 4 | 5 | 6 | class OutputFormat(str, Enum): 7 | TABLE = "table" 8 | TREE = "tree" 9 | 10 | 11 | class TypeName(str, Enum): 12 | ARRAY = "array" 13 | STRUCT = "struct" 14 | -------------------------------------------------------------------------------- /chispa/dataframe_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import reduce 4 | from typing import Callable 5 | 6 | from pyspark.sql import DataFrame 7 | 8 | from chispa.formatting import FormattingConfig 9 | from chispa.row_comparer import are_rows_approx_equal, are_rows_equal_enhanced 10 | from chispa.rows_comparer import ( 11 | assert_basic_rows_equality, 12 | assert_generic_rows_equality, 13 | ) 14 | from chispa.schema_comparer import assert_schema_equality 15 | 16 | 17 | class DataFramesNotEqualError(Exception): 18 | """The DataFrames are not equal""" 19 | 20 | pass 21 | 22 | 23 | def assert_df_equality( 24 | df1: DataFrame, 25 | df2: DataFrame, 26 | ignore_nullable: bool = False, 27 | transforms: list[Callable] | None = None, # type: ignore[type-arg] 28 | allow_nan_equality: bool = False, 29 | ignore_column_order: bool = False, 30 | ignore_row_order: bool = False, 31 | underline_cells: bool = False, 32 | ignore_metadata: bool = False, 33 | ignore_columns: list[str] | None = None, 34 | formats: FormattingConfig | None = None, 35 | ) -> None: 36 | if not formats: 37 | formats = FormattingConfig() 38 | elif not isinstance(formats, FormattingConfig): 39 | formats = FormattingConfig._from_arbitrary_dataclass(formats) 40 | 41 | if transforms is None: 42 | transforms = [] 43 | if ignore_column_order: 44 | transforms.append(lambda df: df.select(sorted(df.columns))) 45 | if ignore_columns: 46 | transforms.append(lambda df: df.drop(*ignore_columns)) 47 | if ignore_row_order: 48 | transforms.append(lambda df: df.sort(df.columns)) 49 | 50 | df1 = reduce(lambda acc, fn: fn(acc), transforms, df1) 51 | df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) 52 | 53 | assert_schema_equality(df1.schema, df2.schema, ignore_nullable, ignore_metadata) 54 | 55 | if allow_nan_equality: 56 | assert_generic_rows_equality( 57 | df1.collect(), 58 | df2.collect(), 59 | are_rows_equal_enhanced, 60 | {"allow_nan_equality": True}, 61 | underline_cells=underline_cells, 62 | formats=formats, 63 | ) 64 | else: 65 | assert_basic_rows_equality( 66 | df1.collect(), 67 | df2.collect(), 68 | underline_cells=underline_cells, 69 | formats=formats, 70 | ) 71 | 72 | 73 | def are_dfs_equal(df1: DataFrame, df2: DataFrame) -> bool: 74 | if df1.schema != df2.schema: 75 | return False 76 | if df1.collect() != df2.collect(): 77 | return False 78 | return True 79 | 80 | 81 | def assert_approx_df_equality( 82 | df1: DataFrame, 83 | df2: DataFrame, 84 | precision: float, 85 | ignore_nullable: bool = False, 86 | transforms: list[Callable] | None = None, # type: ignore[type-arg] 87 | allow_nan_equality: bool = False, 88 | ignore_column_order: bool = False, 89 | ignore_row_order: bool = False, 90 | ignore_columns: list[str] | None = None, 91 | formats: FormattingConfig | None = None, 92 | ) -> None: 93 | if not formats: 94 | formats = FormattingConfig() 95 | elif not isinstance(formats, FormattingConfig): 96 | formats = FormattingConfig._from_arbitrary_dataclass(formats) 97 | 98 | if transforms is None: 99 | transforms = [] 100 | if ignore_column_order: 101 | transforms.append(lambda df: df.select(sorted(df.columns))) 102 | if ignore_columns: 103 | transforms.append(lambda df: df.drop(*ignore_columns)) 104 | if ignore_row_order: 105 | transforms.append(lambda df: df.sort(df.columns)) 106 | 107 | df1 = reduce(lambda acc, fn: fn(acc), transforms, df1) 108 | df2 = reduce(lambda acc, fn: fn(acc), transforms, df2) 109 | 110 | assert_schema_equality(df1.schema, df2.schema, ignore_nullable) 111 | 112 | if precision != 0: 113 | assert_generic_rows_equality( 114 | df1.collect(), 115 | df2.collect(), 116 | are_rows_approx_equal, 117 | {"precision": precision, "allow_nan_equality": allow_nan_equality}, 118 | formats=formats, 119 | ) 120 | elif allow_nan_equality: 121 | assert_generic_rows_equality( 122 | df1.collect(), df2.collect(), are_rows_equal_enhanced, {"allow_nan_equality": True}, formats=formats 123 | ) 124 | else: 125 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=formats) 126 | -------------------------------------------------------------------------------- /chispa/default_formats.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from dataclasses import dataclass, field 5 | 6 | 7 | @dataclass 8 | class DefaultFormats: 9 | """ 10 | This class is now deprecated and should be removed in a future release. 11 | """ 12 | 13 | mismatched_rows: list[str] = field(default_factory=lambda: ["red"]) 14 | matched_rows: list[str] = field(default_factory=lambda: ["blue"]) 15 | mismatched_cells: list[str] = field(default_factory=lambda: ["red", "underline"]) 16 | matched_cells: list[str] = field(default_factory=lambda: ["blue"]) 17 | 18 | def __post_init__(self) -> None: 19 | warnings.warn( 20 | "DefaultFormats is deprecated. Use `chispa.formatting.FormattingConfig` instead.", DeprecationWarning 21 | ) 22 | -------------------------------------------------------------------------------- /chispa/formatting/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from chispa.formatting.format_string import blue, format_string 4 | from chispa.formatting.formats import RESET, Color, Format, Style 5 | from chispa.formatting.formatting_config import FormattingConfig 6 | 7 | __all__ = ("Style", "Color", "FormattingConfig", "Format", "format_string", "RESET", "blue") 8 | -------------------------------------------------------------------------------- /chispa/formatting/format_string.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from chispa.formatting.formats import RESET, Color, Format 4 | 5 | 6 | def format_string(input_string: str, format: Format) -> str: 7 | if not format.color and not format.style: 8 | return input_string 9 | 10 | formatted_string = input_string 11 | codes = [] 12 | 13 | if format.style: 14 | for style in format.style: 15 | codes.append(style.value) 16 | 17 | if format.color: 18 | codes.append(format.color.value) 19 | 20 | formatted_string = "".join(codes) + formatted_string + RESET 21 | return formatted_string 22 | 23 | 24 | def blue(string: str) -> str: 25 | return Color.LIGHT_BLUE + string + Color.LIGHT_RED 26 | -------------------------------------------------------------------------------- /chispa/formatting/formats.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | 6 | RESET = "\033[0m" 7 | 8 | 9 | class Color(str, Enum): 10 | """ 11 | Enum for terminal colors. 12 | Each color is represented by its corresponding ANSI escape code. 13 | """ 14 | 15 | BLACK = "\033[30m" 16 | RED = "\033[31m" 17 | GREEN = "\033[32m" 18 | YELLOW = "\033[33m" 19 | BLUE = "\033[34m" 20 | PURPLE = "\033[35m" 21 | CYAN = "\033[36m" 22 | LIGHT_GRAY = "\033[37m" 23 | DARK_GRAY = "\033[90m" 24 | LIGHT_RED = "\033[91m" 25 | LIGHT_GREEN = "\033[92m" 26 | LIGHT_YELLOW = "\033[93m" 27 | LIGHT_BLUE = "\033[94m" 28 | LIGHT_PURPLE = "\033[95m" 29 | LIGHT_CYAN = "\033[96m" 30 | WHITE = "\033[97m" 31 | 32 | 33 | class Style(str, Enum): 34 | """ 35 | Enum for text styles. 36 | Each style is represented by its corresponding ANSI escape code. 37 | """ 38 | 39 | BOLD = "\033[1m" 40 | UNDERLINE = "\033[4m" 41 | BLINK = "\033[5m" 42 | INVERT = "\033[7m" 43 | HIDE = "\033[8m" 44 | 45 | 46 | @dataclass 47 | class Format: 48 | """ 49 | Data class to represent text formatting with color and style. 50 | 51 | Attributes: 52 | color (Color | None): The color for the text. 53 | style (list[Style] | None): A list of styles for the text. 54 | """ 55 | 56 | color: Color | None = None 57 | style: list[Style] | None = None 58 | 59 | @classmethod 60 | def from_dict(cls, format_dict: dict[str, str | list[str]]) -> Format: 61 | """ 62 | Create a Format instance from a dictionary. 63 | 64 | Args: 65 | format_dict (dict): A dictionary with keys 'color' and/or 'style'. 66 | """ 67 | if not isinstance(format_dict, dict): 68 | raise ValueError("Input must be a dictionary") 69 | 70 | valid_keys = {"color", "style"} 71 | invalid_keys = set(format_dict) - valid_keys 72 | if invalid_keys: 73 | raise ValueError(f"Invalid keys in format dictionary: {invalid_keys}. Valid keys are {valid_keys}") 74 | 75 | if isinstance(format_dict.get("color"), list): 76 | raise TypeError("The value for key 'color' should be a string, not a list!") 77 | color = cls._get_color_enum(format_dict.get("color")) # type: ignore[arg-type] 78 | 79 | style = format_dict.get("style") 80 | if isinstance(style, str): 81 | styles = [cls._get_style_enum(style)] 82 | elif isinstance(style, list): 83 | styles = [cls._get_style_enum(s) for s in style] 84 | else: 85 | styles = None 86 | 87 | return cls(color=color, style=styles) # type: ignore[arg-type] 88 | 89 | @classmethod 90 | def from_list(cls, values: list[str]) -> Format: 91 | """ 92 | Create a Format instance from a list of strings. 93 | 94 | Args: 95 | values (list[str]): A list of strings representing colors and styles. 96 | """ 97 | if not all(isinstance(value, str) for value in values): 98 | raise ValueError("All elements in the list must be strings") 99 | 100 | color = None 101 | styles = [] 102 | valid_colors = [c.name.lower() for c in Color] 103 | valid_styles = [s.name.lower() for s in Style] 104 | 105 | for value in values: 106 | if value in valid_colors: 107 | color = Color[value.upper()] 108 | elif value in valid_styles: 109 | styles.append(Style[value.upper()]) 110 | else: 111 | raise ValueError( 112 | f"Invalid value: {value}. Valid values are colors: {valid_colors} and styles: {valid_styles}" 113 | ) 114 | 115 | return cls(color=color, style=styles if styles else None) 116 | 117 | @staticmethod 118 | def _get_color_enum(color: Color | str | None) -> Color | None: 119 | if isinstance(color, Color): 120 | return color 121 | elif isinstance(color, str): 122 | try: 123 | return Color[color.upper()] 124 | except KeyError: 125 | valid_colors = [c.name.lower() for c in Color] 126 | raise ValueError(f"Invalid color name: {color}. Valid color names are {valid_colors}") 127 | return None 128 | 129 | @staticmethod 130 | def _get_style_enum(style: Style | str | None) -> Style | None: 131 | if isinstance(style, Style): 132 | return style 133 | elif isinstance(style, str): 134 | try: 135 | return Style[style.upper()] 136 | except KeyError: 137 | valid_styles = [f.name.lower() for f in Style] 138 | raise ValueError(f"Invalid style name: {style}. Valid style names are {valid_styles}") 139 | return None 140 | -------------------------------------------------------------------------------- /chispa/formatting/formatting_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from typing import Any, ClassVar 5 | 6 | from chispa.default_formats import DefaultFormats 7 | from chispa.formatting.formats import Color, Format, Style 8 | 9 | 10 | class FormattingConfig: 11 | """ 12 | Class to manage and parse formatting configurations. 13 | """ 14 | 15 | VALID_KEYS: ClassVar = {"color", "style"} 16 | 17 | def __init__( 18 | self, 19 | mismatched_rows: Format | dict[str, str | list[str]] = Format(Color.RED), 20 | matched_rows: Format | dict[str, str | list[str]] = Format(Color.BLUE), 21 | mismatched_cells: Format | dict[str, str | list[str]] = Format(Color.RED, [Style.UNDERLINE]), 22 | matched_cells: Format | dict[str, str | list[str]] = Format(Color.BLUE), 23 | ): 24 | """ 25 | Initializes the FormattingConfig with given or default formatting. 26 | 27 | Each of the arguments can be provided as a `Format` object or a dictionary with the following keys: 28 | - 'color': A string representing a color name, which should be one of the valid colors: 29 | ['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', 30 | 'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 31 | 'light_purple', 'light_cyan', 'white']. 32 | - 'style': A string or list of strings representing styles, which should be one of the valid styles: 33 | ['bold', 'underline', 'blink', 'invert', 'hide']. 34 | 35 | Args: 36 | mismatched_rows (Format | dict): Format or dictionary for mismatched rows. 37 | matched_rows (Format | dict): Format or dictionary for matched rows. 38 | mismatched_cells (Format | dict): Format or dictionary for mismatched cells. 39 | matched_cells (Format | dict): Format or dictionary for matched cells. 40 | 41 | Raises: 42 | ValueError: If the dictionary contains invalid keys or values. 43 | """ 44 | self.mismatched_rows: Format = self._parse_format(mismatched_rows) 45 | self.matched_rows: Format = self._parse_format(matched_rows) 46 | self.mismatched_cells: Format = self._parse_format(mismatched_cells) 47 | self.matched_cells: Format = self._parse_format(matched_cells) 48 | 49 | def _parse_format(self, format: Format | dict[str, str | list[str]]) -> Format: 50 | if isinstance(format, Format): 51 | return format 52 | elif isinstance(format, dict): 53 | return Format.from_dict(format) 54 | raise ValueError("Invalid format type. Must be Format or dict.") 55 | 56 | @classmethod 57 | def _from_arbitrary_dataclass(cls, instance: Any) -> FormattingConfig: 58 | """ 59 | Converts an instance of an arbitrary class with specified fields to a FormattingConfig instance. 60 | This method is purely for backwards compatibility and should be removed in a future release, 61 | together with the `DefaultFormats` class. 62 | """ 63 | 64 | if not isinstance(instance, DefaultFormats): 65 | warnings.warn( 66 | "Using an arbitrary dataclass is deprecated. Use `chispa.formatting.FormattingConfig` instead.", 67 | DeprecationWarning, 68 | ) 69 | 70 | mismatched_rows = Format.from_list(getattr(instance, "mismatched_rows")) 71 | matched_rows = Format.from_list(getattr(instance, "matched_rows")) 72 | mismatched_cells = Format.from_list(getattr(instance, "mismatched_cells")) 73 | matched_cells = Format.from_list(getattr(instance, "matched_cells")) 74 | 75 | return cls( 76 | mismatched_rows=mismatched_rows, 77 | matched_rows=matched_rows, 78 | mismatched_cells=mismatched_cells, 79 | matched_cells=matched_cells, 80 | ) 81 | -------------------------------------------------------------------------------- /chispa/number_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from decimal import Decimal 5 | from typing import Any 6 | 7 | 8 | def isnan(x: Any) -> bool: 9 | try: 10 | return math.isnan(x) 11 | except TypeError: 12 | return False 13 | 14 | 15 | def nan_safe_equality(x: int | float, y: int | float | Decimal) -> bool: 16 | return (x == y) or (isnan(x) and isnan(y)) 17 | 18 | 19 | def nan_safe_approx_equality(x: int | float, y: int | float, precision: float | Decimal) -> bool: 20 | return (abs(x - y) <= precision) or (isnan(x) and isnan(y)) 21 | -------------------------------------------------------------------------------- /chispa/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/chispa/py.typed -------------------------------------------------------------------------------- /chispa/row_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | from pyspark.sql import Row 6 | 7 | from chispa.number_helpers import nan_safe_approx_equality, nan_safe_equality 8 | 9 | 10 | def are_rows_equal(r1: Row, r2: Row) -> bool: 11 | return r1 == r2 12 | 13 | 14 | def are_rows_equal_enhanced(r1: Row | None, r2: Row | None, allow_nan_equality: bool) -> bool: 15 | if r1 is None and r2 is None: 16 | return True 17 | if r1 is None or r2 is None: 18 | return False 19 | d1 = r1.asDict() 20 | d2 = r2.asDict() 21 | if allow_nan_equality: 22 | for key in d1.keys() & d2.keys(): 23 | if not (nan_safe_equality(d1[key], d2[key])): 24 | return False 25 | return True 26 | else: 27 | return r1 == r2 28 | 29 | 30 | def are_rows_approx_equal(r1: Row | None, r2: Row | None, precision: float, allow_nan_equality: bool = False) -> bool: 31 | if r1 is None and r2 is None: 32 | return True 33 | if r1 is None or r2 is None: 34 | return False 35 | d1 = r1.asDict() 36 | d2 = r2.asDict() 37 | allEqual = True 38 | for key in d1.keys() & d2.keys(): 39 | if isinstance(d1[key], float) and isinstance(d2[key], float): 40 | if allow_nan_equality and not (nan_safe_approx_equality(d1[key], d2[key], precision)): 41 | allEqual = False 42 | elif not (allow_nan_equality) and math.isnan(abs(d1[key] - d2[key])): 43 | allEqual = False 44 | elif abs(d1[key] - d2[key]) > precision: 45 | allEqual = False 46 | elif d1[key] != d2[key]: 47 | allEqual = False 48 | return allEqual 49 | -------------------------------------------------------------------------------- /chispa/rows_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from itertools import zip_longest 4 | from typing import Any, Callable 5 | 6 | from prettytable import PrettyTable 7 | from pyspark.sql import Row 8 | 9 | import chispa 10 | from chispa.formatting import FormattingConfig, format_string 11 | 12 | 13 | def assert_basic_rows_equality( 14 | rows1: list[Row], rows2: list[Row], underline_cells: bool = False, formats: FormattingConfig | None = None 15 | ) -> None: 16 | if not formats: 17 | formats = FormattingConfig() 18 | elif not isinstance(formats, FormattingConfig): 19 | formats = FormattingConfig._from_arbitrary_dataclass(formats) 20 | 21 | if rows1 != rows2: 22 | t = PrettyTable(["df1", "df2"]) 23 | zipped = list(zip_longest(rows1, rows2)) 24 | all_rows_equal = True 25 | 26 | for r1, r2 in zipped: 27 | if r1 is None and r2 is not None: 28 | t.add_row([None, format_string(str(r2), formats.mismatched_rows)]) 29 | all_rows_equal = False 30 | elif r1 is not None and r2 is None: 31 | t.add_row([format_string(str(r1), formats.mismatched_rows), None]) 32 | all_rows_equal = False 33 | else: 34 | r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) 35 | r1_string = [] 36 | r2_string = [] 37 | for r1_field, r2_field in r_zipped: 38 | if r1[r1_field] != r2[r2_field]: 39 | all_rows_equal = False 40 | r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) 41 | r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) 42 | else: 43 | r1_string.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) 44 | r2_string.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) 45 | r1_res = ", ".join(r1_string) 46 | r2_res = ", ".join(r2_string) 47 | 48 | t.add_row([r1_res, r2_res]) 49 | if all_rows_equal is False: 50 | raise chispa.DataFramesNotEqualError("\n" + t.get_string()) 51 | 52 | 53 | def assert_generic_rows_equality( 54 | rows1: list[Row], 55 | rows2: list[Row], 56 | row_equality_fun: Callable, # type: ignore[type-arg] 57 | row_equality_fun_args: dict[str, Any], 58 | underline_cells: bool = False, 59 | formats: FormattingConfig | None = None, 60 | ) -> None: 61 | if not formats: 62 | formats = FormattingConfig() 63 | elif not isinstance(formats, FormattingConfig): 64 | formats = FormattingConfig._from_arbitrary_dataclass(formats) 65 | 66 | df1_rows = rows1 67 | df2_rows = rows2 68 | zipped = list(zip_longest(df1_rows, df2_rows)) 69 | t = PrettyTable(["df1", "df2"]) 70 | all_rows_equal = True 71 | for r1, r2 in zipped: 72 | # rows are not equal when one is None and the other isn't 73 | if (r1 is None) ^ (r2 is None): 74 | all_rows_equal = False 75 | t.add_row([ 76 | format_string(str(r1), formats.mismatched_rows), 77 | format_string(str(r2), formats.mismatched_rows), 78 | ]) 79 | # rows are equal 80 | elif row_equality_fun(r1, r2, **row_equality_fun_args): 81 | r1_string = ", ".join(map(lambda f: f"{f}={r1[f]}", r1.__fields__)) 82 | r2_string = ", ".join(map(lambda f: f"{f}={r2[f]}", r2.__fields__)) 83 | t.add_row([ 84 | format_string(r1_string, formats.matched_rows), 85 | format_string(r2_string, formats.matched_rows), 86 | ]) 87 | # otherwise, rows aren't equal 88 | else: 89 | r_zipped = list(zip_longest(r1.__fields__, r2.__fields__)) 90 | r1_string_list: list[str] = [] 91 | r2_string_list: list[str] = [] 92 | for r1_field, r2_field in r_zipped: 93 | if r1[r1_field] != r2[r2_field]: 94 | all_rows_equal = False 95 | r1_string_list.append(format_string(f"{r1_field}={r1[r1_field]}", formats.mismatched_cells)) 96 | r2_string_list.append(format_string(f"{r2_field}={r2[r2_field]}", formats.mismatched_cells)) 97 | else: 98 | r1_string_list.append(format_string(f"{r1_field}={r1[r1_field]}", formats.matched_cells)) 99 | r2_string_list.append(format_string(f"{r2_field}={r2[r2_field]}", formats.matched_cells)) 100 | r1_res = ", ".join(r1_string_list) 101 | r2_res = ", ".join(r2_string_list) 102 | 103 | t.add_row([r1_res, r2_res]) 104 | if all_rows_equal is False: 105 | raise chispa.DataFramesNotEqualError("\n" + t.get_string()) 106 | -------------------------------------------------------------------------------- /chispa/schema_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from itertools import zip_longest 5 | 6 | from prettytable import PrettyTable 7 | from pyspark.sql.types import StructField, StructType 8 | 9 | from chispa.bcolors import bcolors, line_blue, line_red 10 | from chispa.common_enums import OutputFormat, TypeName 11 | from chispa.formatting import blue 12 | 13 | 14 | class SchemasNotEqualError(Exception): 15 | """The schemas are not equal""" 16 | 17 | pass 18 | 19 | 20 | def print_schema_diff( 21 | s1: StructType, 22 | s2: StructType, 23 | ignore_nullable: bool, 24 | ignore_metadata: bool, 25 | output_format: OutputFormat = OutputFormat.TABLE, 26 | ) -> None: 27 | if output_format == OutputFormat.TABLE: 28 | schema_diff_table: PrettyTable = create_schema_comparison_table(s1, s2, ignore_nullable, ignore_metadata) 29 | print(schema_diff_table) 30 | elif output_format == OutputFormat.TREE: 31 | schema_diff_tree: str = create_schema_comparison_tree(s1, s2, ignore_nullable, ignore_metadata) 32 | print(schema_diff_tree) 33 | else: 34 | raise ValueError(f"output_format must be one of {OutputFormat.__members__}") 35 | 36 | 37 | def create_schema_comparison_tree(s1: StructType, s2: StructType, ignore_nullable: bool, ignore_metadata: bool) -> str: 38 | def parse_schema_as_tree(s: StructType, indent: int) -> tuple[list[str], list[StructField]]: 39 | tree_lines = [] 40 | fields = [] 41 | 42 | for struct_field in s: 43 | nullable = "(nullable = true)" if struct_field.nullable else "(nullable = false)" 44 | struct_field_type = struct_field.dataType.typeName() 45 | 46 | struct_prefix = f"{indent * ' '}|{'-' * 2}" 47 | struct_as_string = f"{struct_field.name}: {struct_field_type} {nullable}" 48 | 49 | tree_lines += [f"{struct_prefix} {struct_as_string}"] 50 | 51 | if not struct_field_type == TypeName.STRUCT: 52 | fields += [struct_field] 53 | continue 54 | 55 | tree_line_nested, fields_nested = parse_schema_as_tree(struct_field.dataType, indent + 4) # type: ignore[arg-type] 56 | 57 | fields += [struct_field] 58 | tree_lines += tree_line_nested 59 | fields += fields_nested 60 | 61 | return tree_lines, fields 62 | 63 | tree_space = 6 64 | s1_tree, s1_fields = parse_schema_as_tree(s1, 0) 65 | s2_tree, s2_fields = parse_schema_as_tree(s2, 0) 66 | 67 | widest_line = max(len(line) for line in s1_tree) 68 | longest_tree = max(len(s1_tree), len(s2_tree)) 69 | schema_gap = widest_line + tree_space 70 | 71 | tree = "\nschema1".ljust(schema_gap) + "schema2\n" 72 | for i in range(longest_tree): 73 | line1 = line2 = "" 74 | s1_field = s2_field = None 75 | 76 | if i < len(s1_tree): 77 | line1 = s1_tree[i] 78 | s1_field = s1_fields[i] 79 | if i < len(s2_tree): 80 | line2 = s2_tree[i] 81 | s2_field = s2_fields[i] 82 | 83 | tree_line = line1.ljust(schema_gap) + line2 84 | 85 | if are_structfields_equal(s1_field, s2_field, ignore_nullable, ignore_metadata): 86 | tree += line_blue(tree_line) + "\n" 87 | else: 88 | tree += line_red(tree_line) + "\n" 89 | 90 | tree += bcolors.NC 91 | return tree 92 | 93 | 94 | def create_schema_comparison_table( 95 | s1: StructType, s2: StructType, ignore_nullable: bool, ignore_metadata: bool 96 | ) -> PrettyTable: 97 | t = PrettyTable(["schema1", "schema2"]) 98 | zipped = list(zip_longest(s1, s2)) 99 | for sf1, sf2 in zipped: 100 | if are_structfields_equal(sf1, sf2, ignore_nullable, ignore_metadata): 101 | t.add_row([blue(str(sf1)), blue(str(sf2))]) 102 | else: 103 | t.add_row([sf1, sf2]) 104 | return t 105 | 106 | 107 | def check_if_schemas_are_wide(s1: StructType, s2: StructType) -> bool: 108 | contains_nested_structs = any(sf.dataType.typeName() == TypeName.STRUCT for sf in s1) or any( 109 | sf.dataType.typeName() == TypeName.STRUCT for sf in s2 110 | ) 111 | contains_many_columns = len(s1) > 10 or len(s2) > 10 112 | return contains_nested_structs or contains_many_columns 113 | 114 | 115 | def handle_schemas_not_equal(s1: StructType, s2: StructType, ignore_nullable: bool, ignore_metadata: bool) -> None: 116 | schemas_are_wide = check_if_schemas_are_wide(s1, s2) 117 | if schemas_are_wide: 118 | error_message = create_schema_comparison_tree(s1, s2, ignore_nullable, ignore_metadata) 119 | else: 120 | t = create_schema_comparison_table(s1, s2, ignore_nullable, ignore_metadata) 121 | error_message = "\n" + t.get_string() 122 | raise SchemasNotEqualError(error_message) 123 | 124 | 125 | def assert_schema_equality( 126 | s1: StructType, s2: StructType, ignore_nullable: bool = False, ignore_metadata: bool = False 127 | ) -> None: 128 | if not ignore_nullable and not ignore_metadata: 129 | assert_basic_schema_equality(s1, s2) 130 | else: 131 | assert_schema_equality_full(s1, s2, ignore_nullable, ignore_metadata) 132 | 133 | 134 | def assert_schema_equality_full( 135 | s1: StructType, s2: StructType, ignore_nullable: bool = False, ignore_metadata: bool = False 136 | ) -> None: 137 | def inner(s1: StructType, s2: StructType, ignore_nullable: bool, ignore_metadata: bool) -> bool: 138 | if len(s1) != len(s2): 139 | return False 140 | zipped = list(zip_longest(s1, s2)) 141 | for sf1, sf2 in zipped: 142 | if not are_structfields_equal(sf1, sf2, ignore_nullable, ignore_metadata): 143 | return False 144 | return True 145 | 146 | if not inner(s1, s2, ignore_nullable, ignore_metadata): 147 | handle_schemas_not_equal(s1, s2, ignore_nullable, ignore_metadata) 148 | 149 | 150 | # deprecate this 151 | # perhaps it is a little faster, but do we really need this? 152 | # I think schema equality operations are really fast to begin with 153 | def assert_basic_schema_equality(s1: StructType, s2: StructType) -> None: 154 | if s1 != s2: 155 | handle_schemas_not_equal(s1, s2, ignore_nullable=False, ignore_metadata=False) 156 | 157 | 158 | # deprecate this. ignore_nullable should be a flag. 159 | def assert_schema_equality_ignore_nullable(s1: StructType, s2: StructType) -> None: 160 | if not are_schemas_equal_ignore_nullable(s1, s2): 161 | handle_schemas_not_equal(s1, s2, ignore_nullable=True, ignore_metadata=False) 162 | 163 | 164 | # deprecate this. ignore_nullable should be a flag. 165 | def are_schemas_equal_ignore_nullable(s1: StructType, s2: StructType, ignore_metadata: bool = False) -> bool: 166 | if len(s1) != len(s2): 167 | return False 168 | zipped = list(zip_longest(s1, s2)) 169 | for sf1, sf2 in zipped: 170 | if not are_structfields_equal(sf1, sf2, True, ignore_metadata): 171 | return False 172 | return True 173 | 174 | 175 | # "ignore_nullability" should be "ignore_nullable" for consistent terminology 176 | def are_structfields_equal( 177 | sf1: StructField | None, sf2: StructField | None, ignore_nullability: bool = False, ignore_metadata: bool = False 178 | ) -> bool: 179 | if not ignore_nullability and not ignore_metadata: 180 | return sf1 == sf2 181 | else: 182 | if sf1 is None or sf2 is None: 183 | if sf1 is None and sf2 is None: 184 | return True 185 | else: 186 | return False 187 | if sf1.name != sf2.name: 188 | return False 189 | if not ignore_metadata and sf1.metadata != sf2.metadata: 190 | return False 191 | else: 192 | return are_datatypes_equal_ignore_nullable(sf1.dataType, sf2.dataType, ignore_metadata) # type: ignore[no-any-return, no-untyped-call] 193 | 194 | 195 | # deprecate this 196 | @typing.no_type_check 197 | def are_datatypes_equal_ignore_nullable(dt1, dt2, ignore_metadata: bool = False) -> bool: 198 | """Checks if datatypes are equal, descending into structs and arrays to 199 | ignore nullability. 200 | """ 201 | if dt1.typeName() == dt2.typeName(): 202 | # Account for array types by inspecting elementType. 203 | if dt1.typeName() == TypeName.ARRAY: 204 | return are_datatypes_equal_ignore_nullable(dt1.elementType, dt2.elementType, ignore_metadata) 205 | elif dt1.typeName() == TypeName.STRUCT: 206 | return are_schemas_equal_ignore_nullable(dt1, dt2, ignore_metadata) 207 | else: 208 | # Some data types have additional attributes (e.g. precision and scale for Decimal), 209 | # and the type equality check must also check for equality of these attributes. 210 | return vars(dt1) == vars(dt2) 211 | else: 212 | return False 213 | -------------------------------------------------------------------------------- /chispa/structfield_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from chispa.schema_comparer import are_structfields_equal 4 | 5 | __all__ = ("are_structfields_equal",) 6 | -------------------------------------------------------------------------------- /ci/environment-py39.yml: -------------------------------------------------------------------------------- 1 | name: chispa_test_env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - pytest 8 | - pytest-describe 9 | - pyspark 10 | - findspark 11 | -------------------------------------------------------------------------------- /docs/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation. 2 | Script was taken from 3 | https://mkdocstrings.github.io/recipes/#automatic-code-reference-pages 4 | """ 5 | 6 | from __future__ import annotations 7 | 8 | from pathlib import Path 9 | 10 | import mkdocs_gen_files 11 | 12 | nav = mkdocs_gen_files.Nav() 13 | 14 | for path in sorted(Path(".").rglob("chispa/**/*.py")): 15 | module_path = path.relative_to(".").with_suffix("") 16 | doc_path = path.relative_to(".").with_suffix(".md") 17 | full_doc_path = Path("reference", doc_path) 18 | 19 | parts = tuple(module_path.parts) 20 | 21 | if parts[-1] == "__init__": 22 | parts = parts[:-1] 23 | doc_path = doc_path.with_name("index.md") 24 | full_doc_path = full_doc_path.with_name("index.md") 25 | elif parts[-1] == "__main__": 26 | continue 27 | 28 | nav[parts] = doc_path.as_posix() # 29 | 30 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 31 | ident = ".".join(parts) 32 | fd.write(f"::: {ident}") 33 | 34 | mkdocs_gen_files.set_edit_path(full_doc_path, path) 35 | 36 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 37 | nav_file.writelines(nav.build_literate_nav()) 38 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | {!README.md!} 2 | -------------------------------------------------------------------------------- /images/columns_not_approx_equal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/columns_not_approx_equal.png -------------------------------------------------------------------------------- /images/columns_not_equal_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/columns_not_equal_error.png -------------------------------------------------------------------------------- /images/custom_formats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/custom_formats.png -------------------------------------------------------------------------------- /images/df_not_equal_underlined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/df_not_equal_underlined.png -------------------------------------------------------------------------------- /images/dfs_not_approx_equal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/dfs_not_approx_equal.png -------------------------------------------------------------------------------- /images/dfs_not_equal_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/dfs_not_equal_error.png -------------------------------------------------------------------------------- /images/dfs_not_equal_error_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/dfs_not_equal_error_old.png -------------------------------------------------------------------------------- /images/ignore_column_order_false.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/ignore_column_order_false.png -------------------------------------------------------------------------------- /images/ignore_row_order_false.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/ignore_row_order_false.png -------------------------------------------------------------------------------- /images/ignore_row_order_false_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/ignore_row_order_false_old.png -------------------------------------------------------------------------------- /images/nullable_off_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/nullable_off_error.png -------------------------------------------------------------------------------- /images/schemas_not_approx_equal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/images/schemas_not_approx_equal.png -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Chispa 2 | edit_uri: edit/main/docs/ 3 | repo_name: MrPowers/chispa 4 | repo_url: https://github.com/MrPowers/chispa 5 | site_url: https://mrpowers.github.io/chispa/ 6 | site_description: PySpark test helper methods with beautiful error messages. 7 | site_author: Matthew Powers 8 | copyright: Maintained by Matthew. 9 | 10 | nav: 11 | - Home: index.md 12 | - API Docs: reference/SUMMARY.md 13 | 14 | plugins: 15 | - search 16 | - gen-files: 17 | scripts: 18 | - docs/gen_ref_pages.py 19 | - section-index 20 | - mkdocstrings: 21 | default_handler: python 22 | handlers: 23 | python: 24 | options: 25 | docstring_style: google 26 | docstring_options: 27 | show_if_no_docstring: true 28 | show_source: true 29 | 30 | theme: 31 | name: material 32 | features: 33 | - content.action.edit 34 | - content.code.copy 35 | - navigation.footer 36 | palette: 37 | - media: "(prefers-color-scheme: light)" 38 | scheme: default 39 | primary: indigo 40 | accent: amber 41 | toggle: 42 | icon: material/brightness-7 43 | name: Switch to dark mode 44 | - media: "(prefers-color-scheme: dark)" 45 | scheme: slate 46 | primary: indigo 47 | accent: amber 48 | toggle: 49 | icon: material/brightness-4 50 | name: Switch to light mode 51 | icon: 52 | repo: fontawesome/brands/github 53 | 54 | extra: 55 | social: 56 | - icon: fontawesome/brands/github 57 | link: https://github.com/MrPowers/chispa 58 | - icon: fontawesome/brands/python 59 | link: https://pypi.org/project/chispa/ 60 | 61 | markdown_extensions: 62 | - admonition 63 | - attr_list 64 | - md_in_html 65 | - pymdownx.details 66 | - pymdownx.superfences 67 | - toc: 68 | permalink: true 69 | - pymdownx.arithmatex: 70 | generic: true 71 | - markdown_include.include: 72 | base_path: . 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "chispa" 3 | version = "0.11.1" 4 | description = "Pyspark test helper library" 5 | authors = ["Matthew Powers "] 6 | maintainers = [ 7 | "Semyon Sinchenko ", 8 | "Florian Maas " 9 | ] 10 | repository = "https://github.com/MrPowers/chispa" 11 | documentation = "https://mrpowers.github.io/chispa" 12 | readme = "README.md" 13 | license = "MIT" 14 | keywords = ['apachespark', 'spark', 'pyspark', 'pytest'] 15 | classifiers = [ 16 | "Development Status :: 3 - Alpha", 17 | "Environment :: Console", 18 | "Framework :: Pytest", 19 | "Intended Audience :: Developers", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Topic :: Software Development :: Libraries :: Python Modules", 30 | "Topic :: Software Development :: Quality Assurance", 31 | "Topic :: Software Development :: Testing", 32 | ] 33 | 34 | [tool.poetry.dependencies] 35 | python = ">=3.8,<4.0" 36 | prettytable = "^3.10.2" 37 | 38 | [tool.poetry.group.dev.dependencies] 39 | pytest = "7.4.2" 40 | pyspark = ">3.0.0" 41 | findspark = "1.4.2" 42 | pytest-describe = "^2.1.0" 43 | pytest-cov = "^5.0.0" 44 | pre-commit = "3.3.3" 45 | mypy = "^1.11.0" 46 | 47 | [tool.poetry.group.mkdocs.dependencies] 48 | mkdocs = "^1.6.0" 49 | mkdocstrings-python = "*" 50 | mkdocs-gen-files = "*" 51 | mkdocs-literate-nav = "*" 52 | mkdocs-section-index = "*" 53 | markdown-include = "*" 54 | mkdocs-material = "*" 55 | 56 | [tool.poetry.group.mkdocs] 57 | optional = true 58 | 59 | [build-system] 60 | requires = ["poetry-core"] 61 | build-backend = "poetry.core.masonry.api" 62 | 63 | [tool.ruff] 64 | target-version = "py39" 65 | line-length = 120 66 | fix = true 67 | 68 | [tool.ruff.format] 69 | preview = true 70 | 71 | [tool.ruff.lint] 72 | select = ["E", "F", "I", "RUF", "UP"] 73 | ignore = [ 74 | # Line too long 75 | "E501" 76 | ] 77 | 78 | [tool.ruff.lint.flake8-type-checking] 79 | strict = true 80 | 81 | [tool.ruff.lint.per-file-ignores] 82 | "tests/*" = ["S101", "S603"] 83 | 84 | [tool.ruff.lint.isort] 85 | required-imports = ["from __future__ import annotations"] 86 | 87 | [tool.mypy] 88 | files = ["chispa"] 89 | explicit_package_bases = true 90 | disallow_any_unimported = true 91 | enable_error_code = [ 92 | "ignore-without-code", 93 | "redundant-expr", 94 | "truthy-bool", 95 | ] 96 | strict = true 97 | pretty = true 98 | show_error_codes = true 99 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrPowers/chispa/57db7b6c40c9f841e29d574982404a24bd92bef8/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from chispa.formatting import FormattingConfig 6 | 7 | 8 | @pytest.fixture() 9 | def my_formats(): 10 | return FormattingConfig( 11 | mismatched_rows={"color": "light_yellow"}, 12 | matched_rows={"color": "cyan", "style": "bold"}, 13 | mismatched_cells={"color": "purple"}, 14 | matched_cells={"color": "blue"}, 15 | ) 16 | 17 | 18 | @pytest.fixture() 19 | def my_chispa(): 20 | return FormattingConfig( 21 | mismatched_rows={"color": "light_yellow"}, 22 | matched_rows={"color": "cyan", "style": "bold"}, 23 | mismatched_cells={"color": "purple"}, 24 | matched_cells={"color": "blue"}, 25 | ) 26 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[34m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[34m|-- age: integer (nullable = true) |-- age: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_number: integer (nullable = true) |-- fav_number: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_numbers: array (nullable = true) |-- fav_numbers: array (nullable = true)\x1b[0m\n\x1b[31m|-- fav_colors: struct (nullable = true) |-- fav_colors: struct (nullable = true)\x1b[0m\n\x1b[31m |-- red: integer (nullable = true) |-- orange: integer (nullable = true)\x1b[0m\n\x1b[34m |-- green: integer (nullable = true) |-- green: integer (nullable = true)\x1b[0m\n\x1b[31m |-- blue: integer (nullable = true) |-- yellow: integer (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas_different_lengths.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[34m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[34m|-- age: integer (nullable = true) |-- age: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_number: integer (nullable = true) |-- fav_number: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_numbers: array (nullable = true) |-- fav_numbers: array (nullable = true)\x1b[0m\n\x1b[31m|-- fav_colors: struct (nullable = true) |-- fav_colors: struct (nullable = true)\x1b[0m\n\x1b[31m |-- red: integer (nullable = true) |-- orange: integer (nullable = true)\x1b[0m\n\x1b[34m |-- green: integer (nullable = true) |-- green: integer (nullable = true)\x1b[0m\n\x1b[31m |-- blue: integer (nullable = true) |-- yellow: integer (nullable = true)\x1b[0m\n\x1b[31m |-- purple: integer (nullable = true)\x1b[0m\n\x1b[31m |-- phone_number: string (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas_ignore_metadata.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[34m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[34m|-- age: integer (nullable = true) |-- age: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_number: integer (nullable = true) |-- fav_number: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_numbers: array (nullable = true) |-- fav_numbers: array (nullable = true)\x1b[0m\n\x1b[31m|-- fav_colors: struct (nullable = true) |-- fav_colors: struct (nullable = true)\x1b[0m\n\x1b[31m |-- red: integer (nullable = true) |-- orange: integer (nullable = true)\x1b[0m\n\x1b[34m |-- green: integer (nullable = true) |-- green: integer (nullable = true)\x1b[0m\n\x1b[31m |-- blue: integer (nullable = true) |-- yellow: integer (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas_ignore_nullable.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[34m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[34m|-- age: integer (nullable = true) |-- age: integer (nullable = false)\x1b[0m\n\x1b[34m|-- fav_number: integer (nullable = true) |-- fav_number: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_numbers: array (nullable = true) |-- fav_numbers: array (nullable = false)\x1b[0m\n\x1b[31m|-- fav_colors: struct (nullable = true) |-- fav_colors: struct (nullable = true)\x1b[0m\n\x1b[31m |-- red: integer (nullable = true) |-- orange: integer (nullable = true)\x1b[0m\n\x1b[34m |-- green: integer (nullable = true) |-- green: integer (nullable = false)\x1b[0m\n\x1b[31m |-- blue: integer (nullable = true) |-- yellow: integer (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas_multiple_nested_structs.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[34m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[31m|-- fav_genres: struct (nullable = true) |-- fav_genres: struct (nullable = true)\x1b[0m\n\x1b[31m |-- rock: struct (nullable = true) |-- rock: struct (nullable = true)\x1b[0m\n\x1b[34m |-- metal: integer (nullable = true) |-- metal: integer (nullable = true)\x1b[0m\n\x1b[31m |-- punk: integer (nullable = true) |-- classic: integer (nullable = true)\x1b[0m\n\x1b[34m |-- electronic: struct (nullable = true) |-- electronic: struct (nullable = true)\x1b[0m\n\x1b[34m |-- house: integer (nullable = true) |-- house: integer (nullable = true)\x1b[0m\n\x1b[34m |-- dubstep: integer (nullable = true) |-- dubstep: integer (nullable = true)\x1b[0m\n\x1b[31m |-- pop: struct (nullable = true)\x1b[0m\n\x1b[31m |-- pop: integer (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/data/tree_string/it_prints_correctly_for_wide_schemas_with_metadata.txt: -------------------------------------------------------------------------------- 1 | '\nschema1 schema2\n\x1b[31m|-- name: string (nullable = true) |-- name: string (nullable = true)\x1b[0m\n\x1b[34m|-- age: integer (nullable = true) |-- age: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_number: integer (nullable = true) |-- fav_number: integer (nullable = true)\x1b[0m\n\x1b[34m|-- fav_numbers: array (nullable = true) |-- fav_numbers: array (nullable = true)\x1b[0m\n\x1b[31m|-- fav_colors: struct (nullable = true) |-- fav_colors: struct (nullable = true)\x1b[0m\n\x1b[31m |-- red: integer (nullable = true) |-- orange: integer (nullable = true)\x1b[0m\n\x1b[34m |-- green: integer (nullable = true) |-- green: integer (nullable = true)\x1b[0m\n\x1b[31m |-- blue: integer (nullable = true) |-- yellow: integer (nullable = true)\x1b[0m\n\x1b[0m' 2 | -------------------------------------------------------------------------------- /tests/formatting/test_formats.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | import pytest 6 | 7 | from chispa.formatting import Color, Format, Style 8 | 9 | 10 | def test_format_from_dict_valid(): 11 | format_dict = {"color": "blue", "style": ["bold", "underline"]} 12 | format_instance = Format.from_dict(format_dict) 13 | assert format_instance.color == Color.BLUE 14 | assert format_instance.style == [Style.BOLD, Style.UNDERLINE] 15 | 16 | 17 | def test_format_from_dict_invalid_color(): 18 | format_dict = {"color": "invalid_color", "style": ["bold"]} 19 | with pytest.raises(ValueError) as exc_info: 20 | Format.from_dict(format_dict) 21 | assert str(exc_info.value) == ( 22 | "Invalid color name: invalid_color. Valid color names are " 23 | "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " 24 | "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " 25 | "'light_cyan', 'white']" 26 | ) 27 | 28 | 29 | def test_format_from_dict_invalid_style(): 30 | format_dict = {"color": "blue", "style": ["invalid_style"]} 31 | with pytest.raises(ValueError) as exc_info: 32 | Format.from_dict(format_dict) 33 | assert str(exc_info.value) == ( 34 | "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" 35 | ) 36 | 37 | 38 | def test_format_from_dict_invalid_key(): 39 | format_dict = {"invalid_key": "value"} 40 | try: 41 | Format.from_dict(format_dict) 42 | except ValueError as e: 43 | error_message = str(e) 44 | assert re.match( 45 | r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", 46 | error_message, 47 | ) 48 | 49 | 50 | def test_format_from_list_valid(): 51 | values = ["blue", "bold", "underline"] 52 | format_instance = Format.from_list(values) 53 | assert format_instance.color == Color.BLUE 54 | assert format_instance.style == [Style.BOLD, Style.UNDERLINE] 55 | 56 | 57 | def test_format_from_list_invalid_color(): 58 | values = ["invalid_color", "bold"] 59 | with pytest.raises(ValueError) as exc_info: 60 | Format.from_list(values) 61 | assert str(exc_info.value) == ( 62 | "Invalid value: invalid_color. Valid values are colors: " 63 | "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " 64 | "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " 65 | "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" 66 | ) 67 | 68 | 69 | def test_format_from_list_invalid_style(): 70 | values = ["blue", "invalid_style"] 71 | with pytest.raises(ValueError) as exc_info: 72 | Format.from_list(values) 73 | assert str(exc_info.value) == ( 74 | "Invalid value: invalid_style. Valid values are colors: " 75 | "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " 76 | "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " 77 | "'light_cyan', 'white'] and styles: ['bold', 'underline', 'blink', 'invert', 'hide']" 78 | ) 79 | 80 | 81 | def test_format_from_list_non_string_elements(): 82 | values = ["blue", 123] 83 | with pytest.raises(ValueError) as exc_info: 84 | Format.from_list(values) 85 | assert str(exc_info.value) == "All elements in the list must be strings" 86 | 87 | 88 | def test_format_from_dict_empty(): 89 | format_dict = {} 90 | format_instance = Format.from_dict(format_dict) 91 | assert format_instance.color is None 92 | assert format_instance.style is None 93 | 94 | 95 | def test_format_from_list_empty(): 96 | values = [] 97 | format_instance = Format.from_list(values) 98 | assert format_instance.color is None 99 | assert format_instance.style is None 100 | -------------------------------------------------------------------------------- /tests/formatting/test_formatting_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | 5 | import pytest 6 | 7 | from chispa.formatting import Color, FormattingConfig, Style 8 | 9 | 10 | def test_default_mismatched_rows(): 11 | config = FormattingConfig() 12 | assert config.mismatched_rows.color == Color.RED 13 | assert config.mismatched_rows.style is None 14 | 15 | 16 | def test_default_matched_rows(): 17 | config = FormattingConfig() 18 | assert config.matched_rows.color == Color.BLUE 19 | assert config.matched_rows.style is None 20 | 21 | 22 | def test_default_mismatched_cells(): 23 | config = FormattingConfig() 24 | assert config.mismatched_cells.color == Color.RED 25 | assert config.mismatched_cells.style == [Style.UNDERLINE] 26 | 27 | 28 | def test_default_matched_cells(): 29 | config = FormattingConfig() 30 | assert config.matched_cells.color == Color.BLUE 31 | assert config.matched_cells.style is None 32 | 33 | 34 | def test_custom_mismatched_rows(): 35 | config = FormattingConfig(mismatched_rows={"color": "green", "style": ["bold", "underline"]}) 36 | assert config.mismatched_rows.color == Color.GREEN 37 | assert config.mismatched_rows.style == [Style.BOLD, Style.UNDERLINE] 38 | 39 | 40 | def test_custom_matched_rows(): 41 | config = FormattingConfig(matched_rows={"color": "yellow"}) 42 | assert config.matched_rows.color == Color.YELLOW 43 | assert config.matched_rows.style is None 44 | 45 | 46 | def test_custom_mismatched_cells(): 47 | config = FormattingConfig(mismatched_cells={"color": "purple", "style": ["blink"]}) 48 | assert config.mismatched_cells.color == Color.PURPLE 49 | assert config.mismatched_cells.style == [Style.BLINK] 50 | 51 | 52 | def test_custom_matched_cells(): 53 | config = FormattingConfig(matched_cells={"color": "cyan", "style": ["invert", "hide"]}) 54 | assert config.matched_cells.color == Color.CYAN 55 | assert config.matched_cells.style == [Style.INVERT, Style.HIDE] 56 | 57 | 58 | def test_invalid_color(): 59 | with pytest.raises(ValueError) as exc_info: 60 | FormattingConfig(mismatched_rows={"color": "invalid_color"}) 61 | assert str(exc_info.value) == ( 62 | "Invalid color name: invalid_color. Valid color names are " 63 | "['black', 'red', 'green', 'yellow', 'blue', 'purple', 'cyan', 'light_gray', " 64 | "'dark_gray', 'light_red', 'light_green', 'light_yellow', 'light_blue', 'light_purple', " 65 | "'light_cyan', 'white']" 66 | ) 67 | 68 | 69 | def test_invalid_style(): 70 | with pytest.raises(ValueError) as exc_info: 71 | FormattingConfig(mismatched_rows={"style": ["invalid_style"]}) 72 | assert str(exc_info.value) == ( 73 | "Invalid style name: invalid_style. Valid style names are " "['bold', 'underline', 'blink', 'invert', 'hide']" 74 | ) 75 | 76 | 77 | def test_invalid_key(): 78 | try: 79 | FormattingConfig(mismatched_rows={"invalid_key": "value"}) 80 | except ValueError as e: 81 | error_message = str(e) 82 | assert re.match( 83 | r"Invalid keys in format dictionary: \{'invalid_key'\}. Valid keys are \{('color', 'style'|'style', 'color')\}", 84 | error_message, 85 | ) 86 | -------------------------------------------------------------------------------- /tests/formatting/test_terminal_string_formatter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from chispa.formatting import RESET, format_string 4 | from chispa.formatting.formats import Color, Format, Style 5 | 6 | 7 | def test_format_with_enum_inputs(): 8 | format = Format(color=Color.BLUE, style=[Style.BOLD, Style.UNDERLINE]) 9 | formatted_string = format_string("Hello, World!", format) 10 | expected_string = f"{Style.BOLD.value}{Style.UNDERLINE.value}{Color.BLUE.value}Hello, World!{RESET}" 11 | assert formatted_string == expected_string 12 | 13 | 14 | def test_format_with_no_style(): 15 | format = Format(color=Color.GREEN, style=[]) 16 | formatted_string = format_string("Hello, World!", format) 17 | expected_string = f"{Color.GREEN.value}Hello, World!{RESET}" 18 | assert formatted_string == expected_string 19 | 20 | 21 | def test_format_with_no_color(): 22 | format = Format(color=None, style=[Style.BLINK]) 23 | formatted_string = format_string("Hello, World!", format) 24 | expected_string = f"{Style.BLINK.value}Hello, World!{RESET}" 25 | assert formatted_string == expected_string 26 | 27 | 28 | def test_format_with_no_color_or_style(): 29 | format = Format(color=None, style=[]) 30 | formatted_string = format_string("Hello, World!", format) 31 | expected_string = "Hello, World!" 32 | assert formatted_string == expected_string 33 | -------------------------------------------------------------------------------- /tests/spark.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pyspark.sql import SparkSession 4 | 5 | spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() 6 | -------------------------------------------------------------------------------- /tests/test_column_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from chispa import ColumnsNotEqualError, assert_approx_column_equality, assert_column_equality 6 | 7 | from .spark import spark 8 | 9 | 10 | def describe_assert_column_equality(): 11 | def it_throws_error_with_data_mismatch(): 12 | data = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] 13 | df = spark.createDataFrame(data, ["name", "expected_name"]) 14 | with pytest.raises(ColumnsNotEqualError): 15 | assert_column_equality(df, "name", "expected_name") 16 | 17 | def it_doesnt_throw_without_mismatch(): 18 | data = [("jose", "jose"), ("li", "li"), ("luisa", "luisa"), (None, None)] 19 | df = spark.createDataFrame(data, ["name", "expected_name"]) 20 | assert_column_equality(df, "name", "expected_name") 21 | 22 | def it_works_with_integer_values(): 23 | data = [(1, 1), (10, 10), (8, 8), (None, None)] 24 | df = spark.createDataFrame(data, ["num1", "num2"]) 25 | assert_column_equality(df, "num1", "num2") 26 | 27 | 28 | def describe_assert_approx_column_equality(): 29 | def it_works_with_no_mismatches(): 30 | data = [(1.1, 1.1), (1.0004, 1.0005), (0.4, 0.45), (None, None)] 31 | df = spark.createDataFrame(data, ["num1", "num2"]) 32 | assert_approx_column_equality(df, "num1", "num2", 0.1) 33 | 34 | def it_throws_when_difference_is_bigger_than_precision(): 35 | data = [(1.5, 1.1), (1.0004, 1.0005), (0.4, 0.45)] 36 | df = spark.createDataFrame(data, ["num1", "num2"]) 37 | with pytest.raises(ColumnsNotEqualError): 38 | assert_approx_column_equality(df, "num1", "num2", 0.1) 39 | 40 | def it_throws_when_comparing_floats_with_none(): 41 | data = [(1.1, 1.1), (2.2, 2.2), (3.3, None)] 42 | df = spark.createDataFrame(data, ["num1", "num2"]) 43 | with pytest.raises(ColumnsNotEqualError): 44 | assert_approx_column_equality(df, "num1", "num2", 0.1) 45 | 46 | def it_throws_when_comparing_none_with_floats(): 47 | data = [(1.1, 1.1), (2.2, 2.2), (None, 3.3)] 48 | df = spark.createDataFrame(data, ["num1", "num2"]) 49 | with pytest.raises(ColumnsNotEqualError): 50 | assert_approx_column_equality(df, "num1", "num2", 0.1) 51 | -------------------------------------------------------------------------------- /tests/test_dataframe_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | import pytest 6 | from pyspark.sql.types import IntegerType, StringType, StructField, StructType 7 | 8 | from chispa import DataFramesNotEqualError, assert_approx_df_equality, assert_df_equality 9 | from chispa.dataframe_comparer import are_dfs_equal 10 | from chispa.schema_comparer import SchemasNotEqualError 11 | 12 | from .spark import spark 13 | 14 | 15 | def describe_assert_df_equality(): 16 | def it_throws_with_schema_mismatches(): 17 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 18 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 19 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 20 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 21 | with pytest.raises(SchemasNotEqualError): 22 | assert_df_equality(df1, df2) 23 | 24 | def it_can_work_with_different_row_orders(): 25 | data1 = [(1, "jose"), (2, "li")] 26 | df1 = spark.createDataFrame(data1, ["num", "name"]) 27 | data2 = [(2, "li"), (1, "jose")] 28 | df2 = spark.createDataFrame(data2, ["num", "name"]) 29 | assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) 30 | 31 | def it_can_work_with_different_row_orders_with_a_flag(): 32 | data1 = [(1, "jose"), (2, "li")] 33 | df1 = spark.createDataFrame(data1, ["num", "name"]) 34 | data2 = [(2, "li"), (1, "jose")] 35 | df2 = spark.createDataFrame(data2, ["num", "name"]) 36 | assert_df_equality(df1, df2, ignore_row_order=True) 37 | 38 | def it_can_work_with_different_row_and_column_orders(): 39 | data1 = [(1, "jose"), (2, "li")] 40 | df1 = spark.createDataFrame(data1, ["num", "name"]) 41 | data2 = [("li", 2), ("jose", 1)] 42 | df2 = spark.createDataFrame(data2, ["name", "num"]) 43 | assert_df_equality(df1, df2, ignore_row_order=True, ignore_column_order=True) 44 | 45 | def it_raises_for_row_insensitive_with_diff_content(): 46 | data1 = [(1, "XXXX"), (2, "li")] 47 | df1 = spark.createDataFrame(data1, ["num", "name"]) 48 | data2 = [(2, "li"), (1, "jose")] 49 | df2 = spark.createDataFrame(data2, ["num", "name"]) 50 | with pytest.raises(DataFramesNotEqualError): 51 | assert_df_equality(df1, df2, transforms=[lambda df: df.sort(df.columns)]) 52 | 53 | def it_throws_with_schema_column_order_mismatch(): 54 | data1 = [(1, "jose"), (2, "li")] 55 | df1 = spark.createDataFrame(data1, ["num", "name"]) 56 | data2 = [("jose", 1), ("li", 1)] 57 | df2 = spark.createDataFrame(data2, ["name", "num"]) 58 | with pytest.raises(SchemasNotEqualError): 59 | assert_df_equality(df1, df2) 60 | 61 | def it_does_not_throw_on_schema_column_order_mismatch_with_transforms(): 62 | data1 = [(1, "jose"), (2, "li")] 63 | df1 = spark.createDataFrame(data1, ["num", "name"]) 64 | data2 = [("jose", 1), ("li", 2)] 65 | df2 = spark.createDataFrame(data2, ["name", "num"]) 66 | assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) 67 | 68 | def it_throws_with_schema_mismatch(): 69 | data1 = [(1, "jose"), (2, "li")] 70 | df1 = spark.createDataFrame(data1, ["num", "different_name"]) 71 | data2 = [("jose", 1), ("li", 2)] 72 | df2 = spark.createDataFrame(data2, ["name", "num"]) 73 | with pytest.raises(SchemasNotEqualError): 74 | assert_df_equality(df1, df2, transforms=[lambda df: df.select(sorted(df.columns))]) 75 | 76 | def it_throws_with_content_mismatches(): 77 | data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] 78 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 79 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 80 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 81 | with pytest.raises(DataFramesNotEqualError): 82 | assert_df_equality(df1, df2) 83 | 84 | def it_throws_with_length_mismatches(): 85 | data1 = [("jose", "jose"), ("li", "li"), ("laura", "laura")] 86 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 87 | data2 = [("jose", "jose"), ("li", "li")] 88 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 89 | with pytest.raises(DataFramesNotEqualError): 90 | assert_df_equality(df1, df2) 91 | 92 | def it_can_consider_nan_values_equal(): 93 | data1 = [(float("nan"), "jose"), (2.0, "li")] 94 | df1 = spark.createDataFrame(data1, ["num", "name"]) 95 | data2 = [(float("nan"), "jose"), (2.0, "li")] 96 | df2 = spark.createDataFrame(data2, ["num", "name"]) 97 | assert_df_equality(df1, df2, allow_nan_equality=True) 98 | 99 | def it_does_not_consider_nan_values_equal_by_default(): 100 | data1 = [(float("nan"), "jose"), (2.0, "li")] 101 | df1 = spark.createDataFrame(data1, ["num", "name"]) 102 | data2 = [(float("nan"), "jose"), (2.0, "li")] 103 | df2 = spark.createDataFrame(data2, ["num", "name"]) 104 | with pytest.raises(DataFramesNotEqualError): 105 | assert_df_equality(df1, df2, allow_nan_equality=False) 106 | 107 | def it_can_ignore_metadata(): 108 | rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] 109 | schema1 = StructType([ 110 | StructField("name", StringType(), True, {"hi": "no"}), 111 | StructField("age", IntegerType(), True), 112 | ]) 113 | schema2 = StructType([ 114 | StructField("name", StringType(), True, {"hi": "whatever"}), 115 | StructField("age", IntegerType(), True), 116 | ]) 117 | df1 = spark.createDataFrame(rows_data, schema1) 118 | df2 = spark.createDataFrame(rows_data, schema2) 119 | assert_df_equality(df1, df2, ignore_metadata=True) 120 | 121 | def it_catches_mismatched_metadata(): 122 | rows_data = [("jose", 1), ("li", 2), ("luisa", 3)] 123 | schema1 = StructType([ 124 | StructField("name", StringType(), True, {"hi": "no"}), 125 | StructField("age", IntegerType(), True), 126 | ]) 127 | schema2 = StructType([ 128 | StructField("name", StringType(), True, {"hi": "whatever"}), 129 | StructField("age", IntegerType(), True), 130 | ]) 131 | df1 = spark.createDataFrame(rows_data, schema1) 132 | df2 = spark.createDataFrame(rows_data, schema2) 133 | with pytest.raises(SchemasNotEqualError): 134 | assert_df_equality(df1, df2) 135 | 136 | def it_can_ignore_columns(): 137 | data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 138 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 139 | data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")] 140 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 141 | assert_df_equality(df1, df2, ignore_columns=["expected_name"]) 142 | 143 | def it_throws_when_dfs_are_not_same_with_ignored_columns(): 144 | data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 145 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 146 | data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")] 147 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 148 | with pytest.raises(DataFramesNotEqualError): 149 | assert assert_df_equality(df1, df2, ignore_columns=["name"]) 150 | 151 | def it_works_when_sorting_and_dropping_columns(): 152 | data1 = [("b", "jose", 10), ("a", "jose", 20)] 153 | df1 = spark.createDataFrame(data1, ["ignore_me", "name", "score"]) 154 | data2 = [("a", "jose", 10), ("b", "jose", 20)] 155 | df2 = spark.createDataFrame(data2, ["ignore_me", "name", "score"]) 156 | assert_df_equality(df1, df2, ignore_columns=["ignore_me"], ignore_row_order=True) 157 | 158 | 159 | def describe_are_dfs_equal(): 160 | def it_returns_false_with_schema_mismatches(): 161 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 162 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 163 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 164 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 165 | assert are_dfs_equal(df1, df2) is False 166 | 167 | def it_returns_false_with_content_mismatches(): 168 | data1 = [("jose", "jose"), ("li", "li"), ("luisa", "laura")] 169 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 170 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 171 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 172 | assert are_dfs_equal(df1, df2) is False 173 | 174 | def it_returns_true_when_dfs_are_same(): 175 | data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 176 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 177 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 178 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 179 | assert are_dfs_equal(df1, df2) is True 180 | 181 | 182 | def describe_assert_approx_df_equality(): 183 | def it_throws_with_content_mismatch(): 184 | data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (1.0, None)] 185 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 186 | data2 = [(1.0, "jose"), (1.05, "li"), (1.0, "laura"), (None, "hi")] 187 | df2 = spark.createDataFrame(data2, ["num", "expected_name"]) 188 | with pytest.raises(DataFramesNotEqualError): 189 | assert_approx_df_equality(df1, df2, 0.1) 190 | 191 | def it_throws_with_with_length_mismatch(): 192 | data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] 193 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 194 | data2 = [(1.0, "jose"), (1.05, "li")] 195 | df2 = spark.createDataFrame(data2, ["num", "expected_name"]) 196 | with pytest.raises(DataFramesNotEqualError): 197 | assert_approx_df_equality(df1, df2, 0.1) 198 | 199 | def it_does_not_throw_with_no_mismatch(): 200 | data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] 201 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 202 | data2 = [(1.0, "jose"), (1.05, "li"), (1.2, "laura"), (None, None)] 203 | df2 = spark.createDataFrame(data2, ["num", "expected_name"]) 204 | assert_approx_df_equality(df1, df2, 0.1) 205 | 206 | def it_does_not_throw_with_different_row_col_order(): 207 | data1 = [(1.0, "jose"), (1.1, "li"), (1.2, "laura"), (None, None)] 208 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 209 | data2 = [("li", 1.05), ("laura", 1.2), (None, None), ("jose", 1.0)] 210 | df2 = spark.createDataFrame(data2, ["expected_name", "num"]) 211 | assert_approx_df_equality(df1, df2, 0.1, ignore_row_order=True, ignore_column_order=True) 212 | 213 | def it_does_not_throw_with_nan_values(): 214 | data1 = [ 215 | (1.0, "jose"), 216 | (1.1, "li"), 217 | (1.2, "laura"), 218 | (None, None), 219 | (float("nan"), "buk"), 220 | ] 221 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 222 | data2 = [ 223 | (1.0, "jose"), 224 | (1.05, "li"), 225 | (1.2, "laura"), 226 | (None, None), 227 | (math.nan, "buk"), 228 | ] 229 | df2 = spark.createDataFrame(data2, ["num", "expected_name"]) 230 | assert_approx_df_equality(df1, df2, 0.1, allow_nan_equality=True) 231 | 232 | def it_can_ignore_columns(): 233 | data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 234 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 235 | data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")] 236 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 237 | assert_approx_df_equality(df1, df2, 0.1, ignore_columns=["expected_name"]) 238 | 239 | def it_throws_when_dfs_are_not_same_with_ignored_columns(): 240 | data1 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 241 | df1 = spark.createDataFrame(data1, ["name", "expected_name"]) 242 | data2 = [("bob", "jose"), ("li", "boo"), ("luisa", "boo")] 243 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 244 | with pytest.raises(DataFramesNotEqualError): 245 | assert assert_approx_df_equality(df1, df2, 0.1, ignore_columns=["name"]) 246 | -------------------------------------------------------------------------------- /tests/test_deprecated.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from dataclasses import dataclass 5 | 6 | import pytest 7 | 8 | from chispa import DataFramesNotEqualError, assert_basic_rows_equality 9 | from chispa.bcolors import bcolors, blue, underline_text 10 | from chispa.default_formats import DefaultFormats 11 | from chispa.formatting import FormattingConfig 12 | 13 | from .spark import spark 14 | 15 | 16 | def test_default_formats_deprecation_warning(): 17 | with warnings.catch_warnings(record=True) as w: 18 | warnings.simplefilter("always") 19 | DefaultFormats() 20 | assert len(w) == 1 21 | assert issubclass(w[-1].category, DeprecationWarning) 22 | assert "DefaultFormats is deprecated" in str(w[-1].message) 23 | 24 | 25 | def test_that_default_formats_still_works(): 26 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 27 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 28 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 29 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 30 | with pytest.raises(DataFramesNotEqualError): 31 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=DefaultFormats()) 32 | 33 | 34 | def test_deprecated_arbitrary_dataclass(): 35 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 36 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 37 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 38 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 39 | 40 | @dataclass 41 | class CustomFormats: 42 | mismatched_rows = ["green"] # noqa: RUF012 43 | matched_rows = ["yellow"] # noqa: RUF012 44 | mismatched_cells = ["purple", "bold"] # noqa: RUF012 45 | matched_cells = ["cyan"] # noqa: RUF012 46 | 47 | with warnings.catch_warnings(record=True) as w: 48 | try: 49 | assert_basic_rows_equality(df1.collect(), df2.collect(), formats=CustomFormats()) 50 | # should not reach the line below due to the raised error. 51 | # pytest.raises does not work as expected since then we cannot verify the warning. 52 | assert False 53 | except DataFramesNotEqualError: 54 | warnings.simplefilter("always") 55 | assert len(w) == 1 56 | assert issubclass(w[-1].category, DeprecationWarning) 57 | assert "Using an arbitrary dataclass is deprecated." in str(w[-1].message) 58 | 59 | 60 | def test_invalid_value_in_default_formats(): 61 | @dataclass 62 | class InvalidFormats: 63 | mismatched_rows = ["green"] # noqa: RUF012 64 | matched_rows = ["yellow"] # noqa: RUF012 65 | mismatched_cells = ["purple", "invalid"] # noqa: RUF012 66 | matched_cells = ["cyan"] # noqa: RUF012 67 | 68 | with pytest.raises(ValueError): 69 | FormattingConfig._from_arbitrary_dataclass(InvalidFormats()) 70 | 71 | 72 | def test_bcolors_deprecation(): 73 | with pytest.warns(DeprecationWarning, match="The `bcolors` class is deprecated"): 74 | _ = bcolors() 75 | 76 | 77 | def test_blue_deprecation(): 78 | with pytest.warns(DeprecationWarning, match="The `blue` function is deprecated"): 79 | result = blue("test") 80 | assert result == "\033[34mtest\033[31m" 81 | 82 | 83 | def test_underline_text_deprecation(): 84 | with pytest.warns(DeprecationWarning, match="The `underline_text` function is deprecated"): 85 | result = underline_text("test") 86 | assert result == "\033[97m\033[4mtest\033[0m\033[31m" 87 | -------------------------------------------------------------------------------- /tests/test_readme_examples.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pyspark.sql.functions as F 4 | import pytest 5 | from pyspark.sql import SparkSession 6 | from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType, StructField, StructType 7 | 8 | from chispa import ( 9 | ColumnsNotEqualError, 10 | DataFramesNotEqualError, 11 | assert_approx_column_equality, 12 | assert_approx_df_equality, 13 | assert_basic_rows_equality, 14 | assert_column_equality, 15 | assert_df_equality, 16 | ) 17 | from chispa.schema_comparer import SchemasNotEqualError 18 | 19 | 20 | def remove_non_word_characters(col): 21 | return F.regexp_replace(col, "[^\\w\\s]+", "") 22 | 23 | 24 | spark = SparkSession.builder.master("local").appName("chispa").getOrCreate() 25 | 26 | 27 | def describe_column_equality(): 28 | def test_removes_non_word_characters_short(): 29 | data = [ 30 | ("jo&&se", "jose"), 31 | ("**li**", "li"), 32 | ("#::luisa", "luisa"), 33 | (None, None), 34 | ] 35 | df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( 36 | "clean_name", remove_non_word_characters(F.col("name")) 37 | ) 38 | assert_column_equality(df, "clean_name", "expected_name") 39 | 40 | def test_remove_non_word_characters_nice_error(): 41 | data = [ 42 | ("matt7", "matt"), 43 | ("bill&", "bill"), 44 | ("isabela*", "isabela"), 45 | (None, None), 46 | ] 47 | df = spark.createDataFrame(data, ["name", "expected_name"]).withColumn( 48 | "clean_name", remove_non_word_characters(F.col("name")) 49 | ) 50 | # assert_column_equality(df, "clean_name", "expected_name") 51 | with pytest.raises(ColumnsNotEqualError): 52 | assert_column_equality(df, "clean_name", "expected_name") 53 | 54 | 55 | def describe_dataframe_equality(): 56 | def test_remove_non_word_characters_long(): 57 | source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)] 58 | source_df = spark.createDataFrame(source_data, ["name"]) 59 | actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) 60 | expected_data = [ 61 | ("jo&&se", "jose"), 62 | ("**li**", "li"), 63 | ("#::luisa", "luisa"), 64 | (None, None), 65 | ] 66 | expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) 67 | assert_df_equality(actual_df, expected_df) 68 | 69 | def test_remove_non_word_characters_long_error(): 70 | source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] 71 | source_df = spark.createDataFrame(source_data, ["name"]) 72 | actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) 73 | expected_data = [ 74 | ("matt7", "matt"), 75 | ("bill&", "bill"), 76 | ("isabela*", "isabela"), 77 | (None, None), 78 | ] 79 | expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) 80 | # assert_df_equality(actual_df, expected_df) 81 | with pytest.raises(DataFramesNotEqualError): 82 | assert_df_equality(actual_df, expected_df) 83 | 84 | def ignore_row_order(): 85 | df1 = spark.createDataFrame([(1,), (2,), (3,)], ["some_num"]) 86 | df2 = spark.createDataFrame([(2,), (1,), (3,)], ["some_num"]) 87 | # assert_df_equality(df1, df2) 88 | assert_df_equality(df1, df2, ignore_row_order=True) 89 | 90 | def ignore_column_order(): 91 | df1 = spark.createDataFrame([(1, 7), (2, 8), (3, 9)], ["num1", "num2"]) 92 | df2 = spark.createDataFrame([(7, 1), (8, 2), (9, 3)], ["num2", "num1"]) 93 | assert_df_equality(df1, df2, ignore_column_order=True) 94 | 95 | def ignore_nullable_property(): 96 | s1 = StructType([ 97 | StructField("name", StringType(), True), 98 | StructField("age", IntegerType(), True), 99 | ]) 100 | df1 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s1) 101 | s2 = StructType([ 102 | StructField("name", StringType(), True), 103 | StructField("age", IntegerType(), False), 104 | ]) 105 | df2 = spark.createDataFrame([("juan", 7), ("bruna", 8)], s2) 106 | assert_df_equality(df1, df2, ignore_nullable=True) 107 | 108 | def ignore_nullable_property_array(): 109 | s1 = StructType([ 110 | StructField("name", StringType(), True), 111 | StructField("coords", ArrayType(DoubleType(), True), True), 112 | ]) 113 | df1 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s1) 114 | s2 = StructType([ 115 | StructField("name", StringType(), True), 116 | StructField("coords", ArrayType(DoubleType(), False), True), 117 | ]) 118 | df2 = spark.createDataFrame([("juan", [1.42, 3.5]), ("bruna", [2.76, 3.2])], s2) 119 | assert_df_equality(df1, df2, ignore_nullable=True) 120 | 121 | def consider_nan_values_equal(): 122 | data1 = [(float("nan"), "jose"), (2.0, "li")] 123 | df1 = spark.createDataFrame(data1, ["num", "name"]) 124 | data2 = [(float("nan"), "jose"), (2.0, "li")] 125 | df2 = spark.createDataFrame(data2, ["num", "name"]) 126 | assert_df_equality(df1, df2, allow_nan_equality=True) 127 | 128 | def it_prints_underline_message(): 129 | data = [ 130 | (None, None), 131 | ("jose", 42), 132 | ("li", 99), 133 | ("rick", 28), 134 | ("funny", 33), 135 | ] 136 | df1 = spark.createDataFrame(data, ["firstname", "age"]) 137 | data = [ 138 | (None, None), 139 | ("lou", 42), 140 | ("li", 99), 141 | ("rick", 66), 142 | ] 143 | df2 = spark.createDataFrame(data, ["firstname", "age"]) 144 | with pytest.raises(DataFramesNotEqualError): 145 | assert_df_equality(df1, df2, underline_cells=True) 146 | 147 | def it_shows_assert_basic_rows_equality(my_formats): 148 | data = [ 149 | (None, None), 150 | ("jose", 42), 151 | ("li", 99), 152 | ("rick", 28), 153 | ("funny", 33), 154 | ] 155 | df1 = spark.createDataFrame(data, ["firstname", "age"]) 156 | data = [ 157 | (None, None), 158 | ("lou", 42), 159 | ("li", 99), 160 | ("rick", 66), 161 | ] 162 | df2 = spark.createDataFrame(data, ["firstname", "age"]) 163 | # assert_basic_rows_equality(df1.collect(), df2.collect(), formats=my_formats) 164 | with pytest.raises(DataFramesNotEqualError): 165 | assert_basic_rows_equality(df1.collect(), df2.collect(), underline_cells=True) 166 | 167 | 168 | def describe_assert_approx_column_equality(): 169 | def test_approx_col_equality_same(): 170 | data = [(1.1, 1.1), (2.2, 2.15), (3.3, 3.37), (None, None)] 171 | df = spark.createDataFrame(data, ["num1", "num2"]) 172 | assert_approx_column_equality(df, "num1", "num2", 0.1) 173 | 174 | def test_approx_col_equality_different(): 175 | data = [(1.1, 1.1), (2.2, 2.15), (3.3, 5.0), (None, None)] 176 | df = spark.createDataFrame(data, ["num1", "num2"]) 177 | with pytest.raises(ColumnsNotEqualError): 178 | assert_approx_column_equality(df, "num1", "num2", 0.1) 179 | 180 | def test_approx_df_equality_same(): 181 | data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] 182 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 183 | data2 = [(1.05, "a"), (2.13, "b"), (3.3, "c"), (None, None)] 184 | df2 = spark.createDataFrame(data2, ["num", "letter"]) 185 | assert_approx_df_equality(df1, df2, 0.1) 186 | 187 | def test_approx_df_equality_different(): 188 | data1 = [(1.1, "a"), (2.2, "b"), (3.3, "c"), (None, None)] 189 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 190 | data2 = [(1.1, "a"), (5.0, "b"), (3.3, "z"), (None, None)] 191 | df2 = spark.createDataFrame(data2, ["num", "letter"]) 192 | # assert_approx_df_equality(df1, df2, 0.1) 193 | with pytest.raises(DataFramesNotEqualError): 194 | assert_approx_df_equality(df1, df2, 0.1) 195 | 196 | 197 | def describe_schema_mismatch_messages(): 198 | def test_schema_mismatch_message(): 199 | data1 = [(1, "a"), (2, "b"), (3, "c"), (None, None)] 200 | df1 = spark.createDataFrame(data1, ["num", "letter"]) 201 | data2 = [(1, 6), (2, 7), (3, 8), (None, None)] 202 | df2 = spark.createDataFrame(data2, ["num", "num2"]) 203 | with pytest.raises(SchemasNotEqualError): 204 | assert_df_equality(df1, df2) 205 | 206 | 207 | def test_remove_non_word_characters_long_error(my_chispa): 208 | source_data = [("matt7",), ("bill&",), ("isabela*",), (None,)] 209 | source_df = spark.createDataFrame(source_data, ["name"]) 210 | actual_df = source_df.withColumn("clean_name", remove_non_word_characters(F.col("name"))) 211 | expected_data = [ 212 | ("matt7", "matt"), 213 | ("bill&", "bill"), 214 | ("isabela*", "isabela"), 215 | (None, None), 216 | ] 217 | expected_df = spark.createDataFrame(expected_data, ["name", "clean_name"]) 218 | # my_chispa.assert_df_equality(actual_df, expected_df) 219 | with pytest.raises(DataFramesNotEqualError): 220 | assert_df_equality(actual_df, expected_df) 221 | -------------------------------------------------------------------------------- /tests/test_row_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pyspark.sql import Row 4 | 5 | from chispa.row_comparer import are_rows_approx_equal, are_rows_equal, are_rows_equal_enhanced 6 | 7 | 8 | def test_are_rows_equal(): 9 | assert are_rows_equal(Row("bob", "jose"), Row("li", "li")) is False 10 | assert are_rows_equal(Row("luisa", "laura"), Row("luisa", "laura")) is True 11 | assert are_rows_equal(Row(None, None), Row(None, None)) is True 12 | 13 | 14 | def test_are_rows_equal_enhanced(): 15 | assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), False) is False 16 | assert are_rows_equal_enhanced(Row(n1="luisa", n2="laura"), Row(n1="luisa", n2="laura"), False) is True 17 | assert are_rows_equal_enhanced(Row(n1=None, n2=None), Row(n1=None, n2=None), False) is True 18 | 19 | assert are_rows_equal_enhanced(Row(n1="bob", n2="jose"), Row(n1="li", n2="li"), True) is False 20 | assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1=float("nan"), n2="jose"), True) is True 21 | assert are_rows_equal_enhanced(Row(n1=float("nan"), n2="jose"), Row(n1="hi", n2="jose"), True) is False 22 | 23 | 24 | def test_are_rows_approx_equal(): 25 | assert are_rows_approx_equal(Row(num=1.1, first_name="li"), Row(num=1.05, first_name="li"), 0.1) is True 26 | assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.0, first_name="laura"), 0.1) is True 27 | assert are_rows_approx_equal(Row(num=5.0, first_name="laura"), Row(num=5.9, first_name="laura"), 0.1) is False 28 | assert are_rows_approx_equal(Row(num=None, first_name=None), Row(num=None, first_name=None), 0.1) is True 29 | -------------------------------------------------------------------------------- /tests/test_rows_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from chispa import DataFramesNotEqualError, assert_basic_rows_equality 6 | 7 | from .spark import spark 8 | 9 | 10 | def describe_assert_basic_rows_equality(): 11 | def it_throws_with_row_mismatches(): 12 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 13 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 14 | data2 = [("bob", "jose"), ("li", "li"), ("luisa", "laura")] 15 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 16 | with pytest.raises(DataFramesNotEqualError): 17 | assert_basic_rows_equality(df1.collect(), df2.collect()) 18 | 19 | def it_throws_when_rows_have_different_lengths(): 20 | data1 = [(1, "jose"), (2, "li"), (3, "laura"), (4, "bill")] 21 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 22 | data2 = [(1, "jose"), (2, "li"), (3, "laura")] 23 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 24 | with pytest.raises(DataFramesNotEqualError): 25 | assert_basic_rows_equality(df1.collect(), df2.collect()) 26 | 27 | def it_works_when_rows_are_the_same(): 28 | data1 = [(1, "jose"), (2, "li"), (3, "laura")] 29 | df1 = spark.createDataFrame(data1, ["num", "expected_name"]) 30 | data2 = [(1, "jose"), (2, "li"), (3, "laura")] 31 | df2 = spark.createDataFrame(data2, ["name", "expected_name"]) 32 | assert_basic_rows_equality(df1.collect(), df2.collect()) 33 | -------------------------------------------------------------------------------- /tests/test_schema_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | from pyspark.sql.types import ArrayType, DecimalType, DoubleType, IntegerType, StringType, StructField, StructType 5 | 6 | from chispa.schema_comparer import ( 7 | SchemasNotEqualError, 8 | are_schemas_equal_ignore_nullable, 9 | are_structfields_equal, 10 | assert_schema_equality, 11 | assert_schema_equality_ignore_nullable, 12 | create_schema_comparison_tree, 13 | ) 14 | 15 | 16 | def describe_assert_schema_equality(): 17 | def it_does_nothing_when_equal(): 18 | s1 = StructType([ 19 | StructField("name", StringType(), True), 20 | StructField("age", IntegerType(), True), 21 | ]) 22 | s2 = StructType([ 23 | StructField("name", StringType(), True), 24 | StructField("age", IntegerType(), True), 25 | ]) 26 | assert_schema_equality(s1, s2) 27 | 28 | def it_throws_when_column_names_differ(): 29 | s1 = StructType([ 30 | StructField("HAHA", StringType(), True), 31 | StructField("age", IntegerType(), True), 32 | ]) 33 | s2 = StructType([ 34 | StructField("name", StringType(), True), 35 | StructField("age", IntegerType(), True), 36 | ]) 37 | with pytest.raises(SchemasNotEqualError): 38 | assert_schema_equality(s1, s2) 39 | 40 | def it_throws_when_schema_lengths_differ(): 41 | s1 = StructType([ 42 | StructField("name", StringType(), True), 43 | StructField("age", IntegerType(), True), 44 | ]) 45 | s2 = StructType([ 46 | StructField("name", StringType(), True), 47 | StructField("age", IntegerType(), True), 48 | StructField("fav_number", IntegerType(), True), 49 | ]) 50 | with pytest.raises(SchemasNotEqualError): 51 | assert_schema_equality(s1, s2) 52 | 53 | def it_throws_when_data_types_differ(): 54 | s1 = StructType([ 55 | StructField("name", StringType(), True), 56 | StructField("age", IntegerType(), True), 57 | StructField("height", DecimalType(10, 2), True), 58 | ]) 59 | s2 = StructType([ 60 | StructField("name", StringType(), True), 61 | StructField("age", IntegerType(), True), 62 | StructField("height", DecimalType(10, 3), True), 63 | ]) 64 | with pytest.raises(SchemasNotEqualError): 65 | assert_schema_equality(s1, s2) 66 | 67 | def it_throws_when_data_types_differ_with_enabled_ignore_nullability(): 68 | s1 = StructType([ 69 | StructField("name", StringType(), True), 70 | StructField("age", IntegerType(), True), 71 | StructField("height", DecimalType(10, 2), True), 72 | ]) 73 | s2 = StructType([ 74 | StructField("name", StringType(), True), 75 | StructField("age", IntegerType(), True), 76 | StructField("height", DecimalType(10, 3), True), 77 | ]) 78 | with pytest.raises(SchemasNotEqualError): 79 | assert_schema_equality(s1, s2, ignore_nullable=True) 80 | 81 | def it_throws_when_data_types_differ_with_enabled_ignore_metadata(): 82 | s1 = StructType([ 83 | StructField("name", StringType(), True), 84 | StructField("age", IntegerType(), True), 85 | StructField("height", DecimalType(10, 2), True), 86 | ]) 87 | s2 = StructType([ 88 | StructField("name", StringType(), True), 89 | StructField("age", IntegerType(), True), 90 | StructField("height", DecimalType(10, 3), True), 91 | ]) 92 | with pytest.raises(SchemasNotEqualError): 93 | assert_schema_equality(s1, s2, ignore_metadata=True) 94 | 95 | 96 | def describe_tree_string(): 97 | def it_prints_correctly_for_wide_schemas(): 98 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas.txt") as f: 99 | expected = f.read() 100 | 101 | s1 = StructType([ 102 | StructField("name", StringType(), True), 103 | StructField("age", IntegerType(), True), 104 | StructField("fav_number", IntegerType(), True), 105 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 106 | StructField( 107 | "fav_colors", 108 | StructType([ 109 | StructField("red", IntegerType(), True), 110 | StructField("green", IntegerType(), True), 111 | StructField("blue", IntegerType(), True), 112 | ]), 113 | ), 114 | ]) 115 | 116 | s2 = StructType([ 117 | StructField("name", StringType(), True), 118 | StructField("age", IntegerType(), True), 119 | StructField("fav_number", IntegerType(), True), 120 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 121 | StructField( 122 | "fav_colors", 123 | StructType([ 124 | StructField("orange", IntegerType(), True), 125 | StructField("green", IntegerType(), True), 126 | StructField("yellow", IntegerType(), True), 127 | ]), 128 | ), 129 | ]) 130 | 131 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=False, ignore_metadata=False) 132 | 133 | assert repr(result) + "\n" == expected 134 | 135 | def it_prints_correctly_for_wide_schemas_multiple_nested_structs(): 136 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas_multiple_nested_structs.txt") as f: 137 | expected = f.read() 138 | 139 | s1 = StructType([ 140 | StructField("name", StringType(), True), 141 | StructField( 142 | "fav_genres", 143 | StructType([ 144 | StructField( 145 | "rock", 146 | StructType([ 147 | StructField("metal", IntegerType(), True), 148 | StructField("punk", IntegerType(), True), 149 | ]), 150 | True, 151 | ), 152 | StructField( 153 | "electronic", 154 | StructType([ 155 | StructField("house", IntegerType(), True), 156 | StructField("dubstep", IntegerType(), True), 157 | ]), 158 | True, 159 | ), 160 | ]), 161 | ), 162 | ]) 163 | 164 | s2 = StructType([ 165 | StructField("name", StringType(), True), 166 | StructField( 167 | "fav_genres", 168 | StructType([ 169 | StructField( 170 | "rock", 171 | StructType([ 172 | StructField("metal", IntegerType(), True), 173 | StructField("classic", IntegerType(), True), 174 | ]), 175 | True, 176 | ), 177 | StructField( 178 | "electronic", 179 | StructType([ 180 | StructField("house", IntegerType(), True), 181 | StructField("dubstep", IntegerType(), True), 182 | ]), 183 | True, 184 | ), 185 | StructField( 186 | "pop", 187 | StructType([ 188 | StructField("pop", IntegerType(), True), 189 | ]), 190 | True, 191 | ), 192 | ]), 193 | ), 194 | ]) 195 | 196 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=False, ignore_metadata=False) 197 | assert repr(result) + "\n" == expected 198 | 199 | def it_prints_correctly_for_wide_schemas_ignore_nullable(): 200 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas_ignore_nullable.txt") as f: 201 | expected = f.read() 202 | 203 | s1 = StructType([ 204 | StructField("name", StringType(), True), 205 | StructField("age", IntegerType(), True), 206 | StructField("fav_number", IntegerType(), True), 207 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 208 | StructField( 209 | "fav_colors", 210 | StructType([ 211 | StructField("red", IntegerType(), True), 212 | StructField("green", IntegerType(), True), 213 | StructField("blue", IntegerType(), True), 214 | ]), 215 | ), 216 | ]) 217 | 218 | s2 = StructType([ 219 | StructField("name", StringType(), True), 220 | StructField("age", IntegerType(), False), 221 | StructField("fav_number", IntegerType(), True), 222 | StructField("fav_numbers", ArrayType(IntegerType(), True), False), 223 | StructField( 224 | "fav_colors", 225 | StructType([ 226 | StructField("orange", IntegerType(), True), 227 | StructField("green", IntegerType(), False), 228 | StructField("yellow", IntegerType(), True), 229 | ]), 230 | ), 231 | ]) 232 | 233 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=True, ignore_metadata=False) 234 | 235 | assert repr(result) + "\n" == expected 236 | 237 | def it_prints_correctly_for_wide_schemas_different_lengths(): 238 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas_different_lengths.txt") as f: 239 | expected = f.read() 240 | 241 | s1 = StructType([ 242 | StructField("name", StringType(), True), 243 | StructField("age", IntegerType(), True), 244 | StructField("fav_number", IntegerType(), True), 245 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 246 | StructField( 247 | "fav_colors", 248 | StructType([ 249 | StructField("red", IntegerType(), True), 250 | StructField("green", IntegerType(), True), 251 | StructField("blue", IntegerType(), True), 252 | ]), 253 | ), 254 | ]) 255 | 256 | s2 = StructType([ 257 | StructField("name", StringType(), True), 258 | StructField("age", IntegerType(), True), 259 | StructField("fav_number", IntegerType(), True), 260 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 261 | StructField( 262 | "fav_colors", 263 | StructType([ 264 | StructField("orange", IntegerType(), True), 265 | StructField("green", IntegerType(), True), 266 | StructField("yellow", IntegerType(), True), 267 | StructField("purple", IntegerType(), True), 268 | ]), 269 | ), 270 | StructField("phone_number", StringType(), True), 271 | ]) 272 | 273 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=False, ignore_metadata=False) 274 | assert repr(result) + "\n" == expected 275 | 276 | def it_prints_correctly_for_wide_schemas_ignore_metadata(): 277 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas_ignore_metadata.txt") as f: 278 | expected = f.read() 279 | 280 | s1 = StructType([ 281 | StructField("name", StringType(), True, {"foo": "bar"}), 282 | StructField("age", IntegerType(), True), 283 | StructField("fav_number", IntegerType(), True), 284 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 285 | StructField( 286 | "fav_colors", 287 | StructType([ 288 | StructField("red", IntegerType(), True), 289 | StructField("green", IntegerType(), True), 290 | StructField("blue", IntegerType(), True), 291 | ]), 292 | ), 293 | ]) 294 | 295 | s2 = StructType([ 296 | StructField("name", StringType(), True, {"foo": "baz"}), 297 | StructField("age", IntegerType(), True), 298 | StructField("fav_number", IntegerType(), True), 299 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 300 | StructField( 301 | "fav_colors", 302 | StructType([ 303 | StructField("orange", IntegerType(), True), 304 | StructField("green", IntegerType(), True), 305 | StructField("yellow", IntegerType(), True), 306 | ]), 307 | ), 308 | ]) 309 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=False, ignore_metadata=True) 310 | assert repr(result) + "\n" == expected 311 | 312 | def it_prints_correctly_for_wide_schemas_with_metadata(): 313 | with open("tests/data/tree_string/it_prints_correctly_for_wide_schemas_with_metadata.txt") as f: 314 | expected = f.read() 315 | 316 | s1 = StructType([ 317 | StructField("name", StringType(), True, {"foo": "bar"}), 318 | StructField("age", IntegerType(), True), 319 | StructField("fav_number", IntegerType(), True), 320 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 321 | StructField( 322 | "fav_colors", 323 | StructType([ 324 | StructField("red", IntegerType(), True), 325 | StructField("green", IntegerType(), True), 326 | StructField("blue", IntegerType(), True), 327 | ]), 328 | ), 329 | ]) 330 | 331 | s2 = StructType([ 332 | StructField("name", StringType(), True, {"foo": "baz"}), 333 | StructField("age", IntegerType(), True), 334 | StructField("fav_number", IntegerType(), True), 335 | StructField("fav_numbers", ArrayType(IntegerType(), True), True), 336 | StructField( 337 | "fav_colors", 338 | StructType([ 339 | StructField("orange", IntegerType(), True), 340 | StructField("green", IntegerType(), True), 341 | StructField("yellow", IntegerType(), True), 342 | ]), 343 | ), 344 | ]) 345 | 346 | result = create_schema_comparison_tree(s1, s2, ignore_nullable=False, ignore_metadata=False) 347 | assert repr(result) + "\n" == expected 348 | 349 | 350 | def describe_assert_schema_equality_ignore_nullable(): 351 | def it_has_good_error_messages_for_different_sized_schemas(): 352 | s1 = StructType([ 353 | StructField("name", StringType(), True), 354 | StructField("age", IntegerType(), True), 355 | ]) 356 | s2 = StructType([ 357 | StructField("name", StringType(), False), 358 | StructField("age", IntegerType(), True), 359 | StructField("something", IntegerType(), True), 360 | StructField("else", IntegerType(), True), 361 | ]) 362 | with pytest.raises(SchemasNotEqualError): 363 | assert_schema_equality_ignore_nullable(s1, s2) 364 | 365 | def it_does_nothing_when_equal(): 366 | s1 = StructType([ 367 | StructField("name", StringType(), True), 368 | StructField("age", IntegerType(), True), 369 | ]) 370 | s2 = StructType([ 371 | StructField("name", StringType(), True), 372 | StructField("age", IntegerType(), True), 373 | ]) 374 | assert_schema_equality_ignore_nullable(s1, s2) 375 | 376 | def it_does_nothing_when_only_nullable_flag_is_different(): 377 | s1 = StructType([ 378 | StructField("name", StringType(), True), 379 | StructField("age", IntegerType(), True), 380 | ]) 381 | s2 = StructType([ 382 | StructField("name", StringType(), True), 383 | StructField("age", IntegerType(), False), 384 | ]) 385 | assert_schema_equality_ignore_nullable(s1, s2) 386 | 387 | 388 | def describe_are_schemas_equal_ignore_nullable(): 389 | def it_returns_true_when_only_nullable_flag_is_different(): 390 | s1 = StructType([ 391 | StructField("name", StringType(), True), 392 | StructField("age", IntegerType(), True), 393 | StructField("coords", ArrayType(DoubleType(), True), True), 394 | ]) 395 | s2 = StructType([ 396 | StructField("name", StringType(), True), 397 | StructField("age", IntegerType(), False), 398 | StructField("coords", ArrayType(DoubleType(), True), False), 399 | ]) 400 | assert are_schemas_equal_ignore_nullable(s1, s2) is True 401 | 402 | def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): 403 | s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) 404 | s2 = StructType([StructField("coords", ArrayType(DoubleType(), False), True)]) 405 | assert are_schemas_equal_ignore_nullable(s1, s2) is True 406 | 407 | def it_returns_true_when_only_nullable_flag_is_different_within_nested_array_element(): 408 | s1 = StructType([StructField("coords", ArrayType(ArrayType(DoubleType(), True), True), True)]) 409 | s2 = StructType([StructField("coords", ArrayType(ArrayType(DoubleType(), False), True), True)]) 410 | assert are_schemas_equal_ignore_nullable(s1, s2) is True 411 | 412 | def it_returns_false_when_the_element_type_is_different_within_array(): 413 | s1 = StructType([StructField("coords", ArrayType(DoubleType(), True), True)]) 414 | s2 = StructType([StructField("coords", ArrayType(IntegerType(), True), True)]) 415 | assert are_schemas_equal_ignore_nullable(s1, s2) is False 416 | 417 | def it_returns_false_when_column_names_differ(): 418 | s1 = StructType([ 419 | StructField("blah", StringType(), True), 420 | StructField("age", IntegerType(), True), 421 | ]) 422 | s2 = StructType([ 423 | StructField("name", StringType(), True), 424 | StructField("age", IntegerType(), False), 425 | ]) 426 | assert are_schemas_equal_ignore_nullable(s1, s2) is False 427 | 428 | def it_returns_false_when_columns_have_different_order(): 429 | s1 = StructType([ 430 | StructField("blah", StringType(), True), 431 | StructField("age", IntegerType(), True), 432 | ]) 433 | s2 = StructType([ 434 | StructField("age", IntegerType(), False), 435 | StructField("blah", StringType(), True), 436 | ]) 437 | assert are_schemas_equal_ignore_nullable(s1, s2) is False 438 | 439 | 440 | def describe_are_structfields_equal(): 441 | def it_returns_true_when_only_nullable_flag_is_different_within_array_element(): 442 | s1 = StructField("coords", ArrayType(DoubleType(), True), True) 443 | s2 = StructField("coords", ArrayType(DoubleType(), False), True) 444 | assert are_structfields_equal(s1, s2, True) is True 445 | 446 | def it_returns_false_when_the_element_type_is_different_within_array(): 447 | s1 = StructField("coords", ArrayType(DoubleType(), True), True) 448 | s2 = StructField("coords", ArrayType(IntegerType(), True), True) 449 | assert are_structfields_equal(s1, s2, True) is False 450 | 451 | def it_returns_true_when_the_element_type_is_same_within_struct(): 452 | s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) 453 | s2 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) 454 | assert are_structfields_equal(s1, s2, True) is True 455 | 456 | def it_returns_false_when_the_element_type_is_different_within_struct(): 457 | s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) 458 | s2 = StructField("coords", StructType([StructField("hello", IntegerType(), True)]), True) 459 | assert are_structfields_equal(s1, s2, True) is False 460 | 461 | def it_returns_false_when_the_element_name_is_different_within_struct(): 462 | s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) 463 | s2 = StructField("coords", StructType([StructField("world", DoubleType(), True)]), True) 464 | assert are_structfields_equal(s1, s2, True) is False 465 | 466 | def it_returns_true_when_different_nullability_within_struct(): 467 | s1 = StructField("coords", StructType([StructField("hello", DoubleType(), True)]), True) 468 | s2 = StructField("coords", StructType([StructField("hello", DoubleType(), False)]), True) 469 | assert are_structfields_equal(s1, s2, True) is True 470 | 471 | def it_returns_false_when_metadata_differs(): 472 | s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) 473 | s2 = StructField("coords", StringType(), True, {"hi": "no"}) 474 | assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=False) is False 475 | 476 | def it_allows_metadata_to_be_ignored(): 477 | s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) 478 | s2 = StructField("coords", StringType(), True, {"hi": "no"}) 479 | assert are_structfields_equal(s1, s2, ignore_nullability=False, ignore_metadata=True) is True 480 | 481 | def it_allows_nullability_and_metadata_to_be_ignored(): 482 | s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) 483 | s2 = StructField("coords", StringType(), False, {"hi": "no"}) 484 | assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=True) is True 485 | 486 | def it_returns_true_when_metadata_is_the_same(): 487 | s1 = StructField("coords", StringType(), True, {"hi": "whatever"}) 488 | s2 = StructField("coords", StringType(), True, {"hi": "whatever"}) 489 | assert are_structfields_equal(s1, s2, ignore_nullability=True, ignore_metadata=False) is True 490 | -------------------------------------------------------------------------------- /tests/test_structfield_comparer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructField, StructType 4 | 5 | from chispa.structfield_comparer import are_structfields_equal 6 | 7 | 8 | def describe_are_structfields_equal(): 9 | def it_returns_true_when_structfields_are_the_same(): 10 | sf1 = StructField("hi", IntegerType(), True) 11 | sf2 = StructField("hi", IntegerType(), True) 12 | assert are_structfields_equal(sf1, sf2) is True 13 | 14 | def it_returns_false_when_column_names_are_different(): 15 | sf1 = StructField("hello", IntegerType(), True) 16 | sf2 = StructField("hi", IntegerType(), True) 17 | assert are_structfields_equal(sf1, sf2) is False 18 | 19 | def it_returns_false_when_nullable_property_is_different(): 20 | sf1 = StructField("hi", IntegerType(), False) 21 | sf2 = StructField("hi", IntegerType(), True) 22 | assert are_structfields_equal(sf1, sf2) is False 23 | 24 | def it_can_perform_nullability_insensitive_comparisons(): 25 | sf1 = StructField("hi", IntegerType(), False) 26 | sf2 = StructField("hi", IntegerType(), True) 27 | assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is True 28 | 29 | def it_returns_true_when_nested_types_are_the_same(): 30 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 31 | sf2 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 32 | assert are_structfields_equal(sf1, sf2) is True 33 | 34 | def it_returns_false_when_nested_names_are_different(): 35 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 36 | sf2 = StructField("hi", StructType([StructField("developer", IntegerType(), False)]), False) 37 | assert are_structfields_equal(sf1, sf2) is False 38 | 39 | def it_returns_false_when_nested_types_are_different(): 40 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 41 | sf2 = StructField("hi", StructType([StructField("world", DoubleType(), False)]), False) 42 | assert are_structfields_equal(sf1, sf2) is False 43 | 44 | def it_returns_false_when_nested_types_have_different_nullability(): 45 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 46 | sf2 = StructField("hi", StructType([StructField("world", IntegerType(), True)]), False) 47 | assert are_structfields_equal(sf1, sf2) is False 48 | 49 | def it_returns_false_when_nested_types_are_different_with_ignore_nullable_true(): 50 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 51 | sf2 = StructField("hi", StructType([StructField("developer", IntegerType(), False)]), False) 52 | assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is False 53 | 54 | def it_returns_true_when_nested_types_have_different_nullability_with_ignore_null_true(): 55 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 56 | sf2 = StructField("hi", StructType([StructField("world", IntegerType(), True)]), False) 57 | assert are_structfields_equal(sf1, sf2, ignore_nullability=True) is True 58 | 59 | def it_returns_true_when_inner_metadata_is_different_but_ignored(): 60 | sf1 = StructField("hi", StructType([StructField("world", IntegerType(), False)]), False) 61 | sf2 = StructField("hi", StructType([StructField("world", IntegerType(), False, {"a": "b"})]), False) 62 | assert are_structfields_equal(sf1, sf2, ignore_metadata=True) is True 63 | 64 | def it_returns_true_when_inner_array_metadata_is_different_but_ignored(): 65 | sf1 = StructField( 66 | "hi", 67 | ArrayType( 68 | StructType([ 69 | StructField("world", IntegerType(), True, {"comment": "Comment"}), 70 | ]), 71 | True, 72 | ), 73 | True, 74 | ) 75 | sf2 = StructField( 76 | "hi", 77 | ArrayType( 78 | StructType([ 79 | StructField("world", IntegerType(), True, {"comment": "Some other comment"}), 80 | ]), 81 | True, 82 | ), 83 | True, 84 | ) 85 | assert are_structfields_equal(sf1, sf2, ignore_metadata=True) is True 86 | --------------------------------------------------------------------------------