├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation.md │ └── feature-request-template.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yaml └── workflows │ ├── dependencies.yml │ ├── schedule-dependencies.yml │ ├── style.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── Makefile ├── docs ├── Makefile ├── README.md ├── _scripts │ ├── cross-validation.py │ ├── datasets.py │ ├── debug-pipeline.py │ ├── fairness.py │ ├── feature-selection.py │ ├── linear-models.py │ ├── meta-models.py │ ├── mixture-methods.py │ ├── naive-bayes.py │ ├── outliers.py │ ├── pandas-pipelines.py │ └── preprocessing.py ├── _static │ ├── contribution │ │ └── contribute.png │ ├── cross-validation │ │ ├── example-1.png │ │ ├── example-2.png │ │ ├── example-3.png │ │ ├── example-4.png │ │ ├── example-5.png │ │ ├── group-time-series-split.png │ │ ├── grp-summary.md │ │ ├── grp-ts.md │ │ ├── kfold.png │ │ ├── summary.md │ │ └── ts.md │ ├── datasets │ │ ├── abalone.md │ │ ├── abalone.png │ │ ├── arrests.md │ │ ├── arrests.png │ │ ├── chicken.md │ │ ├── chicken.png │ │ ├── creditcards.md │ │ ├── creditcards.png │ │ ├── hearts.md │ │ ├── hearts.png │ │ ├── heroes.md │ │ ├── heroes.png │ │ ├── penguins.md │ │ ├── penguins.png │ │ ├── timeseries.md │ │ └── timeseries.png │ ├── fairness │ │ ├── boston-description.txt │ │ ├── demographic-parity-grid-results.png │ │ ├── drop-two.png │ │ ├── equal-opportunity-grid-results.png │ │ ├── information-filter-coefs.md │ │ ├── original-situation.png │ │ ├── predict-boston-simple.png │ │ ├── projections.png │ │ └── use-info-filter.png │ ├── feature-selection │ │ └── mrmr-feature-selection-mnist.png │ ├── linear-models │ │ ├── grid-span-sigma-01.png │ │ ├── grid-span-sigma-02.png │ │ ├── grid.html │ │ ├── lad-data.png │ │ ├── lad-fit.png │ │ ├── lowess-rolling-001.gif │ │ ├── lowess-rolling-01.gif │ │ ├── lowess-rolling.gif │ │ ├── lowess-two-predictions.gif │ │ ├── lowess.png │ │ ├── lr-fit.png │ │ └── quantile-fit.png │ ├── logo.png │ ├── meta-models │ │ ├── baseline-model.png │ │ ├── confusion-balanced-grid.html │ │ ├── confusion-balancer-results.png │ │ ├── decay-functions.png │ │ ├── decay-model.png │ │ ├── grouped-df.png │ │ ├── grouped-dummy-model.png │ │ ├── grouped-model.png │ │ ├── grouped-np.png │ │ ├── grouped-transform.png │ │ ├── make-blobs.png │ │ ├── ordinal-classification.png │ │ ├── ordinal_data.md │ │ ├── outlier-classifier-stacking.html │ │ ├── outlier-classifier.html │ │ ├── penguins.md │ │ ├── skewed-data.png │ │ ├── threshold-chart.png │ │ └── ts-data.png │ ├── mixture-methods │ │ ├── gmm-classifier.png │ │ ├── gmm-outlier-detector.png │ │ ├── gmm-outlier-multi-threshold.png │ │ └── outlier-mixture-threshold.png │ ├── naive-bayes │ │ ├── model-density.png │ │ ├── model-results.png │ │ ├── naive-bayes.png │ │ └── simulated-data.png │ ├── outliers │ │ ├── bayesian-gmm-outlier.png │ │ ├── decomposition.png │ │ ├── gmm-outlier.png │ │ ├── pca-outlier.png │ │ ├── regr-outlier.png │ │ └── umap-outlier.png │ ├── preprocessing │ │ ├── column-capper.png │ │ ├── estimator-transformer-1.png │ │ ├── estimator-transformer-2.png │ │ ├── formulaic-1.md │ │ ├── formulaic-2.md │ │ ├── identity-transformer-1.png │ │ ├── identity-transformer-2.png │ │ ├── interval-encoder-1.png │ │ ├── interval-encoder-2.png │ │ ├── interval-encoder-3.png │ │ ├── monotonic-2.png │ │ ├── monotonic-3.png │ │ ├── monotonic-spline-regr.png │ │ ├── monotonic-spline-transform.png │ │ ├── monotonic-spline.png │ │ ├── rbf-data.png │ │ ├── rbf-plot.png │ │ └── rbf-regr.png │ └── rstudio │ │ ├── Rplot1.png │ │ └── Rplot2.png ├── api │ ├── base.md │ ├── common.md │ ├── datasets.md │ ├── decay-functions.md │ ├── decomposition.md │ ├── dummy.md │ ├── feature-selection.md │ ├── linear-model.md │ ├── meta.md │ ├── metrics.md │ ├── mixture.md │ ├── model-selection.md │ ├── naive-bayes.md │ ├── neighbors.md │ ├── pandas-utils.md │ ├── pipeline.md │ ├── preprocessing.md │ └── shrinkage-functions.md ├── contribution.md ├── generate_this_content.py ├── index.md ├── installation.md ├── rstudio.md └── user-guide │ ├── cross-validation.md │ ├── datasets.md │ ├── debug-pipeline.md │ ├── fairness.md │ ├── feature-selection.md │ ├── linear-models.md │ ├── meta-models.md │ ├── mixture-methods.md │ ├── naive-bayes.md │ ├── outliers.md │ ├── pandas-pipelines.md │ └── preprocessing.md ├── features.py ├── images └── logo.png ├── mkdocs.yaml ├── pyproject.toml ├── readme.md ├── requirements └── docs.txt ├── setup.py ├── sklego ├── __init__.py ├── base.py ├── common.py ├── data │ ├── abalone.zip │ ├── arrests.zip │ ├── chickweight.zip │ ├── hearts.zip │ ├── heroes.zip │ └── penguins.zip ├── datasets.py ├── decomposition │ ├── __init__.py │ ├── pca_reconstruction.py │ └── umap_reconstruction.py ├── dummy.py ├── feature_selection │ ├── __init__.py │ └── mrmr.py ├── linear_model.py ├── meta │ ├── __init__.py │ ├── _decay_utils.py │ ├── _grouped_utils.py │ ├── _shrinkage_utils.py │ ├── confusion_balancer.py │ ├── decay_estimator.py │ ├── estimator_transformer.py │ ├── grouped_predictor.py │ ├── grouped_transformer.py │ ├── hierarchical_predictor.py │ ├── ordinal_classification.py │ ├── outlier_classifier.py │ ├── regression_outlier_detector.py │ ├── subjective_classifier.py │ ├── thresholder.py │ └── zero_inflated_regressor.py ├── metrics.py ├── mixture │ ├── __init__.py │ ├── bayesian_gmm_classifier.py │ ├── bayesian_gmm_detector.py │ ├── gmm_classifier.py │ └── gmm_outlier_detector.py ├── model_selection.py ├── naive_bayes.py ├── neighbors.py ├── notinstalled.py ├── pandas_utils.py ├── pipeline.py ├── preprocessing │ ├── __init__.py │ ├── columncapper.py │ ├── dictmapper.py │ ├── formulaictransformer.py │ ├── identitytransformer.py │ ├── intervalencoder.py │ ├── monotonicspline.py │ ├── outlier_remover.py │ ├── pandastransformers.py │ ├── projections.py │ ├── randomadder.py │ └── repeatingbasis.py ├── testing.py └── this.py └── tests ├── __init__.py ├── conftest.py ├── data └── boston.arff ├── scripts ├── check_pip.py └── import_all.py ├── test_common ├── __init__.py ├── test_basics.py └── test_transformerselectormixin.py ├── test_datasets.py ├── test_estimators ├── __init__.py ├── test_deadzone.py ├── test_demographic_parity.py ├── test_equal_opportunity.py ├── test_gmm_naive_bayes.py ├── test_imbalanced_linear_regression.py ├── test_lowess.py ├── test_mixture_classifier.py ├── test_mixture_detector.py ├── test_neighbor_classifier.py ├── test_pca_reconstruction.py ├── test_probweight_regression.py ├── test_quantile_regression.py ├── test_randomregressor.py └── test_umap_reconstruction.py ├── test_feature_selection ├── __init__.py └── test_mrmr.py ├── test_meta ├── __init__.py ├── test_confusion_balancer.py ├── test_decay_estimator.py ├── test_decay_utils.py ├── test_estimatortransformer.py ├── test_grouped_predictor.py ├── test_grouped_transformer.py ├── test_hierarchical_predictor.py ├── test_ordinal_classification.py ├── test_outlier_classifier.py ├── test_regression_outlier.py ├── test_subjective_classifier.py ├── test_thresholder.py └── test_zero_inflated_regressor.py ├── test_metrics ├── __init__.py ├── test_correlation_score.py ├── test_equal_opportunity.py ├── test_p_percent.py └── test_subset_metric.py ├── test_model_selection ├── __init__.py ├── test_clusterfold.py ├── test_grouptimeseriessplit.py └── test_timegapsplit.py ├── test_notinstalled.py ├── test_pandas_utils └── test_pandas_utils.py ├── test_pipeline └── test_debug_pipeline.py └── test_preprocessing ├── __init__.py ├── test_columncapper.py ├── test_columndropper.py ├── test_columnselector.py ├── test_dictmapper.py ├── test_formulaic_transformer.py ├── test_identitytransformer.py ├── test_informationfilter.py ├── test_interval_encoder.py ├── test_monospline.py ├── test_orthogonal_transformer.py ├── test_outlier_remover.py ├── test_pandastypeselector.py ├── test_randomadder.py └── test_repeatingbasisfunction.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | In order to help us fix a potential bug we need: 11 | 12 | - [ ] a clear and concise description of what the bug is. 13 | - [ ] code that helps us reproduce this 14 | - [ ] **optional** a unit test that helps us catch the bug in the future 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation 3 | about: Something is wrong with the documentation. 4 | title: "[DOCS] " 5 | labels: documentation 6 | assignees: '' 7 | 8 | --- 9 | 10 | When something is wrong with the documentation: 11 | 12 | - [ ] please tell us clearly where the mistake is 13 | - [ ] explain why the mistake is confusing or wrong 14 | - [ ] suggest a better tone/alternative 15 | 16 | It is also perfectly fine to discuss the examples in the documentation here as well. 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request-template.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: New Feature Request 3 | about: This is a template for a Feature Request 4 | title: "[FEATURE]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | Please explain clearly what you'd like to see added. 11 | 12 | - [ ] convince us of the use-case, we're open to many suggestions but we prefer to solve problems with pipelines that are at least somewhat general 13 | - [ ] add a screenshot if applicable (ML stuff is hard to explain with words, pictures say 1000 words) 14 | - [ ] make sure that the feature you want is not already supported by sklearn 15 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Before working on a large PR, please check with @FBruzzesi or @koaning to confirm that they agree with the direction of the PR. This discussion should take place in a [Github issue](https://github.com/koaning/scikit-lego/issues/new/choose) before working on the PR, unless it's a minor change like spelling in the docs. 2 | 3 | # Description 4 | 5 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. 6 | 7 | Fixes #(issue) 8 | 9 | ## Type of change 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | 15 | 16 | ## Checklist: 17 | 18 | - [ ] My code follows the style guidelines (ruff) 19 | - [ ] I have commented my code, particularly in hard-to-understand areas 20 | - [ ] I have made corresponding changes to the documentation (also to the readme.md) 21 | - [ ] I have added tests that prove my fix is effective or that my feature works 22 | - [ ] I have added tests to check whether the new feature adheres to the sklearn convention 23 | - [ ] New and existing unit tests pass locally with my changes 24 | 25 | If you feel your PR is ready for a review, ping @FBruzzesi or @koaning. 26 | -------------------------------------------------------------------------------- /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /.github/workflows/dependencies.yml: -------------------------------------------------------------------------------- 1 | name: Check Optional Dependencies 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout source code 13 | uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: "3.10" 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install pytest setuptools wheel 21 | - name: Run Base Install 22 | run: | 23 | python -m pip install -e . 24 | - name: Run Checks 25 | run: | 26 | python tests/scripts/check_pip.py missing cvxpy 27 | python tests/scripts/check_pip.py installed scikit-learn 28 | python tests/scripts/import_all.py 29 | - name: Install cvxpy 30 | run: | 31 | python -m pip install -e ".[cvxpy]" 32 | - name: Run Checks 33 | run: | 34 | python tests/scripts/check_pip.py installed cvxpy scikit-learn 35 | python tests/scripts/import_all.py 36 | - name: Install All 37 | run: | 38 | python -m pip install -e ".[all]" 39 | - name: Run Checks 40 | run: | 41 | python tests/scripts/check_pip.py installed cvxpy formulaic scikit-learn umap-learn 42 | - name: Docs can Build 43 | run: | 44 | sudo apt-get update && sudo apt-get install pandoc 45 | python -m pip install -e ".[docs]" 46 | mkdocs build 47 | -------------------------------------------------------------------------------- /.github/workflows/schedule-dependencies.yml: -------------------------------------------------------------------------------- 1 | name: Cron Test Dependencies 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 0 * * *" 7 | 8 | 9 | jobs: 10 | cron-base: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | python-version: ["3.10"] 15 | os: [macos-latest, ubuntu-latest, windows-latest] 16 | pre-release-dependencies: ["--pre", ""] 17 | steps: 18 | - name: Checkout source code 19 | uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | if: always() 26 | run: | 27 | python -m pip install wheel 28 | python -m pip install ${{ matrix.pre-release-dependencies }} -e ".[test]" 29 | python -m pip freeze 30 | - name: Test with pytest 31 | if: always() 32 | run: pytest -n auto --disable-warnings --cov=sklego -m "not cvxpy and not formulaic and not umap" 33 | 34 | cron-extra: 35 | runs-on: ${{ matrix.os }} 36 | strategy: 37 | matrix: 38 | python-version: ["3.10"] 39 | os: [macos-latest, ubuntu-latest, windows-latest] 40 | pre-release-dependencies: [ 41 | # "--pre", 42 | "", 43 | ] 44 | extra: ["cvxpy", "formulaic", "umap"] 45 | steps: 46 | - name: Checkout source code 47 | uses: actions/checkout@v4 48 | - name: Set up Python ${{ matrix.python-version }} 49 | uses: actions/setup-python@v5 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | - name: Install dependencies 53 | if: always() 54 | run: | 55 | python -m pip install wheel 56 | python -m pip install ${{ matrix.pre-release-dependencies }} -e ".[test,${{ matrix.extra }}]" 57 | python -m pip freeze 58 | - name: Test with pytest 59 | if: always() 60 | run: pytest -n auto --disable-warnings --cov=sklego -m "${{ matrix.extra }}" 61 | -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | 6 | jobs: 7 | linting-python: 8 | name: Lint Python 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout source code 12 | uses: actions/checkout@v4 13 | - name: Setup Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.10" 17 | - name: Install Dependencies 18 | run: python -m pip install ruff==0.11.7 --no-cache-dir 19 | - name: Run Ruff 20 | run: make lint 21 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python Test Package 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | test: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | os: [ubuntu-latest, macos-latest, windows-latest] 14 | python-version: [ 15 | "3.9", 16 | "3.10", 17 | "3.11", 18 | "3.12", 19 | "3.13" 20 | ] 21 | exclude: 22 | - os: windows-latest 23 | python-version: "3.13" 24 | runs-on: ${{ matrix.os }} 25 | steps: 26 | - name: Checkout source code 27 | uses: actions/checkout@v4 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install -e ".[test-all]" 35 | python -m pip freeze 36 | - name: Test with pytest 37 | run: make test 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Automatically generate when building docs 2 | docs/this.md 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | Pipfile 88 | Pipfile.lock 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | venv*/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # editor 122 | .vscode 123 | .idea 124 | .DS_Store 125 | 126 | # Local Netlify folder 127 | .netlify 128 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - id: check-docstring-first 9 | - id: check-merge-conflict 10 | - id: check-added-large-files 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.11.7 13 | hooks: 14 | - id: ruff # Run the linter. 15 | args: [--fix, sklego, tests] 16 | - id: ruff-format # Run the formatter. 17 | args: [sklego, tests] 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at vincentwarmerdam@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 vincent d warmerdam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | install: 4 | python -m pip install -e ".[dev]" 5 | pre-commit install 6 | 7 | test: 8 | pytest -n auto --disable-warnings --cov=sklego 9 | rm -rf .coverage* 10 | 11 | precommit: 12 | pre-commit run 13 | 14 | docs: 15 | mkdocs serve 16 | 17 | docs-deploy: 18 | mkdocs gh-deploy 19 | 20 | clean: 21 | rm -rf .pytest_cache build dist scikit_lego.egg-info .ipynb_checkpoints .coverage* .mypy_cache .ruff_cache 22 | 23 | lint: 24 | ruff format sklego tests 25 | ruff check sklego tests --fix 26 | 27 | check: lint precommit test clean 28 | 29 | pypi: clean 30 | python setup.py sdist 31 | python setup.py bdist_wheel --universal 32 | twine upload dist/* 33 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: generate-all 2 | generate-all: 3 | ls _scripts/*.py | xargs -n 1 -P 4 python 4 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Docs readme 2 | 3 | The docs folder contains the documentation for the scikit-lego package. 4 | 5 | The documentation is generated using [Material for MkDocs][mkdocs-material], its extensions and a few plugins. 6 | In particular the `mkdocstrings-python` is used for API rendering. 7 | 8 | ## Render locally 9 | 10 | To render the documentation locally, you can run the following command from the root of the repository: 11 | 12 | ```console 13 | make docs 14 | ``` 15 | 16 | Then the documentation page will be available at [localhost][localhost]. 17 | 18 | ## Remark 19 | 20 | The majority of code and code generate plots in the documentation is generated using the scripts in the `docs/_scripts` folder, 21 | and accessed via the [pymdown snippets][pymdown-snippets] extension. 22 | 23 | To generate the plots from scratch it is enough to run the following command from the root of the repository: 24 | 25 | ```console 26 | cd docs 27 | make generate-all 28 | ``` 29 | 30 | which will run all the scripts and save results in the `docs/_static` folder. 31 | 32 | [mkdocs-material]: https://squidfunk.github.io/mkdocs-material/ 33 | [pymdown-snippets]: https://facelessuser.github.io/pymdown-extensions/extensions/snippets/ 34 | [localhost]: http://localhost:8000/ 35 | -------------------------------------------------------------------------------- /docs/_scripts/debug-pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | _file = Path(__file__) 4 | print(f"Executing {_file}") 5 | 6 | _static_path = Path("_static") / _file.stem 7 | _static_path.mkdir(parents=True, exist_ok=True) 8 | 9 | ######################################## DebugPipeline ########################################### 10 | ########################################################################################## 11 | 12 | # --8<-- [start:setup] 13 | import logging 14 | import numpy as np 15 | from sklearn.base import BaseEstimator, TransformerMixin 16 | 17 | from sklego.pipeline import DebugPipeline 18 | 19 | logging.basicConfig( 20 | format=("[%(funcName)s:%(lineno)d] - %(message)s"), 21 | level=logging.INFO 22 | ) 23 | # --8<-- [end:setup] 24 | 25 | # --8<-- [start:simple-pipe] 26 | n_samples, n_features = 3, 5 27 | X = np.zeros((n_samples, n_features)) 28 | y = np.arange(n_samples) 29 | 30 | 31 | class Adder(TransformerMixin, BaseEstimator): 32 | def __init__(self, value): 33 | self._value = value 34 | 35 | def fit(self, X, y=None): 36 | return self 37 | 38 | def transform(self, X): 39 | return X + self._value 40 | 41 | def __repr__(self): 42 | return f"Adder(value={self._value})" 43 | 44 | 45 | steps = [ 46 | ("add_1", Adder(value=1)), 47 | ("add_10", Adder(value=10)), 48 | ("add_100", Adder(value=100)), 49 | ("add_1000", Adder(value=1000)), 50 | ] 51 | # --8<-- [end:simple-pipe] 52 | 53 | # --8<-- [start:simple-pipe-fit-transform] 54 | pipe = DebugPipeline(steps) 55 | _ = pipe.fit(X, y=y) 56 | 57 | X_out = pipe.transform(X) 58 | print("Transformed X:\n", X_out) 59 | # --8<-- [end:simple-pipe-fit-transform] 60 | 61 | # --8<-- [start:log-callback] 62 | pipe = DebugPipeline(steps, log_callback="default") 63 | _ = pipe.fit(X, y=y) 64 | 65 | X_out = pipe.transform(X) 66 | print("Transformed X:\n", X_out) 67 | # --8<-- [end:log-callback] 68 | 69 | # --8<-- [start:log-callback-after] 70 | pipe = DebugPipeline(steps) 71 | pipe.log_callback = "default" 72 | 73 | _ = pipe.fit(X, y=y) 74 | 75 | X_out = pipe.transform(X) 76 | print("Transformed X:\n", X_out) 77 | # --8<-- [end:log-callback-after] 78 | 79 | # --8<-- [start:custom-log-callback] 80 | def log_callback(output, execution_time, **kwargs): 81 | """My custom `log_callback` function 82 | 83 | Parameters 84 | ---------- 85 | output : tuple( 86 | numpy.ndarray or pandas.DataFrame 87 | :class:estimator or :class:transformer 88 | ) 89 | The output of the step and a step in the pipeline. 90 | execution_time : float 91 | The execution time of the step. 92 | """ 93 | logger = logging.getLogger(__name__) 94 | step_result, step = output 95 | logger.info(f"[{step}] shape={step_result.shape} " 96 | f"nbytes={step_result.nbytes} time={execution_time}") 97 | 98 | 99 | pipe.log_callback = log_callback 100 | _ = pipe.fit(X, y=y) 101 | 102 | X_out = pipe.transform(X) 103 | print("Transformed X:\n", X_out) 104 | # --8<-- [end:custom-log-callback] 105 | 106 | 107 | # --8<-- [start:feature-union] 108 | from sklearn.pipeline import FeatureUnion 109 | 110 | pipe_w_default_log_callback = DebugPipeline(steps, log_callback='default') 111 | pipe_w_custom_log_callback = DebugPipeline(steps, log_callback=log_callback) 112 | 113 | pipe_union = FeatureUnion([ 114 | ('pipe_w_default_log_callback', pipe_w_default_log_callback), 115 | ('pipe_w_custom_log_callback', pipe_w_custom_log_callback), 116 | ]) 117 | 118 | _ = pipe_union.fit(X, y=y) 119 | 120 | X_out = pipe_union.transform(X) 121 | print('Transformed X:\n', X_out) 122 | # --8<-- [end:feature-union] 123 | 124 | 125 | # --8<-- [start:remove] 126 | pipe.log_callback = None 127 | _ = pipe.fit(X, y=y) 128 | 129 | X_out = pipe.transform(X) 130 | print('Transformed X:\n', X_out) 131 | # --8<-- [end:remove] 132 | -------------------------------------------------------------------------------- /docs/_scripts/naive-bayes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | _file = Path(__file__) 4 | print(f"Executing {_file}") 5 | 6 | _static_path = Path("_static") / _file.stem 7 | _static_path.mkdir(parents=True, exist_ok=True) 8 | 9 | #################################### Simulated Data ###################################### 10 | ########################################################################################## 11 | 12 | # --8<-- [start:simulated-data] 13 | import numpy as np 14 | import matplotlib.pylab as plt 15 | import seaborn as sns 16 | 17 | sns.set_theme() 18 | 19 | n = 10000 20 | 21 | def make_arr(mu1, mu2, std1=1, std2=1, p=0.5): 22 | res = np.where(np.random.uniform(0, 1, n) > p, 23 | np.random.normal(mu1, std1, n), 24 | np.random.normal(mu2, std2, n)); 25 | return np.expand_dims(res, 1) 26 | 27 | np.random.seed(42) 28 | X1 = np.concatenate([make_arr(0, 4), make_arr(0, 4)], axis=1) 29 | X2 = np.concatenate([make_arr(-3, 7), make_arr(2, 2)], axis=1) 30 | 31 | plt.figure(figsize=(4,4)) 32 | plt.scatter(X1[:, 0], X1[:, 1], alpha=0.5) 33 | plt.scatter(X2[:, 0], X2[:, 1], alpha=0.5) 34 | plt.title("simulated dataset"); 35 | # --8<-- [end:simulated-data] 36 | 37 | plt.savefig(_static_path / "simulated-data.png") 38 | plt.clf() 39 | 40 | #################################### Model results ####################################### 41 | ########################################################################################## 42 | 43 | # --8<-- [start:model-results] 44 | from sklego.naive_bayes import GaussianMixtureNB 45 | cmap=sns.color_palette("flare", as_cmap=True) 46 | 47 | X = np.concatenate([X1, X2]) 48 | y = np.concatenate([np.zeros(n), np.ones(n)]) 49 | plt.figure(figsize=(8, 8)) 50 | for i, k in enumerate([1, 2]): 51 | mod = GaussianMixtureNB(n_components=k).fit(X, y) 52 | plt.subplot(220 + i * 2 + 1) 53 | pred = mod.predict_proba(X)[:, 0] 54 | plt.scatter(X[:, 0], X[:, 1], c=pred, cmap=cmap) 55 | plt.title(f"predict_proba k={k}") 56 | 57 | plt.subplot(220 + i * 2 + 2) 58 | pred = mod.predict(X) 59 | plt.scatter(X[:, 0], X[:, 1], c=pred, cmap=cmap) 60 | plt.title(f"predict k={k}"); 61 | # --8<-- [end:model-results] 62 | 63 | plt.savefig(_static_path / "model-results.png") 64 | plt.clf() 65 | 66 | #################################### Model density ####################################### 67 | ########################################################################################## 68 | 69 | # --8<-- [start:model-density] 70 | gmm1 = mod.gmms_[0.0] 71 | gmm2 = mod.gmms_[1.0] 72 | plt.figure(figsize=(8, 8)) 73 | 74 | plt.subplot(221) 75 | plt.hist(gmm1[0].sample(n)[0], 30) 76 | plt.title("model 1 - column 1 density") 77 | plt.subplot(222) 78 | plt.hist(gmm1[1].sample(n)[0], 30) 79 | plt.title("model 1 - column 2 density") 80 | plt.subplot(223) 81 | plt.hist(gmm2[0].sample(n)[0], 30) 82 | plt.title("model 2 - column 1 density") 83 | plt.subplot(224) 84 | plt.hist(gmm2[1].sample(n)[0], 30) 85 | plt.title("model 2 - column 2 density"); 86 | # --8<-- [end:model-density] 87 | 88 | plt.savefig(_static_path / "model-density.png") 89 | plt.clf() -------------------------------------------------------------------------------- /docs/_scripts/pandas-pipelines.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | _file = Path(__file__) 4 | print(f"Executing {_file}") 5 | 6 | _static_path = Path("_static") / _file.stem 7 | _static_path.mkdir(parents=True, exist_ok=True) 8 | 9 | import pandas as pd 10 | 11 | # --8<-- [start:log-setup] 12 | import logging 13 | 14 | logging.basicConfig(level=logging.DEBUG) 15 | # --8<-- [end:log-setup] 16 | 17 | # --8<-- [start:data-setup] 18 | from sklego.datasets import load_chicken 19 | 20 | chickweight = load_chicken(as_frame=True) 21 | # --8<-- [end:data-setup] 22 | 23 | # --8<-- [start:log-step] 24 | from sklego.pandas_utils import log_step 25 | 26 | @log_step 27 | def set_dtypes(chickweight): 28 | return chickweight.assign( 29 | diet=lambda d: d['diet'].astype('category'), 30 | chick=lambda d: d['chick'].astype('category'), 31 | ) 32 | 33 | chickweight.pipe(set_dtypes).head() 34 | # --8<-- [end:log-step] 35 | 36 | print(chickweight.pipe(set_dtypes).head()) 37 | 38 | # --8<-- [start:log-step-printfn] 39 | @log_step(print_fn=logging.debug) 40 | def remove_dead_chickens(chickweight): 41 | dead_chickens = chickweight.groupby('chick').size().loc[lambda s: s < 12] 42 | return chickweight.loc[lambda d: ~d['chick'].isin(dead_chickens)] 43 | 44 | 45 | @log_step(print_fn=logging.info) 46 | def remove_outliers(chickweight): 47 | return chickweight.pipe(remove_dead_chickens) 48 | 49 | chickweight.pipe(set_dtypes).pipe(remove_outliers).head() 50 | # --8<-- [end:log-step-printfn] 51 | 52 | print(chickweight.pipe(set_dtypes).pipe(remove_outliers).head()) 53 | 54 | # --8<-- [start:log-step-notime] 55 | @log_step(time_taken=False, shape=False, shape_delta=True) 56 | def remove_dead_chickens(chickweight): 57 | dead_chickens = chickweight.groupby('chick').size().loc[lambda s: s < 12] 58 | return chickweight.loc[lambda d: ~d['chick'].isin(dead_chickens)] 59 | 60 | chickweight.pipe(remove_dead_chickens).head() 61 | # --8<-- [end:log-step-notime] 62 | 63 | print(chickweight.pipe(remove_dead_chickens).head()) 64 | 65 | 66 | # --8<-- [start:log-step-extra] 67 | from sklego.pandas_utils import log_step_extra 68 | 69 | def count_unique_chicks(df, **kwargs): 70 | return "nchicks=" + str(df["chick"].nunique()) 71 | 72 | def display_message(df, msg): 73 | return msg 74 | 75 | 76 | @log_step_extra(count_unique_chicks) 77 | def start_pipe(df): 78 | """Get initial chick count""" 79 | return df 80 | 81 | 82 | @log_step_extra(count_unique_chicks, display_message, msg="without diet 1") 83 | def remove_diet_1_chicks(df): 84 | return df.loc[df["diet"] != 1] 85 | 86 | (chickweight 87 | .pipe(start_pipe) 88 | .pipe(remove_diet_1_chicks) 89 | .head() 90 | ) 91 | # --8<-- [end:log-step-extra] 92 | 93 | print(chickweight.pipe(start_pipe).pipe(remove_diet_1_chicks).head()) 94 | -------------------------------------------------------------------------------- /docs/_static/contribution/contribute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/contribution/contribute.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/example-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/example-1.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/example-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/example-2.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/example-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/example-3.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/example-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/example-4.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/example-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/example-5.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/group-time-series-split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/group-time-series-split.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/grp-summary.md: -------------------------------------------------------------------------------- 1 | | index | observations | group | obs_per_group | ideal_group_size | diff_from_ideal_group_size | 2 | |--------:|---------------:|--------:|----------------:|-------------------:|-----------------------------:| 3 | | 2000 | 3 | 0 | 4 | 4 | 0 | 4 | | 2001 | 1 | 0 | 4 | 4 | 0 | 5 | | 2002 | 2 | 1 | 3 | 4 | -1 | 6 | | 2003 | 1 | 1 | 3 | 4 | -1 | 7 | | 2004 | 5 | 2 | 5 | 4 | 1 | 8 | | 2005 | 2 | 3 | 5 | 4 | 1 | 9 | | 2006 | 2 | 3 | 5 | 4 | 1 | 10 | | 2007 | 1 | 3 | 5 | 4 | 1 | 11 | -------------------------------------------------------------------------------- /docs/_static/cross-validation/grp-ts.md: -------------------------------------------------------------------------------- 1 | | X | y | 2 | |----:|----:| 3 | | 583 | 481 | 4 | | 414 | 617 | 5 | | 669 | 627 | 6 | | 812 | 604 | 7 | | 800 | 248 | 8 | | 966 | 503 | 9 | | 719 | 650 | 10 | | 476 | 939 | 11 | | 743 | 170 | 12 | | 142 | 893 | 13 | -------------------------------------------------------------------------------- /docs/_static/cross-validation/kfold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/cross-validation/kfold.png -------------------------------------------------------------------------------- /docs/_static/cross-validation/summary.md: -------------------------------------------------------------------------------- 1 | | Start date | End date | Period | Unique days | nbr samples | 2 | |:--------------------|:--------------------|:----------------|--------------:|--------------:| 3 | | 2018-01-01 00:00:00 | 2018-01-10 00:00:00 | 9 days 00:00:00 | 10 | 10 | 4 | | 2018-01-12 00:00:00 | 2018-01-13 00:00:00 | 1 days 00:00:00 | 2 | 2 | 5 | | 2018-01-06 00:00:00 | 2018-01-15 00:00:00 | 9 days 00:00:00 | 10 | 10 | 6 | | 2018-01-17 00:00:00 | 2018-01-18 00:00:00 | 1 days 00:00:00 | 2 | 2 | 7 | | 2018-01-10 00:00:00 | 2018-01-19 00:00:00 | 9 days 00:00:00 | 10 | 10 | 8 | | 2018-01-21 00:00:00 | 2018-01-22 00:00:00 | 1 days 00:00:00 | 2 | 2 | 9 | | 2018-01-15 00:00:00 | 2018-01-24 00:00:00 | 9 days 00:00:00 | 10 | 10 | 10 | | 2018-01-26 00:00:00 | 2018-01-27 00:00:00 | 1 days 00:00:00 | 2 | 2 | 11 | -------------------------------------------------------------------------------- /docs/_static/cross-validation/ts.md: -------------------------------------------------------------------------------- 1 | | A | B | C | y | date | 2 | |----:|----:|----:|----:|:--------------------| 3 | | 28 | 9 | 24 | 5 | 2018-01-30 00:00:00 | 4 | | 5 | 0 | 19 | 1 | 2018-01-29 00:00:00 | 5 | | 8 | 1 | 29 | 2 | 2018-01-28 00:00:00 | 6 | | 11 | 4 | 21 | 19 | 2018-01-27 00:00:00 | 7 | | 19 | 26 | 6 | 2 | 2018-01-26 00:00:00 | 8 | -------------------------------------------------------------------------------- /docs/_static/datasets/abalone.md: -------------------------------------------------------------------------------- 1 | | sex | length | diameter | height | whole_weight | shucked_weight | viscera_weight | shell_weight | rings | 2 | |:------|---------:|-----------:|---------:|---------------:|-----------------:|-----------------:|---------------:|--------:| 3 | | M | 0.455 | 0.365 | 0.095 | 0.514 | 0.2245 | 0.101 | 0.15 | 15 | 4 | | M | 0.35 | 0.265 | 0.09 | 0.2255 | 0.0995 | 0.0485 | 0.07 | 7 | 5 | | F | 0.53 | 0.42 | 0.135 | 0.677 | 0.2565 | 0.1415 | 0.21 | 9 | 6 | | M | 0.44 | 0.365 | 0.125 | 0.516 | 0.2155 | 0.114 | 0.155 | 10 | 7 | | I | 0.33 | 0.255 | 0.08 | 0.205 | 0.0895 | 0.0395 | 0.055 | 7 | -------------------------------------------------------------------------------- /docs/_static/datasets/abalone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/abalone.png -------------------------------------------------------------------------------- /docs/_static/datasets/arrests.md: -------------------------------------------------------------------------------- 1 | | released | colour | year | age | sex | employed | citizen | checks | 2 | |:-----------|:---------|-------:|------:|:-------|:-----------|:----------|---------:| 3 | | Yes | White | 2002 | 21 | Male | Yes | Yes | 3 | 4 | | No | Black | 1999 | 17 | Male | Yes | Yes | 3 | 5 | | Yes | White | 2000 | 24 | Male | Yes | Yes | 3 | 6 | | No | Black | 2000 | 46 | Male | Yes | Yes | 1 | 7 | | Yes | Black | 1999 | 27 | Female | Yes | Yes | 1 | -------------------------------------------------------------------------------- /docs/_static/datasets/arrests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/arrests.png -------------------------------------------------------------------------------- /docs/_static/datasets/chicken.md: -------------------------------------------------------------------------------- 1 | | weight | time | chick | diet | 2 | |---------:|-------:|--------:|-------:| 3 | | 42 | 0 | 1 | 1 | 4 | | 51 | 2 | 1 | 1 | 5 | | 59 | 4 | 1 | 1 | 6 | | 64 | 6 | 1 | 1 | 7 | | 76 | 8 | 1 | 1 | -------------------------------------------------------------------------------- /docs/_static/datasets/chicken.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/chicken.png -------------------------------------------------------------------------------- /docs/_static/datasets/creditcards.md: -------------------------------------------------------------------------------- 1 | | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 | V11 | V12 | V13 | V14 | V15 | V16 | V17 | V18 | V19 | V20 | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | 2 | |----------:|-----------:|--------:|----------:|-----------:|-----------:|----------:|-----------:|----------:|-----------:|----------:|-----------:|----------:|----------:|----------:|----------:|----------:|-----------:|----------:|-----------:|-----------:|-----------:|----------:|-----------:|----------:|----------:|-----------:|-----------:|---------:|--------:| 3 | | -1.35981 | -0.0727812 | 2.53635 | 1.37816 | -0.338321 | 0.462388 | 0.239599 | 0.0986979 | 0.363787 | 0.0907942 | -0.5516 | -0.617801 | -0.99139 | -0.311169 | 1.46818 | -0.470401 | 0.207971 | 0.0257906 | 0.403993 | 0.251412 | -0.0183068 | 0.277838 | -0.110474 | 0.0669281 | 0.128539 | -0.189115 | 0.133558 | -0.0210531 | 149.62 | 0 | 4 | | 1.19186 | 0.266151 | 0.16648 | 0.448154 | 0.0600176 | -0.0823608 | -0.078803 | 0.0851017 | -0.255425 | -0.166974 | 1.61273 | 1.06524 | 0.489095 | -0.143772 | 0.635558 | 0.463917 | -0.114805 | -0.183361 | -0.145783 | -0.0690831 | -0.225775 | -0.638672 | 0.101288 | -0.339846 | 0.16717 | 0.125895 | -0.0089831 | 0.0147242 | 2.69 | 0 | 5 | | -1.35835 | -1.34016 | 1.77321 | 0.37978 | -0.503198 | 1.8005 | 0.791461 | 0.247676 | -1.51465 | 0.207643 | 0.624501 | 0.0660837 | 0.717293 | -0.165946 | 2.34586 | -2.89008 | 1.10997 | -0.121359 | -2.26186 | 0.52498 | 0.247998 | 0.771679 | 0.909412 | -0.689281 | -0.327642 | -0.139097 | -0.0553528 | -0.0597518 | 378.66 | 0 | 6 | | -0.966272 | -0.185226 | 1.79299 | -0.863291 | -0.0103089 | 1.2472 | 0.237609 | 0.377436 | -1.38702 | -0.0549519 | -0.226487 | 0.178228 | 0.507757 | -0.287924 | -0.631418 | -1.05965 | -0.684093 | 1.96578 | -1.23262 | -0.208038 | -0.1083 | 0.0052736 | -0.190321 | -1.17558 | 0.647376 | -0.221929 | 0.0627228 | 0.0614576 | 123.5 | 0 | 7 | | -1.15823 | 0.877737 | 1.54872 | 0.403034 | -0.407193 | 0.0959215 | 0.592941 | -0.270533 | 0.817739 | 0.753074 | -0.822843 | 0.538196 | 1.34585 | -1.11967 | 0.175121 | -0.451449 | -0.237033 | -0.0381948 | 0.803487 | 0.408542 | -0.0094307 | 0.798278 | -0.137458 | 0.141267 | -0.20601 | 0.502292 | 0.219422 | 0.215153 | 69.99 | 0 | -------------------------------------------------------------------------------- /docs/_static/datasets/creditcards.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/creditcards.png -------------------------------------------------------------------------------- /docs/_static/datasets/hearts.md: -------------------------------------------------------------------------------- 1 | | age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | target | 2 | |------:|------:|-----:|-----------:|-------:|------:|----------:|----------:|--------:|----------:|--------:|-----:|:-----------|---------:| 3 | | 63 | 1 | 1 | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 | 4 | | 67 | 1 | 4 | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 | 5 | | 67 | 1 | 4 | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversible | 0 | 6 | | 37 | 1 | 3 | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 | 7 | | 41 | 0 | 2 | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 | -------------------------------------------------------------------------------- /docs/_static/datasets/hearts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/hearts.png -------------------------------------------------------------------------------- /docs/_static/datasets/heroes.md: -------------------------------------------------------------------------------- 1 | | name | attack_type | role | health | attack | attack_spd | 2 | |:---------|:--------------|:--------|---------:|---------:|-------------:| 3 | | Artanis | Melee | Bruiser | 2470 | 111 | 1 | 4 | | Chen | Melee | Bruiser | 2473 | 90 | 1.11 | 5 | | Dehaka | Melee | Bruiser | 2434 | 100 | 1.11 | 6 | | Imperius | Melee | Bruiser | 2450 | 122 | 0.83 | 7 | | Leoric | Melee | Bruiser | 2550 | 109 | 0.77 | -------------------------------------------------------------------------------- /docs/_static/datasets/heroes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/heroes.png -------------------------------------------------------------------------------- /docs/_static/datasets/penguins.md: -------------------------------------------------------------------------------- 1 | | species | island | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g | sex | 2 | |:----------|:----------|-----------------:|----------------:|--------------------:|--------------:|:-------| 3 | | Adelie | Torgersen | 39.1 | 18.7 | 181 | 3750 | male | 4 | | Adelie | Torgersen | 39.5 | 17.4 | 186 | 3800 | female | 5 | | Adelie | Torgersen | 40.3 | 18 | 195 | 3250 | female | 6 | | Adelie | Torgersen | nan | nan | nan | nan | nan | 7 | | Adelie | Torgersen | 36.7 | 19.3 | 193 | 3450 | female | -------------------------------------------------------------------------------- /docs/_static/datasets/penguins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/penguins.png -------------------------------------------------------------------------------- /docs/_static/datasets/timeseries.md: -------------------------------------------------------------------------------- 1 | | | yt | 2 | |---:|----------:| 3 | | 0 | -0.335058 | 4 | | 1 | -0.283375 | 5 | | 2 | 0.521791 | 6 | | 3 | 0.50202 | 7 | | 4 | 0.310048 | -------------------------------------------------------------------------------- /docs/_static/datasets/timeseries.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/datasets/timeseries.png -------------------------------------------------------------------------------- /docs/_static/fairness/boston-description.txt: -------------------------------------------------------------------------------- 1 | .. _boston_dataset: 2 | 3 | Boston house prices dataset 4 | --------------------------- 5 | 6 | **Data Set Characteristics:** 7 | 8 | :Number of Instances: 506 9 | 10 | :Number of Attributes: 13 numeric/categorical predictive. Median Value (attribute 14) is usually the target. 11 | 12 | :Attribute Information (in order): 13 | - CRIM per capita crime rate by town 14 | - ZN proportion of residential land zoned for lots over 25,000 sq.ft. 15 | - INDUS proportion of non-retail business acres per town 16 | - CHAS Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) 17 | - NOX nitric oxides concentration (parts per 10 million) 18 | - RM average number of rooms per dwelling 19 | - AGE proportion of owner-occupied units built prior to 1940 20 | - DIS weighted distances to five Boston employment centres 21 | - RAD index of accessibility to radial highways 22 | - TAX full-value property-tax rate per $10,000 23 | - PTRATIO pupil-teacher ratio by town 24 | - B 1000(Bk - 0.63)^2 where Bk is the proportion of black people by town 25 | - LSTAT % lower status of the population 26 | - MEDV Median value of owner-occupied homes in $1000's 27 | -------------------------------------------------------------------------------- /docs/_static/fairness/demographic-parity-grid-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/demographic-parity-grid-results.png -------------------------------------------------------------------------------- /docs/_static/fairness/drop-two.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/drop-two.png -------------------------------------------------------------------------------- /docs/_static/fairness/equal-opportunity-grid-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/equal-opportunity-grid-results.png -------------------------------------------------------------------------------- /docs/_static/fairness/information-filter-coefs.md: -------------------------------------------------------------------------------- 1 | | crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | 2 | |----------:|---------:|-----------:|---------:|---------:|--------:|-----------:|---------:|--------:|---------:|----------:|-----------:|----------:| 3 | | -0.928146 | 1.08157 | 0.1409 | 0.68174 | -2.05672 | 2.67423 | 0.0194661 | -3.10404 | 2.66222 | -2.07678 | -2.06061 | 0.849268 | -3.74363 | 4 | | -1.5814 | 0.911004 | -0.290074 | 0.884936 | -2.56787 | 4.2647 | -1.27073 | -3.33184 | 2.21574 | -2.05625 | -2.1546 | nan | nan | 5 | | -0.763568 | 1.02805 | 0.0613932 | 0.697504 | -1.60546 | 6.84677 | -0.0579197 | -2.5376 | 1.93506 | -1.77983 | -2.79307 | nan | nan | -------------------------------------------------------------------------------- /docs/_static/fairness/original-situation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/original-situation.png -------------------------------------------------------------------------------- /docs/_static/fairness/predict-boston-simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/predict-boston-simple.png -------------------------------------------------------------------------------- /docs/_static/fairness/projections.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/projections.png -------------------------------------------------------------------------------- /docs/_static/fairness/use-info-filter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/fairness/use-info-filter.png -------------------------------------------------------------------------------- /docs/_static/feature-selection/mrmr-feature-selection-mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/feature-selection/mrmr-feature-selection-mnist.png -------------------------------------------------------------------------------- /docs/_static/linear-models/grid-span-sigma-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/grid-span-sigma-01.png -------------------------------------------------------------------------------- /docs/_static/linear-models/grid-span-sigma-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/grid-span-sigma-02.png -------------------------------------------------------------------------------- /docs/_static/linear-models/lad-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lad-data.png -------------------------------------------------------------------------------- /docs/_static/linear-models/lad-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lad-fit.png -------------------------------------------------------------------------------- /docs/_static/linear-models/lowess-rolling-001.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lowess-rolling-001.gif -------------------------------------------------------------------------------- /docs/_static/linear-models/lowess-rolling-01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lowess-rolling-01.gif -------------------------------------------------------------------------------- /docs/_static/linear-models/lowess-rolling.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lowess-rolling.gif -------------------------------------------------------------------------------- /docs/_static/linear-models/lowess-two-predictions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lowess-two-predictions.gif -------------------------------------------------------------------------------- /docs/_static/linear-models/lowess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lowess.png -------------------------------------------------------------------------------- /docs/_static/linear-models/lr-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/lr-fit.png -------------------------------------------------------------------------------- /docs/_static/linear-models/quantile-fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/linear-models/quantile-fit.png -------------------------------------------------------------------------------- /docs/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/logo.png -------------------------------------------------------------------------------- /docs/_static/meta-models/baseline-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/baseline-model.png -------------------------------------------------------------------------------- /docs/_static/meta-models/confusion-balancer-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/confusion-balancer-results.png -------------------------------------------------------------------------------- /docs/_static/meta-models/decay-functions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/decay-functions.png -------------------------------------------------------------------------------- /docs/_static/meta-models/decay-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/decay-model.png -------------------------------------------------------------------------------- /docs/_static/meta-models/grouped-df.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/grouped-df.png -------------------------------------------------------------------------------- /docs/_static/meta-models/grouped-dummy-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/grouped-dummy-model.png -------------------------------------------------------------------------------- /docs/_static/meta-models/grouped-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/grouped-model.png -------------------------------------------------------------------------------- /docs/_static/meta-models/grouped-np.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/grouped-np.png -------------------------------------------------------------------------------- /docs/_static/meta-models/grouped-transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/grouped-transform.png -------------------------------------------------------------------------------- /docs/_static/meta-models/make-blobs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/make-blobs.png -------------------------------------------------------------------------------- /docs/_static/meta-models/ordinal-classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/ordinal-classification.png -------------------------------------------------------------------------------- /docs/_static/meta-models/ordinal_data.md: -------------------------------------------------------------------------------- 1 | | apply | pared | public | gpa | apply_codes | 2 | |:----------------|--------:|---------:|------:|--------------:| 3 | | very likely | 0 | 0 | 3.26 | 2 | 4 | | somewhat likely | 1 | 0 | 3.21 | 1 | 5 | | unlikely | 1 | 1 | 3.94 | 0 | 6 | | somewhat likely | 0 | 0 | 2.81 | 1 | 7 | | somewhat likely | 0 | 0 | 2.53 | 1 | -------------------------------------------------------------------------------- /docs/_static/meta-models/penguins.md: -------------------------------------------------------------------------------- 1 | | flipper_length_mm | body_mass_g | sex | 2 | |--------------------:|--------------:|:-------| 3 | | 181 | 3750 | male | 4 | | 186 | 3800 | female | 5 | | 195 | 3250 | female | 6 | | 193 | 3450 | female | 7 | | 190 | 3650 | male | -------------------------------------------------------------------------------- /docs/_static/meta-models/skewed-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/skewed-data.png -------------------------------------------------------------------------------- /docs/_static/meta-models/threshold-chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/threshold-chart.png -------------------------------------------------------------------------------- /docs/_static/meta-models/ts-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/meta-models/ts-data.png -------------------------------------------------------------------------------- /docs/_static/mixture-methods/gmm-classifier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/mixture-methods/gmm-classifier.png -------------------------------------------------------------------------------- /docs/_static/mixture-methods/gmm-outlier-detector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/mixture-methods/gmm-outlier-detector.png -------------------------------------------------------------------------------- /docs/_static/mixture-methods/gmm-outlier-multi-threshold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/mixture-methods/gmm-outlier-multi-threshold.png -------------------------------------------------------------------------------- /docs/_static/mixture-methods/outlier-mixture-threshold.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/mixture-methods/outlier-mixture-threshold.png -------------------------------------------------------------------------------- /docs/_static/naive-bayes/model-density.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/naive-bayes/model-density.png -------------------------------------------------------------------------------- /docs/_static/naive-bayes/model-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/naive-bayes/model-results.png -------------------------------------------------------------------------------- /docs/_static/naive-bayes/naive-bayes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/naive-bayes/naive-bayes.png -------------------------------------------------------------------------------- /docs/_static/naive-bayes/simulated-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/naive-bayes/simulated-data.png -------------------------------------------------------------------------------- /docs/_static/outliers/bayesian-gmm-outlier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/bayesian-gmm-outlier.png -------------------------------------------------------------------------------- /docs/_static/outliers/decomposition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/decomposition.png -------------------------------------------------------------------------------- /docs/_static/outliers/gmm-outlier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/gmm-outlier.png -------------------------------------------------------------------------------- /docs/_static/outliers/pca-outlier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/pca-outlier.png -------------------------------------------------------------------------------- /docs/_static/outliers/regr-outlier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/regr-outlier.png -------------------------------------------------------------------------------- /docs/_static/outliers/umap-outlier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/outliers/umap-outlier.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/column-capper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/column-capper.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/estimator-transformer-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/estimator-transformer-1.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/estimator-transformer-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/estimator-transformer-2.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/formulaic-1.md: -------------------------------------------------------------------------------- 1 | | Intercept | a | np.log(a) | b[T.no] | b[T.yes] | 2 | |------------:|----:|------------:|----------:|-----------:| 3 | | 1 | 1 | 0 | 0 | 1 | 4 | | 1 | 2 | 0.693147 | 0 | 1 | 5 | | 1 | 3 | 1.09861 | 1 | 0 | 6 | | 1 | 4 | 1.38629 | 0 | 0 | 7 | | 1 | 5 | 1.60944 | 0 | 1 | 8 | -------------------------------------------------------------------------------- /docs/_static/preprocessing/formulaic-2.md: -------------------------------------------------------------------------------- 1 | | a | np.log(a) | b[T.maybe] | b[T.no] | b[T.yes] | 2 | |----:|------------:|-------------:|----------:|-----------:| 3 | | 1 | 0 | 0 | 0 | 1 | 4 | | 2 | 0.693147 | 0 | 0 | 1 | 5 | | 3 | 1.09861 | 0 | 1 | 0 | 6 | | 4 | 1.38629 | 1 | 0 | 0 | 7 | | 5 | 1.60944 | 0 | 0 | 1 | 8 | -------------------------------------------------------------------------------- /docs/_static/preprocessing/identity-transformer-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/identity-transformer-1.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/identity-transformer-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/identity-transformer-2.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/interval-encoder-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/interval-encoder-1.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/interval-encoder-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/interval-encoder-2.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/interval-encoder-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/interval-encoder-3.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/monotonic-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/monotonic-2.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/monotonic-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/monotonic-3.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/monotonic-spline-regr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/monotonic-spline-regr.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/monotonic-spline-transform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/monotonic-spline-transform.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/monotonic-spline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/monotonic-spline.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/rbf-data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/rbf-data.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/rbf-plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/rbf-plot.png -------------------------------------------------------------------------------- /docs/_static/preprocessing/rbf-regr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/preprocessing/rbf-regr.png -------------------------------------------------------------------------------- /docs/_static/rstudio/Rplot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/rstudio/Rplot1.png -------------------------------------------------------------------------------- /docs/_static/rstudio/Rplot2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/docs/_static/rstudio/Rplot2.png -------------------------------------------------------------------------------- /docs/api/base.md: -------------------------------------------------------------------------------- 1 | # Base 2 | 3 | ::: sklego.base.ClustererMeta 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.base.Clusterer 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.base.OutlierModelMeta 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ::: sklego.base.OutlierModel 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | 23 | ::: sklego.base.ProbabilisticClassifierMeta 24 | options: 25 | show_root_full_path: true 26 | show_root_heading: true 27 | 28 | ::: sklego.base.ProbabilisticClassifier 29 | options: 30 | show_root_full_path: true 31 | show_root_heading: true 32 | -------------------------------------------------------------------------------- /docs/api/common.md: -------------------------------------------------------------------------------- 1 | # Common 2 | 3 | Module with common classes and functions used across the package. 4 | 5 | ::: sklego.common.TrainOnlyTransformerMixin 6 | options: 7 | show_root_full_path: true 8 | show_root_heading: true 9 | 10 | ::: sklego.common.as_list 11 | options: 12 | show_root_full_path: true 13 | show_root_heading: true 14 | 15 | ::: sklego.common.flatten 16 | options: 17 | show_root_full_path: true 18 | show_root_heading: true 19 | 20 | ::: sklego.common.expanding_list 21 | options: 22 | show_root_full_path: true 23 | show_root_heading: true 24 | 25 | ::: sklego.common.sliding_window 26 | options: 27 | show_root_full_path: true 28 | show_root_heading: true 29 | -------------------------------------------------------------------------------- /docs/api/datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | ::: sklego.datasets.load_abalone 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.datasets.load_arrests 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.datasets.load_chicken 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ::: sklego.datasets.load_heroes 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | 23 | ::: sklego.datasets.load_hearts 24 | options: 25 | show_root_full_path: true 26 | show_root_heading: true 27 | 28 | ::: sklego.datasets.load_penguins 29 | options: 30 | show_root_full_path: true 31 | show_root_heading: true 32 | 33 | ::: sklego.datasets.make_simpleseries 34 | options: 35 | show_root_full_path: true 36 | show_root_heading: true 37 | 38 | ::: sklego.datasets.fetch_creditcard 39 | options: 40 | show_root_full_path: true 41 | show_root_heading: true 42 | -------------------------------------------------------------------------------- /docs/api/decay-functions.md: -------------------------------------------------------------------------------- 1 | # Decay Functions 2 | 3 | These functions are used in the [`DecayEstimator`][decay-estimator] to generate sample weights for the wrapped model. 4 | 5 | ::: sklego.meta._decay_utils.exponential_decay 6 | options: 7 | show_root_full_path: true 8 | show_root_heading: true 9 | 10 | ::: sklego.meta._decay_utils.linear_decay 11 | options: 12 | show_root_full_path: true 13 | show_root_heading: true 14 | 15 | ::: sklego.meta._decay_utils.sigmoid_decay 16 | options: 17 | show_root_full_path: true 18 | show_root_heading: true 19 | 20 | ::: sklego.meta._decay_utils.stepwise_decay 21 | options: 22 | show_root_full_path: true 23 | show_root_heading: true 24 | 25 | [decay-estimator]: ../../api/meta#sklego.meta.decay_estimator.DecayEstimator 26 | -------------------------------------------------------------------------------- /docs/api/decomposition.md: -------------------------------------------------------------------------------- 1 | # Decomposition 2 | 3 | ::: sklego.decomposition.pca_reconstruction.PCAOutlierDetection 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.decomposition.umap_reconstruction.UMAPOutlierDetection 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | -------------------------------------------------------------------------------- /docs/api/dummy.md: -------------------------------------------------------------------------------- 1 | # Dummy 2 | 3 | ::: sklego.dummy.RandomRegressor 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | -------------------------------------------------------------------------------- /docs/api/feature-selection.md: -------------------------------------------------------------------------------- 1 | # Features Selection 2 | 3 | :::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | -------------------------------------------------------------------------------- /docs/api/linear-model.md: -------------------------------------------------------------------------------- 1 | # Linear Models 2 | 3 | ::: sklego.linear_model.LowessRegression 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.linear_model.ProbWeightRegression 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.linear_model.DeadZoneRegressor 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ::: sklego.linear_model.DemographicParityClassifier 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | 23 | ::: sklego.linear_model.EqualOpportunityClassifier 24 | options: 25 | show_root_full_path: true 26 | show_root_heading: true 27 | 28 | ::: sklego.linear_model.BaseScipyMinimizeRegressor 29 | options: 30 | show_root_full_path: true 31 | show_root_heading: true 32 | 33 | ::: sklego.linear_model.ImbalancedLinearRegression 34 | options: 35 | show_root_full_path: true 36 | show_root_heading: true 37 | 38 | ::: sklego.linear_model.QuantileRegression 39 | options: 40 | show_root_full_path: true 41 | show_root_heading: true 42 | 43 | ::: sklego.linear_model.LADRegression 44 | options: 45 | show_root_full_path: true 46 | show_root_heading: true 47 | -------------------------------------------------------------------------------- /docs/api/meta.md: -------------------------------------------------------------------------------- 1 | # Meta Models 2 | 3 | ::: sklego.meta.confusion_balancer.ConfusionBalancer 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.meta.decay_estimator.DecayEstimator 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.meta.estimator_transformer.EstimatorTransformer 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ::: sklego.meta.grouped_predictor.GroupedPredictor 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | 23 | ::: sklego.meta.grouped_predictor.GroupedClassifier 24 | options: 25 | show_root_full_path: true 26 | show_root_heading: true 27 | 28 | ::: sklego.meta.grouped_predictor.GroupedRegressor 29 | options: 30 | show_root_full_path: true 31 | show_root_heading: true 32 | 33 | ::: sklego.meta.grouped_transformer.GroupedTransformer 34 | options: 35 | show_root_full_path: true 36 | show_root_heading: true 37 | 38 | ::: sklego.meta.ordinal_classification.OrdinalClassifier 39 | options: 40 | show_root_full_path: true 41 | show_root_heading: true 42 | 43 | ::: sklego.meta.outlier_classifier.OutlierClassifier 44 | options: 45 | show_root_full_path: true 46 | show_root_heading: true 47 | 48 | ::: sklego.meta.regression_outlier_detector.RegressionOutlierDetector 49 | options: 50 | show_root_full_path: true 51 | show_root_heading: true 52 | 53 | ::: sklego.meta.subjective_classifier.SubjectiveClassifier 54 | options: 55 | show_root_full_path: true 56 | show_root_heading: true 57 | 58 | ::: sklego.meta.thresholder.Thresholder 59 | options: 60 | show_root_full_path: true 61 | show_root_heading: true 62 | 63 | ::: sklego.meta.zero_inflated_regressor.ZeroInflatedRegressor 64 | options: 65 | show_root_full_path: true 66 | show_root_heading: true 67 | 68 | ::: sklego.meta.hierarchical_predictor.HierarchicalPredictor 69 | options: 70 | show_root_full_path: true 71 | show_root_heading: true 72 | 73 | ::: sklego.meta.hierarchical_predictor.HierarchicalClassifier 74 | options: 75 | show_root_full_path: true 76 | show_root_heading: true 77 | 78 | ::: sklego.meta.hierarchical_predictor.HierarchicalRegressor 79 | options: 80 | show_root_full_path: true 81 | show_root_heading: true 82 | -------------------------------------------------------------------------------- /docs/api/metrics.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | ::: sklego.metrics.correlation_score 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.metrics.equal_opportunity_score 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.metrics.p_percent_score 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ::: sklego.metrics.subset_score 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | -------------------------------------------------------------------------------- /docs/api/mixture.md: -------------------------------------------------------------------------------- 1 | # Mixture Models 2 | 3 | :::sklego.mixture.bayesian_gmm_classifier.BayesianGMMClassifier 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | :::sklego.mixture.bayesian_gmm_detector.BayesianGMMOutlierDetector 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | :::sklego.mixture.gmm_classifier.GMMClassifier 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | :::sklego.mixture.gmm_outlier_detector.GMMOutlierDetector 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | -------------------------------------------------------------------------------- /docs/api/model-selection.md: -------------------------------------------------------------------------------- 1 | # Model Selection 2 | 3 | :::sklego.model_selection.TimeGapSplit 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | :::sklego.model_selection.GroupTimeSeriesSplit 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | :::sklego.model_selection.ClusterFoldValidation 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | ## `KlusterFoldValidation` 19 | 20 | Prior to `version 0.8.2`, the `ClusterFoldValidation` class was named `KlusterFoldValidation`. The old name is deprecated and will be removed in a future releases. 21 | -------------------------------------------------------------------------------- /docs/api/naive-bayes.md: -------------------------------------------------------------------------------- 1 | # Naive Bayes 2 | 3 | :::sklego.naive_bayes.GaussianMixtureNB 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | :::sklego.naive_bayes.BayesianGaussianMixtureNB 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | -------------------------------------------------------------------------------- /docs/api/neighbors.md: -------------------------------------------------------------------------------- 1 | # Neighbors 2 | 3 | :::sklego.neighbors.BayesianKernelDensityClassifier 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | -------------------------------------------------------------------------------- /docs/api/pandas-utils.md: -------------------------------------------------------------------------------- 1 | # Pandas Utils 2 | 3 | ::: sklego.pandas_utils.add_lags 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ::: sklego.pandas_utils.log_step 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | ::: sklego.pandas_utils.log_step_extra 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | -------------------------------------------------------------------------------- /docs/api/pipeline.md: -------------------------------------------------------------------------------- 1 | # Pipeline 2 | 3 | Pipelines, variances to the sklearn.pipeline.Pipeline object. 4 | 5 | ::: sklego.pipeline.DebugPipeline 6 | options: 7 | show_root_full_path: true 8 | show_root_heading: true 9 | 10 | ::: sklego.pipeline.make_debug_pipeline 11 | options: 12 | show_root_full_path: true 13 | show_root_heading: true 14 | 15 | ::: sklego.pipeline.default_log_callback 16 | options: 17 | show_root_full_path: true 18 | show_root_heading: true 19 | -------------------------------------------------------------------------------- /docs/api/preprocessing.md: -------------------------------------------------------------------------------- 1 | # Preprocessing 2 | 3 | :::sklego.preprocessing.columncapper.ColumnCapper 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | :::sklego.preprocessing.pandastransformers.ColumnDropper 9 | options: 10 | show_root_full_path: true 11 | show_root_heading: true 12 | 13 | :::sklego.preprocessing.pandastransformers.ColumnSelector 14 | options: 15 | show_root_full_path: true 16 | show_root_heading: true 17 | 18 | :::sklego.preprocessing.dictmapper.DictMapper 19 | options: 20 | show_root_full_path: true 21 | show_root_heading: true 22 | 23 | :::sklego.preprocessing.identitytransformer.IdentityTransformer 24 | options: 25 | show_root_full_path: true 26 | show_root_heading: true 27 | 28 | :::sklego.preprocessing.projections.InformationFilter 29 | options: 30 | show_root_full_path: true 31 | show_root_heading: true 32 | 33 | :::sklego.preprocessing.intervalencoder.IntervalEncoder 34 | options: 35 | show_root_full_path: true 36 | show_root_heading: true 37 | 38 | :::sklego.preprocessing.formulaictransformer.FormulaicTransformer 39 | options: 40 | show_root_full_path: true 41 | show_root_heading: true 42 | 43 | :::sklego.preprocessing.monotonicspline.MonotonicSplineTransformer 44 | options: 45 | show_root_full_path: true 46 | show_root_heading: true 47 | 48 | :::sklego.preprocessing.projections.OrthogonalTransformer 49 | options: 50 | show_root_full_path: true 51 | show_root_heading: true 52 | 53 | :::sklego.preprocessing.outlier_remover.OutlierRemover 54 | options: 55 | show_root_full_path: true 56 | show_root_heading: true 57 | 58 | :::sklego.preprocessing.pandastransformers.PandasTypeSelector 59 | options: 60 | show_root_full_path: true 61 | show_root_heading: true 62 | 63 | :::sklego.preprocessing.randomadder.RandomAdder 64 | options: 65 | show_root_full_path: true 66 | show_root_heading: true 67 | 68 | :::sklego.preprocessing.repeatingbasis.RepeatingBasisFunction 69 | options: 70 | show_root_full_path: true 71 | show_root_heading: true 72 | 73 | :::sklego.preprocessing.pandastransformers.TypeSelector 74 | options: 75 | show_root_full_path: true 76 | show_root_heading: true 77 | -------------------------------------------------------------------------------- /docs/api/shrinkage-functions.md: -------------------------------------------------------------------------------- 1 | # Shrinkage 2 | 3 | ::: sklego.meta._shrinkage_utils.ShrinkageMixin 4 | options: 5 | show_root_full_path: true 6 | show_root_heading: true 7 | 8 | ## Shrinkage Functions 9 | 10 | The following functions are the available built-in shrinkage accessed in the [`GroupedPredictor`][grouped-predictor-api] and [`HierarchicalPredictor`][hierarchical-predictor-api]. 11 | 12 | ::: sklego.meta._shrinkage_utils.constant_shrinkage 13 | options: 14 | show_root_full_path: true 15 | show_root_heading: true 16 | 17 | ::: sklego.meta._shrinkage_utils.equal_shrinkage 18 | options: 19 | show_root_full_path: true 20 | show_root_heading: true 21 | 22 | ::: sklego.meta._shrinkage_utils.min_n_obs_shrinkage 23 | options: 24 | show_root_full_path: true 25 | show_root_heading: true 26 | 27 | ::: sklego.meta._shrinkage_utils.relative_shrinkage 28 | options: 29 | show_root_full_path: true 30 | show_root_heading: true 31 | 32 | [grouped-predictor-api]: ../../api/meta#sklego.meta.grouped_predictor.GroupedPredictor 33 | [hierarchical-predictor-api]: ../../api/meta#sklego.meta.hierarchical_predictor.HierarchicalPredictor 34 | -------------------------------------------------------------------------------- /docs/generate_this_content.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | from typing import Final 5 | 6 | from sklego.this import poem 7 | 8 | DESTINATION_PATH: Final[Path] = Path("docs") / "this.md" 9 | 10 | content = f""" 11 | # Import This 12 | 13 | In Python there's a poem that you can read by importing the `this` module. 14 | 15 | ```py 16 | import this 17 | ``` 18 | 19 | It has wonderful lessons that the authors of the language learned while designing the python language. 20 | 21 | In the same tradition we've done the same thing. Folks who have made significant contributions have also been asked to 22 | contribute to the poem. 23 | 24 | You can read it via: 25 | 26 | ```py 27 | from sklego import this 28 | ``` 29 | 30 | ```console 31 | {poem} 32 | ``` 33 | """ 34 | 35 | with DESTINATION_PATH.open(mode="w") as destination: 36 | destination.write(content) 37 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # scikit-lego 2 | 3 | ![logo](_static/logo.png) 4 | 5 | We love scikit learn but very often we find ourselves writing custom transformers, metrics and models. 6 | The goal of this project is to attempt to consolidate these into a package that offers code quality/testing. 7 | This project is a collaboration between multiple companies in the Netherlands. 8 | Note that we're not formally affiliated with the scikit-learn project at all. 9 | 10 | ## Disclaimer 11 | 12 | LEGO® is a trademark of the LEGO Group of companies which does not sponsor, authorize or endorse this project. 13 | Also note this package, albeit designing to be used on top of scikit-learn, is not associated with that project in any formal manner. 14 | 15 | The goal of the package is to allow you to joyfully build with new building blocks that are scikit-learn compatible. 16 | 17 | ## Installation 18 | 19 | Install `scikit-lego` via pip with 20 | 21 | ```bash 22 | pip install scikit-lego 23 | ``` 24 | 25 | For more installation options and details, check the [installation section][installation-section]. 26 | 27 | ## Usage 28 | 29 | ```python 30 | from sklearn.preprocessing import StandardScaler 31 | from sklearn.linear_model import LogisticRegression 32 | from sklearn.pipeline import Pipeline 33 | 34 | from sklego.transformers import RandomAdder 35 | 36 | X, y = ... 37 | 38 | mod = Pipeline([ 39 | ("scale", StandardScaler()), 40 | ("random_noise", RandomAdder()), 41 | ("model", LogisticRegression(solver='lbfgs')) 42 | ]) 43 | 44 | _ = mod.fit(X, y) 45 | ... 46 | ``` 47 | 48 | To see more examples, please refer to the [user guide section][user-guide]. 49 | 50 | [installation-section]: installation 51 | [user-guide]: user-guide/datasets 52 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | !!! warning 4 | 5 | This project is experimental and is in alpha. We do our best to keep things stable but you should assume that if 6 | you do not specify a version number that certain functionality can break. 7 | 8 | Install **scikit-lego**: 9 | 10 | === "pip" 11 | 12 | ```bash 13 | python -m pip install scikit-lego 14 | ``` 15 | 16 | === "conda" 17 | 18 | ```bash 19 | conda install -c conda-forge scikit-lego 20 | ``` 21 | 22 | === "source/git" 23 | 24 | ```bash 25 | python -m pip install git+https://github.com/koaning/scikit-lego.git 26 | ``` 27 | 28 | === "local clone" 29 | 30 | ```bash 31 | git clone https://github.com/koaning/scikit-lego.git 32 | cd scikit-lego 33 | python -m pip install . 34 | ``` 35 | 36 | ## Dependency installs 37 | 38 | Some functionality can only be used if certain dependencies are installed. This can be done by specifying the extra dependencies in square brackets after the package name. 39 | 40 | Currently supported extras are [**cvxpy**][cvxpy], [**formulaic**][formulaic] and [**umap**][umap]. You can specify these as follows: 41 | 42 | === "pip" 43 | 44 | ```bash 45 | python -m pip install scikit-lego"[cvxpy]" 46 | python -m pip install scikit-lego"[formulaic]" 47 | python -m pip install scikit-lego"[umap]" 48 | python -m pip install scikit-lego"[all]" 49 | ``` 50 | 51 | === "local clone" 52 | 53 | ```bash 54 | git clone https://github.com/koaning/scikit-lego.git 55 | cd scikit-lego 56 | 57 | python -m pip install ".[cvxpy]" 58 | python -m pip install ."[formulaic]" 59 | python -m pip install ."[umap]" 60 | python -m pip install ".[all]" 61 | ``` 62 | 63 | [cvxpy]: https://www.cvxpy.org/ 64 | [formulaic]: https://matthewwardrop.github.io/formulaic/ 65 | [umap]: https://umap-learn.readthedocs.io/en/latest/index.html 66 | -------------------------------------------------------------------------------- /docs/user-guide/feature-selection.md: -------------------------------------------------------------------------------- 1 | # Feature Selection 2 | 3 | ## Maximum Relevance Minimum Redundancy 4 | 5 | !!! info "New in version 0.8.0" 6 | 7 | The [`Maximum Relevance Minimum Redundancy`][MaximumRelevanceMinimumRedundancy-api] (MRMR) is an iterative feature selection method commonly used in data science to select a subset of features from a larger feature set. The goal of MRMR is to choose features that have high *relevance* to the target variable while minimizing *redundancy* among the already selected features. 8 | 9 | MRMR is heavily dependent on the two functions used to determine relevace and redundancy. However, the paper [Maximum Relevanceand Minimum Redundancy Feature Selection Methods for a Marketing Machine Learning Platform](https://arxiv.org/pdf/1908.05376.pdf) shows that using [f_classif](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html) or [f_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) as relevance function and Pearson correlation as redundancy function is the best choice for a variety of different problems and in general is a good choice. 10 | 11 | Inspired by the Medium article [Feature Selection: How To Throw Away 95% of Your Data and Get 95% Accuracy](https://towardsdatascience.com/feature-selection-how-to-throw-away-95-of-your-data-and-get-95-accuracy-ad41ca016877) we showcase a practical application using the well known mnist dataset. 12 | 13 | Note that although the default scikit-lego MRMR implementation uses redundancy and relevance as defined in [Maximum Relevanceand Minimum Redundancy Feature Selection Methods for a Marketing Machine Learning Platform](https://arxiv.org/pdf/1908.05376.pdf), our implementation offers the possibility of defining custom functions, that may be necessary in different scenarios depending on the data. 14 | 15 | We will compare this list of well known filters method: 16 | 17 | - F statistical test ([ANOVA F-test](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html)). 18 | - Mutual information approximation based on sklearn implementation. 19 | 20 | Against the default scikit-lego MRMR implementation and a custom MRMR implementation aimed to select features in order to draw a smiling face on the plot showing the minst letters. 21 | 22 | 23 | 24 | ??? example "MRMR imports" 25 | ```py 26 | --8<-- "docs/_scripts/feature-selection.py:mrmr-commonimports" 27 | ``` 28 | 29 | ```py title="MRMR mnist" 30 | --8<-- "docs/_scripts/feature-selection.py:mrmr-intro" 31 | ``` 32 | 33 | As custom functions, we implemented the smile redundancy and smile relevance. 34 | 35 | ```py title="MRMR smile functions" 36 | --8<-- "docs/_scripts/feature-selection.py:mrmr-smile" 37 | ``` 38 | 39 | Then we execute the main code part. 40 | 41 | ```py title="MRMR core" 42 | --8<-- "docs/_scripts/feature-selection.py:mrmr-core" 43 | ``` 44 | 45 | After the execution it is possible to inspect the F1-score for the selected features: 46 | 47 | ```py title="MRMR mnist selected features" 48 | --8<-- "docs/_scripts/feature-selection.py:mrmr-selected-features" 49 | ``` 50 | 51 | ```console hl_lines="5-6" 52 | Feature selection method: f_classif 53 | F1 score: 0.854 54 | Feature selection method: mutual_info 55 | F1 score: 0.879 56 | Feature selection method: mrmr 57 | F1 score: 0.925 58 | Feature selection method: mrmr_smile 59 | F1 score: 0.849 60 | ``` 61 | 62 | The MRMR feature selection model provides better results compared against the other methods, although the smile technique performs rather good as well. 63 | 64 | Finally, we can take a look at the selected features. 65 | 66 | ??? example "MRMR generate plots" 67 | ```py 68 | --8<-- "docs/_scripts/feature-selection.py:mrmr-plots" 69 | ``` 70 | 71 | ![selected-features-mrmr](../_static/feature-selection/mrmr-feature-selection-mnist.png) 72 | 73 | [MaximumRelevanceMinimumRedundancy-api]: ../../api/feature-selection#sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy 74 | -------------------------------------------------------------------------------- /docs/user-guide/mixture-methods.md: -------------------------------------------------------------------------------- 1 | # Mixture Methods 2 | 3 | Gaussian Mixture Models (GMMs) are flexible building blocks for other machine learning algorithms. 4 | 5 | This is in part because they are great approximations for general probability distributions but also because they remain somewhat interpretable even when the dataset gets very complex. 6 | 7 | This package makes use of GMMs to construct other algorithms. In addition to the [GMMClassifier][gmm-classifier-api] and [GMMDetector][gmm-classifier-api], this library also features a [BayesianGMMClassifier][bayes_gmm-classifier-api] and [BayesianGMMDetector][bayes_gmm-outlier-detector-api] as well. These methods offer pretty much the same API, but will have internal methods to figure out what number of components to estimate. These methods tend to take significantly more time to train, so alternatively you may also try doing a proper grid search to figure out the best number of components for your use-case. 8 | 9 | ## Classification 10 | 11 | Below is some example code of how you might use a [GMMClassifier][gmm-classifier-api] from sklego to perform classification. 12 | 13 | ```py title="GMMClassifier" 14 | --8<-- "docs/_scripts/mixture-methods.py:gmm-classifier" 15 | ``` 16 | 17 | ![gmm-classifier](../_static/mixture-methods/gmm-classifier.png) 18 | 19 | ## Outlier Detection 20 | 21 | Below is some example code of how you might use a GMM from sklego to do outlier detection. 22 | 23 | Note that the [GMMOutlierDetector][gmm-outlier-detector-api] generates prediction values that are either -1 (outlier) or +1 (normal). 24 | 25 | ```py title="GMMOutlierDetector" 26 | --8<-- "docs/_scripts/mixture-methods.py:gmm-outlier-detector" 27 | ``` 28 | 29 | ![gmm-outlier-detector](../_static/mixture-methods/gmm-outlier-detector.png) 30 | 31 | Remark that with a GMM there are multiple ways to select outliers. Instead of selection points that are beyond the likely quantile threshold one can also specify the number of standard deviations away from the most likely standard deviations a given point it. 32 | 33 | ??? example "Different thresholds" 34 | ```py 35 | --8<-- "docs/_scripts/mixture-methods.py:gmm-outlier-multi-threshold" 36 | ``` 37 | 38 | ![gmm-outlier-multi-threshold](../_static/mixture-methods/gmm-outlier-multi-threshold.png) 39 | 40 | ### Detection Details 41 | 42 | The outlier detection methods that we use are based on the likelihoods that come out of the estimated Gaussian Mixture. 43 | 44 | Depending on the setting you choose we have a different method for determining if a point is inside or outside the 45 | threshold. 46 | 47 | 1. If the `"quantile"` method is used, we take all the likelihood scores found that the GMM associates on a training dataset to determine where to set a threshold. The threshold value must be between 0 and 1 here. 48 | 2. If the `"stddev"` method is used, then the threshold value is now interpreted as the number of standard deviations lower than the mean we are. We only calculate the standard deviation on the lower scores because there's usually more variance here. 49 | !!! note 50 | This setting allows you to be much more picky in selecting than the `"quantile"` one since this method allows you to be more exclusive than the `"quantile"` method with threshold equal to one. 51 | 52 | ![outlier-mixture-threshold](../_static/mixture-methods/outlier-mixture-threshold.png) 53 | 54 | As a sidenote: this image was generated with some dummy data, but its code can be found below: 55 | 56 | !!! example "Code for plot generation" 57 | ```py 58 | --8<-- "docs/_scripts/mixture-methods.py:outlier-mixture-threshold" 59 | ``` 60 | 61 | [gmm-classifier-api]: ../../api/mixture#sklego.mixture.gmm_classifier.GMMClassifier 62 | [bayes_gmm-classifier-api]: ../../api/mixture#sklego.mixture.bayesian_gmm_classifier.BayesianGMMClassifier 63 | [gmm-outlier-detector-api]: ../../api/mixture#sklego.mixture.gmm_outlier_detector.GMMOutlierDetector 64 | [bayes_gmm-outlier-detector-api]: ../../api/mixture#sklego.mixture.gmm_outlier_detector.BayesianGMMOutlierDetector 65 | -------------------------------------------------------------------------------- /docs/user-guide/naive-bayes.md: -------------------------------------------------------------------------------- 1 | # Naive Bayes 2 | 3 | Naive Bayes models are flexible and interpretable. In scikit-lego we've added support for a Gaussian Mixture variant of the algorithm. 4 | 5 | ![naive bayes sketch](../_static/naive-bayes/naive-bayes.png) 6 | 7 | An example of the usage of algorithm can be found below. 8 | 9 | ## Example 10 | 11 | Let's first import the dependencies and create some data. This code will create a plot of the dataset we'll try to predict. 12 | 13 | ```py title="Simulated dataset" 14 | --8<-- "docs/_scripts/naive-bayes.py:simulated-data" 15 | ``` 16 | 17 | ![simulated-data](../_static/naive-bayes/simulated-data.png) 18 | 19 | Note that this dataset would be hard to classify directly if we would be using a standard Gaussian Naive Bayes algorithm since the orange class is multipeaked over two clusters. 20 | 21 | To demonstrate this we'll run our [`GaussianMixtureNB`][gaussian-mix-nb-api] algorithm with one or two gaussians that the mixture is allowed to find. 22 | 23 | ```py title="GaussianMixtureNB model" 24 | --8<-- "docs/_scripts/naive-bayes.py:model-results" 25 | ``` 26 | 27 | ![results](../_static/naive-bayes/model-results.png) 28 | 29 | Note that the second plot fits the original much better. 30 | 31 | We can even zoom in on this second algorithm by having it sample what it believes is the distribution on each column. 32 | 33 | ??? example "Model density" 34 | ```py title="Model density" 35 | --8<-- "docs/_scripts/naive-bayes.py:model-density" 36 | ``` 37 | 38 | ![density](../_static/naive-bayes/model-density.png) 39 | 40 | [gaussian-mix-nb-api]: ../../api/naive-bayes#sklego.naive_bayes.GaussianMixtureNB 41 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from sklego import meta, pipeline, pandas_utils, dummy, linear_model, mixture, \ 3 | naive_bayes, datasets, model_selection, preprocessing, metrics 4 | 5 | 6 | def not_in(thing, *substrings): 7 | for string in substrings: 8 | if string in thing: 9 | return False 10 | return True 11 | 12 | 13 | def print_classes(submodule): 14 | for cls in dir(submodule): 15 | if inspect.isclass(getattr(submodule, cls)): 16 | if not_in(cls, 'Mixin', 'Base'): 17 | if (cls[0].upper() == cls[0]) and (cls[0] != '_'): 18 | print(f"{submodule.__name__}.{cls}") 19 | 20 | 21 | def print_functions(submodule): 22 | for cls in dir(submodule): 23 | if inspect.isfunction(getattr(submodule, cls)): 24 | if cls[0] != '_': 25 | print(f"{submodule.__name__}.{cls}") 26 | 27 | 28 | if __name__ == "__main__": 29 | print_functions(datasets) 30 | print_functions(pandas_utils) 31 | print_classes(dummy) 32 | print_classes(linear_model) 33 | print_classes(naive_bayes) 34 | print_classes(mixture) 35 | print_classes(meta) 36 | print_classes(preprocessing) 37 | print_classes(model_selection) 38 | print_classes(pipeline) 39 | print_functions(pipeline) 40 | print_functions(metrics) 41 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/images/logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "scikit-lego" 7 | version = "0.9.5" 8 | description="A collection of lego bricks for scikit-learn pipelines" 9 | 10 | license = {file = "LICENSE"} 11 | readme = "readme.md" 12 | requires-python = ">=3.6" 13 | authors = [ 14 | {name = "Vincent D. Warmerdam"}, 15 | {name = "Matthijs Brouns"}, 16 | ] 17 | 18 | maintainers = [ 19 | {name = "Francesco Bruzzesi"} 20 | ] 21 | 22 | dependencies = [ 23 | "narwhals>=1.5.0", 24 | "pandas>=1.1.5", 25 | "scikit-learn>=1.0", 26 | "sklearn-compat>=0.1.3", 27 | "importlib-metadata >= 1.0; python_version < '3.8'", 28 | "importlib-resources; python_version < '3.9'", 29 | ] 30 | 31 | classifiers = [ 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.8", 34 | "Programming Language :: Python :: 3.9", 35 | "Programming Language :: Python :: 3.10", 36 | "Programming Language :: Python :: 3.11", 37 | "Programming Language :: Python :: 3.12", 38 | "License :: OSI Approved :: MIT License", 39 | "Topic :: Scientific/Engineering", 40 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 41 | "Topic :: Software Development :: Libraries :: Python Modules", 42 | ] 43 | 44 | [project.urls] 45 | repository = "https://github.com/koaning/scikit-lego" 46 | issue-tracker = "https://github.com/koaning/scikit-lego/issues" 47 | documentation = "https://koaning.github.io/scikit-lego/" 48 | 49 | [project.optional-dependencies] 50 | cvxpy = ["cmake", "osqp", "cvxpy>=1.1.8", "numpy<2.0"] 51 | formulaic = ["formulaic>=0.6.0"] 52 | umap = ["umap-learn>=0.4.6", "numpy<2.0"] 53 | 54 | all = ["scikit-lego[cvxpy,formulaic,umap]"] 55 | 56 | docs = [ 57 | "mkdocs>=1.5.3", 58 | "mkdocs-autorefs>=0.5.0", 59 | "mkdocs-material>=9.4.5", 60 | "mkdocs-material-extensions>=1.2", 61 | "mkdocstrings>=0.23.0", 62 | "mkdocstrings-python>=1.7.3", 63 | ] 64 | 65 | test = [ 66 | "narwhals[polars,pyarrow]", 67 | "pytest>=6.2.5", 68 | "pytest-xdist>=1.34.0", 69 | "pytest-cov>=2.6.1", 70 | "pytest-mock>=1.6.3", 71 | ] 72 | 73 | test-all = [ 74 | "scikit-lego[all,test]", 75 | ] 76 | 77 | utils = [ 78 | "matplotlib>=3.0.2", 79 | "jupyter>=1.0.0", 80 | "jupyterlab>=0.35.4", 81 | ] 82 | 83 | dev = [ 84 | "scikit-lego[all,test,docs]", 85 | "pre-commit>=1.18.3", 86 | "ruff>=0.1.6", 87 | ] 88 | 89 | [tool.setuptools.packages.find] 90 | include = ["sklego*"] 91 | exclude = [ 92 | "docs", 93 | "images", 94 | "notebooks", 95 | "tests", 96 | ] 97 | 98 | [tool.setuptools.package-data] 99 | sklego = ["data/*.zip"] 100 | 101 | [tool.ruff] 102 | line-length = 120 103 | exclude = ["docs"] 104 | 105 | [tool.ruff.lint] 106 | extend-select = ["I", "T201"] 107 | ignore = [ 108 | "E731", # do not assign a `lambda` expression, use a `def` 109 | ] 110 | 111 | [tool.pytest.ini_options] 112 | markers = [ 113 | "cvxpy: tests that require cvxpy (deselect with '-m \"not cvxpy\"')", 114 | "formulaic: tests that require formulaic (deselect with '-m \"not formulaic\"')", 115 | "umap: tests that require umap (deselect with '-m \"not umap\"')" 116 | ] 117 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.5.3 2 | mkdocs-autorefs==0.5.0 3 | mkdocs-material==9.4.5 4 | mkdocs-material-extensions==1.2 5 | mkdocstrings==0.23.0 6 | mkdocstrings-python==1.7.3 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | setup.py is used to distribute scikit-lego as a package using setuptools and twine 4 | """ 5 | 6 | from setuptools import setup 7 | 8 | if __name__ == "__main__": 9 | setup() 10 | -------------------------------------------------------------------------------- /sklego/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info >= (3, 8): 4 | from importlib import metadata 5 | else: 6 | import importlib_metadata as metadata 7 | 8 | __title__ = "sklego" 9 | __version__ = metadata.version("scikit-lego") 10 | -------------------------------------------------------------------------------- /sklego/base.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import OutlierMixin 2 | 3 | 4 | class ProbabilisticClassifierMeta(type): 5 | """Metaclass for `ProbabilisticClassifier`. 6 | 7 | This metaclass is responsible for checking whether a class can be considered a `ProbabilisticClassifier`. 8 | A class is considered a `ProbabilisticClassifier` if it has a "predict_proba" method. 9 | """ 10 | 11 | def __instancecheck__(self, other): 12 | """Checks if the provided object is a `ProbabilisticClassifier`. 13 | 14 | Parameters 15 | ---------- 16 | self : ProbabilisticClassifierMeta 17 | `ProbabilisticClassifierMeta` class. 18 | other : object 19 | The object to check for `ProbabilisticClassifier` compatibility. 20 | 21 | Returns 22 | ------- 23 | bool 24 | True if the object is a `ProbabilisticClassifier` (has a "predict_proba" method ), False otherwise. 25 | """ 26 | return hasattr(other, "predict_proba") 27 | 28 | 29 | class ProbabilisticClassifier(metaclass=ProbabilisticClassifierMeta): 30 | """Base class for `ProbabilisticClassifier`. 31 | 32 | This base class defines the `ProbabilisticClassifier` interface, indicating that subclasses should have a 33 | "predict_proba" method. 34 | """ 35 | 36 | pass 37 | 38 | 39 | class ClustererMeta(type): 40 | """Metaclass for `Clusterer`. 41 | 42 | This metaclass is responsible for checking whether a class can be considered a `Clusterer`. 43 | A class is considered a `Clusterer` if it has a "fit_predict" method. 44 | """ 45 | 46 | def __instancecheck__(self, other): 47 | """Checks if the provided object is a `Clusterer`. 48 | 49 | Parameters 50 | ---------- 51 | self : ClustererMeta 52 | `ClustererMeta` class. 53 | other : object 54 | The object to check for `Clusterer` compatibility. 55 | 56 | Returns 57 | ------- 58 | bool 59 | True if the object is a `Clusterer` (has a "fit_predict" method ), False otherwise. 60 | """ 61 | return hasattr(other, "fit_predict") 62 | 63 | 64 | class Clusterer(metaclass=ClustererMeta): 65 | """Base class for `Clusterer`. 66 | 67 | This base class defines the `Clusterer` interface, indicating that subclasses should have a "fit_predict" method. 68 | """ 69 | 70 | pass 71 | 72 | 73 | class OutlierModelMeta(type): 74 | """Metaclass for `OutlierModel`. 75 | 76 | This metaclass is responsible for checking whether a class can be considered an `OutlierModel`. 77 | A class is considered an `OutlierModel` if it is an instance of the `sklearn.base.OutlierMixin` class. 78 | """ 79 | 80 | def __instancecheck__(self, other): 81 | """ 82 | Check if the provided object is an `OutlierModel`. 83 | 84 | Parameters 85 | ---------- 86 | self : OutlierModelMeta 87 | The `OutlierModelMeta` class. 88 | other : object 89 | The object to check for `OutlierModel` compatibility. 90 | 91 | Returns 92 | ------- 93 | bool 94 | True if the object is an `OutlierModel` (an instance of "OutlierMixin"), False otherwise. 95 | """ 96 | return isinstance(other, OutlierMixin) 97 | 98 | 99 | class OutlierModel(metaclass=OutlierModelMeta): 100 | """Base class for `OutlierModel`. 101 | 102 | This base class defines the `OutlierModel` interface, indicating that subclasses should be instances of the 103 | "OutlierMixin" class. 104 | """ 105 | 106 | pass 107 | -------------------------------------------------------------------------------- /sklego/data/abalone.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/abalone.zip -------------------------------------------------------------------------------- /sklego/data/arrests.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/arrests.zip -------------------------------------------------------------------------------- /sklego/data/chickweight.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/chickweight.zip -------------------------------------------------------------------------------- /sklego/data/hearts.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/hearts.zip -------------------------------------------------------------------------------- /sklego/data/heroes.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/heroes.zip -------------------------------------------------------------------------------- /sklego/data/penguins.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/sklego/data/penguins.zip -------------------------------------------------------------------------------- /sklego/decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | from sklego.decomposition.pca_reconstruction import PCAOutlierDetection 2 | from sklego.decomposition.umap_reconstruction import UMAPOutlierDetection 3 | 4 | __all__ = ["PCAOutlierDetection", "UMAPOutlierDetection"] 5 | -------------------------------------------------------------------------------- /sklego/feature_selection/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "MaximumRelevanceMinimumRedundancy", 3 | ] 4 | 5 | from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy 6 | -------------------------------------------------------------------------------- /sklego/meta/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ConfusionBalancer", 3 | "DecayEstimator", 4 | "EstimatorTransformer", 5 | "GroupedClassifier", 6 | "GroupedPredictor", 7 | "GroupedRegressor", 8 | "GroupedTransformer", 9 | "HierarchicalClassifier", 10 | "HierarchicalPredictor", 11 | "HierarchicalRegressor", 12 | "OrdinalClassifier", 13 | "SubjectiveClassifier", 14 | "Thresholder", 15 | "RegressionOutlierDetector", 16 | "OutlierClassifier", 17 | "ZeroInflatedRegressor", 18 | ] 19 | from sklego.meta.confusion_balancer import ConfusionBalancer 20 | from sklego.meta.decay_estimator import DecayEstimator 21 | from sklego.meta.estimator_transformer import EstimatorTransformer 22 | from sklego.meta.grouped_predictor import GroupedClassifier, GroupedPredictor, GroupedRegressor 23 | from sklego.meta.grouped_transformer import GroupedTransformer 24 | from sklego.meta.hierarchical_predictor import HierarchicalClassifier, HierarchicalPredictor, HierarchicalRegressor 25 | from sklego.meta.ordinal_classification import OrdinalClassifier 26 | from sklego.meta.outlier_classifier import OutlierClassifier 27 | from sklego.meta.regression_outlier_detector import RegressionOutlierDetector 28 | from sklego.meta.subjective_classifier import SubjectiveClassifier 29 | from sklego.meta.thresholder import Thresholder 30 | from sklego.meta.zero_inflated_regressor import ZeroInflatedRegressor 31 | -------------------------------------------------------------------------------- /sklego/meta/_grouped_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List 4 | 5 | import narwhals.stable.v1 as nw 6 | import pandas as pd 7 | from scipy.sparse import issparse 8 | from sklearn.utils.validation import _ensure_no_complex_data 9 | from sklearn_compat.utils.validation import check_array 10 | 11 | 12 | def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame: 13 | """Converts X, y to narwhals dataframe. 14 | 15 | If it is not a supported dataframe, it uses pandas constructor as a fallback. 16 | 17 | Additionally, data checks are performed. 18 | """ 19 | # Check raw X 20 | _data_format_checks(X) 21 | 22 | # Convert X to Narwhals frame 23 | X = nw.from_native(X, strict=False, eager_only=True) 24 | if not isinstance(X, nw.DataFrame): 25 | X = nw.from_native(pd.DataFrame(X)) 26 | 27 | # Check groups and feaures values 28 | if groups: 29 | _validate_groups_values(X, groups) 30 | 31 | if check_X: 32 | check_array(X.drop(groups), **kwargs) 33 | 34 | # Convert y and assign it to the frame 35 | n_samples = X.shape[0] 36 | y_series = nw.new_series( 37 | name="tmp", values=[None] * n_samples if y is None else y, native_namespace=nw.get_native_namespace(X) 38 | ) 39 | 40 | if len(y_series) != n_samples: 41 | msg = f"Found input variables with inconsistent numbers of samples: {[n_samples, len(y_series)]}" 42 | raise ValueError(msg) 43 | 44 | return X.with_columns(__sklego_target__=y_series) 45 | 46 | 47 | def _validate_groups_values(X: nw.DataFrame, groups: List[int] | List[str]) -> None: 48 | X_cols = X.columns 49 | unexisting_cols = [g for g in groups if g not in X_cols] 50 | 51 | if len(unexisting_cols): 52 | raise ValueError(f"The following groups are not available in X: {unexisting_cols}") 53 | 54 | if X.select(nw.col(groups).is_null().any()).to_numpy().squeeze().any(): 55 | raise ValueError("Groups values have NaN") 56 | 57 | 58 | def _data_format_checks(X): 59 | """Checks that X is not sparse nor has complex dtype""" 60 | _ensure_no_complex_data(X) 61 | 62 | if issparse(X): # sklearn.validation._ensure_sparse_format to complicated 63 | msg = "Estimator does not work on sparse matrices" 64 | raise ValueError(msg) 65 | -------------------------------------------------------------------------------- /sklego/mixture/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["GMMClassifier", "BayesianGMMClassifier", "GMMOutlierDetector", "BayesianGMMOutlierDetector"] 2 | 3 | from sklego.mixture.bayesian_gmm_classifier import BayesianGMMClassifier 4 | from sklego.mixture.bayesian_gmm_detector import BayesianGMMOutlierDetector 5 | from sklego.mixture.gmm_classifier import GMMClassifier 6 | from sklego.mixture.gmm_outlier_detector import GMMOutlierDetector 7 | -------------------------------------------------------------------------------- /sklego/notinstalled.py: -------------------------------------------------------------------------------- 1 | KNOWN_PACKAGES = { 2 | "cvxpy": {"version": ">=1.0.24", "extra_name": "cvxpy"}, 3 | "umap-learn": {"version": ">=0.4.6", "extra_name": "umap"}, 4 | "formulaic": {"version": ">=0.6.0", "extra_name": "formulaic"}, 5 | } 6 | 7 | 8 | class NotInstalledPackage: 9 | """Class to gracefully catch `ImportError`s for modules and packages that are not installed. 10 | 11 | Parameters 12 | ---------- 13 | package_name : str 14 | Name of the package you want to load 15 | version : str | None, default=None 16 | Version of the package 17 | 18 | Examples 19 | -------- 20 | ```py 21 | try: 22 | import thispackagedoesnotexist as package 23 | except ImportError: 24 | from sklego.notinstalled import NotInstalledPackage 25 | package = NotInstalledPackage("thispackagedoesnotexist") 26 | ``` 27 | """ 28 | 29 | def __init__(self, package_name: str, version: str = None): 30 | self.package_name = package_name 31 | package_info = KNOWN_PACKAGES.get(package_name, {}) 32 | self.version = version if version else package_info.get("version", "") 33 | 34 | extra_name = package_info.get("extra_name", None) 35 | self.pip_message = ( 36 | ( 37 | f"Install extra requirement {package_name} using " 38 | + f"`python -m pip install scikit-lego[{extra_name}]` or " 39 | + "`python -m pip install scikit-lego[all]`. " 40 | + "For more information, check the 'Dependency installs' section of the installation docs at " 41 | + "https://scikit-lego.netlify.app/install" 42 | ) 43 | if extra_name 44 | else "" 45 | ) 46 | 47 | def __getattr__(self, name): 48 | raise ImportError(f"The package {self.package_name}{self.version} is not installed. " + self.pip_message) 49 | -------------------------------------------------------------------------------- /sklego/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "ColumnCapper", 3 | "ColumnDropper", 4 | "ColumnSelector", 5 | "DictMapper", 6 | "FormulaicTransformer", 7 | "IdentityTransformer", 8 | "InformationFilter", 9 | "IntervalEncoder", 10 | "OrthogonalTransformer", 11 | "OutlierRemover", 12 | "PandasTypeSelector", 13 | "TypeSelector", 14 | "RandomAdder", 15 | "RepeatingBasisFunction", 16 | "MonotonicSplineTransformer", 17 | ] 18 | 19 | from sklego.preprocessing.columncapper import ColumnCapper 20 | from sklego.preprocessing.dictmapper import DictMapper 21 | from sklego.preprocessing.formulaictransformer import FormulaicTransformer 22 | from sklego.preprocessing.identitytransformer import IdentityTransformer 23 | from sklego.preprocessing.intervalencoder import IntervalEncoder 24 | from sklego.preprocessing.monotonicspline import MonotonicSplineTransformer 25 | from sklego.preprocessing.outlier_remover import OutlierRemover 26 | from sklego.preprocessing.pandastransformers import ColumnDropper, ColumnSelector, PandasTypeSelector, TypeSelector 27 | from sklego.preprocessing.projections import InformationFilter, OrthogonalTransformer 28 | from sklego.preprocessing.randomadder import RandomAdder 29 | from sklego.preprocessing.repeatingbasis import RepeatingBasisFunction 30 | -------------------------------------------------------------------------------- /sklego/preprocessing/dictmapper.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | 3 | import numpy as np 4 | from sklearn.base import BaseEstimator, TransformerMixin 5 | from sklearn.utils.validation import check_is_fitted 6 | from sklearn_compat.utils.validation import validate_data 7 | 8 | 9 | class DictMapper(TransformerMixin, BaseEstimator): 10 | """The `DictMapper` transformer maps the values of columns according to the input `mapper` dictionary, fall back to 11 | the `default` value if the key is not present in the dictionary. 12 | 13 | Parameters 14 | ---------- 15 | mapper : dict[..., int] 16 | The dictionary containing the mapping of the values. 17 | default : int 18 | The value to fall back to if the value is not in the mapper. 19 | 20 | Attributes 21 | ---------- 22 | n_features_in_ : int 23 | Number of features seen during `fit`. 24 | dim_ : int 25 | Deprecated, please use `n_features_in_` instead. 26 | 27 | Examples 28 | -------- 29 | ```py 30 | import pandas as pd 31 | from sklego.preprocessing.dictmapper import DictMapper 32 | from sklearn.compose import ColumnTransformer 33 | 34 | X = pd.DataFrame({ 35 | "city_pop": ["Amsterdam", "Leiden", "Utrecht", "None", "Haarlem"] 36 | }) 37 | 38 | mapper = { 39 | "Amsterdam": 1_181_817, 40 | "Leiden": 130_181, 41 | "Utrecht": 367_984, 42 | "Haarlem": 165_396, 43 | } 44 | 45 | ct = ColumnTransformer([("dictmapper", DictMapper(mapper, 0), ["city_pop"])]) 46 | X_trans = ct.fit_transform(X) 47 | X_trans 48 | # array([[1181817], 49 | # [ 130181], 50 | # [ 367984], 51 | # [ 0], 52 | # [ 165396]]) 53 | ``` 54 | """ 55 | 56 | _required_parameters = ["mapper", "default"] 57 | 58 | def __init__(self, mapper, default): 59 | self.mapper = mapper 60 | self.default = default 61 | 62 | def fit(self, X, y=None): 63 | """Checks the input data and records the number of features. 64 | 65 | Parameters 66 | ---------- 67 | X : array-like of shape (n_samples, n_features) 68 | The data to fit. 69 | y : array-like of shape (n_samples,), default=None 70 | Ignored, present for compatibility. 71 | 72 | Returns 73 | ------- 74 | self : DictMapper 75 | The fitted transformer. 76 | """ 77 | X = validate_data(self, X=X, copy=True, dtype=None, ensure_2d=True, ensure_all_finite=False, reset=True) 78 | return self 79 | 80 | def transform(self, X): 81 | """Performs the mapping on the column(s) of `X`. 82 | 83 | Parameters 84 | ---------- 85 | X : array-like of shape (n_samples, n_features) 86 | The data for which the mapping will be applied. 87 | 88 | Returns 89 | ------- 90 | np.ndarray of shape (n_samples, n_features) 91 | The data with the mapping applied. 92 | 93 | Raises 94 | ------ 95 | ValueError 96 | If the number of columns from `X` differs from the number of columns when fitting. 97 | """ 98 | check_is_fitted(self, ["n_features_in_"]) 99 | X = validate_data(self, X=X, copy=True, dtype=None, ensure_2d=True, ensure_all_finite=False, reset=False) 100 | return np.vectorize(self.mapper.get, otypes=[int])(X, self.default) 101 | 102 | @property 103 | def dim_(self): 104 | warn( 105 | "Please use `n_features_in_` instead of `dim_`, `dim_` will be deprecated in future versions", 106 | DeprecationWarning, 107 | ) 108 | return self.n_features_in_ 109 | 110 | def _more_tags(self): 111 | return {"preserves_dtype": None, "allow_nan": True, "no_validation": True} 112 | 113 | def __sklearn_tags__(self): 114 | tags = super().__sklearn_tags__() 115 | tags.transformer_tags.preserves_dtype = [] 116 | tags.input_tags.allow_nan = True 117 | tags.no_validation = True 118 | return tags 119 | -------------------------------------------------------------------------------- /sklego/preprocessing/identitytransformer.py: -------------------------------------------------------------------------------- 1 | from sklearn.base import BaseEstimator, TransformerMixin 2 | from sklearn.utils.validation import check_is_fitted 3 | from sklearn_compat.utils.validation import _check_n_features, validate_data 4 | 5 | 6 | class IdentityTransformer(TransformerMixin, BaseEstimator): 7 | """The `IdentityTransformer` returns what it is fed. Does not apply any transformation. 8 | 9 | The reason for having it is because you can build more expressive pipelines. 10 | 11 | Parameters 12 | ---------- 13 | check_X : bool, default=False 14 | Whether to validate `X` to be non-empty 2D array of finite values and attempt to cast `X` to float. 15 | If disabled, the model/pipeline is expected to handle e.g. missing, non-numeric, or non-finite values. 16 | 17 | Attributes 18 | ---------- 19 | n_samples_ : int 20 | The number of samples seen during `fit`. 21 | n_features_in_ : int 22 | The number of features seen during `fit`. 23 | shape_ : tuple[int, int] 24 | Deprecated, please use `n_samples_` and `n_features_in_` instead. 25 | 26 | Examples 27 | -------- 28 | ```py 29 | import pandas as pd 30 | from sklego.preprocessing import IdentityTransformer 31 | 32 | df = pd.DataFrame({ 33 | "name": ["Swen", "Victor", "Alex"], 34 | "length": [1.82, 1.85, 1.80], 35 | "shoesize": [42, 44, 45] 36 | }) 37 | 38 | IdentityTransformer().fit_transform(df) 39 | # name length shoesize 40 | # 0 Swen 1.82 42 41 | # 1 Victor 1.85 44 42 | # 2 Alex 1.80 45 43 | 44 | #using check_X=True to validate `X` to be non-empty 2D array of finite values and attempt to cast `X` to float 45 | IdentityTransformer(check_X=True).fit_transform(df.drop(columns="name")) 46 | # array([[ 1.82, 42. ], 47 | # [ 1.85, 44. ], 48 | # [ 1.8 , 45. ]]) 49 | ``` 50 | """ 51 | 52 | def __init__(self, check_X: bool = False): 53 | self.check_X = check_X 54 | 55 | def fit(self, X, y=None): 56 | """Check the input data if `check_X` is enabled and and records its shape. 57 | 58 | Parameters 59 | ---------- 60 | X : array-like of shape (n_samples, n_features) 61 | The data to fit. 62 | y : array-like of shape (n_samples,), default=None 63 | Ignored, present for compatibility. 64 | 65 | Returns 66 | ------- 67 | self : IdentityTransformer 68 | The fitted transformer. 69 | """ 70 | if self.check_X: 71 | X = validate_data(self, X=X, copy=True, reset=True) 72 | else: 73 | _check_n_features(self, X, reset=True) 74 | self.n_samples_, self.n_features_in_ = X.shape 75 | return self 76 | 77 | def transform(self, X): 78 | """Performs identity "transformation" on `X` - which is no transformation at all. 79 | 80 | Parameters 81 | ---------- 82 | X : array-like of shape (n_samples, n_features) 83 | Input data. 84 | 85 | Returns 86 | ------- 87 | array-like of shape (n_samples, n_features) 88 | Unchanged input data. 89 | 90 | Raises 91 | ------ 92 | ValueError 93 | If the number of columns from `X` differs from the number of columns when fitting. 94 | """ 95 | check_is_fitted(self, "n_features_in_") 96 | 97 | if self.check_X: 98 | X = validate_data(self, X=X, copy=True, reset=False) 99 | else: 100 | _check_n_features(self, X, reset=False) 101 | return X 102 | 103 | @property 104 | def shape_(self): 105 | """Returns the shape of the estimator.""" 106 | return (self.n_samples_, self.n_features_in_) 107 | -------------------------------------------------------------------------------- /sklego/preprocessing/monotonicspline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.base import BaseEstimator, TransformerMixin 3 | from sklearn.preprocessing import SplineTransformer 4 | from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted 5 | from sklearn_compat.utils.validation import validate_data 6 | 7 | 8 | class MonotonicSplineTransformer(TransformerMixin, BaseEstimator): 9 | """The `MonotonicSplineTransformer` integrates the output of the `SplineTransformer` in an attempt to make monotonic features. 10 | 11 | This estimator is heavily inspired by [this blogpost](https://matekadlicsko.github.io/posts/monotonic-splines/) by Mate Kadlicsko. 12 | 13 | Parameters 14 | ---------- 15 | n_knots : int, default=3 16 | The number of knots to use in the spline transformation. 17 | degree : int, default=3 18 | The polynomial degree to use in the spline transformation 19 | knots : Literal['uniform', 'quantile'], default="uniform" 20 | Knots argument of spline transformer 21 | 22 | Attributes 23 | ---------- 24 | spline_transformer_ : trained SplineTransformer 25 | features_in_ : int 26 | The number of features seen in the training data. 27 | 28 | """ 29 | 30 | def __init__(self, n_knots=3, degree=3, knots="uniform"): 31 | self.n_knots = n_knots 32 | self.degree = degree 33 | self.knots = knots 34 | 35 | def fit(self, X, y=None): 36 | """Fit the `MonotonicSplineTransformer` transformer by computing the spline transformation of `X`. 37 | 38 | Parameters 39 | ---------- 40 | X : array-like of shape (n_samples, n_features) 41 | The data to transform. 42 | y : array-like of shape (n_samples,), default=None 43 | Ignored, present for compatibility. 44 | 45 | Returns 46 | ------- 47 | self : MonotonicSplineTransformer 48 | The fitted transformer. 49 | 50 | Raises 51 | ------ 52 | ValueError 53 | If `X` contains non-numeric columns. 54 | """ 55 | X = validate_data(self, X=X, copy=True, ensure_all_finite=False, dtype=FLOAT_DTYPES, reset=True) 56 | # If X contains infs, we need to replace them by nans before computing quantiles 57 | self.spline_transformer_ = { 58 | col: SplineTransformer(n_knots=self.n_knots, degree=self.degree, knots=self.knots).fit( 59 | X[:, col].reshape(-1, 1) 60 | ) 61 | for col in range(X.shape[1]) 62 | } 63 | return self 64 | 65 | def transform(self, X): 66 | """Performs the Ispline transformation on `X`. 67 | 68 | Parameters 69 | ---------- 70 | X : array-like of shape (n_samples, n_features) 71 | 72 | Returns 73 | ------- 74 | X : np.ndarray of shape (n_samples, n_out) 75 | Transformed `X` values. 76 | 77 | Raises 78 | ------ 79 | ValueError 80 | If the number of columns from `X` differs from the number of columns when fitting. 81 | """ 82 | check_is_fitted(self, "spline_transformer_") 83 | X = validate_data(self, X=X, ensure_all_finite=False, dtype=FLOAT_DTYPES, reset=False) 84 | 85 | out = [] 86 | for col in range(X.shape[1]): 87 | out.append( 88 | np.cumsum( 89 | self.spline_transformer_[col].transform(X[:, [col]])[:, ::-1], 90 | axis=1, 91 | ) 92 | ) 93 | return np.concatenate(out, axis=1) 94 | -------------------------------------------------------------------------------- /sklego/preprocessing/outlier_remover.py: -------------------------------------------------------------------------------- 1 | from sklearn import clone 2 | from sklearn.base import BaseEstimator 3 | from sklearn.utils.validation import check_is_fitted 4 | from sklearn_compat.utils.validation import _check_n_features, check_array 5 | 6 | from sklego.common import TrainOnlyTransformerMixin 7 | 8 | 9 | class OutlierRemover(TrainOnlyTransformerMixin, BaseEstimator): 10 | """The `OutlierRemover` transformer removes outliers (train-time only) using the supplied removal model. The 11 | removal model should implement `.fit()` and `.predict()` methods. 12 | 13 | Parameters 14 | ---------- 15 | outlier_detector : scikit-learn compatible estimator 16 | An outlier detector that implements `.fit()` and `.predict()` methods. 17 | refit : bool, default=True 18 | Whether or not to fit the underlying estimator during `OutlierRemover(...).fit()`. 19 | 20 | Attributes 21 | ---------- 22 | estimator_ : object 23 | The fitted outlier detector. 24 | 25 | Examples 26 | -------- 27 | ```py 28 | import numpy as np 29 | 30 | from sklearn.ensemble import IsolationForest 31 | from sklego.preprocessing import OutlierRemover 32 | 33 | np.random.seed(0) 34 | X = np.random.randn(10000, 2) 35 | 36 | isolation_forest = IsolationForest() 37 | isolation_forest.fit(X) 38 | detector_preds = isolation_forest.predict(X) 39 | 40 | outlier_remover = OutlierRemover(isolation_forest, refit=True) 41 | outlier_remover.fit(X) 42 | 43 | X_trans = outlier_remover.transform_train(X) 44 | ``` 45 | """ 46 | 47 | _required_parameters = ["outlier_detector"] 48 | 49 | def __init__(self, outlier_detector, refit=True): 50 | self.outlier_detector = outlier_detector 51 | self.refit = refit 52 | 53 | def fit(self, X, y=None): 54 | """Fit the estimator on training data `X` and `y` by fitting the underlying outlier detector if `refit` is True. 55 | 56 | Parameters 57 | ---------- 58 | X : array-like of shape (n_samples, n_features) 59 | Training data. 60 | y : array-like of shape (n_samples,), default=None 61 | Target values. 62 | 63 | Returns 64 | ------- 65 | self : OutlierRemover 66 | The fitted transformer. 67 | """ 68 | self.estimator_ = clone(self.outlier_detector) 69 | if self.refit: 70 | super().fit(X, y) 71 | self.estimator_.fit(X, y) 72 | _check_n_features(self, X, reset=True) 73 | return self 74 | 75 | def transform_train(self, X): 76 | """Removes outliers from `X` using the fitted estimator. 77 | 78 | Parameters 79 | ---------- 80 | X : array-like of shape (n_samples, n_features) 81 | The data for which the outliers will be removed. 82 | 83 | Returns 84 | ------- 85 | np.ndarray of shape (n_not_outliers, n_features) 86 | The data with the outliers removed, where `n_not_outliers = n_samples - n_outliers`. 87 | """ 88 | check_is_fitted(self, "estimator_") 89 | _check_n_features(self, X, reset=False) 90 | 91 | predictions = self.estimator_.predict(X) 92 | check_array(predictions, estimator=self.outlier_detector, ensure_2d=False) 93 | 94 | return X[predictions != -1] 95 | -------------------------------------------------------------------------------- /sklego/preprocessing/randomadder.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | 3 | from sklearn.base import BaseEstimator 4 | from sklearn.utils.validation import FLOAT_DTYPES, check_is_fitted, check_random_state 5 | from sklearn_compat.utils.validation import validate_data 6 | 7 | from sklego.common import TrainOnlyTransformerMixin 8 | 9 | 10 | class RandomAdder(TrainOnlyTransformerMixin, BaseEstimator): 11 | """The `RandomAdder` transformer adds random noise to the input data. 12 | 13 | This class is designed to be used during the training phase and not for transforming test data. 14 | Noise added is sampled from a normal distribution with mean 0 and standard deviation `noise`. 15 | 16 | Parameters 17 | ---------- 18 | noise : float, default=1.0 19 | The standard deviation of the normal distribution from which the noise is sampled. 20 | random_state : int | None 21 | The seed used by the random number generator. 22 | 23 | Attributes 24 | ---------- 25 | n_features_in_ : int 26 | Number of features seen during `fit`. 27 | dim_ : int 28 | Deprecated, please use `n_features_in_` instead. 29 | 30 | Examples 31 | -------- 32 | ```py 33 | from sklearn.pipeline import Pipeline 34 | from sklearn.linear_model import LinearRegression 35 | from sklego.preprocessing import RandomAdder 36 | 37 | # Create a pipeline with the RandomAdder and a LinearRegression model 38 | pipeline = Pipeline([ 39 | ('random_adder', RandomAdder(noise=0.5, random_state=42)), 40 | ('linear_regression', LinearRegression()) 41 | ]) 42 | 43 | # Fit the pipeline with training data 44 | pipeline.fit(X_train, y_train) 45 | 46 | # Use the fitted pipeline to make predictions 47 | y_pred = pipeline.predict(X_test) 48 | ``` 49 | """ 50 | 51 | def __init__(self, noise=1, random_state=None): 52 | self.noise = noise 53 | self.random_state = random_state 54 | 55 | def fit(self, X, y): 56 | """Fit the transformer on training data `X` and `y` by checking the input data and record the number of 57 | input features. 58 | 59 | Parameters 60 | ---------- 61 | X : array-like of shape (n_samples, n_features) 62 | Training data. 63 | y : array-like of shape (n_samples,) 64 | Target values. 65 | 66 | Returns 67 | ------- 68 | self : RandomAdder 69 | The fitted transformer. 70 | """ 71 | super().fit(X, y) 72 | X, y = validate_data(self, X=X, y=y, dtype=FLOAT_DTYPES, reset=True) 73 | 74 | return self 75 | 76 | def transform_train(self, X): 77 | r"""Transform training data by adding random noise sampled from $N(0, \text{noise})$. 78 | 79 | Parameters 80 | ---------- 81 | X : array-like of shape (n_samples, n_features) 82 | The data for which the noise will be added. 83 | 84 | Returns 85 | ------- 86 | np.ndarray of shape (n_samples, n_features) 87 | The data with the noise added. 88 | """ 89 | rs = check_random_state(self.random_state) 90 | check_is_fitted(self, ["n_features_in_"]) 91 | X = validate_data(self, X=X, dtype=FLOAT_DTYPES, reset=False) 92 | 93 | return X + rs.normal(0, self.noise, size=X.shape) 94 | 95 | @property 96 | def dim_(self): 97 | warn( 98 | "Please use `n_features_in_` instead of `dim_`, `dim_` will be deprecated in future versions", 99 | DeprecationWarning, 100 | ) 101 | return self.n_features_in_ 102 | 103 | def _more_tags(self): 104 | return {"non_deterministic": True} 105 | 106 | def __sklearn_tags__(self): 107 | tags = super().__sklearn_tags__() 108 | tags.non_deterministic = True 109 | return tags 110 | -------------------------------------------------------------------------------- /sklego/testing.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | from copy import copy 3 | 4 | from sklearn.datasets import make_classification, make_regression 5 | from sklearn.utils._testing import ignore_warnings 6 | 7 | 8 | @ignore_warnings(category=(DeprecationWarning, FutureWarning)) 9 | def check_shape_remains_same_regressor(name, regressor): 10 | """Ensure that the estimator does not change the shape of prediction.""" 11 | seed = 42 12 | 13 | for n_samples, n_feat in it.product((100, 1000, 10000), (2, 5, 10)): 14 | regr = copy(regressor) 15 | X, y = make_regression( 16 | n_samples=n_samples, 17 | n_features=n_feat, 18 | n_informative=n_feat, 19 | noise=2, 20 | random_state=seed, 21 | ) 22 | pred = regr.fit(X, y).predict(X) 23 | assert y.shape[0] == pred.shape[0] 24 | 25 | 26 | @ignore_warnings(category=(DeprecationWarning, FutureWarning)) 27 | def check_shape_remains_same_classifier(name, classifier): 28 | """Ensure that the estimator does not change the shape of prediction.""" 29 | seed = 42 30 | 31 | for n_samples, n_feat in it.product((100, 1000, 10000), (2, 5, 10)): 32 | clf = copy(classifier) 33 | X, y = make_classification( 34 | n_samples=n_samples, 35 | n_features=n_feat, 36 | n_informative=n_feat, 37 | n_redundant=0, 38 | n_repeated=0, 39 | random_state=seed, 40 | ) 41 | pred = clf.fit(X, y).predict(X) 42 | assert y.shape[0] == pred.shape[0] 43 | -------------------------------------------------------------------------------- /sklego/this.py: -------------------------------------------------------------------------------- 1 | poem = """ 2 | Roses are red, violets are blue, 3 | naming your package is really hard to undo. 4 | Haste can make decisions in one fell swoop, 5 | note that LEGO® is a trademark of the LEGO Group. 6 | It really makes sense, we do not need to bluff, 7 | LEGO does not sponsor, authorize or endorse any of this stuff. 8 | 9 | Look at all the features and look at all the extensions, 10 | the path towards ruin is paved with good intentions. 11 | Be careful with features as they tend to go sour, 12 | defer responsibility to the end user, this might just give them power. 13 | If you don't know the requirements you don't know if they're met. 14 | If you haven't gotten to where you're going, you aren't there yet. 15 | 16 | Infinity is ever present, the unknown may be ignored, 17 | not everything needs to be built, not everything needs to be explored. 18 | Change everything and you'll soon be a jerk, 19 | you may invent a new tool, not a way to work. 20 | Some problems cannot be solved in a single day, 21 | but if you can ignore them, they sometimes go away. 22 | 23 | So as we forge ahead, let's remember the creed, 24 | simplicity over complexity, our library's seed. 25 | In the maze of features, let's not lose sight, 26 | of the end goal in mind shining bright. 27 | 28 | With each new feature, a temptation to craft, 29 | but elegance is found in what we choose to subtract. 30 | For every line of code, let's ask ourselves twice, 31 | does it add clarity or is it a vice? 32 | 33 | There's a lot of power in simplicity, 34 | it keeps the approach strong, 35 | if you understand the solution better than the problem, 36 | you're doing it wrong. 37 | """ 38 | 39 | print(poem) # noqa: T201 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | import pytest 7 | 8 | n_vals = (10, 500) 9 | k_vals = (1, 5) 10 | np_types = (np.int32, np.float32, np.float64) 11 | 12 | 13 | def select_tests(include, exclude=[]): 14 | """Return an iterable of include with all tests whose name is not in exclude""" 15 | for test in include: 16 | if test.__name__ not in exclude: 17 | yield test 18 | 19 | 20 | @pytest.fixture(scope="module", params=[_ for _ in it.product(n_vals, k_vals, np_types)]) 21 | def random_xy_dataset_regr(request): 22 | n, k, np_type = request.param 23 | np.random.seed(42) 24 | X = np.random.normal(0, 2, (n, k)).astype(np_type) 25 | y = np.random.normal(0, 2, (n,)) 26 | return X, y 27 | 28 | 29 | @pytest.fixture(scope="module", params=[_ for _ in it.product([10, 100], [1, 2, 3], np_types)]) 30 | def random_xy_dataset_regr_small(request): 31 | n, k, np_type = request.param 32 | np.random.seed(42) 33 | X = np.random.normal(0, 2, (n, k)).astype(np_type) 34 | y = np.random.normal(0, 2, (n,)) 35 | return X, y 36 | 37 | 38 | @pytest.fixture(scope="module", params=[_ for _ in it.product(n_vals, k_vals, np_types)]) 39 | def random_xy_dataset_clf(request): 40 | n, k, np_type = request.param 41 | np.random.seed(42) 42 | X = np.random.normal(0, 2, (n, k)).astype(np_type) 43 | y = np.random.normal(0, 2, (n,)) > 0.0 44 | return X, y 45 | 46 | 47 | @pytest.fixture(scope="module", params=[_ for _ in it.product(n_vals, k_vals, np_types)]) 48 | def random_xy_dataset_multiclf(request): 49 | n, k, np_type = request.param 50 | np.random.seed(42) 51 | X = np.random.normal(0, 2, (n, k)).astype(np_type) 52 | y = pd.cut(np.random.normal(0, 2, (n,)), 3).codes 53 | return X, y 54 | 55 | 56 | @pytest.fixture(scope="module", params=[_ for _ in it.product(n_vals, k_vals, np_types)]) 57 | def random_xy_dataset_multitarget(request): 58 | n, k, np_type = request.param 59 | np.random.seed(42) 60 | X = np.random.normal(0, 2, (n, k)).astype(np_type) 61 | y = np.random.randint(0, 2, (n, k)) > 0.0 62 | return X, y 63 | 64 | 65 | @pytest.fixture(params=[pd.DataFrame, pl.DataFrame]) 66 | def funct(request): 67 | return request.param 68 | 69 | 70 | @pytest.fixture 71 | def sensitive_classification_dataset(funct): 72 | df = funct( 73 | { 74 | "x1": [1, 0, 1, 0, 1, 0, 1, 1], 75 | "x2": [0, 0, 0, 0, 0, 1, 1, 1], 76 | "y": [1, 1, 1, 0, 1, 0, 0, 0], 77 | } 78 | ) 79 | 80 | return df[["x1", "x2"]], df["y"] 81 | 82 | 83 | @pytest.fixture 84 | def sensitive_multiclass_classification_dataset(): 85 | df = pd.DataFrame( 86 | { 87 | "x1": [1, 0, 1, 0, 1, 0, 1, 1, -2, -2, -2, -2], 88 | "x2": [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1], 89 | "y": [1, 1, 1, 0, 1, 0, 0, 0, 2, 2, 0, 0], 90 | } 91 | ) 92 | return df[["x1", "x2"]], df["y"] 93 | 94 | 95 | def id_func(param): 96 | """Returns the repr of an object for usage in pytest parametrize""" 97 | return repr(param) 98 | -------------------------------------------------------------------------------- /tests/scripts/check_pip.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import sys 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("verb", help="installed/missing") 7 | parser.add_argument("packages", help="list of items to be there/not be there", nargs="+") 8 | 9 | if __name__ == "__main__": 10 | args = parser.parse_args() 11 | installed = subprocess.check_output([sys.executable, "-m", "pip", "freeze"]).decode("utf-8") 12 | for pkg in args.packages: 13 | if args.verb == "missing": 14 | if pkg in installed: 15 | raise ValueError(f"Expected {pkg} to not be installed.") 16 | if args.verb == "installed": 17 | if pkg not in installed: 18 | raise ValueError(f"Expected {pkg} to be installed.") 19 | -------------------------------------------------------------------------------- /tests/scripts/import_all.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from sklego import * 3 | from sklego.decomposition import * 4 | from sklego.meta import * 5 | from sklego.mixture import * 6 | from sklego.preprocessing import * 7 | from sklego.base import * 8 | from sklego.common import * 9 | from sklego.datasets import * 10 | from sklego.dummy import * 11 | from sklego.linear_model import * 12 | from sklego.metrics import * 13 | from sklego.model_selection import * 14 | from sklego.naive_bayes import * 15 | from sklego.neighbors import * 16 | from sklego.notinstalled import * 17 | from sklego.pandas_utils import * 18 | from sklego.pipeline import * 19 | from sklego.testing import * 20 | from sklego.this import * 21 | -------------------------------------------------------------------------------- /tests/test_common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_common/__init__.py -------------------------------------------------------------------------------- /tests/test_common/test_basics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sklego.common import as_list, sliding_window 4 | 5 | 6 | def test_as_list_strings(): 7 | assert as_list("test") == ["test"] 8 | assert as_list(["test1", "test2"]) == ["test1", "test2"] 9 | 10 | 11 | def test_as_list_ints(): 12 | assert as_list(123) == [123] 13 | assert as_list([1, 2, 3]) == [1, 2, 3] 14 | 15 | 16 | def test_as_list_other(): 17 | def f(): 18 | return 123 19 | 20 | assert as_list(f) == [f] 21 | assert as_list(range(1, 4)) == [1, 2, 3] 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "sequence, window_size, step_size", 26 | [([1, 2, 3, 4, 5], 2, 1), (["a", "b", "c", "d", "e"], 3, 2)], 27 | ) 28 | def test_sliding_window(sequence, window_size, step_size): 29 | windows = list(sliding_window(sequence, window_size, step_size)) 30 | assert windows[0] == sequence[:window_size] 31 | assert len(windows[0]) == window_size 32 | assert windows[1][0] == sequence[step_size] 33 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sklego.datasets import load_abalone, load_chicken, load_hearts, load_penguins, make_simpleseries 4 | 5 | 6 | def test_chickweight1(): 7 | X, y = load_chicken(return_X_y=True) 8 | assert X.shape == (578, 3) 9 | assert y.shape[0] == 578 10 | 11 | 12 | def test_chickweight2(): 13 | df = load_chicken(as_frame=True) 14 | assert df.shape == (578, 4) 15 | 16 | 17 | def test_abalone1(): 18 | X, y = load_abalone(return_X_y=True) 19 | assert X.shape == (4177, 8) 20 | assert y.shape[0] == 4177 21 | 22 | 23 | def test_abalone2(): 24 | df = load_abalone(as_frame=True) 25 | assert df.shape == (4177, 9) 26 | 27 | 28 | def test_simpleseries_constant_season(): 29 | df = make_simpleseries( 30 | n_samples=365 * 2, 31 | as_frame=True, 32 | start_date="2018-01-01", 33 | trend=0, 34 | noise=0, 35 | season_trend=0, 36 | ).assign(month=lambda d: d["date"].dt.month, year=lambda d: d["date"].dt.year) 37 | agg = df.groupby(["year", "month"]).mean().reset_index() 38 | var = agg.loc[lambda d: d["month"] == 1]["yt"].var() 39 | assert var == pytest.approx(0.0, abs=0.01) 40 | 41 | 42 | def test_load_hearts(): 43 | df = load_hearts(as_frame=True) 44 | assert df.shape == (303, 14) 45 | 46 | 47 | def test_penguin1(): 48 | X, y = load_penguins(return_X_y=True) 49 | assert X.shape == (344, 6) 50 | assert y.shape[0] == 344 51 | 52 | 53 | def test_penguin2(): 54 | df = load_penguins(as_frame=True) 55 | assert df.shape == (344, 7) 56 | -------------------------------------------------------------------------------- /tests/test_estimators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_estimators/__init__.py -------------------------------------------------------------------------------- /tests/test_estimators/test_deadzone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.linear_model import DeadZoneRegressor 6 | from sklego.testing import check_shape_remains_same_regressor 7 | 8 | 9 | @parametrize_with_checks([DeadZoneRegressor()]) 10 | def test_sklearn_compatible_estimator(estimator, check): 11 | check(estimator) 12 | 13 | 14 | @pytest.fixture 15 | def dataset(): 16 | np.random.seed(42) 17 | n = 100 18 | inputs = np.concatenate([np.ones((n, 1)), np.random.normal(0, 1, (n, 1))], axis=1) 19 | targets = 3.1 * inputs[:, 0] + 2.0 * inputs[:, 1] 20 | return inputs, targets 21 | 22 | 23 | @pytest.fixture(scope="module", params=["constant", "linear", "quadratic"]) 24 | def mod(request): 25 | return DeadZoneRegressor(effect=request.param, threshold=0.3) 26 | 27 | 28 | @pytest.mark.parametrize("test_fn", [check_shape_remains_same_regressor]) 29 | def test_deadzone(test_fn): 30 | regr = DeadZoneRegressor() 31 | test_fn(DeadZoneRegressor.__name__, regr) 32 | 33 | 34 | def test_values_uniform(dataset, mod): 35 | if mod.effect == "constant": 36 | pytest.skip("Constant effect") 37 | X, y = dataset 38 | coefs = mod.fit(X, y).coef_ 39 | assert coefs[0] == pytest.approx(3.1, abs=0.2) 40 | assert coefs[1] == pytest.approx(2.0, abs=0.2) 41 | -------------------------------------------------------------------------------- /tests/test_estimators/test_gmm_naive_bayes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.naive_bayes import BayesianGaussianMixtureNB, GaussianMixtureNB 6 | 7 | 8 | @parametrize_with_checks( 9 | [ 10 | GaussianMixtureNB(), 11 | GaussianMixtureNB(n_components=2), 12 | BayesianGaussianMixtureNB(), 13 | BayesianGaussianMixtureNB(n_components=2), 14 | ], 15 | ) 16 | def test_sklearn_compatible_estimator(estimator, check): 17 | check(estimator) 18 | 19 | 20 | @pytest.fixture 21 | def dataset(): 22 | np.random.seed(42) 23 | return np.concatenate([np.random.normal(0, 1, (2000, 2))]) 24 | 25 | 26 | @pytest.mark.parametrize("k", [1, 5, 10]) 27 | def test_obvious_usecase(k): 28 | X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) 29 | y = np.concatenate([np.zeros(100), np.ones(100)]) 30 | assert (GaussianMixtureNB(n_components=k, max_iter=1000).fit(X, y).predict(X) == y).all() 31 | assert (BayesianGaussianMixtureNB(n_components=k, max_iter=1000).fit(X, y).predict(X) == y).all() 32 | -------------------------------------------------------------------------------- /tests/test_estimators/test_lowess.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.utils.estimator_checks import parametrize_with_checks 6 | 7 | from sklego.linear_model import LowessRegression 8 | 9 | 10 | @parametrize_with_checks([LowessRegression()]) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | check(estimator) 13 | 14 | 15 | def test_obvious_usecase(): 16 | x = np.linspace(0, 10, 100) 17 | X = x.reshape(-1, 1) 18 | y = np.ones(x.shape) 19 | y_pred = LowessRegression().fit(X, y).predict(X) 20 | assert np.isclose(y, y_pred).all() 21 | 22 | 23 | def test_custom_error_for_zero_division(): 24 | x = np.arange(0, 100) 25 | X = x.reshape(-1, 1) 26 | y = np.ones(x.shape) 27 | estimator = LowessRegression(sigma=1e-10).fit(X, y) 28 | 29 | with pytest.raises( 30 | ValueError, match=re.escape("Weights, resulting from `np.exp(-(distances**2) / self.sigma)`, are all zero.") 31 | ): 32 | # Adding an offset, otherwise X to predict would be the same as X in fit method, 33 | # leading to weight of 1 for the corresponding value. 34 | X_pred = X[:10] + 0.5 35 | estimator.predict(X_pred) 36 | -------------------------------------------------------------------------------- /tests/test_estimators/test_mixture_classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.mixture import BayesianGMMClassifier, GMMClassifier 6 | 7 | 8 | @parametrize_with_checks([GMMClassifier(), BayesianGMMClassifier()]) 9 | def test_sklearn_compatible_estimator(estimator, check): 10 | check(estimator) 11 | 12 | 13 | @pytest.mark.parametrize("clf", [GMMClassifier(max_iter=1000), BayesianGMMClassifier(max_iter=1000)]) 14 | def test_obvious_usecase(clf): 15 | X = np.concatenate([np.random.normal(-10, 1, (100, 2)), np.random.normal(10, 1, (100, 2))]) 16 | y = np.concatenate([np.zeros(100), np.ones(100)]) 17 | assert (clf.fit(X, y).predict(X) == y).all() 18 | -------------------------------------------------------------------------------- /tests/test_estimators/test_mixture_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | 6 | from sklego.mixture import BayesianGMMOutlierDetector, GMMOutlierDetector 7 | 8 | 9 | @parametrize_with_checks( 10 | [ 11 | GMMOutlierDetector(threshold=0.999, method="quantile"), 12 | GMMOutlierDetector(threshold=2, method="stddev"), 13 | BayesianGMMOutlierDetector(threshold=0.999, method="quantile"), 14 | BayesianGMMOutlierDetector(threshold=2, method="stddev"), 15 | ] 16 | ) 17 | def test_sklearn_compatible_estimator(estimator, check): 18 | check(estimator) 19 | 20 | 21 | @pytest.fixture 22 | def dataset(): 23 | np.random.seed(42) 24 | return np.concatenate([np.random.normal(0, 1, (2000, 2))]) 25 | 26 | 27 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 28 | def test_obvious_usecase_quantile(dataset, model): 29 | mod = model(n_components=2, threshold=0.999, method="quantile").fit(dataset) 30 | assert mod.predict([[10, 10]]) == np.array([-1]) 31 | assert mod.predict([[0, 0]]) == np.array([1]) 32 | 33 | 34 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 35 | def test_obvious_usecase_stddev(dataset, model): 36 | mod = model(n_components=2, threshold=2, method="stddev").fit(dataset) 37 | assert mod.predict([[10, 10]]) == np.array([-1]) 38 | assert mod.predict([[0, 0]]) == np.array([1]) 39 | 40 | 41 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 42 | @pytest.mark.parametrize( 43 | "kwargs", 44 | [ 45 | {"threshold": 10}, 46 | {"threshold": -10}, 47 | {"threshold": -10, "method": "stddev"}, 48 | ], 49 | ) 50 | def test_value_error_threshold(dataset, model, kwargs): 51 | with pytest.raises(ValueError): 52 | model(**kwargs).fit(dataset) 53 | 54 | 55 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 56 | def test_thresh_effect_stddev(dataset, model): 57 | mod1 = model(threshold=0.5, method="stddev").fit(dataset) 58 | mod2 = model(threshold=1, method="stddev").fit(dataset) 59 | mod3 = model(threshold=2, method="stddev").fit(dataset) 60 | n_outliers1 = np.sum(mod1.predict(dataset) == -1) 61 | n_outliers2 = np.sum(mod2.predict(dataset) == -1) 62 | n_outliers3 = np.sum(mod3.predict(dataset) == -1) 63 | assert n_outliers1 > n_outliers2 64 | assert n_outliers2 > n_outliers3 65 | 66 | 67 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 68 | def test_thresh_effect_quantile(dataset, model): 69 | mod1 = model(threshold=0.90, method="quantile").fit(dataset) 70 | mod2 = model(threshold=0.95, method="quantile").fit(dataset) 71 | mod3 = model(threshold=0.99, method="quantile").fit(dataset) 72 | n_outliers1 = np.sum(mod1.predict(dataset) == -1) 73 | n_outliers2 = np.sum(mod2.predict(dataset) == -1) 74 | n_outliers3 = np.sum(mod3.predict(dataset) == -1) 75 | assert n_outliers1 > n_outliers2 76 | assert n_outliers2 > n_outliers3 77 | 78 | 79 | @pytest.mark.parametrize("model", [GMMOutlierDetector, BayesianGMMOutlierDetector]) 80 | def test_obvious_usecase_github(model): 81 | # from this bug: https://github.com/koaning/scikit-lego/issues/225 thanks Corrie! 82 | np.random.seed(42) 83 | X = np.random.normal(-10, 1, (2000, 2)) 84 | mod = model(n_components=1, threshold=0.99).fit(X) 85 | 86 | df = pd.DataFrame( 87 | { 88 | "x1": X[:, 0], 89 | "x2": X[:, 1], 90 | "loglik": mod.score_samples(X), 91 | "prediction": mod.predict(X).astype(str), 92 | } 93 | ) 94 | assert df.loc[lambda d: d["prediction"] == "-1"].shape[0] == 20 95 | -------------------------------------------------------------------------------- /tests/test_estimators/test_neighbor_classifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # signature(check.func.__dict__["__wrapped__"]).parameters["kind"] 4 | import numpy as np 5 | import pytest 6 | from sklearn.utils.estimator_checks import parametrize_with_checks 7 | 8 | from sklego.neighbors import BayesianKernelDensityClassifier 9 | 10 | 11 | @parametrize_with_checks([BayesianKernelDensityClassifier()]) 12 | def test_sklearn_compatible_estimator(estimator, check): 13 | if ( 14 | sys.version_info < (3, 9) 15 | and check.func.__name__ == "check_classifiers_train" 16 | and getattr(check, "keywords", {}).get("readonly_memmap") is True 17 | ): 18 | pytest.skip() 19 | 20 | check(estimator) 21 | 22 | 23 | @pytest.fixture() 24 | def simple_dataset(): 25 | # Two linearly separable mvn should have a 100% prediction accuracy 26 | x = np.concatenate([np.random.normal(-1000, 0.01, (100, 2)), np.random.normal(1000, 0.01, (100, 2))]) 27 | y = np.concatenate([np.zeros(100), np.ones(100)]) 28 | return x, y 29 | 30 | 31 | def test_trivial_classification(simple_dataset): 32 | x, y = simple_dataset 33 | model = BayesianKernelDensityClassifier().fit(x, y) 34 | assert (model.predict(x) == y).all() 35 | -------------------------------------------------------------------------------- /tests/test_estimators/test_pca_reconstruction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.decomposition import PCAOutlierDetection 6 | 7 | 8 | # Remark that as some tests only have 2 features, we need to pass less components, otherwise no outlier is detected 9 | @parametrize_with_checks([PCAOutlierDetection(n_components=1, threshold=0.05, random_state=42, variant="relative")]) 10 | def test_sklearn_compatible_estimator(estimator, check): 11 | check(estimator) 12 | 13 | 14 | @pytest.fixture 15 | def dataset(): 16 | np.random.seed(42) 17 | return np.concatenate([np.random.normal(0, 1, (2000, 10))]) 18 | 19 | 20 | def test_obvious_usecase(dataset): 21 | mod = PCAOutlierDetection(n_components=2, threshold=2.5, random_state=42, variant="absolute").fit(dataset) 22 | assert mod.predict([[10] * 10]) == np.array([-1]) 23 | assert mod.predict([[0.01] * 10]) == np.array([1]) 24 | -------------------------------------------------------------------------------- /tests/test_estimators/test_probweight_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.linear_model import ProbWeightRegression 6 | 7 | pytestmark = pytest.mark.cvxpy 8 | 9 | 10 | @parametrize_with_checks([ProbWeightRegression(non_negative=True), ProbWeightRegression(non_negative=False)]) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | check(estimator) 13 | 14 | 15 | def test_shape_trained_model(random_xy_dataset_regr): 16 | X, y = random_xy_dataset_regr 17 | mod_no_intercept = ProbWeightRegression() 18 | assert mod_no_intercept.fit(X, y).coef_.shape == (X.shape[1],) 19 | np.testing.assert_approx_equal(mod_no_intercept.fit(X, y).coef_.sum(), 1.0, significant=4) 20 | -------------------------------------------------------------------------------- /tests/test_estimators/test_quantile_regression.py: -------------------------------------------------------------------------------- 1 | """Test the QuantileRegression.""" 2 | 3 | from itertools import product 4 | 5 | import numpy as np 6 | import pytest 7 | from sklearn.utils.estimator_checks import parametrize_with_checks 8 | 9 | from sklego.linear_model import QuantileRegression 10 | from sklego.testing import check_shape_remains_same_regressor 11 | 12 | test_batch = [ 13 | (np.array([0, 0, 3, 0, 6]), 3), 14 | (np.array([1, 0, -2, 0, 4, 0, -5, 0, 6]), 2), 15 | (np.array([4, -4]), 0), 16 | ] 17 | 18 | 19 | def _create_dataset(coefs, intercept, noise=0.0): 20 | np.random.seed(0) 21 | size = 1_000 22 | X = np.random.randn(size, coefs.shape[0]) 23 | y = X @ coefs + intercept + noise * np.random.randn(size) 24 | 25 | return X, y 26 | 27 | 28 | @parametrize_with_checks( 29 | [ 30 | QuantileRegression(**dict(zip(["quantile", "positive", "fit_intercept", "method"], args))) 31 | for args in product([0.5, 0.3], [True, False], [True, False], ["SLSQP", "TNC", "L-BFGS-B"]) 32 | ] 33 | ) 34 | def test_sklearn_compatible_estimator(estimator, check): 35 | if check.func.__name__ in { 36 | "check_sample_weights_invariance", 37 | "check_sample_weight_equivalence_on_dense_data", 38 | "check_sample_weights_invariance", 39 | }: 40 | pytest.skip() 41 | check(estimator) 42 | 43 | 44 | @pytest.mark.parametrize("method", ["SLSQP", "TNC", "L-BFGS-B"]) 45 | @pytest.mark.parametrize("coefs, intercept", test_batch) 46 | @pytest.mark.parametrize("noise, expected", [(0.0, 0.99), (1.0, 0.9)]) 47 | def test_coefs_and_intercept(method, coefs, intercept, noise, expected): 48 | """Regression problems with different level of noise.""" 49 | X, y = _create_dataset(coefs, intercept, noise=noise) 50 | quant = QuantileRegression(method=method) 51 | quant.fit(X, y) 52 | assert quant.score(X, y) > expected 53 | 54 | 55 | @pytest.mark.parametrize("method", ["SLSQP", "TNC", "L-BFGS-B"]) 56 | @pytest.mark.parametrize("coefs, intercept", test_batch) 57 | @pytest.mark.parametrize("quantile", np.linspace(0, 1, 7)) 58 | def test_quantile(coefs, intercept, quantile, method): 59 | """Tests with noise on an easy problem. A good score should be possible.""" 60 | X, y = _create_dataset(coefs, intercept, noise=1.0) 61 | quant = QuantileRegression(method=method, quantile=quantile) 62 | quant.fit(X, y) 63 | 64 | np.testing.assert_almost_equal((quant.predict(X) >= y).mean(), quantile, decimal=2) 65 | 66 | 67 | @pytest.mark.parametrize("method", ["SLSQP", "TNC", "L-BFGS-B"]) 68 | @pytest.mark.parametrize("coefs, intercept", test_batch) 69 | def test_coefs_and_intercept__no_noise_positive(coefs, intercept, method): 70 | """Test with only positive coefficients.""" 71 | X, y = _create_dataset(coefs, intercept, noise=0.0) 72 | quant = QuantileRegression(method=method, positive=True) 73 | quant.fit(X, y) 74 | assert all(quant.coef_ >= 0) 75 | assert quant.score(X, y) > 0.3 76 | 77 | 78 | @pytest.mark.parametrize("coefs, intercept", test_batch) 79 | def test_coefs_and_intercept__no_noise_regularization(coefs, intercept): 80 | """Test model with regularization. The size of the coef vector should shrink the larger alpha gets.""" 81 | X, y = _create_dataset(coefs, intercept) 82 | 83 | quants = [QuantileRegression(alpha=alpha, l1_ratio=0.0).fit(X, y) for alpha in range(3)] 84 | coef_size = np.array([np.sum(quant.coef_**2) for quant in quants]) 85 | 86 | for i in range(2): 87 | assert coef_size[i] >= coef_size[i + 1] 88 | 89 | 90 | @pytest.mark.parametrize("coefs, intercept", test_batch) 91 | def test_fit_intercept_and_copy(coefs, intercept): 92 | """Test if fit_intercept and copy_X work.""" 93 | X, y = _create_dataset(coefs, intercept, noise=2.0) 94 | imb = QuantileRegression(fit_intercept=False, copy_X=False) 95 | imb.fit(X, y) 96 | 97 | assert imb.intercept_ == 0.0 98 | 99 | 100 | @pytest.mark.parametrize("test_fn", [check_shape_remains_same_regressor]) 101 | def test_quant(test_fn): 102 | regr = QuantileRegression() 103 | test_fn(QuantileRegression.__name__, regr) 104 | -------------------------------------------------------------------------------- /tests/test_estimators/test_randomregressor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.dummy import RandomRegressor 6 | 7 | 8 | @parametrize_with_checks( 9 | [RandomRegressor(strategy="normal", random_state=42), RandomRegressor(strategy="uniform", random_state=42)] 10 | ) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | if check.func.__name__ in {"check_methods_subset_invariance", "check_methods_sample_order_invariance"}: 13 | pytest.skip("RandomRegressor is not invariant") 14 | check(estimator) 15 | 16 | 17 | def test_values_uniform(random_xy_dataset_regr): 18 | X, y = random_xy_dataset_regr 19 | mod = RandomRegressor(strategy="uniform") 20 | predictions = mod.fit(X, y).predict(X) 21 | assert (predictions >= y.min()).all() 22 | assert (predictions <= y.max()).all() 23 | assert mod.min_ == pytest.approx(y.min(), abs=0.0001) 24 | assert mod.max_ == pytest.approx(y.max(), abs=0.0001) 25 | 26 | 27 | def test_values_normal(random_xy_dataset_regr): 28 | X, y = random_xy_dataset_regr 29 | mod = RandomRegressor(strategy="normal").fit(X, y) 30 | assert mod.mu_ == pytest.approx(np.mean(y), abs=0.001) 31 | assert mod.sigma_ == pytest.approx(np.std(y), abs=0.001) 32 | 33 | 34 | def test_bad_values(): 35 | np.random.seed(42) 36 | X = np.random.normal(0, 1, (10, 2)) 37 | y = np.random.normal(0, 1, (10, 1)) 38 | with pytest.raises(ValueError): 39 | RandomRegressor(strategy="foobar").fit(X, y) 40 | -------------------------------------------------------------------------------- /tests/test_estimators/test_umap_reconstruction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.decomposition import UMAPOutlierDetection 6 | 7 | pytestmark = pytest.mark.umap 8 | 9 | 10 | @parametrize_with_checks([UMAPOutlierDetection(n_components=2, threshold=0.1, n_neighbors=3)]) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | if check.func.__name__ in { 13 | # numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend) 14 | "check_estimators_pickle", 15 | # ValueError: Need at least 2-D data 16 | # in `predict`: np.sum(np.abs(self.umap_.inverse_transform(reduced) - X), axis=1) 17 | "check_dict_unchanged", 18 | }: 19 | pytest.skip() 20 | 21 | check(estimator) 22 | 23 | 24 | @pytest.fixture 25 | def dataset(): 26 | np.random.seed(42) 27 | return np.concatenate([np.random.normal(0, 1, (200, 10))]) 28 | 29 | 30 | def test_obvious_usecase(dataset): 31 | mod = UMAPOutlierDetection( 32 | n_components=2, 33 | threshold=7.5, 34 | random_state=42, 35 | variant="absolute", 36 | n_neighbors=3, 37 | ).fit(dataset) 38 | assert mod.predict([[10] * 10]) == np.array([-1]) 39 | assert mod.predict([[0.01] * 10]) == np.array([1]) 40 | -------------------------------------------------------------------------------- /tests/test_feature_selection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_feature_selection/__init__.py -------------------------------------------------------------------------------- /tests/test_meta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_meta/__init__.py -------------------------------------------------------------------------------- /tests/test_meta/test_confusion_balancer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LogisticRegression 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.meta import ConfusionBalancer 6 | 7 | 8 | @parametrize_with_checks( 9 | [ 10 | ConfusionBalancer(estimator=LogisticRegression(solver="lbfgs"), alpha=alpha, cfm_smooth=cfm_smooth) 11 | for alpha in (0.1, 0.5, 0.9) 12 | for cfm_smooth in (0, 1, 2) 13 | ] 14 | ) 15 | def test_sklearn_compatible_estimator(estimator, check): 16 | check(estimator) 17 | 18 | 19 | def test_sum_equals_one(): 20 | np.random.seed(42) 21 | n1, n2 = 100, 500 22 | X = np.concatenate([np.random.normal(0, 1, (n1, 2)), np.random.normal(2, 1, (n2, 2))], axis=0) 23 | y = np.concatenate([np.zeros((n1, 1)), np.ones((n2, 1))], axis=0).reshape(-1) 24 | mod = ConfusionBalancer( 25 | LogisticRegression(solver="lbfgs", multi_class="multinomial", max_iter=1000), 26 | alpha=0.1, 27 | ) 28 | mod.fit(X, y) 29 | assert np.all(np.isclose(mod.predict_proba(X).sum(axis=1), 1, atol=0.001)) 30 | -------------------------------------------------------------------------------- /tests/test_meta/test_decay_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.base import is_classifier, is_regressor 4 | from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge 5 | from sklearn.neighbors import KNeighborsClassifier 6 | from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor 7 | from sklearn.utils.estimator_checks import parametrize_with_checks 8 | 9 | from sklego.meta import DecayEstimator 10 | 11 | 12 | @parametrize_with_checks( 13 | [ 14 | DecayEstimator(LinearRegression(), check_input=True, decay_func=decay_func) 15 | for decay_func in ("linear", "exponential", "sigmoid") 16 | ] 17 | ) 18 | def test_sklearn_compatible_estimator(estimator, check): 19 | if check.func.__name__ in { 20 | "check_no_attributes_set_in_init", # Setting **kwargs in init 21 | "check_regressor_multioutput", # incompatible between pre and post 1.6 22 | }: 23 | pytest.skip() 24 | 25 | check(estimator) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "mod, is_clf", 30 | [ 31 | (LinearRegression(), False), 32 | (Ridge(), False), 33 | (DecisionTreeRegressor(), False), 34 | (DecisionTreeClassifier(), True), 35 | (LogisticRegression(solver="lbfgs"), True), 36 | ], 37 | ) 38 | @pytest.mark.parametrize( 39 | "decay_func, decay_kwargs", 40 | [ 41 | ("exponential", {"decay_rate": 0.999}), 42 | ("exponential", {"decay_rate": 0.99}), 43 | ("linear", {"min_value": 0.0, "max_value": 1.0}), 44 | ("linear", {"min_value": 0.5, "max_value": 1.0}), 45 | ("sigmoid", {"growth_rate": 0.1}), 46 | ("sigmoid", {"growth_rate": None}), 47 | ("stepwise", {"n_steps": 10}), 48 | ("stepwise", {"step_size": 2}), 49 | ], 50 | ) 51 | def test_decay_weight(mod, is_clf, decay_func, decay_kwargs): 52 | X, y = np.random.normal(0, 1, (100, 100)), np.random.normal(0, 1, (100,)) 53 | 54 | if is_clf: 55 | y = (y < 0).astype(int) 56 | 57 | mod = DecayEstimator(mod, decay_func=decay_func, decay_kwargs=decay_kwargs).fit(X, y) 58 | 59 | assert np.logical_and(mod.weights_ >= 0, mod.weights_ <= 1).all() 60 | assert np.all(mod.weights_[:-1] <= mod.weights_[1:]) 61 | 62 | 63 | @pytest.mark.parametrize("mod", [KNeighborsClassifier()]) 64 | def test_throw_warning(mod): 65 | X, y = np.random.normal(0, 1, (100, 100)), np.random.normal(0, 1, (100,)) < 0 66 | with pytest.raises(TypeError) as e: 67 | DecayEstimator(mod, decay_rate=0.95).fit(X, y) 68 | assert "sample_weight" in str(e) 69 | assert type(mod).__name__ in str(e) 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "mod, is_regr", 74 | [ 75 | (LinearRegression(), True), 76 | (Ridge(), True), 77 | (DecisionTreeRegressor(), True), 78 | (LogisticRegression(), False), 79 | (DecisionTreeClassifier(), False), 80 | ], 81 | ) 82 | def test_estimator_type_regressor(mod, is_regr): 83 | mod = DecayEstimator(mod) 84 | assert mod._estimator_type == mod.model._estimator_type 85 | assert is_regressor(mod) == is_regr 86 | assert is_classifier(mod) == (not is_regr) 87 | -------------------------------------------------------------------------------- /tests/test_meta/test_decay_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from sklego.meta._decay_utils import exponential_decay, linear_decay, sigmoid_decay, stepwise_decay 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "kwargs, context", 11 | [ 12 | ({"min_value": 0.1, "max_value": 0.9}, does_not_raise()), 13 | ({"min_value": 0.1, "max_value": 10}, does_not_raise()), 14 | ({"min_value": 0.5, "max_value": 0.1}, pytest.raises(ValueError)), 15 | ({"min_value": "abc", "max_value": 0.1}, pytest.raises(TypeError)), 16 | ], 17 | ) 18 | def test_linear_decay(kwargs, context): 19 | X, y = np.random.randn(100, 10), np.random.randn(100) 20 | 21 | with context: 22 | weights = linear_decay(X, y, **kwargs) 23 | assert np.all(weights[:-1] <= weights[1:]) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "kwargs, context", 28 | [ 29 | ({"decay_rate": 0.9}, does_not_raise()), 30 | ({"decay_rate": 0.1}, does_not_raise()), 31 | ({"decay_rate": -1.0}, pytest.raises(ValueError)), 32 | ({"decay_rate": 2.0}, pytest.raises(ValueError)), 33 | ({"decay_rate": "abc"}, pytest.raises(TypeError)), 34 | ], 35 | ) 36 | def test_exponential_decay(kwargs, context): 37 | X, y = np.random.randn(100, 10), np.random.randn(100) 38 | 39 | with context: 40 | weights = exponential_decay(X, y, **kwargs) 41 | assert np.all(weights[:-1] <= weights[1:]) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "kwargs, context", 46 | [ 47 | ({"min_value": 0.1, "max_value": 0.9, "n_steps": 10}, does_not_raise()), 48 | ({"n_steps": 10}, does_not_raise()), 49 | ({"step_size": 5}, does_not_raise()), 50 | ({"min_value": 0.5, "max_value": 0.1, "n_steps": 10}, pytest.raises(ValueError)), 51 | ({"min_value": "abc", "max_value": 0.1, "n_steps": 10}, pytest.raises(TypeError)), 52 | ({"n_steps": None, "step_size": None}, pytest.raises(ValueError)), 53 | ({"n_steps": 10, "step_size": 10}, pytest.raises(ValueError)), 54 | ({"n_steps": 200}, pytest.raises(ValueError)), 55 | ({"step_size": 200}, pytest.raises(ValueError)), 56 | ({"n_steps": -2}, pytest.raises(ValueError)), 57 | ({"step_size": -2}, pytest.raises(ValueError)), 58 | ({"n_steps": 2.5}, pytest.raises(TypeError)), 59 | ({"step_size": 2.5}, pytest.raises(TypeError)), 60 | ], 61 | ) 62 | def test_stepwise_decay(kwargs, context): 63 | X, y = np.random.randn(100, 10), np.random.randn(100) 64 | 65 | with context: 66 | weights = stepwise_decay(X, y, **kwargs) 67 | assert np.all(weights[:-1] <= weights[1:]) 68 | 69 | 70 | @pytest.mark.parametrize( 71 | "kwargs, context", 72 | [ 73 | ({"min_value": 0.1, "max_value": 0.9}, does_not_raise()), 74 | ({"min_value": 0.5, "max_value": 0.1}, pytest.raises(ValueError)), 75 | ({"growth_rate": 0.1}, does_not_raise()), 76 | ({"growth_rate": -0.1}, pytest.raises(ValueError)), 77 | ({"growth_rate": 1.1}, pytest.raises(ValueError)), 78 | ({"abc": 1.1}, pytest.raises(TypeError)), 79 | ], 80 | ) 81 | def test_sigmoid_decay(kwargs, context): 82 | X, y = np.random.randn(100, 10), np.random.randn(100) 83 | 84 | with context: 85 | weights = sigmoid_decay(X, y, **kwargs) 86 | assert np.all(weights[:-1] <= weights[1:]) 87 | -------------------------------------------------------------------------------- /tests/test_meta/test_ordinal_classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.linear_model import LinearRegression, LogisticRegression, RidgeClassifier 4 | 5 | from sklego.meta import OrdinalClassifier 6 | 7 | 8 | @pytest.fixture 9 | def random_xy_ordinal(): 10 | np.random.seed(42) 11 | X = np.random.normal(0, 2, (1000, 3)) 12 | y = np.select(condlist=[X[:, 0] < 2, X[:, 1] > 2], choicelist=[0, 2], default=1) 13 | return X, y 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "estimator, context, err_msg", 18 | [ 19 | (LinearRegression(), pytest.raises(ValueError), "The estimator must be a classifier."), 20 | (RidgeClassifier(), pytest.raises(ValueError), "The estimator must implement `.predict_proba()` method."), 21 | ], 22 | ) 23 | def test_raises_error(random_xy_ordinal, estimator, context, err_msg): 24 | X, y = random_xy_ordinal 25 | with context as exc_info: 26 | ord_clf = OrdinalClassifier(estimator=estimator) 27 | ord_clf.fit(X, y) 28 | 29 | if exc_info: 30 | assert err_msg in str(exc_info.value) 31 | 32 | 33 | @pytest.mark.parametrize("n_jobs", [-2, -1, 2, None]) 34 | @pytest.mark.parametrize("use_calibration", [True, False]) 35 | def test_can_fit_param_combination(random_xy_ordinal, n_jobs, use_calibration): 36 | X, y = random_xy_ordinal 37 | ord_clf = OrdinalClassifier(estimator=LogisticRegression(), n_jobs=n_jobs, use_calibration=use_calibration) 38 | ord_clf.fit(X, y) 39 | 40 | assert ord_clf.n_jobs == n_jobs 41 | assert ord_clf.use_calibration == use_calibration 42 | assert ord_clf.n_classes_ == 3 43 | assert ord_clf.n_features_in_ == X.shape[1] 44 | -------------------------------------------------------------------------------- /tests/test_meta/test_outlier_classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.ensemble import IsolationForest 4 | from sklearn.linear_model import LinearRegression 5 | from sklearn.neighbors import LocalOutlierFactor 6 | from sklearn.svm import OneClassSVM 7 | from sklearn.utils.estimator_checks import parametrize_with_checks 8 | 9 | from sklego.meta import OutlierClassifier 10 | from sklego.mixture import GMMOutlierDetector 11 | 12 | 13 | @parametrize_with_checks([OutlierClassifier(GMMOutlierDetector(threshold=0.1, method="quantile"))]) 14 | def test_sklearn_compatible_estimator(estimator, check): 15 | if check.func.__name__ in { 16 | # Since `OutlierClassifier` is a classifier (`ClassifierMixin`), parametrize_with_checks feeds a classification 17 | # dataset. However this is not how this classifier is supposed to be used. 18 | "check_classifiers_train", 19 | "check_classifiers_classes", 20 | # Similarly, the original dataset could also be regression task depending on the outlier detection algo 21 | "check_classifiers_regression_target", 22 | }: 23 | pytest.skip() 24 | 25 | check(estimator) 26 | 27 | 28 | @pytest.fixture 29 | def dataset(): 30 | np.random.seed(42) 31 | return np.random.normal(0, 1, (2000, 2)) 32 | 33 | 34 | @pytest.mark.parametrize("outlier_model", [GMMOutlierDetector(), OneClassSVM(nu=0.05), IsolationForest()]) 35 | def test_obvious_usecase(dataset, outlier_model): 36 | outlier_clf = OutlierClassifier(outlier_model) 37 | X = dataset 38 | y = (dataset.max(axis=1) > 3).astype(int) 39 | outlier_clf.fit(X, y) 40 | assert outlier_clf.predict([[10, 10]]) == np.array([1]) 41 | assert outlier_clf.predict([[0, 0]]) == np.array([0]) 42 | np.testing.assert_array_almost_equal(outlier_clf.predict_proba([[0, 0]]), np.array([[1, 0]]), decimal=3) 43 | np.testing.assert_allclose(outlier_clf.predict_proba([[10, 10]]), np.array([[0, 1]]), atol=0.2) 44 | assert isinstance(outlier_clf.score(X, y), float) 45 | 46 | 47 | def test_raises_error(dataset): 48 | mod_quantile = LinearRegression() 49 | clf_quantile = OutlierClassifier(mod_quantile) 50 | X = dataset 51 | y = (dataset.max(axis=1) > 3).astype(int) 52 | with pytest.raises(ValueError): 53 | clf_quantile.fit(X, y) 54 | 55 | 56 | def test_raises_error_no_decision_function(dataset): 57 | outlier_model = LocalOutlierFactor() 58 | clf_model = OutlierClassifier(outlier_model) 59 | X = dataset 60 | y = (dataset.max(axis=1) > 3).astype(int) 61 | with pytest.raises(ValueError): 62 | clf_model.fit(X, y) 63 | -------------------------------------------------------------------------------- /tests/test_meta/test_regression_outlier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import polars as pl 4 | import pyarrow as pa 5 | import pytest 6 | from sklearn.linear_model import LinearRegression, LogisticRegression 7 | from sklearn.utils.estimator_checks import parametrize_with_checks 8 | 9 | from sklego.meta import RegressionOutlierDetector 10 | 11 | 12 | @parametrize_with_checks([RegressionOutlierDetector(LinearRegression(), column=0)]) 13 | def test_sklearn_compatible_estimator(estimator, check): 14 | if check.func.__name__ in { 15 | # Since `RegressionOutlierDetector` is an outlier detector (`OutlierMixin`), parametrize_with_checks feeds a 16 | # outlier dataset. However this is not how this componented is supposed to be used. 17 | "check_outliers_train", 18 | "check_fit2d_1feature", # custom message 19 | }: 20 | pytest.skip() 21 | 22 | check(estimator) 23 | 24 | 25 | def test_obvious_example(): 26 | # generate random data for illustrative example 27 | np.random.seed(42) 28 | X = np.random.normal(0, 1, (100, 1)) 29 | y = 1 + np.sum(X, axis=1).reshape(-1, 1) + np.random.normal(0, 0.2, (100, 1)) 30 | for i in [20, 25, 50, 80]: 31 | y[i] += 10 32 | X = np.concatenate([X, y], axis=1) 33 | 34 | # fit and plot 35 | mod = RegressionOutlierDetector(LinearRegression(), column=1) 36 | preds = mod.fit(X).predict(X) 37 | for i in [20, 25, 50, 80]: 38 | assert preds[i] == -1 39 | 40 | 41 | @pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame, pa.table]) 42 | def test_obvious_example_dataframe(frame_func): 43 | # generate random data for illustrative example 44 | np.random.seed(42) 45 | x = np.random.normal(0, 1, 100) 46 | y = 1 + x + np.random.normal(0, 0.2, 100) 47 | for i in [20, 25, 50, 80]: 48 | y[i] += 10 49 | X = frame_func({"x": x, "y": y}) 50 | 51 | # fit and plot 52 | mod = RegressionOutlierDetector(LinearRegression(), column="y") 53 | preds = mod.fit(X).predict(X) 54 | for i in [20, 25, 50, 80]: 55 | assert preds[i] == -1 56 | 57 | 58 | @pytest.mark.parametrize("frame_func", [pd.DataFrame, pl.DataFrame, pa.table]) 59 | def test_raises_error(frame_func): 60 | # generate random data for illustrative example 61 | np.random.seed(42) 62 | x = np.random.normal(0, 1, 100) 63 | y = 1 + x + np.random.normal(0, 0.2, 100) 64 | for i in [20, 25, 50, 80]: 65 | y[i] += 2 66 | X = frame_func({"x": x, "y": y}) 67 | 68 | with pytest.raises(ValueError): 69 | mod = RegressionOutlierDetector(LogisticRegression(), column="y") 70 | mod.fit(X) 71 | -------------------------------------------------------------------------------- /tests/test_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_metrics/__init__.py -------------------------------------------------------------------------------- /tests/test_metrics/test_correlation_score.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.linear_model import Ridge 3 | 4 | from sklego.metrics import correlation_score 5 | 6 | 7 | def test_corr_pandas(): 8 | df = pd.DataFrame( 9 | { 10 | "x1": [1, 2, 3, 4, 5, 6, 7, 8], 11 | "x2": [0, 0, 0, 1, 0, 0, 0, 0], 12 | "y": [2, 3, 4, 6, 6, 7, 8, 9], 13 | } 14 | ) 15 | 16 | mod = Ridge().fit(df[["x1", "x2"]], df["y"]) 17 | assert abs(correlation_score("x1")(mod, df[["x1", "x2"]])) > abs(0.99) 18 | assert abs(correlation_score("x2")(mod, df[["x1", "x2"]])) < abs(0.02) 19 | 20 | 21 | def test_corr_numpy(): 22 | df = pd.DataFrame( 23 | { 24 | "x1": [1, 2, 3, 4, 5, 6, 7, 8], 25 | "x2": [0, 0, 0, 1, 0, 0, 0, 0], 26 | "y": [2, 3, 4, 6, 6, 7, 8, 9], 27 | } 28 | ) 29 | arr = df[["x1", "x2"]].values 30 | mod = Ridge().fit(arr, df["y"]) 31 | assert abs(correlation_score(0)(mod, arr)) > abs(0.99) 32 | assert abs(correlation_score(1)(mod, arr)) < abs(0.02) 33 | -------------------------------------------------------------------------------- /tests/test_metrics/test_equal_opportunity.py: -------------------------------------------------------------------------------- 1 | import types 2 | import warnings 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.pipeline import make_pipeline 8 | 9 | from sklego.metrics import equal_opportunity_score 10 | from sklego.preprocessing import ColumnSelector 11 | 12 | 13 | def test_equal_opportunity_pandas(): 14 | sensitive_classification_dataset = pd.DataFrame( 15 | { 16 | "x1": [1, 0, 1, 0, 1, 0, 1, 1], 17 | "x2": [0, 0, 0, 0, 0, 1, 1, 1], 18 | "y": [1, 1, 1, 0, 1, 0, 0, 1], 19 | } 20 | ) 21 | 22 | X, y = ( 23 | sensitive_classification_dataset.drop(columns="y"), 24 | sensitive_classification_dataset["y"], 25 | ) 26 | 27 | mod_1 = types.SimpleNamespace() 28 | 29 | mod_1.predict = lambda X: np.array([1, 0, 1, 0, 1, 0, 1, 1]) 30 | assert equal_opportunity_score(sensitive_column="x2")(mod_1, X, y) == 0.75 31 | 32 | mod_1.predict = lambda X: np.array([1, 0, 1, 0, 1, 0, 0, 1]) 33 | assert equal_opportunity_score(sensitive_column="x2")(mod_1, X, y) == 0.75 34 | 35 | mod_1.predict = lambda X: np.array([1, 0, 1, 0, 1, 0, 0, 0]) 36 | assert equal_opportunity_score(sensitive_column="x2")(mod_1, X, y) == 0 37 | 38 | 39 | def test_p_percent_pandas_multiclass(): 40 | sensitive_classification_dataset = pd.DataFrame( 41 | { 42 | "x1": [1, 0, 1, 0, 1, 0, 1, 1], 43 | "x2": [0, 0, 0, 0, 0, 1, 1, 1], 44 | "y": [1, 1, 1, 0, 1, 0, 0, 2], 45 | } 46 | ) 47 | 48 | X, y = ( 49 | sensitive_classification_dataset.drop(columns="y"), 50 | sensitive_classification_dataset["y"], 51 | ) 52 | 53 | mod_1 = types.SimpleNamespace() 54 | 55 | mod_1.predict = lambda X: np.array([2, 0, 1, 0, 1, 0, 1, 2]) 56 | assert ( 57 | equal_opportunity_score(sensitive_column="x2", positive_target=2)(mod_1, X, np.array([2, 0, 1, 0, 1, 0, 1, 2])) 58 | == 1 59 | ) 60 | 61 | mod_1.predict = lambda X: np.array([1, 0, 1, 0, 1, 0, 0, 1]) 62 | assert equal_opportunity_score(sensitive_column="x2", positive_target=2)(mod_1, X, y) == 0 63 | 64 | mod_1.predict = lambda X: np.array([1, 0, 1, 0, 1, 0, 0, 0]) 65 | assert equal_opportunity_score(sensitive_column="x2", positive_target=2)(mod_1, X, y) == 0 66 | 67 | 68 | def test_p_percent_numpy(sensitive_classification_dataset): 69 | X, y = sensitive_classification_dataset 70 | X, y = X.to_numpy(), y.to_numpy() 71 | mod = LogisticRegression().fit(X, y) 72 | assert equal_opportunity_score(1)(mod, X, y) == 0 73 | 74 | 75 | def test_warning_is_logged(sensitive_classification_dataset): 76 | X, y = sensitive_classification_dataset 77 | mod_fair = make_pipeline(ColumnSelector("x1"), LogisticRegression()).fit(X, y) 78 | with warnings.catch_warnings(record=True) as w: 79 | # Cause all warnings to always be triggered. 80 | warnings.simplefilter("always") 81 | # Trigger a warning. 82 | equal_opportunity_score("x2", positive_target=2)(mod_fair, X, y) 83 | assert issubclass(w[-1].category, RuntimeWarning) 84 | -------------------------------------------------------------------------------- /tests/test_metrics/test_p_percent.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | from sklearn.linear_model import LogisticRegression 5 | from sklearn.pipeline import make_pipeline 6 | 7 | from sklego.metrics import p_percent_score 8 | from sklego.preprocessing import ColumnSelector 9 | 10 | 11 | def test_p_percent_pandas(sensitive_classification_dataset): 12 | X, y = sensitive_classification_dataset 13 | mod_unfair = LogisticRegression().fit(X, y) 14 | assert p_percent_score("x2")(mod_unfair, X) == 0 15 | 16 | mod_fair = make_pipeline(ColumnSelector("x1"), LogisticRegression()).fit(X, y) 17 | assert p_percent_score("x2")(mod_fair, X) == 0.9 18 | 19 | 20 | def test_p_percent_pandas_multiclass(sensitive_multiclass_classification_dataset): 21 | X, y = sensitive_multiclass_classification_dataset 22 | mod_unfair = LogisticRegression(multi_class="ovr").fit(X, y) 23 | assert p_percent_score("x2")(mod_unfair, X) == 0 24 | assert p_percent_score("x2", positive_target=2)(mod_unfair, X) == 0 25 | 26 | mod_fair = make_pipeline(ColumnSelector("x1"), LogisticRegression()).fit(X, y) 27 | assert p_percent_score("x2")(mod_fair, X) == pytest.approx(0.9333333) 28 | assert p_percent_score("x2", positive_target=2)(mod_fair, X) == 0 29 | 30 | 31 | def test_p_percent_numpy(sensitive_classification_dataset): 32 | X, y = sensitive_classification_dataset 33 | X, y = X.to_numpy(), y.to_numpy() 34 | mod = LogisticRegression().fit(X, y) 35 | assert p_percent_score(1)(mod, X) == 0 36 | 37 | 38 | def test_warning_is_logged(sensitive_classification_dataset): 39 | X, y = sensitive_classification_dataset 40 | mod_fair = make_pipeline(ColumnSelector("x1"), LogisticRegression()).fit(X, y) 41 | with warnings.catch_warnings(record=True) as w: 42 | # Cause all warnings to always be triggered. 43 | warnings.simplefilter("always") 44 | # Trigger a warning. 45 | p_percent_score("x2", positive_target=2)(mod_fair, X) 46 | assert issubclass(w[-1].category, RuntimeWarning) 47 | -------------------------------------------------------------------------------- /tests/test_model_selection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_model_selection/__init__.py -------------------------------------------------------------------------------- /tests/test_model_selection/test_clusterfold.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn import clone 4 | from sklearn.base import BaseEstimator 5 | from sklearn.cluster import KMeans, MiniBatchKMeans 6 | from sklearn.pipeline import make_pipeline 7 | from sklearn.preprocessing import StandardScaler 8 | 9 | from sklego.model_selection import ClusterFoldValidation 10 | from tests.conftest import id_func 11 | 12 | k_means_pipeline = make_pipeline(StandardScaler(), KMeans()) 13 | 14 | 15 | class DummyCluster(BaseEstimator): 16 | def __init__(self, n_splits=3): 17 | self.n_splits = n_splits 18 | 19 | def fit(self, X): 20 | return self 21 | 22 | def predict(self, X): 23 | return np.random.randint(0, self.n_splits, size=X.shape[0]) 24 | 25 | def fit_predict(self, X): 26 | return self.predict(X) 27 | 28 | 29 | @pytest.mark.parametrize( 30 | "cluster_method", 31 | [KMeans(), MiniBatchKMeans(), DummyCluster(), k_means_pipeline], 32 | ids=id_func, 33 | ) 34 | def test_splits_not_fitted(cluster_method, random_xy_dataset_regr): 35 | cluster_method = clone(cluster_method) 36 | X, y = random_xy_dataset_regr 37 | kf = ClusterFoldValidation(cluster_method=cluster_method) 38 | for train_index, test_index in kf.split(X): 39 | assert len(train_index) > 0 40 | assert len(test_index) > 0 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "cluster_method", 45 | [KMeans(), MiniBatchKMeans(), DummyCluster(), k_means_pipeline], 46 | ids=id_func, 47 | ) 48 | def test_splits_fitted(cluster_method, random_xy_dataset_regr): 49 | cluster_method = clone(cluster_method) 50 | X, y = random_xy_dataset_regr 51 | cluster_method = cluster_method.fit(X) 52 | kf = ClusterFoldValidation(cluster_method=cluster_method) 53 | for train_index, test_index in kf.split(X): 54 | assert len(train_index) > 0 55 | assert len(test_index) > 0 56 | 57 | 58 | def test_no_split(random_xy_dataset_regr): 59 | X, y = random_xy_dataset_regr 60 | # With only one split, the method should raise a ValueError 61 | cluster_method = DummyCluster(n_splits=1) 62 | kf = ClusterFoldValidation(cluster_method=cluster_method) 63 | with pytest.raises(ValueError): 64 | for train_index, test_index in kf.split(X): 65 | assert len(train_index) > 0 66 | assert len(test_index) > 0 67 | -------------------------------------------------------------------------------- /tests/test_notinstalled.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_installed_package_works(): 5 | package_name = "pandas" 6 | version = ">=0.23.4" 7 | 8 | try: 9 | import pandas as pd 10 | except ImportError: 11 | from sklego.notinstalled import NotInstalledPackage 12 | 13 | pd = NotInstalledPackage(package_name, version=version) 14 | 15 | assert pd.__version__ 16 | 17 | 18 | def test_uninstsalled_package_raises(): 19 | package_name = "thispackagedoesnotexist" 20 | version = "==1.2.3" 21 | 22 | try: 23 | import thispackagedoesnotexist as package 24 | except ImportError: 25 | from sklego.notinstalled import NotInstalledPackage 26 | 27 | package = NotInstalledPackage(package_name, version=version) 28 | 29 | with pytest.raises(ImportError) as e: 30 | package.__version__ 31 | assert package_name in str(e) 32 | assert version in str(e) 33 | -------------------------------------------------------------------------------- /tests/test_preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/koaning/scikit-lego/17ab5926dd33a8b7b7ec0fa9cd725033f6641714/tests/test_preprocessing/__init__.py -------------------------------------------------------------------------------- /tests/test_preprocessing/test_columncapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | from sklearn.utils.validation import FLOAT_DTYPES 6 | 7 | from sklego.preprocessing import ColumnCapper 8 | 9 | 10 | @parametrize_with_checks([ColumnCapper()]) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | check(estimator) 13 | 14 | 15 | def test_quantile_range(): 16 | def expect_type_error(quantile_range): 17 | with pytest.raises(TypeError): 18 | ColumnCapper(quantile_range).fit([]) 19 | 20 | def expect_value_error(quantile_range): 21 | with pytest.raises(ValueError): 22 | ColumnCapper(quantile_range).fit([]) 23 | 24 | # Testing quantile_range type 25 | expect_type_error(quantile_range=1) 26 | expect_type_error(quantile_range="a") 27 | expect_type_error(quantile_range={}) 28 | expect_type_error(quantile_range=set()) 29 | 30 | # Testing quantile_range values 31 | # Invalid type: 32 | expect_type_error(quantile_range=("a", 90)) 33 | expect_type_error(quantile_range=(10, "a")) 34 | 35 | # Invalid limits 36 | expect_value_error(quantile_range=(-1, 90)) 37 | expect_value_error(quantile_range=(10, 110)) 38 | 39 | # Invalid order 40 | expect_value_error(quantile_range=(60, 40)) 41 | 42 | 43 | def test_interpolation(): 44 | valid_interpolations = ("linear", "lower", "higher", "midpoint", "nearest") 45 | invalid_interpolations = ("test", 42, None, [], {}, set(), 0.42) 46 | 47 | for interpolation in valid_interpolations: 48 | ColumnCapper(interpolation=interpolation) 49 | 50 | for interpolation in invalid_interpolations: 51 | with pytest.raises(ValueError): 52 | ColumnCapper(interpolation=interpolation).fit([]) 53 | 54 | 55 | @pytest.fixture() 56 | def valid_df(): 57 | return pd.DataFrame({"a": [1, np.nan, 3, 4], "b": [11, 12, np.inf, 14], "c": [21, 22, 23, 24]}) 58 | 59 | 60 | def test_X_types_and_transformed_shapes(valid_df): 61 | def expect_value_error(X, X_transform=None): 62 | if X_transform is None: 63 | X_transform = X 64 | with pytest.raises(ValueError): 65 | capper = ColumnCapper().fit(X) 66 | capper.transform(X_transform) 67 | 68 | # Fitted and transformed arrays must have the same number of columns 69 | expect_value_error(valid_df, valid_df[["a", "b"]]) 70 | 71 | invalid_dfs = [ 72 | pd.DataFrame({"a": [np.nan, np.nan, np.nan], "b": [11, 12, 13]}), 73 | pd.DataFrame({"a": [np.inf, np.inf, np.inf], "b": [11, 12, 13]}), 74 | ] 75 | 76 | for invalid_df in invalid_dfs: 77 | expect_value_error(invalid_df) # contains an invalid column ('a') 78 | expect_value_error(invalid_df["b"]) # 1d arrays should be reshaped before fitted/transformed 79 | # Like this: 80 | ColumnCapper().fit_transform(invalid_df["b"].values.reshape(-1, 1)) 81 | ColumnCapper().fit_transform(invalid_df["b"].values.reshape(1, -1)) 82 | 83 | capper = ColumnCapper() 84 | for X in valid_df, valid_df.values: 85 | assert capper.fit_transform(X).shape == X.shape 86 | 87 | 88 | def test_nan_inf(valid_df): 89 | # Capping infs 90 | capper = ColumnCapper(discard_infs=False) 91 | assert (capper.fit_transform(valid_df) == np.inf).sum().sum() == 0 92 | assert np.isnan(capper.fit_transform(valid_df)).sum() == 1 93 | 94 | # Discarding infs 95 | capper = ColumnCapper(discard_infs=True) 96 | assert (capper.fit_transform(valid_df) == np.inf).sum().sum() == 0 97 | assert np.isnan(capper.fit_transform(valid_df)).sum() == 2 98 | 99 | 100 | def test_dtype_regression(random_xy_dataset_regr): 101 | X, y = random_xy_dataset_regr 102 | assert ColumnCapper().fit(X, y).transform(X).dtype in FLOAT_DTYPES 103 | 104 | 105 | def test_dtype_classification(random_xy_dataset_clf): 106 | X, y = random_xy_dataset_clf 107 | assert ColumnCapper().fit(X, y).transform(X).dtype in FLOAT_DTYPES 108 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_columndropper.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import pandas as pd 4 | import polars as pl 5 | import pytest 6 | from pandas.testing import assert_frame_equal as pandas_assert_frame_equal 7 | from polars.testing import assert_frame_equal as polars_assert_frame_equal 8 | from sklearn.pipeline import Pipeline, make_pipeline 9 | 10 | from sklego.preprocessing import ColumnDropper 11 | 12 | 13 | @pytest.fixture() 14 | def data(): 15 | return { 16 | "a": [1, 2, 3, 4, 5, 6], 17 | "b": [10, 9, 8, 7, 6, 5], 18 | "c": ["a", "b", "a", "b", "c", "c"], 19 | "d": ["b", "a", "a", "b", "a", "b"], 20 | "e": [0, 1, 0, 1, 0, 1], 21 | } 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "frame_func, assert_func", 26 | [ 27 | (pd.DataFrame, pandas_assert_frame_equal), 28 | (pl.DataFrame, polars_assert_frame_equal), 29 | ], 30 | ) 31 | @pytest.mark.parametrize( 32 | "to_drop, context", 33 | [ 34 | (["e"], does_not_raise()), # one 35 | (["a", "b"], does_not_raise()), # two 36 | ([], does_not_raise()), # none 37 | (["a", "b", "c", "d", "e"], pytest.raises(ValueError)), # all 38 | (["f"], pytest.raises(KeyError)), # not in data 39 | ], 40 | ) 41 | @pytest.mark.parametrize("wrapper", [lambda x: x, make_pipeline]) 42 | def test_drop(data, frame_func, assert_func, to_drop, context, wrapper): 43 | sub_data = {k: v for k, v in data.items() if k not in to_drop} 44 | 45 | with context: 46 | transformer = wrapper(ColumnDropper(to_drop)) 47 | result_df = transformer.fit_transform(frame_func(data)) 48 | expected_df = frame_func(sub_data) 49 | 50 | assert_func(result_df, expected_df) 51 | 52 | if not isinstance(transformer, Pipeline): 53 | assert transformer.get_feature_names() == list(sub_data.keys()) 54 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_columnselector.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext as does_not_raise 2 | 3 | import pandas as pd 4 | import polars as pl 5 | import pytest 6 | from pandas.testing import assert_frame_equal as pandas_assert_frame_equal 7 | from polars.testing import assert_frame_equal as polars_assert_frame_equal 8 | from sklearn.pipeline import Pipeline, make_pipeline 9 | 10 | from sklego.preprocessing import ColumnSelector 11 | 12 | 13 | @pytest.fixture() 14 | def data(): 15 | return { 16 | "a": [1, 2, 3, 4, 5, 6], 17 | "b": [10, 9, 8, 7, 6, 5], 18 | "c": ["a", "b", "a", "b", "c", "c"], 19 | "d": ["b", "a", "a", "b", "a", "b"], 20 | "e": [0, 1, 0, 1, 0, 1], 21 | } 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "frame_func, assert_func", 26 | [ 27 | (pd.DataFrame, pandas_assert_frame_equal), 28 | (pl.DataFrame, polars_assert_frame_equal), 29 | ], 30 | ) 31 | @pytest.mark.parametrize( 32 | "select, context", 33 | [ 34 | (["a", "b"], does_not_raise()), # two 35 | (["e"], does_not_raise()), # one 36 | (["a", "b", "c", "d", "e"], does_not_raise()), # all) 37 | ([], pytest.raises(ValueError)), # none 38 | (["f"], pytest.raises(KeyError)), # not in data 39 | ], 40 | ) 41 | @pytest.mark.parametrize("wrapper", [lambda x: x, make_pipeline]) 42 | def test_drop(data, frame_func, assert_func, select, context, wrapper): 43 | sub_data = {k: v for k, v in data.items() if k in select} 44 | 45 | with context: 46 | transformer = wrapper(ColumnSelector(select)) 47 | result_df = transformer.fit_transform(frame_func(data)) 48 | expected_df = frame_func(sub_data) 49 | 50 | assert_func(result_df, expected_df) 51 | 52 | if not isinstance(transformer, Pipeline): 53 | assert transformer.get_feature_names() == list(sub_data.keys()) 54 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_dictmapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | 6 | from sklego.preprocessing import DictMapper 7 | 8 | 9 | @parametrize_with_checks([DictMapper(mapper={"foo": 1}, default=-1)]) 10 | def test_sklearn_compatible_estimator(estimator, check): 11 | check(estimator) 12 | 13 | 14 | @pytest.fixture() 15 | def mapper(): 16 | return {"foo": 1, "bar": 2, "baz": 3} 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "input_array,expected_array", 21 | [ 22 | (["foo", "bar", "baz"], [1, 2, 3]), 23 | (["foo", "bar", "bar"], [1, 2, 2]), 24 | (["foo", "bar", "monty"], [1, 2, -1]), 25 | (["foo", "bar", np.nan], [1, 2, -1]), 26 | ([["foo", "bar", "baz"], ["foo", "bar", "baz"]], [[1, 2, 3], [1, 2, 3]]), 27 | ], 28 | ) 29 | def test_array(input_array, expected_array, mapper): 30 | X = np.array(input_array).reshape(-1, 1) 31 | expected = np.array(expected_array).reshape(-1, 1) 32 | result = DictMapper(mapper=mapper, default=-1).fit_transform(X) 33 | np.testing.assert_array_equal(result, expected) 34 | 35 | 36 | def test_pandas(mapper): 37 | X = pd.DataFrame(["foo", "bar", "baz"], dtype=object) 38 | expected = np.array([1, 2, 3]).reshape(-1, 1) 39 | result = DictMapper(mapper=mapper, default=-1).fit_transform(X) 40 | np.testing.assert_array_equal(result, expected) 41 | 42 | 43 | def test_no_mapper(): 44 | mapper = {} 45 | X = pd.DataFrame(["foo", "bar", "baz"], dtype=object) 46 | expected = np.array([-1, -1, -1]).reshape(-1, 1) 47 | result = DictMapper(mapper=mapper, default=-1).fit_transform(X) 48 | np.testing.assert_array_equal(result, expected) 49 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_identitytransformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.utils.estimator_checks import parametrize_with_checks 3 | 4 | from sklego.preprocessing import IdentityTransformer 5 | 6 | 7 | @parametrize_with_checks([IdentityTransformer(check_X=True)]) 8 | def test_sklearn_compatible_estimator(estimator, check): 9 | check(estimator) 10 | 11 | 12 | def test_same_values(random_xy_dataset_regr): 13 | X, y = random_xy_dataset_regr 14 | X_new = IdentityTransformer(check_X=True).fit_transform(X) 15 | assert np.isclose(X, X_new).all() 16 | 17 | 18 | def test_nan_inf(random_xy_dataset_regr): 19 | # see https://github.com/koaning/scikit-lego/pull/527 20 | X, y = random_xy_dataset_regr 21 | X = X.astype(np.float32) 22 | X[np.random.ranf(size=X.shape) > 0.9] = np.nan 23 | X[np.random.ranf(size=X.shape) > 0.9] = -np.inf 24 | X[np.random.ranf(size=X.shape) > 0.9] = np.inf 25 | IdentityTransformer(check_X=False).fit_transform(X) 26 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_interval_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import parametrize_with_checks 4 | 5 | from sklego.preprocessing import IntervalEncoder 6 | 7 | pytestmark = pytest.mark.cvxpy 8 | 9 | 10 | @parametrize_with_checks([IntervalEncoder()]) 11 | def test_sklearn_compatible_estimator(estimator, check): 12 | check(estimator) 13 | 14 | 15 | @pytest.mark.parametrize("chunks", [1, 2, 5, 10]) 16 | def test_obvious_cases_one(random_xy_dataset_regr, chunks): 17 | X, y = random_xy_dataset_regr 18 | y = np.ones(y.shape) 19 | x_transform = IntervalEncoder(n_chunks=chunks).fit(X, y).transform(X) 20 | assert x_transform.shape == X.shape 21 | assert np.all(np.isclose(x_transform, 1.0)) 22 | 23 | 24 | @pytest.mark.parametrize("method", ["average", "normal", "increasing", "decreasing"]) 25 | def test_obvious_cases_two(random_xy_dataset_regr_small, method): 26 | X, y = random_xy_dataset_regr_small 27 | y = np.ones(y.shape) 28 | x_transform = IntervalEncoder(method=method).fit(X, y).transform(X) 29 | assert x_transform.shape == X.shape 30 | assert np.all(np.isclose(x_transform, 1.0)) 31 | 32 | 33 | def generate_dataset(start, n=600): 34 | np.random.seed(42) 35 | xs = np.arange(start, start + n) / 100 / np.pi 36 | y = np.sin(xs) + np.random.normal(0, 0.1, n) 37 | return xs.reshape(-1, 1), y 38 | 39 | 40 | @pytest.mark.parametrize("data_init", [50, 600, 1200, 2100]) 41 | def test_monotonicity_increasing(data_init): 42 | X, y = generate_dataset(start=data_init) 43 | encoder = IntervalEncoder(n_chunks=40, method="increasing") 44 | y_transformed = encoder.fit_transform(X, y).reshape(-1).round(4) 45 | for i in range(len(y_transformed) - 1): 46 | assert y_transformed[i] <= y_transformed[i + 1] 47 | 48 | 49 | @pytest.mark.parametrize("data_init", [50, 600, 1200, 2100]) 50 | def test_monotonicity_decreasing(data_init): 51 | X, y = generate_dataset(start=data_init) 52 | encoder = IntervalEncoder(n_chunks=40, method="decreasing") 53 | y_transformed = encoder.fit_transform(X, y).reshape(-1).round(4) 54 | for i in range(len(y_transformed) - 1): 55 | assert y_transformed[i] >= y_transformed[i + 1] 56 | 57 | 58 | def test_throw_valuerror_given_nonsense(): 59 | X = np.ones((10, 2)) 60 | y = np.ones(10) 61 | with pytest.raises(ValueError): 62 | IntervalEncoder(n_chunks=0).fit(X, y) 63 | with pytest.raises(ValueError): 64 | IntervalEncoder(n_chunks=-1).fit(X, y) 65 | with pytest.raises(ValueError): 66 | IntervalEncoder(span=-0.1).fit(X, y) 67 | with pytest.raises(ValueError): 68 | IntervalEncoder(span=2.0).fit(X, y) 69 | with pytest.raises(ValueError): 70 | IntervalEncoder(method="dinosaurhead").fit(X, y) 71 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_monospline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.preprocessing import SplineTransformer 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | 6 | from sklego.preprocessing import MonotonicSplineTransformer 7 | 8 | 9 | @parametrize_with_checks([MonotonicSplineTransformer()]) 10 | def test_sklearn_compatible_estimator(estimator, check): 11 | check(estimator) 12 | 13 | 14 | @pytest.mark.parametrize("n_knots", [3, 5]) 15 | @pytest.mark.parametrize("degree", [3, 5]) 16 | @pytest.mark.parametrize("knots", ["uniform", "quantile"]) 17 | def test_monotonic_spline_transformer(n_knots, degree, knots): 18 | X = np.random.uniform(size=(100, 10)) 19 | transformer = MonotonicSplineTransformer(n_knots=n_knots, degree=degree, knots=knots) 20 | transformer_sk = SplineTransformer(n_knots=n_knots, degree=degree, knots=knots) 21 | transformer.fit(X) 22 | transformer_sk.fit(X) 23 | out = transformer.transform(X) 24 | out_sk = transformer_sk.transform(X) 25 | 26 | # Both should have the same shape 27 | assert out.shape == out_sk.shape 28 | 29 | n_splines_per_feature = n_knots + degree - 1 30 | assert out.shape[1] == X.shape[1] * n_splines_per_feature 31 | 32 | # I splines should be bounded by 0 and 1 33 | assert np.logical_or(out >= 0, np.isclose(out, 0)).all() 34 | assert np.logical_or(out <= 1, np.isclose(out, 1)).all() 35 | 36 | # The features should be monotonically increasing 37 | for i in range(X.shape[1]): 38 | feature = X[:, i] 39 | sorted_out = out[np.argsort(feature), i * n_splines_per_feature : (i + 1) * n_splines_per_feature] 40 | differences = np.diff(sorted_out, axis=0) 41 | 42 | # All differences should be greater or equal to zero upto floating point errors 43 | assert np.logical_or(np.greater_equal(differences, 0), np.isclose(differences, 0)).all() 44 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_orthogonal_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | 6 | from sklego.preprocessing import OrthogonalTransformer 7 | 8 | 9 | @pytest.fixture 10 | def sample_matrix(): 11 | np.random.seed(1313) 12 | return np.random.normal(size=(50, 10)) 13 | 14 | 15 | @pytest.fixture 16 | def sample_df(sample_matrix): 17 | return pd.DataFrame(sample_matrix) 18 | 19 | 20 | @parametrize_with_checks([OrthogonalTransformer()]) 21 | def test_sklearn_compatible_estimator(estimator, check): 22 | check(estimator) 23 | 24 | 25 | def check_is_orthogonal(X, tolerance=10**-5): 26 | """ 27 | Check if X is an column orthogonal matrix. If X is column orthogonal, then X.T * X equals the identity matrix 28 | :param X: Matrix to check 29 | :param tolerance: Tolerance for difference caused by rounding 30 | :raises: AssertionError if X is not orthogonal 31 | """ 32 | diff_with_eye = np.dot(X.T, X) - np.eye(X.shape[1]) 33 | 34 | if np.max(np.abs(diff_with_eye)) > tolerance: 35 | raise AssertionError("X is not orthogonal") 36 | 37 | 38 | def check_is_orthonormal(X, tolerance=10**-5): 39 | """ 40 | Check if X is an column orthonormal matrix, i.e. orthogonal and with columns with norm 1. 41 | :param X: Matrix to check 42 | :param tolerance: Tolerance for difference caused by rounding 43 | :raises: AssertionError if X is not orthonormal 44 | """ 45 | # Orthonormal, so orthogonal and columns must be normalized 46 | check_is_orthogonal(X, tolerance) 47 | 48 | norms = np.linalg.norm(X, ord=2, axis=0) 49 | 50 | if (max(norms) > 1 + tolerance) or (min(norms) < 1 - tolerance): 51 | raise AssertionError("X is not orthonormal") 52 | 53 | 54 | def test_orthogonal_transformer(sample_matrix): 55 | ot = OrthogonalTransformer(normalize=False) 56 | ot.fit(X=sample_matrix) 57 | 58 | assert hasattr(ot, "inv_R_") 59 | assert hasattr(ot, "normalization_vector_") 60 | assert ot.inv_R_.shape[0] == sample_matrix.shape[1] 61 | 62 | assert all(ot.normalization_vector_ == 1) 63 | 64 | trans = ot.transform(sample_matrix) 65 | 66 | check_is_orthogonal(trans) 67 | 68 | 69 | def test_orthonormal_transformer(sample_matrix): 70 | ot = OrthogonalTransformer(normalize=True) 71 | ot.fit(X=sample_matrix) 72 | 73 | assert hasattr(ot, "inv_R_") 74 | assert hasattr(ot, "normalization_vector_") 75 | assert ot.inv_R_.shape[0] == sample_matrix.shape[1] 76 | assert ot.normalization_vector_.shape[0] == sample_matrix.shape[1] 77 | 78 | trans = ot.transform(sample_matrix) 79 | 80 | check_is_orthonormal(trans) 81 | 82 | 83 | def test_orthogonal_with_df(sample_df): 84 | ot = OrthogonalTransformer(normalize=False) 85 | ot.fit(X=sample_df) 86 | 87 | assert ot.inv_R_.shape[0] == sample_df.shape[1] 88 | 89 | trans = ot.transform(sample_df) 90 | 91 | check_is_orthogonal(trans) 92 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_outlier_remover.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.cluster import KMeans 4 | from sklearn.ensemble import IsolationForest 5 | from sklearn.pipeline import Pipeline 6 | from sklearn.utils.estimator_checks import parametrize_with_checks 7 | 8 | from sklego.mixture import GMMOutlierDetector 9 | from sklego.preprocessing import OutlierRemover 10 | 11 | 12 | @parametrize_with_checks( 13 | [ 14 | OutlierRemover(outlier_detector=GMMOutlierDetector(), refit=True), 15 | OutlierRemover(outlier_detector=IsolationForest(), refit=True), 16 | ] 17 | ) 18 | def test_sklearn_compatible_estimator(estimator, check): 19 | if check.func.__name__ in { 20 | # As this transformer removes samples, it is not standard for sure 21 | "check_transformer_general", 22 | "check_methods_sample_order_invariance", # leads to out of index 23 | "check_methods_subset_invariance", # leads to different shapes 24 | "check_transformer_data_not_an_array", # hash only supports a few types 25 | "check_pipeline_consistency", # Discussed in https://github.com/koaning/scikit-lego/issues/643 26 | }: 27 | pytest.skip("OutlierRemover is a TrainOnlyTransformer") 28 | check(estimator) 29 | 30 | 31 | def test_no_outliers(mocker): 32 | mock_outlier_detector = mocker.Mock() 33 | mock_outlier_detector.fit.return_value = None 34 | mock_outlier_detector.predict.return_value = np.array([1, 1]) 35 | mocker.patch("sklego.preprocessing.outlier_remover.clone").return_value = mock_outlier_detector 36 | 37 | outlier_remover = OutlierRemover(outlier_detector=mock_outlier_detector, refit=True) 38 | outlier_remover.fit(X=np.array([[1, 1], [2, 2]])) 39 | assert len(outlier_remover.transform_train(np.array([[1, 1], [2, 2]]))) == 2 40 | 41 | 42 | def test_remove_outlier(mocker): 43 | mock_outlier_detector = mocker.Mock() 44 | mock_outlier_detector.fit.return_value = None 45 | mock_outlier_detector.predict.return_value = np.array([-1]) 46 | mocker.patch("sklego.preprocessing.outlier_remover.clone").return_value = mock_outlier_detector 47 | 48 | outlier_remover = OutlierRemover(outlier_detector=mock_outlier_detector, refit=True) 49 | outlier_remover.fit(X=np.array([[5, 5]])) 50 | assert len(outlier_remover.transform_train(np.array([[0, 0]]))) == 0 51 | 52 | 53 | def test_do_not_refit(mocker): 54 | mock_outlier_detector = mocker.Mock() 55 | mock_outlier_detector.fit.return_value = None 56 | mock_outlier_detector.predict.return_value = np.array([-1]) 57 | mocker.patch("sklego.preprocessing.outlier_remover.clone").return_value = mock_outlier_detector 58 | 59 | outlier_remover = OutlierRemover(outlier_detector=mock_outlier_detector, refit=False) 60 | outlier_remover.fit(X=np.array([[5, 5]])) 61 | mock_outlier_detector.fit.assert_not_called() 62 | 63 | 64 | def test_pipeline_integration(): 65 | np.random.seed(42) 66 | dataset = np.concatenate([np.random.normal(0, 1, (2000, 2))]) 67 | isolation_forest_remover = OutlierRemover(outlier_detector=IsolationForest()) 68 | gmm_remover = OutlierRemover(outlier_detector=GMMOutlierDetector()) 69 | pipeline = Pipeline( 70 | [ 71 | ("isolation_forest_remover", isolation_forest_remover), 72 | ("gmm_remover", gmm_remover), 73 | ("kmeans", KMeans()), 74 | ] 75 | ) 76 | pipeline.fit(dataset) 77 | pipeline.transform(dataset) 78 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_randomadder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.model_selection import train_test_split 4 | from sklearn.utils.estimator_checks import parametrize_with_checks 5 | 6 | from sklego.preprocessing import RandomAdder 7 | 8 | 9 | @parametrize_with_checks([RandomAdder()]) 10 | def test_sklearn_compatible_estimator(estimator, check): 11 | if check.func.__name__ in { 12 | "check_transformer_data_not_an_array", # hash only supports a few types 13 | }: 14 | pytest.skip("RandomAdder is a TrainOnlyTransformer") 15 | 16 | check(estimator) 17 | 18 | 19 | def test_dtype_regression(random_xy_dataset_regr): 20 | X, y = random_xy_dataset_regr 21 | assert RandomAdder().fit(X, y).transform(X).dtype == float 22 | 23 | 24 | def test_dtype_classification(random_xy_dataset_clf): 25 | X, y = random_xy_dataset_clf 26 | assert RandomAdder().fit(X, y).transform(X).dtype == float 27 | 28 | 29 | def test_only_transform_train(random_xy_dataset_clf): 30 | X, y = random_xy_dataset_clf 31 | X_train, X_test, y_train, y_test = train_test_split(X, y) 32 | 33 | random_adder = RandomAdder() 34 | random_adder.fit(X_train, y_train) 35 | 36 | assert np.all(random_adder.transform(X_train) != X_train) 37 | assert np.all(random_adder.transform(X_test) == X_test) 38 | -------------------------------------------------------------------------------- /tests/test_preprocessing/test_repeatingbasisfunction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from sklego.preprocessing import RepeatingBasisFunction 6 | from sklego.preprocessing.repeatingbasis import _RepeatingBasisFunction 7 | 8 | 9 | @pytest.fixture() 10 | def df(): 11 | return pd.DataFrame( 12 | { 13 | "a": [1, 2, 3, 4, 5, 6], 14 | "b": np.log([10, 9, 8, 7, 6, 5]), 15 | "c": ["a", "b", "a", "b", "c", "c"], 16 | "d": ["b", "a", "a", "b", "a", "b"], 17 | "e": [0, 1, 0, 1, 0, 1], 18 | } 19 | ) 20 | 21 | 22 | def test_int_indexing(df): 23 | X, y = df[["a", "b", "c", "d"]], df[["e"]] 24 | tf = RepeatingBasisFunction(column=0, n_periods=4, remainder="passthrough") 25 | assert tf.fit(X, y).transform(X).shape == (6, 7) 26 | 27 | 28 | def test_str_indexing(df): 29 | X, y = df[["a", "b", "c", "d"]], df[["e"]] 30 | tf = RepeatingBasisFunction(column="b", n_periods=4, remainder="passthrough") 31 | assert tf.fit(X, y).transform(X).shape == (6, 7) 32 | 33 | 34 | def test_drop_remainder(df): 35 | X, y = df[["a", "b", "c", "d"]], df[["e"]] 36 | tf = RepeatingBasisFunction(column="b", n_periods=4, remainder="drop") 37 | assert tf.fit(X, y).transform(X).shape == (6, 4) 38 | 39 | 40 | def test_dataframe_equals_array(df): 41 | X, y = df[["a", "b", "c", "d"]], df[["e"]] 42 | tf = RepeatingBasisFunction(column=1, n_periods=4, remainder="passthrough") 43 | df_transformed = tf.fit(X, y).transform(X) 44 | array_transformed = tf.fit(X.values, y).transform(X.values) 45 | np.testing.assert_array_equal(df_transformed, array_transformed) 46 | 47 | 48 | def test_when_rbf_helper_receives_more_than_one_col_raises_value_error(df): 49 | X, y = df[["a", "b", "c", "d"]], df[["e"]] 50 | rbf_helper_tf = _RepeatingBasisFunction() 51 | with pytest.raises(ValueError): 52 | rbf_helper_tf.fit(X, y) 53 | --------------------------------------------------------------------------------