├── .coveragerc ├── .flake8 ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── .travis.yml ├── AUTHORS.rst ├── CHANGELOG.rst ├── LICENSE.txt ├── README.rst ├── docs ├── Makefile ├── _static │ └── .gitignore ├── authors.rst ├── changelog.rst ├── conf.py ├── index.rst └── license.rst ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── pywrangler │ ├── __init__.py │ ├── base.py │ ├── benchmark.py │ ├── dask │ ├── __init__.py │ ├── base.py │ └── benchmark.py │ ├── exceptions.py │ ├── pandas │ ├── __init__.py │ ├── base.py │ ├── benchmark.py │ ├── util.py │ └── wranglers │ │ ├── __init__.py │ │ └── interval_identifier.py │ ├── pyspark │ ├── __init__.py │ ├── base.py │ ├── benchmark.py │ ├── pipeline.py │ ├── testing.py │ ├── types.py │ ├── util.py │ └── wranglers │ │ ├── __init__.py │ │ └── interval_identifier.py │ ├── util │ ├── __init__.py │ ├── _pprint.py │ ├── dependencies.py │ ├── helper.py │ ├── sanitizer.py │ ├── testing │ │ ├── __init__.py │ │ ├── datatestcase.py │ │ ├── mutants.py │ │ ├── plainframe.py │ │ └── util.py │ └── types.py │ └── wranglers.py ├── tests ├── __init__.py ├── conftest.py ├── dask │ ├── __init__.py │ ├── test_base.py │ └── test_benchmark.py ├── pandas │ ├── __init__.py │ ├── test_base.py │ ├── test_benchmark.py │ ├── test_util.py │ └── wranglers │ │ ├── __init__.py │ │ └── test_interval_identifier.py ├── pyspark │ ├── __init__.py │ ├── conftest.py │ ├── test_base.py │ ├── test_benchmark.py │ ├── test_environment.py │ ├── test_pipeline.py │ ├── test_testing.py │ ├── test_util.py │ └── wranglers │ │ ├── __init__.py │ │ └── test_interval_identifier.py ├── test_base.py ├── test_benchmark.py ├── test_data │ ├── __init__.py │ └── interval_identifier.py ├── test_wranglers.py └── util │ ├── __init__.py │ ├── test_dependencies.py │ ├── test_helper.py │ ├── test_pprint.py │ ├── test_sanitizer.py │ └── testing │ ├── __init__.py │ ├── test_datatestcase.py │ ├── test_mutants.py │ ├── test_plainframe.py │ └── test_util.py ├── tox.ini └── travisci ├── code_coverage.sh ├── fix_paths.py ├── java_install.sh └── tox_invocation.sh /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = pywrangler 5 | omit = 6 | */PyScaffold*.egg/* 7 | */pyscaffold/contrib/* 8 | 9 | [paths] 10 | source = 11 | #src/ 12 | */site-packages/ 13 | 14 | [report] 15 | # Regexes for lines to exclude from consideration 16 | exclude_lines = 17 | # Have to re-enable the standard pragma 18 | pragma: no cover 19 | 20 | # Don't complain about missing debug-only code: 21 | def __repr__ 22 | if self\.debug 23 | 24 | # Don't complain if tests don't hit defensive assertion code: 25 | raise AssertionError 26 | raise NotImplementedError 27 | 28 | # Don't complain if non-runnable code isn't run: 29 | if 0: 30 | if __name__ == .__main__.: 31 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | 16 | # Project files 17 | .ropeproject 18 | .project 19 | .pydevproject 20 | .settings 21 | .idea 22 | tags 23 | 24 | # Package files 25 | *.egg 26 | *.eggs/ 27 | .installed.cfg 28 | *.egg-info 29 | 30 | # Unittest and coverage 31 | htmlcov/* 32 | .coverage 33 | .tox 34 | junit.xml 35 | coverage.xml 36 | .pytest_cache/ 37 | 38 | # Build and docs folder/files 39 | build/* 40 | dist/* 41 | sdist/* 42 | docs/api/* 43 | docs/_rst/* 44 | docs/_build/* 45 | cover/* 46 | MANIFEST 47 | 48 | # Per-project virtualenvs 49 | .venv*/ 50 | /notebooks/ 51 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | line_length=79 3 | indent=' ' 4 | skip=.tox,.venv,build,dist 5 | known_standard_library=setuptools,pkg_resources 6 | known_test=pytest 7 | known_first_party=pywrangler 8 | sections=FUTURE,STDLIB,COMPAT,TEST,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 9 | default_section=THIRDPARTY 10 | multi_line_output=3 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^docs/conf.py' 2 | 3 | repos: 4 | - repo: git://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.1.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=no'] 19 | - id: flake8 20 | 21 | - repo: https://github.com/pre-commit/mirrors-isort 22 | rev: v4.3.4 23 | hooks: 24 | - id: isort 25 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # sudo: false 2 | 3 | language: python 4 | 5 | python: 6 | # - '3.5' 7 | # - '3.6' 8 | - '3.7' 9 | 10 | env: 11 | - ENV_STRING=master 12 | # - ENV_STRING=pandas0.24.1 13 | # - ENV_STRING=pandas0.24.0 14 | # 15 | # - ENV_STRING=pandas0.23.4 16 | # - ENV_STRING=pandas0.23.3 17 | # - ENV_STRING=pandas0.23.2 18 | # - ENV_STRING=pandas0.23.1 19 | # - ENV_STRING=pandas0.23.0 20 | # 21 | # - ENV_STRING=pandas0.22.0 22 | # 23 | # - ENV_STRING=pandas0.21.1 24 | # - ENV_STRING=pandas0.21.0 25 | # 26 | # - ENV_STRING=pandas0.20.3 27 | # - ENV_STRING=pandas0.20.2 28 | # - ENV_STRING=pandas0.20.1 29 | # - ENV_STRING=pandas0.20.0 30 | # 31 | # - ENV_STRING=pandas0.19.2 32 | # - ENV_STRING=pandas0.19.1 33 | # - ENV_STRING=pandas0.19.0 34 | # 35 | # - ENV_STRING=pyspark2.4.0 36 | # - ENV_STRING=pyspark2.3.1 37 | 38 | # - ENV_STRING=dask1.1.5 39 | 40 | 41 | # Remove python/pandas version interactions which do not have wheels on pypi 42 | matrix: 43 | exclude: 44 | - python: '3.7' 45 | env: ENV_STRING=pandas0.22.0 46 | - python: '3.7' 47 | env: ENV_STRING=pandas0.21.1 48 | - python: '3.7' 49 | env: ENV_STRING=pandas0.21.0 50 | - python: '3.7' 51 | env: ENV_STRING=pandas0.20.3 52 | - python: '3.7' 53 | env: ENV_STRING=pandas0.20.2 54 | - python: '3.7' 55 | env: ENV_STRING=pandas0.20.1 56 | - python: '3.7' 57 | env: ENV_STRING=pandas0.20.0 58 | - python: '3.7' 59 | env: ENV_STRING=pandas0.19.2 60 | - python: '3.7' 61 | env: ENV_STRING=pandas0.19.1 62 | - python: '3.7' 63 | env: ENV_STRING=pandas0.19.0 64 | - python: '3.6' 65 | env: ENV_STRING=pandas0.19.0 66 | 67 | dist: xenial 68 | 69 | before_install: 70 | - source travisci/java_install.sh 71 | 72 | install: 73 | - travis_retry pip install --upgrade pip 74 | - travis_retry pip install --upgrade setuptools 75 | - travis_retry pip install codecov flake8 tox 76 | 77 | script: 78 | - source travisci/tox_invocation.sh 79 | 80 | after_success: 81 | - source travisci/code_coverage.sh 82 | 83 | cache: pip 84 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * mansenfranzen 6 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1.0 6 | ============= 7 | 8 | This is the initial release of pywrangler. 9 | 10 | - Enable raw, valid and enumerated return type for ``IntervalIdentifier`` (`#23 `_). 11 | - Enable variable sequence lengths for ``IntervalIdentifier`` (`#23 `_). 12 | - Add ``DataTestCase`` and ``TestCollection`` as standards for data centric test cases (`#23 `_). 13 | - Add computation engine independent data abstraction ``PlainFrame`` (`#23 `_). 14 | - Add ``VectorizedCumSum`` pyspark implementation for ``IntervalIdentifier`` wrangler (`#7 `_). 15 | - Add benchmark utilities for pandas, spark and dask wranglers (`#5 `_). 16 | - Add sequential ``NaiveIterator`` and vectorized ``VectorizedCumSum`` pandas implementations for ``IntervalIdentifier`` wrangler (`#2 `_). 17 | - Add ``PandasWrangler`` (`#2 `_). 18 | - Add ``IntervalIdentifier`` wrangler interface (`#2 `_). 19 | - Add ``BaseWrangler`` class defining wrangler interface (`#1 `_). 20 | - Enable ``pandas`` and ``pyspark`` testing on TravisCI (`#1 `_). 21 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 mansenfranzen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | pywrangler 3 | ========== 4 | 5 | .. image:: https://travis-ci.org/mansenfranzen/pywrangler.svg?branch=master 6 | :target: https://travis-ci.org/mansenfranzen/pywrangler 7 | 8 | .. image:: https://codecov.io/gh/mansenfranzen/pywrangler/branch/master/graph/badge.svg 9 | :target: https://codecov.io/gh/mansenfranzen/pywrangler 10 | 11 | .. image:: https://badge.fury.io/gh/mansenfranzen%2Fpywrangler.svg 12 | :target: https://badge.fury.io/gh/mansenfranzen%2Fpywrangler 13 | 14 | .. image:: https://img.shields.io/badge/code%20style-flake8-orange.svg 15 | :target: https://www.python.org/dev/peps/pep-0008/ 16 | 17 | .. image:: https://img.shields.io/badge/python-3.5+-blue.svg 18 | :target: https://www.python.org/downloads/release/python-370/ 19 | 20 | .. image:: https://img.shields.io/badge/License-MIT-blue.svg 21 | :target: https://lbesson.mit-license.org/ 22 | 23 | .. image:: https://badges.frapsoft.com/os/v1/open-source.png?v=103 24 | :target: https://github.com/ellerbrock/open-source-badges/ 25 | 26 | The pydata ecosystem provides a rich set of tools (e.g pandas, dask and pyspark) 27 | to handle most data wrangling tasks with ease. When dealing with data on a 28 | daily basis, however, one often encounters **problems which go beyond the 29 | common dataframe API usage**. They typically require a combination of multiple 30 | transformations and aggregations in order to achieve the desired outcome. For 31 | example, extracting intervals with given start and end values from raw time 32 | series is out of scope for native dataframe functionality. 33 | 34 | **pywrangler** accomplishes such requirements with care while exposing so 35 | called *data wranglers*. A data wrangler serves a specific use case just like 36 | the one mentioned above. It takes one or more input dataframes, applies a 37 | computation which is usually built on top of existing dataframe API, and 38 | returns one or more output dataframes. 39 | 40 | Why should I use pywrangler? 41 | ============================ 42 | 43 | - you deal with data wrangling **problems** which are **beyond common dataframe API usage** 44 | - you are looking for a **framework with consistent API** to handle your data wrangling complexities 45 | - you need implementations tailored for **small data (pandas)** and **big data (dask and pyspark)** libraries 46 | 47 | You want **well tested, documented and benchmarked solutions**? If that's the case, pywrangler might be what you're looking for. 48 | 49 | Features 50 | ======== 51 | - supports pandas, dask and pyspark as computation engines 52 | - exposes consistent scikit-learn like API 53 | - provides backwards compatibility for pandas versions from 0.19.2 upwards 54 | - emphasises extensive tests and documentation 55 | - includes type annotations 56 | 57 | Thanks 58 | ====== 59 | We like to thank the pydata stack including `numpy `_, `pandas `_, `sklearn `_, `scipy `_, `dask `_ and `pyspark `_ and many more (and the open source community in general). 60 | 61 | Notes 62 | ===== 63 | 64 | - This project is currently under active development and has no release yet. 65 | - This project has been set up using PyScaffold 3.1. For details and usage information on PyScaffold see https://pyscaffold.org/. 66 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | AUTODOCDIR = api 10 | AUTODOCBUILD = sphinx-apidoc 11 | PROJECT = tswrangler 12 | MODULEDIR = ../src/tswrangler 13 | 14 | # User-friendly check for sphinx-build 15 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 16 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 17 | endif 18 | 19 | # Internal variables. 20 | PAPEROPT_a4 = -D latex_paper_size=a4 21 | PAPEROPT_letter = -D latex_paper_size=letter 22 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 23 | # the i18n builder cannot share the environment and doctrees with the others 24 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 25 | 26 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext doc-requirements 27 | 28 | help: 29 | @echo "Please use \`make ' where is one of" 30 | @echo " html to make standalone HTML files" 31 | @echo " dirhtml to make HTML files named index.html in directories" 32 | @echo " singlehtml to make a single large HTML file" 33 | @echo " pickle to make pickle files" 34 | @echo " json to make JSON files" 35 | @echo " htmlhelp to make HTML files and a HTML help project" 36 | @echo " qthelp to make HTML files and a qthelp project" 37 | @echo " devhelp to make HTML files and a Devhelp project" 38 | @echo " epub to make an epub" 39 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 40 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 41 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 42 | @echo " text to make text files" 43 | @echo " man to make manual pages" 44 | @echo " texinfo to make Texinfo files" 45 | @echo " info to make Texinfo files and run them through makeinfo" 46 | @echo " gettext to make PO message catalogs" 47 | @echo " changes to make an overview of all changed/added/deprecated items" 48 | @echo " xml to make Docutils-native XML files" 49 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 50 | @echo " linkcheck to check all external links for integrity" 51 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 52 | 53 | clean: 54 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 55 | 56 | $(AUTODOCDIR): $(MODULEDIR) 57 | mkdir -p $@ 58 | $(AUTODOCBUILD) -f -o $@ $^ 59 | 60 | doc-requirements: $(AUTODOCDIR) 61 | 62 | html: doc-requirements 63 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 64 | @echo 65 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 66 | 67 | dirhtml: doc-requirements 68 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 69 | @echo 70 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 71 | 72 | singlehtml: doc-requirements 73 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 74 | @echo 75 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 76 | 77 | pickle: doc-requirements 78 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 79 | @echo 80 | @echo "Build finished; now you can process the pickle files." 81 | 82 | json: doc-requirements 83 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 84 | @echo 85 | @echo "Build finished; now you can process the JSON files." 86 | 87 | htmlhelp: doc-requirements 88 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 89 | @echo 90 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 91 | ".hhp project file in $(BUILDDIR)/htmlhelp." 92 | 93 | qthelp: doc-requirements 94 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 95 | @echo 96 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 97 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 98 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/$(PROJECT).qhcp" 99 | @echo "To view the help file:" 100 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/$(PROJECT).qhc" 101 | 102 | devhelp: doc-requirements 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $HOME/.local/share/devhelp/$(PROJECT)" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $HOME/.local/share/devhelp/$(PROJEC)" 109 | @echo "# devhelp" 110 | 111 | epub: doc-requirements 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | patch-latex: 117 | find _build/latex -iname "*.tex" | xargs -- \ 118 | sed -i'' 's~includegraphics{~includegraphics\[keepaspectratio,max size={\\textwidth}{\\textheight}\]{~g' 119 | 120 | latex: doc-requirements 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | $(MAKE) patch-latex 123 | @echo 124 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 125 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 126 | "(use \`make latexpdf' here to do that automatically)." 127 | 128 | latexpdf: doc-requirements 129 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 130 | $(MAKE) patch-latex 131 | @echo "Running LaTeX files through pdflatex..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | latexpdfja: doc-requirements 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through platex and dvipdfmx..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | text: doc-requirements 142 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 143 | @echo 144 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 145 | 146 | man: doc-requirements 147 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 148 | @echo 149 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 150 | 151 | texinfo: doc-requirements 152 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 153 | @echo 154 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 155 | @echo "Run \`make' in that directory to run these through makeinfo" \ 156 | "(use \`make info' here to do that automatically)." 157 | 158 | info: doc-requirements 159 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 160 | @echo "Running Texinfo files through makeinfo..." 161 | make -C $(BUILDDIR)/texinfo info 162 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 163 | 164 | gettext: doc-requirements 165 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 166 | @echo 167 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 168 | 169 | changes: doc-requirements 170 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 171 | @echo 172 | @echo "The overview file is in $(BUILDDIR)/changes." 173 | 174 | linkcheck: doc-requirements 175 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 176 | @echo 177 | @echo "Link check complete; look for any errors in the above output " \ 178 | "or in $(BUILDDIR)/linkcheck/output.txt." 179 | 180 | doctest: doc-requirements 181 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 182 | @echo "Testing of doctests in the sources finished, look at the " \ 183 | "results in $(BUILDDIR)/doctest/output.txt." 184 | 185 | xml: doc-requirements 186 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 187 | @echo 188 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 189 | 190 | pseudoxml: doc-requirements 191 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 192 | @echo 193 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 194 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGELOG.rst 3 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # This file is execfile()d with the current directory set to its containing dir. 4 | # 5 | # Note that not all possible configuration values are present in this 6 | # autogenerated file. 7 | # 8 | # All configuration values have a default; values that are commented out 9 | # serve to show the default. 10 | 11 | import os 12 | import sys 13 | import inspect 14 | import shutil 15 | 16 | __location__ = os.path.join(os.getcwd(), os.path.dirname( 17 | inspect.getfile(inspect.currentframe()))) 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | sys.path.insert(0, os.path.join(__location__, '../src')) 23 | 24 | # -- Run sphinx-apidoc ------------------------------------------------------ 25 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 26 | # `sphinx-build -b html . _build/html`. See Issue: 27 | # https://github.com/rtfd/readthedocs.org/issues/1139 28 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 29 | # setup.py install" in the RTD Advanced Settings. 30 | # Additionally it helps us to avoid running apidoc manually 31 | 32 | try: # for Sphinx >= 1.7 33 | from sphinx.ext import apidoc 34 | except ImportError: 35 | from sphinx import apidoc 36 | 37 | output_dir = os.path.join(__location__, "api") 38 | module_dir = os.path.join(__location__, "../src/pywrangler") 39 | try: 40 | shutil.rmtree(output_dir) 41 | except FileNotFoundError: 42 | pass 43 | 44 | try: 45 | import sphinx 46 | from pkg_resources import parse_version 47 | 48 | cmd_line_template = "sphinx-apidoc -f -o {outputdir} {moduledir}" 49 | cmd_line = cmd_line_template.format(outputdir=output_dir, moduledir=module_dir) 50 | 51 | args = cmd_line.split(" ") 52 | if parse_version(sphinx.__version__) >= parse_version('1.7'): 53 | args = args[1:] 54 | 55 | apidoc.main(args) 56 | except Exception as e: 57 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 58 | 59 | # -- General configuration ----------------------------------------------------- 60 | 61 | # If your documentation needs a minimal Sphinx version, state it here. 62 | # needs_sphinx = '1.0' 63 | 64 | # Add any Sphinx extension module names here, as strings. They can be extensions 65 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 66 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 67 | 'sphinx.ext.autosummary', 'sphinx.ext.viewcode', 'sphinx.ext.coverage', 68 | 'sphinx.ext.doctest', 'sphinx.ext.ifconfig', 'sphinx.ext.mathjax', 69 | 'sphinx.ext.napoleon'] 70 | 71 | # Add any paths that contain templates here, relative to this directory. 72 | templates_path = ['_templates'] 73 | 74 | # The suffix of source filenames. 75 | source_suffix = '.rst' 76 | 77 | # The encoding of source files. 78 | # source_encoding = 'utf-8-sig' 79 | 80 | # The master toctree document. 81 | master_doc = 'index' 82 | 83 | # General information about the project. 84 | project = u'pywrangler' 85 | copyright = u'2019, mansenfranzen' 86 | 87 | # The version info for the project you're documenting, acts as replacement for 88 | # |version| and |release|, also used in various other places throughout the 89 | # built documents. 90 | # 91 | # The short X.Y version. 92 | version = '' # Is set by calling `setup.py docs` 93 | # The full version, including alpha/beta/rc tags. 94 | release = '' # Is set by calling `setup.py docs` 95 | 96 | # The language for content autogenerated by Sphinx. Refer to documentation 97 | # for a list of supported languages. 98 | # language = None 99 | 100 | # There are two options for replacing |today|: either, you set today to some 101 | # non-false value, then it is used: 102 | # today = '' 103 | # Else, today_fmt is used as the format for a strftime call. 104 | # today_fmt = '%B %d, %Y' 105 | 106 | # List of patterns, relative to source directory, that match files and 107 | # directories to ignore when looking for source files. 108 | exclude_patterns = ['_build'] 109 | 110 | # The reST default role (used for this markup: `text`) to use for all documents. 111 | # default_role = None 112 | 113 | # If true, '()' will be appended to :func: etc. cross-reference text. 114 | # add_function_parentheses = True 115 | 116 | # If true, the current module name will be prepended to all description 117 | # unit titles (such as .. function::). 118 | # add_module_names = True 119 | 120 | # If true, sectionauthor and moduleauthor directives will be shown in the 121 | # output. They are ignored by default. 122 | # show_authors = False 123 | 124 | # The name of the Pygments (syntax highlighting) style to use. 125 | pygments_style = 'sphinx' 126 | 127 | # A list of ignored prefixes for module index sorting. 128 | # modindex_common_prefix = [] 129 | 130 | # If true, keep warnings as "system message" paragraphs in the built documents. 131 | # keep_warnings = False 132 | 133 | 134 | # -- Options for HTML output --------------------------------------------------- 135 | 136 | # The theme to use for HTML and HTML Help pages. See the documentation for 137 | # a list of builtin themes. 138 | html_theme = 'alabaster' 139 | 140 | # Theme options are theme-specific and customize the look and feel of a theme 141 | # further. For a list of options available for each theme, see the 142 | # documentation. 143 | html_theme_options = { 144 | 'sidebar_width': '300px', 145 | 'page_width': '1200px' 146 | } 147 | 148 | # Add any paths that contain custom themes here, relative to this directory. 149 | # html_theme_path = [] 150 | 151 | # The name for this set of Sphinx documents. If None, it defaults to 152 | # " v documentation". 153 | try: 154 | from tswrangler import __version__ as version 155 | except ImportError: 156 | pass 157 | else: 158 | release = version 159 | 160 | # A shorter title for the navigation bar. Default is the same as html_title. 161 | # html_short_title = None 162 | 163 | # The name of an image file (relative to this directory) to place at the top 164 | # of the sidebar. 165 | # html_logo = "" 166 | 167 | # The name of an image file (within the static path) to use as favicon of the 168 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 169 | # pixels large. 170 | # html_favicon = None 171 | 172 | # Add any paths that contain custom static files (such as style sheets) here, 173 | # relative to this directory. They are copied after the builtin static files, 174 | # so a file named "default.css" will overwrite the builtin "default.css". 175 | html_static_path = ['_static'] 176 | 177 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 178 | # using the given strftime format. 179 | # html_last_updated_fmt = '%b %d, %Y' 180 | 181 | # If true, SmartyPants will be used to convert quotes and dashes to 182 | # typographically correct entities. 183 | # html_use_smartypants = True 184 | 185 | # Custom sidebar templates, maps document names to template names. 186 | # html_sidebars = {} 187 | 188 | # Additional templates that should be rendered to pages, maps page names to 189 | # template names. 190 | # html_additional_pages = {} 191 | 192 | # If false, no module index is generated. 193 | # html_domain_indices = True 194 | 195 | # If false, no index is generated. 196 | # html_use_index = True 197 | 198 | # If true, the index is split into individual pages for each letter. 199 | # html_split_index = False 200 | 201 | # If true, links to the reST sources are added to the pages. 202 | # html_show_sourcelink = True 203 | 204 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 205 | # html_show_sphinx = True 206 | 207 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 208 | # html_show_copyright = True 209 | 210 | # If true, an OpenSearch description file will be output, and all pages will 211 | # contain a tag referring to it. The value of this option must be the 212 | # base URL from which the finished HTML is served. 213 | # html_use_opensearch = '' 214 | 215 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 216 | # html_file_suffix = None 217 | 218 | # Output file base name for HTML help builder. 219 | htmlhelp_basename = 'pywrangler-doc' 220 | 221 | 222 | # -- Options for LaTeX output -------------------------------------------------- 223 | 224 | latex_elements = { 225 | # The paper size ('letterpaper' or 'a4paper'). 226 | # 'papersize': 'letterpaper', 227 | 228 | # The font size ('10pt', '11pt' or '12pt'). 229 | # 'pointsize': '10pt', 230 | 231 | # Additional stuff for the LaTeX preamble. 232 | # 'preamble': '', 233 | } 234 | 235 | # Grouping the document tree into LaTeX files. List of tuples 236 | # (source start file, target name, title, author, documentclass [howto/manual]). 237 | latex_documents = [ 238 | ('index', 'user_guide.tex', u'pywrangler Documentation', 239 | u'mansenfranzen', 'manual'), 240 | ] 241 | 242 | # The name of an image file (relative to this directory) to place at the top of 243 | # the title page. 244 | # latex_logo = "" 245 | 246 | # For "manual" documents, if this is true, then toplevel headings are parts, 247 | # not chapters. 248 | # latex_use_parts = False 249 | 250 | # If true, show page references after internal links. 251 | # latex_show_pagerefs = False 252 | 253 | # If true, show URL addresses after external links. 254 | # latex_show_urls = False 255 | 256 | # Documents to append as an appendix to all manuals. 257 | # latex_appendices = [] 258 | 259 | # If false, no module index is generated. 260 | # latex_domain_indices = True 261 | 262 | # -- External mapping ------------------------------------------------------------ 263 | python_version = '.'.join(map(str, sys.version_info[0:2])) 264 | intersphinx_mapping = { 265 | 'sphinx': ('http://www.sphinx-doc.org/en/stable', None), 266 | 'python': ('https://docs.python.org/' + python_version, None), 267 | 'matplotlib': ('https://matplotlib.org', None), 268 | 'numpy': ('https://docs.scipy.org/doc/numpy', None), 269 | 'sklearn': ('http://scikit-learn.org/stable', None), 270 | 'pandas': ('http://pandas.pydata.org/pandas-docs/stable', None), 271 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 272 | } 273 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | tswrangler 3 | ========== 4 | 5 | This is the documentation of **tswrangler**. 6 | 7 | .. note:: 8 | 9 | This is the main page of your project's `Sphinx`_ documentation. 10 | It is formatted in `reStructuredText`_. Add additional pages 11 | by creating rst-files in ``docs`` and adding them to the `toctree`_ below. 12 | Use then `references`_ in order to link them from this page, e.g. 13 | :ref:`authors` and :ref:`changes`. 14 | 15 | It is also possible to refer to the documentation of other Python packages 16 | with the `Python domain syntax`_. By default you can reference the 17 | documentation of `Sphinx`_, `Python`_, `NumPy`_, `SciPy`_, `matplotlib`_, 18 | `Pandas`_, `Scikit-Learn`_. You can add more by extending the 19 | ``intersphinx_mapping`` in your Sphinx's ``conf.py``. 20 | 21 | The pretty useful extension `autodoc`_ is activated by default and lets 22 | you include documentation from docstrings. Docstrings can be written in 23 | `Google style`_ (recommended!), `NumPy style`_ and `classical style`_. 24 | 25 | 26 | Contents 27 | ======== 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | 32 | License 33 | Authors 34 | Changelog 35 | Module Reference 36 | 37 | 38 | Indices and tables 39 | ================== 40 | 41 | * :ref:`genindex` 42 | * :ref:`modindex` 43 | * :ref:`search` 44 | 45 | .. _toctree: http://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html 46 | .. _reStructuredText: http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 47 | .. _references: http://www.sphinx-doc.org/en/stable/markup/inline.html 48 | .. _Python domain syntax: http://sphinx-doc.org/domains.html#the-python-domain 49 | .. _Sphinx: http://www.sphinx-doc.org/ 50 | .. _Python: http://docs.python.org/ 51 | .. _Numpy: http://docs.scipy.org/doc/numpy 52 | .. _SciPy: http://docs.scipy.org/doc/scipy/reference/ 53 | .. _matplotlib: https://matplotlib.org/contents.html# 54 | .. _Pandas: http://pandas.pydata.org/pandas-docs/stable 55 | .. _Scikit-Learn: http://scikit-learn.org/stable 56 | .. _autodoc: http://www.sphinx-doc.org/en/stable/ext/autodoc.html 57 | .. _Google style: https://github.com/google/styleguide/blob/gh-pages/pyguide.md#38-comments-and-docstrings 58 | .. _NumPy style: https://numpydoc.readthedocs.io/en/latest/format.html 59 | .. _classical style: http://www.sphinx-doc.org/en/stable/domains.html#info-field-lists 60 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. include:: ../LICENSE.txt 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # DEPRECATION WARNING: 3 | # 4 | # The file `requirements.txt` does not influence the package dependencies and 5 | # will not be automatically created in the next version of PyScaffold (v4.x). 6 | # 7 | # Please have look at the docs for better alternatives 8 | # (`Dependency Management` section). 9 | # ============================================================================= 10 | # 11 | # Add your pinned requirements so that they can be easily installed with: 12 | # pip install -r requirements.txt 13 | # Remember to also add them in setup.cfg but unpinned. 14 | # Example: 15 | # numpy==1.13.3 16 | # scipy==1.0 17 | 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | [metadata] 3 | name = pywrangler 4 | description = Data wrangling for time series 5 | author = mansenfranzen 6 | author-email = franz.woellert@gmail.com 7 | license = mit 8 | url = https://github.com/mansenfranzen/pywrangler 9 | long-description = file: README.rst 10 | platforms = any 11 | classifiers = 12 | Development Status :: 4 - Beta 13 | Programming Language :: Python 14 | 15 | [options] 16 | zip_safe = False 17 | packages = find: 18 | include_package_data = True 19 | package_dir = 20 | =src 21 | # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! 22 | setup_requires = pyscaffold>=3.1a0,<3.2a0 23 | install_requires = 24 | pandas 25 | tabulate 26 | 27 | [options.packages.find] 28 | where = src 29 | exclude = 30 | tests 31 | 32 | [options.extras_require] 33 | testing = 34 | pytest 35 | pytest-cov 36 | tox 37 | memory_profiler 38 | pyarrow 39 | 40 | dev = 41 | sphinx 42 | twine 43 | 44 | 45 | [options.entry_points] 46 | 47 | 48 | [test] 49 | extras = True 50 | 51 | [tool:pytest] 52 | addopts = 53 | --cov pywrangler --cov-report term-missing:skip-covered --cov-report xml 54 | --verbose 55 | norecursedirs = 56 | dist 57 | build 58 | .tox 59 | markers = 60 | pandas: marks all pandas tests 61 | pyspark: marks all pyspark tests 62 | dask: marks all dask tests 63 | 64 | testpaths = tests 65 | 66 | [aliases] 67 | build = bdist_wheel 68 | release = build upload 69 | 70 | [bdist_wheel] 71 | # Use this option if your package is pure-python 72 | universal = 1 73 | 74 | [build_sphinx] 75 | source_dir = docs 76 | build_dir = docs/_build 77 | 78 | [devpi:upload] 79 | # Options for the devpi: PyPI server and packaging tool 80 | # VCS export must be deactivated since we are using setuptools-scm 81 | no-vcs = 1 82 | formats = bdist_wheel 83 | 84 | [flake8] 85 | # Some sane defaults for the code style checker flake8 86 | exclude = 87 | .tox 88 | build 89 | dist 90 | .eggs 91 | docs/conf.py 92 | 93 | [pyscaffold] 94 | # PyScaffold's parameters when the project was created. 95 | # This will be used when updating. Do not change! 96 | version = 3.1 97 | package = pywrangler 98 | extensions = 99 | pre_commit 100 | tox 101 | travis 102 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Setup file for pywrangler. 5 | Use setup.cfg to configure your project. 6 | 7 | This file was generated with PyScaffold 3.1. 8 | PyScaffold helps you to put up the scaffold of your new Python project. 9 | Learn more under: https://pyscaffold.org/ 10 | """ 11 | import sys 12 | 13 | from pkg_resources import require, VersionConflict 14 | from setuptools import setup 15 | 16 | try: 17 | require('setuptools>=38.3') 18 | except VersionConflict: 19 | print("Error: version of setuptools is too old (<38.3)!") 20 | sys.exit(1) 21 | 22 | 23 | if __name__ == "__main__": 24 | setup(use_pyscaffold=True) 25 | -------------------------------------------------------------------------------- /src/pywrangler/__init__.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import get_distribution, DistributionNotFound 2 | 3 | try: 4 | dist_name = __name__ 5 | __version__ = get_distribution(dist_name).version 6 | except DistributionNotFound: 7 | __version__ = 'unknown' 8 | finally: 9 | del get_distribution, DistributionNotFound 10 | -------------------------------------------------------------------------------- /src/pywrangler/base.py: -------------------------------------------------------------------------------- 1 | """This module contains the BaseWrangler definition and the wrangler base 2 | classes including wrangler descriptions and parameters. 3 | 4 | """ 5 | import inspect 6 | from abc import ABC, abstractmethod 7 | 8 | from pywrangler.util import _pprint 9 | from pywrangler.util.helper import get_param_names 10 | 11 | 12 | class BaseWrangler(ABC): 13 | """Defines the basic interface common to all data wranglers. 14 | 15 | In analogy to sklearn transformers (see link below), all wranglers have to 16 | implement `fit`, `transform` and `fit_transform` methods. In addition, 17 | parameters (e.g. column names) need to be provided via the `__init__` 18 | method. Furthermore, `get_params` and `set_params` methods are required for 19 | grid search and pipeline compatibility. 20 | 21 | The `fit` method contains optional fitting (e.g. compute mean and variance 22 | for scaling) which sets training data dependent transformation behaviour. 23 | The `transform` method includes the actual computational transformation. 24 | The `fit_transform` either applies the former methods in sequence or adds a 25 | new implementation of both with better performance. The `__init__` method 26 | should contain any logic behind parameter parsing and conversion. 27 | 28 | In contrast to sklearn, wranglers do only accept dataframes like objects 29 | (like pandas/pyspark/dask dataframes) as inputs to `fit` and `transform`. 30 | The relevant columns and their respective meaning is provided via the 31 | `__init__` method. In addition, wranglers may accept multiple input 32 | dataframes with different shapes. Also, the number of samples may also 33 | change between input and output (which is not allowed in sklearn). The 34 | `preserves_sample_size` indicates whether sample size (number of rows) may 35 | change during transformation. 36 | 37 | The wrangler's employed computation engine is given via 38 | `computation_engine`. 39 | 40 | See also 41 | -------- 42 | https://scikit-learn.org/stable/developers/contributing.html 43 | 44 | """ 45 | 46 | @property 47 | @abstractmethod 48 | def preserves_sample_size(self) -> bool: 49 | raise NotImplementedError 50 | 51 | @property 52 | @abstractmethod 53 | def computation_engine(self) -> str: 54 | raise NotImplementedError 55 | 56 | def get_params(self) -> dict: 57 | """Retrieve all wrangler parameters set within the __init__ method. 58 | 59 | Returns 60 | ------- 61 | param_dict: dictionary 62 | Parameter names as keys and corresponding values as values 63 | 64 | """ 65 | 66 | base_classes = [cls for cls in inspect.getmro(self.__class__) 67 | if issubclass(cls, BaseWrangler)] 68 | 69 | ignore = ["self", "args", "kwargs"] 70 | param_names = [] 71 | for cls in base_classes[::-1]: 72 | param_names.extend(get_param_names(cls.__init__, ignore)) 73 | 74 | param_dict = {x: getattr(self, x) for x in param_names} 75 | 76 | return param_dict 77 | 78 | def set_params(self, **params): 79 | """Set wrangler parameters 80 | 81 | Parameters 82 | ---------- 83 | params: dict 84 | Dictionary containing new values to be updated on wrangler. Keys 85 | have to match parameter names of wrangler. 86 | 87 | Returns 88 | ------- 89 | self 90 | 91 | """ 92 | 93 | valid_params = self.get_params() 94 | for key, value in params.items(): 95 | if key not in valid_params: 96 | raise ValueError('Invalid parameter {} for wrangler {}. ' 97 | 'Check the list of available parameters ' 98 | 'with `wrangler.get_params().keys()`.' 99 | .format(key, self)) 100 | 101 | setattr(self, key, value) 102 | 103 | return self 104 | 105 | @abstractmethod 106 | def fit(self, *args, **kwargs): 107 | raise NotImplementedError 108 | 109 | @abstractmethod 110 | def transform(self, *args, **kwargs): 111 | raise NotImplementedError 112 | 113 | @abstractmethod 114 | def fit_transform(self, *args, **kwargs): 115 | raise NotImplementedError 116 | 117 | def __repr__(self): 118 | 119 | template = '{wrangler_name} ({computation_engine})\n\n{parameters}'\ 120 | 121 | parameters = (_pprint.header("Parameters", 3) + 122 | _pprint.enumeration(self.get_params(), 3)) 123 | 124 | _repr = template.format(wrangler_name=self.__class__.__name__, 125 | computation_engine=self.computation_engine, 126 | parameters=parameters) 127 | 128 | if not self.preserves_sample_size: 129 | _repr += "\n\n Note: Does not preserve sample size." 130 | 131 | return _repr 132 | -------------------------------------------------------------------------------- /src/pywrangler/dask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/dask/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/dask/base.py: -------------------------------------------------------------------------------- 1 | """This module contains the dask base wrangler. 2 | 3 | """ 4 | 5 | from dask.dataframe import DataFrame 6 | 7 | from pywrangler.base import BaseWrangler 8 | 9 | 10 | class DaskWrangler(BaseWrangler): 11 | """Contains methods common to all dask based wranglers. 12 | 13 | """ 14 | 15 | @property 16 | def computation_engine(self): 17 | return "dask" 18 | 19 | 20 | class DaskSingleNoFit(DaskWrangler): 21 | """Mixin class defining `fit` and `fit_transform` for all wranglers with 22 | a single data frame input and output with no fitting necessary. 23 | 24 | """ 25 | 26 | def fit(self, df: DataFrame): 27 | """Do nothing and return the wrangler unchanged. 28 | 29 | This method is just there to implement the usual API and hence work in 30 | pipelines. 31 | 32 | Parameters 33 | ---------- 34 | df: pd.DataFrame 35 | 36 | """ 37 | 38 | return self 39 | 40 | def fit_transform(self, df: DataFrame) -> DataFrame: 41 | """Apply fit and transform in sequence at once. 42 | 43 | Parameters 44 | ---------- 45 | df: pd.DataFrame 46 | 47 | Returns 48 | ------- 49 | result: pd.DataFrame 50 | 51 | """ 52 | 53 | return self.fit(df).transform(df) 54 | -------------------------------------------------------------------------------- /src/pywrangler/dask/benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains benchmarking utility for pandas wranglers. 2 | 3 | """ 4 | 5 | import gc 6 | import sys 7 | import warnings 8 | from typing import Callable, List, Union 9 | 10 | import numpy as np 11 | from dask.diagnostics import ResourceProfiler 12 | 13 | from pywrangler.benchmark import MemoryProfiler, TimeProfiler 14 | from pywrangler.dask.base import DaskWrangler 15 | 16 | 17 | class DaskBaseProfiler: 18 | """Define common methods for dask profiler. 19 | 20 | """ 21 | 22 | def _wrap_fit_transform(self) -> Callable: 23 | """Wrapper function to call `compute()` on wrangler's `fit_transform` 24 | to enforce computation on lazily evaluated dask graphs. 25 | 26 | Returns 27 | ------- 28 | wrapped: callable 29 | Wrapped `fit_transform` method as a function. 30 | 31 | """ 32 | 33 | def wrapped(*args, **kwargs): 34 | return self.wrangler.fit_transform(*args, **kwargs).compute() 35 | 36 | return wrapped 37 | 38 | @staticmethod 39 | def _cache_input(dfs) -> List: 40 | """Persist lazily evaluated dask input collections before profiling to 41 | capture only relevant `fit_transform`. 42 | 43 | Parameters 44 | ---------- 45 | dfs: iterable 46 | Dask collections which can be persisted. 47 | 48 | Returns 49 | ------- 50 | persisted: iterable 51 | List of computed dask collections. 52 | 53 | """ 54 | 55 | return [df.persist() for df in dfs] 56 | 57 | @staticmethod 58 | def _clear_cached_input(dfs): 59 | """Remove original reference to previously persisted dask collections 60 | to enable garbage collection to free memory. Explicitly check reference 61 | count and give warning if persisted dask collections are referenced 62 | elsewhere which would prevent memory deallocation. 63 | 64 | Parameters 65 | ---------- 66 | dfs: iterable 67 | Persisted dask collections which should be removed. 68 | 69 | """ 70 | 71 | # ensure reference counts are updated 72 | gc.collect() 73 | 74 | # check ref counts 75 | for df in dfs: 76 | if sys.getrefcount(df) > 3: 77 | warnings.warn("Persisted dask collection is referenced " 78 | "elsewhere and prevents garbage collection", 79 | ResourceWarning) 80 | 81 | dfs.clear() 82 | 83 | 84 | class DaskTimeProfiler(TimeProfiler, DaskBaseProfiler): 85 | """Approximate time that a dask wrangler instance requires to execute the 86 | `fit_transform` step. 87 | 88 | Parameters 89 | ---------- 90 | wrangler: pywrangler.wranglers.base.BaseWrangler 91 | The wrangler instance to be profiled. 92 | repetitions: None, int, optional 93 | Number of repetitions. If `None`, `timeit.Timer.autorange` will 94 | determine a sensible default. 95 | cache_input: bool, optional 96 | Dask collections may be cached before timing execution to ensure 97 | timing measurements only capture wrangler's `fit_transform`. By 98 | default, it is disabled. 99 | 100 | Attributes 101 | ---------- 102 | measurements: list 103 | The actual profiling measurements in seconds. 104 | best: float 105 | The best measurement in seconds. 106 | median: float 107 | The median of measurements in seconds. 108 | worst: float 109 | The worst measurement in seconds. 110 | std: float 111 | The standard deviation of measurements in seconds. 112 | runs: int 113 | The number of measurements. 114 | 115 | Methods 116 | ------- 117 | profile 118 | Contains the actual profiling implementation. 119 | report 120 | Print simple report consisting of best, median, worst, standard 121 | deviation and the number of measurements. 122 | profile_report 123 | Calls profile and report in sequence. 124 | 125 | """ 126 | 127 | def __init__(self, wrangler: DaskWrangler, 128 | repetitions: Union[None, int] = None, 129 | cache_input: bool = False): 130 | self.wrangler = wrangler 131 | self.cache_input = cache_input 132 | 133 | func = self._wrap_fit_transform() 134 | super().__init__(func, repetitions) 135 | 136 | def profile(self, *dfs, **kwargs): 137 | """Profiles timing given input dataframes `dfs` which are passed to 138 | `fit_transform`. 139 | 140 | """ 141 | 142 | if self.cache_input: 143 | dfs = self._cache_input(dfs) 144 | 145 | super().profile(*dfs, **kwargs) 146 | 147 | if self.cache_input: 148 | self._clear_cached_input(dfs) 149 | 150 | return self 151 | 152 | 153 | class DaskMemoryProfiler(MemoryProfiler, DaskBaseProfiler): 154 | """Approximate memory usage that a dask wrangler instance requires to 155 | execute the `fit_transform` step. 156 | 157 | Parameters 158 | ---------- 159 | func: callable 160 | Callable object to be memory profiled. 161 | repetitions: int, optional 162 | Number of repetitions. 163 | interval: float, optional 164 | Defines interval duration between consecutive memory usage 165 | measurements in seconds. 166 | cache_input: bool, optional 167 | Dask collections may be cached before timing execution to ensure 168 | timing measurements only capture wrangler's `fit_transform`. By 169 | default, it is disabled. 170 | 171 | Attributes 172 | ---------- 173 | measurements: list 174 | The actual profiling measurements in bytes. 175 | best: float 176 | The best measurement in bytes. 177 | median: float 178 | The median of measurements in bytes. 179 | worst: float 180 | The worst measurement in bytes. 181 | std: float 182 | The standard deviation of measurements in bytes. 183 | runs: int 184 | The number of measurements. 185 | baseline_change: float 186 | The median change in baseline memory usage across all runs in bytes. 187 | 188 | Methods 189 | ------- 190 | profile 191 | Contains the actual profiling implementation. 192 | report 193 | Print simple report consisting of best, median, worst, standard 194 | deviation and the number of measurements. 195 | profile_report 196 | Calls profile and report in sequence. 197 | 198 | Notes 199 | ----- 200 | The implementation uses dask's own `ResourceProfiler`. 201 | 202 | """ 203 | 204 | def __init__(self, wrangler: DaskWrangler, 205 | repetitions: Union[None, int] = 5, 206 | interval: float = 0.01, 207 | cache_input: bool = False): 208 | self.wrangler = wrangler 209 | self.cache_input = cache_input 210 | 211 | func = self._wrap_fit_transform() 212 | super().__init__(func, repetitions, interval) 213 | 214 | def profile(self, *dfs, **kwargs): 215 | """Profiles timing given input dataframes `dfs` which are passed to 216 | `fit_transform`. 217 | 218 | """ 219 | 220 | if self.cache_input: 221 | dfs = self._cache_input(dfs) 222 | 223 | counter = 0 224 | baselines = [] 225 | max_usages = [] 226 | 227 | while counter < self.repetitions: 228 | gc.collect() 229 | 230 | with ResourceProfiler(dt=self.interval) as rprof: 231 | self.func(*dfs, **kwargs) 232 | 233 | mem_usages = [x.mem for x in rprof.results] 234 | baselines.append(np.min(mem_usages)) 235 | max_usages.append(np.max(mem_usages)) 236 | 237 | counter += 1 238 | 239 | self._max_usages = max_usages 240 | self._baselines = baselines 241 | self._measurements = np.subtract(max_usages, baselines).tolist() 242 | 243 | if self.cache_input: 244 | self._clear_cached_input(dfs) 245 | 246 | return self 247 | -------------------------------------------------------------------------------- /src/pywrangler/exceptions.py: -------------------------------------------------------------------------------- 1 | """The module contains package wide custom exceptions and warnings. 2 | 3 | """ 4 | 5 | 6 | class NotProfiledError(ValueError, AttributeError): 7 | """Exception class to raise if profiling results are acquired before 8 | calling `profile`. 9 | 10 | This class inherits from both ValueError and AttributeError to help with 11 | exception handling 12 | 13 | """ 14 | -------------------------------------------------------------------------------- /src/pywrangler/pandas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/pandas/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/pandas/base.py: -------------------------------------------------------------------------------- 1 | """This module contains the pandas base wrangler. 2 | 3 | """ 4 | 5 | import pandas as pd 6 | 7 | from pywrangler.base import BaseWrangler 8 | 9 | 10 | class PandasWrangler(BaseWrangler): 11 | """Pandas wrangler base class. 12 | 13 | """ 14 | 15 | @property 16 | def computation_engine(self): 17 | return "pandas" 18 | 19 | def _validate_output_shape(self, df_in: pd.DataFrame, 20 | df_out: pd.DataFrame): 21 | """If wrangler implementation preserves sample size, assert equal 22 | sample sizes between input and output dataframe. 23 | 24 | Using pandas, all data is in memory. Hence, getting shape information 25 | is cheap and this check can be done regularly (in contrast to pyspark 26 | where `df.count()` can be very expensive). 27 | 28 | Parameters 29 | ---------- 30 | df_in: pd.DataFrame 31 | Input dataframe. 32 | df_out: pd.DataFrame 33 | Output dataframe. 34 | 35 | """ 36 | 37 | if self.preserves_sample_size: 38 | shape_in = df_in.shape[0] 39 | shape_out = df_out.shape[0] 40 | 41 | if shape_in != shape_out: 42 | raise ValueError('Number of input samples ({}) does not match ' 43 | 'number of ouput samples ({}) which should ' 44 | 'be the case because wrangler is supposed to ' 45 | 'preserve the number of samples.' 46 | .format(shape_in, shape_out)) 47 | 48 | 49 | class PandasSingleNoFit(PandasWrangler): 50 | """Mixin class defining `fit` and `fit_transform` for all wranglers with 51 | a single data frame input and output with no fitting necessary. 52 | 53 | """ 54 | 55 | def fit(self, df: pd.DataFrame): 56 | """Do nothing and return the wrangler unchanged. 57 | 58 | This method is just there to implement the usual API and hence work in 59 | pipelines. 60 | 61 | Parameters 62 | ---------- 63 | df: pd.DataFrame 64 | 65 | """ 66 | 67 | return self 68 | 69 | def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame: 70 | """Apply fit and transform in sequence at once. 71 | 72 | Parameters 73 | ---------- 74 | df: pd.DataFrame 75 | 76 | Returns 77 | ------- 78 | result: pd.DataFrame 79 | 80 | """ 81 | 82 | return self.fit(df).transform(df) 83 | -------------------------------------------------------------------------------- /src/pywrangler/pandas/benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains benchmarking utility for pandas wranglers. 2 | 3 | """ 4 | 5 | from typing import Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from pywrangler.benchmark import MemoryProfiler, TimeProfiler 11 | from pywrangler.pandas.base import PandasWrangler 12 | from pywrangler.util import sanitizer 13 | 14 | 15 | class PandasTimeProfiler(TimeProfiler): 16 | """Approximate time that a pandas wrangler instance requires to execute the 17 | `fit_transform` step. 18 | 19 | Parameters 20 | ---------- 21 | wrangler: pywrangler.wranglers.base.BaseWrangler 22 | The wrangler instance to be profiled. 23 | repetitions: None, int, optional 24 | Number of repetitions. If `None`, `timeit.Timer.autorange` will 25 | determine a sensible default. 26 | 27 | Attributes 28 | ---------- 29 | measurements: list 30 | The actual profiling measurements in seconds. 31 | best: float 32 | The best measurement in seconds. 33 | median: float 34 | The median of measurements in seconds. 35 | worst: float 36 | The worst measurement in seconds. 37 | std: float 38 | The standard deviation of measurements in seconds. 39 | runs: int 40 | The number of measurements. 41 | 42 | Methods 43 | ------- 44 | profile 45 | Contains the actual profiling implementation. 46 | report 47 | Print simple report consisting of best, median, worst, standard 48 | deviation and the number of measurements. 49 | profile_report 50 | Calls profile and report in sequence. 51 | 52 | """ 53 | 54 | def __init__(self, wrangler: PandasWrangler, 55 | repetitions: Union[None, int] = None): 56 | self._wrangler = wrangler 57 | super().__init__(wrangler.fit_transform, repetitions) 58 | 59 | 60 | class PandasMemoryProfiler(MemoryProfiler): 61 | """Approximate memory usage that a pandas wrangler instance requires to 62 | execute the `fit_transform` step. 63 | 64 | As a key metric, `ratio` is computed. It refers to the amount of 65 | memory which is required to execute the `fit_transform` step. More 66 | concretely, it estimates how much more memory is used standardized by the 67 | input memory usage (memory usage increase during function execution divided 68 | by memory usage of input dataframes). In other words, if you have a 1GB 69 | input dataframe, and the `usage_ratio` is 5, `fit_transform` needs 5GB free 70 | memory available to succeed. A `usage_ratio` of 0.5 given a 2GB input 71 | dataframe would require 1GB free memory available for computation. 72 | 73 | Parameters 74 | ---------- 75 | wrangler: pywrangler.wranglers.pandas.base.PandasWrangler 76 | The wrangler instance to be profiled. 77 | repetitions: int 78 | The number of measurements for memory profiling. 79 | interval: float, optional 80 | Defines interval duration between consecutive memory usage 81 | measurements in seconds. 82 | 83 | Attributes 84 | ---------- 85 | measurements: list 86 | The actual profiling measurements in bytes. 87 | best: float 88 | The best measurement in bytes. 89 | median: float 90 | The median of measurements in bytes. 91 | worst: float 92 | The worst measurement in bytes. 93 | std: float 94 | The standard deviation of measurements in bytes. 95 | runs: int 96 | The number of measurements. 97 | baseline_change: float 98 | The median change in baseline memory usage across all runs in bytes. 99 | input: int 100 | Memory usage of input dataframes in bytes. 101 | output: int 102 | Memory usage of output dataframes in bytes. 103 | ratio: float 104 | The amount of memory required for computation in units of input 105 | memory usage. 106 | 107 | Methods 108 | ------- 109 | profile 110 | Contains the actual profiling implementation. 111 | report 112 | Print simple report consisting of best, median, worst, standard 113 | deviation and the number of measurements. 114 | profile_report 115 | Calls profile and report in sequence. 116 | 117 | """ 118 | 119 | def __init__(self, wrangler: PandasWrangler, repetitions: int = 5, 120 | interval: float = 0.01): 121 | self._wrangler = wrangler 122 | 123 | super().__init__(wrangler.fit_transform, repetitions, interval) 124 | 125 | def profile(self, *dfs: pd.DataFrame, **kwargs): 126 | """Profiles the actual memory usage given input dataframes `dfs` 127 | which are passed to `fit_transform`. 128 | 129 | """ 130 | 131 | # usage input 132 | self._usage_input = self._memory_usage_dfs(*dfs) 133 | 134 | # usage output 135 | dfs_output = self._wrangler.fit_transform(*dfs) 136 | dfs_output = sanitizer.ensure_iterable(dfs_output) 137 | self._usage_output = self._memory_usage_dfs(*dfs_output) 138 | 139 | # usage during fit_transform 140 | super().profile(*dfs, **kwargs) 141 | 142 | return self 143 | 144 | @property 145 | def input(self) -> float: 146 | """Returns the memory usage of the input dataframes in bytes. 147 | 148 | """ 149 | 150 | self._check_is_profiled(['_usage_input']) 151 | return self._usage_input 152 | 153 | @property 154 | def output(self) -> float: 155 | """Returns the memory usage of the output dataframes in bytes. 156 | 157 | """ 158 | 159 | self._check_is_profiled(['_usage_output']) 160 | return self._usage_output 161 | 162 | @property 163 | def ratio(self) -> float: 164 | """Refers to the amount of memory which is required to execute the 165 | `fit_transform` step. More concretely, it estimates how much more 166 | memory is used standardized by the input memory usage (memory usage 167 | increase during function execution divided by memory usage of input 168 | dataframes). In other words, if you have a 1GB input dataframe, and the 169 | `usage_ratio` is 5, `fit_transform` needs 5GB free memory available to 170 | succeed. A `usage_ratio` of 0.5 given a 2GB input dataframe would 171 | require 1GB free memory available for computation. 172 | 173 | """ 174 | 175 | return self.median / self.input 176 | 177 | @staticmethod 178 | def _memory_usage_dfs(*dfs: pd.DataFrame) -> int: 179 | """Return memory usage in bytes for all given dataframes. 180 | 181 | Parameters 182 | ---------- 183 | dfs: pd.DataFrame 184 | The pandas dataframes for which memory usage should be computed. 185 | 186 | Returns 187 | ------- 188 | memory_usage: int 189 | The computed memory usage in bytes. 190 | 191 | """ 192 | 193 | mem_usages = [df.memory_usage(deep=True, index=True).sum() 194 | for df in dfs] 195 | 196 | return int(np.sum(mem_usages)) 197 | -------------------------------------------------------------------------------- /src/pywrangler/pandas/util.py: -------------------------------------------------------------------------------- 1 | """This module contains utility functions (e.g. validation) commonly used by 2 | pandas wranglers. 3 | 4 | """ 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from pandas.core.groupby.generic import DataFrameGroupBy 9 | 10 | from pywrangler.util.sanitizer import ensure_iterable 11 | from pywrangler.util.types import TYPE_ASCENDING, TYPE_COLUMNS 12 | 13 | 14 | def validate_empty_df(df: pd.DataFrame): 15 | """Check for empty dataframe. By definition, wranglers operate on non 16 | empty dataframe. Therefore, raise error if dataframe is empty. 17 | 18 | Parameters 19 | ---------- 20 | df: pd.DataFrame 21 | Dataframe to check against. 22 | 23 | """ 24 | 25 | if df.empty: 26 | raise ValueError('Dataframe is empty.') 27 | 28 | 29 | def validate_columns(df: pd.DataFrame, columns: TYPE_COLUMNS): 30 | """Check that columns exist in dataframe and raise error if otherwise. 31 | 32 | Parameters 33 | ---------- 34 | df: pd.DataFrame 35 | Dataframe to check against. 36 | columns: iterable[str] 37 | Columns to be validated. 38 | 39 | """ 40 | 41 | columns = ensure_iterable(columns) 42 | 43 | for column in columns: 44 | if column not in df.columns: 45 | raise ValueError('Column with name `{}` does not exist. ' 46 | 'Please check parameter settings.' 47 | .format(column)) 48 | 49 | 50 | def sort_values(df: pd.DataFrame, 51 | order_columns: TYPE_COLUMNS, 52 | ascending: TYPE_ASCENDING) -> pd.DataFrame: 53 | """Convenient function to return sorted dataframe while taking care of 54 | optional order columns and order (ascending/descending). 55 | 56 | Parameters 57 | ---------- 58 | df: pd.DataFrame 59 | Dataframe to check against. 60 | order_columns: TYPE_COLUMNS 61 | Columns to be sorted. 62 | ascending: TYPE_ASCENDING 63 | Column order. 64 | 65 | Returns 66 | ------- 67 | df_sorted: pd.DataFrame 68 | 69 | """ 70 | 71 | if order_columns: 72 | return df.sort_values(order_columns, ascending=ascending) 73 | else: 74 | return df 75 | 76 | 77 | def groupby(df: pd.DataFrame, 78 | groupby_columns: TYPE_COLUMNS) -> DataFrameGroupBy: 79 | """Convenient function to group by a dataframe while taking care of 80 | optional groupby columns. Always returns a `DataFrameGroupBy` object. 81 | 82 | Parameters 83 | ---------- 84 | df: pd.DataFrame 85 | Dataframe to check against. 86 | groupby_columns: TYPE_COLUMNS 87 | Columns to be grouped by. 88 | 89 | Returns 90 | ------- 91 | groupby: DataFrameGroupBy 92 | 93 | """ 94 | 95 | if groupby_columns: 96 | return df.groupby(groupby_columns) 97 | else: 98 | return df.groupby(np.zeros(df.shape[0])) 99 | -------------------------------------------------------------------------------- /src/pywrangler/pandas/wranglers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/pandas/wranglers/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/pandas/wranglers/interval_identifier.py: -------------------------------------------------------------------------------- 1 | """This module contains implementations of the interval identifier wrangler. 2 | 3 | """ 4 | 5 | from typing import List 6 | 7 | import pandas as pd 8 | from pywrangler.pandas import util 9 | from pywrangler.pandas.base import PandasSingleNoFit 10 | from pywrangler.wranglers import IntervalIdentifier 11 | 12 | 13 | class _BaseIntervalIdentifier(PandasSingleNoFit, IntervalIdentifier): 14 | """Provides `transform` and `validate_input` methods common to more than 15 | one implementation of the pandas interval identification wrangler. 16 | 17 | The `transform` has several shared responsibilities. It sorts and groups 18 | the data before applying the `_transform` method which needs to be 19 | implemented by every wrangler subclassing this mixin. In addition, it 20 | remains the original index of the input dataframe, ensures the resulting 21 | column to be of type integer and converts output to a data frame with 22 | parametrized target column name. 23 | 24 | """ 25 | 26 | def _validate_input(self, df: pd.DataFrame): 27 | """Checks input data frame in regard to column names and empty data. 28 | 29 | Parameters 30 | ---------- 31 | df: pd.DataFrame 32 | Dataframe to be validated. 33 | 34 | """ 35 | 36 | util.validate_columns(df, self.marker_column) 37 | util.validate_columns(df, self.orderby_columns) 38 | util.validate_columns(df, self.groupby_columns) 39 | util.validate_empty_df(df) 40 | 41 | def transform(self, df: pd.DataFrame) -> pd.DataFrame: 42 | """Extract interval ids from given dataframe. 43 | 44 | Parameters 45 | ---------- 46 | df: pd.DataFrame 47 | 48 | Returns 49 | ------- 50 | result: pd.DataFrame 51 | Single columned dataframe with same index as `df`. 52 | 53 | """ 54 | 55 | # check input 56 | self._validate_input(df) 57 | 58 | # transform 59 | df_ordered = util.sort_values(df, self.orderby_columns, self.ascending) 60 | df_grouped = util.groupby(df_ordered, self.groupby_columns) 61 | 62 | df_result = df_grouped[self.marker_column] \ 63 | .transform(self._transform) \ 64 | .astype(int) \ 65 | .reindex(df.index) \ 66 | .to_frame(self.target_column_name) 67 | 68 | # check output 69 | self._validate_output_shape(df, df_result) 70 | 71 | return df_result 72 | 73 | 74 | class NaiveIterator(_BaseIntervalIdentifier): 75 | """Most simple, sequential implementation which iterates values while 76 | remembering the state of start and end markers. 77 | 78 | """ 79 | 80 | def _transform(self, values: pd.Series) -> List[int]: 81 | """Selects appropriate algorithm depending on identical/different 82 | start and end markers. 83 | 84 | """ 85 | 86 | start_first = self.marker_start_use_first 87 | end_first = self.marker_end_use_first 88 | 89 | if self._identical_start_end_markers: 90 | return self._agg_identical_start_end_markers(values) 91 | elif self.result_type == "raw": 92 | return self._agg_raw_iids(values) 93 | elif not start_first and end_first: 94 | return self._generic_start_first_end(values, False) 95 | elif start_first and not end_first: 96 | return self._generic_start_last_end(values, True) 97 | elif start_first and end_first: 98 | return self._generic_start_first_end(values, True) 99 | elif not start_first and not end_first: 100 | return self._generic_start_last_end(values, False) 101 | 102 | def _is_start(self, value): 103 | return value == self.marker_start 104 | 105 | def _is_end(self, value): 106 | return value == self.marker_end 107 | 108 | def _agg_identical_start_end_markers(self, series: pd.Series) -> List[int]: 109 | """Iterates given `series` testing each value against start marker 110 | while increasing counter each time start marker is encountered. 111 | 112 | Assumes that series is already ordered and grouped. 113 | 114 | """ 115 | 116 | result = [] 117 | counter = 0 118 | 119 | for value in series.values: 120 | if self._is_start(value): 121 | counter += 1 122 | 123 | result.append(counter) 124 | 125 | return result 126 | 127 | def _agg_raw_iids(self, series: pd.Series) -> List[int]: 128 | """Iterates given `series` testing each value against start marker 129 | while increasing counter each time start or end marker (shifted) is 130 | encountered. 131 | 132 | Assumes that series is already ordered and grouped. 133 | 134 | """ 135 | 136 | result = [] 137 | counter = 0 138 | lag = False 139 | 140 | for value in series.values: 141 | if lag: 142 | counter += 1 143 | lag = False 144 | 145 | if self._is_start(value): 146 | counter += 1 147 | elif self._is_end(value): 148 | lag = True 149 | 150 | result.append(counter) 151 | 152 | return result 153 | 154 | def _generic_start_first_end(self, series: pd.Series, first_start: bool) \ 155 | -> List[int]: 156 | """Iterates given `series` testing each value against start and end 157 | markers while keeping track of already instantiated intervals to 158 | separate valid from invalid intervals. 159 | 160 | Assumes that series is already ordered and grouped. 161 | 162 | Parameters 163 | ---------- 164 | series: pd.Series 165 | Sorted values which contain interval data. 166 | first_start: bool 167 | Indicates if first or last start is required. If True, generates 168 | ids for first start. If False, generates ids for last start. 169 | 170 | """ 171 | 172 | counter = 0 # counts the current interval id 173 | active = 0 # 0 in case no active interval, otherwise equals counter 174 | intermediate = [] # stores intermediate results 175 | result = [] # keeps track of all results 176 | 177 | for value in series.values: 178 | 179 | if self._is_start(value): 180 | if active and not first_start: 181 | # add invalid values to result (from previous begin marker) 182 | result.extend([0] * len(intermediate)) 183 | 184 | # start new intermediate list 185 | intermediate = [] 186 | 187 | if not active: 188 | active = counter + 1 189 | 190 | intermediate.append(active) 191 | 192 | elif self._is_end(value) and active: 193 | # add valid interval to result 194 | result.extend(intermediate) 195 | result.append(active) 196 | 197 | # empty intermediate list 198 | intermediate = [] 199 | active = 0 200 | 201 | # increase id counter since valid interval was closed 202 | counter += 1 203 | 204 | else: 205 | intermediate.append(active) 206 | 207 | # finally, add rest to result which must be invalid 208 | result.extend([0] * len(intermediate)) 209 | 210 | return result 211 | 212 | def _generic_start_last_end(self, series: pd.Series, first_start: bool) \ 213 | -> List[int]: 214 | """Iterates given `series` testing each value against start and end 215 | markers while keeping track of already instantiated intervals to 216 | separate valid from invalid intervals. 217 | 218 | Requires state for opened start/end markers and number of noise values 219 | since last end marker. 220 | 221 | Assumes that series is already ordered and grouped. 222 | 223 | Parameters 224 | ---------- 225 | series: pd.Series 226 | Sorted values which contain interval data. 227 | first_start: bool 228 | Indicates if first or last start is required. If True, generates 229 | ids for first start. If False, generates ids for last start. 230 | 231 | """ 232 | 233 | counter = 0 # counts the current interval id 234 | active_start = False # remember opened start marker 235 | active_end = False # remember opened end marker 236 | noise_counter = 0 # store number of noises after end marker 237 | intermediate = [] # store intermediate results 238 | result = [] # keeps track of all results 239 | 240 | for value in series.values: 241 | # handle start marker 242 | if self._is_start(value): 243 | # closing valid interval 244 | if active_start & active_end: 245 | result.extend(intermediate) 246 | result.extend([0] * noise_counter) 247 | counter += 1 248 | 249 | noise_counter = 0 250 | active_end = False 251 | intermediate = [] 252 | 253 | # increase counter only if start was not active previously 254 | elif not active_start: 255 | counter += 1 256 | 257 | # handle last start 258 | elif not active_end and not first_start: 259 | result.extend([0] * len(intermediate)) 260 | intermediate = [] 261 | 262 | active_start = True 263 | intermediate.append(counter) 264 | 265 | # handle end marker 266 | elif self._is_end(value): 267 | if not active_start: 268 | result.append(0) 269 | else: 270 | active_end = True 271 | count = len(intermediate) + noise_counter + 1 272 | result.extend([counter] * count) 273 | 274 | intermediate = [] 275 | noise_counter = 0 276 | 277 | # handle noise 278 | else: 279 | if active_end: 280 | noise_counter += 1 281 | elif active_start: 282 | intermediate.append(counter) 283 | else: 284 | result.append(0) 285 | 286 | # handle remaining values 287 | if active_start & ~active_end: 288 | result.extend([0] * len(intermediate)) 289 | elif active_end: 290 | intermediate.extend([0] * noise_counter) 291 | result.extend(intermediate) 292 | 293 | return result 294 | 295 | 296 | class VectorizedCumSum(_BaseIntervalIdentifier): 297 | """Sophisticated approach using multiple, vectorized operations. Using 298 | cumulative sum allows enumeration of intervals to avoid looping. 299 | 300 | """ 301 | 302 | def _transform(self, values: pd.Series) -> List[int]: 303 | """Selects appropriate algorithm depending on identical/different 304 | start and end markers. 305 | 306 | """ 307 | 308 | if self._identical_start_end_markers: 309 | return self._agg_identical_start_end_markers(values) 310 | 311 | if self.marker_start_use_first and not self.result_type == "raw": 312 | values = self._drop_duplicated_marker(values, True) 313 | 314 | if not self.marker_end_use_first and not self.result_type == "raw": 315 | values = self._drop_duplicated_marker(values, False) 316 | 317 | return self._last_start_first_end(values) 318 | 319 | def _drop_duplicated_marker(self, marker_column: pd.Series, 320 | start: bool = True): 321 | """Modify marker column to keep only first start marker or last end 322 | marker. 323 | 324 | Parameters 325 | ---------- 326 | marker_column: pd.Series 327 | Values for which duplicated markers will be removed. 328 | start: bool, optional 329 | Indicate which duplicates should be dropped. If True, only first 330 | start marker is kept. If False, only last end marker is kept. 331 | 332 | Returns 333 | ------- 334 | dropped: pd.Series 335 | 336 | """ 337 | 338 | valid_values = [self.marker_start, self.marker_end] 339 | denoised = marker_column.where(marker_column.isin(valid_values)) 340 | 341 | if start: 342 | fill = denoised.ffill() 343 | marker = 1 344 | shift = 1 345 | else: 346 | fill = denoised.bfill() 347 | marker = 2 348 | shift = -1 349 | 350 | shifted = fill.shift(shift) 351 | shifted_start_only = shifted.where(fill.eq(marker)) 352 | 353 | mask_drop = (shifted_start_only == marker_column) 354 | dropped = marker_column.where(~mask_drop) 355 | 356 | return dropped 357 | 358 | def _last_start_first_end(self, series: pd.Series) -> List[int]: 359 | """Extract shortest intervals from given dataFrame as ids. 360 | First, get enumeration of all intervals (valid and invalid). Every 361 | time a start or end marker is encountered, increase interval id by one. 362 | The end marker is shifted by one to include the end marker in the 363 | current interval. This is realized via the cumulative sum of boolean 364 | series of start markers and shifted end markers. 365 | 366 | Second, separate valid from invalid intervals by ensuring the presence 367 | of both start and end markers per interval id. 368 | 369 | Third, numerate valid intervals starting with 1 and set invalid 370 | intervals to 0. 371 | 372 | Assumes that series is already ordered and grouped. 373 | 374 | """ 375 | 376 | # get boolean series with start and end markers 377 | bool_start = series.eq(self.marker_start) 378 | bool_end = series.eq(self.marker_end) 379 | 380 | # shifting the close marker allows cumulative sum to include the end 381 | bool_end_shift = bool_end.shift().fillna(False) 382 | 383 | # get increasing ids for intervals (in/valid) with cumsum 384 | iids_raw = bool_start.add(bool_end_shift).cumsum() 385 | if self.result_type == "raw": 386 | return iids_raw 387 | 388 | # separate valid vs invalid: ids with start AND end marker are valid 389 | mask_valid_ids = bool_start.add(bool_end).groupby(iids_raw).sum().eq(2) 390 | valid_ids = mask_valid_ids.index[mask_valid_ids].values 391 | mask = iids_raw.isin(valid_ids) 392 | 393 | if self.result_type == "valid": 394 | return iids_raw.where(mask, 0) 395 | 396 | # re-numerate ids from 1 to x and fill invalid with 0 397 | result = iids_raw[mask].diff().ne(0).cumsum() 398 | return result.reindex(series.index).fillna(0).values 399 | 400 | def _agg_identical_start_end_markers(self, series: pd.Series) -> List[int]: 401 | """Iterates given `series` testing each value against start marker 402 | while increasing counter each time start marker is encountered. 403 | 404 | Assumes that series is already ordered and grouped. 405 | 406 | """ 407 | 408 | bool_start = series.eq(self.marker_start) 409 | return bool_start.cumsum() 410 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/pyspark/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/pyspark/base.py: -------------------------------------------------------------------------------- 1 | """This module contains the pyspark base wrangler. 2 | 3 | """ 4 | 5 | from pyspark.sql import DataFrame 6 | 7 | from pywrangler.base import BaseWrangler 8 | 9 | 10 | class PySparkWrangler(BaseWrangler): 11 | """Contains methods common to all pyspark based wranglers. 12 | 13 | """ 14 | 15 | @property 16 | def computation_engine(self): 17 | return "pyspark" 18 | 19 | 20 | class PySparkSingleNoFit(PySparkWrangler): 21 | """Mixin class defining `fit` and `fit_transform` for all wranglers with 22 | a single data frame input and output with no fitting necessary. 23 | 24 | """ 25 | 26 | def fit(self, df: DataFrame): 27 | """Do nothing and return the wrangler unchanged. 28 | 29 | This method is just there to implement the usual API and hence work in 30 | pipelines. 31 | 32 | Parameters 33 | ---------- 34 | df: pyspark.sql.DataFrame 35 | 36 | """ 37 | 38 | return self 39 | 40 | def fit_transform(self, df: DataFrame) -> DataFrame: 41 | """Apply fit and transform in sequence at once. 42 | 43 | Parameters 44 | ---------- 45 | df: pd.DataFrame 46 | 47 | Returns 48 | ------- 49 | result: pyspark.sql.DataFrame 50 | 51 | """ 52 | 53 | return self.fit(df).transform(df) 54 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains benchmarking utility for pandas wranglers. 2 | 3 | TODO: implement PySparkMemoryProfiler 4 | 5 | """ 6 | 7 | import warnings 8 | from typing import Callable, Iterable, Union 9 | 10 | from pyspark.sql import DataFrame 11 | 12 | from pywrangler.benchmark import TimeProfiler 13 | from pywrangler.pyspark.base import PySparkWrangler 14 | 15 | 16 | class PySparkBaseProfiler: 17 | """Define common methods for pyspark profiler. 18 | 19 | """ 20 | 21 | def _wrap_fit_transform(self) -> Callable: 22 | """Wrapper function to call `count()` on wrangler's `fit_transform` 23 | to enforce computation on lazily evaluated pyspark dataframes. 24 | 25 | Returns 26 | ------- 27 | wrapped: callable 28 | Wrapped `fit_transform` method as a function. 29 | 30 | """ 31 | 32 | def wrapped(*args, **kwargs): 33 | return self.wrangler.fit_transform(*args, **kwargs).count() 34 | 35 | return wrapped 36 | 37 | @staticmethod 38 | def _cache_input(dfs: Iterable[DataFrame]): 39 | """Persist lazily evaluated pyspark dataframes before profiling to 40 | capture only relevant `fit_transform`. Apply `count` to enforce 41 | computation to create cached representation. 42 | 43 | Parameters 44 | ---------- 45 | dfs: iterable 46 | Spark dataframes to be persisted. 47 | 48 | Returns 49 | ------- 50 | persisted: iterable 51 | List of computed dask collections. 52 | 53 | """ 54 | 55 | for df in dfs: 56 | df.persist() 57 | df.count() 58 | 59 | @staticmethod 60 | def _clear_cached_input(dfs: Iterable[DataFrame]): 61 | """Unpersist previously persisted pyspark dataframes after profiling. 62 | 63 | Parameters 64 | ---------- 65 | dfs: iterable 66 | Persisted pyspark dataframes. 67 | 68 | """ 69 | 70 | for df in dfs: 71 | df.unpersist() 72 | 73 | if df.is_cached: 74 | warnings.warn("Spark dataframe could not be unpersisted.", 75 | ResourceWarning) 76 | 77 | 78 | class PySparkTimeProfiler(TimeProfiler, PySparkBaseProfiler): 79 | """Approximate time that a pyspark wrangler instance requires to execute 80 | the `fit_transform` step. 81 | 82 | Parameters 83 | ---------- 84 | wrangler: pywrangler.wranglers.base.BaseWrangler 85 | The wrangler instance to be profiled. 86 | repetitions: None, int, optional 87 | Number of repetitions. If `None`, `timeit.Timer.autorange` will 88 | determine a sensible default. 89 | cache_input: bool, optional 90 | Spark dataframes may be cached before timing execution to ensure 91 | timing measurements only capture wrangler's `fit_transform`. By 92 | default, it is disabled. 93 | 94 | Attributes 95 | ---------- 96 | measurements: list 97 | The actual profiling measurements in seconds. 98 | best: float 99 | The best measurement in seconds. 100 | median: float 101 | The median of measurements in seconds. 102 | worst: float 103 | The worst measurement in seconds. 104 | std: float 105 | The standard deviation of measurements in seconds. 106 | runs: int 107 | The number of measurements. 108 | 109 | Methods 110 | ------- 111 | profile 112 | Contains the actual profiling implementation. 113 | report 114 | Print simple report consisting of best, median, worst, standard 115 | deviation and the number of measurements. 116 | profile_report 117 | Calls profile and report in sequence. 118 | 119 | """ 120 | 121 | def __init__(self, wrangler: PySparkWrangler, 122 | repetitions: Union[None, int] = None, 123 | cache_input: bool = False): 124 | self.wrangler = wrangler 125 | self.cache_input = cache_input 126 | 127 | func = self._wrap_fit_transform() 128 | super().__init__(func, repetitions) 129 | 130 | def profile(self, *dfs: DataFrame, **kwargs): 131 | """Profiles timing given input dataframes `dfs` which are passed to 132 | `fit_transform`. 133 | 134 | Please note, input dataframes are cached before timing execution to 135 | ensure timing measurements only capture wrangler's `fit_transform`. 136 | This may cause problems if the size of input dataframes exceeds 137 | available memory. 138 | 139 | """ 140 | 141 | if self.cache_input: 142 | self._cache_input(dfs) 143 | 144 | super().profile(*dfs, **kwargs) 145 | 146 | if self.cache_input: 147 | self._clear_cached_input(dfs) 148 | 149 | return self 150 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/testing.py: -------------------------------------------------------------------------------- 1 | """This module contains helper functions for testing. 2 | 3 | """ 4 | 5 | import pandas as pd 6 | from pyspark.sql import DataFrame 7 | 8 | from pywrangler.util.types import TYPE_COLUMNS 9 | 10 | try: 11 | from pandas.testing import assert_frame_equal 12 | except ImportError: 13 | from pandas.util.testing import assert_frame_equal 14 | 15 | # constant for pandas missing NULL values 16 | PANDAS_NULL = object() 17 | 18 | 19 | def prepare_spark_conversion(df: pd.DataFrame) -> pd.DataFrame: 20 | """Pandas does not distinguish NULL and NaN values. Everything null-like 21 | is converted to NaN. However, spark does distinguish NULL and NaN for 22 | example. To enable correct spark dataframe creation with NULL and NaN 23 | values, the `PANDAS_NULL` constant is used as a workaround to enforce NULL 24 | values in pyspark dataframes. Pyspark treats `None` values as NULL. 25 | 26 | Parameters 27 | ---------- 28 | df: pd.DataFrame 29 | Input dataframe to be prepared. 30 | 31 | Returns 32 | ------- 33 | df_prepared: pd.DataFrame 34 | Prepared dataframe for spark conversion. 35 | 36 | """ 37 | 38 | return df.where(df.ne(PANDAS_NULL), None) 39 | 40 | 41 | def assert_pyspark_pandas_equality(df_spark: DataFrame, 42 | df_pandas: pd.DataFrame, 43 | orderby: TYPE_COLUMNS = None): 44 | """Compare a pyspark and pandas dataframe in regard to content equality. 45 | Pyspark dataframes don't have a specific index or column order due to their 46 | distributed nature. In contrast, a test for equality for pandas dataframes 47 | respects index and column order. Therefore, the test for equality between a 48 | pyspark and pandas dataframe will ignore index and column order on purpose. 49 | 50 | Testing pyspark dataframes content is most simple while converting to 51 | pandas dataframes and having test data as pandas dataframes, too. 52 | 53 | To ensure index order is ignored, both dataframes need be sorted by all or 54 | given columns `orderby`. 55 | 56 | Parameters 57 | ---------- 58 | df_spark: pyspark.sql.DataFrame 59 | Spark dataframe to be tested for equality. 60 | df_pandas: pd.DataFrame 61 | Pandas dataframe to be tested for equality. 62 | orderby: iterable, optional 63 | Columns to be sorted for correct index order. 64 | 65 | Returns 66 | ------- 67 | None but asserts if dataframes are not equal. 68 | 69 | """ 70 | 71 | df_spark = df_spark.toPandas() 72 | 73 | # check for non matching columns and enforce identical column order 74 | mismatch_columns = df_pandas.columns.symmetric_difference(df_spark.columns) 75 | if not mismatch_columns.empty: 76 | raise AssertionError("Column names do not match: {}" 77 | .format(mismatch_columns.tolist())) 78 | else: 79 | df_spark = df_spark[df_pandas.columns] 80 | 81 | # enforce identical row order 82 | orderby = orderby or df_pandas.columns.tolist() 83 | 84 | def prepare_compare(df): 85 | df = df.sort_values(orderby).reset_index(drop=True) 86 | df = df.where((pd.notnull(df)), None) 87 | return df 88 | 89 | df_pandas = prepare_compare(df_pandas) 90 | df_spark = prepare_compare(df_spark) 91 | 92 | assert_frame_equal(df_spark, 93 | df_pandas, 94 | check_like=False, 95 | check_dtype=False) 96 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/types.py: -------------------------------------------------------------------------------- 1 | """This module contains pyspark specific types. 2 | 3 | """ 4 | from typing import Union, Iterable, Optional 5 | 6 | from pyspark.sql import Column 7 | 8 | TYPE_PYSPARK_COLUMNS = Optional[ 9 | Union[str, Column, Iterable[str], Iterable[Column]]] 10 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/util.py: -------------------------------------------------------------------------------- 1 | """This module contains utility functions (e.g. validation) commonly used by 2 | pyspark wranglers. 3 | 4 | """ 5 | 6 | from typing import Union, Optional, List 7 | 8 | from pyspark.sql import DataFrame 9 | from pyspark.sql import functions as F 10 | from pyspark.sql.column import Column 11 | 12 | from pywrangler.util.sanitizer import ensure_iterable 13 | from pywrangler.util.types import TYPE_ASCENDING, TYPE_COLUMNS 14 | from pywrangler.pyspark.types import TYPE_PYSPARK_COLUMNS 15 | 16 | 17 | def ensure_column(column: Union[Column, str]) -> Column: 18 | """Helper function to ensure that provided column will be of type 19 | `pyspark.sql.column.Column`. 20 | 21 | Parameters 22 | ---------- 23 | column: str, Column 24 | Column object to be casted if required. 25 | 26 | Returns 27 | ------- 28 | ensured: Column 29 | 30 | """ 31 | 32 | if isinstance(column, Column): 33 | return column 34 | else: 35 | return F.col(column) 36 | 37 | 38 | def validate_columns(df: DataFrame, columns: TYPE_COLUMNS): 39 | """Check that columns exist in dataframe and raise error if otherwise. 40 | 41 | Parameters 42 | ---------- 43 | df: pyspark.sql.DataFrame 44 | Dataframe to check against. 45 | columns: Tuple[str] 46 | Columns to be validated. 47 | 48 | """ 49 | 50 | columns = ensure_iterable(columns) 51 | compare_columns = {column.lower() for column in df.columns} 52 | 53 | for column in columns: 54 | if column.lower() not in compare_columns: 55 | raise ValueError('Column with name `{}` does not exist. ' 56 | 'Please check parameter settings.' 57 | .format(column)) 58 | 59 | 60 | def prepare_orderby(orderby_columns: TYPE_PYSPARK_COLUMNS, 61 | ascending: TYPE_ASCENDING = True, 62 | reverse: bool = False) -> List[Column]: 63 | """Convenient function to return orderby columns in correct 64 | ascending/descending order. 65 | 66 | Parameters 67 | ---------- 68 | orderby_columns: TYPE_PYSPARK_COLUMNS 69 | Columns to explicitly apply an order to. 70 | ascending: TYPE_ASCENDING, optional 71 | Define order of columns via bools. True and False refer to ascending 72 | and descending, respectively. 73 | reverse: bool, optional 74 | Reverse the given order. By default, not activated. 75 | 76 | Returns 77 | ------- 78 | ordered: list 79 | List of order columns. 80 | 81 | """ 82 | 83 | # ensure columns 84 | orderby_columns = ensure_iterable(orderby_columns) 85 | orderby_columns = [ensure_column(column) for column in orderby_columns] 86 | 87 | # check if only True/False is given broadcast 88 | if isinstance(ascending, bool): 89 | ascending = [ascending] * len(orderby_columns) 90 | 91 | # ensure equal lengths, otherwise raise 92 | elif len(orderby_columns) != len(ascending): 93 | raise ValueError('`orderby_columns` and `ascending` must have ' 94 | 'equal number of items.') 95 | 96 | zipped = zip(orderby_columns, ascending) 97 | 98 | def boolify(sort_ascending: Optional[bool]) -> bool: 99 | return bool(sort_ascending) != reverse 100 | 101 | return [column.asc() if boolify(sort_ascending) else column.desc() 102 | for column, sort_ascending in zipped] 103 | 104 | 105 | class ColumnCacher: 106 | """Pyspark column expression cacher which enables storing of intermediate 107 | column expressions. PySpark column expressions can be stacked/chained. For 108 | example, a column expression may be a result of a conjunction of window 109 | functions and boolean masks for which the intermediate results are not 110 | stored because they are not needed for the final outcome. 111 | 112 | There are two valid reasons to store intermediate results. First, 113 | debugging requires to inspect intermediate results. Second, stacking 114 | column expressions seem to create more complex computation graphs. Storing 115 | intermediate results may help to decrease DAG complexity. 116 | 117 | For more, see Spark Jira: https://issues.apache.org/jira/browse/SPARK-30552 118 | 119 | """ 120 | 121 | def __init__(self, df: DataFrame, mode: Union[bool, str]): 122 | """Initialize column cacher. Set reference to dataframe. 123 | 124 | Parameters 125 | ---------- 126 | df: pyspark.sql.DataFrame 127 | DataFrame for which column caching will be activated. 128 | mode: bool, str 129 | If True, enables caching. If False, disables caching. If 'debug', 130 | enables caching and keeps intermediate columns (does not drop 131 | columns). 132 | 133 | """ 134 | 135 | self.df = df 136 | self.mode = mode 137 | 138 | self.columns = {} 139 | 140 | valid_modes = {True, False, "debug"} 141 | if mode not in valid_modes: 142 | raise ValueError("Parameter `mode` has to be one of the " 143 | "following: {}." 144 | .format(valid_modes)) 145 | 146 | def add(self, name: str, column: Column, force=False) -> Column: 147 | """Add given column to dataframe. Return referenced column. Creates 148 | unique name which is not yet present in dataframe. 149 | 150 | Parameters 151 | ---------- 152 | name: str 153 | Name of the column. 154 | column: pyspark.sql.column.Column 155 | PySpark column expression to be explicitly added to dataframe. 156 | force: bool, optional 157 | You may need to force to add a column temporarily for a given 158 | computation to finish even though you do not want to store 159 | intermediate results. This may be the case for window specs which 160 | rely on computed columns. 161 | 162 | Returns 163 | ------- 164 | reference: pyspark.sql.column.Column 165 | 166 | """ 167 | 168 | if (self.mode is False) and (force is not True): 169 | return column 170 | 171 | col_name = "{}_{}".format(name, len(self.columns)) 172 | while col_name in self.df.columns: 173 | col_name += "_" 174 | 175 | self.columns[name] = col_name 176 | self.df = self.df.withColumn(col_name, column) 177 | 178 | return F.col(col_name) 179 | 180 | def finish(self, name, column) -> DataFrame: 181 | """Closes column cacher and returns dataframe representation with 182 | provided final result column. Intermediate columns will be dropped 183 | based on `mode`. 184 | 185 | Parameters 186 | ---------- 187 | name: str 188 | Name of the final result column. 189 | column: pyspark.sql.column.Column 190 | Content of the final result column. 191 | 192 | Returns 193 | ------- 194 | df: pyspark.sql.DataFrame 195 | Original dataframe with added column. 196 | 197 | """ 198 | 199 | self.df = self.df.withColumn(name, column) 200 | 201 | if self.mode != "debug": 202 | self.df = self.df.drop(*self.columns.values()) 203 | 204 | return self.df 205 | -------------------------------------------------------------------------------- /src/pywrangler/pyspark/wranglers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/pyspark/wranglers/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/src/pywrangler/util/__init__.py -------------------------------------------------------------------------------- /src/pywrangler/util/_pprint.py: -------------------------------------------------------------------------------- 1 | """This module contains helper functions for printing. 2 | 3 | """ 4 | 5 | import re 6 | import textwrap 7 | from typing import Any, List, Tuple, Union 8 | 9 | ITERABLE = Union[List[str], Tuple[str]] 10 | ENUM = Union[ITERABLE, dict] 11 | 12 | REGEX_REMOVE_WHITESPACES = re.compile(r"\s{2,}") 13 | 14 | 15 | def _join(lines: ITERABLE) -> str: 16 | """Join given lines. 17 | 18 | Parameters 19 | ---------- 20 | lines: list, tuple 21 | Iterable to join. 22 | 23 | Returns 24 | ------- 25 | joined: str 26 | 27 | """ 28 | 29 | return "\n".join(lines) 30 | 31 | 32 | def _indent(lines: ITERABLE, indent: int = 3) -> list: 33 | """Indent given lines and optionally join. 34 | 35 | Parameters 36 | ---------- 37 | lines: list, tuple 38 | Iterable to indent. 39 | indent: int, optional 40 | Indentation count. 41 | 42 | """ 43 | 44 | spacing = " " * indent 45 | return [spacing + x for x in lines] 46 | 47 | 48 | def header(name: str, indent: int = 0, underline: str = "-") -> str: 49 | """Create columns with underline. 50 | 51 | Parameters 52 | ---------- 53 | name: str 54 | Name of title. 55 | indent: int, optional 56 | Indentation count. 57 | underline: str, optional 58 | Underline character. 59 | 60 | Returns 61 | ------- 62 | columns: str 63 | 64 | """ 65 | 66 | _indent = " " * indent 67 | 68 | _header = _indent + name 69 | _underline = _indent + underline * len(name) + "\n" 70 | 71 | return _join([_header, _underline]) 72 | 73 | 74 | def enumeration(values: ENUM, indent: int = 0, bullet_char: str = "-", 75 | align_values: bool = True, align_width: int = 0) -> str: 76 | """Create enumeration with bullet points. 77 | 78 | Parameters 79 | ---------- 80 | values: list, tuple, dict 81 | Iterable vales. If dict, creates key/value pairs.. 82 | indent: int, optional 83 | Indentation count. 84 | bullet_char: str, optional 85 | Bullet character. 86 | align_values: bool, optional 87 | If dict is provided, align all identifiers to the same column. The 88 | longest key defines the exact position. 89 | align_width: int, optional 90 | If dict is provided and `align_values` is True, manually set the align 91 | width. 92 | 93 | Returns 94 | ------- 95 | enumeration: str 96 | 97 | """ 98 | 99 | if isinstance(values, dict): 100 | fstring = "{key:>{align_width}}: {value}" 101 | if align_values and not align_width: 102 | align_width = max([len(x) for x in values.keys()]) 103 | 104 | _values = [fstring.format(key=key, 105 | value=value, 106 | align_width=align_width) 107 | 108 | for key, value in sorted(values.items())] 109 | else: 110 | _values = values 111 | 112 | with_bullets = ["{} {}".format(bullet_char, x) for x in _values] 113 | indented = _indent(with_bullets, indent) 114 | 115 | return _join(indented) 116 | 117 | 118 | def pretty_file_size(size: float, precision: int = 2, align: str = ">", 119 | width: int = 0) -> str: 120 | """Helper function to format size in human readable format. 121 | 122 | Parameters 123 | ---------- 124 | size: float 125 | The size in bytes to be converted into human readable format. 126 | precision: int, optional 127 | Define shown precision. 128 | align: {'<', '^', '>'}, optional 129 | Format align specifier. 130 | width: int 131 | Define maximum width for number. 132 | 133 | Returns 134 | ------- 135 | human_fmt: str 136 | Human readable representation of given `size`. 137 | 138 | Notes 139 | ----- 140 | Credit to https://stackoverflow.com/questions/1094841/reusable-library-to-get-human-readable-version-of-file-size 141 | 142 | """ # noqa: E501 143 | 144 | template = "{size:{align}{width}.{precision}f} {unit}B" 145 | kwargs = dict(width=width, precision=precision, align=align) 146 | 147 | # iterate units (multiples of 1024 bytes) 148 | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: 149 | if abs(size) < 1024.0: 150 | return template.format(size=size, unit=unit, **kwargs) 151 | size /= 1024.0 152 | 153 | return template.format(size=size, unit='Yi', **kwargs) 154 | 155 | 156 | def pretty_time_duration(seconds: float, precision: int = 1, align: str = ">", 157 | width: int = 0) -> str: 158 | """Helper function to format time duration in human readable format. 159 | 160 | Parameters 161 | ---------- 162 | seconds: float 163 | The size in seconds to be converted into human readable format. 164 | precision: int, optional 165 | Define shown precision. 166 | align: {'<', '^', '>'}, optional 167 | Format align specifier. 168 | width: int 169 | Define maximum width for number. 170 | 171 | Returns 172 | ------- 173 | human_fmt: str 174 | Human readable representation of given `seconds`. 175 | 176 | """ 177 | 178 | template = "{time_delta:{align}{width}.{precision}f} {unit}" 179 | 180 | units = [('year', 60 * 60 * 24 * 365), 181 | ('month', 60 * 60 * 24 * 30), 182 | ('d', 60 * 60 * 24), 183 | ('h', 60 * 60), 184 | ('min', 60), 185 | ('s', 1), 186 | ('ms', 1e-3), 187 | ('µs', 1e-6), 188 | ('ns', 1e-9)] 189 | 190 | # catch 0 value 191 | if seconds == 0: 192 | return template.format(time_delta=0, 193 | align=align, 194 | width=width, 195 | precision=0, 196 | unit="s") 197 | 198 | # catch negative value 199 | if seconds < 0: 200 | sign = -1 201 | seconds = abs(seconds) 202 | else: 203 | sign = 1 204 | 205 | for unit_name, unit_seconds in units: 206 | if seconds > unit_seconds: 207 | time_delta = seconds / unit_seconds 208 | return template.format(time_delta=sign * time_delta, 209 | align=align, 210 | width=width, 211 | precision=precision, 212 | unit=unit_name) 213 | 214 | 215 | def textwrap_docstring(dobject: Any, width: int = 70) -> List[str]: 216 | """Extract doc string from object and textwrap it with given width. Remove 217 | double whitespaces. 218 | 219 | Parameters 220 | ---------- 221 | dobject: Any 222 | Object to extract doc string from. 223 | width: int, optional 224 | Length of text values to wrap doc string. 225 | 226 | Returns 227 | ------- 228 | Wrapped doc string as list of lines. 229 | 230 | """ 231 | 232 | if not dobject.__doc__: 233 | return [] 234 | 235 | sanitized = REGEX_REMOVE_WHITESPACES.sub(" ", dobject.__doc__).strip() 236 | return textwrap.wrap(sanitized, width=width) 237 | 238 | 239 | def truncate(string: str, width: int, ending: str = "...") -> str: 240 | """Truncate string to be no longer than provided width. When truncated, add 241 | add `ending` to shortened string as indication of truncation. 242 | 243 | Parameters 244 | ---------- 245 | string: str 246 | String to be truncated. 247 | width: int 248 | Maximum amount of characters before truncation. 249 | ending: str, optional 250 | Indication string of truncation. 251 | 252 | Returns 253 | ------- 254 | Truncated string. 255 | 256 | """ 257 | 258 | if not len(string) > width: 259 | return string 260 | 261 | length = width - len(ending) 262 | return string[:length] + ending 263 | -------------------------------------------------------------------------------- /src/pywrangler/util/dependencies.py: -------------------------------------------------------------------------------- 1 | """This module contains functionality to check optional and mandatory imports. 2 | It aims to provide useful error messages if optional dependencies are missing. 3 | """ 4 | 5 | import importlib 6 | import sys 7 | from functools import wraps 8 | from typing import Callable 9 | 10 | 11 | def raise_if_missing(import_name): 12 | """Checks for available import and raises with more detailed error 13 | message if not given. 14 | 15 | Parameters 16 | ---------- 17 | import_name: str 18 | 19 | """ 20 | try: 21 | importlib.import_module(import_name) 22 | 23 | except ImportError as e: 24 | msg = ("The requested functionality requires '{dep}'. " 25 | "However, '{dep}' is not available in the current " 26 | "environment with the following interpreter: " 27 | "'{interpreter}'. Please install '{dep}' first.\n\n" 28 | .format(dep=import_name, interpreter=sys.executable)) 29 | 30 | raise type(e)(msg) from e 31 | 32 | 33 | def requires(*deps: str) -> Callable: 34 | """Decorator for callables to ensure that required dependencies are met. 35 | Provides more useful error message if dependency is missing. 36 | 37 | Parameters 38 | ---------- 39 | deps: list 40 | List of dependencies to check. 41 | 42 | Returns 43 | ------- 44 | decorated: callable 45 | 46 | Examples 47 | -------- 48 | 49 | >>> @requires("dep1", "dep2") 50 | >>> def func(a): 51 | >>> return a 52 | 53 | """ 54 | 55 | def decorator(func): 56 | @wraps(func) 57 | def wrapper(*args, **kwargs): 58 | for dep in deps: 59 | raise_if_missing(dep) 60 | return func(*args, **kwargs) 61 | 62 | return wrapper 63 | 64 | return decorator 65 | 66 | 67 | def is_available(*deps: str) -> bool: 68 | """Check if given dependencies are available. 69 | 70 | Parameters 71 | ---------- 72 | deps: list 73 | List of dependencies to check. 74 | 75 | Returns 76 | ------- 77 | available: bool 78 | 79 | """ 80 | 81 | for dep in deps: 82 | try: 83 | importlib.import_module(dep) 84 | except ImportError: 85 | return False 86 | 87 | return True 88 | -------------------------------------------------------------------------------- /src/pywrangler/util/helper.py: -------------------------------------------------------------------------------- 1 | """This module contains commonly used helper functions or classes. 2 | 3 | """ 4 | 5 | import inspect 6 | from typing import Callable, List 7 | 8 | from pywrangler.util.types import T_STR_OPT_MUL 9 | 10 | 11 | def get_param_names(func: Callable, 12 | ignore: T_STR_OPT_MUL = None) -> List[str]: 13 | """Retrieve all parameter names for given function. 14 | 15 | Parameters 16 | ---------- 17 | func: Callable 18 | Function for which parameter names should be retrieved. 19 | ignore: iterable, None, optional 20 | Parameter names to be ignored. For example, `self` for `__init__` 21 | functions. 22 | 23 | Returns 24 | ------- 25 | param_names: list 26 | List of parameter names. 27 | 28 | """ 29 | 30 | ignore = ignore or [] 31 | 32 | signature = inspect.signature(func) 33 | parameters = signature.parameters.values() 34 | 35 | param_names = [x.name for x in parameters if x.name not in ignore] 36 | 37 | return param_names 38 | -------------------------------------------------------------------------------- /src/pywrangler/util/sanitizer.py: -------------------------------------------------------------------------------- 1 | """This module contains common helper functions for sanity checks and 2 | conversions. 3 | 4 | """ 5 | 6 | import collections 7 | from typing import Any, List, Tuple, Type, Union, Optional 8 | 9 | import pandas as pd 10 | 11 | ITER_TYPE = Optional[Union[List[Any], Tuple[Any]]] 12 | 13 | 14 | # TODO: Use generic for sequence type 15 | 16 | def ensure_iterable(values: Any, seq_type: Type = list, 17 | retain_none: bool = False) -> ITER_TYPE: 18 | """For convenience, some parameters may accept a single value (string 19 | for a column name) or multiple values (list of strings for column 20 | names). Other functions always require a list or tuple of strings. Hence, 21 | this function ensures that the output is always an iterable of given 22 | `constructor` type (list or tuple) while taking care of exceptions like 23 | strings. 24 | 25 | Parameters 26 | ---------- 27 | values: Any 28 | Input values to be converted to tuples. 29 | seq_type: type 30 | Define return container type. 31 | retain_none: bool, optional 32 | Define behaviour if None is passed. If True, returns None. If False, 33 | returns empty 34 | 35 | Returns 36 | ------- 37 | iterable: seq_type 38 | 39 | """ 40 | 41 | # None remains None 42 | if values is None: 43 | if retain_none: 44 | return None 45 | else: 46 | return seq_type() 47 | 48 | # if not iterable, return iterable with single value 49 | elif not isinstance(values, collections.abc.Iterable): 50 | return seq_type([values]) 51 | 52 | # handle exception which are iterable but still count as one value 53 | elif isinstance(values, (str, pd.DataFrame)): 54 | return seq_type([values]) 55 | 56 | # anything else should ok to be converted to tuple/list 57 | else: 58 | return seq_type(values) 59 | -------------------------------------------------------------------------------- /src/pywrangler/util/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from pywrangler.util.testing.datatestcase import DataTestCase, TestCollection 2 | from pywrangler.util.testing.plainframe import NULL, NaN, PlainFrame 3 | from pywrangler.util.testing.mutants import ValueMutant, FunctionMutant, \ 4 | RandomMutant, MutantCollection 5 | 6 | # Container for references during tests 7 | TEST_VARS = {} -------------------------------------------------------------------------------- /src/pywrangler/util/testing/util.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | 4 | def concretize_abstract_wrangler(abstract_class: Type) -> Type: 5 | """Makes abstract wrangler classes instantiable for testing purposes by 6 | implementing abstract methods of `BaseWrangler`. 7 | 8 | Parameters 9 | ---------- 10 | abstract_class: Type 11 | Class object to inherit from while overriding abstract methods. 12 | 13 | Returns 14 | ------- 15 | concrete_class: Type 16 | Concrete class usable for testing. 17 | 18 | """ 19 | 20 | class ConcreteWrangler(abstract_class): 21 | 22 | @property 23 | def preserves_sample_size(self): 24 | return super().preserves_sample_size 25 | 26 | @property 27 | def computation_engine(self): 28 | return super().computation_engine 29 | 30 | def fit(self, *args, **kwargs): 31 | return super().fit(*args, **kwargs) 32 | 33 | def fit_transform(self, *args, **kwargs): 34 | return super().fit_transform(*args, **kwargs) 35 | 36 | def transform(self, *args, **kwargs): 37 | return super().transform(*args, **kwargs) 38 | 39 | ConcreteWrangler.__name__ = abstract_class.__name__ 40 | ConcreteWrangler.__doc__ = abstract_class.__doc__ 41 | 42 | return ConcreteWrangler -------------------------------------------------------------------------------- /src/pywrangler/util/types.py: -------------------------------------------------------------------------------- 1 | """This module contains type definitions. 2 | 3 | """ 4 | 5 | from typing import Iterable, Union, Optional 6 | 7 | T_STR_OPT_MUL = Optional[Iterable[str]] 8 | T_STR_OPT_SING_MUL = Optional[Union[str, Iterable[str]]] 9 | 10 | TYPE_COLUMNS = T_STR_OPT_SING_MUL 11 | TYPE_ASCENDING = Union[bool, Iterable[bool]] 12 | -------------------------------------------------------------------------------- /src/pywrangler/wranglers.py: -------------------------------------------------------------------------------- 1 | """This module contains computation engine independent wrangler interfaces 2 | and corresponding descriptions. 3 | 4 | """ 5 | from typing import Any 6 | 7 | from pywrangler.base import BaseWrangler 8 | from pywrangler.util import sanitizer 9 | from pywrangler.util.types import TYPE_ASCENDING, TYPE_COLUMNS 10 | 11 | NONEVALUE = object() 12 | 13 | 14 | class IntervalIdentifier(BaseWrangler): 15 | """Defines the reference interface for the interval identification 16 | wrangler. 17 | 18 | An interval is defined as a range of values beginning with an opening 19 | marker and ending with a closing marker (e.g. the interval daylight may be 20 | defined as all events/values occurring between sunrise and sunset). Start 21 | and end marker may be identical. 22 | 23 | The interval identification wrangler assigns ids to values such that values 24 | belonging to the same interval share the same interval id. For example, all 25 | values of the first daylight interval are assigned with id 1. All values of 26 | the second daylight interval will be assigned with id 2 and so on. 27 | 28 | By default, values which do not belong to any valid interval, are assigned 29 | the value 0 by definition (please refer to `result_type` for different 30 | result types). If start and end marker are identical or the end marker is 31 | not provided, invalid values are only possible before the first start 32 | marker is encountered. 33 | 34 | Due to messy data, start and end marker may occur multiple times in 35 | sequence until its counterpart is reached. Therefore, intervals may have 36 | different spans based on different task requirements. For example, the very 37 | first start or very last start marker may define the correct start of an 38 | interval. Accordingly, four intervals can be selected by setting 39 | `marker_start_use_first` and `marker_end_use_first`. The resulting 40 | intervals are as follows: 41 | 42 | - first start / first end 43 | - first start / last end (longest interval) 44 | - last start / first end (shortest interval) 45 | - last start / last end 46 | 47 | Opening and closing markers are included in their corresponding interval. 48 | 49 | Parameters 50 | ---------- 51 | marker_column: str 52 | Name of column which contains the opening and closing markers. 53 | marker_start: Any 54 | A value defining the start of an interval. 55 | marker_end: Any, optional 56 | A value defining the end of an interval. This value is optional. If not 57 | given, the end marker equals the start marker. 58 | marker_start_use_first: bool 59 | Identifies if the first occurring `marker_start` of an interval is used. 60 | Otherwise the last occurring `marker_start` is used. Default is False. 61 | marker_end_use_first: bool 62 | Identifies if the first occurring `marker_end` of an interval is used. 63 | Otherwise the last occurring `marker_end` is used. Default is True. 64 | orderby_columns: str, Iterable[str], optional 65 | Column names which define the order of the data (e.g. a timestamp 66 | column). Sort order can be defined with the parameter `ascending`. 67 | groupby_columns: str, Iterable[str], optional 68 | Column names which define how the data should be grouped/split into 69 | separate entities. For distributed computation engines, groupby columns 70 | should ideally reference partition keys to avoid data shuffling. 71 | ascending: bool, Iterable[bool], optional 72 | Sort ascending vs. descending. Specify list for multiple sort orders. 73 | If a list is specified, length of the list must equal length of 74 | `order_columns`. Default is True. 75 | result_type: str, optional 76 | Defines the content of the returned result. If 'raw', interval ids 77 | will be in arbitrary order with no distinction made between valid and 78 | invalid intervals. Intervals are distinguishable by interval id but the 79 | interval id may not provide any more information. If 'valid', the 80 | result is the same as 'raw' but all invalid intervals are set to 0. 81 | If 'enumerated', the result is the same as 'valid' but interval ids 82 | increase in ascending order (as defined by order) in steps of one. 83 | target_column_name: str, optional 84 | Name of the resulting target column. 85 | 86 | """ 87 | 88 | def __init__(self, 89 | marker_column: str, 90 | marker_start: Any, 91 | marker_end: Any = NONEVALUE, 92 | marker_start_use_first: bool = False, 93 | marker_end_use_first: bool = True, 94 | orderby_columns: TYPE_COLUMNS = None, 95 | groupby_columns: TYPE_COLUMNS = None, 96 | ascending: TYPE_ASCENDING = None, 97 | result_type: str = "enumerated", 98 | target_column_name: str = "iids"): 99 | 100 | self.marker_column = marker_column 101 | self.marker_start = marker_start 102 | self.marker_end = marker_end 103 | self.marker_start_use_first = marker_start_use_first 104 | self.marker_end_use_first = marker_end_use_first 105 | self.orderby_columns = sanitizer.ensure_iterable(orderby_columns) 106 | self.groupby_columns = sanitizer.ensure_iterable(groupby_columns) 107 | self.ascending = sanitizer.ensure_iterable(ascending) 108 | self.result_type = result_type 109 | self.target_column_name = target_column_name 110 | 111 | # check correct result type 112 | valid_result_types = {"raw", "valid", "enumerated"} 113 | if result_type not in valid_result_types: 114 | raise ValueError("Parameter `result_type` is invalid with: {}. " 115 | "Allowed arguments are: {}" 116 | .format(result_type, valid_result_types)) 117 | 118 | # check for identical start and end values 119 | self._identical_start_end_markers = ((marker_end == NONEVALUE) or 120 | (marker_start == marker_end)) 121 | 122 | # sanity checks for sort order 123 | if self.ascending: 124 | 125 | # check for equal number of items of order and sort columns 126 | if len(self.orderby_columns) != len(self.ascending): 127 | raise ValueError('`order_columns` and `ascending` must have ' 128 | 'equal number of items.') 129 | 130 | # check for correct sorting keywords 131 | if not all([isinstance(x, bool) for x in self.ascending]): 132 | raise ValueError('Only `True` and `False` are ' 133 | 'allowed arguments for `ascending`') 134 | 135 | # set default sort order if None is given 136 | elif self.orderby_columns: 137 | self.ascending = [True] * len(self.orderby_columns) 138 | 139 | @property 140 | def preserves_sample_size(self) -> bool: 141 | return True 142 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """pytest configuration 2 | 3 | """ 4 | 5 | import multiprocessing 6 | 7 | import pytest 8 | 9 | import pandas as pd 10 | 11 | 12 | def patch_spark_create_dataframe(): 13 | """Overwrite pyspark's default `SparkSession.createDataFrame` method to 14 | cache all test data. This has proven to be faster because identical data 15 | does not need to be converted multiple times. Out of memory should not 16 | occur because test data is very small and if a memory limit is reached, 17 | the oldest not-used dataframes should be dropped automatically. 18 | 19 | """ 20 | 21 | from pyspark.sql.session import SparkSession 22 | 23 | cache = {} 24 | 25 | def wrapper(func): 26 | def wrapped(self, data, *args, schema=None, **kwargs): 27 | # create hashable key 28 | if isinstance(data, pd.DataFrame): 29 | key = tuple(data.columns), data.values.tobytes() 30 | else: 31 | key = str(data), schema 32 | 33 | # check existent result and return cached dataframe 34 | if key in cache: 35 | return cache[key] 36 | else: 37 | result = func(self, data, *args, schema=schema, **kwargs) 38 | result.cache() 39 | cache[key] = result 40 | return result 41 | 42 | return wrapped 43 | 44 | SparkSession.createDataFrame = wrapper(SparkSession.createDataFrame) 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def spark(request): 49 | """Provide session wide Spark Session to avoid expensive recreation for 50 | each test. 51 | 52 | If pyspark is not available, skip tests. 53 | 54 | """ 55 | 56 | try: 57 | patch_spark_create_dataframe() 58 | 59 | from pyspark.sql import SparkSession 60 | spark = SparkSession.builder.getOrCreate() 61 | 62 | # use pyarrow if available for pandas to pyspark communication 63 | spark.conf.set("pyspark.sql.execution.arrow.enabled", "true") 64 | 65 | # for testing, reduce the number of partitions to the number of cores 66 | cpu_count = multiprocessing.cpu_count() 67 | spark.conf.set("spark.sql.shuffle.partitions", cpu_count) 68 | 69 | # print pyspark ui url 70 | print("\nPySpark UiWebUrl:", spark.sparkContext.uiWebUrl, "\n") 71 | 72 | request.addfinalizer(spark.stop) 73 | return spark 74 | 75 | except ImportError: 76 | pytest.skip("Pyspark not available.") 77 | -------------------------------------------------------------------------------- /tests/dask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/dask/__init__.py -------------------------------------------------------------------------------- /tests/dask/test_base.py: -------------------------------------------------------------------------------- 1 | """Test dask base wrangler. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import pytest 7 | 8 | from pywrangler.util.testing.util import concretize_abstract_wrangler 9 | 10 | pytestmark = pytest.mark.dask # noqa: E402 11 | dask = pytest.importorskip("dask") # noqa: E402 12 | 13 | from pywrangler.dask.base import DaskWrangler 14 | 15 | 16 | def test_dask_base_wrangler_engine(): 17 | wrangler = concretize_abstract_wrangler(DaskWrangler)() 18 | 19 | assert wrangler.computation_engine == "dask" 20 | -------------------------------------------------------------------------------- /tests/dask/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for dask benchmarks. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import time 7 | 8 | import pytest 9 | import pandas as pd 10 | import numpy as np 11 | 12 | from pywrangler.util.testing.util import concretize_abstract_wrangler 13 | 14 | pytestmark = pytest.mark.dask # noqa: E402 15 | dask = pytest.importorskip("dask") # noqa: E402 16 | 17 | from dask import dataframe as dd 18 | 19 | from pywrangler.benchmark import allocate_memory 20 | from pywrangler.dask.benchmark import ( 21 | DaskTimeProfiler, 22 | DaskMemoryProfiler, 23 | DaskBaseProfiler 24 | ) 25 | from pywrangler.dask.base import DaskSingleNoFit 26 | 27 | 28 | @pytest.fixture 29 | def mean_wranger(): 30 | class DummyWrangler(DaskSingleNoFit): 31 | def transform(self, df): 32 | return df.mean() 33 | 34 | return concretize_abstract_wrangler(DummyWrangler)() 35 | 36 | 37 | @pytest.fixture 38 | def test_wrangler(): 39 | """Helper fixture to generate DaskWrangler instances with parametrization 40 | of transform output and sleep. 41 | 42 | """ 43 | 44 | def create_wrangler(size=None, result=None, sleep=0): 45 | """Return instance of DaskWrangler. 46 | 47 | Parameters 48 | ---------- 49 | size: float 50 | Memory size in MiB to allocate during transform step. 51 | result: Dask DataFrame 52 | Define extact return value of transform step. 53 | sleep: float 54 | Define sleep interval. 55 | 56 | """ 57 | 58 | class DummyWrangler(DaskSingleNoFit): 59 | def transform(self, df): 60 | if size is not None: 61 | pdf = pd.DataFrame(allocate_memory(size)) 62 | df_out = dd.from_pandas(pdf) 63 | elif result is not None: 64 | df_out = result 65 | else: 66 | df_out = dd.from_pandas(pd.DataFrame([0]), 1) 67 | 68 | time.sleep(sleep) 69 | return df_out 70 | 71 | return concretize_abstract_wrangler(DummyWrangler)() 72 | 73 | return create_wrangler 74 | 75 | 76 | def test_dask_base_profiler_wrap_fit_transform(test_wrangler): 77 | pdf = pd.DataFrame(np.random.rand(50, 5)) 78 | df = dd.from_pandas(pdf, 5).max().max() 79 | 80 | profiler = DaskTimeProfiler(wrangler=test_wrangler(result=df), 81 | repetitions=1) 82 | 83 | wrapped = profiler._wrap_fit_transform() 84 | 85 | assert callable(wrapped) 86 | assert wrapped(df) == pdf.max().max() 87 | 88 | 89 | def test_dask_base_profiler_cache_input(): 90 | class MockPersist: 91 | def persist(self): 92 | self.persist_called = True 93 | return self 94 | 95 | dask_mocks = [MockPersist(), MockPersist()] 96 | 97 | persisted = DaskBaseProfiler._cache_input(dask_mocks) 98 | 99 | assert all([x.persist_called for x in persisted]) 100 | 101 | 102 | def test_dask_base_profiler_clear_cache_input(): 103 | pdf = pd.DataFrame(np.random.rand(50, 5)) 104 | 105 | with pytest.warns(None) as record: 106 | DaskBaseProfiler._clear_cached_input([dd.from_pandas(pdf, 5)]) 107 | assert len(record) == 0 108 | 109 | df = dd.from_pandas(pdf, 5) 110 | ref = df # noqa: F841 111 | 112 | with pytest.warns(ResourceWarning): 113 | DaskBaseProfiler._clear_cached_input([df]) 114 | 115 | 116 | def test_dask_time_profiler_fastest(test_wrangler): 117 | """Basic test for dask time profiler ensuring fastest timing is slower 118 | than forced sleep. 119 | 120 | """ 121 | 122 | sleep = 0.001 123 | 124 | df_input = dd.from_pandas(pd.DataFrame(np.random.rand(10, 10)), 2) 125 | 126 | time_profiler = DaskTimeProfiler(wrangler=test_wrangler(sleep=sleep), 127 | repetitions=1, 128 | cache_input=True) 129 | 130 | assert time_profiler.profile(df_input).best >= sleep 131 | 132 | 133 | def test_dask_time_profiler_profile_return_self(test_wrangler): 134 | df_input = dd.from_pandas(pd.DataFrame(np.random.rand(10, 10)), 2) 135 | 136 | time_profiler = DaskTimeProfiler(wrangler=test_wrangler(), 137 | repetitions=1) 138 | 139 | assert time_profiler.profile(df_input) is time_profiler 140 | 141 | 142 | def test_dask_time_profiler_cached_faster(mean_wranger): 143 | pdf = pd.DataFrame(np.random.rand(1000000, 10)) 144 | df_input = dd.from_pandas(pdf, 2).mean() 145 | 146 | time_profiler_no_cache = DaskTimeProfiler(wrangler=mean_wranger, 147 | repetitions=5, 148 | cache_input=False) 149 | 150 | time_profiler_cache = DaskTimeProfiler(wrangler=mean_wranger, 151 | repetitions=5, 152 | cache_input=True) 153 | 154 | no_cache_time = time_profiler_no_cache.profile(df_input).median 155 | cache_time = time_profiler_cache.profile(df_input).median 156 | 157 | assert no_cache_time > cache_time 158 | 159 | 160 | def test_dask_memory_profiler_profile_return_self(test_wrangler): 161 | df_input = dd.from_pandas(pd.DataFrame(np.random.rand(10, 10)), 2) 162 | 163 | mem_profiler = DaskMemoryProfiler(wrangler=test_wrangler(), 164 | repetitions=1) 165 | 166 | assert mem_profiler.profile(df_input) is mem_profiler 167 | assert mem_profiler.runs == 1 168 | 169 | 170 | @pytest.mark.xfail(reason="Succeeds locally but sometimes fails remotely due " 171 | "to non deterministic memory management.") 172 | def test_dask_memory_profiler_cached_lower_usage(mean_wranger): 173 | pdf = pd.DataFrame(np.random.rand(1000000, 10)) 174 | df_input = dd.from_pandas(pdf, 5).mean() 175 | 176 | mem_profiler_no_cache = DaskMemoryProfiler(wrangler=mean_wranger, 177 | repetitions=5, 178 | cache_input=False, 179 | interval=0.00001) 180 | 181 | mem_profiler_cache = DaskMemoryProfiler(wrangler=mean_wranger, 182 | repetitions=5, 183 | cache_input=True, 184 | interval=0.00001) 185 | 186 | no_cache_usage = mem_profiler_no_cache.profile(df_input).median 187 | cache_usage = mem_profiler_cache.profile(df_input).median 188 | 189 | assert no_cache_usage > cache_usage 190 | -------------------------------------------------------------------------------- /tests/pandas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/pandas/__init__.py -------------------------------------------------------------------------------- /tests/pandas/test_base.py: -------------------------------------------------------------------------------- 1 | """Test pandas base wrangler. 2 | 3 | """ 4 | 5 | import pytest 6 | 7 | import pandas as pd 8 | 9 | from pywrangler.pandas.base import PandasWrangler 10 | from pywrangler.util.testing.util import concretize_abstract_wrangler 11 | 12 | pytestmark = pytest.mark.pandas 13 | 14 | 15 | def test_pandas_base_wrangler_engine(): 16 | wrangler = concretize_abstract_wrangler(PandasWrangler)() 17 | 18 | assert wrangler.computation_engine == "pandas" 19 | 20 | 21 | @pytest.mark.parametrize("preserves_sample_size", [True, False]) 22 | def test_pandas_wrangler_validate_output_shape_raises(preserves_sample_size): 23 | class DummyWrangler(PandasWrangler): 24 | @property 25 | def preserves_sample_size(self): 26 | return preserves_sample_size 27 | 28 | wrangler = concretize_abstract_wrangler(DummyWrangler)() 29 | 30 | df1 = pd.DataFrame([0] * 10) 31 | df2 = pd.DataFrame([0] * 20) 32 | 33 | if preserves_sample_size: 34 | with pytest.raises(ValueError): 35 | wrangler._validate_output_shape(df1, df2) 36 | else: 37 | wrangler._validate_output_shape(df1, df2) 38 | -------------------------------------------------------------------------------- /tests/pandas/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for pandas benchmarks. 2 | 3 | """ 4 | 5 | import time 6 | 7 | import pytest 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from pywrangler.benchmark import allocate_memory 13 | from pywrangler.pandas.base import PandasSingleNoFit 14 | from pywrangler.pandas.benchmark import ( 15 | PandasMemoryProfiler, 16 | PandasTimeProfiler 17 | ) 18 | from pywrangler.util.testing.util import concretize_abstract_wrangler 19 | 20 | pytestmark = pytest.mark.pandas 21 | 22 | MIB = 2 ** 20 23 | 24 | 25 | @pytest.fixture 26 | def test_wrangler(): 27 | """Helper fixture to generate PandasWrangler instances with parametrization 28 | of transform output and sleep. 29 | 30 | """ 31 | 32 | def create_wrangler(size=None, result=None, sleep=0): 33 | """Return instance of PandasWrangler. 34 | 35 | Parameters 36 | ---------- 37 | size: float 38 | Memory size in MiB to allocate during transform step. 39 | result: pd.DataFrame 40 | Define extact return value of transform step. 41 | sleep: float 42 | Define sleep interval. 43 | 44 | """ 45 | 46 | class DummyWrangler(PandasSingleNoFit): 47 | def transform(self, df): 48 | if size is not None: 49 | df_out = pd.DataFrame(allocate_memory(size)) 50 | else: 51 | df_out = pd.DataFrame(result) 52 | 53 | time.sleep(sleep) 54 | return df_out 55 | 56 | return concretize_abstract_wrangler(DummyWrangler)() 57 | 58 | return create_wrangler 59 | 60 | 61 | def test_pandas_memory_profiler_memory_usage_dfs(): 62 | df1 = pd.DataFrame(np.random.rand(10)) 63 | df2 = pd.DataFrame(np.random.rand(10)) 64 | 65 | test_input = [df1, df2] 66 | test_output = int(df1.memory_usage(index=True, deep=True).sum() + 67 | df2.memory_usage(index=True, deep=True).sum()) 68 | 69 | assert PandasMemoryProfiler._memory_usage_dfs(*test_input) == test_output 70 | 71 | 72 | def test_pandas_memory_profiler_return_self(test_wrangler): 73 | memory_profiler = PandasMemoryProfiler(test_wrangler()) 74 | 75 | assert memory_profiler is memory_profiler.profile(pd.DataFrame()) 76 | 77 | 78 | @pytest.mark.xfail(reason="Succeeds locally but sometimes fails remotely due " 79 | "to non deterministic memory management.") 80 | def test_pandas_memory_profiler_usage_median(test_wrangler): 81 | wrangler = test_wrangler(size=30, sleep=0.01) 82 | memory_profiler = PandasMemoryProfiler(wrangler) 83 | 84 | assert memory_profiler.profile(pd.DataFrame()).median > 29 * MIB 85 | 86 | 87 | def test_pandas_memory_profiler_usage_input_output(test_wrangler): 88 | df_input = pd.DataFrame(np.random.rand(1000)) 89 | df_output = pd.DataFrame(np.random.rand(10000)) 90 | 91 | test_df_input = df_input.memory_usage(index=True, deep=True).sum() 92 | test_df_output = df_output.memory_usage(index=True, deep=True).sum() 93 | 94 | wrangler = test_wrangler(result=df_output) 95 | memory_profiler = PandasMemoryProfiler(wrangler).profile(df_input) 96 | 97 | assert memory_profiler.input == test_df_input 98 | assert memory_profiler.output == test_df_output 99 | 100 | 101 | @pytest.mark.xfail(reason="Succeeds locally but sometimes fails remotely due " 102 | "to non deterministic memory management.") 103 | def test_pandas_memory_profiler_ratio(test_wrangler): 104 | usage_mib = 30 105 | df_input = pd.DataFrame(np.random.rand(1000000)) 106 | usage_input = df_input.memory_usage(index=True, deep=True).sum() 107 | test_output = ((usage_mib - 1) * MIB) / usage_input 108 | 109 | wrangler = test_wrangler(size=usage_mib, sleep=0.01) 110 | 111 | memory_profiler = PandasMemoryProfiler(wrangler) 112 | 113 | assert memory_profiler.profile(df_input).ratio > test_output 114 | 115 | 116 | def test_pandas_time_profiler_best(test_wrangler): 117 | """Basic test for pandas time profiler ensuring fastest timing is slower 118 | than forced sleep. 119 | 120 | """ 121 | 122 | sleep = 0.0001 123 | wrangler = test_wrangler(sleep=sleep) 124 | time_profiler = PandasTimeProfiler(wrangler, 1).profile(pd.DataFrame()) 125 | 126 | assert time_profiler.best >= sleep 127 | -------------------------------------------------------------------------------- /tests/pandas/test_util.py: -------------------------------------------------------------------------------- 1 | """This module contains pandas wrangler utility tests. 2 | 3 | """ 4 | 5 | import pytest 6 | 7 | import pandas as pd 8 | 9 | from pywrangler.pandas import util 10 | 11 | 12 | def test_validate_empty_df_raises(): 13 | df = pd.DataFrame() 14 | 15 | with pytest.raises(ValueError): 16 | util.validate_empty_df(df) 17 | 18 | 19 | def test_validate_empty_df_not_raises(): 20 | df = pd.DataFrame([0, 0]) 21 | 22 | util.validate_empty_df(df) 23 | 24 | 25 | def test_validate_columns_raises(): 26 | df = pd.DataFrame(columns=["col1", "col2"]) 27 | 28 | with pytest.raises(ValueError): 29 | util.validate_columns(df, ("col3", "col1")) 30 | 31 | 32 | def test_validate_columns_not_raises(): 33 | df = pd.DataFrame(columns=["col1", "col2"]) 34 | 35 | util.validate_columns(df, ("col1", "col2")) 36 | 37 | 38 | def test_sort_values(): 39 | values = list(range(10)) 40 | df = pd.DataFrame({"col1": values, 41 | "col2": values}) 42 | 43 | # no sort order given 44 | assert df is util.sort_values(df, [], []) 45 | 46 | # sort order 47 | computed = util.sort_values(df, ["col1"], [False]) 48 | assert df.sort_values("col1", ascending=False).equals(computed) 49 | 50 | 51 | def test_groupby(): 52 | values = list(range(10)) 53 | df = pd.DataFrame({"col1": values, 54 | "col2": values}) 55 | 56 | conv = lambda x: {key: value.values.tolist() 57 | for key, value in x.groups.items()} 58 | 59 | # no groupby given 60 | expected = df.groupby([0]*len(values)) 61 | given = util.groupby(df, []) 62 | assert conv(expected) == conv(given) 63 | 64 | # with groupby 65 | expected = df.groupby("col1") 66 | given = util.groupby(df, ["col1"]) 67 | assert conv(expected) == conv(given) 68 | 69 | -------------------------------------------------------------------------------- /tests/pandas/wranglers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/pandas/wranglers/__init__.py -------------------------------------------------------------------------------- /tests/pandas/wranglers/test_interval_identifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | 4 | from tests.test_data.interval_identifier import ( 5 | CollectionGeneral, 6 | CollectionIdenticalStartEnd, 7 | CollectionMarkerSpecifics, 8 | ResultTypeRawIids, 9 | ResultTypeValidIids, 10 | CollectionNoOrderGroupBy) 11 | 12 | from pywrangler.pandas.wranglers.interval_identifier import ( 13 | NaiveIterator, 14 | VectorizedCumSum 15 | ) 16 | 17 | pytestmark = pytest.mark.pandas 18 | 19 | WRANGLER = (NaiveIterator, VectorizedCumSum) 20 | WRANGLER_IDS = [x.__name__ for x in WRANGLER] 21 | WRANGLER_KWARGS = dict(argnames='wrangler', 22 | argvalues=WRANGLER, 23 | ids=WRANGLER_IDS) 24 | 25 | 26 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 27 | @CollectionGeneral.pytest_parametrize_kwargs("marker_use") 28 | @CollectionGeneral.pytest_parametrize_testcases 29 | def test_base(testcase, wrangler, marker_use): 30 | """Tests against all available wranglers and test cases . 31 | Parameters 32 | ---------- 33 | test_case: function 34 | Generates test data for given test case. 35 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 36 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 37 | marker_use: dict 38 | Defines the marker start/end use. 39 | """ 40 | 41 | # instantiate test case 42 | testcase_instance = testcase("pandas") 43 | 44 | # instantiate wrangler 45 | kwargs = testcase_instance.test_kwargs.copy() 46 | kwargs.update(marker_use) 47 | wrangler_instance = wrangler(**kwargs) 48 | 49 | # pass wrangler to test case 50 | kwargs = dict(merge_input=True, 51 | force_dtypes={"marker": testcase_instance.marker_dtype}) 52 | testcase_instance.test(wrangler_instance.transform, **kwargs) 53 | 54 | 55 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 56 | @CollectionIdenticalStartEnd.pytest_parametrize_testcases 57 | def test_identical_start_end(testcase, wrangler): 58 | """Tests against all available wranglers and test cases . 59 | Parameters 60 | ---------- 61 | test_case: function 62 | Generates test data for given test case. 63 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 64 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 65 | """ 66 | 67 | # instantiate test case 68 | testcase_instance = testcase("pandas") 69 | 70 | # instantiate wrangler 71 | wrangler_instance = wrangler(**testcase_instance.test_kwargs) 72 | 73 | # pass wrangler to test case 74 | kwargs = dict(merge_input=True, 75 | force_dtypes={"marker": testcase_instance.marker_dtype}) 76 | testcase_instance.test(wrangler_instance.transform, **kwargs) 77 | 78 | 79 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 80 | @CollectionMarkerSpecifics.pytest_parametrize_testcases 81 | def test_marker_specifics(testcase, wrangler): 82 | """Tests specific `marker_start_use_first` and `marker_end_use_first` 83 | scenarios. 84 | 85 | Parameters 86 | ---------- 87 | testcase: DataTestCase 88 | Generates test data for given test case. 89 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 90 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 91 | 92 | """ 93 | 94 | # instantiate test case 95 | testcase_instance = testcase("pandas") 96 | 97 | # instantiate wrangler 98 | wrangler_instance = wrangler(**testcase_instance.test_kwargs) 99 | 100 | # pass wrangler to test case 101 | kwargs = dict(merge_input=True, 102 | force_dtypes={"marker": testcase_instance.marker_dtype}) 103 | testcase_instance.test(wrangler_instance.transform, **kwargs) 104 | 105 | 106 | @CollectionGeneral.pytest_parametrize_kwargs("marker_use") 107 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 108 | def test_result_type_raw_iids(wrangler, marker_use): 109 | """Test for correct raw iids constraints. Returned result only needs to 110 | distinguish intervals regardless of their validity. Interval ids do not 111 | need to be in specific order. 112 | 113 | Parameters 114 | ---------- 115 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 116 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 117 | marker_use: dict 118 | Contains `marker_start_use_first` and `marker_end_use_first` parameters 119 | as dict. 120 | 121 | """ 122 | 123 | testcase_instance = ResultTypeRawIids("pandas") 124 | kwargs = testcase_instance.test_kwargs.copy() 125 | kwargs.update(marker_use) 126 | 127 | wrangler_instance = wrangler(result_type="raw", **kwargs) 128 | 129 | df_input = testcase_instance.input.to_pandas() 130 | df_output = testcase_instance.output.to_pandas() 131 | df_result = wrangler_instance.transform(df_input) 132 | 133 | col = testcase_instance.target_column_name 134 | pd.testing.assert_series_equal(df_result[col].diff().ne(0), 135 | df_output[col].diff().ne(0)) 136 | 137 | 138 | @CollectionGeneral.pytest_parametrize_kwargs("marker_use") 139 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 140 | def test_result_type_valid_iids(wrangler, marker_use): 141 | """Test for correct valid iids constraints. Returned result needs to 142 | distinguish valid from invalid intervals. Invalid intervals need to be 0. 143 | 144 | Parameters 145 | ---------- 146 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 147 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 148 | marker_use: dict 149 | Contains `marker_start_use_first` and `marker_end_use_first` parameters 150 | as dict. 151 | 152 | """ 153 | 154 | testcase_instance = ResultTypeValidIids("pandas") 155 | kwargs = testcase_instance.test_kwargs.copy() 156 | kwargs.update(marker_use) 157 | 158 | wrangler_instance = wrangler(result_type="valid", **kwargs) 159 | 160 | df_input = testcase_instance.input.to_pandas() 161 | df_output = testcase_instance.output.to_pandas() 162 | df_result = wrangler_instance.transform(df_input) 163 | 164 | col = testcase_instance.target_column_name 165 | pd.testing.assert_series_equal(df_result[col].diff().ne(0), 166 | df_output[col].diff().ne(0)) 167 | 168 | pd.testing.assert_series_equal(df_result[col].eq(0), 169 | df_output[col].eq(0)) 170 | 171 | 172 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 173 | @CollectionNoOrderGroupBy.pytest_parametrize_kwargs("missing_order_group_by") 174 | @CollectionNoOrderGroupBy.pytest_parametrize_testcases 175 | def test_no_order_groupby(testcase, missing_order_group_by, wrangler): 176 | """Tests correct behaviour for missing groupby columns. 177 | 178 | Parameters 179 | ---------- 180 | testcase: DataTestCase 181 | Generates test data for given test case. 182 | missing_order_group_by: dict 183 | Defines `orderby_columns` and `groupby_columns`. 184 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 185 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 186 | 187 | """ 188 | 189 | # instantiate test case 190 | testcase_instance = testcase("pandas") 191 | 192 | # instantiate wrangler 193 | wrangler_kwargs = testcase_instance.test_kwargs.copy() 194 | wrangler_kwargs.update(missing_order_group_by) 195 | wrangler_instance = wrangler(**wrangler_kwargs) 196 | 197 | # pass wrangler to test case 198 | kwargs = dict(merge_input=True, 199 | force_dtypes={"marker": testcase_instance.marker_dtype}) 200 | testcase_instance.test(wrangler_instance.transform, **kwargs) 201 | -------------------------------------------------------------------------------- /tests/pyspark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/pyspark/__init__.py -------------------------------------------------------------------------------- /tests/pyspark/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/pyspark/conftest.py -------------------------------------------------------------------------------- /tests/pyspark/test_base.py: -------------------------------------------------------------------------------- 1 | """Test pyspark base wrangler. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import pytest 7 | 8 | from pywrangler.util.testing.util import concretize_abstract_wrangler 9 | 10 | pytestmark = pytest.mark.pyspark # noqa: E402 11 | pyspark = pytest.importorskip("pyspark") # noqa: E402 12 | 13 | from pywrangler.pyspark.base import PySparkWrangler 14 | 15 | 16 | def test_spark_base_wrangler_engine(): 17 | wrangler = concretize_abstract_wrangler(PySparkWrangler)() 18 | 19 | assert wrangler.computation_engine == "pyspark" 20 | -------------------------------------------------------------------------------- /tests/pyspark/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for pyspark benchmarks. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import time 7 | 8 | import pytest 9 | 10 | from pywrangler.util.testing.util import concretize_abstract_wrangler 11 | 12 | pytestmark = pytest.mark.pyspark # noqa: E402 13 | pyspark = pytest.importorskip("pyspark") # noqa: E402 14 | 15 | from pywrangler.pyspark.base import PySparkSingleNoFit 16 | from pywrangler.pyspark.benchmark import PySparkTimeProfiler, \ 17 | PySparkBaseProfiler 18 | 19 | SLEEP = 0.0001 20 | 21 | 22 | @pytest.fixture 23 | def wrangler_sleeps(): 24 | class DummyWrangler(PySparkSingleNoFit): 25 | def transform(self, df): 26 | time.sleep(SLEEP) 27 | return df 28 | 29 | return concretize_abstract_wrangler(DummyWrangler) 30 | 31 | 32 | def test_spark_time_profiler_fastest(spark, wrangler_sleeps): 33 | """Basic test for pyspark time profiler ensuring fastest timing is slower 34 | than forced sleep. 35 | 36 | """ 37 | 38 | df_input = spark.range(10).toDF("col") 39 | 40 | time_profiler = PySparkTimeProfiler(wrangler_sleeps(), 1).profile(df_input) 41 | 42 | assert time_profiler.best >= SLEEP 43 | 44 | 45 | def test_spark_time_profiler_no_caching(spark, wrangler_sleeps): 46 | df_input = spark.range(10).toDF("col") 47 | 48 | PySparkTimeProfiler(wrangler_sleeps(), 1).profile(df_input) 49 | 50 | assert df_input.is_cached is False 51 | 52 | 53 | def test_spark_time_profiler_caching(spark, wrangler_sleeps): 54 | """Cache is released after profiling.""" 55 | df_input = spark.range(10).toDF("col") 56 | 57 | PySparkTimeProfiler(wrangler_sleeps(), 1, cache_input=True)\ 58 | .profile(df_input) 59 | 60 | assert df_input.is_cached is False 61 | 62 | 63 | def test_spark_base_profiler_cache_input(spark): 64 | df = spark.range(10).toDF("col") 65 | 66 | PySparkBaseProfiler._cache_input([df]) 67 | assert df.is_cached is True 68 | 69 | PySparkBaseProfiler._clear_cached_input([df]) 70 | assert df.is_cached is False 71 | -------------------------------------------------------------------------------- /tests/pyspark/test_environment.py: -------------------------------------------------------------------------------- 1 | """Check for working test environment. 2 | 3 | """ 4 | 5 | import os 6 | import subprocess 7 | 8 | import pytest 9 | 10 | pytestmark = pytest.mark.pyspark 11 | 12 | 13 | def test_java_environment(): 14 | """Pyspark requires Java to be available. It uses Py4J to start and 15 | communicate with the JVM. Py4J looks for JAVA_HOME or falls back calling 16 | java directly. This test explicitly checks for the java prerequisites for 17 | pyspark to work correctly. If errors occur regarding the instantiation of 18 | a pyspark session, this test helps to rule out potential java related 19 | causes. 20 | 21 | """ 22 | 23 | java_home = os.environ.get("JAVA_HOME") 24 | 25 | java_version = subprocess.run(["java", "-version"], 26 | stdout=subprocess.PIPE, 27 | stderr=subprocess.STDOUT, 28 | universal_newlines=True) 29 | 30 | if (java_home is None) and (java_version.returncode != 0): 31 | raise EnvironmentError("Java setup broken.") 32 | 33 | 34 | def test_pyspark_import(): 35 | """Fail if pyspark can't be imported. This test is mandatory because other 36 | pyspark tests will be skipped if the pyspark session fixture fails. 37 | 38 | """ 39 | 40 | try: 41 | import pyspark 42 | print(pyspark.__version__) 43 | except (ImportError, ModuleNotFoundError): 44 | pytest.fail("pyspark can't be imported") 45 | 46 | 47 | def test_pyspark_pandas_interaction(spark): 48 | """Check simple interaction between pyspark and pandas. 49 | 50 | """ 51 | 52 | import pandas as pd 53 | import numpy as np 54 | 55 | df_pandas = pd.DataFrame(np.random.rand(10, 2), columns=["a", "b"]) 56 | df_spark = spark.createDataFrame(df_pandas) 57 | df_converted = df_spark.toPandas() 58 | 59 | pd.testing.assert_frame_equal(df_pandas, df_converted) 60 | -------------------------------------------------------------------------------- /tests/pyspark/test_pipeline.py: -------------------------------------------------------------------------------- 1 | """This module tests the customized pyspark pipeline. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import pytest 7 | 8 | from pywrangler.util.testing.util import concretize_abstract_wrangler 9 | 10 | pytestmark = pytest.mark.pyspark # noqa: E402 11 | pyspark = pytest.importorskip("pyspark") # noqa: E402 12 | 13 | from pyspark.sql import functions as F 14 | from pywrangler.pyspark import pipeline 15 | from pywrangler.pyspark.pipeline import StageTransformerConverter 16 | from pywrangler.pyspark.base import PySparkSingleNoFit 17 | from pyspark.ml.param.shared import Param 18 | from pyspark.ml import Transformer 19 | 20 | 21 | @pytest.fixture 22 | def pipe(): 23 | """Create example pipeline 24 | 25 | """ 26 | 27 | def add_1(df, a=2): 28 | return df.withColumn("add1", F.col("value") + a) 29 | 30 | def add_2(df, b=4): 31 | return df.withColumn("add2", F.col("value") + b) 32 | 33 | return pipeline.Pipeline([add_1, add_2]) 34 | 35 | 36 | def test_create_getter_setter(): 37 | """Test correct creation of getter and setter methods for given name. 38 | 39 | """ 40 | 41 | result = StageTransformerConverter._create_getter_setter("Dummy") 42 | 43 | assert "getDummy" in result 44 | assert "setDummy" in result 45 | assert all([callable(x) for x in result.values()]) 46 | 47 | class MockUp: 48 | def __init__(self): 49 | self.Dummy = "Test" 50 | 51 | def _set(self, **kwargs): 52 | return kwargs 53 | 54 | def getOrDefault(self, value): 55 | return value 56 | 57 | mock = MockUp() 58 | assert result["getDummy"](mock) == "Test" 59 | assert result["setDummy"](mock, 1) == {"Dummy": 1} 60 | 61 | 62 | def test_create_param_dict(): 63 | """Test correct creation of `Param` values and setter/getter methods. 64 | 65 | """ 66 | 67 | converter = StageTransformerConverter(lambda x: None) 68 | param_keys = {"Dummy": 1}.keys() 69 | result = converter._create_param_dict(param_keys) 70 | 71 | members_callable = ["setParams", "getParams", "setDummy", "getDummy"] 72 | 73 | assert all([x in result for x in members_callable]) 74 | assert all([callable(result[x]) for x in members_callable]) 75 | assert "Dummy" in result 76 | assert isinstance(result["Dummy"], Param) 77 | 78 | class ParamMock: 79 | def __init__(self, name): 80 | self.name = name 81 | 82 | class MockUp: 83 | def _set(self, **kwargs): 84 | return kwargs 85 | 86 | def extractParamMap(self): 87 | return {ParamMock("key"): "value"} 88 | 89 | mock = MockUp() 90 | assert result["setParams"](mock, a=2) == {"a": 2} 91 | assert result["getParams"](mock) == {"key": "value"} 92 | 93 | 94 | def test_instantiate_transformer(): 95 | """Test correct instantiation of `Transformer` subclass instance. 96 | 97 | """ 98 | 99 | converter = StageTransformerConverter(lambda x: None) 100 | 101 | params = {"Dummy": 2} 102 | dicts = converter._create_param_dict(params.keys()) 103 | dicts.update({"attribute": "value"}) 104 | 105 | instance = converter._instantiate_transformer("Name", dicts, params) 106 | 107 | assert instance.__class__.__name__ == "Name" 108 | assert issubclass(instance.__class__, Transformer) 109 | assert instance.getDummy() == 2 110 | 111 | 112 | def test_wrangler_to_spark_transformer(): 113 | """Test correct pyspark wrangler to `Transformer` conversion. 114 | 115 | """ 116 | 117 | class DummyWrangler(PySparkSingleNoFit): 118 | """Test Doc""" 119 | 120 | def __init__(self, a=5): 121 | self.a = a 122 | 123 | def transform(self, number): 124 | return number + self.a 125 | 126 | stage_wrangler = concretize_abstract_wrangler(DummyWrangler)() 127 | instance = StageTransformerConverter(stage_wrangler).convert() 128 | 129 | assert issubclass(instance.__class__, Transformer) 130 | assert instance.__class__.__name__ == "DummyWrangler" 131 | assert instance.__doc__ == "Test Doc" 132 | assert instance.transform(10) == 15 133 | 134 | assert instance.geta() == 5 135 | instance.seta(10) 136 | assert instance.geta() == 10 137 | assert instance.transform(10) == 20 138 | 139 | 140 | def test_func_to_spark_transformer(): 141 | """Test correct python function to `Transformer` conversion. 142 | 143 | """ 144 | 145 | def dummy(number, a=5): 146 | """Test Doc""" 147 | return number + a 148 | 149 | instance = StageTransformerConverter(dummy).convert() 150 | 151 | assert issubclass(instance.__class__, Transformer) 152 | assert instance.__class__.__name__ == "dummy" 153 | assert instance.__doc__ == "Test Doc" 154 | assert instance.transform(10) == 15 155 | 156 | assert instance.geta() == 5 157 | instance.seta(10) 158 | assert instance.geta() == 10 159 | assert instance.transform(10) == 20 160 | 161 | # test passing a transformer already 162 | assert instance is StageTransformerConverter(instance).convert() 163 | 164 | # test passing invalid type 165 | with pytest.raises(ValueError): 166 | StageTransformerConverter(["Wrong Type"]).convert() 167 | 168 | 169 | def test_pipeline_locator(spark, pipe): 170 | """Test index and label access for stages and dataframe representation. 171 | 172 | """ 173 | 174 | df_input = spark.range(10).toDF("value") 175 | df_output = df_input.withColumn("add1", F.col("value") + 2) \ 176 | .withColumn("add2", F.col("value") + 4) 177 | 178 | # test non existant transformer 179 | with pytest.raises(ValueError): 180 | pipe(Transformer()) 181 | 182 | # test missing transformation 183 | with pytest.raises(ValueError): 184 | pipe(0) 185 | 186 | test_result = pipe.transform(df_input) 187 | 188 | stage_add_1 = pipe.stages[0] 189 | transform_add_1 = pipe._transformer.transformations[0] 190 | 191 | assert stage_add_1 is pipe[0] 192 | assert stage_add_1 is pipe["add_1"] 193 | assert stage_add_1 is pipe[stage_add_1] 194 | 195 | # test incorrect type 196 | with pytest.raises(ValueError): 197 | pipe(tuple()) 198 | 199 | # test out of bounds error 200 | with pytest.raises(IndexError): 201 | pipe(20) 202 | 203 | # test ambiguous identifier 204 | with pytest.raises(ValueError): 205 | pipe("add") 206 | 207 | # test non existant identifier 208 | with pytest.raises(ValueError): 209 | pipe("I do not exist") 210 | 211 | assert transform_add_1 is pipe(0) 212 | assert transform_add_1 is pipe("add_1") 213 | 214 | assert test_result is pipe(1) 215 | assert test_result is pipe("add_2") 216 | 217 | assert df_output.toPandas().equals(test_result.toPandas()) 218 | 219 | 220 | def test_pipeline_cacher(spark, pipe): 221 | """Test pipeline caching functionality. 222 | 223 | """ 224 | 225 | df_input = spark.range(10).toDF("value") 226 | 227 | # test empty cache 228 | assert pipe.cache.enabled == [] 229 | 230 | # test disable on empty cache 231 | with pytest.raises(ValueError): 232 | pipe.cache.disable("add_1") 233 | 234 | pipe.cache.enable("add_2") 235 | pipe.transform(df_input) 236 | 237 | assert pipe("add_1").is_cached is False 238 | assert pipe("add_2").is_cached is True 239 | assert pipe.cache.enabled == [pipe["add_2"]] 240 | 241 | pipe.cache.enable(["add_1"]) 242 | assert pipe("add_1").is_cached is True 243 | 244 | pipe.cache.disable("add_1") 245 | assert pipe("add_1").is_cached is False 246 | assert pipe("add_2").is_cached is True 247 | 248 | pipe.cache.clear() 249 | assert pipe.cache.enabled == [] 250 | assert pipe("add_1").is_cached is False 251 | assert pipe("add_2").is_cached is False 252 | 253 | 254 | def test_pipeline_transformer(spark, pipe): 255 | """Test correct pipeline transformation. 256 | 257 | """ 258 | 259 | df_input = spark.range(10).toDF("value") 260 | 261 | assert bool(pipe._transformer) is False 262 | pipe.transform(df_input) 263 | assert bool(pipe._transformer) is True 264 | assert pipe._transformer.input_df is df_input 265 | 266 | assert [x for x in pipe._transformer] == pipe._transformer.transformations 267 | 268 | 269 | def test_pipeline_profiler(spark): 270 | """Test pipeline profiler. 271 | 272 | """ 273 | 274 | df_input = spark.range(10).toDF("value") 275 | 276 | def add_order(df): 277 | return df.withColumn("order", F.col("value") + 5) 278 | 279 | def add_groupby(df): 280 | return df.withColumn("groupby", F.col("value") + 10) 281 | 282 | def sort(df): 283 | return df.orderBy("order") 284 | 285 | def groupby(df): 286 | return df.groupBy("groupby").agg(F.max("value")) 287 | 288 | pipe = pipeline.Pipeline(stages=[add_order, add_groupby, sort, groupby]) 289 | 290 | # test missing df 291 | with pytest.raises(ValueError): 292 | pipe.profile() 293 | 294 | # test non pipeline df before transform 295 | df_profiles = pipe.profile(df_input) 296 | 297 | assert df_profiles.loc[0, "name"] == "Input dataframe" 298 | assert df_profiles.loc[0, "rows"] == 10 299 | assert df_profiles.loc[0, "idx"] == "None" 300 | assert df_profiles.loc[1, "name"] == "add_order" 301 | assert df_profiles.loc[4, "stage_count"] == 3 302 | assert df_profiles.loc[4, "cols"] == 2 303 | assert df_profiles.loc[4, "cached"] == False # noqa E712 304 | 305 | # test pipeline profile after transform 306 | pipe.transform(df_input) 307 | 308 | df_profiles = pipe.profile() 309 | 310 | assert df_profiles.loc[0, "name"] == "Input dataframe" 311 | assert df_profiles.loc[0, "rows"] == 10 312 | assert df_profiles.loc[0, "idx"] == "None" 313 | assert df_profiles.loc[1, "name"] == "add_order" 314 | assert df_profiles.loc[4, "stage_count"] == 3 315 | assert df_profiles.loc[4, "cols"] == 2 316 | assert df_profiles.loc[4, "cached"] == False # noqa E712 317 | 318 | # add caching and test 319 | pipe.cache.enable(2) 320 | pipe.transform(df_input) 321 | 322 | df_profiles = pipe.profile() 323 | 324 | assert df_profiles.loc[3, "cached"] == True # noqa E712 325 | assert df_profiles.loc[4, "stage_count"] == 4 326 | 327 | 328 | def test_pipeline_describer(spark): 329 | """Test pipeline describer. 330 | 331 | """ 332 | 333 | df_input = spark.range(10).toDF("value") 334 | 335 | def add_order(df): 336 | return df.withColumn("order", F.col("value") + 5) 337 | 338 | def add_groupby(df): 339 | return df.withColumn("groupby", F.col("value") + 10) 340 | 341 | def sort(df): 342 | return df.orderBy("order") 343 | 344 | def groupby(df): 345 | return df.groupBy("groupby").agg(F.max("value")) 346 | 347 | pipe = pipeline.Pipeline(stages=[add_order, add_groupby, sort, groupby]) 348 | 349 | # test missing df 350 | with pytest.raises(ValueError): 351 | pipe.profile() 352 | 353 | # test non pipeline df before transform 354 | df_descriptions = pipe.describe(df_input) 355 | 356 | assert df_descriptions.loc[0, "name"] == "Input dataframe" 357 | assert df_descriptions.loc[0, "idx"] == "None" 358 | assert df_descriptions.loc[1, "uid"] == pipe[0].uid 359 | assert df_descriptions.loc[1, "name"] == "add_order" 360 | assert df_descriptions.loc[1, "stage_count"] == 1 361 | assert df_descriptions.loc[4, "cols"] == 2 362 | assert df_descriptions.loc[4, "cached"] == False # noqa E712 363 | 364 | 365 | def test_full_pipeline(spark): 366 | """Create two stages from PySparkWrangler and native function and check 367 | against correct end result of pipeline. 368 | 369 | """ 370 | 371 | df_input = spark.range(10).toDF("value") 372 | df_output = df_input.withColumn("add1", F.col("value") + 1) \ 373 | .withColumn("add2", F.col("value") + 2) 374 | 375 | class DummyWrangler(PySparkSingleNoFit): 376 | def __init__(self, a=5): 377 | self.a = a 378 | 379 | def transform(self, df): 380 | return df.withColumn("add1", F.col("value") + 1) 381 | 382 | stage_wrangler = concretize_abstract_wrangler(DummyWrangler)() 383 | 384 | def stage_func(df, a=2): 385 | return df.withColumn("add2", F.col("value") + 2) 386 | 387 | pipe = pipeline.Pipeline([stage_wrangler, stage_func]) 388 | test_result = pipe.transform(df_input) 389 | 390 | assert df_output.toPandas().equals(test_result.toPandas()) 391 | -------------------------------------------------------------------------------- /tests/pyspark/test_testing.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for pyspark testing utility. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import pytest 7 | import pandas as pd 8 | 9 | pytestmark = pytest.mark.pyspark # noqa: E402 10 | pyspark = pytest.importorskip("pyspark") # noqa: E402 11 | 12 | from pywrangler.pyspark.testing import assert_pyspark_pandas_equality 13 | 14 | 15 | def test_assert_spark_pandas_equality_no_assert(spark): 16 | data = [[1, 2, 3, 4], 17 | [5, 6, 7, 8]] 18 | columns = list("abcd") 19 | index = [0, 1] 20 | test_data = pd.DataFrame(data=data, columns=columns, index=index) 21 | 22 | test_input = spark.createDataFrame(test_data) 23 | test_output = test_data.reindex([1, 0]) 24 | test_output = test_output[["b", "c", "a", "d"]] 25 | 26 | assert_pyspark_pandas_equality(test_input, test_output) 27 | assert_pyspark_pandas_equality(test_input, test_output, orderby=["a"]) 28 | 29 | 30 | def test_assert_spark_pandas_equality_assert(spark): 31 | data = [[1, 2, 3, 4], 32 | [5, 6, 7, 8]] 33 | columns = list("abcd") 34 | index = [0, 1] 35 | test_data = pd.DataFrame(data=data, columns=columns, index=index) 36 | 37 | test_input = spark.createDataFrame(test_data) 38 | test_output = test_data.copy(deep=True) 39 | test_output.iloc[0, 0] = 100 40 | 41 | with pytest.raises(AssertionError): 42 | assert_pyspark_pandas_equality(test_input, test_output) 43 | 44 | with pytest.raises(AssertionError): 45 | assert_pyspark_pandas_equality(test_input, test_output, orderby=["a"]) 46 | -------------------------------------------------------------------------------- /tests/pyspark/test_util.py: -------------------------------------------------------------------------------- 1 | """This module contains pyspark wrangler utility tests. 2 | 3 | isort:skip_file 4 | """ 5 | 6 | import pytest 7 | import pandas as pd 8 | from pywrangler.pyspark.util import ColumnCacher 9 | from pyspark.sql import functions as F 10 | 11 | pytestmark = pytest.mark.pyspark # noqa: E402 12 | pyspark = pytest.importorskip("pyspark") # noqa: E402 13 | 14 | from pywrangler.pyspark import util 15 | 16 | 17 | def test_ensure_column(spark): 18 | assert str(F.col("a")) == str(util.ensure_column("a")) 19 | assert str(F.col("a")) == str(util.ensure_column(F.col("a"))) 20 | 21 | 22 | def test_spark_wrangler_validate_columns_raises(spark): 23 | 24 | data = {"col1": [1, 2], "col2": [3, 4]} 25 | df = spark.createDataFrame(pd.DataFrame(data)) 26 | 27 | with pytest.raises(ValueError): 28 | util.validate_columns(df, ("col3", "col1")) 29 | 30 | 31 | def test_spark_wrangler_validate_columns_not_raises(spark): 32 | 33 | data = {"col1": [1, 2], "col2": [3, 4]} 34 | df = spark.createDataFrame(pd.DataFrame(data)) 35 | 36 | util.validate_columns(df, ("col1", "col2")) 37 | util.validate_columns(df, None) 38 | 39 | 40 | def test_prepare_orderby(spark): 41 | 42 | columns = ["a", "b"] 43 | 44 | # test empty input 45 | assert util.prepare_orderby(None) == [] 46 | 47 | # test broadcast 48 | result = [F.col("a").asc(), F.col("b").asc()] 49 | assert str(result) == str(util.prepare_orderby(columns, True)) 50 | 51 | # test exact 52 | result = [F.col("a").asc(), F.col("b").desc()] 53 | assert str(result) == str(util.prepare_orderby(columns, [True, False])) 54 | 55 | # test reverse 56 | result = [F.col("a").asc(), F.col("b").desc()] 57 | assert str(result) == str(util.prepare_orderby(columns, [False, True], 58 | reverse=True)) 59 | 60 | # raise unequal lengths 61 | with pytest.raises(ValueError): 62 | util.prepare_orderby(columns, [True, False, True]) 63 | 64 | 65 | def test_column_cacher(spark): 66 | 67 | data = {"col1": [1, 2], "col2": [3, 4]} 68 | df = spark.createDataFrame(pd.DataFrame(data)) 69 | 70 | # check invalid mode argument 71 | with pytest.raises(ValueError): 72 | ColumnCacher(df, mode="incorrect argument") 73 | 74 | # check added column for mode = True 75 | cc = ColumnCacher(df, mode=True) 76 | cc.add("col3", F.lit(None)) 77 | assert cc.columns["col3"] in cc.df.columns 78 | 79 | # check added column for mode = debug 80 | cc = ColumnCacher(df, mode="debug") 81 | cc.add("col3", F.lit(None)) 82 | assert cc.columns["col3"] in cc.df.columns 83 | 84 | # check missing column for mode = False 85 | cc = ColumnCacher(df, mode=False) 86 | cc.add("col3", F.lit(None)) 87 | assert "col3" not in cc.columns 88 | 89 | # check added column for force = True and mode = False 90 | cc = ColumnCacher(df, mode=False) 91 | cc.add("col3", F.lit(None), force=True) 92 | assert cc.columns["col3"] in cc.df.columns 93 | 94 | # check removed columns after finish with mode = True/False 95 | cc = ColumnCacher(df, mode=False) 96 | cc.add("col3", F.lit(None)) 97 | df_result = cc.finish("col4", F.lit(None)) 98 | assert "col3" not in cc.columns 99 | assert "col4" in df_result.columns 100 | 101 | # check remaining column after finish with mode debug 102 | cc = ColumnCacher(df, mode="debug") 103 | cc.add("col3", F.lit(None)) 104 | df_result = cc.finish("col4", F.lit(None)) 105 | assert cc.columns["col3"] in df_result.columns 106 | assert "col4" in df_result.columns 107 | 108 | -------------------------------------------------------------------------------- /tests/pyspark/wranglers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/pyspark/wranglers/__init__.py -------------------------------------------------------------------------------- /tests/pyspark/wranglers/test_interval_identifier.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for pyspark interval identifier. 2 | isort:skip_file 3 | """ 4 | 5 | import pandas as pd 6 | import pytest 7 | from pywrangler.util.testing import PlainFrame 8 | 9 | pytestmark = pytest.mark.pyspark # noqa: E402 10 | pyspark = pytest.importorskip("pyspark") # noqa: E402 11 | 12 | from tests.test_data.interval_identifier import ( 13 | CollectionGeneral, 14 | CollectionIdenticalStartEnd, 15 | CollectionMarkerSpecifics, 16 | CollectionNoOrderGroupBy, 17 | MultipleIntervalsSpanningGroupbyExtendedTriple, 18 | ResultTypeRawIids, 19 | ResultTypeValidIids 20 | ) 21 | 22 | from pywrangler.pyspark.wranglers.interval_identifier import ( 23 | VectorizedCumSum, 24 | VectorizedCumSumAdjusted 25 | ) 26 | 27 | WRANGLER = (VectorizedCumSum, VectorizedCumSumAdjusted) 28 | WRANGLER_IDS = [x.__name__ for x in WRANGLER] 29 | WRANGLER_KWARGS = dict(argnames='wrangler', 30 | argvalues=WRANGLER, 31 | ids=WRANGLER_IDS) 32 | 33 | 34 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 35 | @CollectionGeneral.pytest_parametrize_kwargs("marker_use") 36 | @CollectionGeneral.pytest_parametrize_testcases 37 | def test_base(testcase, wrangler, marker_use): 38 | """Tests against all available wranglers and test cases. 39 | 40 | Parameters 41 | ---------- 42 | testcase: DataTestCase 43 | Generates test data for given test case. 44 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 45 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 46 | marker_use: dict 47 | Defines the marker start/end use. 48 | 49 | """ 50 | 51 | # instantiate test case 52 | testcase_instance = testcase("pyspark") 53 | 54 | # instantiate wrangler 55 | kwargs = testcase_instance.test_kwargs.copy() 56 | kwargs.update(marker_use) 57 | wrangler_instance = wrangler(**kwargs) 58 | 59 | # pass wrangler to test case 60 | testcase_instance.test(wrangler_instance.transform) 61 | 62 | 63 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 64 | @CollectionIdenticalStartEnd.pytest_parametrize_testcases 65 | def test_identical_start_end(testcase, wrangler): 66 | """Tests against all available wranglers and test cases. 67 | 68 | Parameters 69 | ---------- 70 | testcase: DataTestCase 71 | Generates test data for given test case. 72 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 73 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 74 | 75 | """ 76 | 77 | # instantiate test case 78 | testcase_instance = testcase("pyspark") 79 | 80 | # instantiate wrangler 81 | wrangler_instance = wrangler(**testcase_instance.test_kwargs) 82 | 83 | # pass wrangler to test case 84 | testcase_instance.test(wrangler_instance.transform) 85 | 86 | 87 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 88 | @CollectionMarkerSpecifics.pytest_parametrize_testcases 89 | def test_marker_specifics(testcase, wrangler): 90 | """Tests specific `marker_start_use_first` and `marker_end_use_first` 91 | scenarios. 92 | 93 | Parameters 94 | ---------- 95 | testcase: DataTestCase 96 | Generates test data for given test case. 97 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 98 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 99 | 100 | """ 101 | 102 | # instantiate test case 103 | testcase_instance = testcase("pyspark") 104 | 105 | # instantiate wrangler 106 | wrangler_instance = wrangler(**testcase_instance.test_kwargs) 107 | 108 | # pass wrangler to test case 109 | testcase_instance.test(wrangler_instance.transform) 110 | 111 | 112 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 113 | def test_repartition(wrangler): 114 | """Tests that repartition has no effect. 115 | 116 | Parameters 117 | ---------- 118 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 119 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 120 | 121 | """ 122 | 123 | # instantiate test case 124 | testcase_instance = MultipleIntervalsSpanningGroupbyExtendedTriple() 125 | 126 | # instantiate wrangler 127 | wrangler_instance = wrangler(**testcase_instance.test_kwargs) 128 | 129 | # pass wrangler to test case 130 | testcase_instance.test.pyspark(wrangler_instance.transform, repartition=5) 131 | 132 | 133 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 134 | def test_result_type_raw_iids(wrangler): 135 | """Test for correct raw iids constraints. Returned result only needs to 136 | distinguish intervals regardless of their validity. Interval ids do not 137 | need to be in specific order. 138 | 139 | Parameters 140 | ---------- 141 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 142 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 143 | 144 | """ 145 | 146 | testcase_instance = ResultTypeRawIids("pandas") 147 | wrangler_instance = wrangler(result_type="raw", 148 | **testcase_instance.test_kwargs) 149 | 150 | df_input = testcase_instance.input.to_pyspark() 151 | df_output = testcase_instance.output.to_pandas() 152 | df_result = wrangler_instance.transform(df_input) 153 | df_result = (PlainFrame.from_pyspark(df_result) 154 | .to_pandas() 155 | .sort_values(testcase_instance.orderby_columns) 156 | .reset_index(drop=True)) 157 | 158 | col = testcase_instance.target_column_name 159 | pd.testing.assert_series_equal(df_result[col].diff().ne(0), 160 | df_output[col].diff().ne(0)) 161 | 162 | 163 | @CollectionGeneral.pytest_parametrize_kwargs("marker_use") 164 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 165 | def test_result_type_valid_iids(wrangler, marker_use): 166 | """Test for correct valid iids constraints. Returned result needs to 167 | distinguish valid from invalid intervals. Invalid intervals need to be 0. 168 | 169 | Parameters 170 | ---------- 171 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 172 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 173 | marker_use: dict 174 | Contains `marker_start_use_first` and `marker_end_use_first` parameters 175 | as dict. 176 | 177 | """ 178 | 179 | testcase_instance = ResultTypeValidIids("pyspark") 180 | kwargs = testcase_instance.test_kwargs.copy() 181 | kwargs.update(marker_use) 182 | wrangler_instance = wrangler(result_type="valid", **kwargs) 183 | 184 | df_input = testcase_instance.input.to_pyspark() 185 | df_output = testcase_instance.output.to_pandas() 186 | df_result = wrangler_instance.transform(df_input) 187 | df_result = (PlainFrame.from_pyspark(df_result) 188 | .to_pandas() 189 | .sort_values(testcase_instance.orderby_columns) 190 | .reset_index(drop=True)) 191 | 192 | col = testcase_instance.target_column_name 193 | pd.testing.assert_series_equal(df_result[col].diff().ne(0), 194 | df_output[col].diff().ne(0)) 195 | 196 | pd.testing.assert_series_equal(df_result[col].eq(0), 197 | df_output[col].eq(0)) 198 | 199 | 200 | @pytest.mark.parametrize(**WRANGLER_KWARGS) 201 | @CollectionNoOrderGroupBy.pytest_parametrize_testcases 202 | def test_no_order_groupby(testcase, wrangler): 203 | """Tests correct behaviour for missing groupby columns. 204 | 205 | Parameters 206 | ---------- 207 | testcase: DataTestCase 208 | Generates test data for given test case. 209 | wrangler: pywrangler.wrangler_instance.interfaces.IntervalIdentifier 210 | Refers to the actual wrangler_instance begin tested. See `WRANGLER`. 211 | 212 | """ 213 | 214 | # instantiate test case 215 | testcase_instance = testcase("pyspark") 216 | 217 | # instantiate wrangler 218 | kwargs = testcase_instance.test_kwargs.copy() 219 | kwargs.update({'groupby_columns': None}) 220 | wrangler_instance = wrangler(**kwargs) 221 | 222 | # pass wrangler to test case 223 | testcase_instance.test(wrangler_instance.transform) 224 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | """This module contains the BaseWrangler tests. 2 | 3 | """ 4 | 5 | import pytest 6 | 7 | from pywrangler import base 8 | from pywrangler.util.testing.util import concretize_abstract_wrangler 9 | 10 | 11 | @pytest.fixture(scope="session") 12 | def dummy_wrangler(): 13 | """Create DummyWrangler for testing BaseWrangler. 14 | 15 | """ 16 | 17 | class DummyWrangler(base.BaseWrangler): 18 | def __init__(self, arg1, kwarg1): 19 | self.arg1 = arg1 20 | self.kwarg1 = kwarg1 21 | 22 | @property 23 | def preserves_sample_size(self): 24 | return True 25 | 26 | @property 27 | def computation_engine(self): 28 | return "DummyEngine" 29 | 30 | return concretize_abstract_wrangler(DummyWrangler)("arg_val", "kwarg_val") 31 | 32 | 33 | def test_base_wrangler_not_implemented(): 34 | with pytest.raises(TypeError): 35 | base.BaseWrangler() 36 | 37 | empty_wrangler = concretize_abstract_wrangler(base.BaseWrangler)() 38 | 39 | test_attributes = ("preserves_sample_size", "computation_engine") 40 | test_methods = ("fit", "transform", "fit_transform") 41 | 42 | for test_attribute in test_attributes: 43 | with pytest.raises(NotImplementedError): 44 | getattr(empty_wrangler, test_attribute) 45 | 46 | for test_method in test_methods: 47 | with pytest.raises(NotImplementedError): 48 | getattr(empty_wrangler, test_method)() 49 | 50 | 51 | def test_base_wrangler_get_params(dummy_wrangler): 52 | test_output = {"arg1": "arg_val", "kwarg1": "kwarg_val"} 53 | 54 | assert dummy_wrangler.get_params() == test_output 55 | 56 | 57 | def test_base_wrangler_get_params_subclassed(dummy_wrangler): 58 | class SubClass(dummy_wrangler.__class__): 59 | def __init__(self, *args, new_param=2, **kwargs): 60 | super().__init__(*args, **kwargs) 61 | self.new_param = new_param 62 | 63 | test_output = {"arg1": "arg_val", "kwarg1": "kwarg_val", "new_param": 2} 64 | 65 | assert SubClass("arg_val", "kwarg_val").get_params() == test_output 66 | 67 | 68 | def test_base_wrangler_properties(dummy_wrangler): 69 | assert dummy_wrangler.preserves_sample_size is True 70 | assert dummy_wrangler.computation_engine == "DummyEngine" 71 | 72 | 73 | def test_base_wrangler_set_params(dummy_wrangler): 74 | dummy_wrangler.set_params(arg1="new_value") 75 | 76 | assert dummy_wrangler.arg1 == "new_value" 77 | assert dummy_wrangler.kwarg1 == "kwarg_val" 78 | 79 | 80 | def test_base_wrangler_set_params_exception(dummy_wrangler): 81 | with pytest.raises(ValueError): 82 | dummy_wrangler.set_params(not_exist=0) 83 | -------------------------------------------------------------------------------- /tests/test_benchmark.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for the benchmark utilities. 2 | 3 | """ 4 | 5 | import sys 6 | import time 7 | 8 | import pytest 9 | 10 | from pywrangler.benchmark import ( 11 | BaseProfiler, 12 | MemoryProfiler, 13 | TimeProfiler, 14 | allocate_memory 15 | ) 16 | from pywrangler.exceptions import NotProfiledError 17 | 18 | MIB = 2 ** 20 19 | 20 | 21 | @pytest.fixture() 22 | def func_no_effect(): 23 | def func(): 24 | pass 25 | 26 | return func 27 | 28 | 29 | def test_allocate_memory_empty(): 30 | memory_holder = allocate_memory(0) 31 | 32 | assert memory_holder is None 33 | 34 | 35 | def test_allocate_memory_5mb(): 36 | memory_holder = allocate_memory(5) 37 | 38 | assert sys.getsizeof(memory_holder) == 5 * (2 ** 20) 39 | 40 | 41 | def test_base_profiler_not_implemented(): 42 | base_profiler = BaseProfiler() 43 | 44 | for will_raise in ('profile', 'profile_report', 'less_is_better'): 45 | with pytest.raises(NotImplementedError): 46 | getattr(base_profiler, will_raise)() 47 | 48 | 49 | def test_base_profiler_check_is_profiled(): 50 | base_profiler = BaseProfiler() 51 | base_profiler._not_set = None 52 | base_profiler._is_set = "value" 53 | 54 | with pytest.raises(NotProfiledError): 55 | base_profiler._check_is_profiled(['_not_set']) 56 | 57 | base_profiler._check_is_profiled(['_is_set']) 58 | 59 | 60 | def test_base_profiler_measurements_less_is_better(capfd): 61 | measurements = range(7) 62 | 63 | class Profiler(BaseProfiler): 64 | 65 | @property 66 | def less_is_better(self): 67 | return True 68 | 69 | def profile(self, *args, **kwargs): 70 | self._measurements = measurements 71 | return self 72 | 73 | def _pretty_formatter(self, value): 74 | return "{:.0f}".format(value) 75 | 76 | base_profiler = Profiler() 77 | base_profiler.profile_report() 78 | 79 | assert base_profiler.median == 3 80 | assert base_profiler.best == 0 81 | assert base_profiler.worst == 6 82 | assert base_profiler.std == 2 83 | assert base_profiler.runs == 7 84 | assert base_profiler.measurements == measurements 85 | 86 | out, _ = capfd.readouterr() 87 | assert out == "0 < 3 < 6 ± 2 (7 runs)\n" 88 | 89 | 90 | def test_base_profiler_measurements_more_is_better(capfd): 91 | measurements = range(7) 92 | 93 | class Profiler(BaseProfiler): 94 | @property 95 | def less_is_better(self): 96 | return False 97 | 98 | def profile(self, *args, **kwargs): 99 | self._measurements = measurements 100 | return self 101 | 102 | def _pretty_formatter(self, value): 103 | return "{:.0f}".format(value) 104 | 105 | base_profiler = Profiler() 106 | base_profiler.profile_report() 107 | 108 | assert base_profiler.median == 3 109 | assert base_profiler.best == 6 110 | assert base_profiler.worst == 0 111 | assert base_profiler.std == 2 112 | assert base_profiler.runs == 7 113 | assert base_profiler.measurements == measurements 114 | 115 | out, _ = capfd.readouterr() 116 | assert out == "6 > 3 > 0 ± 2 (7 runs)\n" 117 | 118 | 119 | def test_memory_profiler_mb_to_bytes(): 120 | assert MemoryProfiler._mb_to_bytes(1) == 1048576 121 | assert MemoryProfiler._mb_to_bytes(1.5) == 1572864 122 | assert MemoryProfiler._mb_to_bytes(0.33) == 346030 123 | 124 | 125 | def test_memory_profiler_return_self(func_no_effect): 126 | memory_profiler = MemoryProfiler(func_no_effect) 127 | assert memory_profiler.profile() is memory_profiler 128 | 129 | 130 | def test_memory_profiler_measurements(func_no_effect): 131 | baselines = [0, 1, 2, 3] 132 | max_usages = [4, 5, 7, 8] 133 | measurements = [4, 4, 5, 5] 134 | 135 | memory_profiler = MemoryProfiler(func_no_effect) 136 | memory_profiler._baselines = baselines 137 | memory_profiler._max_usages = max_usages 138 | memory_profiler._measurements = measurements 139 | 140 | assert memory_profiler.less_is_better is True 141 | assert memory_profiler.max_usages == max_usages 142 | assert memory_profiler.baselines == baselines 143 | assert memory_profiler.measurements == measurements 144 | assert memory_profiler.median == 4.5 145 | assert memory_profiler.std == 0.5 146 | assert memory_profiler.best == 4 147 | assert memory_profiler.worst == 5 148 | assert memory_profiler.baseline_change == 1 149 | assert memory_profiler.runs == 4 150 | 151 | 152 | def test_memory_profiler_no_side_effect(func_no_effect): 153 | baseline_change = MemoryProfiler(func_no_effect).profile().baseline_change 154 | 155 | assert baseline_change < 0.5 * MIB 156 | 157 | 158 | def test_memory_profiler_side_effect(): 159 | side_effect_container = [] 160 | 161 | def side_effect(): 162 | memory_holder = allocate_memory(5) 163 | side_effect_container.append(memory_holder) 164 | 165 | return memory_holder 166 | 167 | assert MemoryProfiler(side_effect).profile().baseline_change > 4.9 * MIB 168 | 169 | 170 | def test_memory_profiler_no_increase(func_no_effect): 171 | memory_profiler = MemoryProfiler(func_no_effect).profile() 172 | print(memory_profiler.measurements) 173 | 174 | assert memory_profiler.median < MIB 175 | 176 | 177 | @pytest.mark.xfail(reason="Succeeds locally but sometimes fails remotely due " 178 | "to non deterministic memory management.") 179 | def test_memory_profiler_increase(): 180 | def increase(): 181 | memory_holder = allocate_memory(30) 182 | return memory_holder 183 | 184 | assert MemoryProfiler(increase).profile().median > 29 * MIB 185 | 186 | 187 | def test_time_profiler_return_self(func_no_effect): 188 | time_profiler = TimeProfiler(func_no_effect, 1) 189 | assert time_profiler.profile() is time_profiler 190 | 191 | 192 | def test_time_profiler_measurements(func_no_effect): 193 | measurements = [1, 1, 3, 3] 194 | 195 | time_profiler = TimeProfiler(func_no_effect) 196 | time_profiler._measurements = measurements 197 | 198 | assert time_profiler.less_is_better is True 199 | assert time_profiler.median == 2 200 | assert time_profiler.std == 1 201 | assert time_profiler.best == 1 202 | assert time_profiler.runs == 4 203 | assert time_profiler.measurements == measurements 204 | 205 | 206 | def test_time_profiler_repetitions(func_no_effect): 207 | time_profiler = TimeProfiler(func_no_effect, repetitions=10) 208 | assert time_profiler.repetitions == 10 209 | 210 | 211 | def test_time_profiler_best(): 212 | sleep = 0.0001 213 | 214 | def dummy(): 215 | time.sleep(sleep) 216 | pass 217 | 218 | time_profiler = TimeProfiler(dummy, repetitions=1).profile() 219 | 220 | assert time_profiler.best >= sleep 221 | -------------------------------------------------------------------------------- /tests/test_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/test_data/__init__.py -------------------------------------------------------------------------------- /tests/test_wranglers.py: -------------------------------------------------------------------------------- 1 | """Test wrangler interfaces. 2 | 3 | """ 4 | 5 | import pytest 6 | 7 | from pywrangler import wranglers 8 | from pywrangler.util.testing.util import concretize_abstract_wrangler 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def ii_kwargs(): 13 | return {"marker_column": "marker_col", 14 | "marker_start": "start", 15 | "marker_end": "end", 16 | "marker_start_use_first": False, 17 | "marker_end_use_first": True, 18 | "orderby_columns": ["col1", "col2"], 19 | "groupby_columns": ["col3", "col4"], 20 | "ascending": [True, False], 21 | "target_column_name": "abc", 22 | "result_type": "raw"} 23 | 24 | 25 | @pytest.fixture() 26 | def interval_identifier(): 27 | return concretize_abstract_wrangler(wranglers.IntervalIdentifier) 28 | 29 | 30 | def test_base_interval_identifier_init(ii_kwargs, interval_identifier): 31 | wrangler = interval_identifier 32 | bii = wrangler(**ii_kwargs) 33 | 34 | assert bii.get_params() == ii_kwargs 35 | 36 | 37 | def test_base_interval_identifier_forced_ascending(ii_kwargs, 38 | interval_identifier): 39 | forced_ascending = ii_kwargs.copy() 40 | forced_ascending["ascending"] = None 41 | 42 | wrangler = interval_identifier 43 | bii = wrangler(**forced_ascending) 44 | 45 | assert bii.ascending == [True, True] 46 | 47 | 48 | def test_base_interval_identifier_sort_length_exc(ii_kwargs, 49 | interval_identifier): 50 | incorrect_length = ii_kwargs.copy() 51 | incorrect_length["ascending"] = (True,) 52 | 53 | wrangler = interval_identifier 54 | 55 | with pytest.raises(ValueError): 56 | wrangler(**incorrect_length) 57 | 58 | 59 | def test_base_interval_identifier_sort_keyword_exc(ii_kwargs, 60 | interval_identifier): 61 | incorrect_keyword = ii_kwargs.copy() 62 | incorrect_keyword["ascending"] = ("wrong keyword", "wrong keyword too") 63 | 64 | wrangler = interval_identifier 65 | 66 | with pytest.raises(ValueError): 67 | wrangler(**incorrect_keyword) 68 | 69 | 70 | def test_base_interval_identifier_identical_markers(ii_kwargs, 71 | interval_identifier): 72 | kwargs = ii_kwargs.copy() 73 | del kwargs["marker_end"] 74 | 75 | wrangler = interval_identifier(**kwargs) 76 | 77 | assert wrangler._identical_start_end_markers is True 78 | 79 | 80 | def test_base_interval_identifier_identical_start_end_markers(ii_kwargs, 81 | interval_identifier): 82 | kwargs = ii_kwargs.copy() 83 | kwargs["marker_end"] = kwargs["marker_start"] 84 | 85 | wrangler = interval_identifier(**kwargs) 86 | 87 | assert wrangler._identical_start_end_markers is True 88 | 89 | 90 | def test_base_interval_identifier_non_identical_markers(ii_kwargs, 91 | interval_identifier): 92 | wrangler = interval_identifier(**ii_kwargs) 93 | 94 | assert wrangler._identical_start_end_markers is False 95 | 96 | 97 | def test_base_interval_identifier_result_type(ii_kwargs, interval_identifier): 98 | kwargs = ii_kwargs.copy() 99 | kwargs["result_type"] = "does not exist" 100 | 101 | with pytest.raises(ValueError): 102 | interval_identifier(**kwargs) 103 | -------------------------------------------------------------------------------- /tests/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/util/__init__.py -------------------------------------------------------------------------------- /tests/util/test_dependencies.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for the dependencies module. 2 | 3 | """ 4 | import pytest 5 | from pywrangler.util import dependencies 6 | 7 | 8 | def test_raise_if_missing(): 9 | # test non raising for available package 10 | dependencies.raise_if_missing("collections") 11 | 12 | # test raising for missing package 13 | with pytest.raises(ImportError): 14 | dependencies.raise_if_missing("not_existent_package_name123") 15 | 16 | 17 | def test_requires(): 18 | def func(value, a, b=1): 19 | return value + a + b 20 | 21 | # test non raising for available package 22 | decorated = dependencies.requires("collections")(func) 23 | assert decorated(1, 1, 2) == 4 24 | 25 | # test raising for missing package 26 | decorated = dependencies.requires("not_existent_package_name123")(func) 27 | with pytest.raises(ImportError): 28 | decorated(1, 1, 2) 29 | 30 | 31 | def test_is_available(): 32 | assert dependencies.is_available("collections") is True 33 | assert dependencies.is_available("not_existent_package_name123") is False 34 | -------------------------------------------------------------------------------- /tests/util/test_helper.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for the helper module. 2 | 3 | """ 4 | 5 | from pywrangler.util.helper import get_param_names 6 | 7 | 8 | def test_get_param_names(): 9 | 10 | def func(): 11 | pass 12 | 13 | assert get_param_names(func) == [] 14 | 15 | def func1(a, b=4, c=6): 16 | pass 17 | 18 | assert get_param_names(func1) == ["a", "b", "c"] 19 | assert get_param_names(func1, ["a"]) == ["b", "c"] 20 | -------------------------------------------------------------------------------- /tests/util/test_pprint.py: -------------------------------------------------------------------------------- 1 | """Test printing helpers. 2 | 3 | """ 4 | 5 | from pywrangler.util import _pprint 6 | 7 | 8 | def test_join(): 9 | test_input = ["a", "b", "c"] 10 | test_output = "a\nb\nc" 11 | 12 | assert _pprint._join(test_input) == test_output 13 | 14 | 15 | def test_indent(): 16 | test_input = ["a", "b", "c"] 17 | test_output = [" a", " b", " c"] 18 | 19 | assert _pprint._indent(test_input, 3) == test_output 20 | 21 | 22 | def test_header(): 23 | test_input = "Header" 24 | test_output = 'Header\n------\n' 25 | 26 | assert _pprint.header(test_input) == test_output 27 | 28 | 29 | def test_header_with_indent(): 30 | test_input = "Header" 31 | test_output = ' Header\n ------\n' 32 | 33 | assert _pprint.header(test_input, indent=3) == test_output 34 | 35 | 36 | def test_header_with_underline(): 37 | test_input = "Header" 38 | test_output = 'Header\n======\n' 39 | 40 | assert _pprint.header(test_input, underline="=") == test_output 41 | 42 | 43 | def test_enumeration_dict_align_values_false(): 44 | test_input = {"a": 1, "bb": 2} 45 | test_output = '- a: 1\n- bb: 2' 46 | 47 | assert _pprint.enumeration(test_input, align_values=False) == test_output 48 | 49 | 50 | def test_enumeration_dict_align_values(): 51 | test_input = {"a": 1, "bb": 2} 52 | test_output = '- a: 1\n- bb: 2' 53 | 54 | assert _pprint.enumeration(test_input) == test_output 55 | 56 | 57 | def test_enumeration_dict_align_values_with_align_width(): 58 | test_input = {"a": 1, "bb": 2} 59 | test_output = '- a: 1\n- bb: 2' 60 | 61 | assert _pprint.enumeration(test_input, align_width=3) == test_output 62 | 63 | 64 | def test_enumeration_list(): 65 | test_input = ["note 1", "note 2"] 66 | test_output = '- note 1\n- note 2' 67 | 68 | assert _pprint.enumeration(test_input) == test_output 69 | 70 | 71 | def test_enumeration_list_with_indent(): 72 | test_input = ["note 1", "note 2"] 73 | test_output = ' - note 1\n - note 2' 74 | 75 | assert _pprint.enumeration(test_input, indent=4) == test_output 76 | 77 | 78 | def test_enumeration_list_with_bullet(): 79 | test_input = ["note 1", "note 2"] 80 | test_output = 'o note 1\no note 2' 81 | 82 | assert _pprint.enumeration(test_input, bullet_char="o") == test_output 83 | 84 | 85 | def test_pretty_file_size(): 86 | pfs = _pprint.pretty_file_size 87 | 88 | assert pfs(1024, precision=1, width=4) == ' 1.0 KiB' 89 | assert pfs(1024, precision=1, width=4, align="<") == '1.0 KiB' 90 | assert pfs(1024, precision=1) == '1.0 KiB' 91 | assert pfs(1024 ** 2, precision=1, width=0) == '1.0 MiB' 92 | assert pfs(1024 ** 8, precision=2, width=0) == '1.00 YiB' 93 | 94 | 95 | def test_pretty_time_duration(): 96 | ptd = _pprint.pretty_time_duration 97 | 98 | assert ptd(1.1) == "1.1 s" 99 | assert ptd(1.59, width=5) == " 1.6 s" 100 | assert ptd(1.55, width=7, precision=2) == " 1.55 s" 101 | assert ptd(1.55, width=7, precision=2, align="<") == "1.55 s" 102 | assert ptd(120, precision=2) == "2.00 min" 103 | assert ptd(5400, precision=1) == "1.5 h" 104 | assert ptd(0.5, precision=1) == "500.0 ms" 105 | assert ptd(0.0005, precision=1) == "500.0 µs" 106 | assert ptd(0.0000005, precision=1) == "500.0 ns" 107 | assert ptd(0) == "0 s" 108 | assert ptd(-1.1) == "-1.1 s" 109 | 110 | 111 | def test_textwrap_docstring(): 112 | twds = _pprint.textwrap_docstring 113 | 114 | class NoDocStr: 115 | pass 116 | 117 | assert twds(NoDocStr) == [] 118 | 119 | class Mock: 120 | """ Dummy test doc string. """ 121 | pass 122 | 123 | assert twds(Mock) == ["Dummy test doc string."] 124 | assert twds(Mock, 10) == ["Dummy test", "doc", "string."] 125 | 126 | 127 | def test_truncate(): 128 | truncate = _pprint.truncate 129 | 130 | assert truncate("foo", 20) == "foo" 131 | assert truncate("foofoofoo", 6) == "foo..." 132 | assert truncate("foofoofoo", 4, "-") == "foo-" 133 | -------------------------------------------------------------------------------- /tests/util/test_sanitizer.py: -------------------------------------------------------------------------------- 1 | """Test sanitizer functions. 2 | 3 | """ 4 | 5 | import pytest 6 | 7 | from pywrangler.util.sanitizer import ensure_iterable 8 | 9 | 10 | @pytest.mark.parametrize(argnames="seq_type", argvalues=(list, tuple)) 11 | def test_ensure_iterable_number(seq_type): 12 | test_input = 3 13 | test_output = seq_type([3]) 14 | 15 | assert ensure_iterable(test_input, seq_type) == test_output 16 | 17 | 18 | @pytest.mark.parametrize(argnames="seq_type", argvalues=(list, tuple)) 19 | def test_ensure_iterable_string(seq_type): 20 | test_input = "test_string" 21 | test_output = seq_type(["test_string"]) 22 | 23 | assert ensure_iterable(test_input, seq_type) == test_output 24 | 25 | 26 | @pytest.mark.parametrize(argnames="seq_type", argvalues=(list, tuple)) 27 | def test_ensure_iterable_strings(seq_type): 28 | test_input = ["test1", "test2"] 29 | test_output = seq_type(["test1", "test2"]) 30 | 31 | assert ensure_iterable(test_input, seq_type) == test_output 32 | 33 | 34 | @pytest.mark.parametrize(argnames="seq_type", argvalues=(list, tuple)) 35 | def test_ensure_iterable_custom_class(seq_type): 36 | class Dummy: 37 | pass 38 | 39 | dummy = Dummy() 40 | 41 | test_input = dummy 42 | test_output = seq_type([dummy]) 43 | 44 | assert ensure_iterable(test_input, seq_type) == test_output 45 | 46 | 47 | @pytest.mark.parametrize(argnames="seq_type", argvalues=(list, tuple)) 48 | def test_ensure_iterable_none(seq_type): 49 | 50 | assert ensure_iterable(None, seq_type) == seq_type() 51 | assert ensure_iterable(None, seq_type, retain_none=True) is None 52 | -------------------------------------------------------------------------------- /tests/util/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mansenfranzen/pywrangler/2faa62b4e3a223e85b298118ba2923439e42cd22/tests/util/testing/__init__.py -------------------------------------------------------------------------------- /tests/util/testing/test_datatestcase.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for DataTestCase. 2 | 3 | """ 4 | import pytest 5 | from pywrangler.util.testing.datatestcase import DataTestCase, TestCollection 6 | 7 | 8 | @pytest.fixture 9 | def datatestcase(): 10 | class TestCase(DataTestCase): 11 | 12 | def input(self): 13 | return self.output["col1"] 14 | 15 | def output(self): 16 | return {"col1:i": [1, 2, 3], 17 | "col2:i": [2, 3, 4]} 18 | 19 | def mutants(self): 20 | return {("col1", 0): 10} 21 | 22 | return TestCase 23 | 24 | 25 | def test_engine_tester(datatestcase): 26 | def test_func(df): 27 | return df 28 | 29 | # assert missing engine specification 30 | with pytest.raises(ValueError): 31 | datatestcase().test(test_func) 32 | 33 | # assert invalid engine 34 | with pytest.raises(ValueError): 35 | datatestcase().test(test_func, engine="not_exists") 36 | 37 | with pytest.raises(ValueError): 38 | datatestcase("not_exists").test(test_func) 39 | 40 | 41 | def test_engine_tester_pandas(datatestcase): 42 | # test correct standard behaviour 43 | def test_func(df): 44 | df = df.copy() 45 | df["col2"] = df["col1"] + 1 46 | return df 47 | 48 | datatestcase("pandas").test(test_func) 49 | datatestcase().test(test_func, engine="pandas") 50 | datatestcase().test.pandas(test_func) 51 | 52 | # check merge input column 53 | def test_func(df): 54 | return df["col1"].add(1).to_frame("col2") 55 | 56 | datatestcase("pandas").test(test_func, merge_input=True) 57 | 58 | # pass kwargs with merge input 59 | def test_func(df, add, mul=0): 60 | return df["col1"].add(add).mul(mul).to_frame("col2") 61 | 62 | datatestcase("pandas").test(test_func, 63 | test_kwargs={"mul": 1, "add": 1}, 64 | merge_input=True) 65 | 66 | 67 | def test_engine_tester_pyspark(datatestcase): 68 | from pyspark.sql import functions as F 69 | 70 | def test_func(df): 71 | return df.withColumn("col2", F.col("col1") + 1) 72 | 73 | # test correct standard behaviour 74 | datatestcase("pyspark").test(test_func) 75 | 76 | # check repartition 77 | datatestcase("pyspark").test(test_func, repartition=2) 78 | 79 | # pass kwargs with repartition 80 | def test_func(df, add, mul=0): 81 | return df.withColumn("col2", (F.col("col1") + add) * mul) 82 | 83 | datatestcase("pyspark").test(test_func, 84 | test_kwargs={"add": 1, "mul": 1}, 85 | repartition=2) 86 | 87 | 88 | def test_engine_tester_surviving_mutant(): 89 | """Tests for a mutant that does not killed and hence should raise an error. 90 | In this example, the mutant equals the actual correct input. 91 | """ 92 | 93 | class TestCase(DataTestCase): 94 | def input(self): 95 | return self.output["col1"] 96 | 97 | def output(self): 98 | return {"col1:i": [1, 2, 3], 99 | "col2:i": [2, 3, 4]} 100 | 101 | def mutants(self): 102 | return {("col1", 0): 1} 103 | 104 | def test_func(df): 105 | df = df.copy() 106 | df["col2"] = df["col1"] + 1 107 | return df 108 | 109 | with pytest.raises(AssertionError): 110 | TestCase().test.pandas(test_func) 111 | 112 | 113 | def test_test_collection(datatestcase): 114 | collection = TestCollection([datatestcase]) 115 | 116 | # test init 117 | assert collection.testcases == [datatestcase] 118 | assert collection.names == ["TestCase"] 119 | 120 | # test with custom parameter name 121 | parametrize = pytest.mark.parametrize 122 | param = dict(argvalues=[datatestcase], ids=["TestCase"], argnames="a") 123 | assert collection.pytest_parametrize_testcases("a") == parametrize(**param) 124 | 125 | # test with default parameter name 126 | param["argnames"] = "testcase" 127 | 128 | def func(): 129 | pass 130 | 131 | assert (collection.pytest_parametrize_testcases(func) == 132 | parametrize(**param)(func)) 133 | 134 | # test test_kwargs 135 | kwargs = {"conf1": {"param1": 1, "param2": 2}} 136 | param = dict(argvalues=[1, 2], ids=["param1", "param2"], argnames="conf1") 137 | collection = TestCollection([datatestcase], test_kwargs=kwargs) 138 | assert (collection.pytest_parametrize_kwargs("conf1") == 139 | parametrize(**param)) 140 | 141 | with pytest.raises(ValueError): 142 | collection.pytest_parametrize_kwargs("notexists") 143 | -------------------------------------------------------------------------------- /tests/util/testing/test_mutants.py: -------------------------------------------------------------------------------- 1 | """This module contains data mutant and mutations tests. 2 | 3 | """ 4 | from datetime import datetime 5 | 6 | import pytest 7 | from pywrangler.util.testing import mutants, plainframe 8 | 9 | 10 | def test_mutation(): 11 | """Test correct mutant instantiation. 12 | 13 | """ 14 | 15 | mutation = mutants.Mutation("foo", 1, "bar") 16 | 17 | assert mutation.column == "foo" 18 | assert mutation.row == 1 19 | assert mutation.value == "bar" 20 | assert mutation.key == ("foo", 1) 21 | 22 | 23 | def test_base_mutant(): 24 | """Test BaseMutant functionality. 25 | 26 | """ 27 | 28 | # check duplicated mutations 29 | m1 = mutants.Mutation("foo", 1, "bar") 30 | m2 = mutants.Mutation("foo", 1, "far") 31 | m3 = mutants.Mutation("bar", 0, "foo") 32 | m4 = mutants.Mutation("foo", 0, "bar") 33 | 34 | with pytest.raises(ValueError): 35 | mutants.BaseMutant._check_duplicated_mutations([m1, m2]) 36 | 37 | mutants.BaseMutant._check_duplicated_mutations([m1, m3]) 38 | 39 | # check invalid mutations 40 | df = plainframe.PlainFrame.from_dict({"foo:str": ["bar"]}) 41 | 42 | with pytest.raises(ValueError): 43 | mutants.BaseMutant._check_valid_mutations([m1], df) # check row 44 | 45 | with pytest.raises(ValueError): 46 | mutants.BaseMutant._check_valid_mutations([m3], df) # check column 47 | 48 | mutants.BaseMutant._check_valid_mutations([m4], df) 49 | 50 | 51 | def test_base_mutant_from_dict(): 52 | # test dict single 53 | raw = {("col1", 0): 1} 54 | conv = mutants.ValueMutant(column="col1", row=0, value=1) 55 | assert mutants.BaseMutant.from_dict(raw) == conv 56 | 57 | # test dict multi 58 | raw = {("col1", 0): 1, 59 | ("col2", 0): 0} 60 | 61 | conv = mutants.MutantCollection([ 62 | mutants.ValueMutant(column="col1", row=0, value=1), 63 | mutants.ValueMutant(column="col2", row=0, value=0) 64 | ]) 65 | 66 | assert mutants.BaseMutant.from_dict(raw) == conv 67 | 68 | # test wrong type 69 | with pytest.raises(ValueError): 70 | mutants.BaseMutant.from_dict(("col1", 2)) 71 | 72 | 73 | def test_base_mutant_from_multiple_any(): 74 | func = mutants.BaseMutant.from_multiple_any 75 | 76 | # test empty result 77 | assert func([]) == [] 78 | assert func(None) == [] 79 | 80 | # test Mutant 81 | raw_single = {("col1", 0): 1} 82 | conv_single = mutants.ValueMutant(column="col1", row=0, value=1) 83 | assert func(conv_single) == [conv_single] 84 | 85 | # test dict single single 86 | assert func(raw_single) == [conv_single] 87 | 88 | # test dict single multi 89 | raw_multi = {("col1", 0): 1, 90 | ("col2", 0): 0} 91 | 92 | conv_multi = mutants.MutantCollection([ 93 | mutants.ValueMutant(column="col1", row=0, value=1), 94 | mutants.ValueMutant(column="col2", row=0, value=0) 95 | ]) 96 | 97 | assert func(raw_multi) == [conv_multi] 98 | 99 | # test list mixed 100 | raw = [raw_multi, conv_single] 101 | assert func(raw) == [conv_multi, conv_single] 102 | 103 | # test incorrect type 104 | with pytest.raises(ValueError): 105 | func(tuple([1, 2, 3])) 106 | 107 | 108 | def test_value_mutant(): 109 | """Test ValueMutant functionality. 110 | 111 | """ 112 | 113 | df = plainframe.PlainFrame.from_dict({"foo:str": ["bar"]}) 114 | df_test = plainframe.PlainFrame.from_dict({"foo:str": ["foo"]}) 115 | mutant = mutants.ValueMutant("foo", 0, "foo") 116 | mutation = mutants.Mutation("foo", 0, "foo") 117 | 118 | assert mutant.generate_mutations(df) == [mutation] 119 | assert mutant.mutate(df) == df_test 120 | 121 | 122 | def test_function_mutant(): 123 | """Test FunctionMutant functionality. 124 | 125 | """ 126 | 127 | df = plainframe.PlainFrame.from_dict({"foo:str": ["bar"]}) 128 | df_test = plainframe.PlainFrame.from_dict({"foo:str": ["foo"]}) 129 | mutation = mutants.Mutation("foo", 0, "foo") 130 | 131 | def custom_func(df): 132 | return [mutation] 133 | 134 | mutant = mutants.FunctionMutant(custom_func) 135 | 136 | assert mutant.generate_mutations(df) == [mutation] 137 | assert mutant.mutate(df) == df_test 138 | 139 | 140 | def test_random_mutant(): 141 | """Test RandomMutant functionalty. 142 | 143 | """ 144 | 145 | df = plainframe.PlainFrame.from_dict({"foo:str": ["bar"], 146 | "bar:int": [1]}) 147 | 148 | # test invalid column 149 | mutant = mutants.RandomMutant(columns=["not_exists"]) 150 | with pytest.raises(ValueError): 151 | mutant.mutate(df) 152 | 153 | # test column 154 | mutant = mutants.RandomMutant(columns=["foo"]) 155 | df_mutated = mutant.mutate(df) 156 | assert df_mutated.get_column("foo").values[0] != "bar" 157 | assert df_mutated.get_column("bar").values[0] == 1 158 | 159 | # test invalid row 160 | mutant = mutants.RandomMutant(rows=[2]) 161 | with pytest.raises(ValueError): 162 | mutant.mutate(df) 163 | 164 | # test row 165 | df_rows = plainframe.PlainFrame.from_dict({"foo:str": ["bar", "foo"]}) 166 | mutant = mutants.RandomMutant(rows=[1]) 167 | df_mutated = mutant.mutate(df_rows) 168 | assert df_mutated.get_column("foo").values[0] == "bar" 169 | assert df_mutated.get_column("foo").values[0] != "foo" 170 | 171 | # test count 172 | mutant = mutants.RandomMutant(count=1) 173 | df_mutated = mutant.mutate(df) 174 | assert ((df_mutated.get_column("foo").values[0] != "bar") != 175 | (df_mutated.get_column("bar").values[0] != 1)) 176 | 177 | # test max count 178 | mutant = mutants.RandomMutant(count=100) 179 | df_mutated = mutant.mutate(df) 180 | assert df_mutated.get_column("foo").values[0] != "bar" 181 | assert df_mutated.get_column("bar").values[0] != 1 182 | 183 | # test random funcs for all types 184 | random = mutants.RandomMutant._random_value 185 | date = datetime(2019, 1, 1) 186 | assert isinstance(random("bool", True), bool) 187 | assert isinstance(random("int", 1), int) 188 | assert isinstance(random("float", 1.1), float) 189 | assert isinstance(random("str", "foo"), str) 190 | assert isinstance(random("datetime", date), datetime) 191 | 192 | assert random("bool", True) is not True 193 | assert random("int", 1) != 1 194 | assert random("float", 1.1) != 1.1 195 | assert random("str", "foo") != "foo" 196 | assert random("datetime", date) != date 197 | 198 | 199 | def test_collection_mutant(): 200 | """Test MutantCollection functionality. 201 | 202 | """ 203 | 204 | # test combination 205 | df = plainframe.PlainFrame.from_dict({"foo:str": ["foo", "foo"]}) 206 | value_mutant = mutants.ValueMutant("foo", 0, "bar") 207 | func = lambda _: [mutants.Mutation("foo", 1, "bar")] 208 | func_mutant = mutants.FunctionMutant(func) 209 | 210 | result = [mutants.Mutation("foo", 0, "bar"), 211 | mutants.Mutation("foo", 1, "bar")] 212 | 213 | df_result = plainframe.PlainFrame.from_dict({"foo:str": ["bar", "bar"]}) 214 | 215 | mutant_collection = mutants.MutantCollection([value_mutant, func_mutant]) 216 | assert mutant_collection.generate_mutations(df) == result 217 | assert mutant_collection.mutate(df) == df_result 218 | 219 | 220 | def test_mutant_assertions(): 221 | """Test invalid type changes due to mutations. 222 | 223 | """ 224 | 225 | df = plainframe.PlainFrame.from_dict({"foo:str": ["foo", "foo"]}) 226 | 227 | mutant = mutants.ValueMutant("foo", 1, 2) 228 | with pytest.raises(TypeError): 229 | mutant.mutate(df) 230 | -------------------------------------------------------------------------------- /tests/util/testing/test_util.py: -------------------------------------------------------------------------------- 1 | """This module contains testing util tests. 2 | 3 | """ 4 | 5 | import pytest 6 | from pywrangler.base import BaseWrangler 7 | from pywrangler.util.testing.util import concretize_abstract_wrangler 8 | 9 | 10 | def test_concretize_abstract_wrangler(): 11 | class Dummy(BaseWrangler): 12 | """Doc""" 13 | 14 | @property 15 | def computation_engine(self) -> str: 16 | return "engine" 17 | 18 | concrete_class = concretize_abstract_wrangler(Dummy) 19 | instance = concrete_class() 20 | 21 | assert instance.computation_engine == "engine" 22 | assert instance.__doc__ == "Doc" 23 | assert instance.__class__.__name__ == "Dummy" 24 | 25 | with pytest.raises(NotImplementedError): 26 | instance.preserves_sample_size 27 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = 3 | {py35,py36,py37}-master 4 | {py35,py36,py37}-pandas{0190,0191,0192,0200,0201,0202,0203,0210,0211,0220,0230,0231,0232,0233,0234,0240,0241} 5 | {py35,py36,py37}-pyspark{231,240} 6 | {py35,py36,py37}-dask{115} 7 | flake8 8 | dev 9 | 10 | skip_missing_interpreters = True 11 | 12 | [testenv] 13 | commands = 14 | py.test {posargs} 15 | python travisci/fix_paths.py 16 | 17 | whitelist_externals = python 18 | 19 | extras = testing 20 | 21 | deps = 22 | master: pandas 23 | master: pyspark 24 | master: dask[dataframe] 25 | 26 | pandas0241: pandas==0.24.1 27 | pandas0240: pandas==0.24.0 28 | pandas0234: pandas==0.23.4 29 | pandas0233: pandas==0.23.3 30 | pandas0232: pandas==0.23.2 31 | pandas0231: pandas==0.23.1 32 | pandas0230: pandas==0.23.0 33 | pandas0220: pandas==0.22.0 34 | pandas0211: pandas==0.21.1 35 | pandas0210: pandas==0.21.0 36 | pandas0203: pandas==0.20.3 37 | pandas0202: pandas==0.20.2 38 | pandas0201: pandas==0.20.1 39 | pandas0200: pandas==0.20.0 40 | pandas0192: pandas==0.19.2 41 | pandas0191: pandas==0.19.1 42 | pandas0190: pandas==0.19.0 43 | 44 | pyspark240: pyspark==2.4.0 45 | pyspark231: pyspark==2.3.1 46 | 47 | dask115: dask[dataframe]==1.1.5 48 | 49 | codecov 50 | 51 | setenv = 52 | PYWRANGLER_TEST_ENV = {envname} 53 | 54 | passenv = * 55 | 56 | [testenv:dev] 57 | basepython = python3.5 58 | usedevelop = True 59 | 60 | [testenv:flake8] 61 | skip_install = true 62 | changedir = {toxinidir} 63 | deps = flake8 64 | commands = flake8 setup.py pywrangler tests 65 | -------------------------------------------------------------------------------- /travisci/code_coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Different pandas/pyspark/dask versions are tested separately to avoid 4 | # irrelevant tests to be run. For example, no spark tests need to be run 5 | # when pandas wranglers are tested on older pandas versions. 6 | 7 | # However, code coverage drops due to many skipped tests. Therefore, there is a 8 | # master version (marked via env variables) which includes all tests for 9 | # pandas/pyspark/dask for the newest available versions which is subject to 10 | # code coverage. Non master versions will not be included in code coverage. 11 | 12 | if [[ $ENV_STRING == *"master"* ]]; then 13 | codecov -e $TOXENV 14 | fi 15 | -------------------------------------------------------------------------------- /travisci/fix_paths.py: -------------------------------------------------------------------------------- 1 | """This module fixes paths of the coverage.xml manually since codecov is 2 | unable to correctly read the coverage report generated within tox environments. 3 | 4 | More specifically, it replaces the tox file paths with `src` file paths which 5 | is expected by codecov. 6 | 7 | """ 8 | 9 | import re 10 | 11 | REGEX_FILENAME = r'(.*)\.tox.*site-packages(.*)' 12 | regex_filename = re.compile(REGEX_FILENAME) 13 | 14 | with open("coverage.xml", "r") as infile: 15 | print("Reading original coverage.xml report.") 16 | replaced = [regex_filename.sub("\g<1>src\g<2>", line) 17 | for line in infile] 18 | 19 | with open("coverage.xml", "w") as outfile: 20 | print("Overwriting coverage.xml. Replaced {} paths.".format(len(replaced))) 21 | outfile.writelines(replaced) 22 | -------------------------------------------------------------------------------- /travisci/java_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Spark requires Java 8 in order to work properly. However, TravisCI's Ubuntu 4 | # 16.04 ships with Java 11 and Java can't be set with `jdk` when python is 5 | # selected as language. Ubuntu 14.04 does not work due to missing python 3.7 6 | # support on TravisCI which does have Java 8 as default. 7 | 8 | if [[ $ENV_STRING == *"spark"* ]] || [[ $ENV_STRING == *"master"* ]]; then 9 | # show current JAVA_HOME and java version 10 | echo "Current JAVA_HOME: $JAVA_HOME" 11 | echo "Current java -version:" 12 | java -version 13 | 14 | # install Java 8 15 | sudo add-apt-repository -y ppa:openjdk-r/ppa 16 | sudo apt-get -qq update 17 | sudo apt-get install -y openjdk-8-jdk --no-install-recommends 18 | sudo update-java-alternatives -s java-1.8.0-openjdk-amd64 19 | 20 | # change JAVA_HOME to Java 8 21 | export JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 22 | fi 23 | -------------------------------------------------------------------------------- /travisci/tox_invocation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # To test individual pandas/pyspark/dask versions, only corresponding tests 4 | # need to be run (e.g test pandas 0.19.1, and only run pandas tests while 5 | # ignoring all pyspark and dask tests). To achieve this, the corresponding 6 | # pytest mark is passed to pytest. 7 | 8 | if [[ $ENV_STRING == *"pandas"* ]]; then 9 | MARKS="-- -m pandas" 10 | 11 | elif [[ $ENV_STRING == *"pyspark"* ]]; then 12 | MARKS="-- -m pyspark" 13 | 14 | elif [[ $ENV_STRING == *"dask"* ]]; then 15 | MARKS="-- -m dask" 16 | fi 17 | 18 | tox -e $(echo py$TRAVIS_PYTHON_VERSION-$ENV_STRING | tr -d .) $MARKS 19 | --------------------------------------------------------------------------------