├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── config.yml │ └── issue-template.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── release.yml └── workflows │ ├── build.yml │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── pixi.lock ├── pixi.toml ├── polarify ├── __init__.py └── main.py ├── pyproject.toml └── tests ├── __init__.py ├── functions.py ├── functions_310.py ├── test_error_handling.py └── test_parse_body.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @pavelzw @0xbe7a 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/issue-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New issue 3 | about: Create a new issue 4 | --- 5 | 6 | 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: / 5 | schedule: 6 | interval: monthly 7 | groups: 8 | actions: 9 | patterns: 10 | - "*" 11 | -------------------------------------------------------------------------------- /.github/release.yml: -------------------------------------------------------------------------------- 1 | changelog: 2 | exclude: 3 | labels: 4 | - ignore for release 5 | categories: 6 | - title: ✨ New features 7 | labels: 8 | - enhancement 9 | - title: 🐛 Bug fixes 10 | labels: 11 | - bug 12 | - title: 📝 Documentation 13 | labels: 14 | - documentation 15 | - title: ⬆️ Dependencies 16 | labels: 17 | - dependencies 18 | - title: 🚀 CI 19 | labels: 20 | - ci 21 | - title: 🤷🏻 Other changes 22 | labels: 23 | - '*' 24 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | permissions: 7 | contents: write 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | outputs: 13 | version-changed: ${{ steps.version-metadata.outputs.changed }} 14 | new-version: ${{ steps.version-metadata.outputs.newVersion }} 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.9' 20 | - name: Install Python build 21 | run: pip install build 22 | - name: Build project 23 | run: python -m build 24 | - name: Upload package 25 | uses: actions/upload-artifact@v4 26 | with: 27 | name: artifact 28 | path: dist/* 29 | - uses: Quantco/ui-actions/version-metadata@v1 30 | id: version-metadata 31 | with: 32 | file: ./pyproject.toml 33 | token: ${{ secrets.GITHUB_TOKEN }} 34 | version-extraction-override: 'regex:version = "(.*)"' 35 | 36 | release: 37 | name: Publish package 38 | if: github.event_name == 'push' && github.repository == 'Quantco/polarify' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true' 39 | needs: [build] 40 | runs-on: ubuntu-latest 41 | permissions: 42 | id-token: write 43 | contents: write 44 | environment: pypi 45 | steps: 46 | - uses: actions/download-artifact@v4 47 | with: 48 | name: artifact 49 | path: dist 50 | - name: Publish package on PyPi 51 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc 52 | - uses: actions/checkout@v4 53 | - name: Push v${{ needs.build.outputs.new-version }} tag 54 | run: | 55 | git tag v${{ needs.build.outputs.new-version }} 56 | git push origin v${{ needs.build.outputs.new-version }} 57 | - name: Create release 58 | uses: softprops/action-gh-release@c95fe1489396fe8a9eb87c0abf8aa5b2ef267fda 59 | with: 60 | generate_release_notes: true 61 | tag_name: v${{ needs.build.outputs.new-version }} 62 | draft: true 63 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | 7 | concurrency: 8 | group: ${{ github.workflow }}-${{ github.ref }} 9 | cancel-in-progress: true 10 | 11 | jobs: 12 | tests: 13 | name: Unit tests 14 | timeout-minutes: 15 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | env: 20 | - pl014 21 | - pl1 22 | - py39 23 | - py310 24 | - py311 25 | - py312 26 | - py313 27 | steps: 28 | - uses: actions/checkout@v4 29 | - uses: prefix-dev/setup-pixi@92815284c57faa15cd896c4d5cfb2d59f32dc43d 30 | with: 31 | environments: ${{ matrix.env }} 32 | - name: Install repository 33 | run: | 34 | pixi run -e ${{ matrix.env }} postinstall 35 | - name: Run unittests 36 | uses: pavelzw/pytest-action@510c5e90c360a185039bea56ce8b3e7e51a16507 37 | with: 38 | custom-pytest: pixi run -e ${{ matrix.env }} coverage 39 | report-title: ${{ matrix.env }} 40 | - name: Upload coverage reports to Codecov 41 | if: matrix.env == 'py312' 42 | uses: codecov/codecov-action@0565863a31f2c772f9f0395002a31e3f06189574 43 | env: 44 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 45 | file: ./coverage.xml 46 | 47 | pre-commit-checks: 48 | # TODO: switch to pixi once there is a good way 49 | name: pre-commit 50 | timeout-minutes: 15 51 | runs-on: ubuntu-latest 52 | steps: 53 | - name: Checkout branch 54 | uses: actions/checkout@v4 55 | - name: Run pre-commit 56 | uses: quantco/pre-commit-conda@v1 57 | 58 | lint-workflow-files: 59 | name: Lint workflow files 60 | runs-on: ubuntu-latest 61 | steps: 62 | - name: Checkout branch 63 | uses: actions/checkout@v4 64 | # https://github.com/rhysd/actionlint/blob/main/docs/usage.md#use-actionlint-on-github-actions 65 | - name: Download actionlint 66 | id: get_actionlint 67 | run: bash <(curl https://raw.githubusercontent.com/rhysd/actionlint/main/scripts/download-actionlint.bash) 68 | - name: Check workflow files 69 | run: ${{ steps.get_actionlint.outputs.executable }} -color 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Explicitly unignore hidden files for ripgrep. Individual folders can still 2 | # be ignored in the blocks below. 3 | !.* 4 | 5 | # NOTE: The block below is NOT copied sic from https://www.toptal.com/developers/gitignore. 6 | # Instead, it contains some customizations that should be taken into account when adjusting the 7 | # block. 8 | 9 | # Created by https://www.toptal.com/developers/gitignore/api/linux,macos,direnv,python,windows,pycharm+all,visualstudiocode,vim 10 | # Edit at https://www.toptal.com/developers/gitignore?templates=linux,macos,direnv,python,windows,pycharm+all,visualstudiocode,vim 11 | 12 | ### direnv ### 13 | .direnv 14 | .envrc 15 | 16 | ### Linux ### 17 | *~ 18 | 19 | # temporary files which can be created if a process still has a handle open of a deleted file 20 | .fuse_hidden* 21 | 22 | # KDE directory preferences 23 | .directory 24 | 25 | # Linux trash folder which might appear on any partition or disk 26 | .Trash-* 27 | 28 | # .nfs files are created when an open file is removed but is still being accessed 29 | .nfs* 30 | 31 | ### macOS ### 32 | # General 33 | .DS_Store 34 | .AppleDouble 35 | .LSOverride 36 | 37 | # Icon must end with two \r 38 | Icon 39 | 40 | 41 | # Thumbnails 42 | ._* 43 | 44 | # Files that might appear in the root of a volume 45 | .DocumentRevisions-V100 46 | .fseventsd 47 | .Spotlight-V100 48 | .TemporaryItems 49 | .Trashes 50 | .VolumeIcon.icns 51 | .com.apple.timemachine.donotpresent 52 | 53 | # Directories potentially created on remote AFP share 54 | .AppleDB 55 | .AppleDesktop 56 | Network Trash Folder 57 | Temporary Items 58 | .apdisk 59 | 60 | ### macOS Patch ### 61 | # iCloud generated files 62 | *.icloud 63 | 64 | ### PyCharm+all ### 65 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 66 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 67 | 68 | # User-specific stuff 69 | .idea/**/workspace.xml 70 | .idea/**/tasks.xml 71 | .idea/**/usage.statistics.xml 72 | .idea/**/dictionaries 73 | .idea/**/shelf 74 | 75 | # AWS User-specific 76 | .idea/**/aws.xml 77 | 78 | # Generated files 79 | .idea/**/contentModel.xml 80 | 81 | # Sensitive or high-churn files 82 | .idea/**/dataSources/ 83 | .idea/**/dataSources.ids 84 | .idea/**/dataSources.local.xml 85 | .idea/**/sqlDataSources.xml 86 | .idea/**/dynamic.xml 87 | .idea/**/uiDesigner.xml 88 | .idea/**/dbnavigator.xml 89 | 90 | # Gradle 91 | .idea/**/gradle.xml 92 | .idea/**/libraries 93 | 94 | # Gradle and Maven with auto-import 95 | # When using Gradle or Maven with auto-import, you should exclude module files, 96 | # since they will be recreated, and may cause churn. Uncomment if using 97 | # auto-import. 98 | # .idea/artifacts 99 | # .idea/compiler.xml 100 | # .idea/jarRepositories.xml 101 | # .idea/modules.xml 102 | # .idea/*.iml 103 | # .idea/modules 104 | # *.iml 105 | # *.ipr 106 | 107 | # CMake 108 | cmake-build-*/ 109 | 110 | # Mongo Explorer plugin 111 | .idea/**/mongoSettings.xml 112 | 113 | # File-based project format 114 | *.iws 115 | 116 | # IntelliJ 117 | out/ 118 | 119 | # mpeltonen/sbt-idea plugin 120 | .idea_modules/ 121 | 122 | # JIRA plugin 123 | atlassian-ide-plugin.xml 124 | 125 | # Cursive Clojure plugin 126 | .idea/replstate.xml 127 | 128 | # SonarLint plugin 129 | .idea/sonarlint/ 130 | 131 | # Crashlytics plugin (for Android Studio and IntelliJ) 132 | com_crashlytics_export_strings.xml 133 | crashlytics.properties 134 | crashlytics-build.properties 135 | fabric.properties 136 | 137 | # Editor-based Rest Client 138 | .idea/httpRequests 139 | 140 | # Android studio 3.1+ serialized cache file 141 | .idea/caches/build_file_checksums.ser 142 | 143 | ### PyCharm+all Patch ### 144 | # Ignore everything but code style settings and run configurations 145 | # that are supposed to be shared within teams. 146 | 147 | .idea/* 148 | 149 | !.idea/codeStyles 150 | !.idea/runConfigurations 151 | 152 | ### Python ### 153 | # Byte-compiled / optimized / DLL files 154 | __pycache__/ 155 | *.py[cod] 156 | *$py.class 157 | 158 | # C extensions 159 | *.so 160 | 161 | # Distribution / packaging 162 | .Python 163 | build/ 164 | develop-eggs/ 165 | dist/ 166 | downloads/ 167 | eggs/ 168 | .eggs/ 169 | lib/ 170 | lib64/ 171 | parts/ 172 | sdist/ 173 | var/ 174 | wheels/ 175 | share/python-wheels/ 176 | *.egg-info/ 177 | .installed.cfg 178 | *.egg 179 | MANIFEST 180 | 181 | # PyInstaller 182 | # Usually these files are written by a python script from a template 183 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 184 | *.manifest 185 | *.spec 186 | 187 | # Installer logs 188 | pip-log.txt 189 | pip-delete-this-directory.txt 190 | 191 | # Unit test / coverage reports 192 | htmlcov/ 193 | .tox/ 194 | .nox/ 195 | .coverage 196 | .coverage.* 197 | .cache 198 | nosetests.xml 199 | coverage.xml 200 | *.cover 201 | *.py,cover 202 | .hypothesis/ 203 | .pytest_cache/ 204 | cover/ 205 | 206 | # Translations 207 | *.mo 208 | *.pot 209 | 210 | # Django stuff: 211 | *.log 212 | local_settings.py 213 | db.sqlite3 214 | db.sqlite3-journal 215 | 216 | # Flask stuff: 217 | instance/ 218 | .webassets-cache 219 | 220 | # Scrapy stuff: 221 | .scrapy 222 | 223 | # Sphinx documentation 224 | docs/_build/ 225 | 226 | # PyBuilder 227 | .pybuilder/ 228 | target/ 229 | 230 | # Jupyter Notebook 231 | .ipynb_checkpoints 232 | 233 | # IPython 234 | profile_default/ 235 | ipython_config.py 236 | 237 | # pyenv 238 | # For a library or package, you might want to ignore these files since the code is 239 | # intended to run in multiple environments; otherwise, check them in: 240 | # .python-version 241 | 242 | # pipenv 243 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 244 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 245 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 246 | # install all needed dependencies. 247 | #Pipfile.lock 248 | 249 | # poetry 250 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 251 | # This is especially recommended for binary packages to ensure reproducibility, and is more 252 | # commonly ignored for libraries. 253 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 254 | #poetry.lock 255 | 256 | # pixi 257 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 258 | # This is especially recommended for binary packages to ensure reproducibility, and is more 259 | # commonly ignored for libraries. 260 | .pixi 261 | 262 | # pdm 263 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 264 | #pdm.lock 265 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 266 | # in version control. 267 | # https://pdm.fming.dev/#use-with-ide 268 | .pdm.toml 269 | 270 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 271 | __pypackages__/ 272 | 273 | # Celery stuff 274 | celerybeat-schedule 275 | celerybeat.pid 276 | 277 | # SageMath parsed files 278 | *.sage.py 279 | 280 | # Environments 281 | .env 282 | .venv 283 | env/ 284 | venv/ 285 | ENV/ 286 | env.bak/ 287 | venv.bak/ 288 | 289 | # Spyder project settings 290 | .spyderproject 291 | .spyproject 292 | 293 | # Rope project settings 294 | .ropeproject 295 | 296 | # mkdocs documentation 297 | /site 298 | 299 | # mypy 300 | .mypy_cache/ 301 | .dmypy.json 302 | dmypy.json 303 | 304 | # Pyre type checker 305 | .pyre/ 306 | 307 | # pytype static type analyzer 308 | .pytype/ 309 | 310 | # Cython debug symbols 311 | cython_debug/ 312 | 313 | # PyCharm 314 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 315 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 316 | # and can be added to the global gitignore or merged into this file. For a more nuclear 317 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 318 | .idea/ 319 | 320 | ### Python Patch ### 321 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 322 | poetry.toml 323 | 324 | # ruff 325 | .ruff_cache/ 326 | 327 | # LSP config files 328 | pyrightconfig.json 329 | 330 | ### Vim ### 331 | # Swap 332 | [._]*.s[a-v][a-z] 333 | !*.svg # comment out if you don't need vector files 334 | [._]*.sw[a-p] 335 | [._]s[a-rt-v][a-z] 336 | [._]ss[a-gi-z] 337 | [._]sw[a-p] 338 | 339 | # Session 340 | Session.vim 341 | Sessionx.vim 342 | 343 | # Temporary 344 | .netrwhist 345 | # Auto-generated tag files 346 | tags 347 | # Persistent undo 348 | [._]*.un~ 349 | 350 | ### VisualStudioCode ### 351 | .vscode/ 352 | #!.vscode/settings.json 353 | #!.vscode/tasks.json 354 | #!.vscode/launch.json 355 | #!.vscode/extensions.json 356 | #!.vscode/*.code-snippets 357 | 358 | # Local History for Visual Studio Code 359 | .history/ 360 | 361 | # Built Visual Studio Code Extensions 362 | *.vsix 363 | 364 | ### VisualStudioCode Patch ### 365 | # Ignore all local history of files 366 | .history 367 | .ionide 368 | 369 | ### Windows ### 370 | # Windows thumbnail cache files 371 | Thumbs.db 372 | Thumbs.db:encryptable 373 | ehthumbs.db 374 | ehthumbs_vista.db 375 | 376 | # Dump file 377 | *.stackdump 378 | 379 | # Folder config file 380 | [Dd]esktop.ini 381 | 382 | # Recycle Bin used on file shares 383 | $RECYCLE.BIN/ 384 | 385 | # Windows Installer files 386 | *.cab 387 | *.msi 388 | *.msix 389 | *.msm 390 | *.msp 391 | 392 | # Windows shortcuts 393 | *.lnk 394 | 395 | # End of https://www.toptal.com/developers/gitignore/api/linux,macos,direnv,python,windows,pycharm+all,visualstudiocode,vim 396 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - id: check-toml 9 | - repo: https://github.com/quantco/pre-commit-mirrors-ruff 10 | rev: 0.4.3 11 | hooks: 12 | - id: ruff-conda 13 | - id: ruff-format-conda 14 | - repo: https://github.com/quantco/pre-commit-mirrors-mypy 15 | rev: 1.10.0 16 | hooks: 17 | - id: mypy-conda 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 QuantCo Inc, Pavel Zwerschke, Bela Stoyan 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # polarIFy: Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️ 2 | 3 | ![License][license-badge] 4 | [![Build Status][build-badge]][build] 5 | [![conda-forge][conda-forge-badge]][conda-forge] 6 | [![pypi-version][pypi-badge]][pypi] 7 | [![python-version][python-version-badge]][pypi] 8 | [![codecov][codecov-badge]][codecov] 9 | 10 | [license-badge]: https://img.shields.io/github/license/quantco/polarify?style=flat-square 11 | [build-badge]: https://img.shields.io/github/actions/workflow/status/quantco/polarify/ci.yml?style=flat-square&branch=main 12 | [build]: https://github.com/quantco/polarify/actions/ 13 | [conda-forge]: https://prefix.dev/channels/conda-forge/packages/polarify 14 | [conda-forge-badge]: https://img.shields.io/conda/pn/conda-forge/polarify?style=flat-square&logoColor=white&logo=conda-forge 15 | [pypi]: https://pypi.org/project/polarify 16 | [pypi-badge]: https://img.shields.io/pypi/v/polarify.svg?style=flat-square&logo=pypi&logoColor=white 17 | [python-version-badge]: https://img.shields.io/pypi/pyversions/polarify?style=flat-square&logoColor=white&logo=python 18 | [codecov-badge]: https://img.shields.io/codecov/c/github/quantco/polarify?style=flat-square&logo=codecov 19 | [codecov]: https://codecov.io/gh/quantco/polarify 20 | 21 | Welcome to **polarIFy**, a Python function decorator that simplifies the way you write logical statements for Polars. With polarIFy, you can use Python's language structures like `if / elif / else` statements and transform them into `pl.when(..).then(..).otherwise(..)` statements. This makes your code more readable and less cumbersome to write. 🎉 22 | 23 | ## 🎯 Usage 24 | 25 | polarIFy can automatically transform Python functions using `if / elif / else` statements into Polars expressions. 26 | 27 | ### Basic Transformation 28 | 29 | Here's an example: 30 | 31 | ```python 32 | @polarify 33 | def signum(x: pl.Expr) -> pl.Expr: 34 | s = 0 35 | if x > 0: 36 | s = 1 37 | elif x < 0: 38 | s = -1 39 | return s 40 | ``` 41 | 42 | This gets transformed into: 43 | 44 | ```python 45 | def signum(x: pl.Expr) -> pl.Expr: 46 | return pl.when(x > 0).then(1).otherwise(pl.when(x < 0).then(-1).otherwise(0)) 47 | ``` 48 | 49 | ### Handling Multiple Statements 50 | 51 | polarIFy can also handle multiple statements like: 52 | 53 | ```python 54 | @polarify 55 | def multiple_if_statement(x: pl.Expr) -> pl.Expr: 56 | a = 1 if x > 0 else 5 57 | b = 2 if x < 0 else 2 58 | return a + b 59 | ``` 60 | 61 | which becomes: 62 | 63 | ```python 64 | def multiple_if_statement(x): 65 | return pl.when(x > 0).then(1).otherwise(5) + pl.when(x < 0).then(2).otherwise(2) 66 | ``` 67 | 68 | ### Handling Nested Statements 69 | 70 | Additionally, it can handle nested statements: 71 | 72 | ```python 73 | @polarify 74 | def nested_if_else(x: pl.Expr) -> pl.Expr: 75 | if x > 0: 76 | if x > 1: 77 | s = 2 78 | else: 79 | s = 1 80 | elif x < 0: 81 | s = -1 82 | else: 83 | s = 0 84 | return s 85 | ``` 86 | 87 | which becomes: 88 | 89 | ```python 90 | def nested_if_else(x: pl.Expr) -> pl.Expr: 91 | return pl.when(x > 0).then(pl.when(x > 1).then(2).otherwise(1)).otherwise(pl.when(x < 0).then(-1).otherwise(0)) 92 | ``` 93 | 94 | So you can still write readable row-wise python code while the `@polarify` decorator transforms it into a function that works with efficient polars expressions. 95 | 96 | ### Using a `polarify`d function 97 | 98 | ```python 99 | import polars as pl 100 | from polarify import polarify 101 | 102 | @polarify 103 | def complicated_operation(x: pl.Expr) -> pl.Expr: 104 | k = 0 105 | c = 2 106 | if x > 0: 107 | k = 1 108 | c = 0 109 | if x < 10: 110 | c = 1 111 | elif x < 0: 112 | k = -1 113 | return k * c 114 | 115 | 116 | df = pl.DataFrame({"x": [-1, 1, 5, 10]}) 117 | result = df.select(pl.col("x"), complicated_operation(pl.col("x"))) 118 | print(result) 119 | # shape: (4, 2) 120 | # ┌─────┬─────────┐ 121 | # │ x ┆ literal │ 122 | # │ --- ┆ --- │ 123 | # │ i64 ┆ i32 │ 124 | # ╞═════╪═════════╡ 125 | # │ -1 ┆ -2 │ 126 | # │ 1 ┆ 1 │ 127 | # │ 5 ┆ 1 │ 128 | # │ 10 ┆ 0 │ 129 | # └─────┴─────────┘ 130 | ``` 131 | 132 | ### Displaying the transpiled polars expression 133 | 134 | You can also display the transpiled polars expression by calling the `transform_func_to_new_source` method: 135 | 136 | ```python 137 | from polarify import transform_func_to_new_source 138 | 139 | def signum(x): 140 | s = 0 141 | if x > 0: 142 | s = 1 143 | elif x < 0: 144 | s = -1 145 | return s 146 | 147 | 148 | print(f"Original function:\n{inspect.getsource(signum)}") 149 | # Original function: 150 | # def signum(x): 151 | # s = 0 152 | # if x > 0: 153 | # s = 1 154 | # elif x < 0: 155 | # s = -1 156 | # return s 157 | print(f"Transformed function:\n{transform_func_to_new_source(signum)}") 158 | # Transformed function: 159 | # def signum_polarified(x): 160 | # import polars as pl 161 | # return pl.when(x > 0).then(1).otherwise(pl.when(x < 0).then(-1).otherwise(0)) 162 | ``` 163 | 164 | TODO: complicated example with nested functions 165 | 166 | ## ⚙️ How It Works 167 | 168 | polarIFy achieves this by parsing the AST (Abstract Syntax Tree) of the function and transforming the body into a Polars expression by inlining the different branches. 169 | To get a more detailed understanding of what's happening under the hood, check out our [blog post](https://tech.quantco.com/blog/polarify) explaining how polarify works! 170 | 171 | ## 💿 Installation 172 | 173 | ### conda 174 | 175 | ```bash 176 | conda install -c conda-forge polarify 177 | # or micromamba 178 | micromamba install -c conda-forge polarify 179 | # or pixi 180 | pixi add polarify 181 | ``` 182 | 183 | ### pip 184 | 185 | ```bash 186 | pip install polarify 187 | ``` 188 | 189 | ## ⚠️ Limitations 190 | 191 | polarIFy is still in an early stage of development and doesn't support the full Python language. Here's a list of the currently supported and unsupported operations: 192 | 193 | ### Supported operations 194 | 195 | - `if / else / elif` statements 196 | - binary operations (like `+`, `==`, `>`, `&`, `|`, ...) 197 | - unary operations (like `~`, `-`, `not`, ...) (TODO) 198 | - assignments (like `x = 1`) 199 | - polars expressions (like `pl.col("x")`, TODO) 200 | - side-effect free functions that return a polars expression (can be generated by `@polarify`) (TODO) 201 | - `match` statements 202 | 203 | ### Unsupported operations 204 | 205 | - `for` loops 206 | - `while` loops 207 | - `break` statements 208 | - `:=` walrus operator 209 | - dictionary mappings in `match` statements 210 | - list matching in `match` statements 211 | - star patterns in `match statements 212 | - functions with side-effects (`print`, `pl.write_csv`, ...) 213 | 214 | ## 🚀 Benchmarks 215 | 216 | TODO: Add some benchmarks 217 | 218 | ## 📥 Development installation 219 | 220 | ```bash 221 | pixi install 222 | pixi run postinstall 223 | pixi run test 224 | ``` 225 | -------------------------------------------------------------------------------- /pixi.toml: -------------------------------------------------------------------------------- 1 | # TODO: move to pyproject.toml when pixi supports it 2 | # https://github.com/prefix-dev/pixi/issues/79 3 | [project] 4 | name = "polarify" 5 | description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️" 6 | authors = [ 7 | "Bela Stoyan ", 8 | "Pavel Zwerschke ", 9 | ] 10 | channels = ["conda-forge"] 11 | platforms = ["linux-64", "osx-arm64", "osx-64", "win-64"] 12 | 13 | [dependencies] 14 | python = ">=3.9" 15 | polars = ">=0.14.24,<2" 16 | 17 | [tasks] 18 | postinstall = "pip install --no-build-isolation --no-deps --disable-pip-version-check -e ." 19 | 20 | [feature.py39.dependencies] 21 | python = "3.9.*" 22 | [feature.py310.dependencies] 23 | python = "3.10.*" 24 | [feature.py311.dependencies] 25 | python = "3.11.*" 26 | [feature.py312.dependencies] 27 | python = "3.12.*" 28 | [feature.py313.dependencies] 29 | python = "3.13.*" 30 | [feature.pl014.dependencies] 31 | polars = "0.14.*" 32 | [feature.pl1.dependencies] 33 | polars = "1.*" 34 | 35 | [host-dependencies] 36 | python = "*" 37 | pip = "*" 38 | hatchling = "*" 39 | 40 | [feature.test.dependencies] 41 | pytest = "*" 42 | pytest-md = "*" 43 | pytest-emoji = "*" 44 | hypothesis = "*" 45 | pytest-cov = "*" 46 | [feature.test.tasks] 47 | test = "pytest" 48 | coverage = "pytest --cov=polarify --cov-report=xml" 49 | 50 | [feature.lint.dependencies] 51 | pre-commit = "*" 52 | [feature.lint.tasks] 53 | lint = "pre-commit run --all" 54 | 55 | [environments] 56 | default = ["test", "py313", "pl1"] 57 | pl014 = ["pl014", "py310", "test"] 58 | pl1 = ["pl1", "py310", "test"] 59 | py39 = ["py39", "test"] 60 | py310 = ["py310", "test"] 61 | py311 = ["py311", "test"] 62 | py312 = ["py312", "test"] 63 | py313 = ["py313", "test"] 64 | lint = ["lint"] 65 | -------------------------------------------------------------------------------- /polarify/__init__.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import importlib.metadata 3 | import inspect 4 | import warnings 5 | from functools import wraps 6 | 7 | from .main import parse_body, transform_tree_into_expr 8 | 9 | try: 10 | __version__ = importlib.metadata.version(__name__) 11 | except importlib.metadata.PackageNotFoundError as e: 12 | warnings.warn(f"Could not determine version of {__name__}", stacklevel=1) 13 | warnings.warn(str(e), stacklevel=1) 14 | __version__ = "unknown" 15 | 16 | 17 | def transform_func_to_new_source(func) -> str: 18 | source = inspect.getsource(func) 19 | tree = ast.parse(source) 20 | func_def: ast.FunctionDef = tree.body[0] # type: ignore 21 | root_node = parse_body(func_def.body) 22 | 23 | expr = transform_tree_into_expr(root_node) 24 | 25 | # Replace the body of the function with the parsed expr 26 | # Also import polars as pl since this is used in the generated code 27 | # We don't want to rely on the user having imported polars as pl 28 | func_def.body = [ 29 | ast.Import(names=[ast.alias(name="polars", asname="pl")]), 30 | ast.Return(value=expr), 31 | ] 32 | # TODO: make this prettier 33 | func_def.decorator_list = [] 34 | func_def.name += "_polarified" 35 | 36 | # Unparse the modified AST back into source code 37 | return ast.unparse(tree) 38 | 39 | 40 | def polarify(func): 41 | new_func_code = transform_func_to_new_source(func) 42 | # Execute the new function code in the original function's globals 43 | exec_globals = func.__globals__ 44 | exec(new_func_code, exec_globals) 45 | 46 | # Get the new function from the globals 47 | new_func = exec_globals[func.__name__ + "_polarified"] 48 | 49 | @wraps(func) 50 | def wrapper(*args, **kwargs): 51 | return new_func(*args, **kwargs) 52 | 53 | return wrapper 54 | -------------------------------------------------------------------------------- /polarify/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ast 4 | import sys 5 | from collections.abc import Sequence 6 | from copy import copy, deepcopy 7 | from dataclasses import dataclass 8 | 9 | PY_39 = sys.version_info <= (3, 9) 10 | 11 | # TODO: make walrus throw ValueError 12 | 13 | 14 | @dataclass 15 | class UnresolvedCase: 16 | """ 17 | An unresolved case in a conditional statement. (if, match, etc.) 18 | Each case consists of a test expression and a state. 19 | The value of the state is not yet resolved. 20 | """ 21 | 22 | test: ast.expr 23 | state: State 24 | 25 | def __init__(self, test: ast.expr, then: State): 26 | self.test = test 27 | self.state = then 28 | 29 | 30 | @dataclass 31 | class ResolvedCase: 32 | """ 33 | A resolved case in a conditional statement. (if, match, etc.) 34 | Each case consists of a test expression and a state. 35 | The value of the state is resolved. 36 | """ 37 | 38 | test: ast.expr 39 | state: ast.expr 40 | 41 | def __init__(self, test: ast.expr, then: ast.expr): 42 | self.test = test 43 | self.state = then 44 | 45 | def __iter__(self): 46 | return iter([self.test, self.state]) 47 | 48 | 49 | def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call: 50 | nodes: list[ast.Call] = [] 51 | 52 | assert body or orelse, "No when-then cases provided." 53 | 54 | for test, then in body: 55 | when_node = ast.Call( 56 | func=ast.Attribute( 57 | value=nodes[-1] if nodes else ast.Name(id="pl", ctx=ast.Load()), 58 | attr="when", 59 | ctx=ast.Load(), 60 | ), 61 | args=[test], 62 | keywords=[], 63 | ) 64 | then_node = ast.Call( 65 | func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()), 66 | args=[then], 67 | keywords=[], 68 | ) 69 | nodes.append(then_node) 70 | final_node = ast.Call( 71 | func=ast.Attribute(value=nodes[-1], attr="otherwise", ctx=ast.Load()), 72 | args=[orelse], 73 | keywords=[], 74 | ) 75 | return final_node 76 | 77 | 78 | # ruff: noqa: N802 79 | class InlineTransformer(ast.NodeTransformer): 80 | def __init__(self, assignments: dict[str, ast.expr]): 81 | self.assignments = assignments 82 | 83 | @classmethod 84 | def inline_expr(cls, expr: ast.expr, assignments: dict[str, ast.expr]) -> ast.expr: 85 | expr = cls(assignments).visit(deepcopy(expr)) 86 | assert isinstance(expr, ast.expr) 87 | return expr 88 | 89 | def visit_Name(self, node: ast.Name) -> ast.expr: 90 | if node.id in self.assignments: 91 | return self.visit(self.assignments[node.id]) 92 | else: 93 | return node 94 | 95 | def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp: 96 | node.left = self.visit(node.left) 97 | node.right = self.visit(node.right) 98 | return node 99 | 100 | def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.UnaryOp: 101 | node.operand = self.visit(node.operand) 102 | return node 103 | 104 | def visit_Call(self, node: ast.Call) -> ast.Call: 105 | node.args = [self.visit(arg) for arg in node.args] 106 | node.keywords = [ast.keyword(arg=k.arg, value=self.visit(k.value)) for k in node.keywords] 107 | return node 108 | 109 | def visit_IfExp(self, node: ast.IfExp) -> ast.Call: 110 | test = self.visit(node.test) 111 | body = self.visit(node.body) 112 | orelse = self.visit(node.orelse) 113 | return build_polars_when_then_otherwise([ResolvedCase(test, body)], orelse) 114 | 115 | def visit_Constant(self, node: ast.Constant) -> ast.Constant: 116 | return node 117 | 118 | def visit_Compare(self, node: ast.Compare) -> ast.Compare: 119 | if len(node.comparators) > 1: 120 | raise ValueError("Polars can't handle chained comparisons") 121 | node.left = self.visit(node.left) 122 | node.comparators = [self.visit(c) for c in node.comparators] 123 | return node 124 | 125 | def generic_visit(self, node): 126 | raise ValueError(f"Unsupported expression type: {type(node)}") 127 | 128 | 129 | @dataclass 130 | class UnresolvedState: 131 | """ 132 | When an execution flow is not finished (i.e., not returned) in a function, we need to keep track 133 | of the assignments. 134 | """ 135 | 136 | assignments: dict[str, ast.expr] 137 | 138 | def handle_assign(self, stmt: ast.Assign): 139 | def _handle_assign(stmt: ast.Assign, assignments: dict[str, ast.expr]): 140 | for t in stmt.targets: 141 | if isinstance(t, ast.Name): 142 | new_value = InlineTransformer.inline_expr(stmt.value, assignments) 143 | assignments[t.id] = new_value 144 | elif isinstance(t, (ast.List, ast.Tuple)): 145 | if not isinstance(stmt.value, (ast.List, ast.Tuple)): 146 | raise ValueError( 147 | f"Assignment target is {type(t)}, but value is {type(stmt.value)}" 148 | ) 149 | assert len(t.elts) == len(stmt.value.elts) 150 | for sub_t, sub_v in zip(t.elts, stmt.value.elts): 151 | _handle_assign(ast.Assign(targets=[sub_t], value=sub_v), assignments) 152 | else: 153 | raise ValueError( 154 | f"Unsupported expression type inside assignment target: {type(t)}" 155 | ) 156 | 157 | _handle_assign(stmt, self.assignments) 158 | 159 | 160 | @dataclass 161 | class ReturnState: 162 | """ 163 | The expression of a return statement. 164 | """ 165 | 166 | expr: ast.expr 167 | 168 | 169 | @dataclass 170 | class ConditionalState: 171 | """ 172 | A list of conditional states. 173 | Each case consists of a test expression and a state. 174 | """ 175 | 176 | body: Sequence[UnresolvedCase] 177 | orelse: State 178 | 179 | 180 | @dataclass 181 | class State: 182 | """ 183 | A state in the execution flow. 184 | Either unresolved assignments, a return statement, or a conditional state. 185 | """ 186 | 187 | node: UnresolvedState | ReturnState | ConditionalState 188 | 189 | def translate_match( 190 | self, 191 | subj: ast.expr | Sequence[ast.expr] | ast.Tuple, 192 | pattern: ast.pattern, 193 | guard: ast.expr | None = None, 194 | ): 195 | """ 196 | Translate a match_case statement into a regular AST expression. 197 | translate_match takes a subject, a pattern and a guard. 198 | patterns can be a MatchValue, MatchAs, MatchOr, or MatchSequence. 199 | subjects can be a single expression (e.g x or (2 * x + 1)) or a list of expressions. 200 | translate_match is called per each case in a match statement. 201 | """ 202 | 203 | if isinstance(pattern, ast.MatchValue): 204 | equality_ast = ast.Compare( 205 | left=subj, 206 | ops=[ast.Eq()], 207 | comparators=[pattern.value], 208 | ) 209 | 210 | if guard is not None: 211 | return ast.BinOp( 212 | left=guard, 213 | op=ast.BitAnd(), 214 | right=equality_ast, 215 | ) 216 | 217 | return equality_ast 218 | elif isinstance(pattern, ast.MatchAs): 219 | if pattern.name is not None: 220 | self.handle_assign( 221 | ast.Assign( 222 | targets=[ast.Name(id=pattern.name, ctx=ast.Store())], 223 | value=subj, 224 | ) 225 | ) 226 | return guard 227 | elif isinstance(pattern, ast.MatchOr): 228 | return ast.BinOp( 229 | left=self.translate_match(subj, pattern.patterns[0], guard), 230 | op=ast.BitOr(), 231 | right=( 232 | self.translate_match(subj, ast.MatchOr(patterns=pattern.patterns[1:])) 233 | if pattern.patterns[2:] 234 | else self.translate_match(subj, pattern.patterns[1]) 235 | ), 236 | ) 237 | elif isinstance(pattern, ast.MatchSequence): 238 | if isinstance(pattern.patterns[-1], ast.MatchStar): 239 | raise ValueError("starred patterns are not supported.") 240 | 241 | if isinstance(subj, ast.Tuple): 242 | # TODO: Use polars list operations in the future 243 | left = self.translate_match(subj.elts[0], pattern.patterns[0], guard) 244 | right = ( 245 | self.translate_match( 246 | ast.Tuple(elts=subj.elts[1:]), 247 | ast.MatchSequence(patterns=pattern.patterns[1:]), 248 | ) 249 | if pattern.patterns[2:] 250 | else self.translate_match(subj.elts[1], pattern.patterns[1]) 251 | ) 252 | 253 | return ( 254 | left or right 255 | if left is None or right is None 256 | else ast.BinOp(left=left, op=ast.BitAnd(), right=right) 257 | ) 258 | raise ValueError("Matching lists is not supported.") 259 | else: 260 | raise ValueError( 261 | f"Incompatible match and subject types: {type(pattern)} and {type(subj)}." 262 | ) 263 | 264 | def handle_assign(self, expr: ast.Assign | ast.AnnAssign): 265 | if isinstance(expr, ast.AnnAssign): 266 | expr = ast.Assign(targets=[expr.target], value=expr.value) 267 | 268 | if isinstance(self.node, UnresolvedState): 269 | self.node.handle_assign(expr) 270 | elif isinstance(self.node, ConditionalState): 271 | for case in self.node.body: 272 | case.state.handle_assign(expr) 273 | self.node.orelse.handle_assign(expr) 274 | 275 | def handle_if(self, stmt: ast.If): 276 | if isinstance(self.node, UnresolvedState): 277 | self.node = ConditionalState( 278 | body=[ 279 | UnresolvedCase( 280 | InlineTransformer.inline_expr(stmt.test, self.node.assignments), 281 | parse_body(stmt.body, copy(self.node.assignments)), 282 | ) 283 | ], 284 | orelse=parse_body(stmt.orelse, copy(self.node.assignments)), 285 | ) 286 | elif isinstance(self.node, ConditionalState): 287 | for case in self.node.body: 288 | case.state.handle_if(stmt) 289 | self.node.orelse.handle_if(stmt) 290 | 291 | def handle_return(self, value: ast.expr): 292 | if isinstance(self.node, UnresolvedState): 293 | self.node = ReturnState( 294 | expr=InlineTransformer.inline_expr(value, self.node.assignments) 295 | ) 296 | elif isinstance(self.node, ConditionalState): 297 | for case in self.node.body: 298 | case.state.handle_return(value) 299 | self.node.orelse.handle_return(value) 300 | 301 | def handle_match(self, stmt: ast.Match): 302 | def is_catch_all(case: ast.match_case) -> bool: 303 | # We check if the case is a catch-all pattern without a guard 304 | # If it has a guard, we treat it as a regular case 305 | return ( 306 | isinstance(case.pattern, ast.MatchAs) 307 | and case.pattern.name is None 308 | and case.guard is None 309 | ) 310 | 311 | def ignore_case(case: ast.match_case) -> bool: 312 | # if the length of the pattern is not equal to the length of the subject, python ignores the case 313 | return ( 314 | isinstance(case.pattern, ast.MatchSequence) 315 | and isinstance(stmt.subject, ast.Tuple) 316 | and len(stmt.subject.elts) != len(case.pattern.patterns) 317 | ) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple)) 318 | 319 | if isinstance(self.node, UnresolvedState): 320 | # We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case. 321 | orelse = next( 322 | iter([case.body for case in stmt.cases if is_catch_all(case)]), 323 | [], 324 | ) 325 | self.node = ConditionalState( 326 | body=[ 327 | UnresolvedCase( 328 | # translate_match transforms the match statement case into regular AST expressions so that the InlineTransformer can handle assignments correctly 329 | # Note that by the time parse_body is called this has mutated the assignments 330 | InlineTransformer.inline_expr( 331 | self.translate_match(stmt.subject, case.pattern, case.guard), 332 | self.node.assignments, 333 | ), 334 | parse_body(case.body, copy(self.node.assignments)), 335 | ) 336 | for case in stmt.cases 337 | if not is_catch_all(case) and not ignore_case(case) 338 | ], 339 | orelse=parse_body( 340 | orelse, 341 | copy(self.node.assignments), 342 | ), 343 | ) 344 | elif isinstance(self.node, ConditionalState): 345 | for case in self.node.body: 346 | case.state.handle_match(stmt) 347 | self.node.orelse.handle_match(stmt) 348 | 349 | 350 | def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State: 351 | if assignments is None: 352 | assignments = {} 353 | state = State(UnresolvedState(assignments)) 354 | for stmt in full_body: 355 | if isinstance(stmt, (ast.Assign, ast.AnnAssign)): 356 | state.handle_assign(stmt) 357 | elif isinstance(stmt, ast.If): 358 | state.handle_if(stmt) 359 | elif isinstance(stmt, ast.Return): 360 | if stmt.value is None: 361 | raise ValueError("return needs a value") 362 | state.handle_return(stmt.value) 363 | break 364 | elif isinstance(stmt, ast.Match): 365 | assert not PY_39 366 | state.handle_match(stmt) 367 | else: 368 | raise ValueError(f"Unsupported statement type: {type(stmt)}") 369 | return state 370 | 371 | 372 | def transform_tree_into_expr(node: State) -> ast.expr: 373 | if isinstance(node.node, ReturnState): 374 | return node.node.expr 375 | elif isinstance(node.node, ConditionalState): 376 | if not node.node.body: 377 | # this happens if none of the cases will ever match or exist 378 | # in these cases we just need to return the orelse body 379 | return transform_tree_into_expr(node.node.orelse) 380 | return build_polars_when_then_otherwise( 381 | [ 382 | ResolvedCase(case.test, transform_tree_into_expr(case.state)) 383 | for case in node.node.body 384 | ], 385 | transform_tree_into_expr(node.node.orelse), 386 | ) 387 | else: 388 | raise ValueError("Not all branches return") 389 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "polarify" 7 | description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️" 8 | version = "0.2.1" 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | requires-python = ">=3.9" 12 | authors = [ 13 | { name = "Bela Stoyan", email = "bela.stoyan@quantco.com" }, 14 | { name = "Pavel Zwerschke", email = "pavel.zwerschke@quantco.com" }, 15 | ] 16 | classifiers = [ 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Programming Language :: Python :: 3.13", 23 | ] 24 | dependencies = ["polars >=0.14.24,<2"] 25 | 26 | [project.urls] 27 | Homepage = "https://github.com/quantco/polarify" 28 | 29 | [tool.hatch.build.targets.sdist] 30 | include = ["/polarify"] 31 | 32 | [tool.ruff] 33 | line-length = 100 34 | target-version = "py39" 35 | 36 | [tool.ruff.lint] 37 | select = [ 38 | # pyflakes 39 | "F", 40 | # pycodestyle 41 | "E", 42 | "W", 43 | # flake8-builtins 44 | "A", 45 | # flake8-bugbear 46 | "B", 47 | # flake8-comprehensions 48 | "C4", 49 | # flake8-simplify 50 | "SIM", 51 | # flake8-unused-arguments 52 | "ARG", 53 | # pylint 54 | "PL", 55 | # tidy 56 | "TID", 57 | # isort 58 | "I", 59 | # pep8-naming 60 | "N", 61 | # pyupgrade 62 | "UP", 63 | ] 64 | ignore = [ 65 | # may cause conflicts with ruff formatter 66 | "E501", 67 | "W191", 68 | ] 69 | 70 | [tool.ruff.format] 71 | quote-style = "double" 72 | indent-style = "space" 73 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Quantco/polarify/8325d8cac05889e93bdeb0196882ce5ea50abb64/tests/__init__.py -------------------------------------------------------------------------------- /tests/functions.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | # ruff must not change the AST of the test functions, even if they are semantically equivalent. 3 | import sys 4 | 5 | if sys.version_info >= (3, 10): 6 | from .functions_310 import functions_310, unsupported_functions_310 7 | else: 8 | functions_310 = [] 9 | unsupported_functions_310 = [] 10 | 11 | 12 | def signum(x): 13 | s = 0 14 | if x > 0: 15 | s = 1 16 | elif x < 0: 17 | s = -1 18 | return s 19 | 20 | 21 | def signum_no_default(x): 22 | if x > 0: 23 | return 1 24 | elif x < 0: 25 | return -1 26 | return 0 27 | 28 | 29 | def nested_partial_return_with_assignments(x): 30 | if x > 0: 31 | s = 1 32 | if x > 1: 33 | s = 2 34 | return s + x 35 | else: 36 | s = -1 37 | else: 38 | return -5 - x 39 | return s * x 40 | 41 | 42 | def early_return(x): 43 | if x > 0: 44 | return 1 45 | return 0 46 | 47 | 48 | def assign_both_branches(x): 49 | if x > 0: 50 | s = 1 51 | else: 52 | s = -1 53 | return s 54 | 55 | 56 | def unary_expr(x): 57 | s = -x 58 | return s 59 | 60 | 61 | def call_target_identity(x): 62 | return x 63 | 64 | 65 | def call_expr(x): 66 | k = x * 2 67 | s = call_target_identity(k + 3) 68 | return s 69 | 70 | 71 | def if_expr(x): 72 | s = 1 if x > 0 else -1 73 | return s 74 | 75 | 76 | def if_expr2(x): 77 | s = 1 + (x if x > 0 else -1) 78 | return s 79 | 80 | 81 | def if_expr3(x): 82 | s = 1 + ((3 if x < 10 else 5) if x > 0 else -1) 83 | return s 84 | 85 | 86 | def compare_expr(x): 87 | if (0 < x) & (x < 10): 88 | s = 1 89 | else: 90 | s = 2 91 | return s 92 | 93 | 94 | def bool_op(x): 95 | if (0 < x) and (x < 10): 96 | return 0 97 | else: 98 | return 1 99 | 100 | 101 | def chained_compare_expr(x): 102 | if 0 < x < 10: 103 | s = 1 104 | else: 105 | s = 2 106 | return s 107 | 108 | 109 | def walrus_expr(x): 110 | if (y := x + 1) > 0: 111 | s = 1 112 | else: 113 | s = -1 114 | return s * y 115 | 116 | 117 | def return_nothing(x): 118 | if x > 0: 119 | return 120 | else: 121 | return 1 122 | 123 | 124 | def no_return(x): 125 | s = x 126 | 127 | 128 | def return_end(x): 129 | s = x 130 | return 131 | 132 | 133 | def annotated_assign(x): 134 | s: int = 15 135 | return s + x 136 | 137 | 138 | def conditional_assign(x): 139 | s = 1 140 | if x > 0: 141 | s = 2 142 | b = 3 143 | return b 144 | 145 | 146 | def return_constant(x): 147 | return 1 148 | 149 | 150 | def return_constant_2(x): 151 | return 1 + 2 152 | 153 | 154 | def return_unconditional_constant(x): 155 | if x > 0: 156 | s = 1 157 | else: 158 | s = 2 159 | return 1 160 | 161 | 162 | def return_constant_additional_assignments(x): 163 | s = 2 164 | return 1 165 | 166 | 167 | def return_conditional_constant(x): 168 | if x > 0: 169 | return 1 170 | return 0 171 | 172 | 173 | def multiple_if(x): 174 | s = 1 175 | if x > 0: 176 | s = 2 177 | if x > 1: 178 | s = 3 179 | return s 180 | 181 | 182 | def multiple_if_else(x): 183 | if x > 0: 184 | s = 1 185 | elif x < 0: 186 | s = -1 187 | else: 188 | s = 0 189 | return s 190 | 191 | 192 | def nested_if_else(x): 193 | if x > 0: 194 | if x > 1: 195 | s = 2 196 | else: 197 | s = 1 198 | elif x < 0: 199 | s = -1 200 | else: 201 | s = 0 202 | return s 203 | 204 | 205 | def nested_if_else_expr(x): 206 | if x > 0: 207 | s = 2 if x > 1 else 1 208 | elif x < 0: 209 | s = -1 210 | else: 211 | s = 0 212 | return s 213 | 214 | 215 | def assignments_inside_branch(x): 216 | if x > 0: 217 | s = 1 218 | s = s + 1 219 | s = x * s 220 | elif x < 0: 221 | s = -1 222 | s = s - 1 223 | s = x 224 | else: 225 | s = 0 226 | return s 227 | 228 | 229 | def override_default(x): 230 | s = 0 231 | if x > 0: 232 | s = 10 233 | return x * s 234 | 235 | 236 | def no_if_else(x): 237 | s = x * 10 238 | k = x - 3 239 | k = k * 2 240 | return s * k 241 | 242 | 243 | def two_if_expr(x): 244 | a = 1 if x > 0 else 5 245 | b = 2 if x < 0 else 2 246 | return a + b 247 | 248 | 249 | def multiple_equals(x): 250 | a = b = 1 251 | return x + a + b 252 | 253 | 254 | def tuple_assignments(x): 255 | a, b = 1, x 256 | return x + a + b 257 | 258 | 259 | def list_assignments(x): 260 | [a, b] = 1, x 261 | return x + a + b 262 | 263 | 264 | def different_type_assignments(x): 265 | [a, b] = {1, 2} 266 | return x 267 | 268 | 269 | def star_assignments(x): 270 | b, *a = [1, 2] 271 | return x 272 | 273 | 274 | def global_variable(x): 275 | global a 276 | a = 1 277 | return x + a 278 | 279 | 280 | functions = [ 281 | signum, 282 | early_return, 283 | assign_both_branches, 284 | unary_expr, 285 | call_expr, 286 | if_expr, 287 | if_expr2, 288 | if_expr3, 289 | compare_expr, 290 | multiple_if_else, 291 | nested_if_else, 292 | nested_if_else_expr, 293 | assignments_inside_branch, 294 | override_default, 295 | no_if_else, 296 | two_if_expr, 297 | signum_no_default, 298 | nested_partial_return_with_assignments, 299 | multiple_equals, 300 | tuple_assignments, 301 | list_assignments, 302 | annotated_assign, 303 | conditional_assign, 304 | multiple_if, 305 | return_unconditional_constant, 306 | return_conditional_constant, 307 | *functions_310, 308 | ] 309 | 310 | xfail_functions = [ 311 | walrus_expr, 312 | # our test setup does not work with literal expressions 313 | return_constant, 314 | return_constant_2, 315 | return_constant_additional_assignments, 316 | different_type_assignments, 317 | star_assignments, 318 | global_variable, 319 | ] 320 | 321 | unsupported_functions = [ 322 | # function, match string in error message 323 | (chained_compare_expr, "Polars can't handle chained comparisons"), 324 | (bool_op, "ast.BoolOp"), # TODO: make error message more specific 325 | (return_end, "return needs a value"), 326 | (no_return, "Not all branches return"), 327 | (return_nothing, "return needs a value"), 328 | *unsupported_functions_310, 329 | ] 330 | -------------------------------------------------------------------------------- /tests/functions_310.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | # ruff must not change the AST of the test functions, even if they are semantically equivalent. 3 | 4 | 5 | def match_case(x): 6 | s = 0 7 | match x: 8 | case 0: 9 | s = 1 10 | case 2: 11 | s = -1 12 | case _: 13 | s = 0 14 | return s 15 | 16 | 17 | def match_with_or(x): 18 | match x: 19 | case 0 | 1: 20 | return 0 21 | case 2: 22 | return 2 * x 23 | case 3: 24 | return 3 * x 25 | return x 26 | 27 | 28 | def match_sequence(x): 29 | match x: 30 | case 0, 1: 31 | return 0 32 | case 2: 33 | return 2 * x 34 | case 3: 35 | return 3 * x 36 | return x 37 | 38 | 39 | def match_sequence_with_brackets(x): 40 | match x: 41 | case [0, 1]: 42 | return 0 43 | case 2: 44 | return 2 * x 45 | case 3: 46 | return 3 * x 47 | return x 48 | 49 | 50 | def match_assignments_inside_branch(x): 51 | match x: 52 | case 0: 53 | return 0 54 | case 1: 55 | return 2 * x 56 | case 2: 57 | return 3 * x 58 | return x 59 | 60 | 61 | def nested_match(x): 62 | match x: 63 | case 0: 64 | match x: 65 | case 0: 66 | return 1 67 | case 1: 68 | return 2 69 | return 3 70 | case 1: 71 | return 4 72 | return 5 73 | 74 | 75 | def match_compare_expr(x): 76 | match x: 77 | case 0: 78 | return 2 79 | case 1: 80 | return 1 81 | case 10: 82 | return 2 83 | return 1 84 | 85 | 86 | def match_nested_partial_return_with_assignments(x): 87 | match x: 88 | case 0: 89 | return -5 - x 90 | case 1: 91 | return 1 * x 92 | case 2: 93 | return 2 + x 94 | return -1 * x 95 | 96 | 97 | def match_signum(x): 98 | s = 0 99 | match x: 100 | case 0: 101 | s = 1 102 | case 2: 103 | s = -1 104 | case 3: 105 | s = 0 106 | return s 107 | 108 | 109 | def match_sequence_star(x): 110 | match x: 111 | case 0, *other: 112 | return other 113 | case 1: 114 | return 1 115 | case 2: 116 | return 2 117 | return x 118 | 119 | 120 | def match_multiple_variables(x): 121 | y = 3 122 | match x, y: 123 | case 1, 3: 124 | return 1 125 | case _: 126 | return 5 127 | 128 | 129 | def match_with_guard(x): 130 | match x: 131 | case 5 if x > 3: 132 | return 1 133 | case _: 134 | return 5 135 | 136 | 137 | def match_with_guard_variable(x): 138 | match x: 139 | case y if y > 5: 140 | return 1 141 | case _: 142 | return 5 143 | 144 | 145 | def match_with_guard_multiple_variable(x): 146 | y = 3 147 | match x, y: 148 | case 1, z if z > 3: 149 | return 1 150 | case z, 3 if z > 3: 151 | return 2 152 | case _: 153 | return 5 154 | 155 | 156 | def match_sequence_incomplete(x): 157 | y = 2 158 | z = 3 159 | match x, y, z: 160 | case 0, 1, 2: 161 | return 0 162 | case 1, 2: 163 | return 1 164 | case 2: 165 | return 2 166 | return x 167 | 168 | 169 | def match_mapping(x): 170 | match x: 171 | case {1: 2}: 172 | return 1 173 | case _: 174 | return x 175 | 176 | 177 | def multiple_match(x): 178 | match x: 179 | case 0: 180 | return 1 181 | case 1: 182 | return 2 183 | match x: 184 | case 0: 185 | return 3 186 | case 1: 187 | return 4 188 | return x 189 | 190 | 191 | def match_with_assignment(x): 192 | match x: 193 | case y if x > 1: 194 | y = y * 2 195 | return y 196 | case _: 197 | return x 198 | 199 | 200 | def match_with_assignment_hard(x): 201 | match x: 202 | case y if x > 1: 203 | y = y * 2 204 | case _: 205 | return x 206 | 207 | return y + 2 208 | 209 | 210 | def match_complex_subject(x): 211 | match x + 2: 212 | case 3: 213 | return 1 214 | case _: 215 | return x 216 | 217 | 218 | def match_guarded_match_as_no_return(x): 219 | match x: 220 | case 1: 221 | return 0 222 | case _ if x > 1: 223 | return 2 224 | 225 | 226 | def match_guarded_match_as(x): 227 | match x: 228 | case 1: 229 | return 0 230 | case _ if x > 1: 231 | return 2 232 | 233 | return 3 234 | 235 | 236 | def match_sequence_unmatchable_case_smaller(x): 237 | y = 2 238 | z = None 239 | 240 | match x, y, z: 241 | case 1, 2: 242 | return 1 243 | case _: 244 | return x 245 | 246 | 247 | def match_sequence_unmatchable_case_larger(x): 248 | y = 2 249 | z = None 250 | 251 | match x, y: 252 | case 1, 2, 3: 253 | return 1 254 | case _: 255 | return x * 2 256 | 257 | 258 | def match_sequence_unmatchable_case_smaller_return(x): 259 | y = 1 260 | z = 2 261 | 262 | match x, y, z: 263 | case 1, 2: 264 | x = 4 265 | return 1 266 | return x 267 | 268 | 269 | def match_sequence_unmatchable_case(x): 270 | y = 1 271 | z = 2 272 | 273 | match x, y, z: 274 | case 1, 2: 275 | return 1 276 | case 3, 4: 277 | return -1 278 | case 1, 2, 3: 279 | return 2 280 | return x 281 | 282 | 283 | def match_guard_no_assignation(x): 284 | match x: 285 | case _ if x > 1: 286 | return 0 287 | case _: 288 | return 2 289 | 290 | 291 | functions_310 = [ 292 | nested_match, 293 | match_assignments_inside_branch, 294 | match_signum, 295 | match_nested_partial_return_with_assignments, 296 | match_compare_expr, 297 | match_case, 298 | match_with_or, 299 | match_multiple_variables, 300 | match_with_guard, 301 | match_with_guard_variable, 302 | match_with_guard_multiple_variable, 303 | match_sequence_incomplete, 304 | multiple_match, 305 | match_with_assignment, 306 | match_with_assignment_hard, 307 | match_complex_subject, 308 | match_guarded_match_as, 309 | match_guard_no_assignation, 310 | match_sequence_unmatchable_case, 311 | match_sequence_unmatchable_case_smaller, 312 | match_sequence_unmatchable_case_smaller_return, 313 | match_sequence_unmatchable_case_larger, 314 | ] 315 | 316 | unsupported_functions_310 = [ 317 | (match_mapping, "ast.MatchMapping"), 318 | (match_sequence_star, "starred patterns are not supported."), 319 | (match_sequence, "Matching lists is not supported."), 320 | (match_sequence_with_brackets, "Matching lists is not supported."), 321 | (match_guarded_match_as_no_return, "Not all branches return"), 322 | ] 323 | -------------------------------------------------------------------------------- /tests/test_error_handling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from polarify import polarify 4 | 5 | from .functions import unsupported_functions 6 | 7 | 8 | @pytest.mark.parametrize("func_match", unsupported_functions) 9 | def test_unsupported_functions(func_match): 10 | func, match = func_match 11 | with pytest.raises(ValueError, match=match): 12 | polarify(func) 13 | -------------------------------------------------------------------------------- /tests/test_parse_body.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import polars 4 | import pytest 5 | from hypothesis import given 6 | from hypothesis.strategies import integers 7 | from packaging.version import Version 8 | from polars import __version__ as _pl_version 9 | from polars.testing import assert_frame_equal 10 | from polars.testing.parametric import column, dataframes 11 | 12 | from polarify import polarify, transform_func_to_new_source 13 | 14 | from .functions import functions, xfail_functions 15 | 16 | pl_version = Version(_pl_version) 17 | 18 | 19 | @pytest.fixture( 20 | scope="module", 21 | params=functions 22 | + [pytest.param(f, marks=pytest.mark.xfail(reason="not implemented")) for f in xfail_functions], 23 | ) 24 | def funcs(request): 25 | original_func = request.param 26 | transformed_func = polarify(original_func) 27 | original_func_unparsed = inspect.getsource(original_func) 28 | # build ast from transformed function as format as string 29 | transformed_func_unparsed = transform_func_to_new_source(original_func) 30 | print(f"Original:\n{original_func_unparsed}\nTransformed:\n{transformed_func_unparsed}") 31 | return transformed_func, original_func 32 | 33 | 34 | # chunking + apply is broken for polars < 0.18.1 35 | # https://github.com/pola-rs/polars/pull/9211 36 | # only relevant for our test setup, not for the library itself 37 | @given( 38 | df=dataframes( 39 | column("x", dtype=polars.Int64, strategy=integers(-100, 100)), 40 | min_size=1, 41 | chunked=False if pl_version < Version("0.18.1") else None, 42 | ) 43 | ) 44 | def test_transform_function(df: polars.DataFrame, funcs): 45 | x = polars.col("x") 46 | transformed_func, original_func = funcs 47 | 48 | if pl_version < Version("0.19.0"): 49 | df_with_transformed_func = df.select(transformed_func(x).alias("apply")) 50 | df_with_applied_func = df.apply(lambda r: original_func(r[0])) 51 | else: 52 | df_with_transformed_func = df.select(transformed_func(x).alias("map")) 53 | df_with_applied_func = df.map_rows(lambda r: original_func(r[0])) 54 | 55 | if pl_version < Version("0.20"): 56 | assert_frame_equal( 57 | df_with_transformed_func, 58 | df_with_applied_func, 59 | check_dtype=False, 60 | ) 61 | else: 62 | assert_frame_equal( 63 | df_with_transformed_func, 64 | df_with_applied_func, 65 | check_dtypes=False, 66 | ) 67 | --------------------------------------------------------------------------------