├── .gitattributes ├── .github ├── FUNDING.yml └── workflows │ ├── CI.yml │ └── python-tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE.txt ├── Makefile ├── README.md ├── SKLEARN_COMPATIBILITY.md ├── benchmarks ├── benchmark.py ├── benchmarks.ipynb ├── generate_benchmark_data.py ├── linear_regression.ipynb ├── requirements-benchmarks.txt ├── test_linear_regression.py └── test_metrics.py ├── docs ├── compat.md ├── eda.md ├── expr_iter.md ├── expr_knn.md ├── expr_linear.md ├── index.md ├── linear_models.md ├── metrics.md ├── num.md ├── pipeline.md ├── polars_ds.md ├── requirements-docs.txt ├── sample_and_split.md ├── spatial.md ├── stats.md ├── string.md └── ts_features.md ├── examples ├── auto_complete.png ├── basics.ipynb ├── dependency.parquet ├── eda.ipynb ├── pipeline.ipynb └── sample_and_split.ipynb ├── mkdocs.yml ├── pyproject.toml ├── python └── polars_ds │ ├── __init__.py │ ├── _utils.py │ ├── compat │ ├── __init__.py │ └── _compat.py │ ├── config.py │ ├── eda │ ├── __init__.py │ ├── diagnosis.py │ └── plots.py │ ├── exprs │ ├── __init__.py │ ├── expr_balltree.py │ ├── expr_iter.py │ ├── expr_knn.py │ ├── expr_linear.py │ ├── metrics.py │ ├── num.py │ ├── stats.py │ ├── string.py │ └── ts_features.py │ ├── linear_models.py │ ├── modeling │ ├── __init__.py │ ├── _step.py │ ├── pipeline.py │ └── transforms.py │ ├── partition │ ├── __init__.py │ └── partition.py │ ├── sample_and_split │ ├── __init__.py │ └── sample_and_split.py │ ├── spatial.py │ └── typing.py ├── requirements.txt ├── rust-toolchain.toml ├── src ├── arkadia │ ├── kdt.rs │ ├── leaf.rs │ ├── mod.rs │ ├── neighbor.rs │ └── utils.rs ├── lib.rs ├── linalg │ ├── glm_solvers.rs │ ├── link_functions.rs │ ├── lr_online_solvers.rs │ ├── lr_solvers.rs │ └── mod.rs ├── num_ext │ ├── ball_tree.rs │ ├── benford.rs │ ├── cond_entropy.rs │ ├── convolve.rs │ ├── entrophies.rs │ ├── fft.rs │ ├── float_extras.rs │ ├── gcd_lcm.rs │ ├── haversine.rs │ ├── isotonic_regression.rs │ ├── iterations.rs │ ├── jaccard.rs │ ├── knn.rs │ ├── lempel_ziv.rs │ ├── linear_regression.rs │ ├── linear_regression_f32.rs │ ├── mod.rs │ ├── mutual_info.rs │ ├── pca.rs │ ├── psi.rs │ ├── subseq_sim.rs │ ├── target_encode.rs │ ├── tp_fp.rs │ ├── trapz.rs │ ├── welch.rs │ └── woe_iv.rs ├── pymodels │ ├── mod.rs │ ├── py_glm.rs │ ├── py_kdt.rs │ └── py_lr.rs ├── stats │ ├── chi2.rs │ ├── fstats.rs │ ├── kendall_tau.rs │ ├── ks.rs │ ├── mann_whitney_u.rs │ ├── mod.rs │ ├── normal_test.rs │ ├── sample.rs │ ├── t_test.rs │ └── xi_corr.rs ├── stats_utils │ ├── beta.rs │ ├── gamma.rs │ ├── mod.rs │ └── normal.rs ├── str_ext │ ├── fuzz.rs │ ├── generic_str_distancer.rs │ ├── hamming.rs │ ├── inflections.rs │ ├── jaro.rs │ ├── lcs_seq.rs │ ├── lcs_str.rs │ ├── levenshtein.rs │ ├── mod.rs │ ├── nearest_str.rs │ ├── osa.rs │ ├── overlap.rs │ ├── sorensen_dice.rs │ ├── str_cleaning.rs │ ├── str_jaccard.rs │ └── tversky.rs └── utils │ └── mod.rs └── tests ├── requirements-test.txt ├── test_compat.py ├── test_linear_models.py ├── test_many.py ├── test_metrics.py ├── test_spatial_queries.py ├── test_string.py └── test_transforms.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb -linguist-detectable 2 | tests/* -linguist-detectable 3 | benchmarks/* -linguist-detectable 4 | examples/* -linguist-detectable -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: abstractqqq -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | permissions: 4 | id-token: write 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | tags: 11 | - "*" 12 | pull_request: 13 | workflow_dispatch: 14 | inputs: 15 | # Latest commit to include with the release. If omitted, use the latest commit on the main branch. 16 | sha: 17 | description: Commit SHA 18 | type: string 19 | 20 | defaults: 21 | run: 22 | shell: bash 23 | 24 | env: 25 | PYTHON_VERSION: '3.9' 26 | 27 | jobs: 28 | create-sdist: 29 | runs-on: ubuntu-latest 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | package: [polars_ds] 34 | 35 | steps: 36 | - uses: actions/checkout@v4 37 | with: 38 | ref: ${{ inputs.sha }} 39 | 40 | - name: Set up Python 41 | uses: actions/setup-python@v5 42 | with: 43 | python-version: ${{ env.PYTHON_VERSION }} 44 | 45 | - name: Create source distribution 46 | uses: PyO3/maturin-action@v1 47 | with: 48 | command: sdist 49 | args: > 50 | --manifest-path Cargo.toml 51 | --out dist 52 | maturin-version: 1.7.4 53 | 54 | - name: Test sdist 55 | run: | 56 | pip install -r requirements.txt 57 | pip install typing_extensions 58 | pip install --force-reinstall --verbose dist/*.tar.gz 59 | python -c 'import polars_ds as pds' 60 | python -c 'from polars_ds import linear_models' 61 | python -c 'from polars_ds.spatial import *' 62 | python -c 'from polars_ds.sample_and_split import *' 63 | python -c 'from polars_ds.exprs.ts_features import *' 64 | 65 | - name: Upload sdist 66 | uses: actions/upload-artifact@v4 67 | with: 68 | name: sdist-${{ matrix.package }} 69 | path: dist/*.tar.gz 70 | 71 | 72 | build-wheels: 73 | runs-on: ${{ matrix.os }} 74 | needs: [create-sdist] 75 | strategy: 76 | fail-fast: false 77 | matrix: 78 | package: [polars_ds] 79 | os: [ubuntu-latest, macos-13, windows-latest] 80 | architecture: [x86-64, aarch64] 81 | exclude: 82 | - os: windows-latest 83 | architecture: aarch64 84 | 85 | steps: 86 | - uses: actions/checkout@v4 87 | with: 88 | ref: ${{ inputs.sha }} 89 | 90 | - name: Set up Python 91 | uses: actions/setup-python@v5 92 | with: 93 | python-version: ${{ env.PYTHON_VERSION }} 94 | 95 | - name: Determine CPU features for x86-64 96 | id: features 97 | if: matrix.architecture == 'x86-64' 98 | 99 | # env: 100 | # IS_LTS_CPU: ${{ matrix.package == 'polars_ds_lts_cpu' }} 101 | 102 | # if [[ "$IS_LTS_CPU" = true ]]; then 103 | # FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b 104 | # CC_FEATURES="-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16" 105 | # else 106 | # fi 107 | run: | 108 | TUNE_CPU=skylake 109 | FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+cmpxchg16b,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq,+movbe 110 | CC_FEATURES="-msse3 -mssse3 -msse4.1 -msse4.2 -mpopcnt -mcx16 -mavx -mavx2 -mfma -mbmi -mbmi2 -mlzcnt -mpclmul -mmovbe" 111 | 112 | echo "features=$FEATURES" >> $GITHUB_OUTPUT 113 | echo "tune_cpu=$TUNE_CPU" >> $GITHUB_OUTPUT 114 | echo "cc_features=$CC_FEATURES" >> $GITHUB_OUTPUT 115 | 116 | - name: Set RUSTFLAGS for x86-64 117 | if: matrix.architecture == 'x86-64' 118 | env: 119 | FEATURES: ${{ steps.features.outputs.features }} 120 | TUNE_CPU: ${{ steps.features.outputs.tune_cpu }} 121 | CC_FEATURES: ${{ steps.features.outputs.cc_features }} 122 | # CFG: ${{ matrix.package == 'polars_ds_lts_cpu' && '--cfg allocator="default"' || '' }} 123 | # add $CFG 124 | run: | 125 | if [[ -z "$TUNE_CPU" ]]; then 126 | echo "RUSTFLAGS=-C target-feature=$FEATURES" >> $GITHUB_ENV 127 | echo "CFLAGS=$CC_FEATURES" >> $GITHUB_ENV 128 | else 129 | echo "RUSTFLAGS=-C target-feature=$FEATURES -Z tune-cpu=$TUNE_CPU" >> $GITHUB_ENV 130 | echo "CFLAGS=$CC_FEATURES -mtune=$TUNE_CPU" >> $GITHUB_ENV 131 | fi 132 | 133 | - name: Set Rust target for aarch64 134 | if: matrix.architecture == 'aarch64' 135 | id: target 136 | run: | 137 | TARGET=$( 138 | if [[ "${{ matrix.os }}" == "macos-13" ]]; then 139 | echo "aarch64-apple-darwin"; 140 | else 141 | echo "aarch64-unknown-linux-gnu"; 142 | fi 143 | ) 144 | echo "target=$TARGET" >> $GITHUB_OUTPUT 145 | 146 | - name: Set jemalloc for aarch64 Linux 147 | if: matrix.architecture == 'aarch64' && matrix.os == 'ubuntu-latest' 148 | run: | 149 | echo "JEMALLOC_SYS_WITH_LG_PAGE=16" >> $GITHUB_ENV 150 | 151 | - name: Build wheel 152 | uses: PyO3/maturin-action@v1 153 | with: 154 | command: build 155 | target: ${{ steps.target.outputs.target }} 156 | args: > 157 | --release 158 | --manifest-path Cargo.toml 159 | --out dist 160 | manylinux: ${{ matrix.architecture == 'aarch64' && '2_24' || 'auto' }} 161 | maturin-version: 1.7.4 162 | 163 | - name: Test wheel 164 | # Only test on x86-64 for now as this matches the runner architecture 165 | if: matrix.architecture == 'x86-64' 166 | run: | 167 | pip install --force-reinstall --verbose dist/*.whl 168 | pip install typing_extensions 169 | pip install -r requirements.txt 170 | python -c 'import polars_ds' 171 | python -c 'from polars_ds import linear_models' 172 | python -c 'from polars_ds.spatial import *' 173 | python -c 'from polars_ds.sample_and_split import *' 174 | 175 | - name: Upload wheel 176 | uses: actions/upload-artifact@v4 177 | with: 178 | name: wheel-${{ matrix.package }}-${{ matrix.os }}-${{ matrix.architecture }} 179 | path: dist/*.whl 180 | 181 | release: 182 | name: Release 183 | runs-on: ubuntu-latest 184 | if: "startsWith(github.ref, 'refs/tags/')" 185 | needs: [build-wheels, create-sdist] 186 | permissions: 187 | id-token: write 188 | steps: 189 | - uses: actions/download-artifact@v4 190 | with: 191 | pattern: wheel-* 192 | merge-multiple: true 193 | 194 | - uses: actions/download-artifact@v4 195 | with: 196 | pattern: sdist-* 197 | merge-multiple: true 198 | 199 | - name: Publish to PyPI 200 | uses: PyO3/maturin-action@v1 201 | with: 202 | command: upload 203 | args: --non-interactive --skip-existing * 204 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Test Python 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - Cargo.lock 7 | - python/** 8 | - src/** 9 | - .github/workflows/python-tests.yml 10 | push: 11 | branches: 12 | - main 13 | paths: 14 | - Cargo.lock 15 | - python/** 16 | - src/** 17 | - .github/workflows/python-tests.yml 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.ref }} 21 | cancel-in-progress: true 22 | 23 | env: 24 | RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down 25 | RUST_BACKTRACE: 1 26 | PYTHONUTF8: 1 27 | 28 | defaults: 29 | run: 30 | shell: bash 31 | 32 | jobs: 33 | test-python: 34 | runs-on: ${{ matrix.os }} 35 | strategy: 36 | fail-fast: false 37 | matrix: 38 | os: [ubuntu-latest] 39 | python-version: ['3.9', '3.12', '3.13'] 40 | # include: 41 | # - os: windows-latest 42 | # python-version: '3.13' 43 | 44 | steps: 45 | - uses: actions/checkout@v4 46 | 47 | - name: Set up Python 48 | uses: actions/setup-python@v5 49 | with: 50 | python-version: ${{ matrix.python-version }} 51 | 52 | - name: Create virtual environment 53 | env: 54 | BIN: ${{ matrix.os == 'windows-latest' && 'Scripts' || 'bin' }} 55 | run: | 56 | python -m venv .venv 57 | echo "$GITHUB_WORKSPACE/.venv/$BIN" >> $GITHUB_PATH 58 | echo "VIRTUAL_ENV=$GITHUB_WORKSPACE/.venv" >> $GITHUB_ENV 59 | 60 | - name: Install package 61 | run: | 62 | pip install -r requirements.txt 63 | maturin develop --release 64 | 65 | - name: Install Python dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | pip install jupyter ipython ipykernel nbconvert pytest 69 | pip install -r tests/requirements-test.txt 70 | 71 | - name: Test Notebooks 72 | run: | 73 | jupyter execute examples/basics.ipynb 74 | jupyter execute examples/pipeline.ipynb 75 | jupyter execute examples/eda.ipynb 76 | jupyter execute examples/sample_and_split.ipynb 77 | 78 | - name: Test Pytests 79 | run: pytest tests/test_* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .pickle 3 | 4 | # Local, quick adhoc test only purpose 5 | tests/*.ipynb 6 | tests/test.ipynb 7 | tests/sample.csv 8 | tests/test*.csv 9 | 10 | /target 11 | 12 | # Mkdocs 13 | site/ 14 | 15 | # Ruff 16 | .ruff_cache/ 17 | 18 | # Memray 19 | tests/*.bin 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | .pytest_cache/ 24 | *.py[cod] 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | .venv/ 32 | env/ 33 | bin/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | include/ 44 | man/ 45 | venv/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | pip-selfcheck.json 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | 63 | # Translations 64 | *.mo 65 | 66 | # Mr Developer 67 | .mr.developer.cfg 68 | .project 69 | .pydevproject 70 | 71 | # Rope 72 | .ropeproject 73 | 74 | # Django stuff: 75 | *.log 76 | *.pot 77 | 78 | .DS_Store 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyCharm 84 | .idea/ 85 | 86 | # VSCode 87 | .vscode/ 88 | 89 | # Pyenv 90 | .python-version 91 | 92 | # Polars Extension 93 | .so 94 | .dll -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.1 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | types_or: [ python, pyi] 9 | args: [ --fix ] 10 | # Run the formatter. 11 | - id: ruff-format 12 | types_or: [ python, pyi] -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for MkDocs projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the version of Python and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.11" 12 | 13 | mkdocs: 14 | configuration: mkdocs.yml 15 | 16 | # Optionally declare the Python requirements required to build your docs 17 | python: 18 | install: 19 | - requirements: docs/requirements-docs.txt -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Simple Guidelines 2 | 3 | For all feature related work, it would be great to ask yourself the following questions before submitting a PR: 4 | 5 | 1. Is your code correct? Proof of correctness and at least one Python side test. It is ok to test against well-known packages. Don't forget to add to requirements-test.txt if more packages need to be downloaded for tests. If it is a highly technical topic, can you provide some reference docs? 6 | 2. Is your code faster/as fast as SciPy, NumPy, Scikit-learn? **It is ok to be slower**. Do you have any idea why? Is it because of data structure? Some data input/copy issues that make performance hard to achieve? Again, it is ok to be slower when the functionality provides more convenience to users. It will be great to provide some notes/comments. We have the power of Rust so thereotically we can optimize many things away. 7 | 3. Are you using a lot of unwraps in your Rust code? Are these unwraps justified? Same for unsafe code. 8 | 4. If an additional dependency is needed, how much of it is really used? Will it bloat the package? What other features can we write with the additional dependency? I would discourage the addition of a dependency if we are using 1 or 2 function out of that package. 9 | 5. **Everything can be discussed**. 10 | 6. New feature are generally welcome if it is commonly used in some field. 11 | 12 | 13 | ## Remember to run these before committing: 14 | 1. pre-commit. We use ruff. 15 | 2. cargo fmt 16 | 17 | ## How to get started? 18 | 19 | Take a look at the Makefile. Set up your environment first. Then take a look at the tutorial [here](https://github.com/MarcoGorelli/polars-plugins-tutorial), and grasp the basics of maturin [here](https://www.maturin.rs/tutorial). 20 | 21 | Then find a issue/feature that you want to improve/implement! 22 | 23 | ## A word on Doc, Typo related PRs 24 | 25 | For docs and typo fix PRs, we welcome changes that: 26 | 27 | 1. Fix actual typos and please do not open a PR for each typo. 28 | 29 | 2. Add explanations, docstrings for previously undocumented features/code. 30 | 31 | 3. Improve clarification for terms, explanations, docs, or docstrings. 32 | 33 | 4. Fix actual broken UI/style components in doc/readme. 34 | 35 | Simple stylistic change/reformatting that doesn't result in any significant change in looks, or doesn't fix any previously noted problems will not be approved. 36 | 37 | Please understand, and thank you for your time. -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "polars_ds" 3 | version = "0.9.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "_polars_ds" 8 | crate-type = ["cdylib"] 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | # PyO3 14 | numpy = "0.23" # matrix interop 15 | pyo3 = {version = "0.23", features = ["abi3-py39", "extension-module"]} 16 | pyo3-polars = {version = "0.20", features = ["derive", "dtype-array", "dtype-struct"]} 17 | # Polars 18 | polars = {version = "0.46", features = ["performant", "lazy", 19 | "diff", "array_count", "abs", "cross_join", "rank", "log", 20 | "cum_agg", "round_series", "nightly","dtype-array", "dtype-struct", "dtype-i128"], default-features = false} 21 | polars-arrow = "0.46.0" 22 | # Numerical / Linear Algebra 23 | rand = "0.8.5" 24 | rand_distr = "0.4.3" 25 | realfft = "3.3.0" 26 | num = "0.4.1" 27 | ordered-float = "4.2.0" 28 | approx = "*" 29 | faer = {version = "0.22", default-features = false, features = ["nightly", "rayon", "sparse-linalg"]} 30 | faer-ext = { version = "0.6.0", features = ["numpy", "ndarray"] } 31 | faer-traits = {version = "0.22"} 32 | ndarray = {version="0.16"} 33 | cfavml = {version = "0.3.0", features=["nightly"]} # easy simd, wait till Pulp gets better then we can replace 34 | # Data Structures, Iteration Helpers 35 | itertools = "0.12.0" 36 | ahash = ">=0.8.5" 37 | hashbrown = {version = "0.15", features=["nightly"]} 38 | # Serialization 39 | serde = {version = "*", features=["derive"]} 40 | # String related 41 | rapidfuzz = "0.5.0" 42 | inflections = "1.1.1" 43 | # Deprecated 44 | unicode-normalization = "0.1.23" 45 | 46 | [profile.release] 47 | codegen-units = 1 48 | strip = "symbols" 49 | lto = "fat" 50 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 T. Qin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/bin/bash 2 | 3 | VENV=.venv 4 | 5 | ifeq ($(OS),Windows_NT) 6 | VENV_BIN=$(VENV)/Scripts 7 | else 8 | VENV_BIN=$(VENV)/bin 9 | endif 10 | 11 | .venv: 12 | python3 -m venv $(VENV) 13 | $(MAKE) requirements 14 | 15 | requirements: .venv 16 | @unset CONDA_PREFIX \ 17 | && $(VENV_BIN)/python -m pip install --upgrade uv \ 18 | && $(VENV_BIN)/uv pip install --upgrade -r requirements.txt \ 19 | && $(VENV_BIN)/uv pip install --upgrade -r tests/requirements-test.txt \ 20 | && $(VENV_BIN)/uv pip install --upgrade -r docs/requirements-docs.txt \ 21 | 22 | dev-release: .venv 23 | unset CONDA_PREFIX && \ 24 | source .venv/bin/activate && maturin develop --release -m Cargo.toml 25 | 26 | pre-commit: .venv 27 | cargo fmt 28 | pre-commit run --all-files 29 | 30 | clean: 31 | rm -f examples/*.json 32 | rm -f examples/*.pickle 33 | rm -rf .ruff_cache/ 34 | rm -rf .pytest_cache/ 35 | -------------------------------------------------------------------------------- /SKLEARN_COMPATIBILITY.md: -------------------------------------------------------------------------------- 1 | # Not Really Standard 2 | 3 | Scikit-learn is a great package but that doesn't mean it is perfect nor it cannot be improved. I use Scikit-learn daily and I think I have my fair share of critiques on it. There are many other things to say about Scikit-learn, such as it is slow, hard to get true parallelism working, not Dataframe-centric, the functions/Transformers it provides are not expressive enough, not easy to serialize, or jsonify, etc. All these issues are addressed or partially addressed by PDS (polars_ds). However, my biggest complaint is about the Pipeline API in Scikit-learn. 4 | 5 | ![auto_complete](examples/auto_complete.png) 6 | 7 | PDS offers an extensible class for the most common ML pipeline transforms. You can also extend and customize it, which you can find in examples/pipeline.ipynb. You can even completely describe a pipeline with transforms using dictionaries and lists if you want. You don't have to remember the names of these transformers becasue your linter will auto complete for you. The pipeline construction is designed to be dataframe centric and as expressive as possible: any data scientist who spends a few seconds looking at it will know what it will do. Everything in PDS just comes with the package and requires no dependency, not even on SciPy or NumPy (besides when interacting with NumPy is necessary.) 8 | 9 | I also don't like the idea of putting model in a Pipeline, which Scikit-learn allows, but is not a good idea in my opinion. For isntance, there is no `.predict`, `.predict_proba` in polars_ds pipelines. Here are three reasons for my design decision: 10 | 11 | 1. Doing so complicates the expected output of the pipeline and makes the object too bloated. In practice, people track the **raw features**, **the transformed features**, and also the **model scores**. How do we get all 3 data points without evaluating the pipeline multiple times? In sklearn, the solution is that you will have to manually exclude the model, let the rest of the transformations run, then evaluate the model. This brings the question of why should models be included in the pipeline in the first place? 12 | 13 | 2. Feature transformations and model training should be isolated environments. The PDS pipeline will only take in raw features and get transformed features out. Models are separate entities and should be kept and managed separately. Say you are doing hyper-parameter training and you only want to tune the hyper-parameters of the model. But your model is inside a sklearn pipeline. What will happen every time you try a new set of hyper-parameters? The whole pipeline will run, despite the fact that the data tranformation part is identical each time. To avoid this, you will have to cache the data with options provided by sklearn. Why not do it naturally (using a data-transformation-only pipeline + a model), instead of relying on additional configuration? For things like caching, it is often better to let the user manage it themself than coming up with an option that might solve the issue for some people. Adding an option like that only adds complexity to the software and burden on the maintainers. 14 | 15 | 3. A cleaner and smaller API. The scikit-learn pipeline API provide calls like `.predict_proba()`, `.predict_log_proba`, `.score` etc. However, depending on the contents of the pipeline, they may not be avialable or may not be what one expects. Too many function calls work situationally. This can be naturally solved if the user manages the data-transformation and the model separately. By following this separation of concern, the user will always know what exactly is being used as the model and the user will know immediately what is available and what is not available on the model object. 16 | 17 | As it stands now, polars_ds pipelines can be turned into a Sklearn pipeline, but not the other way around, and I do not plan to support the other way around. Since polars_ds is Polars native, it can run much faster and have a (surprisingly) more flexible API in many cases. Passing selectors such as `cs.numerics()`, or `pl.col("*").exclude(...)` to sklearn-style pipeline step is just impossible because of their lack of native Polars support. 18 | 19 | I think there is a lack of spirit in exploring new APIs that actually faciliates developer experience and I do not want to confine myself to the pre-defined world of "works with Scikit-learn". That said, common names are respected and work as intended. The accepted and most common terminologies will be used to name the functions and methods. E.g `.fit(...)` and `.transform(...)` are still used in Polars DS as in the Scikit-learn sense. 20 | 21 | # Wrapping Polars DS pipelines inside a Sklearn Transformer 22 | 23 | It is possible, though not recommended. If there is a transform that you really want, please open a feature request. Thank you! You can find the dataset used in the example in ../examples/ folder on github. 24 | 25 | ```python 26 | 27 | import polars_ds.pipeline as pds_pipe 28 | from sklearn.pipeline import Pipeline 29 | from sklearn.base import BaseEstimator, TransformerMixin 30 | 31 | class CustomPDSTransformer(BaseEstimator, TransformerMixin): 32 | 33 | def __init__(self): 34 | self.pipe = None 35 | 36 | def fit(self, df, y=None): 37 | # specify all the rules for the transform here 38 | bp = ( 39 | pds_pipe.Blueprint(df, name = "example", target = "approved", lowercase=True) 40 | .filter( 41 | "city_category is not null" # or equivalently, you can do: pl.col("city_category").is_not_null() 42 | ) 43 | .select(cs.numeric() | cs.by_name(["gender", "employer_category1", "city_category", "test_col"])) 44 | .linear_impute(features = ["var1", "existing_emi"], target = "loan_period") 45 | .impute(["existing_emi"], method = "median") 46 | ) 47 | self.pipe = bp.materialize() 48 | return self 49 | 50 | def transform(self, df, y=None): 51 | return self.pipe.transform(df) 52 | 53 | # --------------------------------------------------------------- 54 | 55 | df = pl.read_parquet("../examples/dependency.parquet") 56 | df.head() 57 | 58 | pipe = Pipeline( 59 | steps=[ 60 | ("CustomPDSTransformer", CustomPDSTransformer()) 61 | ] 62 | ) 63 | df_transformed = pipe.fit_transform(df) 64 | df_transformed 65 | ``` -------------------------------------------------------------------------------- /benchmarks/benchmark.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import timeit 3 | import unicodedata 4 | from pathlib import Path 5 | from typing import Callable 6 | 7 | import polars as pl 8 | import polars_ds as pds 9 | 10 | BASE_PATH = Path(__file__).resolve().parents[0] 11 | 12 | TIMING_RUNS = 10 13 | 14 | 15 | class Bench: 16 | def __init__( 17 | self, 18 | df: pl.DataFrame, 19 | sizes: list[int] = [100, 1_000, 10_000, 100_000, 1_000_000], 20 | timing_runs: int = 10, 21 | ): 22 | self.benchmark_data = {"Function": [], "Size": [], "Time": []} 23 | self.df = df 24 | self.sizes = sizes 25 | self.timing_runs = timing_runs 26 | 27 | def run(self, funcs: list[Callable]): 28 | for n_rows in self.sizes: 29 | df = self.df.sample(n_rows, seed=208) 30 | 31 | for f in funcs: 32 | func_name = f.func.__name__ if isinstance(f, functools.partial) else f.__name__ 33 | time = timeit.timeit(lambda: f(df), number=self.timing_runs) 34 | 35 | self.benchmark_data["Function"].append(func_name) 36 | self.benchmark_data["Size"].append(n_rows) 37 | self.benchmark_data["Time"].append(time) 38 | 39 | return self 40 | 41 | def save(self, file: Path): 42 | pl.DataFrame(self.benchmark_data).write_parquet(file) 43 | 44 | 45 | def python_remove_non_ascii(df: pl.DataFrame): 46 | df.select( 47 | pl.col("RANDOM_STRING").map_elements( 48 | lambda s: s.encode("ascii", errors="ignore").decode(), 49 | return_dtype=pl.String, 50 | ) 51 | ) 52 | 53 | 54 | def regex_remove_non_ascii(df: pl.DataFrame): 55 | df.select(pl.col("RANDOM_STRING").str.replace_all(r"[^\p{Ascii}]", "")) 56 | 57 | 58 | def pds_remove_non_ascii(df: pl.DataFrame): 59 | df.select(pds.replace_non_ascii("RANDOM_STRING")) 60 | 61 | 62 | def python_remove_diacritics(df: pl.DataFrame): 63 | df.select( 64 | pl.col("RANDOM_STRING").map_elements( 65 | lambda s: unicodedata.normalize("NFD", s).encode("ASCII", "ignore"), 66 | return_dtype=pl.String, 67 | ) 68 | ) 69 | 70 | 71 | def pds_remove_diacritics(df: pl.DataFrame): 72 | df.select(pds.remove_diacritics("RANDOM_STRING")) 73 | 74 | 75 | def python_normalize_string(df: pl.DataFrame): 76 | df.select( 77 | pl.col("RANDOM_STRING").map_elements( 78 | lambda s: unicodedata.normalize("NFD", s), return_dtype=pl.String 79 | ) 80 | ) 81 | 82 | 83 | def pds_normalize_string(df: pl.DataFrame): 84 | df.select(pds.normalize_string("RANDOM_STRING", "NFD")) 85 | 86 | 87 | def python_map_words(df: pl.DataFrame, mapping: dict[str, str]): 88 | df.select( 89 | pl.col("RANDOM_ADDRESS").map_elements( 90 | lambda s: " ".join(mapping.get(word, word) for word in s.split()), 91 | return_dtype=pl.String, 92 | ) 93 | ) 94 | 95 | 96 | def regex_map_words(df: pl.DataFrame, mapping: dict[str, str]): 97 | expr = pl.col("RANDOM_ADDRESS") 98 | for k, v in mapping.items(): 99 | expr = expr.str.replace_all(k, v) 100 | df.select(expr) 101 | 102 | 103 | def pds_map_words(df: pl.DataFrame, mapping: dict[str, str]): 104 | df.select(pds.map_words("RANDOM_ADDRESS", mapping)) 105 | 106 | 107 | def python_normalize_whitespace(df: pl.DataFrame): 108 | df.select( 109 | pl.col("RANDOM_STRING").map_elements(lambda s: " ".join(s.split()), return_dtype=pl.String) 110 | ) 111 | 112 | 113 | def python_normalize_whitespace_only_spaces(df: pl.DataFrame): 114 | df.select( 115 | pl.col("RANDOM_STRING").map_elements( 116 | lambda s: " ".join(s.split(" ")), return_dtype=pl.String 117 | ) 118 | ) 119 | 120 | 121 | def expr_normalize_whitespace_only_spaces(df: pl.DataFrame): 122 | df.select(pl.col("RANDOM_STRING").str.split(" ").list.join(" ")) 123 | 124 | 125 | def pds_normalize_whitespace(df: pl.DataFrame): 126 | df.select(pds.normalize_whitespace("RANDOM_STRING")) 127 | 128 | 129 | def pds_normalize_whitespace_only_spaces(df: pl.DataFrame): 130 | df.select(pds.normalize_whitespace("RANDOM_STRING", only_spaces=True)) 131 | 132 | 133 | def main(): 134 | benchmark_df = pl.read_parquet(BASE_PATH / "benchmark_df.parquet") 135 | 136 | map_words_mapping = { 137 | "Apt.": "Apartment", 138 | "NY": "New York", 139 | "CT": "Connecticut", 140 | "Street": "ST", 141 | "Bypass": "BYP", 142 | "GA": "Georgia", 143 | "Parkways": "Pkwy", 144 | "PA": "Pennsylvania", 145 | } 146 | 147 | Bench(benchmark_df).run( 148 | [ 149 | python_remove_non_ascii, 150 | regex_remove_non_ascii, 151 | pds_remove_non_ascii, 152 | python_remove_diacritics, 153 | pds_remove_diacritics, 154 | python_normalize_string, 155 | pds_normalize_string, 156 | functools.partial(python_map_words, mapping=map_words_mapping), 157 | functools.partial( 158 | regex_map_words, 159 | mapping={f"\b{k}\b": v for k, v in map_words_mapping.items()}, 160 | ), 161 | functools.partial(pds_map_words, mapping=map_words_mapping), 162 | python_normalize_whitespace, 163 | python_normalize_whitespace_only_spaces, 164 | expr_normalize_whitespace_only_spaces, 165 | pds_normalize_whitespace, 166 | pds_normalize_whitespace_only_spaces, 167 | ] 168 | ).save(BASE_PATH / "benchmark_data.parquet") 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /benchmarks/generate_benchmark_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from pathlib import Path 4 | 5 | import polars as pl 6 | from faker import Faker 7 | 8 | random.seed(208) 9 | 10 | BASE_PATH = Path(__file__).resolve().parents[0] 11 | 12 | INT32_RANGE = (-(2**31), 2**31 - 1) 13 | 14 | 15 | UNICODE_ALPHABET = [ 16 | chr(code_point) 17 | for current_range in [ 18 | (0x0021, 0x0021), 19 | (0x0023, 0x0026), 20 | (0x0028, 0x007E), 21 | (0x00A1, 0x00AC), 22 | (0x00AE, 0x00FF), 23 | (0x0100, 0x017F), 24 | (0x0180, 0x024F), 25 | (0x2C60, 0x2C7F), 26 | (0x16A0, 0x16F0), 27 | (0x0370, 0x0377), 28 | (0x037A, 0x037E), 29 | (0x0384, 0x038A), 30 | (0x038C, 0x038C), 31 | ] 32 | for code_point in range(current_range[0], current_range[1] + 1) 33 | ] 34 | 35 | 36 | def random_unicode(length: int) -> str: 37 | # https://stackoverflow.com/questions/1477294/generate-random-utf-8-string-in-python 38 | return "".join(random.choices(UNICODE_ALPHABET, k=length)) 39 | 40 | 41 | def random_ascii(length: int) -> str: 42 | return "".join(random.choices(string.printable, k=length)) 43 | 44 | 45 | class DataGenerator: 46 | def __init__(self, height: int): 47 | self.height = height 48 | 49 | def random_unicode_column(self, min_length: int, max_length: int) -> list[str]: 50 | column = [] 51 | for _ in range(self.height): 52 | length = random.randint(min_length, max_length) 53 | column.append(random_unicode(length)) 54 | 55 | return column 56 | 57 | def random_ascii_column(self, min_length: int, max_length: int) -> list[str]: 58 | column = [] 59 | for _ in range(self.height): 60 | length = random.randint(min_length, max_length) 61 | column.append(random_ascii(length)) 62 | 63 | return column 64 | 65 | def random_address_column(self) -> list[str]: 66 | column = [] 67 | fake = Faker() 68 | for i in range(self.height): 69 | Faker.seed(i) 70 | column.append(fake.address()) 71 | 72 | return column 73 | 74 | def random_integer_column(self, min_value: int, max_value: int) -> list[int]: 75 | return [random.randint(min_value, max_value) for _ in range(self.height)] 76 | 77 | def random_float_column(self, min_value: float, max_value: float) -> list[float]: 78 | return [random.uniform(min_value, max_value) for _ in range(self.height)] 79 | 80 | def generate(self) -> pl.DataFrame: 81 | min_length = 1 82 | max_length = 50 83 | 84 | return pl.DataFrame( 85 | { 86 | "RANDOM_STRING": self.random_unicode_column(min_length, max_length), 87 | "RANDOM_ASCII": self.random_ascii_column(min_length, max_length), 88 | "RANDOM_ADDRESS": self.random_address_column(), 89 | } 90 | ) 91 | 92 | 93 | def main(): 94 | DataGenerator(1_000_000).generate().write_parquet(BASE_PATH / "benchmark_df.parquet") 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /benchmarks/requirements-benchmarks.txt: -------------------------------------------------------------------------------- 1 | faker 2 | polars 3 | plotnine -------------------------------------------------------------------------------- /benchmarks/test_linear_regression.py: -------------------------------------------------------------------------------- 1 | import polars as pl 2 | import polars_ds as pds 3 | import numpy as np 4 | import scipy 5 | import pytest 6 | from polars_ds.linear_models import LR 7 | from sklearn.linear_model import Lasso, LinearRegression, Ridge 8 | 9 | SEED = 208 10 | 11 | SIZE = 100_000 12 | DF = ( 13 | pds.frame(size=SIZE) 14 | .select( 15 | pds.random(0.0, 1.0).alias("x1"), 16 | pds.random(0.0, 1.0).alias("x2"), 17 | pds.random(0.0, 1.0).alias("x3"), 18 | pds.random(0.0, 1.0).alias("x4"), 19 | pds.random(0.0, 1.0).alias("x5"), 20 | pds.random_int(0, 4).alias("code"), 21 | pl.Series(name="id", values=range(SIZE)), 22 | ) 23 | .with_columns( 24 | y=pl.col("x1") * 0.5 25 | + pl.col("x2") * 0.25 26 | - pl.col("x3") * 0.15 27 | + pl.col("x4") * 0.2 28 | - pl.col("x5") * 0.13 29 | + pds.random() * 0.0001, 30 | ) 31 | ) 32 | 33 | 34 | # Prepare data for Scikit-learn. We assume the Scikit-learn + Pandas combination. 35 | # One can simply replace to_pandas() by to_numpy() to test the Scikit-learn + NumPy combination 36 | PD_DF = DF.to_pandas() 37 | X_VARS = ["x1", "x2", "x3", "x4", "x5"] 38 | Y = ["y"] 39 | 40 | SIZES = [1_000, 10_000, 50_000, 100_000] 41 | 42 | 43 | @pytest.mark.parametrize("n", SIZES) 44 | @pytest.mark.benchmark(group="linear_on_matrix") 45 | def test_pds_linear_regression_on_matrix(benchmark, n): 46 | df = DF.sample(n=n, seed=SEED) 47 | X = df.select(*X_VARS).to_numpy() 48 | y = df.select(*Y).to_numpy() 49 | 50 | @benchmark 51 | def func(): 52 | model = LR() 53 | model.fit(X, y) 54 | 55 | 56 | @pytest.mark.parametrize("n", SIZES) 57 | @pytest.mark.benchmark(group="linear_on_matrix") 58 | def test_sklearn_linear_regression_on_matrix(benchmark, n): 59 | df = DF.sample(n=n, seed=SEED) 60 | X = df.select(*X_VARS).to_numpy() 61 | y = df.select(*Y).to_numpy() 62 | 63 | @benchmark 64 | def func(): 65 | reg = LinearRegression(fit_intercept=False, n_jobs=-1) 66 | reg.fit(X, y) 67 | 68 | 69 | @pytest.mark.parametrize("n", SIZES) 70 | @pytest.mark.benchmark(group="linear_on_matrix") 71 | def test_numpy_linear_regression_on_matrix(benchmark, n): 72 | df = DF.sample(n=n, seed=SEED) 73 | X = df.select(*X_VARS).to_numpy() 74 | y = df.select(*Y).to_numpy() 75 | 76 | @benchmark 77 | def func(): 78 | _ = np.linalg.lstsq(X, y) 79 | 80 | 81 | @pytest.mark.parametrize("n", SIZES) 82 | @pytest.mark.benchmark(group="linear_on_matrix") 83 | def test_scipy_linear_regression_on_matrix(benchmark, n): 84 | df = DF.sample(n=n, seed=SEED) 85 | X = df.select(*X_VARS).to_numpy() 86 | y = df.select(*Y).to_numpy() 87 | 88 | @benchmark 89 | def func(): 90 | _ = scipy.linalg.lstsq(X, y, check_finite=False) 91 | 92 | 93 | @pytest.mark.parametrize("n", SIZES) 94 | @pytest.mark.benchmark(group="linear_on_df") 95 | def test_pds_linear_regression_on_df(benchmark, n): 96 | df = DF.sample(n=n, seed=SEED) 97 | 98 | @benchmark 99 | def func(): 100 | df.select( 101 | pds.lin_reg( 102 | *X_VARS, 103 | target=Y[0], 104 | ) 105 | ) 106 | 107 | 108 | @pytest.mark.parametrize("n", SIZES) 109 | @pytest.mark.benchmark(group="linear_on_df") 110 | def test_sklearn_linear_regression_on_df(benchmark, n): 111 | df = PD_DF.sample(n=n, random_state=SEED) 112 | 113 | @benchmark 114 | def func(): 115 | reg = LinearRegression(fit_intercept=False, n_jobs=-1) 116 | reg.fit(df[X_VARS], df[Y]) 117 | 118 | 119 | @pytest.mark.parametrize("n", SIZES) 120 | @pytest.mark.benchmark(group="lasso_on_df") 121 | def test_pds_lasso_on_df(benchmark, n): 122 | df = DF.sample(n=n, seed=SEED) 123 | 124 | @benchmark 125 | def func(): 126 | df.select(pds.lin_reg(*X_VARS, target=Y[0], l1_reg=0.1)) 127 | 128 | 129 | @pytest.mark.parametrize("n", SIZES) 130 | @pytest.mark.benchmark(group="lasso_on_df") 131 | def test_sklearn_lasso_on_df(benchmark, n): 132 | df = PD_DF.sample(n=n, random_state=SEED) 133 | 134 | @benchmark 135 | def func(): 136 | reg = Lasso(alpha=0.1, fit_intercept=False) 137 | reg.fit(df[X_VARS], df[Y]) 138 | 139 | 140 | @pytest.mark.parametrize("n", SIZES) 141 | @pytest.mark.benchmark(group="ridge_svd_on_df") 142 | def test_ridge_svd_on_df(benchmark, n): 143 | df = DF.sample(n=n, seed=SEED) 144 | 145 | @benchmark 146 | def func(): 147 | df.select(pds.lin_reg(*X_VARS, target=Y[0], l2_reg=0.1, solver="svd")) 148 | 149 | 150 | @pytest.mark.parametrize("n", SIZES) 151 | @pytest.mark.benchmark(group="ridge_svd_on_df") 152 | def test_sklearn_ridge_svd_on_df(benchmark, n): 153 | df = PD_DF.sample(n=n, random_state=SEED) 154 | 155 | @benchmark 156 | def func(): 157 | reg = Ridge(alpha=0.1, fit_intercept=False, solver="svd") 158 | reg.fit(df[X_VARS], df[Y]) 159 | 160 | 161 | @pytest.mark.parametrize("n", SIZES) 162 | @pytest.mark.benchmark(group="ridge_cholesky") 163 | def test_ridge_cholesky_on_df(benchmark, n): 164 | df = DF.sample(n=n, seed=SEED) 165 | 166 | @benchmark 167 | def func(): 168 | df.select(pds.lin_reg(*X_VARS, target=Y[0], l2_reg=0.1, solver="cholesky")) 169 | 170 | 171 | @pytest.mark.parametrize("n", SIZES) 172 | @pytest.mark.benchmark(group="ridge_cholesky") 173 | def test_sklearn_ridge_cholesky_on_df(benchmark, n): 174 | df = PD_DF.sample(n=n, random_state=SEED) 175 | 176 | @benchmark 177 | def func(): 178 | reg = Ridge(alpha=0.1, fit_intercept=False, solver="cholesky") 179 | reg.fit(df[X_VARS], df[Y]) 180 | -------------------------------------------------------------------------------- /benchmarks/test_metrics.py: -------------------------------------------------------------------------------- 1 | import polars as pl 2 | import polars_ds as pds 3 | import pytest 4 | import scipy.stats 5 | import sklearn.metrics 6 | 7 | SEED = 208 8 | 9 | SIZE = 1_000_000 10 | DF = ( 11 | pds.random_data(size=SIZE, n_cols=2) 12 | .rename({"feature_1": "y_true_score", "feature_2": "y_score"}) 13 | .with_columns( 14 | pl.col("y_true_score").ge(0.5).alias("y_true"), 15 | ) 16 | .drop("row_num") 17 | ) 18 | 19 | MUTLICLASS_DF = ( 20 | pds.random_data(size=SIZE, n_cols=2) 21 | .with_columns( 22 | pds.random_int(0, 2).alias("y_true"), 23 | pl.concat_list("feature_1", "feature_2").alias("y_score"), 24 | ) 25 | .drop("feature_1", "feature_2") 26 | ) 27 | 28 | SIZES = [1_000, 10_000, 100_000, 1_000_000] 29 | 30 | 31 | @pytest.mark.parametrize("n", SIZES) 32 | @pytest.mark.benchmark(group="mad") 33 | def test_mad(benchmark, n): 34 | df = DF.sample(n=n, seed=SEED) 35 | 36 | @benchmark 37 | def func(): 38 | df.select(pds.query_mad("y_score")) 39 | 40 | 41 | @pytest.mark.parametrize("n", SIZES) 42 | @pytest.mark.benchmark(group="mad") 43 | def test_scipy_mad(benchmark, n): 44 | df = DF.sample(n=n, seed=SEED) 45 | 46 | @benchmark 47 | def func(): 48 | scipy.stats.median_abs_deviation(df["y_score"]) 49 | 50 | 51 | @pytest.mark.parametrize("n", SIZES) 52 | @pytest.mark.benchmark(group="r2") 53 | def test_r2(benchmark, n): 54 | df = DF.sample(n=n, seed=SEED) 55 | 56 | @benchmark 57 | def func(): 58 | df.select(pds.query_r2("y_true_score", "y_score")) 59 | 60 | 61 | @pytest.mark.parametrize("n", SIZES) 62 | @pytest.mark.benchmark(group="r2") 63 | def test_sklearn_r2(benchmark, n): 64 | df = DF.sample(n=n, seed=SEED) 65 | 66 | @benchmark 67 | def func(): 68 | sklearn.metrics.r2_score(df["y_true_score"], df["y_score"]) 69 | 70 | 71 | @pytest.mark.parametrize("n", SIZES) 72 | def test_log_cosh(benchmark, n): 73 | df = DF.sample(n=n, seed=SEED) 74 | 75 | @benchmark 76 | def func(): 77 | df.select(pds.query_log_cosh("y_true_score", "y_score")) 78 | 79 | 80 | @pytest.mark.parametrize("n", SIZES) 81 | def test_hubor_loss(benchmark, n): 82 | df = DF.sample(n=n, seed=SEED) 83 | 84 | @benchmark 85 | def func(): 86 | df.select(pds.query_hubor_loss("y_true_score", "y_score", delta=0.2)) 87 | 88 | 89 | @pytest.mark.parametrize("n", SIZES) 90 | def test_l2(benchmark, n): 91 | df = DF.sample(n=n, seed=SEED) 92 | 93 | @benchmark 94 | def func(): 95 | df.select(pds.query_l2("y_true_score", "y_score")) 96 | 97 | 98 | @pytest.mark.parametrize("n", SIZES) 99 | def test_l1(benchmark, n): 100 | df = DF.sample(n=n, seed=SEED) 101 | 102 | @benchmark 103 | def func(): 104 | df.select(pds.query_l1("y_true_score", "y_score")) 105 | 106 | 107 | @pytest.mark.parametrize("n", SIZES) 108 | def test_l_inf(benchmark, n): 109 | df = DF.sample(n=n, seed=SEED) 110 | 111 | @benchmark 112 | def func(): 113 | df.select(pds.query_l_inf("y_true_score", "y_score")) 114 | 115 | 116 | @pytest.mark.parametrize("n", SIZES) 117 | def test_log_loss(benchmark, n): 118 | df = DF.sample(n=n, seed=SEED) 119 | 120 | @benchmark 121 | def func(): 122 | df.select(pds.query_log_loss("y_true_score", "y_score")) 123 | 124 | 125 | @pytest.mark.parametrize("n", SIZES) 126 | def test_mape(benchmark, n): 127 | df = DF.sample(n=n, seed=SEED) 128 | 129 | @benchmark 130 | def func(): 131 | df.select(pds.query_mape("y_true_score", "y_score")) 132 | 133 | 134 | @pytest.mark.parametrize("n", SIZES) 135 | def test_smape(benchmark, n): 136 | df = DF.sample(n=n, seed=SEED) 137 | 138 | @benchmark 139 | def func(): 140 | df.select(pds.query_smape("y_true_score", "y_score")) 141 | 142 | 143 | @pytest.mark.parametrize("n", SIZES) 144 | def test_msle(benchmark, n): 145 | df = DF.sample(n=n, seed=SEED) 146 | 147 | @benchmark 148 | def func(): 149 | df.select(pds.query_msle("y_true_score", "y_score")) 150 | 151 | 152 | @pytest.mark.parametrize("n", SIZES) 153 | @pytest.mark.benchmark(group="roc_auc") 154 | def test_roc_auc(benchmark, n): 155 | df = DF.sample(n=n, seed=SEED) 156 | 157 | @benchmark 158 | def func(): 159 | df.select(pds.query_roc_auc("y_true", "y_score")) 160 | 161 | 162 | @pytest.mark.parametrize("n", SIZES) 163 | @pytest.mark.benchmark(group="roc_auc") 164 | def test_sklearn_roc_auc(benchmark, n): 165 | df = DF.sample(n=n, seed=SEED) 166 | 167 | @benchmark 168 | def func(): 169 | sklearn.metrics.roc_auc_score(df["y_true"], df["y_score"]) 170 | 171 | 172 | @pytest.mark.parametrize("n", SIZES) 173 | def test_binary_metrics(benchmark, n): 174 | df = DF.sample(n=n, seed=SEED) 175 | 176 | @benchmark 177 | def func(): 178 | df.select(pds.query_binary_metrics("y_true", "y_score")) 179 | 180 | 181 | @pytest.mark.parametrize("n", SIZES) 182 | def test_confusion_matrix(benchmark, n): 183 | df = DF.sample(n=n, seed=SEED) 184 | 185 | @benchmark 186 | def func(): 187 | df.select(pds.query_confusion_matrix("y_true", "y_score")) 188 | 189 | 190 | @pytest.mark.parametrize("n", SIZES) 191 | def test_multi_roc_auc(benchmark, n): 192 | df = MUTLICLASS_DF.sample(n=n, seed=SEED) 193 | 194 | @benchmark 195 | def func(): 196 | df.select(pds.query_multi_roc_auc("y_true", "y_score", n_classes=3)) 197 | -------------------------------------------------------------------------------- /docs/compat.md: -------------------------------------------------------------------------------- 1 | ## Compatibility with Arrays 2 | 3 | ::: polars_ds.compat -------------------------------------------------------------------------------- /docs/eda.md: -------------------------------------------------------------------------------- 1 | ## Explorative Data Analysis 2 | 3 | ::: polars_ds.eda -------------------------------------------------------------------------------- /docs/expr_iter.md: -------------------------------------------------------------------------------- 1 | ## Iteration Helper Expressions 2 | 3 | ::: polars_ds.exprs.expr_iter -------------------------------------------------------------------------------- /docs/expr_knn.md: -------------------------------------------------------------------------------- 1 | ## K Nearest Neighbor Related Queries 2 | 3 | ::: polars_ds.exprs.expr_knn -------------------------------------------------------------------------------- /docs/expr_linear.md: -------------------------------------------------------------------------------- 1 | ## Linear Models Related Queries 2 | 3 | ::: polars_ds.exprs.expr_linear -------------------------------------------------------------------------------- /docs/linear_models.md: -------------------------------------------------------------------------------- 1 | ## Linear Models 2 | 3 | ::: polars_ds.linear_models -------------------------------------------------------------------------------- /docs/metrics.md: -------------------------------------------------------------------------------- 1 | ## Extension for ML Metrics/Losses 2 | 3 | ::: polars_ds.exprs.metrics -------------------------------------------------------------------------------- /docs/num.md: -------------------------------------------------------------------------------- 1 | ## Extension for General Numerical Features/Metrics/Quantities 2 | 3 | ::: polars_ds.exprs.num -------------------------------------------------------------------------------- /docs/pipeline.md: -------------------------------------------------------------------------------- 1 | ## Polars Native Machine Learning Pipeline 2 | 3 | ::: polars_ds.modeling.pipeline -------------------------------------------------------------------------------- /docs/polars_ds.md: -------------------------------------------------------------------------------- 1 | ## Additional Expressions 2 | 3 | ::: polars_ds 4 | options: 5 | filters: ["^__init__$"] -------------------------------------------------------------------------------- /docs/requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.6 2 | mkdocstrings[python]==0.24.0 3 | mkdocs-material==9.6 4 | mkdocs-section-index==0.3.9 5 | pytkdocs[numpy-style]==0.16.2 6 | griffe==0.48 -------------------------------------------------------------------------------- /docs/sample_and_split.md: -------------------------------------------------------------------------------- 1 | ## Polars Native Machine Learning Pipeline 2 | 3 | ::: polars_ds.sample_and_split -------------------------------------------------------------------------------- /docs/spatial.md: -------------------------------------------------------------------------------- 1 | ## Standalone Kd-Tree Model 2 | 3 | ::: polars_ds.spatial -------------------------------------------------------------------------------- /docs/stats.md: -------------------------------------------------------------------------------- 1 | ## Extension for Statistical Tests and Samples 2 | 3 | ::: polars_ds.exprs.stats -------------------------------------------------------------------------------- /docs/string.md: -------------------------------------------------------------------------------- 1 | ## Extension for String Manipulation and Metrics 2 | 3 | ::: polars_ds.exprs.string -------------------------------------------------------------------------------- /docs/ts_features.md: -------------------------------------------------------------------------------- 1 | ## Feature Engineering Queries and Time Series Features 2 | 3 | ::: polars_ds.exprs.ts_features -------------------------------------------------------------------------------- /examples/auto_complete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abstractqqq/polars_ds_extension/31e44245c214ad1b464c3458de0960f61a076145/examples/auto_complete.png -------------------------------------------------------------------------------- /examples/dependency.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abstractqqq/polars_ds_extension/31e44245c214ad1b464c3458de0960f61a076145/examples/dependency.parquet -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Polars-ds API References 2 | site_url: https://polars-ds-extension.readthedocs.io/en/latest/ 3 | 4 | use_directory_urls: false 5 | 6 | nav: 7 | - Home: index.md 8 | - Explorative Data Analysis: eda.md 9 | - Pipeline: pipeline.md 10 | - ML Metrics/Loss: metrics.md 11 | - Linear Regression as Polars Expr: expr_linear.md 12 | - Numerical Functions: num.md 13 | - Statistics: stats.md 14 | - String Related: string.md 15 | - Itertools for Polars Lists: expr_iter.md 16 | - Sample and Split: sample_and_split.md 17 | - Linear Models: linear_models.md 18 | - Spatial Models: spatial.md 19 | - KNN as Polars Expr: expr_knn.md 20 | - Time Series Features: ts_features.md 21 | - Compatibility with Arrays: compat.md 22 | - Miscellaneous: polars_ds.md 23 | 24 | theme: 25 | name: material 26 | locale: en 27 | highlightjs: true 28 | 29 | plugins: 30 | - search 31 | - section-index 32 | - mkdocstrings: 33 | handlers: 34 | python: 35 | paths: [python] 36 | options: 37 | summary: true 38 | docstring_style: numpy 39 | show_submodules: true -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.7.4"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "polars_ds" 7 | requires-python = ">=3.9" 8 | version = "0.9.0" 9 | 10 | license = { file = "LICENSE.txt" } 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Programming Language :: Rust", 14 | "Programming Language :: Python :: Implementation :: CPython", 15 | "Programming Language :: Python :: Implementation :: PyPy", 16 | "License :: OSI Approved :: MIT License", 17 | ] 18 | authors = [{ name = "Tianren Qin", email = "tq9695@gmail.com" }] 19 | dependencies = [ 20 | "polars >= 1.4.0, != 1.25, != 1.26", 21 | 'typing-extensions; python_version <= "3.11"', 22 | ] 23 | 24 | keywords = ["polars-extension", "scientific-computing", "data-science"] 25 | 26 | [project.optional-dependencies] 27 | plot = ["great-tables>=0.9", "graphviz>=0.20", "altair >= 5.4.0", "vegafusion[embed]"] 28 | models = ["numpy>=1.16"] 29 | compat = ["numpy>=1.16"] 30 | all = ["great-tables>=0.9", "graphviz>=0.20", "numpy>=1.16", "altair >= 5.4.0", "vegafusion[embed]"] 31 | 32 | [tool.maturin] 33 | strip = true 34 | python-source = "python" 35 | features = ["pyo3/extension-module"] 36 | module-name = "polars_ds._polars_ds" 37 | 38 | [tool.ruff] 39 | line-length = 100 40 | fix = true 41 | src = ["python"] 42 | 43 | [tool.ruff.format] 44 | docstring-code-format = true 45 | 46 | [tool.pytest.ini_options] 47 | testpaths = ["tests"] 48 | -------------------------------------------------------------------------------- /python/polars_ds/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import importlib.metadata 3 | import polars as pl 4 | 5 | __version__ = importlib.metadata.version("polars_ds") 6 | 7 | # Internal dependencies 8 | from polars_ds.exprs import * # noqa F403 9 | 10 | 11 | def frame(size: int = 2_000, index_name: str = "row_num") -> pl.DataFrame: 12 | """ 13 | Generates a frame with only an index (row number) column. 14 | This is a convenience function to be chained with pds.random(...) when running simulations and tests. 15 | 16 | Parameters 17 | ---------- 18 | size 19 | The total number of rows in this dataframe 20 | index_name 21 | The name of the index column 22 | """ 23 | return pl.DataFrame({index_name: range(size)}) 24 | -------------------------------------------------------------------------------- /python/polars_ds/_utils.py: -------------------------------------------------------------------------------- 1 | """Not meant for outside use.""" 2 | 3 | from __future__ import annotations 4 | 5 | import polars as pl 6 | from typing import Sequence 7 | from pathlib import Path 8 | from polars.plugins import register_plugin_function 9 | 10 | # Only need this 11 | _PLUGIN_PATH = Path(__file__).parent 12 | 13 | # V1.18 Introduces a Int128 dtype 14 | # _IS_POLARS_V1_18 = pl.__version__.startswith("1.18.") 15 | 16 | 17 | def pl_plugin( 18 | *, 19 | symbol: str, 20 | args: Sequence[pl.Series | pl.Expr], 21 | kwargs: dict[str, str | int | float | bool] | None = None, 22 | is_elementwise: bool = False, 23 | returns_scalar: bool = False, 24 | changes_length: bool = False, 25 | cast_to_supertype: bool = False, 26 | pass_name_to_apply: bool = False, 27 | ) -> pl.Expr: 28 | return register_plugin_function( 29 | plugin_path=_PLUGIN_PATH, 30 | args=args, 31 | function_name=symbol, 32 | kwargs=kwargs, 33 | is_elementwise=is_elementwise, 34 | returns_scalar=returns_scalar, 35 | changes_length=changes_length, 36 | cast_to_supertype=cast_to_supertype, 37 | pass_name_to_apply=pass_name_to_apply, 38 | ) 39 | 40 | 41 | # Auxiliary functions for type conversions 42 | def str_to_expr(e: str | pl.Expr | int | float) -> pl.Expr: 43 | """ 44 | Turns a string into an expression 45 | 46 | Parameters 47 | ---------- 48 | e 49 | Either a str represeting a column name or an expression 50 | """ 51 | if isinstance(e, str): 52 | return pl.col(e) 53 | elif isinstance(e, (int, float)): 54 | return pl.lit(e) 55 | elif isinstance(e, pl.Expr): 56 | return e 57 | else: 58 | raise ValueError("Input must either be a string or a Polars expression.") 59 | -------------------------------------------------------------------------------- /python/polars_ds/compat/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compatibility with other Dataframes. 3 | 4 | This module provides compatibility with other dataframe libraries that: 5 | 6 | 1. Have a notion of Series 7 | 2. The Series implements the array protocal, which means it can be translated to NumPy array via 8 | .__array__() method. 9 | 10 | Since most dataframe libraries can turn their Series into NumPy (or vice versa) with 0 copy, 11 | this compatibility layer has very little overhead. The only constraint is that the dataframe 12 | must be eager, in the sense that data is already loaded in memory. The reason for this is that 13 | the notion of a Series doesn't really exist in the lazy world, and lazy columns cannot be turned 14 | to NumPy arrays. 15 | 16 | When using this compatibility, the output is always a Polars Series. This is because the output 17 | type could be Polars struct/list Series, which are Polars-specific types. It is up to the user 18 | what to do with the output. 19 | 20 | For example, in order to use PDS with Pandas dataframe, say df:pd.DataFrame, one needs to write 21 | 22 | >>> from polars_ds.compat import compat as pds2 23 | >>> # Output is a Polars Series. 24 | >>> pds2.query_roc_auc(df_pd["actual"], df_pd["predicted"]) 25 | >>> # For more advanced queries 26 | >>> pds2.lin_reg( 27 | >>> df["x1"], df["x2"], df["x3"] 28 | >>> target = df["y"], 29 | >>> return_pred = True 30 | >>> ) 31 | 32 | Question: if output is still Polars, then the user must still use both Polars and Pandas. 33 | Why bother with compatibility? 34 | 35 | Here are some answers I consider to be true (or self-promotion :)) 36 | 37 | 1. PDS is a very light weight package that can reduce dependencies in your project. 38 | 2. For projects with mixed dataframes, it is sometimes not a good idea to cast the 39 | entire Pandas (or other) dataframe to Polars. 40 | 3. Some PDS functions are faster than SciPy / Sklearn equivalents. 41 | 4. For ad-hoc analysis that involves say something like linear regression, PDS is easier to 42 | use than other package. 43 | """ 44 | 45 | from ._compat import compat 46 | 47 | import warnings 48 | warnings.warn( 49 | "The compatibility layer is considered experimental.", 50 | stacklevel=2 51 | ) -------------------------------------------------------------------------------- /python/polars_ds/compat/_compat.py: -------------------------------------------------------------------------------- 1 | import polars as pl 2 | import numpy as np 3 | from typing import Any, Callable 4 | import polars_ds as pds 5 | 6 | # Everything in __init__.py of polars_ds that this shouldn't be able to call 7 | CANNOT_CALL = { 8 | "frame", 9 | "str_to_expr", 10 | "pl", 11 | "annotations", 12 | "__version__", 13 | "warn_len_compare" 14 | } 15 | 16 | __all__ = ["compat"] 17 | 18 | class _Compat(): 19 | 20 | @staticmethod 21 | def _try_into_series(x:Any, name:str) -> Any: 22 | """ 23 | Try to map the input to a Polars Series by going through a NumPy array. If 24 | this is not possible, return the original input. 25 | """ 26 | if isinstance(x, np.ndarray): 27 | return pl.lit(pl.Series(name=name, values=x)) 28 | elif isinstance(x, pl.Series): 29 | return pl.lit(x) 30 | elif hasattr(x, "__array__"): 31 | return pl.lit(pl.Series(name=name, values=x.__array__())) 32 | else: 33 | return x 34 | 35 | def __getattr__(self, name:str) -> pl.Series: 36 | if name in CANNOT_CALL: 37 | raise ValueError(f"`{name}` exists but doesn't work in compat mode.") 38 | 39 | func = getattr(pds, name) 40 | def compat_wrapper(*args, **kwargs) -> Callable: 41 | positionals = list(args) 42 | if len(positionals) <= 0: 43 | raise ValueError("There must be at least 1 positional argument!") 44 | 45 | new_args = ( 46 | _Compat._try_into_series(x, name = str(i)) 47 | for i, x in enumerate(positionals) 48 | ) 49 | new_kwargs = { 50 | n: _Compat._try_into_series(v, name = n) 51 | for n, v in kwargs.items() 52 | } 53 | # An eager df, drop output col, so a pl.Series 54 | return ( 55 | pl.select( 56 | func(*new_args, **new_kwargs).alias("__output__") 57 | ).drop_in_place("__output__") 58 | .rename(name.replace("query_", "")) 59 | ) 60 | 61 | return compat_wrapper 62 | 63 | compat: _Compat = _Compat() 64 | 65 | -------------------------------------------------------------------------------- /python/polars_ds/config.py: -------------------------------------------------------------------------------- 1 | # Configs used in transforms and pipelines 2 | # STREAM_IN_TRANSFORM: bool = False 3 | # Level of optimiztion and memory usage, etc. 4 | # Is there a better way to do this? 5 | 6 | 7 | LIN_REG_EXPR_F64 = True 8 | """ 9 | If true, all linear regression expression will use f64 as the default data type 10 | in the underlying implementation. If fase, f32 will be used. This only controls 11 | linear regression expressions. 12 | 13 | The memory footprint will be smaller, but it is possible to have slower speed than f64 for 14 | multiple reasons: 15 | 1. If input data is already in f64, then using f32 will incur additional casts, slowing 16 | down the process. 17 | 2. If input data is not big enough, there won't be any noticeable difference in runtime. 18 | 19 | """ 20 | 21 | 22 | def _lin_reg_expr_symbol(x: str) -> str: 23 | if LIN_REG_EXPR_F64: 24 | return x 25 | else: 26 | return x + "_f32" 27 | -------------------------------------------------------------------------------- /python/polars_ds/eda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abstractqqq/polars_ds_extension/31e44245c214ad1b464c3458de0960f61a076145/python/polars_ds/eda/__init__.py -------------------------------------------------------------------------------- /python/polars_ds/exprs/__init__.py: -------------------------------------------------------------------------------- 1 | from .expr_knn import * # noqa : F403 2 | from .expr_linear import * # noqa : F403 3 | from .metrics import * # noqa : F403 4 | from .num import * # noqa : F403 5 | from .stats import * # noqa : F403 6 | from .string import * # noqa : F403 7 | from .ts_features import * # noqa : F403 8 | from .expr_iter import * # noqa : F403 9 | -------------------------------------------------------------------------------- /python/polars_ds/exprs/expr_iter.py: -------------------------------------------------------------------------------- 1 | """Iteration related helper expressions""" 2 | 3 | from __future__ import annotations 4 | 5 | import polars as pl 6 | 7 | # Internal dependencies 8 | from polars_ds._utils import pl_plugin, str_to_expr 9 | 10 | __all__ = ["combinations", "product"] 11 | 12 | 13 | def product(s1: str | pl.Expr, s2: str | pl.Expr) -> pl.Expr: 14 | """ 15 | Get the cartesian product of two series. Only non-nulls values will be used. 16 | 17 | Parameters 18 | ---------- 19 | s1 20 | The first column / series 21 | s2 22 | The second column / series 23 | 24 | Examples 25 | -------- 26 | >>> df = pl.DataFrame({ 27 | >>> "a": [1, 2] 28 | >>> , "b": [4, 5] 29 | >>> }) 30 | >>> df.select( 31 | >>> pds.product("a", "b") 32 | >>> ) 33 | shape: (4, 1) 34 | ┌───────────┐ 35 | │ a │ 36 | │ --- │ 37 | │ list[i64] │ 38 | ╞═══════════╡ 39 | │ [1, 4] │ 40 | │ [1, 5] │ 41 | │ [2, 4] │ 42 | │ [2, 5] │ 43 | └───────────┘ 44 | 45 | >>> df = pl.DataFrame({ 46 | >>> "a": [[1,2], [3,4]] 47 | >>> , "b": [[3], [1, 2]] 48 | >>> }).with_row_index() 49 | >>> df 50 | shape: (2, 3) 51 | ┌───────┬───────────┬───────────┐ 52 | │ index ┆ a ┆ b │ 53 | │ --- ┆ --- ┆ --- │ 54 | │ u32 ┆ list[i64] ┆ list[i64] │ 55 | ╞═══════╪═══════════╪═══════════╡ 56 | │ 0 ┆ [1, 2] ┆ [3] │ 57 | │ 1 ┆ [3, 4] ┆ [1, 2] │ 58 | └───────┴───────────┴───────────┘ 59 | 60 | >>> df.group_by( 61 | >>> "index" 62 | >>> ).agg( 63 | >>> pds.product( 64 | >>> pl.col("a").list.explode() 65 | >>> , pl.col("b").list.explode() 66 | >>> ).alias("product") 67 | >>> ) 68 | shape: (2, 2) 69 | ┌───────┬────────────────────────────┐ 70 | │ index ┆ product │ 71 | │ --- ┆ --- │ 72 | │ u32 ┆ list[list[i64]] │ 73 | ╞═══════╪════════════════════════════╡ 74 | │ 0 ┆ [[1, 3], [2, 3]] │ 75 | │ 1 ┆ [[3, 1], [3, 2], … [4, 2]] │ 76 | └───────┴────────────────────────────┘ 77 | """ 78 | return pl_plugin( 79 | symbol="pl_product", 80 | args=[str_to_expr(s1).drop_nulls(), str_to_expr(s2).drop_nulls()], 81 | changes_length=True, 82 | ) 83 | 84 | 85 | def combinations(source: str | pl.Expr, k: int, unique: bool = False) -> pl.Expr: 86 | """ 87 | Get all k-combinations of non-null values in source. This is an expensive operation, as 88 | n choose k can grow very fast. 89 | 90 | Parameters 91 | ---------- 92 | source 93 | Input source column, must have numeric or string type 94 | k 95 | The k in N choose k 96 | unique 97 | Whether to run .unique() on the source column 98 | 99 | Examples 100 | -------- 101 | >>> df = pl.DataFrame({ 102 | >>> "category": ["a", "a", "a", "b", "b"], 103 | >>> "values": [1, 2, 3, 4, 5] 104 | >>> }) 105 | >>> df.select( 106 | >>> pds.combinations("values", 3) 107 | >>> ) 108 | shape: (10, 1) 109 | ┌───────────┐ 110 | │ values │ 111 | │ --- │ 112 | │ list[i64] │ 113 | ╞═══════════╡ 114 | │ [1, 2, 3] │ 115 | │ [1, 2, 4] │ 116 | │ [1, 2, 5] │ 117 | │ [1, 3, 4] │ 118 | │ [1, 3, 5] │ 119 | │ [1, 4, 5] │ 120 | │ [2, 3, 4] │ 121 | │ [2, 3, 5] │ 122 | │ [2, 4, 5] │ 123 | │ [3, 4, 5] │ 124 | └───────────┘ 125 | >>> df.group_by("category").agg( 126 | >>> pds.combinations("values", 2) 127 | >>> ) 128 | shape: (2, 2) 129 | ┌──────────┬──────────────────────────┐ 130 | │ category ┆ values │ 131 | │ --- ┆ --- │ 132 | │ str ┆ list[list[i64]] │ 133 | ╞══════════╪══════════════════════════╡ 134 | │ a ┆ [[1, 2], [1, 3], [2, 3]] │ 135 | │ b ┆ [[4, 5]] │ 136 | └──────────┴──────────────────────────┘ 137 | """ 138 | s = ( 139 | str_to_expr(source).unique().drop_nulls().sort() 140 | if unique 141 | else str_to_expr(source).drop_nulls() 142 | ) 143 | return pl_plugin( 144 | symbol="pl_combinations", 145 | args=[s], 146 | changes_length=True, 147 | kwargs={ 148 | "k": k, 149 | }, 150 | ) 151 | -------------------------------------------------------------------------------- /python/polars_ds/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import Blueprint, Pipeline, FitStep -------------------------------------------------------------------------------- /python/polars_ds/modeling/_step.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import polars as pl 4 | from io import StringIO 5 | from enum import Enum 6 | from dataclasses import dataclass 7 | from polars_ds.typing import FitTransformFunc, ExprTransform 8 | from typing import List, Sequence, Dict 9 | from polars._typing import IntoExprColumn 10 | 11 | 12 | @dataclass 13 | class FitStep: # Not a FittedStep 14 | func: FitTransformFunc 15 | cols: IntoExprColumn | None 16 | exclude: List[str] 17 | 18 | # Here we allow IntoExprColumn as input so that users can use selectors, or other polars expressions 19 | # to specify input columns, which adds flexibility. 20 | # We still need real column names so that the functions in transforms.py will work. 21 | def fit(self, df: pl.DataFrame | pl.LazyFrame) -> ExprTransform: 22 | if self.cols is None: 23 | return self.func(df) 24 | else: 25 | real_cols: List[str] = [ 26 | x 27 | for x in df.lazy().select(self.cols).collect_schema().names() 28 | if x not in self.exclude 29 | ] 30 | return self.func(df, real_cols) 31 | 32 | 33 | class PLContext(Enum): 34 | SELECT = "select" 35 | WITH_COLUMNS = "with_columns" 36 | SQL = "sql" 37 | FILTER = "filter" 38 | 39 | 40 | class PipelineStep: 41 | def __init__(self, action: str | pl.Expr | Sequence[pl.Expr], context: str | PLContext): 42 | self.context = context if isinstance(context, PLContext) else PLContext(context) 43 | if isinstance(action, pl.Expr): 44 | self.exprs = [action] 45 | elif isinstance(action, str) and self.context == PLContext.SQL: 46 | self.exprs = [action] 47 | elif hasattr(action, "__iter__"): 48 | self.exprs = list(action) 49 | if any(not isinstance(e, pl.Expr) for e in self.exprs): 50 | raise ValueError( 51 | "When input is a list, all elements in the list must be polars expressions." 52 | ) 53 | else: 54 | raise ValueError( 55 | "A pipeline step must be either an expression or a list of expressions, or a str in SQL Context." 56 | ) 57 | 58 | @staticmethod 59 | def from_json(step_dict: dict) -> "PipelineStep": 60 | context, json_actions = step_dict["context"], step_dict["action"] 61 | step_context = PLContext(context) 62 | if step_context in (PLContext.SELECT, PLContext.WITH_COLUMNS): 63 | actions = [pl.Expr.deserialize(StringIO(e), format="json") for e in json_actions] 64 | elif step_context == PLContext.SQL: 65 | actions = str(json_actions) # SQL only needs the string 66 | elif step_context == PLContext.FILTER: 67 | actions = [pl.Expr.deserialize(StringIO(json_actions), format="json")] 68 | else: 69 | raise ValueError("Input is not a valid PDS pipeline.") 70 | 71 | return PipelineStep(action=actions, context=step_context) 72 | 73 | def __iter__(self): 74 | return self.exprs.__iter__() 75 | 76 | def to_json(self) -> Dict: 77 | if self.context == PLContext.SELECT: 78 | return { 79 | "context": "select", 80 | "action": [e.meta.serialize(format="json") for e in self.exprs], 81 | } 82 | elif self.context == PLContext.WITH_COLUMNS: 83 | return { 84 | "context": "with_columns", 85 | "action": [e.meta.serialize(format="json") for e in self.exprs], 86 | } 87 | elif self.context == PLContext.SQL: 88 | return {"context": "sql", "action": self.exprs[0]} 89 | elif self.context == PLContext.FILTER: 90 | return {"context": "filter", "action": self.exprs[0].meta.serialize(format="json")} 91 | else: 92 | raise ValueError(f"Unknown context: {self.context}") 93 | 94 | def apply_df(self, df: pl.LazyFrame | pl.DataFrame) -> pl.LazyFrame | pl.DataFrame: 95 | if self.context == PLContext.SELECT: 96 | return df.select(self.exprs) 97 | elif self.context == PLContext.WITH_COLUMNS: 98 | return df.with_columns(self.exprs) 99 | elif self.context == PLContext.SQL: 100 | return pl.SQLContext(df=df, eager=False).execute(self.exprs[0]) 101 | elif self.context == PLContext.FILTER: 102 | return df.filter(self.exprs[0]) 103 | -------------------------------------------------------------------------------- /python/polars_ds/partition/__init__.py: -------------------------------------------------------------------------------- 1 | from .partition import PartitionHelper -------------------------------------------------------------------------------- /python/polars_ds/partition/partition.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import polars as pl 4 | import polars.selectors as cs 5 | import warnings 6 | # Typing 7 | from collections.abc import Callable 8 | from typing import List, Dict, Any 9 | # Internal Dependencies 10 | from polars_ds.typing import PolarsFrame 11 | 12 | class PartitionHelper(): 13 | """ 14 | A transitory convenience class. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | df: PolarsFrame, 20 | by: str | List[str] | None, 21 | separator: str = "|", 22 | whole_df_name: str = "df" 23 | ): 24 | """ 25 | Creates a Partition Result 26 | 27 | Parameters 28 | ---------- 29 | df 30 | Either a Polars dataframe or a Lazyframe 31 | by 32 | Either None, or a string or a list of strings representing column names. If by 33 | is None, the entire df will be considered a partition. 34 | separator 35 | Separator for concatenating the names of different parts, if the partition is done 36 | by multiple columns 37 | whole_df_name 38 | If by is None, the name for the whole df. 39 | """ 40 | if by is None: 41 | self.parts: Dict[str, pl.DataFrame] = {whole_df_name: df.lazy().collect()} 42 | else: 43 | cols = df.select( 44 | (cs.by_name(by)) & (cs.string() | cs.categorical() | cs.boolean()) 45 | ).collect_schema().names() 46 | 47 | all_ok = cols[0] == by if isinstance(by, str) else sorted(cols) == sorted(by) 48 | if not all_ok: 49 | raise ValueError("Currently this only supports partitions by str, bool or categorical columns.") 50 | 51 | self.parts = { 52 | separator.join((str(k) for k in keys)): value 53 | for keys, value in df.lazy().collect().partition_by(by=by, as_dict=True).items() 54 | } 55 | 56 | def __repr__(self) -> str: 57 | output = "" 58 | for part, df in self.parts.items(): 59 | output += f"Paritition: {part}\n" 60 | output += df.__repr__() + "\n" 61 | return output 62 | 63 | def head(self, n:int = 5) -> Dict[str, pl.DataFrame]: 64 | return {k: df.head(n) for k, df in self.parts.items()} 65 | 66 | def names(self) -> List[str]: 67 | return list(self.parts.keys()) 68 | 69 | def get(self, part:str) -> pl.DataFrame | None: 70 | return self.parts.get(part, None) 71 | 72 | def apply(self, func: Callable[[str, pl.DataFrame], Any]) -> Dict[str, Any]: 73 | """ 74 | Apply an arbitrary function to all parts in this partition. 75 | 76 | Parameters 77 | ---------- 78 | func 79 | A function that takes in a str and a pl.DataFrame and outputs anything. The string 80 | represents the name of the segment. Note: this is usually a partial/lambda function with 81 | all other arguments provided. 82 | """ 83 | output = {} 84 | for part, df in self.parts.items(): 85 | try: 86 | output[part] = func(part, df) 87 | except Exception as e: 88 | warnings.warn( 89 | f"An error occured while processing for the part: {part}. This partition is omitted.\nOriginal Error Message: {e}" 90 | , stacklevel = 2 91 | ) 92 | 93 | return output -------------------------------------------------------------------------------- /python/polars_ds/sample_and_split/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_and_split import * -------------------------------------------------------------------------------- /python/polars_ds/typing.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Literal, List, Callable, Union 3 | import sys 4 | import polars as pl 5 | 6 | if sys.version_info >= (3, 10): 7 | from typing import TypeAlias 8 | else: # 3.9, 3.8 9 | from typing_extensions import TypeAlias 10 | 11 | # Custom "Enum" Types 12 | DetrendMethod: TypeAlias = Literal["linear", "mean"] 13 | Alternative: TypeAlias = Literal["two-sided", "less", "greater"] 14 | Distance: TypeAlias = Literal["l1", "l2", "sql2", "inf", "cosine", "haversine"] 15 | KdtDistance: TypeAlias = Literal["l1", "l2", "sql2", "inf"] 16 | ConvMode: TypeAlias = Literal["same", "left", "right", "full", "valid"] 17 | ConvMethod: TypeAlias = Literal["fft", "direct"] 18 | CorrMethod: TypeAlias = Literal["pearson", "spearman", "xi", "kendall", "bicor"] 19 | SimpleImputeMethod: TypeAlias = Literal["mean", "median", "mode"] 20 | SimpleScaleMethod: TypeAlias = Literal["min_max", "standard", "abs_max"] 21 | Noise: TypeAlias = Literal["gaussian", "uniform"] 22 | LRMethods: TypeAlias = Literal["normal", "l2", "l1"] 23 | LRSolverMethods: TypeAlias = Literal["svd", "qr", "cholesky"] 24 | NullPolicy: TypeAlias = Literal["raise", "skip", "one", "zero", "ignore"] 25 | MultiAUCStrategy: TypeAlias = Literal["weighted", "macro"] 26 | EncoderDefaultStrategy: TypeAlias = Literal["mean", "null", "zero"] 27 | # Copy of Polars 28 | QuantileMethod: TypeAlias = Literal["nearest", "higher", "lower", "midpoint", "linear"] 29 | 30 | # Other Custom Types 31 | PolarsFrame: TypeAlias = Union[pl.DataFrame, pl.LazyFrame] 32 | ExprTransform: TypeAlias = Union[pl.Expr, List[pl.Expr]] 33 | # Need ... 34 | FitTransformFunc: TypeAlias = Callable[[PolarsFrame, List[str]], ExprTransform] 35 | 36 | # # For compatibility 37 | # IntoNumpy: TypeAlias = Union["Sequence[float]", "Sequence[int]"] 38 | # """Anything which can be converted to a NumPy numeric array. 39 | 40 | # Examples: 41 | # >>> from polars_ds.typing import IntoNumPy 42 | # >>> def agnostic_to_numpy(s: IntoNumpy) -> np.ndarray: 43 | # ... return s.to_numpy() 44 | # """ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Needed for local development 2 | maturin[patchelf]>=1.7; sys_platform == "linux" 3 | maturin>=1.7; sys_platform != "linux" 4 | polars 5 | pre-commit 6 | ipykernel 7 | numpy 8 | # nbformat>=4.2.0 # Need this if we have plotly 9 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly-2025-04-19" -------------------------------------------------------------------------------- /src/arkadia/leaf.rs: -------------------------------------------------------------------------------- 1 | use num::Float; 2 | 3 | #[derive(Clone, Copy, Debug)] 4 | pub struct Leaf<'a, T: Float, A> { 5 | pub item: A, 6 | pub row_vec: &'a [T], 7 | } 8 | 9 | #[derive(Clone)] 10 | pub struct OwnedLeaf { 11 | pub item: A, 12 | pub row_vec: Vec, 13 | } 14 | 15 | impl<'a, T: Float, A: Copy> From<(A, &'a [T])> for Leaf<'a, T, A> { 16 | fn from(value: (A, &'a [T])) -> Self { 17 | Leaf { 18 | item: value.0, 19 | row_vec: value.1, 20 | } 21 | } 22 | } 23 | 24 | impl From<(A, &[T])> for OwnedLeaf { 25 | fn from(value: (A, &[T])) -> Self { 26 | OwnedLeaf { 27 | item: value.0, 28 | row_vec: value.1.to_vec(), 29 | } 30 | } 31 | } 32 | 33 | pub trait KdLeaf { 34 | fn dim(&self) -> usize; 35 | 36 | fn value_at(&self, idx: usize) -> T; 37 | 38 | fn vec(&self) -> &[T]; 39 | 40 | fn get_item(&self) -> A; 41 | } 42 | 43 | impl<'a, T: Float, A: Copy> KdLeaf for Leaf<'a, T, A> { 44 | fn dim(&self) -> usize { 45 | self.row_vec.len() 46 | } 47 | 48 | fn value_at(&self, idx: usize) -> T { 49 | self.row_vec[idx] 50 | } 51 | 52 | fn vec(&self) -> &'a [T] { 53 | self.row_vec 54 | } 55 | 56 | fn get_item(&self) -> A { 57 | self.item 58 | } 59 | } 60 | 61 | impl<'a, T: Float, A: Copy> KdLeaf for OwnedLeaf { 62 | fn dim(&self) -> usize { 63 | self.row_vec.len() 64 | } 65 | 66 | fn value_at(&self, idx: usize) -> T { 67 | self.row_vec[idx] 68 | } 69 | 70 | fn vec(&self) -> &[T] { 71 | self.row_vec.as_slice() 72 | } 73 | 74 | fn get_item(&self) -> A { 75 | self.item.clone() 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/arkadia/neighbor.rs: -------------------------------------------------------------------------------- 1 | /// NB: Neighbor, search result 2 | /// (Data, and distance) 3 | use num::Float; 4 | 5 | pub struct NB { 6 | pub dist: T, 7 | pub item: A, 8 | } 9 | 10 | impl NB { 11 | pub fn to_item(self) -> A { 12 | self.item 13 | } 14 | 15 | pub fn to_dist(self) -> T { 16 | self.dist 17 | } 18 | 19 | pub fn to_pair(self) -> (T, A) { 20 | (self.dist, self.item) 21 | } 22 | /// Is the neighbor almost equal to the point itself? 23 | pub fn identity(&self) -> bool { 24 | self.dist <= T::epsilon() 25 | } 26 | } 27 | 28 | impl PartialEq for NB { 29 | fn eq(&self, other: &Self) -> bool { 30 | self.dist == other.dist 31 | } 32 | } 33 | 34 | impl PartialOrd for NB { 35 | fn partial_cmp(&self, other: &Self) -> Option { 36 | self.dist.partial_cmp(&other.dist) 37 | } 38 | } 39 | 40 | impl Eq for NB {} 41 | 42 | impl Ord for NB { 43 | // Unwrap is safe because in all use cases, the data should not contain any non-finite values. 44 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 45 | self.dist.partial_cmp(&other.dist).unwrap() 46 | } 47 | 48 | fn max(self, other: Self) -> Self 49 | where 50 | Self: Sized, 51 | { 52 | std::cmp::max_by(self, other, |a, b| a.dist.partial_cmp(&b.dist).unwrap()) 53 | } 54 | 55 | fn min(self, other: Self) -> Self 56 | where 57 | Self: Sized, 58 | { 59 | std::cmp::min_by(self, other, |a, b| a.dist.partial_cmp(&b.dist).unwrap()) 60 | } 61 | 62 | fn clamp(self, min: Self, max: Self) -> Self 63 | where 64 | Self: Sized, 65 | Self: PartialOrd, 66 | { 67 | assert!(min <= max); 68 | if self.dist < min.dist { 69 | min 70 | } else if self.dist > max.dist { 71 | max 72 | } else { 73 | self 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/arkadia/utils.rs: -------------------------------------------------------------------------------- 1 | use crate::arkadia::leaf::{Leaf, OwnedLeaf}; 2 | use num::Float; 3 | 4 | #[derive(Clone, Default)] 5 | pub enum SplitMethod { 6 | MIDPOINT, // min + (max - min) / 2 7 | #[default] 8 | MEDIAN, 9 | } 10 | 11 | impl From for SplitMethod { 12 | fn from(balanced: bool) -> Self { 13 | if balanced { 14 | Self::MEDIAN 15 | } else { 16 | Self::MIDPOINT 17 | } 18 | } 19 | } 20 | 21 | pub fn suggest_capacity(dim: usize) -> usize { 22 | if dim < 5 { 23 | 10 24 | } else if dim < 10 { 25 | 20 26 | } else if dim < 15 { 27 | 40 28 | } else if dim < 20 { 29 | 100 30 | } else { 31 | 4098 32 | } 33 | } 34 | 35 | pub fn slice_to_leaves<'a, T: Float + 'static, A: Copy>( 36 | slice: &'a [T], 37 | row_len: usize, 38 | values: &'a [A], 39 | ) -> Vec> { 40 | values 41 | .iter() 42 | .copied() 43 | .zip(slice.chunks_exact(row_len)) 44 | .map(|pair| pair.into()) 45 | .collect() 46 | } 47 | 48 | pub fn slice_to_owned_leaves( 49 | slice: &[T], 50 | row_len: usize, 51 | values: &[A], 52 | ) -> Vec> { 53 | values 54 | .iter() 55 | .copied() 56 | .zip(slice.chunks_exact(row_len)) 57 | .map(|pair| pair.into()) 58 | .collect() 59 | } 60 | 61 | pub fn slice_to_empty_leaves<'a, T: Float + 'static>( 62 | slice: &'a [T], 63 | row_len: usize, 64 | ) -> Vec> { 65 | slice 66 | .chunks_exact(row_len) 67 | .map(|row| ((), row).into()) 68 | .collect() 69 | } 70 | 71 | // pub fn matrix_to_leaves<'a, T: Float + 'static, A: Copy>( 72 | // matrix: &'a ArrayView2<'a, T>, 73 | // values: &'a [A], 74 | // ) -> Vec> { 75 | // values 76 | // .iter() 77 | // .copied() 78 | // .zip(matrix.rows()) 79 | // .map(|pair| pair.into()) 80 | // .collect::>() 81 | // } 82 | 83 | // pub fn matrix_to_leaves_w_row_num<'a, T: Float + 'static>( 84 | // matrix: &'a ArrayView2<'a, T>, 85 | // ) -> Vec> { 86 | // matrix 87 | // .rows() 88 | // .into_iter() 89 | // .enumerate() 90 | // .map(|pair| pair.into()) 91 | // .collect::>() 92 | // } 93 | 94 | // pub fn matrix_to_empty_leaves<'a, T: Float + 'static>( 95 | // matrix: &'a ArrayView2<'a, T>, 96 | // ) -> Vec> { 97 | // matrix 98 | // .rows() 99 | // .into_iter() 100 | // .map(|row| ((), row).into()) 101 | // .collect::>() 102 | // } 103 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(float_gamma)] 2 | 3 | mod arkadia; 4 | mod linalg; 5 | mod num_ext; 6 | mod pymodels; 7 | mod stats; 8 | mod stats_utils; 9 | mod str_ext; 10 | mod utils; 11 | 12 | use pyo3::{ 13 | pymodule, 14 | types::{PyModule, PyModuleMethods}, 15 | Bound, PyResult, Python, 16 | }; 17 | 18 | #[pymodule] 19 | #[pyo3(name = "_polars_ds")] 20 | fn _polars_ds(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { 21 | m.add_class::()?; 22 | m.add_class::()?; 23 | m.add_class::()?; 24 | m.add_class::()?; 25 | Ok(()) 26 | } 27 | 28 | use pyo3_polars::PolarsAllocator; 29 | #[global_allocator] 30 | static ALLOC: PolarsAllocator = PolarsAllocator::new(); 31 | -------------------------------------------------------------------------------- /src/linalg/link_functions.rs: -------------------------------------------------------------------------------- 1 | /// Defines the link functions for the GLM. 2 | use faer_traits::RealField; 3 | use num::Float; 4 | 5 | #[derive(Debug, Clone, Copy, PartialEq)] 6 | pub enum LinkFunction { 7 | Identity, // Normal 8 | Log, // Poisson 9 | Logit, // Binomial 10 | Inverse, // Gamma 11 | } 12 | 13 | impl LinkFunction { 14 | /// g(μ) 15 | pub fn link(&self, mu: T) -> T { 16 | match self { 17 | LinkFunction::Identity => mu, 18 | LinkFunction::Log => mu.ln(), 19 | LinkFunction::Logit => { 20 | // logit(p) = ln(p/(1-p)) 21 | let one = T::one(); 22 | (mu / (one - mu)).ln() 23 | }, 24 | LinkFunction::Inverse => mu.recip(), 25 | } 26 | } 27 | 28 | /// g^(-1)(η) 29 | pub fn inv_link(&self, eta: T) -> T { 30 | match self { 31 | LinkFunction::Identity => eta, 32 | LinkFunction::Log => eta.exp(), 33 | LinkFunction::Logit => { 34 | // inv_logit(x) = exp(x)/(1+exp(x)) 35 | let one = T::one(); 36 | let exp_eta = eta.exp(); 37 | exp_eta / (one + exp_eta) 38 | }, 39 | LinkFunction::Inverse => eta.recip(), 40 | } 41 | } 42 | 43 | /// Computes g'(μ) 44 | pub fn link_deriv(&self, mu: T) -> T { 45 | match self { 46 | LinkFunction::Identity => T::one(), 47 | LinkFunction::Log => mu.recip(), 48 | LinkFunction::Logit => { 49 | // d/dp logit(p) = 1/(p(1-p)) 50 | let one = T::one(); 51 | one / (mu * (one - mu)) 52 | }, 53 | LinkFunction::Inverse => -mu.recip().powi(2), 54 | } 55 | } 56 | } 57 | 58 | #[derive(Debug, Clone, Copy, PartialEq)] 59 | pub enum VarianceFunction { 60 | Gaussian, 61 | Poisson, 62 | Binomial, 63 | Gamma, 64 | } 65 | 66 | impl VarianceFunction { 67 | /// Variance(μ) 68 | pub fn variance(&self, mu: T) -> T { 69 | match self { 70 | VarianceFunction::Gaussian => T::one(), 71 | VarianceFunction::Poisson => mu, 72 | VarianceFunction::Binomial => { 73 | let one = T::one(); 74 | mu * (one - mu) 75 | }, 76 | VarianceFunction::Gamma => mu.powi(2), 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/num_ext/benford.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::list_u32_output; 2 | use num::Integer; 3 | use polars::prelude::*; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | fn first_digit(u: T) -> T { 7 | // Does the compiler know how to optimize this? 8 | let ten = (0..10).fold(T::zero(), |acc: T, _| acc + T::one()); 9 | let mut v = u; 10 | let mut d = T::zero(); 11 | while v != T::zero() { 12 | d = v % ten; 13 | v = v / ten; 14 | } 15 | d 16 | } 17 | 18 | #[polars_expr(output_type_func=list_u32_output)] 19 | fn pl_benford_law(inputs: &[Series]) -> PolarsResult { 20 | let mut out = [0; 10]; 21 | let s = &inputs[0]; 22 | match s.dtype() { 23 | DataType::UInt8 => { 24 | let ss = s.u8().unwrap(); 25 | for x in ss.into_no_null_iter() { 26 | let d = first_digit(x); 27 | out[d as usize] += 1; 28 | } 29 | } 30 | DataType::UInt16 => { 31 | let ss = s.u16().unwrap(); 32 | for x in ss.into_no_null_iter() { 33 | let d = first_digit(x); 34 | out[d as usize] += 1; 35 | } 36 | } 37 | DataType::UInt32 => { 38 | let ss = s.u32().unwrap(); 39 | for x in ss.into_no_null_iter() { 40 | let d = first_digit(x); 41 | out[d as usize] += 1; 42 | } 43 | } 44 | DataType::UInt64 => { 45 | let ss = s.u64().unwrap(); 46 | for x in ss.into_no_null_iter() { 47 | let d = first_digit(x); 48 | out[d as usize] += 1; 49 | } 50 | } 51 | DataType::Int8 => { 52 | let ss = s.i8().unwrap(); 53 | for x in ss.into_no_null_iter() { 54 | let d = first_digit(x); // could be negative 55 | out[d.abs() as usize] += 1; 56 | } 57 | } 58 | DataType::Int16 => { 59 | let ss = s.i16().unwrap(); 60 | for x in ss.into_no_null_iter() { 61 | let d = first_digit(x); // could be negative 62 | out[d.abs() as usize] += 1; 63 | } 64 | } 65 | DataType::Int32 => { 66 | let ss = s.i32().unwrap(); 67 | for x in ss.into_no_null_iter() { 68 | let d = first_digit(x); // could be negative 69 | out[d.abs() as usize] += 1; 70 | } 71 | } 72 | DataType::Int64 => { 73 | let ss = s.i64().unwrap(); 74 | for x in ss.into_no_null_iter() { 75 | let d = first_digit(x); // could be negative 76 | out[d.abs() as usize] += 1; 77 | } 78 | } 79 | DataType::Float32 => { 80 | let ss = s.f32().unwrap(); 81 | for x in ss.into_no_null_iter() { 82 | if x.is_finite() { 83 | let x_char = x 84 | .abs() 85 | .to_string() 86 | .chars() 87 | .find(|c| *c != '0' && *c != '.') 88 | .unwrap_or('0'); 89 | let idx = x_char.to_digit(10).unwrap() as usize; 90 | out[idx] += 1; 91 | } 92 | } 93 | } 94 | DataType::Float64 => { 95 | let ss = s.f64().unwrap(); 96 | for x in ss.into_no_null_iter() { 97 | if x.is_finite() { 98 | let x_char = x 99 | .abs() 100 | .to_string() 101 | .chars() 102 | .find(|c| *c != '0' && *c != '.') 103 | .unwrap_or('0'); 104 | let idx = x_char.to_digit(10).unwrap() as usize; 105 | out[idx] += 1; 106 | } 107 | } 108 | } 109 | _ => return Err(PolarsError::ComputeError("Invalid incoming type.".into())), 110 | } 111 | 112 | let mut list_builder: ListPrimitiveChunkedBuilder = 113 | ListPrimitiveChunkedBuilder::new("first_digit_count".into(), 1, 9, DataType::UInt32); 114 | 115 | list_builder.append_slice(&out[1..]); 116 | let out = list_builder.finish(); 117 | Ok(out.into_series()) 118 | } 119 | -------------------------------------------------------------------------------- /src/num_ext/cond_entropy.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::*; 2 | use pyo3_polars::derive::polars_expr; 3 | 4 | #[polars_expr(output_type=Float64)] 5 | fn pl_conditional_entropy(inputs: &[Series]) -> PolarsResult { 6 | let df = df!("x" => inputs[0].clone(), "y" => inputs[1].clone())?; 7 | let mut out = df 8 | .lazy() 9 | .group_by([col("x"), col("y")]) 10 | .agg([len().alias("cnt")]) 11 | .with_columns([ 12 | (col("cnt").sum().cast(DataType::Float64).over([col("y")]) 13 | / col("cnt").sum().cast(DataType::Float64)) 14 | .alias("p(y)"), 15 | (col("cnt").cast(DataType::Float64) / col("cnt").sum().cast(DataType::Float64)) 16 | .alias("p(x,y)"), 17 | ]) 18 | .select([(lit(-1.0_f64) 19 | * ((col("p(x,y)") / col("p(y)")) 20 | .log(std::f64::consts::E) 21 | .dot(col("p(x,y)")))) 22 | .alias("H(x|y)")]) 23 | .collect()?; 24 | 25 | out.drop_in_place("H(x|y)") 26 | .map(|s| s.as_materialized_series().clone()) 27 | } 28 | -------------------------------------------------------------------------------- /src/num_ext/fft.rs: -------------------------------------------------------------------------------- 1 | /// Performs forward FFT. 2 | /// Since data in dataframe are always real numbers, only realfft 3 | /// is implemented and. 5-10x slower than NumPy for small data (~ a few thousands rows) 4 | /// but is slighly faster once data gets bigger. 5 | use crate::utils::complex_output; 6 | use itertools::Either; 7 | use num::complex::Complex64; 8 | use polars::prelude::*; 9 | use pyo3_polars::derive::polars_expr; 10 | use realfft::RealFftPlanner; 11 | // Optimization ideas: small size, e.g. <= 2048, always allocate a fixed sized slice? 12 | // 2^n padding in the general case 13 | 14 | #[polars_expr(output_type_func=complex_output)] 15 | fn pl_rfft(inputs: &[Series]) -> PolarsResult { 16 | let s = inputs[0].f64()?; 17 | let n = inputs[1].u32()?; 18 | let mut n = n.get(0).unwrap_or(s.len() as u32) as usize; 19 | let return_full = inputs[2].bool()?; 20 | let return_full = return_full.get(0).unwrap_or(false); 21 | 22 | let mut input_vec = match s.to_vec_null_aware() { 23 | Either::Left(v) => Ok(v), 24 | Either::Right(_) => Err(PolarsError::ComputeError( 25 | "FFT: Input should not contain nulls.".into(), 26 | )), 27 | }?; 28 | 29 | if n > input_vec.len() { 30 | input_vec.extend(vec![0.; n.abs_diff(input_vec.len())]); 31 | } else { 32 | input_vec.truncate(n); 33 | } 34 | let input_len = input_vec.len(); 35 | 36 | let mut planner = RealFftPlanner::::new(); 37 | let r2c = planner.plan_fft_forward(input_len); 38 | 39 | let mut spectrum: Vec = r2c.make_output_vec(); 40 | let _ = r2c.process(&mut input_vec, &mut spectrum); 41 | 42 | n = if return_full { 43 | input_vec.len() // full length 44 | } else { 45 | spectrum.len() // simplified output of rfft 46 | }; 47 | 48 | let mut builder = 49 | ListPrimitiveChunkedBuilder::::new("complex".into(), n, 2, DataType::Float64); 50 | 51 | if return_full { 52 | for c in spectrum.iter() { 53 | builder.append_slice(&[c.re, c.im]) 54 | } 55 | if input_len % 2 == 0 { 56 | let take_n = (input_len >> 1).abs_diff(1); 57 | for c in spectrum.into_iter().rev().skip(1).take(take_n) { 58 | builder.append_slice(&[c.re, -c.im]); 59 | } 60 | } else { 61 | let take_n = input_len >> 1; 62 | for c in spectrum.into_iter().rev().take(take_n) { 63 | builder.append_slice(&[c.re, -c.im]); 64 | } 65 | } 66 | } else { 67 | for c in spectrum { 68 | builder.append_slice(&[c.re, c.im]) 69 | } 70 | } 71 | 72 | let out = builder.finish(); 73 | out.cast(&DataType::Array(Box::new(DataType::Float64), 2)) 74 | } 75 | -------------------------------------------------------------------------------- /src/num_ext/gcd_lcm.rs: -------------------------------------------------------------------------------- 1 | /// GCD and LCM for integers in dataframe. 2 | use polars::prelude::*; 3 | use polars_core::prelude::arity::binary_elementwise_values; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | #[polars_expr(output_type=Int32)] 7 | fn pl_gcd(inputs: &[Series]) -> PolarsResult { 8 | let ca1 = inputs[0].i32()?; 9 | let ca2 = inputs[1].i32()?; 10 | if ca2.len() == 1 { 11 | let b = ca2.get(0).unwrap(); 12 | let out: Int32Chunked = ca1.apply_values(|a| num::integer::gcd(a, b)); 13 | Ok(out.into_series()) 14 | } else if ca1.len() == ca2.len() { 15 | let out: Int32Chunked = binary_elementwise_values(ca1, ca2, num::integer::gcd); 16 | Ok(out.into_series()) 17 | } else { 18 | Err(PolarsError::ShapeMismatch( 19 | "Inputs must have the same length.".into(), 20 | )) 21 | } 22 | } 23 | 24 | #[polars_expr(output_type=Int32)] 25 | fn pl_lcm(inputs: &[Series]) -> PolarsResult { 26 | let ca1 = inputs[0].i32()?; 27 | let ca2 = inputs[1].i32()?; 28 | if ca2.len() == 1 { 29 | let b = ca2.get(0).unwrap(); 30 | let out: Int32Chunked = ca1.apply_values(|a| num::integer::lcm(a, b)); 31 | Ok(out.into_series()) 32 | } else if ca1.len() == ca2.len() { 33 | let out: Int32Chunked = binary_elementwise_values(ca1, ca2, num::integer::lcm); 34 | Ok(out.into_series()) 35 | } else { 36 | Err(PolarsError::ShapeMismatch( 37 | "Inputs must have the same length.".into(), 38 | )) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/num_ext/haversine.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::{float_output, haversine_elementwise}; 2 | use num::Float; 3 | use polars::prelude::*; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | fn naive_haversine( 7 | x_lat: &ChunkedArray, 8 | x_long: &ChunkedArray, 9 | y_lat: &ChunkedArray, 10 | y_long: &ChunkedArray, 11 | ) -> PolarsResult> 12 | where 13 | T: PolarsFloatType, 14 | T::Native: Float, 15 | { 16 | if (y_lat.len() == y_long.len()) && (y_lat.len() == 1) { 17 | let e_lat = y_lat.get(0).unwrap(); 18 | let e_long = y_long.get(0).unwrap(); 19 | let out: ChunkedArray = x_lat 20 | .into_iter() 21 | .zip(x_long.into_iter()) 22 | .map(|(x_lat, x_long)| { 23 | let x_lat = x_lat?; 24 | let x_long = x_long?; 25 | Some(haversine_elementwise(x_lat, x_long, e_lat, e_long)) 26 | }) 27 | .collect(); 28 | Ok(out) 29 | } else if x_lat.len() == x_long.len() 30 | && x_long.len() == y_lat.len() 31 | && y_lat.len() == y_long.len() 32 | { 33 | let out: ChunkedArray = x_lat 34 | .into_iter() 35 | .zip(x_long.into_iter()) 36 | .zip(y_lat.into_iter()) 37 | .zip(y_long.into_iter()) 38 | .map(|(((x_lat, x_long), y_lat), y_long)| { 39 | let x_lat = x_lat?; 40 | let x_long = x_long?; 41 | let y_lat = y_lat?; 42 | let y_long = y_long?; 43 | Some(haversine_elementwise(x_lat, x_long, y_lat, y_long)) 44 | }) 45 | .collect(); 46 | 47 | Ok(out) 48 | } else { 49 | Err(PolarsError::ShapeMismatch( 50 | "Inputs must have the same length or one of them must be a scalar.".into(), 51 | )) 52 | } 53 | } 54 | 55 | #[polars_expr(output_type_func=float_output)] 56 | fn pl_haversine(inputs: &[Series]) -> PolarsResult { 57 | let out = match inputs[0].dtype() { 58 | DataType::Float32 => { 59 | let x_lat = inputs[0].f32().unwrap(); 60 | let x_long = inputs[1].f32().unwrap(); 61 | let y_lat = inputs[2].f32().unwrap(); 62 | let y_long = inputs[3].f32().unwrap(); 63 | let out = naive_haversine(x_lat, x_long, y_lat, y_long)?; 64 | out.into_series() 65 | } 66 | DataType::Float64 => { 67 | let x_lat = inputs[0].f64().unwrap(); 68 | let x_long = inputs[1].f64().unwrap(); 69 | let y_lat = inputs[2].f64().unwrap(); 70 | let y_long = inputs[3].f64().unwrap(); 71 | let out = naive_haversine(x_lat, x_long, y_lat, y_long)?; 72 | out.into_series() 73 | } 74 | _ => return Err(PolarsError::ComputeError("Data type not supported.".into())), 75 | }; 76 | Ok(out) 77 | } 78 | -------------------------------------------------------------------------------- /src/num_ext/isotonic_regression.rs: -------------------------------------------------------------------------------- 1 | use itertools::Either; 2 | use polars::prelude::*; 3 | use pyo3_polars::derive::polars_expr; 4 | use serde::Deserialize; 5 | 6 | // Reference: https://github.com/scipy/scipy/blob/v1.14.1/scipy/optimize/_isotonic.py 7 | // https://github.com/scipy/scipy/blob/v1.14.1/scipy/optimize/_pava/pava_pybind.cpp 8 | // https://www.jstatsoft.org/article/view/v102c01 9 | 10 | // The code here has to be compiled in --release 11 | // Otherwise, a mysterious error will occur. Thank you compiler! 12 | 13 | #[derive(Deserialize, Debug)] 14 | pub(crate) struct IsotonicRegKwargs { 15 | pub(crate) has_weights: bool, 16 | pub(crate) increasing: bool, 17 | } 18 | 19 | fn isotonic_regression(x: &mut [f64], w: &mut [f64], r: &mut [usize]) { 20 | let n = x.len(); 21 | r[0] = 0; 22 | r[1] = 1; 23 | let mut b: usize = 0; 24 | let mut xb_pre = x[b]; 25 | let mut wb_pre = w[b]; 26 | 27 | let mut i: usize = 1; 28 | 29 | while i < n { 30 | b += 1; 31 | let mut xb = x[i]; 32 | let mut wb = w[i]; 33 | if xb_pre >= xb { 34 | b -= 1; 35 | let mut sb = wb_pre * xb_pre + wb * xb; 36 | wb += wb_pre; 37 | xb = sb / wb; 38 | while i + 1 < n && xb >= x[i + 1] { 39 | i += 1; 40 | sb += w[i] * x[i]; 41 | wb += w[i]; 42 | xb = sb / wb; 43 | } 44 | while b > 0 && x[b - 1] >= xb { 45 | b -= 1; 46 | sb += w[b] * x[b]; 47 | wb += w[b]; 48 | xb = sb / wb; 49 | } 50 | } 51 | xb_pre = xb; 52 | x[b] = xb; 53 | 54 | wb_pre = wb; 55 | w[b] = wb; 56 | 57 | r[b + 1] = i + 1; 58 | i += 1; 59 | } 60 | 61 | let mut f = n - 1; 62 | for k in (0..=b).rev() { 63 | // println!("{}", k); 64 | let t = r[k]; 65 | let xk = x[k]; 66 | for i in t..=f { 67 | x[i] = xk; 68 | } 69 | f = t - 1; 70 | } 71 | } 72 | 73 | #[polars_expr(output_type=Float64)] 74 | fn pl_isotonic_regression(inputs: &[Series], kwargs: IsotonicRegKwargs) -> PolarsResult { 75 | // Not sure why increasing = False doesn't give the right result 76 | let y = inputs[0].f64()?; 77 | let increasing = kwargs.increasing; 78 | 79 | if y.len() <= 1 { 80 | return Ok(y.clone().into_series()); 81 | } 82 | 83 | let mut y = match y.to_vec_null_aware() { 84 | Either::Left(v) => Ok(v), 85 | Either::Right(_) => Err(PolarsError::ComputeError( 86 | "Input should not contain nulls.".into(), 87 | )), 88 | }?; 89 | 90 | let has_weights = kwargs.has_weights; // True then weights are given, false then use 1s. 91 | let mut w = if has_weights { 92 | let weight = inputs[1].f64()?; 93 | if weight.len() != y.len() { 94 | Err(PolarsError::ComputeError( 95 | "Weights should not contain nulls and must be the same length as y.".into(), 96 | )) 97 | } else { 98 | match weight.to_vec_null_aware() { 99 | Either::Left(mut w) => { 100 | if w.iter().any(|x| *x <= 0.) { 101 | Err(PolarsError::ComputeError( 102 | "Weight should not contain negative values.".into(), 103 | )) 104 | } else { 105 | if !increasing { 106 | w.reverse(); 107 | } 108 | Ok(w) 109 | } 110 | } 111 | Either::Right(_) => Err(PolarsError::ComputeError( 112 | "Weight should not contain nulls.".into(), 113 | )), 114 | } 115 | } 116 | } else { 117 | Ok(vec![1f64; y.len()]) 118 | }?; 119 | 120 | if !increasing { 121 | y.reverse(); 122 | } 123 | 124 | let mut r = vec![0; y.len() + 1]; 125 | isotonic_regression(&mut y, &mut w, &mut r); 126 | if !increasing { 127 | y.reverse(); 128 | } 129 | 130 | let output = Float64Chunked::from_vec("".into(), y); 131 | Ok(output.into_series()) 132 | } 133 | -------------------------------------------------------------------------------- /src/num_ext/iterations.rs: -------------------------------------------------------------------------------- 1 | use itertools::Itertools; 2 | use polars::prelude::*; 3 | use pyo3_polars::derive::polars_expr; 4 | use serde::Deserialize; 5 | 6 | #[derive(Deserialize, Debug)] 7 | pub(crate) struct CombinationKwargs { 8 | pub(crate) k: usize, 9 | } 10 | 11 | fn itertools_output(fields: &[Field]) -> PolarsResult { 12 | Ok(Field::new( 13 | "".into(), 14 | DataType::List(Box::new(fields[0].dtype().clone())), 15 | )) 16 | } 17 | 18 | fn count_combinations(n: usize, k: usize) -> usize { 19 | if k > n { 20 | 0 21 | } else { 22 | (1..=k.min(n - k)).fold(1, |acc, val| acc * (n - val + 1) / val) 23 | } 24 | } 25 | 26 | fn get_combinations(ca: &ChunkedArray, k: usize) -> Series 27 | where 28 | T: PolarsNumericType, 29 | { 30 | let mut builder: ListPrimitiveChunkedBuilder = ListPrimitiveChunkedBuilder::new( 31 | "".into(), 32 | count_combinations(ca.len(), k), 33 | k, 34 | T::get_dtype(), 35 | ); 36 | 37 | for comb in ca.into_no_null_iter().combinations(k) { 38 | builder.append_slice(&comb); 39 | } 40 | 41 | let ca = builder.finish(); 42 | ca.into_series() 43 | } 44 | 45 | fn get_product(ca1: &ChunkedArray, ca2: &ChunkedArray) -> Series 46 | where 47 | T: PolarsNumericType, 48 | { 49 | let mut builder: ListPrimitiveChunkedBuilder = 50 | ListPrimitiveChunkedBuilder::new("".into(), ca1.len() * ca2.len(), 2, T::get_dtype()); 51 | 52 | for a in ca1.into_no_null_iter() { 53 | for b in ca2.into_no_null_iter() { 54 | builder.append_slice(&[a, b]); 55 | } 56 | } 57 | 58 | let ca = builder.finish(); 59 | ca.into_series() 60 | } 61 | 62 | fn get_combinations_str(ca: &StringChunked, k: usize) -> Series { 63 | let mut builder: ListStringChunkedBuilder = 64 | ListStringChunkedBuilder::new("".into(), count_combinations(ca.len(), k), k); 65 | 66 | for comb in ca.into_no_null_iter().combinations(k) { 67 | builder.append_values_iter(comb.into_iter()); 68 | } 69 | 70 | let ca = builder.finish(); 71 | ca.into_series() 72 | } 73 | 74 | fn get_product_str(ca1: &StringChunked, ca2: &StringChunked) -> Series { 75 | let mut builder: ListStringChunkedBuilder = 76 | ListStringChunkedBuilder::new("".into(), ca1.len() * ca2.len(), 2); 77 | 78 | for a in ca1.into_no_null_iter() { 79 | for b in ca2.into_no_null_iter() { 80 | builder.append_values_iter([a, b].into_iter()); 81 | } 82 | } 83 | 84 | let ca = builder.finish(); 85 | ca.into_series() 86 | } 87 | 88 | #[polars_expr(output_type_func=itertools_output)] 89 | fn pl_combinations(inputs: &[Series], kwargs: CombinationKwargs) -> PolarsResult { 90 | let s = &inputs[0]; 91 | let k = kwargs.k; 92 | 93 | if s.len() < k { 94 | return Err(PolarsError::ComputeError( 95 | "Source has < k (unique) values.".into(), 96 | )); 97 | } 98 | 99 | match s.dtype() { 100 | DataType::UInt8 => Ok(get_combinations(s.u8().unwrap(), k)), 101 | DataType::UInt16 => Ok(get_combinations(s.u16().unwrap(), k)), 102 | DataType::UInt32 => Ok(get_combinations(s.u32().unwrap(), k)), 103 | DataType::UInt64 => Ok(get_combinations(s.u64().unwrap(), k)), 104 | DataType::Int8 => Ok(get_combinations(s.i8().unwrap(), k)), 105 | DataType::Int16 => Ok(get_combinations(s.i16().unwrap(), k)), 106 | DataType::Int32 => Ok(get_combinations(s.i32().unwrap(), k)), 107 | DataType::Int64 => Ok(get_combinations(s.i64().unwrap(), k)), 108 | DataType::Int128 => Ok(get_combinations(s.i128().unwrap(), k)), 109 | DataType::Float32 => Ok(get_combinations(s.f32().unwrap(), k)), 110 | DataType::Float64 => Ok(get_combinations(s.f64().unwrap(), k)), 111 | DataType::String => Ok(get_combinations_str(s.str().unwrap(), k)), 112 | _ => Err(PolarsError::ComputeError("Unsupported data type.".into())), 113 | } 114 | } 115 | 116 | #[polars_expr(output_type_func=itertools_output)] 117 | fn pl_product(inputs: &[Series]) -> PolarsResult { 118 | let s1 = &inputs[0]; 119 | let s2 = &inputs[1]; 120 | 121 | if s1.dtype() != s2.dtype() { 122 | return Err(PolarsError::ComputeError( 123 | "Dtype of first input series is not the same as the second.".into(), 124 | )); 125 | } 126 | 127 | match s1.dtype() { 128 | DataType::UInt8 => Ok(get_product(s1.u8().unwrap(), s2.u8().unwrap())), 129 | DataType::UInt16 => Ok(get_product(s1.u16().unwrap(), s2.u16().unwrap())), 130 | DataType::UInt32 => Ok(get_product(s1.u32().unwrap(), s2.u32().unwrap())), 131 | DataType::UInt64 => Ok(get_product(s1.u64().unwrap(), s2.u64().unwrap())), 132 | DataType::Int8 => Ok(get_product(s1.i8().unwrap(), s2.i8().unwrap())), 133 | DataType::Int16 => Ok(get_product(s1.i16().unwrap(), s2.i16().unwrap())), 134 | DataType::Int32 => Ok(get_product(s1.i32().unwrap(), s2.i32().unwrap())), 135 | DataType::Int64 => Ok(get_product(s1.i64().unwrap(), s2.i64().unwrap())), 136 | DataType::Int128 => Ok(get_product(s1.i128().unwrap(), s2.i128().unwrap())), 137 | DataType::Float32 => Ok(get_product(s1.f32().unwrap(), s2.f32().unwrap())), 138 | DataType::Float64 => Ok(get_product(s1.f64().unwrap(), s2.f64().unwrap())), 139 | DataType::String => Ok(get_product_str(s1.str().unwrap(), s2.str().unwrap())), 140 | _ => Err(PolarsError::ComputeError("Unsupported data type.".into())), 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/num_ext/lempel_ziv.rs: -------------------------------------------------------------------------------- 1 | use hashbrown::HashSet; 2 | use polars::prelude::*; 3 | use pyo3_polars::derive::polars_expr; 4 | 5 | #[polars_expr(output_type=UInt32)] 6 | fn pl_lempel_ziv_complexity(inputs: &[Series]) -> PolarsResult { 7 | let bools = inputs[0].bool()?; 8 | let bits: Vec = bools 9 | .into_iter() 10 | .map(|op_b| op_b.unwrap_or_default()) 11 | .collect(); 12 | 13 | let mut ind: usize = 0; 14 | let mut inc: usize = 1; 15 | let mut sub_strings: HashSet<&[bool]> = HashSet::new(); 16 | while ind + inc <= bits.len() { 17 | let subseq: &[bool] = &bits[ind..ind + inc]; 18 | if sub_strings.contains(subseq) { 19 | inc += 1; 20 | } else { 21 | sub_strings.insert(subseq); 22 | ind += inc; 23 | inc = 1; 24 | } 25 | } 26 | let c = sub_strings.len(); 27 | Ok(Series::from_iter([c as u32])) 28 | } 29 | -------------------------------------------------------------------------------- /src/num_ext/mod.rs: -------------------------------------------------------------------------------- 1 | mod benford; 2 | mod cond_entropy; 3 | mod convolve; 4 | mod entrophies; 5 | mod fft; 6 | mod float_extras; 7 | mod gcd_lcm; 8 | mod haversine; 9 | mod isotonic_regression; 10 | mod iterations; 11 | mod jaccard; 12 | mod knn; 13 | mod lempel_ziv; 14 | mod linear_regression; 15 | mod linear_regression_f32; 16 | mod mutual_info; 17 | mod pca; 18 | mod psi; 19 | mod subseq_sim; 20 | mod target_encode; 21 | mod tp_fp; 22 | mod trapz; 23 | mod welch; 24 | mod woe_iv; 25 | 26 | // mod ball_tree; 27 | -------------------------------------------------------------------------------- /src/num_ext/mutual_info.rs: -------------------------------------------------------------------------------- 1 | // use itertools::Itertools; 2 | // use polars::prelude::*; 3 | // use polars_core::utils::rayon::iter::{ParallelBridge, ParallelIterator}; 4 | // use pyo3_polars::derive::polars_expr; 5 | // use ordered_float::OrderedFloat; 6 | // use serde::Deserialize; 7 | 8 | // #[derive(Deserialize)] 9 | // pub(crate) struct KthNBKwargs { 10 | // pub(crate) k: usize, 11 | // parallel: bool, 12 | // } 13 | 14 | // fn dist_from_kth_nb(data: &[f64], x:f64, k:usize) -> f64 { 15 | // // Distance from the kth Neighbor 16 | // // Not the most efficient 17 | // // Doesn't dedup if there are identical distances 18 | // if k >= data.len() { 19 | // f64::NAN 20 | // } else { 21 | // let x = OrderedFloat(x); 22 | // let ordered_data = unsafe { 23 | // std::mem::transmute::<&[f64], &[OrderedFloat]>(data) 24 | // }; 25 | 26 | // let index = match ordered_data.binary_search(&x) { 27 | // Ok(i) => i, 28 | // Err(j) => j 29 | // }; 30 | // let min_i = index.saturating_sub(k); 31 | // let max_i = (index + k + 1).min(data.len()); 32 | 33 | // let distances = (min_i..max_i) 34 | // .map(|i| OrderedFloat((x - data[i]).abs())) 35 | // .sorted_unstable() 36 | // .collect::>(); 37 | 38 | // *distances[k] 39 | 40 | // // let mut rank = (min_i..max_i).map(|i| (x - data[i]).abs()).collect::>(); 41 | // // rank.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap()); 42 | // // match rank.get(k) { 43 | // // Some(x) => *x, 44 | // // None => { 45 | // // println!("Rank {:?}", rank); 46 | // // println!("Min I {:?}", min_i); 47 | // // println!("Max I {:?}", max_i); 48 | // // f64::NAN 49 | // // }, 50 | // // } 51 | // } 52 | // // *(rank.get(k).unwrap_or(&f64::NAN)) 53 | // } 54 | 55 | // #[polars_expr(output_type=Float64)] 56 | // fn _pl_dist_from_kth_nb(inputs: &[Series], kwargs: KthNBKwargs) -> PolarsResult { 57 | // // k-th nearest neighbor (1d) is a quantity needed during the computation of Mutual Info Score 58 | // // X: NaN filled with Null in Python 59 | // // This is a special helper function only used in mutual_info_score 60 | // let x = inputs[0].f64()?; 61 | // let data = x.drop_nulls().sort(false); 62 | // let data = data.cont_slice()?; 63 | // let k = kwargs.k; 64 | 65 | // let output = if kwargs.parallel { 66 | // x 67 | // .into_iter() 68 | // .par_bridge() 69 | // .map(|op_y| 70 | // op_y.map(|y| dist_from_kth_nb(data, y, k)) 71 | // ).collect::() 72 | // } else { 73 | // x.apply_values(|y| dist_from_kth_nb(data, y, k)) 74 | // }; 75 | 76 | // Ok(output.into_series()) 77 | // } 78 | -------------------------------------------------------------------------------- /src/num_ext/pca.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::{to_f64_vec_fail_on_nulls, to_f64_vec_without_nulls, IndexOrder}; 2 | use faer::{ 3 | dyn_stack::{MemBuffer, MemStack}, 4 | linalg::svd::ComputeSvdVectors, 5 | prelude::*, 6 | }; 7 | use polars::prelude::*; 8 | use pyo3_polars::derive::polars_expr; 9 | 10 | pub fn singular_values_output(_: &[Field]) -> PolarsResult { 11 | Ok(Field::new( 12 | "singular_values".into(), 13 | DataType::List(Box::new(DataType::Float64)), 14 | )) 15 | } 16 | 17 | pub fn pca_output(_: &[Field]) -> PolarsResult { 18 | let singular_value = Field::new("singular_value".into(), DataType::Float64); 19 | let weights = Field::new( 20 | "weight_vector".into(), 21 | DataType::List(Box::new(DataType::Float64)), 22 | ); 23 | Ok(Field::new( 24 | "pca".into(), 25 | DataType::Struct(vec![singular_value, weights]), 26 | )) 27 | } 28 | 29 | pub fn principal_components_output(fields: &[Field]) -> PolarsResult { 30 | let components: Vec<_> = (0..fields.len()) 31 | .map(|i| Field::new(format!("pc{}", i + 1).into(), DataType::Float64)) 32 | .collect(); 33 | Ok(Field::new( 34 | "principal_components".into(), 35 | DataType::Struct(components), 36 | )) 37 | } 38 | 39 | #[polars_expr(output_type_func=singular_values_output)] 40 | fn pl_singular_values(inputs: &[Series]) -> PolarsResult { 41 | let nrows = inputs[0].len(); 42 | let mat_slice = to_f64_vec_without_nulls(inputs, IndexOrder::Fortran)?; 43 | let mat = MatRef::from_column_major_slice(&mat_slice, nrows, mat_slice.len() / nrows); 44 | 45 | let (m, n) = mat.shape(); 46 | let compute = ComputeSvdVectors::Thin; 47 | 48 | let dim = Ord::min(mat.nrows(), mat.ncols()); 49 | 50 | let mut s = vec![0f64; dim]; 51 | let cs = ColMut::from_slice_mut(&mut s); 52 | 53 | let par = Par::rayon(0); 54 | 55 | faer::linalg::svd::svd( 56 | mat, 57 | cs.as_diagonal_mut(), 58 | None, 59 | None, 60 | par, 61 | MemStack::new(&mut MemBuffer::new(faer::linalg::svd::svd_scratch::( 62 | m, 63 | n, 64 | compute, 65 | compute, 66 | par, 67 | default(), 68 | ))), 69 | Default::default(), 70 | ) 71 | .map_err(|_| PolarsError::ComputeError("SVD algorithm did not converge.".into()))?; 72 | 73 | let mut list_builder: ListPrimitiveChunkedBuilder = 74 | ListPrimitiveChunkedBuilder::new("singular_values".into(), 1, dim, DataType::Float64); 75 | 76 | list_builder.append_slice(&s); 77 | let out = list_builder.finish(); 78 | Ok(out.into_series()) 79 | } 80 | 81 | #[polars_expr(output_type_func=principal_components_output)] 82 | fn pl_principal_components(inputs: &[Series]) -> PolarsResult { 83 | let k = inputs[0].u32()?; 84 | let k = k.get(0).unwrap() as usize; 85 | 86 | let nrows = inputs[1].len(); 87 | let mat_slice = to_f64_vec_fail_on_nulls(&inputs[1..], IndexOrder::Fortran)?; 88 | let mat = MatRef::from_column_major_slice(&mat_slice, nrows, mat_slice.len() / nrows); 89 | 90 | let columns = if nrows < k { 91 | (0..k) 92 | .map(|i| Series::from_vec(format!("pc{}", i + 1).into(), vec![f64::NAN]).into_column()) 93 | .collect::>() 94 | } else { 95 | let (m, n) = mat.shape(); 96 | let compute = ComputeSvdVectors::Thin; 97 | let dim = Ord::min(mat.nrows(), mat.ncols()); 98 | let mut s = Col::::zeros(dim); 99 | let mut v = Mat::::zeros(dim, dim); 100 | let par = Par::rayon(0); 101 | faer::linalg::svd::svd( 102 | mat, 103 | s.as_diagonal_mut(), 104 | None, 105 | Some(v.as_mut()), 106 | par, 107 | MemStack::new(&mut MemBuffer::new(faer::linalg::svd::svd_scratch::( 108 | m, 109 | n, 110 | compute, 111 | compute, 112 | par, 113 | default(), 114 | ))), 115 | Default::default(), 116 | ) 117 | .map_err(|_| PolarsError::ComputeError("SVD algorithm did not converge.".into()))?; 118 | 119 | let components = mat * v; 120 | 121 | (0..k) 122 | .map(|i| { 123 | let name = format!("pc{}", i + 1); 124 | let s = Float64Chunked::from_slice(name.into(), components.col_as_slice(i)); 125 | s.into_series().into_column() 126 | }) 127 | .collect::>() 128 | }; 129 | 130 | let ca = StructChunked::from_columns("principal_components".into(), nrows, &columns)?; 131 | Ok(ca.into_series()) 132 | } 133 | 134 | #[polars_expr(output_type_func=pca_output)] 135 | fn pl_pca(inputs: &[Series]) -> PolarsResult { 136 | let nrows = inputs[0].len(); 137 | let mat_slice = to_f64_vec_without_nulls(inputs, IndexOrder::Fortran)?; 138 | let mat = MatRef::from_column_major_slice(&mat_slice, nrows, mat_slice.len() / nrows); 139 | 140 | let (m, n) = mat.shape(); 141 | let dim = Ord::min(mat.nrows(), mat.ncols()); 142 | let mut s = vec![0f64; dim]; 143 | let cs = ColMut::from_slice_mut(&mut s); 144 | let mut v = Mat::::zeros(dim, dim); 145 | let par = Par::rayon(0); 146 | let compute = ComputeSvdVectors::Thin; 147 | 148 | faer::linalg::svd::svd( 149 | mat, 150 | cs.as_diagonal_mut(), 151 | None, 152 | Some(v.as_mut()), 153 | par, 154 | MemStack::new(&mut MemBuffer::new(faer::linalg::svd::svd_scratch::( 155 | m, 156 | n, 157 | compute, 158 | compute, 159 | par, 160 | default(), 161 | ))), 162 | Default::default(), 163 | ) 164 | .map_err(|_| PolarsError::ComputeError("SVD algorithm did not converge.".into()))?; 165 | 166 | let mut builder: PrimitiveChunkedBuilder = 167 | PrimitiveChunkedBuilder::new("singular_value".into(), dim); 168 | let mut list_builder: ListPrimitiveChunkedBuilder = 169 | ListPrimitiveChunkedBuilder::new("weight_vector".into(), dim, dim, DataType::Float64); 170 | 171 | for i in 0..v.nrows() { 172 | builder.append_value(s[i]); 173 | list_builder.append_slice(v.col_as_slice(i)); 174 | } 175 | 176 | let out1 = builder.finish(); 177 | let out2 = list_builder.finish(); 178 | let out = StructChunked::from_series( 179 | "pca".into(), 180 | out1.len(), 181 | [&out1.into_series(), &out2.into_series()].into_iter(), 182 | )?; 183 | Ok(out.into_series()) 184 | } 185 | -------------------------------------------------------------------------------- /src/num_ext/psi.rs: -------------------------------------------------------------------------------- 1 | use ordered_float::OrderedFloat; 2 | use polars::prelude::*; 3 | use pyo3_polars::derive::polars_expr; 4 | 5 | fn psi_report_output(_: &[Field]) -> PolarsResult { 6 | let breakpoints = Field::new("<=".into(), DataType::Float64); 7 | let baseline_pct = Field::new("baseline_pct".into(), DataType::Float64); 8 | let actual_pct = Field::new("actual_pct".into(), DataType::Float64); 9 | let psi_bins = Field::new("psi_bin".into(), DataType::Float64); 10 | let v: Vec = vec![breakpoints, baseline_pct, actual_pct, psi_bins]; 11 | Ok(Field::new("psi_report".into(), DataType::Struct(v))) 12 | } 13 | 14 | /// Computes counts in each bucket given by the breakpoints in 15 | /// a PSI computation. This returns the count for the first series 16 | /// and the count for the second series. 17 | /// This assumes the breakpoints (bp)'s last value is always INF 18 | #[inline(always)] 19 | fn psi_with_bps_helper(s: &[f64], bp: &[f64]) -> Vec { 20 | // s: data 21 | // bp: breakpoints 22 | 23 | // safe, data at this stage is gauranteed to be finite 24 | let s = unsafe { std::mem::transmute::<&[f64], &[OrderedFloat]>(s) }; 25 | 26 | let bp = unsafe { std::mem::transmute::<&[f64], &[OrderedFloat]>(bp) }; 27 | 28 | let mut c = vec![0u32; bp.len()]; 29 | for x in s { 30 | let i = match bp.binary_search(x) { 31 | Ok(j) => j, 32 | Err(k) => k, 33 | }; 34 | c[i] += 1; 35 | } 36 | c 37 | } 38 | 39 | /// Helper function to create PSI reports for numeric PSI computations. 40 | #[inline(always)] 41 | fn psi_frame(bp: &[f64], bp_name: &str, cnt1: &[u32], cnt2: &[u32]) -> PolarsResult { 42 | let b = Float64Chunked::from_slice("".into(), bp); 43 | let c1 = UInt32Chunked::from_slice("".into(), cnt1); 44 | let c2 = UInt32Chunked::from_slice("".into(), cnt2); 45 | 46 | let df = df!( 47 | bp_name => b, 48 | "cnt_baseline" => c1, 49 | "cnt_actual" => c2, 50 | )? 51 | .lazy(); 52 | 53 | Ok(df 54 | .with_columns([ 55 | (col("cnt_baseline").cast(DataType::Float64) 56 | / col("cnt_baseline").sum().cast(DataType::Float64)) 57 | .clip_min(lit(0.0001)) 58 | .alias("baseline_pct"), 59 | (col("cnt_actual").cast(DataType::Float64) 60 | / col("cnt_actual").sum().cast(DataType::Float64)) 61 | .clip_min(lit(0.0001)) 62 | .alias("actual_pct"), 63 | ]) 64 | .select([ 65 | col(bp_name), 66 | col("baseline_pct"), 67 | col("actual_pct"), 68 | ((col("baseline_pct") - col("actual_pct")) 69 | * ((col("baseline_pct") / col("actual_pct")).log(std::f64::consts::E))) 70 | .alias("psi_bin"), 71 | ])) 72 | } 73 | 74 | /// Computs PSI with custom breakpoints and returns a report 75 | #[polars_expr(output_type_func=psi_report_output)] 76 | fn pl_psi_w_bps(inputs: &[Series]) -> PolarsResult { 77 | let data1 = inputs[0].f64().unwrap(); 78 | let data2 = inputs[1].f64().unwrap(); 79 | let breakpoints = inputs[2].f64().unwrap(); 80 | 81 | let binding = data1.rechunk(); 82 | let s1 = binding.cont_slice().unwrap(); 83 | let binding = data2.rechunk(); 84 | let s2 = binding.cont_slice().unwrap(); 85 | 86 | let binding = breakpoints.rechunk(); 87 | let bp = binding.cont_slice().unwrap(); 88 | 89 | let c1 = psi_with_bps_helper(s1, bp); 90 | let c2 = psi_with_bps_helper(s2, bp); 91 | 92 | let psi_report = psi_frame(bp, "<=", &c1, &c2)?.collect()?; 93 | Ok(psi_report.into_struct("".into()).into_series()) 94 | } 95 | 96 | /// Numeric PSI report 97 | #[polars_expr(output_type_func=psi_report_output)] 98 | fn pl_psi_report(inputs: &[Series]) -> PolarsResult { 99 | // The new data 100 | let new = inputs[0].f64().unwrap(); 101 | // The breaks learned from baseline/reference 102 | let brk = inputs[1].f64().unwrap(); 103 | // The cnts for the baseline/reference 104 | let cnt = inputs[2].u32().unwrap(); 105 | 106 | let binding = new.rechunk(); 107 | let data_new = binding.cont_slice().unwrap(); 108 | let binding = brk.rechunk(); 109 | let ref_brk = binding.cont_slice().unwrap(); 110 | let binding = cnt.rechunk(); 111 | let ref_cnt = binding.cont_slice().unwrap(); 112 | 113 | let new_cnt = psi_with_bps_helper(data_new, ref_brk); 114 | let psi_report = psi_frame(ref_brk, "<=", ref_cnt, &new_cnt)?.collect()?; 115 | 116 | Ok(psi_report.into_struct("".into()).into_series()) 117 | } 118 | 119 | /// Discrete PSI report 120 | #[polars_expr(output_type_func=psi_report_output)] 121 | fn pl_psi_discrete_report(inputs: &[Series]) -> PolarsResult { 122 | let df1 = df!( 123 | "actual_cat" => &inputs[0], // data cats 124 | "actual_cnt" => &inputs[1], // data cnt 125 | )?; 126 | let df2 = df!( 127 | "baseline_cat" => &inputs[2], // ref cats 128 | "baseline_cnt" => &inputs[3], // ref cnt 129 | )?; 130 | 131 | let psi_report = df1 132 | .lazy() 133 | .join( 134 | df2.lazy(), 135 | [col("actual_cat")], 136 | [col("baseline_cat")], 137 | JoinArgs::new(JoinType::Full), 138 | ) 139 | .with_columns([ 140 | col("baseline_cnt").fill_null(0), 141 | col("actual_cnt").fill_null(0), 142 | ]) 143 | .with_columns([ 144 | (col("baseline_cnt").cast(DataType::Float64) 145 | / col("baseline_cnt").sum().cast(DataType::Float64)) 146 | .clip_min(lit(0.0001)) 147 | .alias("baseline_pct"), 148 | (col("actual_cnt").cast(DataType::Float64) 149 | / col("actual_cnt").sum().cast(DataType::Float64)) 150 | .clip_min(lit(0.0001)) 151 | .alias("actual_pct"), 152 | ]) 153 | .select([ 154 | col("baseline_cat").alias("baseline_category"), 155 | col("actual_cat").alias("actual_category"), 156 | col("baseline_pct"), 157 | col("actual_pct"), 158 | ((col("baseline_pct") - col("actual_pct")) 159 | * ((col("baseline_pct") / col("actual_pct")).log(std::f64::consts::E))) 160 | .alias("psi_bin"), 161 | ]) 162 | .collect()?; 163 | 164 | Ok(psi_report.into_struct("".into()).into_series()) 165 | } 166 | -------------------------------------------------------------------------------- /src/num_ext/subseq_sim.rs: -------------------------------------------------------------------------------- 1 | /// Subsequence similarity related queries 2 | use cfavml; 3 | use polars::prelude::*; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | utils::rayon::{ 8 | iter::{IntoParallelIterator, ParallelIterator}, 9 | slice::ParallelSlice, 10 | }, 11 | POOL, 12 | }, 13 | }; 14 | use serde::Deserialize; 15 | 16 | #[derive(Deserialize, Debug)] 17 | pub(crate) struct SubseqQueryKwargs { 18 | pub(crate) threshold: f64, 19 | pub(crate) parallel: bool, 20 | } 21 | 22 | #[polars_expr(output_type=UInt32)] 23 | fn pl_subseq_sim_cnt_l2( 24 | inputs: &[Series], 25 | context: CallerContext, 26 | kwargs: SubseqQueryKwargs, 27 | ) -> PolarsResult { 28 | let seq = inputs[0].f64()?; 29 | let seq = seq.cont_slice().unwrap(); 30 | let query = inputs[1].f64()?; 31 | let query = query.cont_slice().unwrap(); 32 | 33 | if query.len() > seq.len() { 34 | return Err(PolarsError::ComputeError( 35 | "Not enough data points for the query.".into(), 36 | )); 37 | } 38 | 39 | let threshold = kwargs.threshold; 40 | let par = kwargs.parallel && !context.parallel(); 41 | let window_size = query.len(); 42 | 43 | let n = if par { 44 | seq.par_windows(window_size) 45 | .map(|w| (cfavml::squared_euclidean(query, w) < threshold) as u32) 46 | .sum() 47 | } else { 48 | if window_size < 16 { 49 | seq.windows(window_size).fold(0u32, |acc, w| { 50 | let d = w 51 | .into_iter() 52 | .copied() 53 | .zip(query.into_iter().copied()) 54 | .fold(0., |acc, (x, y)| acc + (x - y) * (x - y)); 55 | acc + (d < threshold) as u32 56 | }) 57 | } else { 58 | seq.windows(window_size).fold(0u32, |acc, w| { 59 | acc + (cfavml::squared_euclidean(query, w) < threshold) as u32 60 | }) 61 | } 62 | }; 63 | 64 | let output = UInt32Chunked::from_slice("".into(), &[n]); 65 | Ok(output.into_series()) 66 | } 67 | 68 | #[polars_expr(output_type=UInt32)] 69 | fn pl_subseq_sim_cnt_zl2( 70 | inputs: &[Series], 71 | context: CallerContext, 72 | kwargs: SubseqQueryKwargs, 73 | ) -> PolarsResult { 74 | let seq = inputs[0].f64()?; 75 | let seq = seq.cont_slice().unwrap(); 76 | let query = inputs[1].f64()?; // is already z normalized 77 | let query = query.cont_slice().unwrap(); 78 | 79 | let rolling_mean = inputs[2].f64()?; 80 | let rolling_mean = rolling_mean.cont_slice()?; 81 | let rolling_var = inputs[3].f64()?; 82 | let rolling_var = rolling_var.cont_slice()?; 83 | 84 | let threshold = kwargs.threshold; 85 | let par = kwargs.parallel && !context.parallel(); 86 | let window_size = query.len(); 87 | 88 | let total_windows = seq.len() + 1 - window_size; 89 | 90 | let n = if par { 91 | let n_threads = POOL.current_num_threads(); 92 | let windows = seq.windows(window_size).collect::>(); 93 | let splits = crate::utils::split_offsets(total_windows, n_threads); 94 | splits 95 | .into_par_iter() 96 | .map(|(offset, len)| { 97 | let mut acc: u32 = 0; 98 | for (i, w) in windows[offset..offset + len].iter().enumerate() { 99 | let actual_i = i + offset; 100 | let normalized = w 101 | .iter() 102 | .map(|x| (x - rolling_mean[actual_i]) / rolling_var[actual_i].sqrt()) 103 | .collect::>(); 104 | acc += (cfavml::squared_euclidean(query, &normalized) < threshold) as u32; 105 | } 106 | acc 107 | }) 108 | .sum() 109 | } else { 110 | seq.windows(window_size) 111 | .enumerate() 112 | .fold(0u32, |acc, (i, w)| { 113 | let normalized = w 114 | .iter() 115 | .map(|x| (x - rolling_mean[i]) / rolling_var[i].sqrt()) 116 | .collect::>(); 117 | acc + (cfavml::squared_euclidean(query, &normalized) < threshold) as u32 118 | }) 119 | }; 120 | 121 | let output = UInt32Chunked::from_slice("".into(), &[n]); 122 | Ok(output.into_series()) 123 | } 124 | -------------------------------------------------------------------------------- /src/num_ext/target_encode.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::*; 2 | use pyo3_polars::derive::polars_expr; 3 | use serde::Deserialize; 4 | 5 | fn target_encode_output(_: &[Field]) -> PolarsResult { 6 | let values = Field::new("value".into(), DataType::String); 7 | let to = Field::new("to".into(), DataType::Float64); 8 | let v: Vec = vec![values, to]; 9 | Ok(Field::new("target_encoded".into(), DataType::Struct(v))) 10 | } 11 | 12 | #[derive(Deserialize, Debug)] 13 | pub(crate) struct TargetEncodeKwargs { 14 | pub(crate) min_samples_leaf: f64, 15 | pub(crate) smoothing: f64, 16 | } 17 | 18 | #[inline(always)] 19 | fn get_target_encode_frame( 20 | discrete_col: &Series, 21 | target: &Series, 22 | target_mean: f64, 23 | min_samples_leaf: f64, 24 | smoothing: f64, 25 | ) -> PolarsResult { 26 | let df = df!( 27 | "value" => discrete_col.cast(&DataType::String)?, 28 | "target" => target 29 | )?; 30 | 31 | Ok(df 32 | .lazy() 33 | .drop_nulls(None) 34 | .group_by([col("value")]) 35 | .agg([len().alias("cnt"), col("target").mean().alias("cond_p")]) 36 | .with_column( 37 | (lit(1f64) 38 | / (lit(1f64) 39 | + ((-(col("cnt").cast(DataType::Float64) - lit(min_samples_leaf)) 40 | / lit(smoothing)) 41 | .exp()))) 42 | .alias("alpha"), 43 | ) 44 | .select([ 45 | col("value"), 46 | (col("alpha") * col("cond_p") + (lit(1f64) - col("alpha")) * lit(target_mean)) 47 | .alias("to"), 48 | ])) 49 | } 50 | 51 | #[polars_expr(output_type_func=target_encode_output)] 52 | fn pl_target_encode(inputs: &[Series], kwargs: TargetEncodeKwargs) -> PolarsResult { 53 | // Inputs[0] and inputs[1] are the string column and the target respectively 54 | 55 | let target_mean = inputs[2].f64()?; 56 | let target_mean = target_mean.get(0).unwrap(); 57 | 58 | let min_samples_leaf = kwargs.min_samples_leaf; 59 | let smoothing = kwargs.smoothing; 60 | 61 | let encoding_frame = get_target_encode_frame( 62 | &inputs[0], 63 | &inputs[1], 64 | target_mean, 65 | min_samples_leaf, 66 | smoothing, 67 | )? 68 | .collect()?; 69 | 70 | Ok(encoding_frame 71 | .into_struct("target_encoded".into()) 72 | .into_series()) 73 | } 74 | -------------------------------------------------------------------------------- /src/num_ext/trapz.rs: -------------------------------------------------------------------------------- 1 | /// Integration via Trapezoidal rule. 2 | use cfavml; 3 | use polars::prelude::*; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | #[inline(always)] 7 | pub fn trapz(y: &[f64], x: &[f64]) -> f64 { 8 | // x.len() == y.len() checked 9 | if x.len() == 1 && y.len() == 1 { 10 | y[0] * x[0] * -0.5 // y[0] * (-x[0]) * 0.5 11 | } else { 12 | let mut y_d = vec![0.; y.len() - 1]; 13 | cfavml::add_vector(&y[1..], &y[..y.len() - 1], &mut y_d); 14 | let mut x_d = vec![0.; y.len() - 1]; 15 | cfavml::sub_vector(&x[1..], &x[..x.len() - 1], &mut x_d); 16 | 0.5 * cfavml::dot(&y_d, &x_d) 17 | } 18 | } 19 | 20 | pub fn trapz_dx(y: &[f64], dx: f64) -> f64 { 21 | let s = y[1..y.len() - 1].iter().sum::(); 22 | let ss = 0.5 * (y.get(0).unwrap_or(&0.) + y.last().unwrap_or(&0.)); 23 | dx * (s + ss) 24 | } 25 | 26 | #[polars_expr(output_type=Float64)] 27 | fn pl_trapz(inputs: &[Series]) -> PolarsResult { 28 | let y = inputs[0].f64()?; 29 | let x = inputs[1].f64()?; 30 | if y.len() < 1 || x.has_nulls() || y.has_nulls() { 31 | let ca = Float64Chunked::from_slice("".into(), &[f64::NAN]); 32 | return Ok(ca.into_series()); 33 | } 34 | 35 | let y = y.cont_slice()?; 36 | if x.len() == 1 && y.len() > 1 { 37 | let dx = x.get(0).unwrap(); 38 | let ca = Float64Chunked::from_slice("".into(), &[trapz_dx(y, dx)]); 39 | Ok(ca.into_series()) 40 | } else if x.len() == y.len() { 41 | let x = x.cont_slice()?; 42 | let ca = Float64Chunked::from_slice("".into(), &[trapz(y, x)]); 43 | Ok(ca.into_series()) 44 | } else { 45 | Err(PolarsError::ComputeError( 46 | "Input must have the same length or x must be a scalar.".into(), 47 | )) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/num_ext/welch.rs: -------------------------------------------------------------------------------- 1 | // use num::{complex::Complex64, Zero}; 2 | // use polars::prelude::*; 3 | // use pyo3_polars::derive::polars_expr; 4 | // use rustfft::FftPlanner; 5 | // use serde::Deserialize; 6 | 7 | // #[derive(Deserialize, Debug)] 8 | // pub(crate) struct WelchKwargs { 9 | // pub(crate) window_size: usize, 10 | // pub(crate) overlap_size: usize, 11 | // } 12 | 13 | // fn welch(s:&[f64], window_length:usize, overlap_size:usize) -> f64 { 14 | 15 | // let mut left_idx: usize = 0; 16 | // let mut planner = FftPlanner::::new(); 17 | // let r2c = planner.plan_fft_forward(window_length); 18 | // let mut scratch = vec![Complex64::zero(); r2c.get_inplace_scratch_len()]; 19 | // let mut sums = 0f64; 20 | // let mut count = 0u32; 21 | // while left_idx + window_length <= s.len() { 22 | 23 | // let right_idx = left_idx + window_length; 24 | // let mut vec = s[left_idx..right_idx] 25 | // .iter() 26 | // .map(|x| x.into()) 27 | // .collect::>(); 28 | 29 | // let _ = r2c.process_with_scratch(&mut vec, &mut scratch); 30 | // sums += vec 31 | // .into_iter() 32 | // .fold(0f64, |acc, z| acc + z.re * z.re + z.im * z.im); 33 | 34 | // count += 1; 35 | // left_idx += overlap_size; 36 | // } 37 | 38 | // sums / count as f64 39 | // } 40 | 41 | // #[polars_expr(output_type=Float64)] 42 | // fn pl_psd( 43 | // inputs: &[Series], 44 | // kwargs: WelchKwargs, 45 | // ) -> PolarsResult { 46 | 47 | // let s = inputs[0].f64()?; 48 | // let s = s.cont_slice()?; 49 | // let result = welch(s, kwargs.window_size, kwargs.overlap_size); 50 | // Ok(Series::from_iter([result])) 51 | // } 52 | -------------------------------------------------------------------------------- /src/num_ext/woe_iv.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::*; 2 | use pyo3_polars::derive::polars_expr; 3 | 4 | fn woe_output(_: &[Field]) -> PolarsResult { 5 | let value = Field::new("value".into(), DataType::String); 6 | let woe: Field = Field::new("woe".into(), DataType::Float64); 7 | let v: Vec = vec![value, woe]; 8 | Ok(Field::new("woe_output".into(), DataType::Struct(v))) 9 | } 10 | 11 | fn iv_output(_: &[Field]) -> PolarsResult { 12 | let value = Field::new("value".into(), DataType::String); 13 | let iv: Field = Field::new("iv".into(), DataType::Float64); 14 | let v: Vec = vec![value, iv]; 15 | Ok(Field::new("iv_output".into(), DataType::Struct(v))) 16 | } 17 | 18 | /// Get a lazyframe needed to compute Weight Of Evidence. 19 | /// Inputs[0] by default is the discrete bins / categories (cast to String at Python side) 20 | /// Inputs[1] by default is the target (0s and 1s) 21 | /// Nulls will be droppped 22 | fn get_woe_frame(discrete_col: &Series, target: &Series) -> PolarsResult { 23 | let df = df!( 24 | "value" => discrete_col, 25 | "target" => target, 26 | )?; 27 | // Here we are adding 1 to make sure the event/non-event (goods/bads) are nonzero, 28 | // so that the computation will not yield inf as output. 29 | let out = df 30 | .lazy() 31 | .drop_nulls(None) 32 | .group_by([col("value")]) 33 | .agg([ 34 | len().cast(DataType::Float64).alias("cnt"), 35 | col("target").sum().cast(DataType::Float64).alias("goods"), 36 | ]) 37 | .select([ 38 | col("value"), 39 | ((col("goods") + lit(1f64)) / (col("goods").sum() + lit(2f64))).alias("good_pct"), 40 | ((col("cnt") - col("goods") + lit(1f64)) 41 | / (col("cnt").sum() - col("goods").sum() + lit(2f64))) 42 | .alias("bad_pct"), 43 | ]) 44 | .with_column( 45 | (col("good_pct") / col("bad_pct")) 46 | .log(std::f64::consts::E) 47 | .alias("woe"), 48 | ); 49 | Ok(out) 50 | } 51 | 52 | /// WOE for each bin/category 53 | #[polars_expr(output_type_func=woe_output)] 54 | fn pl_woe_discrete(inputs: &[Series]) -> PolarsResult { 55 | let df = get_woe_frame(&inputs[0], &inputs[1])? 56 | .select([col("value"), col("woe")]) 57 | .collect()?; 58 | 59 | Ok(df.into_struct("woe_output".into()).into_series()) 60 | } 61 | 62 | /// Information Value for each bin/category 63 | /// The information value for this column/feature will be the sum. 64 | #[polars_expr(output_type_func=iv_output)] 65 | fn pl_iv(inputs: &[Series]) -> PolarsResult { 66 | let df = get_woe_frame(&inputs[0], &inputs[1])? 67 | .select([ 68 | col("value"), 69 | ((col("good_pct") - col("bad_pct")) * col("woe")).alias("iv"), 70 | ]) 71 | .collect()?; 72 | 73 | Ok(df.into_struct("iv_output".into()).into_series()) 74 | } 75 | -------------------------------------------------------------------------------- /src/pymodels/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod py_kdt; 2 | pub mod py_lr; 3 | -------------------------------------------------------------------------------- /src/pymodels/py_glm.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abstractqqq/polars_ds_extension/31e44245c214ad1b464c3458de0960f61a076145/src/pymodels/py_glm.rs -------------------------------------------------------------------------------- /src/stats/chi2.rs: -------------------------------------------------------------------------------- 1 | use super::{generic_stats_output, simple_stats_output}; 2 | use crate::stats_utils::gamma; 3 | use polars::prelude::*; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | fn chi2_full_output(fields: &[Field]) -> PolarsResult { 7 | let s = Field::new("statistic".into(), DataType::Float64); 8 | let p = Field::new("pvalue".into(), DataType::Float64); 9 | let dof = Field::new("dof".into(), DataType::UInt32); 10 | let f1 = fields[0].clone(); 11 | let f2 = fields[1].clone(); 12 | let ef = Field::new("E[freq]".into(), DataType::Float64); 13 | let v: Vec = vec![s, p, dof, f1, f2, ef]; 14 | Ok(Field::new("chi2_full".into(), DataType::Struct(v))) 15 | } 16 | 17 | fn _chi2_helper(inputs: &[Series]) -> PolarsResult<(LazyFrame, usize, usize)> { 18 | // Return a df with necessary values to compute chi2, together 19 | // with nrows and ncols 20 | let s1_name = "s1"; 21 | let s2_name = "s2"; 22 | 23 | let u1 = inputs[0].unique()?; 24 | let u1_len = u1.len(); 25 | let u2 = inputs[1].unique()?; 26 | let u2_len = u2.len(); 27 | // Get the cartesian product 28 | let df1 = df!(s1_name => u1)?.lazy(); 29 | let df2 = df!(s2_name => u2)?.lazy(); 30 | let cross = df1.cross_join(df2, None); 31 | 32 | // Create a "fake" contigency table 33 | let s1 = inputs[0].clone(); 34 | let s2 = inputs[1].clone(); 35 | let df3 = df!(s1_name => s1, s2_name => s2)? 36 | .lazy() 37 | .group_by([col(s1_name), col(s2_name)]) 38 | .agg([len().cast(DataType::UInt64).alias("ob")]); 39 | 40 | let df4 = cross 41 | .join( 42 | df3, 43 | [col(s1_name), col(s2_name)], 44 | [col(s1_name), col(s2_name)], 45 | JoinArgs::new(JoinType::Left), 46 | ) 47 | .with_column(col("ob").fill_null(0)); 48 | 49 | // Compute the statistic 50 | let frame = df4.with_columns([((col("ob").sum().over([s2_name]) 51 | * col("ob").sum().over([s1_name])) 52 | .cast(DataType::Float64) 53 | / col("ob").sum().cast(DataType::Float64)) 54 | .alias("ex")]); 55 | 56 | Ok((frame, u1_len, u2_len)) 57 | } 58 | 59 | fn _chi2_pvalue(stats: f64, dof: usize) -> PolarsResult { 60 | // The p value for chi2 61 | let p = if stats.is_nan() { 62 | f64::NAN 63 | } else { 64 | let (shape, rate) = (dof as f64 / 2., 0.5); 65 | let p = gamma::sf(stats, shape, rate).map_err(|e| PolarsError::ComputeError(e.into())); 66 | p? 67 | }; 68 | Ok(p) 69 | } 70 | 71 | #[polars_expr(output_type_func=simple_stats_output)] 72 | fn pl_chi2(inputs: &[Series]) -> PolarsResult { 73 | let (df, u1_len, u2_len) = _chi2_helper(inputs)?; 74 | 75 | let mut final_df = df 76 | .select([ 77 | ((col("ob").cast(DataType::Float64) - col("ex")).pow(2) / col("ex")) 78 | .sum() 79 | .alias("output"), 80 | ]) 81 | .collect()?; 82 | 83 | // Get the statistic 84 | let out = final_df.drop_in_place("output").unwrap(); 85 | let stats = out.f64()?; 86 | let stats = stats.get(0).unwrap_or(f64::NAN); 87 | // Compute p value. It is a special case of Gamma distribution 88 | let dof = u1_len.abs_diff(1) * u2_len.abs_diff(1); 89 | let p = _chi2_pvalue(stats, dof)?; 90 | generic_stats_output(stats, p) 91 | } 92 | 93 | #[polars_expr(output_type_func=chi2_full_output)] 94 | fn pl_chi2_full(inputs: &[Series]) -> PolarsResult { 95 | let s1_name = inputs[0].name(); 96 | let s2_name = inputs[1].name(); 97 | 98 | let (df, u1_len, u2_len) = _chi2_helper(inputs)?; 99 | // cheap clone 100 | let mut df2 = df 101 | .clone() 102 | .select([ 103 | col("s1").alias(s1_name.clone()), 104 | col("s2").alias(s2_name.clone()), 105 | col("ex").alias("E[freq]"), 106 | ]) 107 | .collect()?; 108 | let ef = df2.drop_in_place("E[freq]").unwrap(); 109 | let s1 = df2.drop_in_place(s1_name).unwrap(); 110 | let s2 = df2.drop_in_place(s2_name).unwrap(); 111 | 112 | let mut final_df = df 113 | .select([ 114 | ((col("ob").cast(DataType::Float64) - col("ex")).pow(2) / col("ex")) 115 | .sum() 116 | .alias("output"), 117 | ]) 118 | .collect()?; 119 | 120 | // Get the statistic 121 | let out = final_df.drop_in_place("output").unwrap(); 122 | let stats = out.f64()?; 123 | let stats = stats.get(0).unwrap_or(f64::NAN); 124 | let dof = u1_len.abs_diff(1) * u2_len.abs_diff(1); 125 | let p = _chi2_pvalue(stats, dof)?; 126 | let stats_column = Column::new_scalar("statistic".into(), stats.into(), 1); 127 | let pval_column = Column::new_scalar("pvalue".into(), p.into(), 1); 128 | let dof_column = Column::new_scalar("dof".into(), (dof as u32).into(), 1); 129 | 130 | let ca = StructChunked::from_columns( 131 | "chi2_full".into(), 132 | ef.len(), 133 | &[stats_column, pval_column, dof_column, s1, s2, ef], 134 | )?; 135 | Ok(ca.into_series()) 136 | } 137 | -------------------------------------------------------------------------------- /src/stats/fstats.rs: -------------------------------------------------------------------------------- 1 | use super::simple_stats_output; 2 | use crate::stats_utils::beta::fisher_snedecor_sf; 3 | use crate::utils::{columns_to_vec, IndexOrder}; 4 | /// Multiple F-statistics at once and F test 5 | use core::f64; 6 | use itertools::Itertools; 7 | use polars::frame::column::Column; 8 | use polars::prelude::*; 9 | use pyo3_polars::derive::polars_expr; 10 | 11 | /// Use inputs[0] as the target column (discrete, indicating the groups) 12 | /// and inputs[i] as the column to run F-test against the target, i > 0. 13 | #[polars_expr(output_type_func=simple_stats_output)] 14 | fn pl_f_test(inputs: &[Series]) -> PolarsResult { 15 | // Use a df to make the computations parallel. 16 | // Column at index 0 is the target column 17 | let v = inputs 18 | .into_iter() 19 | .enumerate() 20 | .map(|(i, s)| Column::new(i.to_string().into(), s)) 21 | .collect_vec(); 22 | let n_cols = v.len(); 23 | 24 | let df = DataFrame::new(v)?.lazy(); 25 | // inputs[0] is the group 26 | // all the rest should numerical 27 | let mut step_one: Vec = Vec::with_capacity(inputs.len() * 2 - 1); 28 | step_one.push(len().cast(DataType::Float64).alias("cnt")); 29 | let mut step_two: Vec = Vec::with_capacity(inputs.len() + 1); 30 | step_two.push(col("cnt").sum().cast(DataType::UInt32).alias("n_samples")); 31 | step_two.push(col("0").count().cast(DataType::UInt32).alias("n_classes")); 32 | 33 | for i in 1..n_cols { 34 | let name = i.to_string(); 35 | let name = name.as_str(); 36 | let n_sum = format!("{}_sum", i); 37 | let n_sum = n_sum.as_str(); 38 | let n_var = format!("{}_var", i); 39 | let n_var = n_var.as_str(); 40 | step_one.push(col(name).sum().alias(n_sum)); 41 | step_one.push(col(name).var(0).alias(n_var)); 42 | let p1: Expr = (col(n_sum).cast(DataType::Float64) / col("cnt").cast(DataType::Float64) 43 | - (col(n_sum).sum().cast(DataType::Float64) 44 | / col("cnt").sum().cast(DataType::Float64))) 45 | .pow(2); 46 | let p2 = col(n_var).dot(col("cnt").cast(DataType::Float64)); 47 | 48 | step_two.push(p1.dot(col("cnt").cast(DataType::Float64)) / p2) 49 | } 50 | 51 | let mut reference = df 52 | .group_by([col("0")]) 53 | .agg(step_one) 54 | .select(step_two) 55 | .collect()?; 56 | 57 | let n_samples = reference.drop_in_place("n_samples").unwrap(); 58 | let n_classes = reference.drop_in_place("n_classes").unwrap(); 59 | let n_samples = n_samples.u32()?; 60 | let n_classes = n_classes.u32()?; 61 | let n_samples = n_samples.get(0).unwrap_or(0); 62 | let n_classes = n_classes.get(0).unwrap_or(0); 63 | 64 | if n_classes <= 1 || n_samples <= 1 { 65 | return Err(PolarsError::ComputeError( 66 | "F-stats: n_classes <= 1 in target or n_samples <= 1.".into(), 67 | )); 68 | } 69 | 70 | let df_btw_class = n_classes.abs_diff(1) as f64; 71 | let df_in_class = n_samples.abs_diff(n_classes) as f64; 72 | // Note: reference is a df with 1 row. We need to get the stats out 73 | // fstats is 2D but with 1 row. 74 | 75 | let scale = df_in_class / df_btw_class; 76 | 77 | let mut fstats = columns_to_vec::(reference.take_columns(), IndexOrder::C)?; 78 | fstats.iter_mut().for_each(|v| *v = *v * scale); 79 | 80 | let out_p: Vec = fstats 81 | .iter() 82 | .map(|x| fisher_snedecor_sf(*x, df_btw_class, df_in_class).unwrap_or(f64::NAN)) 83 | .collect(); 84 | 85 | let s1 = Column::new("statistic".into(), fstats); 86 | let s2 = Column::new("pvalue".into(), out_p); 87 | 88 | let ca = StructChunked::from_columns("f-test".into(), s1.len(), &[s1, s2])?; 89 | Ok(ca.into_series()) 90 | } 91 | -------------------------------------------------------------------------------- /src/stats/kendall_tau.rs: -------------------------------------------------------------------------------- 1 | /// O(nlogn) implementation of Kendall's Tau correlation 2 | /// Implemented by translating the Java code: 3 | /// https://www.hipparchus.org/xref/org/hipparchus/stat/correlation/KendallsCorrelation.html 4 | use polars::prelude::*; 5 | use pyo3_polars::derive::polars_expr; 6 | 7 | #[polars_expr(output_type=Float64)] 8 | pub fn pl_kendall_tau(inputs: &[Series]) -> PolarsResult { 9 | let name = inputs[0].name(); 10 | let mut binding = df!("x" => &inputs[0], "y" => &inputs[1])? 11 | .lazy() 12 | .filter(col("x").is_not_null().and(col("y").is_not_null())) 13 | .sort(["x", "y"], Default::default()) 14 | .collect()?; 15 | let df = binding.align_chunks(); 16 | 17 | let n = df.height(); 18 | if n <= 1 { 19 | return Ok(Series::from_vec(name.clone(), vec![f64::NAN])); 20 | } 21 | 22 | let n_pairs = ((n * (n - 1)) >> 1) as i64; 23 | 24 | let x = df.drop_in_place("x").unwrap(); 25 | let y = df.drop_in_place("y").unwrap(); 26 | let x = x.u32().unwrap(); 27 | let y = y.u32().unwrap(); 28 | let x = x.cont_slice().unwrap(); 29 | let y = y.cont_slice().unwrap(); 30 | 31 | let mut tied_x: i64 = 0; 32 | let mut tied_xy: i64 = 0; 33 | let mut consecutive_x_ties: i64 = 1; 34 | let mut consecutive_xy_ties: i64 = 1; 35 | let (mut xj, mut yj) = (x[0], y[0]); 36 | for i in 1..x.len() { 37 | // i current, j prev 38 | let (xi, yi) = (x[i], y[i]); 39 | if xi == xj { 40 | consecutive_x_ties += 1; 41 | if yi == yj { 42 | consecutive_xy_ties += 1; 43 | } else { 44 | tied_xy += (consecutive_xy_ties * (consecutive_xy_ties - 1)) >> 1; 45 | consecutive_xy_ties = 1; 46 | } 47 | } else { 48 | tied_x += (consecutive_x_ties * (consecutive_x_ties - 1)) >> 1; 49 | consecutive_x_ties = 1; 50 | tied_xy += (consecutive_xy_ties * (consecutive_xy_ties - 1)) >> 1; 51 | consecutive_xy_ties = 1; 52 | } 53 | xj = xi; 54 | yj = yi; 55 | } 56 | 57 | tied_x += (consecutive_x_ties * (consecutive_x_ties - 1)) >> 1; 58 | tied_xy += (consecutive_xy_ties * (consecutive_xy_ties - 1)) >> 1; 59 | 60 | let mut swaps: usize = 0; 61 | let mut xx = x.to_vec(); 62 | let mut yy = y.to_vec(); 63 | let mut x_copy = x.to_vec(); 64 | let mut y_copy = y.to_vec(); 65 | let mut seg_size: usize = 1; 66 | while seg_size < n { 67 | let mut offset: usize = 0; 68 | while offset < n { 69 | let mut i = offset; 70 | let i_end = (i + seg_size).min(n); 71 | let mut j = i_end; 72 | let j_end = (j + seg_size).min(n); 73 | 74 | let mut copy_loc = offset; 75 | while (i < i_end) || (j < j_end) { 76 | if i < i_end { 77 | if j < j_end { 78 | if yy[i] <= yy[j] { 79 | x_copy[copy_loc] = xx[i]; 80 | y_copy[copy_loc] = yy[i]; 81 | i += 1; 82 | } else { 83 | x_copy[copy_loc] = xx[j]; 84 | y_copy[copy_loc] = yy[j]; 85 | j += 1; 86 | swaps += i_end - i; 87 | } 88 | } else { 89 | x_copy[copy_loc] = xx[i]; 90 | y_copy[copy_loc] = yy[i]; 91 | i += 1; 92 | } 93 | } else { 94 | x_copy[copy_loc] = xx[j]; 95 | y_copy[copy_loc] = yy[j]; 96 | j += 1; 97 | } 98 | copy_loc += 1; 99 | } 100 | offset += seg_size << 1; // multiply by 2 101 | } 102 | std::mem::swap(&mut xx, &mut x_copy); 103 | std::mem::swap(&mut yy, &mut y_copy); 104 | seg_size <<= 1; 105 | } 106 | 107 | let mut tied_y: i64 = 0; 108 | let mut consecutive_y_ties: i64 = 1; 109 | let mut prev = yy[0]; 110 | for i in 1..n { 111 | if yy[i] == prev { 112 | consecutive_y_ties += 1; 113 | } else { 114 | tied_y += (consecutive_y_ties * (consecutive_y_ties - 1)) >> 1; 115 | consecutive_y_ties = 1; 116 | } 117 | prev = yy[i]; 118 | } 119 | tied_y += (consecutive_y_ties * (consecutive_y_ties - 1)) >> 1; 120 | 121 | let nc_m_nd = n_pairs - tied_x - tied_y + tied_xy - ((swaps << 1) as i64); 122 | // Prevent overflow 123 | let denom = (((n_pairs - tied_x) as f64) * ((n_pairs - tied_y) as f64)).sqrt(); 124 | let out = nc_m_nd as f64 / denom; 125 | Ok(Series::from_vec(name.clone(), vec![out])) 126 | } 127 | -------------------------------------------------------------------------------- /src/stats/ks.rs: -------------------------------------------------------------------------------- 1 | /// KS statistics. 2 | use super::{generic_stats_output, simple_stats_output}; 3 | use ordered_float::OrderedFloat; 4 | use polars::prelude::*; 5 | use pyo3_polars::derive::polars_expr; 6 | 7 | #[inline(always)] 8 | fn binary_search_right(arr: &[OrderedFloat], t: &OrderedFloat) -> usize { 9 | // Can likely get rid of the partial_cmp, because I have gauranteed the values to be finite 10 | let mut left = 0; 11 | let mut right = arr.len(); 12 | 13 | while left < right { 14 | let mid = left + ((right - left) >> 1); 15 | match arr[mid].cmp(t) { 16 | std::cmp::Ordering::Greater => right = mid, 17 | _ => left = mid + 1, 18 | } 19 | } 20 | left 21 | } 22 | 23 | /// Currently only supports two-sided. Won't be too hard to do add one-sided? I hope. 24 | /// Reference: 25 | /// https://github.com/scipy/scipy/blob/v1.11.3/scipy/stats/_stats_py.py#L8644-L8875 26 | /// Instead of returning a pvalue, the D_n_m quantity is returned, see 27 | /// https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test 28 | #[inline] 29 | fn ks_2samp(v1: &[f64], v2: &[f64], alpha: f64) -> (f64, f64) { 30 | // It is possible to not do binary search because v1 and v2 are already sorted. 31 | // But that makes the algorithm more complicated. 32 | 33 | let n1: f64 = v1.len() as f64; 34 | let n2: f64 = v2.len() as f64; 35 | 36 | let v1 = unsafe { std::mem::transmute::<&[f64], &[OrderedFloat]>(v1) }; 37 | 38 | let v2 = unsafe { std::mem::transmute::<&[f64], &[OrderedFloat]>(v2) }; 39 | 40 | // Follow SciPy's trick to compute the difference between two CDFs 41 | let stats = v1 42 | .iter() 43 | .chain(v2.iter()) 44 | .map(|x| { 45 | ( 46 | (binary_search_right(v1, x) as f64) / n1, 47 | (binary_search_right(v2, x) as f64) / n2, 48 | ) 49 | }) 50 | .fold(f64::MIN, |acc, (x, y)| acc.max((x - y).abs())); 51 | 52 | // This differs from SciPy, since I am assuming we are doing two-sided test 53 | 54 | let c_alpha = (-0.5 * (alpha / 2.0).ln()).abs(); 55 | let p_estimate = (c_alpha * (n1 + n2) / (n1 * n2)).sqrt(); 56 | (stats, p_estimate) 57 | } 58 | 59 | #[polars_expr(output_type_func=simple_stats_output)] 60 | fn pl_ks_2samp(inputs: &[Series]) -> PolarsResult { 61 | let s1 = inputs[0].f64()?; // input is sorted, one chunk, gauranteed by Python side code 62 | let s2 = inputs[1].f64()?; // input is sorted, one chunk, gauranteed by Python side code 63 | let alpha = inputs[2].f64()?; 64 | let alpha = alpha.get(0).unwrap(); 65 | 66 | if (s1.len() <= 30) || (s2.len() <= 30) { 67 | generic_stats_output(0f64, f64::NAN) 68 | } else { 69 | let v1 = s1.cont_slice().unwrap(); 70 | let v2 = s2.cont_slice().unwrap(); 71 | let (s, p) = ks_2samp(v1, v2, alpha); 72 | generic_stats_output(s, p) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/stats/mann_whitney_u.rs: -------------------------------------------------------------------------------- 1 | /// Mann-Whitney U Statistics 2 | use super::{generic_stats_output, simple_stats_output, Alternative}; 3 | use crate::stats_utils::normal; 4 | use polars::prelude::*; 5 | use pyo3_polars::derive::polars_expr; 6 | 7 | fn mann_whitney_tie_sum(ranks: &Float64Chunked) -> f64 { 8 | // NaN won't exist in ranks. 9 | let mut rank_number = f64::NAN; 10 | let mut rank_cnt: f64 = 0f64; 11 | let mut accumulant = 0f64; 12 | for v in ranks.into_no_null_iter() { 13 | if v == rank_number { 14 | rank_cnt += 1.; 15 | } else { 16 | accumulant += rank_cnt * (rank_cnt + 1.0) * (rank_cnt - 1.0); 17 | rank_number = v; 18 | rank_cnt = 1.0; 19 | } 20 | } 21 | accumulant 22 | } 23 | 24 | #[polars_expr(output_type_func=simple_stats_output)] 25 | fn pl_mann_whitney_u(inputs: &[Series]) -> PolarsResult { 26 | // Reference: https://github.com/scipy/scipy/blob/v1.13.1/scipy/stats/_mannwhitneyu.py#L177 27 | 28 | let u1 = inputs[0].f64().unwrap(); 29 | let u1 = u1.get(0).unwrap(); 30 | 31 | let u2 = inputs[1].f64().unwrap(); 32 | let u2 = u2.get(0).unwrap(); 33 | 34 | let mean = inputs[2].f64().unwrap(); 35 | let mean = mean.get(0).unwrap(); 36 | 37 | // Custom RLE 38 | let sorted_ranks = inputs[3].f64().unwrap(); 39 | let n = sorted_ranks.len() as f64; 40 | let tie_term_sum = mann_whitney_tie_sum(sorted_ranks); 41 | let std_ = ((mean / 6.0) * ((n + 1.0) - tie_term_sum / (n * (n - 1.0)))).sqrt(); 42 | 43 | let alt = inputs[4].str()?; 44 | let alt = alt.get(0).unwrap(); 45 | let alt = Alternative::from(alt); 46 | 47 | let (u, factor) = match alt { 48 | // if I use min here, always wrong p value. But wikipedia says it is min. I wonder wtf.. 49 | Alternative::TwoSided => (u1.max(u2), 2.0), 50 | Alternative::Less => (u2, 1.0), 51 | Alternative::Greater => (u1, 1.0), 52 | }; 53 | 54 | let p = if std_ == 0. { 55 | 0. 56 | } else { 57 | // -0.5 is some continuity adjustment. See Scipy's impl 58 | (factor * normal::sf_unchecked(u, mean + 0.5, std_)).clamp(0., 1.) 59 | }; 60 | generic_stats_output(u1, p) 61 | } 62 | -------------------------------------------------------------------------------- /src/stats/mod.rs: -------------------------------------------------------------------------------- 1 | mod chi2; 2 | mod fstats; 3 | mod kendall_tau; 4 | mod ks; 5 | mod mann_whitney_u; 6 | mod normal_test; 7 | mod sample; 8 | mod t_test; 9 | mod xi_corr; 10 | 11 | use polars::prelude::*; 12 | 13 | pub fn simple_stats_output(_: &[Field]) -> PolarsResult { 14 | let s = Field::new("statistic".into(), DataType::Float64); 15 | let p = Field::new("pvalue".into(), DataType::Float64); 16 | let v: Vec = vec![s, p]; 17 | Ok(Field::new("".into(), DataType::Struct(v))) 18 | } 19 | 20 | pub enum Alternative { 21 | TwoSided, 22 | Less, 23 | Greater, 24 | } 25 | 26 | impl From<&str> for Alternative { 27 | fn from(s: &str) -> Alternative { 28 | match s.to_lowercase().as_str() { 29 | "two-sided" | "two" => Alternative::TwoSided, 30 | "less" => Alternative::Less, 31 | "greater" => Alternative::Greater, 32 | _ => Alternative::TwoSided, 33 | } 34 | } 35 | } 36 | 37 | #[inline] 38 | fn generic_stats_output(statistic: f64, pvalue: f64) -> PolarsResult { 39 | let s = Series::from_vec("statistic".into(), vec![statistic]); 40 | let p = Series::from_vec("pvalue".into(), vec![pvalue]); 41 | let out = StructChunked::from_series("".into(), 1, [&s, &p].into_iter())?; 42 | Ok(out.into_series()) 43 | } 44 | -------------------------------------------------------------------------------- /src/stats/normal_test.rs: -------------------------------------------------------------------------------- 1 | /// Here we implement the test as in SciPy: 2 | /// https://github.com/scipy/scipy/blob/v1.11.4/scipy/stats/_stats_py.py#L1836-L1996 3 | /// 4 | /// It is a method based on Kurtosis and Skew, and the Chi-2 distribution. 5 | /// 6 | /// References: 7 | /// [1] D'Agostino, R. B. (1971), "An omnibus test of normality for 8 | /// moderate and large sample size", Biometrika, 58, 341-348 9 | /// [2] https://www.stata.com/manuals/rsktest.pdf 10 | use super::{generic_stats_output, simple_stats_output}; 11 | use crate::stats_utils::gamma; 12 | use polars::prelude::*; 13 | use pyo3_polars::derive::polars_expr; 14 | 15 | /// Returns the skew test statistic, no pvalue, add p value if needed 16 | fn skew_test_statistic(skew: f64, n: usize) -> f64 { 17 | let n = n as f64; 18 | let y = skew * ((n + 1.) * (n + 3.) / (6. * (n - 2.))).sqrt(); 19 | let beta2 = 3. * (n.powi(2) + 27. * n - 70.) * (n + 1.) * (n + 3.) 20 | / ((n - 2.) * (n + 5.) * (n + 7.) * (n + 9.)); 21 | let w2 = (2. * (beta2 - 1.)).sqrt() - 1.; 22 | let alpha = (2. / (w2 - 1.)).sqrt(); 23 | 24 | let tmp = y / alpha; 25 | let z = (tmp + (tmp.powi(2) + 1.).sqrt()).ln() / (w2.ln() * 0.5).sqrt(); 26 | z 27 | } 28 | 29 | /// Returns the kurtosis test statistic, no pvalue, add p value if needed 30 | fn kurtosis_test_statistic(kur: f64, n: usize) -> f64 { 31 | let n = n as f64; 32 | let e = 3.0 * (n - 1.) / (n + 1.); 33 | let var = 24.0 * n * (n - 2.) * (n - 3.) / ((n + 1.).powi(2) * (n + 3.) * (n + 5.)); 34 | let x = (kur - e) / var.sqrt(); 35 | let root_beta_1 = 6. * (n.powi(2) - 5. * n + 2.) / ((n + 7.) * (n + 9.)); 36 | let root_beta_2 = (6. * (n + 3.) * (n + 5.) / (n * (n - 2.) * (n - 3.))).sqrt(); 37 | let root_beta = root_beta_1 * root_beta_2; 38 | 39 | let a = 6. + (8. / root_beta) * (2. / root_beta + (1. + 4. / root_beta.powi(2)).sqrt()); 40 | 41 | let tmp = 2. / (9. * a); 42 | let denom = 1. + x * (2. / (a - 4.)).sqrt(); 43 | if denom == 0. { 44 | println!("Kurtosis test: Division by 0 encountered."); 45 | f64::NAN 46 | } else { 47 | let term1 = 1. - tmp; 48 | let term2 = ((1. - 2. / a) / denom.abs()).cbrt(); 49 | let z = (term1 - term2) / tmp.sqrt(); 50 | z 51 | } 52 | } 53 | 54 | #[polars_expr(output_type_func=simple_stats_output)] 55 | fn pl_normal_test(inputs: &[Series]) -> PolarsResult { 56 | let skew = inputs[0].f64()?; 57 | let skew = skew.get(0).unwrap(); 58 | 59 | let kurtosis = inputs[1].f64()?; 60 | let kurtosis = kurtosis.get(0).unwrap(); 61 | 62 | let n = inputs[2].u32()?; 63 | let n = n.get(0).unwrap() as usize; 64 | 65 | if n < 20 { 66 | return Err(PolarsError::ComputeError( 67 | "Normal Test: Input should have non-null length >= 20.".into(), 68 | )); 69 | } 70 | 71 | let s = skew_test_statistic(skew, n); 72 | let k = kurtosis_test_statistic(kurtosis, n); 73 | 74 | let k2 = s * s + k * k; // the statistics 75 | 76 | // Define gamma 77 | // Shape = (degree of freedom (2) / 2, rate = 0.5) 78 | let (shape, rate) = (1., 0.5); 79 | let p = gamma::sf(k2, shape, rate).map_err(|e| PolarsError::ComputeError(e.into()))?; 80 | generic_stats_output(k2, p) 81 | } 82 | -------------------------------------------------------------------------------- /src/stats/t_test.rs: -------------------------------------------------------------------------------- 1 | use super::{generic_stats_output, simple_stats_output, Alternative}; 2 | use crate::{stats, stats_utils::beta}; 3 | /// Student's t test and Welch's t test. 4 | use core::f64; 5 | use polars::prelude::*; 6 | use pyo3_polars::derive::polars_expr; 7 | 8 | #[inline] 9 | fn ttest_ind(m1: f64, m2: f64, v1: f64, v2: f64, n: f64, alt: Alternative) -> (f64, f64) { 10 | let num = m1 - m2; 11 | // ((var1 + var2) / 2 ).sqrt() * (2./n).sqrt() can be simplified as below 12 | let denom = ((v1 + v2) / n).sqrt(); 13 | if denom == 0. { 14 | println!("T Test: Division by 0 encountered."); 15 | (f64::NAN, f64::NAN) 16 | } else { 17 | let t = num / denom; 18 | let df = 2. * n - 2.; 19 | let p = match alt { 20 | Alternative::Less => beta::student_t_sf(-t, df).unwrap_or(f64::NAN), 21 | Alternative::Greater => beta::student_t_sf(t, df).unwrap_or(f64::NAN), 22 | Alternative::TwoSided => match beta::student_t_sf(t.abs(), df) { 23 | Ok(p) => 2.0 * p, 24 | Err(_) => f64::NAN, 25 | }, 26 | }; 27 | (t, p) 28 | } 29 | } 30 | 31 | #[inline] 32 | fn ttest_1samp( 33 | mean: f64, 34 | pop_mean: f64, 35 | var: f64, 36 | n: f64, 37 | alt: Alternative, 38 | ) -> Result<(f64, f64), String> { 39 | let num = mean - pop_mean; 40 | let denom = (var / n).sqrt(); 41 | if denom == 0. { 42 | Err("T Test: Division by 0 encountered.".into()) 43 | } else { 44 | let t = num / denom; 45 | let df = n - 1.; 46 | let p = match alt { 47 | Alternative::Less => beta::student_t_sf(-t, df), 48 | Alternative::Greater => beta::student_t_sf(t, df), 49 | Alternative::TwoSided => match beta::student_t_sf(t.abs(), df) { 50 | Ok(p) => Ok(2.0 * p), 51 | Err(e) => Err(e), 52 | }, 53 | }; 54 | let p = p?; 55 | Ok((t, p)) 56 | } 57 | } 58 | 59 | #[inline] 60 | fn welch_t( 61 | m1: f64, 62 | m2: f64, 63 | v1: f64, 64 | v2: f64, 65 | n1: f64, 66 | n2: f64, 67 | alt: Alternative, 68 | ) -> Result<(f64, f64), String> { 69 | let num = m1 - m2; 70 | let vn1 = v1 / n1; 71 | let vn2 = v2 / n2; 72 | let denom = (vn1 + vn2).sqrt(); 73 | if denom == 0. { 74 | Err("T Test: Division by 0 encountered.".into()) 75 | } else { 76 | let t = num / denom; 77 | let df = (vn1 + vn2).powi(2) / (vn1.powi(2) / (n1 - 1.) + (vn2.powi(2) / (n2 - 1.))); 78 | let p = match alt { 79 | // the distribution is approximately student t 80 | Alternative::Less => beta::student_t_sf(-t, df), 81 | Alternative::Greater => beta::student_t_sf(t, df), 82 | Alternative::TwoSided => match beta::student_t_sf(t.abs(), df) { 83 | Ok(p) => Ok(2.0 * p), 84 | Err(e) => Err(e), 85 | }, 86 | }; 87 | let p = p?; 88 | Ok((t, p)) 89 | } 90 | } 91 | 92 | #[polars_expr(output_type_func=simple_stats_output)] 93 | fn pl_ttest_2samp(inputs: &[Series]) -> PolarsResult { 94 | let mean1 = inputs[0].f64()?; 95 | let mean1 = mean1.get(0).unwrap_or(f64::NAN); 96 | let mean2 = inputs[1].f64()?; 97 | let mean2 = mean2.get(0).unwrap_or(f64::NAN); 98 | let var1 = inputs[2].f64()?; 99 | let var1 = var1.get(0).unwrap_or(f64::NAN); 100 | let var2 = inputs[3].f64()?; 101 | let var2 = var2.get(0).unwrap_or(f64::NAN); 102 | let n = inputs[4].u64()?; 103 | let n = n.get(0).unwrap() as f64; 104 | 105 | let alt = inputs[5].str()?; 106 | let alt = alt.get(0).unwrap(); 107 | let alt = stats::Alternative::from(alt); 108 | 109 | let valid = mean1.is_finite() && mean2.is_finite() && var1.is_finite() && var2.is_finite(); 110 | if !valid { 111 | return Err(PolarsError::ComputeError( 112 | "T Test: Sample Mean/Std is found to be NaN or Inf.".into(), 113 | )); 114 | } 115 | let (t, p) = ttest_ind(mean1, mean2, var1, var2, n, alt); 116 | generic_stats_output(t, p) 117 | } 118 | 119 | #[polars_expr(output_type_func=simple_stats_output)] 120 | fn pl_welch_t(inputs: &[Series]) -> PolarsResult { 121 | let mean1 = inputs[0].f64()?; 122 | let mean1 = mean1.get(0).unwrap(); 123 | let mean2 = inputs[1].f64()?; 124 | let mean2 = mean2.get(0).unwrap(); 125 | let var1 = inputs[2].f64()?; 126 | let var1 = var1.get(0).unwrap(); 127 | let var2 = inputs[3].f64()?; 128 | let var2 = var2.get(0).unwrap(); 129 | let n1 = inputs[4].u64()?; 130 | let n1 = n1.get(0).unwrap() as f64; 131 | let n2 = inputs[5].u64()?; 132 | let n2 = n2.get(0).unwrap() as f64; 133 | 134 | let alt = inputs[6].str()?; 135 | let alt = alt.get(0).unwrap(); 136 | let alt = stats::Alternative::from(alt); 137 | 138 | // No need to check for validity because input is sanitized. 139 | 140 | let (t, p) = welch_t(mean1, mean2, var1, var2, n1, n2, alt) 141 | .map_err(|e| PolarsError::ComputeError(e.into()))?; 142 | 143 | generic_stats_output(t, p) 144 | } 145 | 146 | #[polars_expr(output_type_func=simple_stats_output)] 147 | fn pl_ttest_1samp(inputs: &[Series]) -> PolarsResult { 148 | let mean = inputs[0].f64()?; 149 | let mean = mean.get(0).unwrap(); 150 | let pop_mean = inputs[1].f64()?; 151 | let pop_mean = pop_mean.get(0).unwrap(); 152 | let var = inputs[2].f64()?; 153 | let var = var.get(0).unwrap(); 154 | let n = inputs[3].u64()?; 155 | let n = n.get(0).unwrap() as f64; 156 | 157 | let alt = inputs[4].str()?; 158 | let alt = alt.get(0).unwrap(); 159 | let alt = stats::Alternative::from(alt); 160 | 161 | // No need to check for validity because input is sanitized. 162 | 163 | let (t, p) = ttest_1samp(mean, pop_mean, var, n, alt) 164 | .map_err(|e| PolarsError::ComputeError(e.into()))?; 165 | 166 | generic_stats_output(t, p) 167 | } 168 | -------------------------------------------------------------------------------- /src/stats/xi_corr.rs: -------------------------------------------------------------------------------- 1 | use super::simple_stats_output; 2 | use crate::stats_utils::normal; 3 | use polars::{prelude::*, series::ops::NullBehavior}; 4 | use pyo3_polars::derive::polars_expr; 5 | 6 | fn _xi_corr(inputs: &[Series]) -> PolarsResult { 7 | // Input 0 should be x.rank(method="random") 8 | // Input 1 should be y.rank(method="max").cast(pl.Float64).alias("r") 9 | // Input 2 should be (-y).rank(method="max").cast(pl.Float64).alias("l") 10 | 11 | let df = df!("x_rk" => &inputs[0], "r" => &inputs[1], "l" => &inputs[2])?.lazy(); 12 | Ok(df 13 | .sort(["x_rk"], Default::default()) 14 | .select([(lit(1.0) 15 | - ((len().cast(DataType::Float64) / lit(2.0)) 16 | * col("r").diff(1, NullBehavior::Ignore).abs().sum()) 17 | / (col("l") * (len() - col("l"))).sum()) 18 | .alias("statistic")]) 19 | .collect()? 20 | .drop_in_place("statistic") 21 | .unwrap() 22 | .as_materialized_series() 23 | .clone()) 24 | } 25 | 26 | #[polars_expr(output_type=Float64)] 27 | pub fn pl_xi_corr(inputs: &[Series]) -> PolarsResult { 28 | _xi_corr(inputs) 29 | } 30 | 31 | #[polars_expr(output_type_func=simple_stats_output)] 32 | pub fn pl_xi_corr_w_p(inputs: &[Series]) -> PolarsResult { 33 | let n = inputs[0].len(); 34 | let corr = _xi_corr(inputs)?; 35 | let p: f64 = if n < 30 { 36 | f64::NAN 37 | } else { 38 | let sqrt_n = (n as f64).sqrt(); 39 | let c = corr.f64().unwrap(); 40 | let c = c.get(0).unwrap(); 41 | // Two sided 42 | normal::sf_unchecked(sqrt_n * c.abs() / (0.4f64).sqrt(), 0., 1.0) * 2.0 43 | }; 44 | let p = Series::from_vec("pvalue".into(), vec![p]); 45 | let out = StructChunked::from_series("xi_corr".into(), 1, [&corr, &p].into_iter())?; 46 | Ok(out.into_series()) 47 | } 48 | -------------------------------------------------------------------------------- /src/stats_utils/mod.rs: -------------------------------------------------------------------------------- 1 | /// This submodule is mostly taken from the project statrs. See credit section in README.md 2 | /// I do not want to add it as a dependency because a lot of what it offers won't fit. 3 | /// 4 | /// MIT License 5 | /// Copyright (c) 2016 Michael Ma 6 | /// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 7 | /// documentation files (the "Software"), to deal in the Software without restriction, including without limitation 8 | /// the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 9 | /// and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 10 | /// The above copyright notice and this permission notice shall be included in all copies or substantial portions 11 | /// of the Software. 12 | /// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 13 | /// THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 14 | /// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 15 | /// WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR 16 | /// THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | pub mod beta; 18 | pub mod gamma; 19 | pub mod normal; 20 | 21 | pub const PREC_ACC: f64 = 0.0000000000000011102230246251565; 22 | pub const LN_PI: f64 = 1.1447298858494001741434273513530587116472948129153; 23 | //pub const LN_SQRT_2PI: f64 = 0.91893853320467274178032973640561763986139747363778; 24 | pub const LN_2_SQRT_E_OVER_PI: f64 = 0.6207822376352452223455184457816472122518527279025978; 25 | -------------------------------------------------------------------------------- /src/str_ext/fuzz.rs: -------------------------------------------------------------------------------- 1 | /// Returns a simple ratio between two strings or `None` if `ratio < score_cutoff` 2 | /// The simple ratio is 3 | use crate::utils::split_offsets; 4 | use polars::prelude::{arity::binary_elementwise_values, *}; 5 | use pyo3_polars::{ 6 | derive::{polars_expr, CallerContext}, 7 | export::polars_core::{ 8 | error::PolarsError, 9 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 10 | POOL, 11 | }, 12 | }; 13 | use rapidfuzz::fuzz::{ratio, RatioBatchComparator}; 14 | 15 | #[polars_expr(output_type=UInt32)] 16 | fn pl_fuzz(inputs: &[Series], context: CallerContext) -> PolarsResult { 17 | let ca1 = inputs[0].str()?; 18 | let ca2 = inputs[1].str()?; 19 | let parallel = inputs[2].bool()?; 20 | let parallel = parallel.get(0).unwrap(); 21 | let can_parallel = parallel && !context.parallel(); 22 | if ca2.len() == 1 { 23 | let r = ca2.get(0).unwrap(); 24 | let batched = RatioBatchComparator::new(r.chars()); 25 | let out: Float64Chunked = if can_parallel { 26 | let n_threads = POOL.current_num_threads(); 27 | let splits = split_offsets(ca1.len(), n_threads); 28 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 29 | let s1 = ca1.slice(offset as i64, len); 30 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 31 | batched.similarity(s.chars()) 32 | }); 33 | out.downcast_iter().cloned().collect::>() 34 | }); 35 | let chunks = POOL.install(|| chunks_iter.collect::>()); 36 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 37 | } else { 38 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| batched.similarity(s.chars())) 39 | }; 40 | Ok(out.into_series()) 41 | } else if ca1.len() == ca2.len() { 42 | let out: Float64Chunked = if can_parallel { 43 | let n_threads = POOL.current_num_threads(); 44 | let splits = split_offsets(ca1.len(), n_threads); 45 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 46 | let s1 = ca1.slice(offset as i64, len); 47 | let s2 = ca2.slice(offset as i64, len); 48 | let out: Float64Chunked = 49 | binary_elementwise_values(&s1, &s2, |x, y| ratio(x.chars(), y.chars())); 50 | out.downcast_iter().cloned().collect::>() 51 | }); 52 | let chunks = POOL.install(|| chunks_iter.collect::>()); 53 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 54 | } else { 55 | binary_elementwise_values(ca1, ca2, |w1, w2| ratio(w1.chars(), w2.chars())) 56 | }; 57 | Ok(out.into_series()) 58 | } else { 59 | Err(PolarsError::ShapeMismatch( 60 | "Inputs must have the same length or one of them must be a scalar.".into(), 61 | )) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/str_ext/generic_str_distancer.rs: -------------------------------------------------------------------------------- 1 | use polars::{ 2 | prelude::{ 3 | arity::binary_elementwise_values, DataType, Float64Chunked, Series, StringChunked, 4 | UInt32Chunked, 5 | }, 6 | series::IntoSeries, 7 | }; 8 | use pyo3_polars::export::polars_core::{ 9 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 10 | POOL, 11 | }; 12 | /// Polars Series-wise generic str distancers 13 | use rapidfuzz::distance::{damerau_levenshtein, jaro, lcs_seq, levenshtein, osa}; 14 | 15 | use crate::utils::split_offsets; 16 | 17 | // Str Distance Related Helper Functions 18 | pub trait StdBatchedStrDistancer { 19 | fn distance(&self, s: &str) -> u32; 20 | fn normalized_similarity(&self, s: &str) -> f64; 21 | } 22 | 23 | macro_rules! StdBatchedStrDistanceImpl { 24 | ($batch_struct: ty) => { 25 | impl StdBatchedStrDistancer for $batch_struct { 26 | fn distance(&self, s: &str) -> u32 { 27 | self.distance(s.chars()) as u32 28 | } 29 | 30 | fn normalized_similarity(&self, s: &str) -> f64 { 31 | self.normalized_similarity(s.chars()) 32 | } 33 | } 34 | }; 35 | } 36 | 37 | StdBatchedStrDistanceImpl!(lcs_seq::BatchComparator); 38 | StdBatchedStrDistanceImpl!(osa::BatchComparator); 39 | StdBatchedStrDistanceImpl!(levenshtein::BatchComparator); 40 | StdBatchedStrDistanceImpl!(damerau_levenshtein::BatchComparator); 41 | StdBatchedStrDistanceImpl!(jaro::BatchComparator); 42 | 43 | // ------------------------------------------------------------------------------------- 44 | 45 | pub fn generic_batched_distance(batched: T, ca: &StringChunked, parallel: bool) -> Series 46 | where 47 | T: StdBatchedStrDistancer + std::marker::Sync, 48 | { 49 | let out: UInt32Chunked = if parallel { 50 | let n_threads = POOL.current_num_threads(); 51 | let splits = split_offsets(ca.len(), n_threads); 52 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 53 | let s1 = ca.slice(offset as i64, len); 54 | let out: UInt32Chunked = 55 | s1.apply_nonnull_values_generic(DataType::UInt32, |s| batched.distance(s)); 56 | out.downcast_iter().cloned().collect::>() 57 | }); 58 | let chunks = POOL.install(|| chunks_iter.collect::>()); 59 | UInt32Chunked::from_chunk_iter(ca.name().clone(), chunks.into_iter().flatten()) 60 | } else { 61 | ca.apply_nonnull_values_generic(DataType::UInt32, |s| batched.distance(s)) 62 | }; 63 | out.into_series() 64 | } 65 | 66 | pub fn generic_batched_sim(batched: T, ca: &StringChunked, parallel: bool) -> Series 67 | where 68 | T: StdBatchedStrDistancer + std::marker::Sync, 69 | { 70 | let out: Float64Chunked = if parallel { 71 | let n_threads = POOL.current_num_threads(); 72 | let splits = split_offsets(ca.len(), n_threads); 73 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 74 | let s1 = ca.slice(offset as i64, len); 75 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 76 | batched.normalized_similarity(s) 77 | }); 78 | out.downcast_iter().cloned().collect::>() 79 | }); 80 | let chunks = POOL.install(|| chunks_iter.collect::>()); 81 | Float64Chunked::from_chunk_iter(ca.name().clone(), chunks.into_iter().flatten()) 82 | } else { 83 | ca.apply_nonnull_values_generic(DataType::Float64, |s| batched.normalized_similarity(s)) 84 | }; 85 | out.into_series() 86 | } 87 | 88 | pub fn generic_binary_distance( 89 | func: fn(&str, &str) -> u32, 90 | ca1: &StringChunked, 91 | ca2: &StringChunked, 92 | parallel: bool, 93 | ) -> Series { 94 | let out: UInt32Chunked = if parallel { 95 | let n_threads = POOL.current_num_threads(); 96 | let splits = split_offsets(ca1.len(), n_threads); 97 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 98 | let s1 = ca1.slice(offset as i64, len); 99 | let s2 = ca2.slice(offset as i64, len); 100 | let out: UInt32Chunked = binary_elementwise_values(&s1, &s2, func); 101 | out.downcast_iter().cloned().collect::>() 102 | }); 103 | let chunks = POOL.install(|| chunks_iter.collect::>()); 104 | UInt32Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 105 | } else { 106 | binary_elementwise_values(ca1, ca2, func) 107 | }; 108 | out.into_series() 109 | } 110 | 111 | pub fn generic_binary_sim( 112 | func: fn(&str, &str) -> f64, 113 | ca1: &StringChunked, 114 | ca2: &StringChunked, 115 | parallel: bool, 116 | ) -> Series { 117 | let out: Float64Chunked = if parallel { 118 | let n_threads = POOL.current_num_threads(); 119 | let splits = split_offsets(ca1.len(), n_threads); 120 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 121 | let s1 = ca1.slice(offset as i64, len); 122 | let s2 = ca2.slice(offset as i64, len); 123 | let out: Float64Chunked = binary_elementwise_values(&s1, &s2, func); 124 | out.downcast_iter().cloned().collect::>() 125 | }); 126 | let chunks = POOL.install(|| chunks_iter.collect::>()); 127 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 128 | } else { 129 | binary_elementwise_values(ca1, ca2, func) 130 | }; 131 | out.into_series() 132 | } 133 | -------------------------------------------------------------------------------- /src/str_ext/inflections.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::*; 2 | use pyo3_polars::derive::polars_expr; 3 | 4 | #[polars_expr(output_type=String)] 5 | fn pl_to_camel(inputs: &[Series]) -> PolarsResult { 6 | let ca = inputs[0].str()?; 7 | let out: StringChunked = ca.apply_values(|s| inflections::case::to_camel_case(s).into()); 8 | Ok(out.into_series()) 9 | } 10 | 11 | #[polars_expr(output_type=String)] 12 | fn pl_to_snake(inputs: &[Series]) -> PolarsResult { 13 | let ca = inputs[0].str()?; 14 | let out: StringChunked = ca.apply_values(|s| inflections::case::to_snake_case(s).into()); 15 | Ok(out.into_series()) 16 | } 17 | 18 | #[polars_expr(output_type=String)] 19 | fn pl_to_pascal(inputs: &[Series]) -> PolarsResult { 20 | let ca = inputs[0].str()?; 21 | let out: StringChunked = ca.apply_values(|s| inflections::case::to_pascal_case(s).into()); 22 | Ok(out.into_series()) 23 | } 24 | 25 | #[polars_expr(output_type=String)] 26 | fn pl_to_constant(inputs: &[Series]) -> PolarsResult { 27 | let ca = inputs[0].str()?; 28 | let out: StringChunked = ca.apply_values(|s| inflections::case::to_constant_case(s).into()); 29 | Ok(out.into_series()) 30 | } 31 | -------------------------------------------------------------------------------- /src/str_ext/jaro.rs: -------------------------------------------------------------------------------- 1 | use super::generic_str_distancer::{generic_batched_sim, generic_binary_sim}; 2 | use crate::utils::split_offsets; 3 | use polars::prelude::{arity::binary_elementwise_values, *}; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | error::PolarsError, 8 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 9 | POOL, 10 | }, 11 | }; 12 | use rapidfuzz::distance::{jaro, jaro_winkler}; 13 | 14 | #[inline] 15 | fn jaro_sim(s1: &str, s2: &str) -> f64 { 16 | jaro::normalized_similarity(s1.chars(), s2.chars()) 17 | } 18 | 19 | #[inline] 20 | fn jw_sim(s1: &str, s2: &str, weight: f64) -> f64 { 21 | jaro_winkler::normalized_similarity_with_args( 22 | s1.chars(), 23 | s2.chars(), 24 | &jaro_winkler::Args::default().prefix_weight(weight), 25 | ) 26 | } 27 | 28 | #[polars_expr(output_type=Float64)] 29 | fn pl_jaro(inputs: &[Series], context: CallerContext) -> PolarsResult { 30 | let ca1 = inputs[0].str()?; 31 | let ca2 = inputs[1].str()?; 32 | let parallel = inputs[2].bool()?; 33 | let parallel = parallel.get(0).unwrap(); 34 | let can_parallel = parallel && !context.parallel(); 35 | if ca2.len() == 1 { 36 | let r = ca2.get(0).unwrap(); 37 | let batched = jaro::BatchComparator::new(r.chars()); 38 | Ok(generic_batched_sim(batched, ca1, can_parallel)) 39 | } else if ca1.len() == ca2.len() { 40 | Ok(generic_binary_sim(jaro_sim, ca1, ca2, can_parallel)) 41 | } else { 42 | Err(PolarsError::ShapeMismatch( 43 | "Inputs must have the same length or one of them must be a scalar.".into(), 44 | )) 45 | } 46 | } 47 | 48 | #[polars_expr(output_type=Float64)] 49 | fn pl_jw(inputs: &[Series], context: CallerContext) -> PolarsResult { 50 | let ca1 = inputs[0].str()?; 51 | let ca2 = inputs[1].str()?; 52 | let weight = inputs[2].f64()?; 53 | let weight = weight.get(0).unwrap_or(0.1); 54 | let parallel = inputs[3].bool()?; 55 | let parallel = parallel.get(0).unwrap(); 56 | let can_parallel = parallel && !context.parallel(); 57 | if ca2.len() == 1 { 58 | let r = ca2.get(0).unwrap(); 59 | let batched = jaro_winkler::BatchComparator::new(r.chars()); 60 | let out: Float64Chunked = if can_parallel { 61 | let n_threads = POOL.current_num_threads(); 62 | let splits = split_offsets(ca1.len(), n_threads); 63 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 64 | let s1 = ca1.slice(offset as i64, len); 65 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 66 | batched.similarity_with_args( 67 | s.chars(), 68 | &jaro_winkler::Args::default().prefix_weight(weight), 69 | ) 70 | }); 71 | out.downcast_iter().cloned().collect::>() 72 | }); 73 | let chunks = POOL.install(|| chunks_iter.collect::>()); 74 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 75 | } else { 76 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| { 77 | batched.similarity_with_args( 78 | s.chars(), 79 | &jaro_winkler::Args::default().prefix_weight(weight), 80 | ) 81 | }) 82 | }; 83 | Ok(out.into_series()) 84 | } else if ca1.len() == ca2.len() { 85 | let out: Float64Chunked = if parallel { 86 | let n_threads = POOL.current_num_threads(); 87 | let splits = split_offsets(ca1.len(), n_threads); 88 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 89 | let s1 = ca1.slice(offset as i64, len); 90 | let s2 = ca2.slice(offset as i64, len); 91 | let out: Float64Chunked = 92 | binary_elementwise_values(&s1, &s2, |s1, s2| jw_sim(s1, s2, weight)); 93 | out.downcast_iter().cloned().collect::>() 94 | }); 95 | let chunks = POOL.install(|| chunks_iter.collect::>()); 96 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 97 | } else { 98 | binary_elementwise_values(ca1, ca2, |s1, s2| jw_sim(s1, s2, weight)) 99 | }; 100 | Ok(out.into_series()) 101 | } else { 102 | Err(PolarsError::ShapeMismatch( 103 | "Inputs must have the same length or one of them must be a scalar.".into(), 104 | )) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/str_ext/lcs_str.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::{arity::binary_elementwise_values, *}; 2 | use pyo3_polars::{ 3 | derive::{polars_expr, CallerContext}, 4 | export::polars_core::{ 5 | utils::rayon::iter::{IntoParallelIterator, ParallelIterator}, POOL 6 | }, 7 | }; 8 | use crate::utils::split_offsets; 9 | 10 | /// Finds the longest common substring between two input strings. 11 | /// 12 | /// This function uses a space-optimized dynamic programming approach to solve the 13 | /// longest common substring problem. It constructs a 2D matrix conceptually, 14 | /// but only stores two rows (`prev_row` and `curr_row`) at any given time, 15 | /// where `curr_row[j]` stores the length of the longest common suffix 16 | /// of `s1[0...i-1]` and `s2[0...j-1]`. 17 | /// 18 | /// The space complexity is reduced from O(len1 * len2) to O(min(len1, len2)). 19 | /// The time complexity remains O(len1 * len2). 20 | /// 21 | /// See Wikipedia for more details 22 | pub fn lcs_substr_extract(s1: &str, s2: &str) -> String { 23 | // Convert string slices to character vectors for easier indexing. 24 | // This handles multi-byte UTF-8 characters correctly. 25 | let c1: Vec = s1.chars().collect(); 26 | let c2: Vec = s2.chars().collect(); 27 | 28 | let (longer_chars, shorter_chars) = if c1.len() >= c2.len() { 29 | (c1, c2) 30 | } else { 31 | (c2, c1) 32 | }; 33 | 34 | let longer_len = longer_chars.len(); 35 | let shorter_len = shorter_chars.len(); 36 | 37 | // Handle edge cases where one or both strings are empty. 38 | if longer_len == 0 || shorter_len == 0 { 39 | return String::new(); 40 | } 41 | 42 | // Variables to keep track of the maximum length found and the 43 | // ending index of the longest common substring within `longer_chars`. 44 | let mut max_len = 0; 45 | let mut end_index_in_longer_chars = 0; 46 | 47 | // Use two rows for the dynamic programming table: `dp_prev` and `dp_curr`. 48 | // Each row's size is based on the length of the shorter string + 1. 49 | // This ensures space complexity is O(min(len1, len2)). 50 | let mut dp_prev = vec![0; shorter_len + 1]; 51 | let mut dp_curr = vec![0; shorter_len + 1]; 52 | 53 | // Fill the dp table using only two rows. 54 | // The outer loop iterates through the `longer_chars`. 55 | // The inner loop iterates through the `shorter_chars`. 56 | for i in 1..=longer_len { 57 | for j in 1..=shorter_len { 58 | // If the current characters from both strings match, 59 | // the length of the common suffix is extended. 60 | // It's 1 plus the value from the diagonal element in the `dp_prev` row. 61 | if longer_chars[i - 1] == shorter_chars[j - 1] { 62 | dp_curr[j] = 1 + dp_prev[j - 1]; 63 | 64 | // If the newly calculated length is greater than the current maximum, 65 | // update `max_len` and store the ending index in `longer_chars`. 66 | if dp_curr[j] > max_len { 67 | max_len = dp_curr[j]; 68 | end_index_in_longer_chars = i - 1; // 0-based index of the last char in `longer_chars` 69 | } 70 | } else { 71 | // If characters do not match, there is no common suffix ending at these positions, 72 | // so the length is 0. 73 | dp_curr[j] = 0; 74 | } 75 | } 76 | // After processing the current row, copy its contents to `dp_prev` 77 | dp_prev.copy_from_slice(&dp_curr); 78 | } 79 | 80 | // If `max_len` is 0, it means no common substring was found. 81 | if max_len == 0 { 82 | return String::new(); 83 | } 84 | 85 | // Extract the longest common substring from `longer_chars` 86 | // using the calculated `max_len` and `end_index_in_longer_chars`. 87 | // The starting index is derived by subtracting `max_len` and adding 1. 88 | let start_index_in_longer_chars = end_index_in_longer_chars - max_len + 1; 89 | longer_chars[start_index_in_longer_chars..=end_index_in_longer_chars] 90 | .iter() 91 | .collect::() 92 | } 93 | 94 | #[polars_expr(output_type=String)] 95 | fn pl_lcs_substr(inputs: &[Series], context: CallerContext) -> PolarsResult { 96 | let ca1 = inputs[0].str()?; 97 | let ca2 = inputs[1].str()?; 98 | let parallel = inputs[2].bool()?; 99 | let parallel = parallel.get(0).unwrap(); 100 | let can_parallel = parallel && !context.parallel(); 101 | if ca2.len() == 1 { 102 | let r = ca2.get(0).unwrap(); 103 | if can_parallel { 104 | let ca = ca1 105 | .par_iter() 106 | .map(|ss| ss.map(|s| lcs_substr_extract(s, r))) 107 | .collect::(); 108 | Ok(ca.into_series()) 109 | } else { 110 | let ca = ca1.apply_values(|s| lcs_substr_extract(s, r).into()); 111 | Ok(ca.into_series()) 112 | } 113 | } else if ca1.len() == ca2.len() { 114 | if can_parallel { 115 | let n_threads = POOL.current_num_threads(); 116 | let splits = split_offsets(ca1.len(), n_threads); 117 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 118 | let s1 = ca1.slice(offset as i64, len); 119 | let s2 = ca2.slice(offset as i64, len); 120 | let out: StringChunked = binary_elementwise_values(&s1, &s2, lcs_substr_extract); 121 | out.downcast_iter().cloned().collect::>() 122 | }); 123 | let chunks = POOL.install(|| chunks_iter.collect::>()); 124 | let ca = StringChunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()); 125 | Ok(ca.into_series()) 126 | } else { 127 | let ca: StringChunked = 128 | binary_elementwise_values(ca1, ca2, lcs_substr_extract); 129 | Ok(ca.into_series()) 130 | } 131 | } else { 132 | Err(PolarsError::ShapeMismatch( 133 | "Inputs must have the same length or one of them must be a scalar.".into(), 134 | )) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/str_ext/mod.rs: -------------------------------------------------------------------------------- 1 | mod fuzz; 2 | mod generic_str_distancer; 3 | mod hamming; 4 | mod inflections; 5 | mod jaro; 6 | mod lcs_seq; 7 | mod lcs_str; 8 | mod levenshtein; 9 | mod nearest_str; 10 | mod osa; 11 | mod overlap; 12 | mod sorensen_dice; 13 | mod str_cleaning; 14 | mod str_jaccard; 15 | mod tversky; 16 | 17 | // Hashbrown has better perf than Rust's HashSet 18 | use hashbrown::HashSet; 19 | 20 | #[inline(always)] 21 | pub fn str_set_sim_helper(w1: &str, w2: &str, ngram: usize) -> (usize, usize, usize) { 22 | // output: set 1 size, set 2 size, intersection size 23 | 24 | let w1_len = w1.len(); 25 | let w2_len = w2.len(); 26 | 27 | // as long as intersection size is 0, output will be correct 28 | if (w1_len == 0) || (w2_len == 0) { 29 | return (0, 0, 0); 30 | } 31 | 32 | // Both are nonempty 33 | // Another version that has slices of size <= n? 34 | let s1: HashSet<&[u8]> = if w1_len < ngram { 35 | HashSet::from_iter([w1.as_bytes()]) 36 | } else { 37 | HashSet::from_iter(w1.as_bytes().windows(ngram)) 38 | }; 39 | 40 | let s2: HashSet<&[u8]> = if w2_len < ngram { 41 | HashSet::from_iter([w2.as_bytes()]) 42 | } else { 43 | HashSet::from_iter(w2.as_bytes().windows(ngram)) 44 | }; 45 | 46 | let intersection = s1.intersection(&s2).count(); 47 | (s1.len(), s2.len(), intersection) 48 | } 49 | -------------------------------------------------------------------------------- /src/str_ext/nearest_str.rs: -------------------------------------------------------------------------------- 1 | use polars::prelude::*; 2 | use pyo3_polars::derive::polars_expr; 3 | use rapidfuzz::distance::{hamming, levenshtein}; 4 | use serde::Deserialize; 5 | 6 | #[derive(Deserialize, Debug)] 7 | pub(crate) struct NearestStrKwargs { 8 | pub(crate) word: String, 9 | pub(crate) metric: String, 10 | pub(crate) threshold: usize, 11 | } 12 | 13 | fn levenshtein_nearest<'a>(s: &'a StringChunked, cutoff: usize, word: String) -> Option<&'a str> { 14 | let batched = levenshtein::BatchComparator::new(word.chars()); 15 | // Most similar == having smallest distance 16 | let mut best: usize = usize::MAX; 17 | let mut actual_cutoff = levenshtein::Args::default().score_cutoff(cutoff); 18 | let mut nearest_str: Option<&str> = None; 19 | for arr in s.downcast_iter() { 20 | for w in arr.values_iter() { 21 | if let Some(d) = batched.distance_with_args(w.chars(), &actual_cutoff) { 22 | if d == 0 { 23 | return Some(w); 24 | } else if d < best { 25 | best = d; 26 | nearest_str = Some(w); 27 | actual_cutoff = actual_cutoff.score_cutoff(best); 28 | } 29 | } 30 | } 31 | } 32 | nearest_str 33 | } 34 | 35 | fn hamming_nearest<'a>(s: &'a StringChunked, cutoff: usize, word: String) -> Option<&'a str> { 36 | let batched = hamming::BatchComparator::new(word.chars()); 37 | let mut actual_cutoff = hamming::Args::default().score_cutoff(cutoff); 38 | let mut best: usize = usize::MAX; 39 | let mut nearest_str: Option<&str> = None; 40 | 41 | for arr in s.downcast_iter() { 42 | for w in arr.values_iter() { 43 | if let Ok(ss) = batched.distance_with_args(w.chars(), &actual_cutoff) { 44 | if let Some(d) = ss { 45 | if d == 0 { 46 | return Some(w); 47 | } else if d < best { 48 | best = d; 49 | nearest_str = Some(w); 50 | actual_cutoff = actual_cutoff.score_cutoff(best); 51 | } 52 | } 53 | } 54 | } 55 | } 56 | nearest_str 57 | } 58 | 59 | #[polars_expr(output_type=String)] 60 | pub fn pl_nearest_str(inputs: &[Series], kwargs: NearestStrKwargs) -> PolarsResult { 61 | let s = inputs[0].str()?; 62 | let word = kwargs.word; 63 | let cutoff = kwargs.threshold; 64 | let func = match kwargs.metric.as_str() { 65 | "hamming" => hamming_nearest, 66 | _ => levenshtein_nearest, 67 | }; 68 | let mut builder = StringChunkedBuilder::new(s.name().clone(), 1); 69 | builder.append_option(func(s, cutoff, word)); 70 | let ca = builder.finish(); 71 | Ok(ca.into_series()) 72 | } 73 | -------------------------------------------------------------------------------- /src/str_ext/osa.rs: -------------------------------------------------------------------------------- 1 | use super::generic_str_distancer::{ 2 | generic_batched_distance, generic_batched_sim, generic_binary_distance, generic_binary_sim, 3 | }; 4 | use polars::prelude::*; 5 | use pyo3_polars::derive::{polars_expr, CallerContext}; 6 | use rapidfuzz::distance::osa; 7 | 8 | #[inline(always)] 9 | fn osa(s1: &str, s2: &str) -> u32 { 10 | osa::distance(s1.chars(), s2.chars()) as u32 11 | } 12 | 13 | #[inline(always)] 14 | fn osa_sim(s1: &str, s2: &str) -> f64 { 15 | osa::normalized_similarity(s1.chars(), s2.chars()) 16 | } 17 | 18 | #[polars_expr(output_type=UInt32)] 19 | fn pl_osa(inputs: &[Series], context: CallerContext) -> PolarsResult { 20 | let ca1 = inputs[0].str()?; 21 | let ca2 = inputs[1].str()?; 22 | let parallel = inputs[2].bool()?; 23 | let parallel = parallel.get(0).unwrap(); 24 | let can_parallel = parallel && !context.parallel(); 25 | if ca2.len() == 1 { 26 | let r = ca2.get(0).unwrap(); 27 | let batched = osa::BatchComparator::new(r.chars()); 28 | Ok(generic_batched_distance(batched, ca1, can_parallel)) 29 | } else if ca1.len() == ca2.len() { 30 | Ok(generic_binary_distance(osa, ca1, ca2, can_parallel)) 31 | } else { 32 | Err(PolarsError::ShapeMismatch( 33 | "Inputs must have the same length or one of them must be a scalar.".into(), 34 | )) 35 | } 36 | } 37 | 38 | #[polars_expr(output_type=Float64)] 39 | fn pl_osa_sim(inputs: &[Series], context: CallerContext) -> PolarsResult { 40 | let ca1 = inputs[0].str()?; 41 | let ca2 = inputs[1].str()?; 42 | let parallel = inputs[2].bool()?; 43 | let parallel = parallel.get(0).unwrap(); 44 | let can_parallel = parallel && !context.parallel(); 45 | if ca2.len() == 1 { 46 | let r = ca2.get(0).unwrap(); 47 | let batched = osa::BatchComparator::new(r.chars()); 48 | Ok(generic_batched_sim(batched, ca1, can_parallel)) 49 | } else if ca1.len() == ca2.len() { 50 | Ok(generic_binary_sim(osa_sim, ca1, ca2, can_parallel)) 51 | } else { 52 | Err(PolarsError::ShapeMismatch( 53 | "Inputs must have the same length or one of them must be a scalar.".into(), 54 | )) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/str_ext/overlap.rs: -------------------------------------------------------------------------------- 1 | use super::str_set_sim_helper; 2 | use crate::utils::split_offsets; 3 | use polars::prelude::{arity::binary_elementwise_values, *}; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 8 | POOL, 9 | }, 10 | }; 11 | 12 | // Todo. 13 | // Can optimize the case when ca1 is scalar. No need to regenerate the hashset in that case. 14 | 15 | #[inline(always)] 16 | fn overlap_coeff(w1: &str, w2: &str, ngram: usize) -> f64 { 17 | let (s1, s2, intersect) = str_set_sim_helper(w1, w2, ngram); 18 | (intersect as f64) / ((s1.min(s2)) as f64) 19 | } 20 | 21 | #[polars_expr(output_type=Float64)] 22 | fn pl_overlap_coeff(inputs: &[Series], context: CallerContext) -> PolarsResult { 23 | let ca1 = inputs[0].str()?; 24 | let ca2 = inputs[1].str()?; 25 | 26 | // ngram size 27 | let ngram = inputs[2].u32()?; 28 | let ngram = ngram.get(0).unwrap() as usize; 29 | // parallel 30 | let parallel = inputs[3].bool()?; 31 | let parallel = parallel.get(0).unwrap(); 32 | let can_parallel = parallel && !context.parallel(); 33 | 34 | if ca2.len() == 1 { 35 | let r = ca2.get(0).unwrap(); 36 | let out: Float64Chunked = if can_parallel { 37 | let n_threads = POOL.current_num_threads(); 38 | let splits = split_offsets(ca1.len(), n_threads); 39 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 40 | let s1 = ca1.slice(offset as i64, len); 41 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 42 | overlap_coeff(s, r, ngram) 43 | }); 44 | out.downcast_iter().cloned().collect::>() 45 | }); 46 | let chunks = POOL.install(|| chunks_iter.collect::>()); 47 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 48 | } else { 49 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| overlap_coeff(s, r, ngram)) 50 | }; 51 | Ok(out.into_series()) 52 | } else if ca1.len() == ca2.len() { 53 | let out: Float64Chunked = if can_parallel { 54 | let n_threads = POOL.current_num_threads(); 55 | let splits = split_offsets(ca1.len(), n_threads); 56 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 57 | let s1 = ca1.slice(offset as i64, len); 58 | let s2 = ca2.slice(offset as i64, len); 59 | let out: Float64Chunked = 60 | binary_elementwise_values(&s1, &s2, |x, y| overlap_coeff(x, y, ngram)); 61 | out.downcast_iter().cloned().collect::>() 62 | }); 63 | 64 | let chunks = POOL.install(|| chunks_iter.collect::>()); 65 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 66 | } else { 67 | binary_elementwise_values(ca1, ca2, |x, y| overlap_coeff(x, y, ngram)) 68 | }; 69 | Ok(out.into_series()) 70 | } else { 71 | Err(PolarsError::ShapeMismatch( 72 | "Inputs must have the same length or one of them must be a scalar.".into(), 73 | )) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/str_ext/sorensen_dice.rs: -------------------------------------------------------------------------------- 1 | use super::str_set_sim_helper; 2 | use crate::utils::split_offsets; 3 | use polars::prelude::{arity::binary_elementwise_values, *}; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 8 | POOL, 9 | }, 10 | }; 11 | 12 | // Todo. 13 | // Can optimize the case when ca1 is scalar. No need to regenerate the hashset in that case. 14 | 15 | #[inline(always)] 16 | fn sorensen_dice(w1: &str, w2: &str, ngram: usize) -> f64 { 17 | let (s1, s2, intersect) = str_set_sim_helper(w1, w2, ngram); 18 | ((2 * intersect) as f64) / ((s1 + s2) as f64) 19 | } 20 | 21 | #[polars_expr(output_type=Float64)] 22 | fn pl_sorensen_dice(inputs: &[Series], context: CallerContext) -> PolarsResult { 23 | let ca1 = inputs[0].str()?; 24 | let ca2 = inputs[1].str()?; 25 | 26 | // ngram size 27 | let ngram = inputs[2].u32()?; 28 | let ngram = ngram.get(0).unwrap() as usize; 29 | // parallel 30 | let parallel = inputs[3].bool()?; 31 | let parallel = parallel.get(0).unwrap(); 32 | let can_parallel = parallel && !context.parallel(); 33 | 34 | if ca2.len() == 1 { 35 | let r = ca2.get(0).unwrap(); 36 | let out: Float64Chunked = if can_parallel { 37 | let n_threads = POOL.current_num_threads(); 38 | let splits = split_offsets(ca1.len(), n_threads); 39 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 40 | let s1 = ca1.slice(offset as i64, len); 41 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 42 | sorensen_dice(s, r, ngram) 43 | }); 44 | out.downcast_iter().cloned().collect::>() 45 | }); 46 | let chunks = POOL.install(|| chunks_iter.collect::>()); 47 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 48 | } else { 49 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| sorensen_dice(s, r, ngram)) 50 | }; 51 | Ok(out.into_series()) 52 | } else if ca1.len() == ca2.len() { 53 | let out: Float64Chunked = if can_parallel { 54 | let n_threads = POOL.current_num_threads(); 55 | let splits = split_offsets(ca1.len(), n_threads); 56 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 57 | let s1 = ca1.slice(offset as i64, len); 58 | let s2 = ca2.slice(offset as i64, len); 59 | let out: Float64Chunked = 60 | binary_elementwise_values(&s1, &s2, |x, y| sorensen_dice(x, y, ngram)); 61 | out.downcast_iter().cloned().collect::>() 62 | }); 63 | let chunks = POOL.install(|| chunks_iter.collect::>()); 64 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 65 | } else { 66 | binary_elementwise_values(ca1, ca2, |x, y| sorensen_dice(x, y, ngram)) 67 | }; 68 | Ok(out.into_series()) 69 | } else { 70 | Err(PolarsError::ShapeMismatch( 71 | "Inputs must have the same length or the second of them must be a scalar.".into(), 72 | )) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/str_ext/str_cleaning.rs: -------------------------------------------------------------------------------- 1 | use itertools::Itertools; 2 | use polars::prelude::*; 3 | use pyo3_polars::derive::polars_expr; 4 | use unicode_normalization::UnicodeNormalization; 5 | 6 | enum NormalForm { 7 | NFC, 8 | NFKC, 9 | NFD, 10 | NFKD, 11 | } 12 | 13 | impl TryFrom for NormalForm { 14 | type Error = PolarsError; 15 | fn try_from(value: String) -> PolarsResult { 16 | match value.to_uppercase().as_ref() { 17 | "NFC" => Ok(Self::NFC), 18 | "NFKC" => Ok(Self::NFKC), 19 | "NFD" => Ok(Self::NFD), 20 | "NFKD" => Ok(Self::NFKD), 21 | _ => Err(PolarsError::ComputeError("Unknown NormalizeForm.".into())), 22 | } 23 | } 24 | } 25 | 26 | #[polars_expr(output_type=String)] 27 | fn remove_non_ascii(inputs: &[Series]) -> PolarsResult { 28 | let ca = inputs[0].str()?; 29 | let out = 30 | ca.apply_into_string_amortized(|s, buf| *buf = s.chars().filter(char::is_ascii).collect()); 31 | Ok(out.into_series()) 32 | } 33 | 34 | #[polars_expr(output_type=String)] 35 | fn remove_diacritics(inputs: &[Series]) -> PolarsResult { 36 | let ca = inputs[0].str()?; 37 | let out = 38 | ca.apply_into_string_amortized(|s, buf| *buf = s.nfd().filter(char::is_ascii).collect()); 39 | Ok(out.into_series()) 40 | } 41 | 42 | #[derive(serde::Deserialize)] 43 | struct NormalizeKwargs { 44 | form: String, 45 | } 46 | 47 | #[polars_expr(output_type=String)] 48 | fn normalize_string(inputs: &[Series], kwargs: NormalizeKwargs) -> PolarsResult { 49 | let ca = inputs[0].str()?; 50 | let form: NormalForm = kwargs.form.try_into()?; 51 | let out = match form { 52 | NormalForm::NFC => ca.apply_into_string_amortized(|val, buf| *buf = val.nfc().collect()), 53 | NormalForm::NFKC => ca.apply_into_string_amortized(|val, buf| *buf = val.nfkc().collect()), 54 | NormalForm::NFD => ca.apply_into_string_amortized(|val, buf| *buf = val.nfd().collect()), 55 | NormalForm::NFKD => ca.apply_into_string_amortized(|val, buf| *buf = val.nfkd().collect()), 56 | }; 57 | Ok(out.into_series()) 58 | } 59 | 60 | #[derive(serde::Deserialize)] 61 | struct MapWordsKwargs { 62 | mapping: ahash::HashMap, 63 | } 64 | 65 | #[polars_expr(output_type=String)] 66 | fn map_words(inputs: &[Series], kwargs: MapWordsKwargs) -> PolarsResult { 67 | let ca = inputs[0].str()?; 68 | let mapping = kwargs.mapping; 69 | let out = ca.apply_into_string_amortized(|s, buf| { 70 | buf.push_str( 71 | s.split_whitespace() 72 | .map(|word| mapping.get(word).map_or(word, |v| v)) 73 | .join(" ") 74 | .as_ref(), 75 | ) 76 | }); 77 | Ok(out.into_series()) 78 | } 79 | 80 | #[polars_expr(output_type=String)] 81 | fn normalize_whitespace(inputs: &[Series]) -> PolarsResult { 82 | let ca = inputs[0].str()?; 83 | let out = ca.apply_into_string_amortized(|s, buf| *buf = s.split_whitespace().join(" ")); 84 | Ok(out.into_series()) 85 | } 86 | -------------------------------------------------------------------------------- /src/str_ext/str_jaccard.rs: -------------------------------------------------------------------------------- 1 | use super::str_set_sim_helper; 2 | use crate::utils::split_offsets; 3 | use polars::prelude::{arity::binary_elementwise_values, *}; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | error::PolarsError, 8 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 9 | POOL, 10 | }, 11 | }; 12 | 13 | // Todo. 14 | // Can optimize the case when ca1 is scalar. No need to regenerate the hashset in that case. 15 | 16 | #[inline(always)] 17 | fn str_jaccard(w1: &str, w2: &str, ngram: usize) -> f64 { 18 | let (s1, s2, intersect) = str_set_sim_helper(w1, w2, ngram); 19 | (intersect as f64) / ((s1 + s2 - intersect) as f64) 20 | } 21 | 22 | #[polars_expr(output_type=Float64)] 23 | fn pl_str_jaccard(inputs: &[Series], context: CallerContext) -> PolarsResult { 24 | let ca1 = inputs[0].str()?; 25 | let ca2 = inputs[1].str()?; 26 | 27 | // ngram size 28 | let ngram = inputs[2].u32()?; 29 | let ngram = ngram.get(0).unwrap() as usize; 30 | // parallel 31 | let parallel = inputs[3].bool()?; 32 | let parallel = parallel.get(0).unwrap(); 33 | let can_parallel = parallel && !context.parallel(); 34 | 35 | if ca2.len() == 1 { 36 | let r = ca2.get(0).unwrap(); // .unwrap(); 37 | let out: Float64Chunked = if can_parallel { 38 | let n_threads = POOL.current_num_threads(); 39 | let splits = split_offsets(ca1.len(), n_threads); 40 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 41 | let s1 = ca1.slice(offset as i64, len); 42 | let out: Float64Chunked = s1 43 | .apply_nonnull_values_generic(DataType::Float64, |s| str_jaccard(s, r, ngram)); 44 | out.downcast_iter().cloned().collect::>() 45 | }); 46 | 47 | let chunks = POOL.install(|| chunks_iter.collect::>()); 48 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 49 | } else { 50 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| str_jaccard(s, r, ngram)) 51 | }; 52 | Ok(out.into_series()) 53 | } else if ca1.len() == ca2.len() { 54 | let out: Float64Chunked = if can_parallel { 55 | let n_threads = POOL.current_num_threads(); 56 | let splits = split_offsets(ca1.len(), n_threads); 57 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 58 | let s1 = ca1.slice(offset as i64, len); 59 | let s2 = ca2.slice(offset as i64, len); 60 | let out: Float64Chunked = 61 | binary_elementwise_values(&s1, &s2, |x, y| str_jaccard(x, y, ngram)); 62 | out.downcast_iter().cloned().collect::>() 63 | }); 64 | 65 | let chunks = POOL.install(|| chunks_iter.collect::>()); 66 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 67 | } else { 68 | binary_elementwise_values(ca1, ca2, |x, y| str_jaccard(x, y, ngram)) 69 | }; 70 | Ok(out.into_series()) 71 | } else { 72 | Err(PolarsError::ShapeMismatch( 73 | "Inputs must have the same length or the second of them must be a scalar.".into(), 74 | )) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/str_ext/tversky.rs: -------------------------------------------------------------------------------- 1 | use super::str_set_sim_helper; 2 | use crate::utils::split_offsets; 3 | use polars::prelude::{arity::binary_elementwise_values, *}; 4 | use pyo3_polars::{ 5 | derive::{polars_expr, CallerContext}, 6 | export::polars_core::{ 7 | error::PolarsError, 8 | utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, 9 | POOL, 10 | }, 11 | }; 12 | 13 | // Todo. 14 | // Can optimize the case when ca1 is scalar. No need to regenerate the hashset in that case. 15 | 16 | #[inline(always)] 17 | fn tversky_sim(w1: &str, w2: &str, ngram: usize, alpha: f64, beta: f64) -> f64 { 18 | let (s1, s2, intersect) = str_set_sim_helper(w1, w2, ngram); 19 | let s1ms2 = s1.abs_diff(intersect) as f64; 20 | let s2ms1 = s2.abs_diff(intersect) as f64; 21 | (intersect as f64) / (intersect as f64 + alpha * s1ms2 + beta * s2ms1) 22 | } 23 | 24 | #[polars_expr(output_type=Float64)] 25 | fn pl_tversky_sim(inputs: &[Series], context: CallerContext) -> PolarsResult { 26 | let ca1 = inputs[0].str()?; 27 | let ca2 = inputs[1].str()?; 28 | 29 | // ngram size 30 | let ngram = inputs[2].u32()?; 31 | let ngram = ngram.get(0).unwrap() as usize; 32 | 33 | // Alpha and beta params 34 | let alpha = inputs[3].f64()?; 35 | let alpha = alpha.get(0).unwrap(); 36 | let beta = inputs[4].f64()?; 37 | let beta = beta.get(0).unwrap(); 38 | // parallel 39 | let parallel = inputs[5].bool()?; 40 | let parallel = parallel.get(0).unwrap(); 41 | let can_parallel = parallel && !context.parallel(); 42 | 43 | if ca2.len() == 1 { 44 | let r = ca2.get(0).unwrap(); 45 | let out: Float64Chunked = if can_parallel { 46 | let n_threads = POOL.current_num_threads(); 47 | let splits = split_offsets(ca1.len(), n_threads); 48 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 49 | let s1 = ca1.slice(offset as i64, len); 50 | let out: Float64Chunked = s1.apply_nonnull_values_generic(DataType::Float64, |s| { 51 | tversky_sim(s, r, ngram, alpha, beta) 52 | }); 53 | out.downcast_iter().cloned().collect::>() 54 | }); 55 | 56 | let chunks = POOL.install(|| chunks_iter.collect::>()); 57 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 58 | } else { 59 | ca1.apply_nonnull_values_generic(DataType::Float64, |s| { 60 | tversky_sim(s, r, ngram, alpha, beta) 61 | }) 62 | }; 63 | Ok(out.into_series()) 64 | } else if ca1.len() == ca2.len() { 65 | let out: Float64Chunked = if can_parallel { 66 | let n_threads = POOL.current_num_threads(); 67 | let splits = split_offsets(ca1.len(), n_threads); 68 | let chunks_iter = splits.into_par_iter().map(|(offset, len)| { 69 | let s1 = ca1.slice(offset as i64, len); 70 | let s2 = ca2.slice(offset as i64, len); 71 | let out: Float64Chunked = binary_elementwise_values(&s1, &s2, |x, y| { 72 | tversky_sim(x, y, ngram, alpha, beta) 73 | }); 74 | out.downcast_iter().cloned().collect::>() 75 | }); 76 | 77 | let chunks = POOL.install(|| chunks_iter.collect::>()); 78 | Float64Chunked::from_chunk_iter(ca1.name().clone(), chunks.into_iter().flatten()) 79 | } else { 80 | binary_elementwise_values(ca1, ca2, |x, y| tversky_sim(x, y, ngram, alpha, beta)) 81 | }; 82 | Ok(out.into_series()) 83 | } else { 84 | Err(PolarsError::ShapeMismatch( 85 | "Inputs must have the same length or one of them must be a scalar.".into(), 86 | )) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /tests/requirements-test.txt: -------------------------------------------------------------------------------- 1 | # requirements for testing 2 | maturin[patchelf]>=1.7; sys_platform == "linux" 3 | maturin>=1.7; sys_platform != "linux" 4 | numpy 5 | scikit-learn<=1.5 6 | scipy 7 | pyarrow 8 | pandas 9 | pytest 10 | pytest-benchmark 11 | xicor 12 | category_encoders 13 | copent 14 | astropy 15 | graphviz 16 | altair 17 | vegafusion[embed] 18 | vl-convert-python>=1.6 19 | great-tables>=0.9 20 | statsmodels -------------------------------------------------------------------------------- /tests/test_compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | import polars_ds.exprs.expr_linear as pds_linear 5 | import polars_ds.exprs.num as pds_num 6 | import polars_ds.exprs.string as pds_str 7 | import polars_ds.exprs.stats as pds_stats 8 | import polars_ds.exprs.ts_features as pds_ts 9 | import polars_ds.exprs.expr_knn as pds_knn 10 | import polars_ds.exprs.metrics as pds_metrics 11 | 12 | from polars_ds.compat import compat as pds2 13 | 14 | # --- All functions can be Wrapped --- 15 | # Test compatibility works (able to wrap the function for all the expression functions pds provides) 16 | 17 | def test_lin_reg_works(): 18 | 19 | # If this doesn't fail, the returned function is correct. 20 | # The rest of the usage depends on the user. 21 | for expr in pds_linear.__all__: 22 | _ = getattr(pds2, expr) 23 | else: 24 | assert True 25 | 26 | def test_metrics_works(): 27 | 28 | for expr in pds_metrics.__all__: 29 | _ = getattr(pds2, expr) 30 | else: 31 | assert True 32 | 33 | def test_num_works(): 34 | 35 | for expr in pds_num.__all__: 36 | _ = getattr(pds2, expr) 37 | else: 38 | assert True 39 | 40 | def test_str_works(): 41 | 42 | for expr in pds_str.__all__: 43 | _ = getattr(pds2, expr) 44 | else: 45 | assert True 46 | 47 | def test_stats_works(): 48 | 49 | for expr in pds_stats.__all__: 50 | _ = getattr(pds2, expr) 51 | else: 52 | assert True 53 | 54 | def test_ts_works(): 55 | 56 | for expr in pds_ts.__all__: 57 | _ = getattr(pds2, expr) 58 | else: 59 | assert True 60 | 61 | def test_expr_knn_works(): 62 | 63 | for expr in pds_knn.__all__: 64 | _ = getattr(pds2, expr) 65 | else: 66 | assert True 67 | 68 | 69 | -------------------------------------------------------------------------------- /tests/test_linear_models.py: -------------------------------------------------------------------------------- 1 | import polars as pl 2 | import polars_ds as pds 3 | import pytest 4 | import numpy as np 5 | from sklearn.linear_model import LinearRegression 6 | 7 | from polars_ds.linear_models import OnlineLR, ElasticNet, LR 8 | 9 | 10 | def test_lr_null_policies_for_np(): 11 | from polars_ds.linear_models import _handle_nans_in_np 12 | 13 | size = 5_000 14 | df = ( 15 | pds.frame(size=size) 16 | .select( 17 | pds.random(0.0, 1.0).alias("x1"), 18 | pds.random(0.0, 1.0).alias("x2"), 19 | pds.random(0.0, 1.0).alias("x3"), 20 | ) 21 | .with_row_index() 22 | .with_columns( 23 | x1=pl.when(pl.col("x1") > 0.5).then(None).otherwise(pl.col("x1")), 24 | y=pl.col("x1") + pl.col("x2") * 0.2 - 0.3 * pl.col("x3"), 25 | ) 26 | .with_columns(is_null=pl.col("x1").is_null()) 27 | ) 28 | nulls = df.select("is_null").to_numpy().flatten() 29 | x = df.select("x1", "x2", "x3").to_numpy() 30 | y = df.select("y").to_numpy() 31 | 32 | x_nan, _ = _handle_nans_in_np(x, y, "ignore") 33 | assert np.all(np.isnan(x_nan[nulls][:, 0])) 34 | 35 | with pytest.raises(Exception) as exc_info: 36 | _handle_nans_in_np(x, y, "raise") 37 | assert str(exc_info.value) == "Nulls found in X or y." 38 | 39 | x_skipped, _ = _handle_nans_in_np(x, y, "skip") 40 | assert np.all(x_skipped == x[~nulls]) 41 | 42 | x_zeroed, _ = _handle_nans_in_np(x, y, "zero") 43 | assert np.all( 44 | x_zeroed[nulls][:, 0] == 0.0 45 | ) # checking out the first column because only that has nulls 46 | 47 | x_one, _ = _handle_nans_in_np(x, y, "one") 48 | assert np.all( 49 | x_one[nulls][:, 0] == 1.0 50 | ) # checking out the first column because only that has nulls 51 | 52 | 53 | @pytest.mark.parametrize("solver", ["svd", "cholesky", "qr"]) 54 | def test_lr(solver): 55 | ols = LR(False, 0.0, solver) 56 | df = ( 57 | pds.frame(size=5000) 58 | .select( 59 | pds.random(0.0, 1.0).alias("x1"), 60 | pds.random(0.0, 1.0).alias("x2"), 61 | pds.random(0.0, 1.0).alias("x3"), 62 | ) 63 | .with_row_index() 64 | .with_columns( 65 | y=pl.col("x1") + pl.col("x2") * 0.2 - 0.3 * pl.col("x3") + pds.random() * 0.0001 66 | ) 67 | ) 68 | X = df.select("x1", "x2", "x3").to_numpy() 69 | y = df.select("y").to_numpy() 70 | 71 | ols.fit(X, y) 72 | coeffs = ols.coeffs() 73 | sk_ols = LinearRegression(fit_intercept=False) 74 | sk_ols.fit(X, y) 75 | sklearn_coeffs = sk_ols.coef_ 76 | assert np.all(np.abs(coeffs - sklearn_coeffs) < 1e-6) 77 | 78 | 79 | def test_online_lr(): 80 | 81 | size = 5000 82 | df = ( 83 | pds.frame(size=size) 84 | .select( 85 | pds.random(0.0, 1.0).alias("x1"), 86 | pds.random(0.0, 1.0).alias("x2"), 87 | pds.random(0.0, 1.0).alias("x3"), 88 | ) 89 | .with_row_index() 90 | .with_columns( 91 | y=pl.col("x1") + pl.col("x2") * 0.2 - 0.3 * pl.col("x3") + pds.random() * 0.0001 92 | ) 93 | ) 94 | X = df.select("x1", "x2", "x3").to_numpy() 95 | y = df.select("y").to_numpy() 96 | 97 | olr = OnlineLR() # no bias, normal 98 | sk_lr = LinearRegression(fit_intercept=False) 99 | 100 | olr.fit(X[:10], y[:10]) 101 | coeffs = olr.coeffs() 102 | sk_lr.fit(X[:10], y[:10]) 103 | sklearn_coeffs = sk_lr.coef_ 104 | 105 | pred = olr.predict(X[:10]).flatten() 106 | sk_pred = sk_lr.predict(X[:10]).flatten() 107 | assert np.all(np.abs(pred - sk_pred) < 1e-6) 108 | assert np.all(np.abs(coeffs - sklearn_coeffs) < 1e-6) 109 | 110 | for i in range(10, 20): 111 | olr.update(X[i], y[i]) 112 | coeffs = olr.coeffs() 113 | sk_lr = LinearRegression(fit_intercept=False) 114 | sk_lr.fit(X[: i + 1], y[: i + 1]) 115 | sklearn_coeffs = sk_lr.coef_ 116 | assert np.all(np.abs(coeffs - sklearn_coeffs) < 1e-6) 117 | 118 | 119 | def _test_elastic_net(add_bias: bool = False): 120 | import sklearn.linear_model as lm 121 | 122 | l1_reg = 0.1 123 | l2_reg = 0.1 124 | alpha = l1_reg + l2_reg 125 | l1_ratio = l1_reg / (l1_reg + l2_reg) 126 | 127 | df = ( 128 | pds.frame(size=5000) 129 | .select( 130 | pds.random(0.0, 1.0).alias("x1"), 131 | pds.random(0.0, 1.0).alias("x2"), 132 | pds.random(0.0, 1.0).alias("x3"), 133 | ) 134 | .with_row_index() 135 | .with_columns( 136 | y=pl.col("x1") + pl.col("x2") * 0.2 - 0.3 * pl.col("x3"), 137 | ) 138 | .with_columns(is_null=pl.col("x1").is_null()) 139 | ) 140 | 141 | X = df.select("x1", "x2", "x3").to_numpy() 142 | y = df.select("y").to_numpy() 143 | en = ElasticNet(l1_reg=l1_reg, l2_reg=l2_reg, has_bias=add_bias) 144 | elastic = lm.ElasticNet(alpha=alpha, l1_ratio=l1_ratio, fit_intercept=add_bias) 145 | 146 | en.fit(X, y) 147 | pds_res = en.coeffs() 148 | elastic.fit(X, y) 149 | sklearn_res = elastic.coef_ 150 | assert np.all(np.abs(pds_res - sklearn_res) < 1e-4) 151 | 152 | if add_bias is True: 153 | pds_bias = en.bias() 154 | sklearn_bias = elastic.intercept_ 155 | assert np.all(np.abs(pds_bias - sklearn_bias) < 1e-4) 156 | 157 | 158 | def test_elastic_net(): 159 | _test_elastic_net(add_bias=False) 160 | _test_elastic_net(add_bias=True) 161 | -------------------------------------------------------------------------------- /tests/test_spatial_queries.py: -------------------------------------------------------------------------------- 1 | import polars_ds as pds 2 | import numpy as np 3 | from polars_ds.spatial import KDTree as KDT 4 | from scipy.spatial import KDTree 5 | 6 | 7 | def test_kdtree(): 8 | size = 2000 9 | df = pds.frame(size=size).with_columns(*(pds.random().alias(f"var{i}") for i in range(3))) 10 | X = df.select(f"var{i}" for i in range(3)).to_numpy(order="c") 11 | 12 | pds_tree = KDT(X, distance="l2") 13 | scipy_tree = KDTree(X, copy_data=True) 14 | 15 | distances_pds, indices_pds = pds_tree.knn(X, k=10, parallel=False) 16 | distances_scipy, indices_scipy = scipy_tree.query(X, k=10, p=2) 17 | 18 | assert np.all(distances_pds == distances_scipy) 19 | assert np.all(indices_pds.astype(np.int64) == indices_scipy.astype(np.int64)) 20 | 21 | within_pds = pds_tree.within(X, r=0.1, sort=False) 22 | within_scipy = scipy_tree.query_ball_point(X, r=0.1, p=2, return_sorted=False) 23 | assert all(set(x) == set(y) for x, y in zip(within_pds, within_scipy)) 24 | 25 | within_count = pds_tree.within_count(X, r=0.1) 26 | assert all(int(n) == len(pts) for n, pts in zip(within_count, within_pds)) 27 | --------------------------------------------------------------------------------