├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── other-issue.md ├── PULL_REQUEST_TEMPLATE │ └── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── cronjob_unit_tests.yml │ ├── publish_to_pypi.yml │ └── unit_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENCE ├── README.md ├── VISION.md ├── docs ├── api │ ├── feature_elimination.md │ ├── model_interpret.md │ ├── sample_similarity.md │ └── utils.md ├── discussion │ └── nb_rfecv_vs_shaprfecv.ipynb ├── howto │ ├── grouped_data.ipynb │ └── reproducibility.ipynb ├── img │ ├── Probatus_P.png │ ├── Probatus_P_white.png │ ├── earlystoppingshaprfecv.png │ ├── logo_large.png │ ├── logo_large_white.png │ ├── model_interpret_dep.png │ ├── model_interpret_importance.png │ ├── model_interpret_sample.png │ ├── model_interpret_summary.png │ ├── resemblance_model_schema.png │ ├── sample_similarity_permutation_importance.png │ ├── sample_similarity_shap_importance.png │ ├── sample_similarity_shap_summary.png │ └── shaprfecv.png ├── index.md └── tutorials │ ├── nb_automatic_best_num_features.ipynb │ ├── nb_custom_scoring.ipynb │ ├── nb_sample_similarity.ipynb │ ├── nb_shap_dependence.ipynb │ ├── nb_shap_feature_elimination.ipynb │ ├── nb_shap_model_interpreter.ipynb │ └── nb_shap_variance_penalty_and_results_comparison.ipynb ├── mkdocs.yml ├── probatus ├── __init__.py ├── feature_elimination │ ├── __init__.py │ ├── early_stopping_feature_elimination.py │ └── feature_elimination.py ├── interpret │ ├── __init__.py │ ├── model_interpret.py │ └── shap_dependence.py ├── sample_similarity │ ├── __init__.py │ └── resemblance_model.py └── utils │ ├── __init__.py │ ├── _utils.py │ ├── arrayfuncs.py │ ├── base_class_interface.py │ ├── exceptions.py │ ├── scoring.py │ └── shap_helpers.py ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── docs ├── __init__.py ├── test_docstring.py └── test_notebooks.py ├── feature_elimination └── test_feature_elimination.py ├── interpret ├── __init__.py ├── test_model_interpret.py └── test_shap_dependence.py ├── mocks.py ├── sample_similarity ├── __init__.py └── test_resemblance_model.py └── utils ├── __init__.py ├── test_base_class.py └── test_utils_array_funcs.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | --- 8 | 9 | **Describe the bug** 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | **Environment (please complete the following information):** 14 | 15 | - probatus version [e.g. 1.5.0] (`pip list | grep probatus`) 16 | - python version (`python -V`) 17 | - OS: [e.g. macOS, windows, linux] 18 | 19 | **To Reproduce** 20 | 21 | Code or steps to reproduce the error 22 | 23 | ```python 24 | # Put your code here 25 | ``` 26 | 27 | **Error traceback** 28 | 29 | If applicable please provide full traceback of the error. 30 | 31 | ``` 32 | # Put traceback here 33 | ``` 34 | 35 | **Expected behavior** 36 | 37 | A clear and concise description of what you expected to happen. 38 | 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Problem Description** 11 | A clear and concise description of what the problem is, e.g. probatus currently does not allow to use this model. 12 | 13 | **Desired Outcome** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Solution Outline** 17 | If you have an idea how the solution should look like, please describe it. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other-issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other Issue 3 | about: Other kind of issue 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Issue Description** 11 | Please describe your issue. 12 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Pull Request 3 | about: Propose changes to the codebase 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | Please provide a summary of the changes you are proposing with this Pull Request. Include the motivation and context. Highlight any key changes. If this PR fixes an open issue, mention the #issue_number" in this PR. 12 | 13 | ## Type of change 14 | Please delete options that are not relevant. 15 | 16 | - [ ] Bug fix (non-breaking change which fixes an issue) 17 | - [ ] New feature (non-breaking change which adds functionality) 18 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 19 | - [ ] Documentation update 20 | 21 | ## Checklist: 22 | Go over all the following points, and put an "x" in all the boxes that apply. If you're unsure about any of these, don't hesitate to ask. We're here to help! 23 | 24 | - [ ] My code follows the code style of this project. 25 | - [ ] My change requires a change to the documentation/notebooks. 26 | - [ ] I have updated the documentation or added a notebook accordingly. 27 | - [ ] I have tested the code locally according to the **CONTRIBUTING.md** document. 28 | - [ ] I have added tests to cover my changes. 29 | - [ ] All new and existing tests passed. 30 | 31 | ## How Has This Been Tested? 32 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. List any relevant details for your test configuration. 33 | 34 | ## Screenshots (if applicable): 35 | Include screenshots or gifs of the changes you have made. This is especially useful for UI changes. 36 | 37 | ## Additional Context: 38 | Add any other context or screenshots about the pull request here. This could include reasons for changes, decisions made, or anything you want the reviewer to know. 39 | 40 | --- 41 | 42 | **For more information on contributing, please see the [CONTRIBUTING.md](../../CONTRIBUTING.md) document.** 43 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # pyproject.toml 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/cronjob_unit_tests.yml: -------------------------------------------------------------------------------- 1 | name: Cron Test Dependencies 2 | 3 | # Controls when the action will run. 4 | # Every sunday at 4:05 5 | # See https://crontab.guru/#5 4 * * 0 6 | on: 7 | schedule: 8 | - cron: "5 4 * * 0" 9 | 10 | jobs: 11 | run: 12 | name: Run unit tests 13 | runs-on: ${{ matrix.os }} 14 | strategy: 15 | matrix: 16 | build: [macos, ubuntu, windows] 17 | include: 18 | - build: macos 19 | os: macos-latest 20 | - build: ubuntu 21 | os: ubuntu-latest 22 | - build: windows 23 | os: windows-latest 24 | python-version: [3.9, "3.10", "3.11", "3.12"] 25 | steps: 26 | - uses: actions/checkout@master 27 | 28 | - name: Get latest CMake and Ninja 29 | uses: lukka/get-cmake@latest 30 | with: 31 | cmakeVersion: latest 32 | ninjaVersion: latest 33 | 34 | - name: Install LIBOMP on Macos runners 35 | if: runner.os == 'macOS' 36 | run: | 37 | brew install libomp 38 | 39 | - name: Setup Python 40 | uses: actions/setup-python@master 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | 44 | - name: Install Python dependencies 45 | run: | 46 | pip3 install --upgrade setuptools pip 47 | pip3 install ".[all]" 48 | 49 | - name: Run linters 50 | run: | 51 | pre-commit install 52 | pre-commit run --all-files 53 | 54 | - name: Run (unit) tests 55 | env: 56 | TEST_NOTEBOOKS: 0 57 | run: | 58 | pytest --cov=probatus/binning --cov=probatus/metric_volatility --cov=probatus/missing_values --cov=probatus/sample_similarity --cov=probatus/stat_tests --cov=probatus/utils --cov=probatus/interpret/ --ignore==tests/interpret/test_inspector.py --cov-report=xml 59 | pyflakes probatus 60 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@master 12 | - name: Set up Python 13 | uses: actions/setup-python@master 14 | with: 15 | python-version: "3.10" 16 | - name: Install dependencies 17 | run: | 18 | pip3 install --upgrade setuptools pip 19 | pip3 install ".[all]" 20 | - name: Run unit tests and linters 21 | run: | 22 | pytest 23 | - name: Build package & publish to PyPi 24 | env: 25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 27 | run: | 28 | pip3 install --upgrade wheel twine build 29 | python -m build 30 | twine upload dist/* 31 | - name: Deploy mkdocs site 32 | run: | 33 | mkdocs gh-deploy --force 34 | -------------------------------------------------------------------------------- /.github/workflows/unit_tests.yml: -------------------------------------------------------------------------------- 1 | name: Development 2 | on: 3 | # Trigger the workflow on push or pull request, 4 | # but only for the main branch 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | jobs: 10 | run: 11 | name: Run unit tests 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | build: [macos, ubuntu, windows] 16 | include: 17 | - build: macos 18 | os: macos-latest 19 | - build: ubuntu 20 | os: ubuntu-latest 21 | - build: windows 22 | os: windows-latest 23 | python-version: [3.9, "3.10", "3.11", "3.12"] 24 | steps: 25 | - uses: actions/checkout@master 26 | 27 | - name: Get latest CMake and Ninja 28 | uses: lukka/get-cmake@latest 29 | with: 30 | cmakeVersion: latest 31 | ninjaVersion: latest 32 | 33 | - name: Install LIBOMP on Macos runners 34 | if: runner.os == 'macOS' 35 | run: | 36 | brew install libomp 37 | 38 | - name: Setup Python 39 | uses: actions/setup-python@master 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | 43 | - name: Install Python dependencies 44 | run: | 45 | pip3 install --upgrade setuptools pip 46 | pip3 install ".[all]" 47 | 48 | - name: Run linters 49 | run: | 50 | pre-commit install 51 | pre-commit run --all-files 52 | 53 | - name: Run (unit) tests 54 | env: 55 | TEST_NOTEBOOKS: 0 56 | run: | 57 | pytest --cov=probatus/binning --cov=probatus/metric_volatility --cov=probatus/missing_values --cov=probatus/sample_similarity --cov=probatus/stat_tests --cov=probatus/utils --cov=probatus/interpret/ --ignore==tests/interpret/test_inspector.py --cov-report=xml 58 | pyflakes probatus 59 | 60 | - name: Upload coverage to Codecov 61 | if: github.ref == 'refs/heads/main' 62 | uses: codecov/codecov-action@v1 63 | with: 64 | token: ${{ secrets.CODECOV_TOKEN }} 65 | file: ./coverage.xml 66 | flags: unittests 67 | fail_ci_if_error: false 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | notebooks/private 2 | notebooks/explore 3 | .vscode/ 4 | .idea/ 5 | docs/source/.ipynb_checkpoints/ 6 | notebooks/.ipynb_checkpoints/ 7 | 8 | 9 | # Created by https://www.gitignore.io/api/macos,python,pycharm,jupyternotebooks,visualstudiocode 10 | # Edit at https://www.gitignore.io/?templates=macos,python,pycharm,jupyternotebooks,visualstudiocode 11 | 12 | ### JupyterNotebooks ### 13 | # gitignore template for Jupyter Notebooks 14 | # website: http://jupyter.org/ 15 | 16 | .ipynb_checkpoints 17 | */.ipynb_checkpoints/* 18 | 19 | # IPython 20 | profile_default/ 21 | ipython_config.py 22 | 23 | # Remove previous ipynb_checkpoints 24 | # git rm -r .ipynb_checkpoints/ 25 | 26 | ### macOS ### 27 | # General 28 | .DS_Store 29 | .AppleDouble 30 | .LSOverride 31 | 32 | # Icon must end with two \r 33 | Icon 34 | 35 | # Thumbnails 36 | ._* 37 | 38 | # Files that might appear in the root of a volume 39 | .DocumentRevisions-V100 40 | .fseventsd 41 | .Spotlight-V100 42 | .TemporaryItems 43 | .Trashes 44 | .VolumeIcon.icns 45 | .com.apple.timemachine.donotpresent 46 | 47 | # Directories potentially created on remote AFP share 48 | .AppleDB 49 | .AppleDesktop 50 | Network Trash Folder 51 | Temporary Items 52 | .apdisk 53 | 54 | ### PyCharm ### 55 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 56 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 57 | 58 | # User-specific stuff 59 | .idea/**/workspace.xml 60 | .idea/**/tasks.xml 61 | .idea/**/usage.statistics.xml 62 | .idea/**/dictionaries 63 | .idea/**/shelf 64 | 65 | # Generated files 66 | .idea/**/contentModel.xml 67 | 68 | # Sensitive or high-churn files 69 | .idea/**/dataSources/ 70 | .idea/**/dataSources.ids 71 | .idea/**/dataSources.local.xml 72 | .idea/**/sqlDataSources.xml 73 | .idea/**/dynamic.xml 74 | .idea/**/uiDesigner.xml 75 | .idea/**/dbnavigator.xml 76 | 77 | # Gradle 78 | .idea/**/gradle.xml 79 | .idea/**/libraries 80 | 81 | # Gradle and Maven with auto-import 82 | # When using Gradle or Maven with auto-import, you should exclude module files, 83 | # since they will be recreated, and may cause churn. Uncomment if using 84 | # auto-import. 85 | # .idea/modules.xml 86 | # .idea/*.iml 87 | # .idea/modules 88 | # *.iml 89 | # *.ipr 90 | 91 | # CMake 92 | cmake-build-*/ 93 | 94 | # Mongo Explorer plugin 95 | .idea/**/mongoSettings.xml 96 | 97 | # File-based project format 98 | *.iws 99 | 100 | # IntelliJ 101 | out/ 102 | 103 | # mpeltonen/sbt-idea plugin 104 | .idea_modules/ 105 | 106 | # JIRA plugin 107 | atlassian-ide-plugin.xml 108 | 109 | # Cursive Clojure plugin 110 | .idea/replstate.xml 111 | 112 | # Crashlytics plugin (for Android Studio and IntelliJ) 113 | com_crashlytics_export_strings.xml 114 | crashlytics.properties 115 | crashlytics-build.properties 116 | fabric.properties 117 | 118 | # Editor-based Rest Client 119 | .idea/httpRequests 120 | 121 | # Android studio 3.1+ serialized cache file 122 | .idea/caches/build_file_checksums.ser 123 | 124 | # Add after merge review issue #79 125 | .vscode/ 126 | 127 | ### PyCharm Patch ### 128 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 129 | 130 | # *.iml 131 | # modules.xml 132 | # .idea/misc.xml 133 | # *.ipr 134 | 135 | # Sonarlint plugin 136 | .idea/**/sonarlint/ 137 | 138 | # SonarQube Plugin 139 | .idea/**/sonarIssues.xml 140 | 141 | # Markdown Navigator plugin 142 | .idea/**/markdown-navigator.xml 143 | .idea/**/markdown-navigator/ 144 | 145 | ### Python ### 146 | # Byte-compiled / optimized / DLL files 147 | __pycache__/ 148 | *.py[cod] 149 | *$py.class 150 | 151 | # C extensions 152 | *.so 153 | 154 | # Distribution / packaging 155 | .Python 156 | build/ 157 | develop-eggs/ 158 | dist/ 159 | downloads/ 160 | eggs/ 161 | .eggs/ 162 | lib/ 163 | lib64/ 164 | parts/ 165 | sdist/ 166 | var/ 167 | wheels/ 168 | pip-wheel-metadata/ 169 | share/python-wheels/ 170 | *.egg-info/ 171 | .installed.cfg 172 | *.egg 173 | MANIFEST 174 | 175 | # PyInstaller 176 | # Usually these files are written by a python script from a template 177 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 178 | *.manifest 179 | *.spec 180 | 181 | # Installer logs 182 | pip-log.txt 183 | pip-delete-this-directory.txt 184 | 185 | # Unit test / coverage reports 186 | htmlcov/ 187 | .tox/ 188 | .nox/ 189 | .coverage 190 | .coverage.* 191 | .cache 192 | nosetests.xml 193 | coverage.xml 194 | *.cover 195 | .hypothesis/ 196 | .pytest_cache/ 197 | 198 | # Translations 199 | *.mo 200 | *.pot 201 | 202 | # Scrapy stuff: 203 | .scrapy 204 | 205 | # Sphinx documentation 206 | docs/_build/ 207 | 208 | # PyBuilder 209 | target/ 210 | 211 | # pyenv 212 | .python-version 213 | 214 | # pipenv 215 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 216 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 217 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 218 | # install all needed dependencies. 219 | #Pipfile.lock 220 | 221 | # virtualenv 222 | env/ 223 | venv/ 224 | 225 | # celery beat schedule file 226 | celerybeat-schedule 227 | 228 | # SageMath parsed files 229 | *.sage.py 230 | 231 | # Spyder project settings 232 | .spyderproject 233 | .spyproject 234 | 235 | # Rope project settings 236 | .ropeproject 237 | 238 | # Mr Developer 239 | .mr.developer.cfg 240 | .project 241 | .pydevproject 242 | 243 | # mkdocs documentation 244 | /site 245 | 246 | # mypy 247 | .mypy_cache/ 248 | .dmypy.json 249 | dmypy.json 250 | 251 | # Pyre type checker 252 | .pyre/ 253 | 254 | ### VisualStudioCode ### 255 | .vscode/* 256 | !.vscode/settings.json 257 | !.vscode/tasks.json 258 | !.vscode/launch.json 259 | !.vscode/extensions.json 260 | 261 | ### VisualStudioCode Patch ### 262 | # Ignore all local history of files 263 | .history 264 | 265 | # End of https://www.gitignore.io/api/macos,python,pycharm,jupyternotebooks,visualstudiocode 266 | 267 | # Catboost-related files 268 | catboost* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: check-case-conflict # Different OSes 6 | name: "Check case conflict: Naming of files is compatible with all OSes" 7 | - id: check-docstring-first 8 | name: "Check docstring first: Ensures Docstring present and first" 9 | - id: detect-private-key 10 | name: "Detect private key: Prevent commit of env related keys" 11 | - id: trailing-whitespace 12 | name: "Trailing whitespace: Remove empty spaces" 13 | - repo: https://github.com/nbQA-dev/nbQA 14 | rev: 1.8.5 15 | hooks: 16 | - id: nbqa-ruff 17 | name: "ruff nb: Check for errors, styling issues and complexity" 18 | - id: nbqa-mypy 19 | name: "mypy nb: Static type checking" 20 | - id: nbqa-isort 21 | name: "isort nb: Sort file imports" 22 | - id: nbqa-pyupgrade 23 | name: "pyupgrade nb: Updates code to Python 3.9+ code convention" 24 | args: [&py_version --py38-plus] 25 | - id: nbqa-black 26 | name: "black nb: PEP8 compliant code formatter" 27 | - repo: local 28 | hooks: 29 | - id: mypy 30 | name: "mypy: Static type checking" 31 | entry: mypy 32 | language: system 33 | types: [python] 34 | - repo: local 35 | hooks: 36 | - id: ruff-check 37 | name: "Ruff: Check for errors, styling issues and complexity, and fixes issues if possible (including import order)" 38 | entry: ruff check 39 | language: system 40 | args: [--fix, --no-cache] 41 | - id: ruff-format 42 | name: "Ruff: format code in line with PEP8" 43 | entry: ruff format 44 | language: system 45 | args: [--no-cache] 46 | - repo: local 47 | hooks: 48 | - id: codespell 49 | name: "codespell: Check for grammar" 50 | entry: codespell 51 | language: system 52 | types: [python] 53 | args: [-L mot] # Skip the word "mot" 54 | - repo: https://github.com/asottile/pyupgrade 55 | rev: v3.4.0 56 | hooks: 57 | - id: pyupgrade 58 | name: "pyupgrade: Updates code to Python 3.9+ code convention" 59 | args: [*py_version] 60 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guide 2 | 3 | `Probatus` aims to provide a set of tools that can speed up common workflows around validating regressors & classifiers and the data used to train them. 4 | We're very much open to contributions but there are some things to keep in mind: 5 | 6 | - Discuss the feature and implementation you want to add on Github before you write a PR for it. On disagreements, maintainer(s) will have the final word. 7 | - Features need a somewhat general use case. If the use case is very niche it will be hard for us to consider maintaining it. 8 | - If you’re going to add a feature, consider if you could help out in the maintenance of it. 9 | - When issues or pull requests are not going to be resolved or merged, they should be closed as soon as possible. This is kinder than deciding this after a long period. Our issue tracker should reflect work to be done. 10 | 11 | That said, there are many ways to contribute to Probatus, including: 12 | 13 | - Contribution to code 14 | - Improving the documentation 15 | - Reviewing merge requests 16 | - Investigating bugs 17 | - Reporting issues 18 | 19 | Starting out with open source? See the guide [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/) and have a look at [our issues labelled *good first issue*](https://github.com/ing-bank/probatus/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). 20 | 21 | ## Setup 22 | 23 | Development install: 24 | 25 | ```shell 26 | pip install -e '.[all]' 27 | ``` 28 | 29 | Unit testing: 30 | 31 | ```shell 32 | pytest 33 | ``` 34 | 35 | We use [pre-commit](https://pre-commit.com/) hooks to ensure code styling. Install with: 36 | 37 | ```shell 38 | pre-commit install 39 | ``` 40 | 41 | Now if you install it (which you are encouraged to do), you are encouraged to do the following command before committing your work: 42 | 43 | ```shell 44 | pre-commit run --all-files 45 | ``` 46 | 47 | This will allow you to quickly see if the work you made contains some adaptions that you still might need to make before a pull request is accepted. 48 | 49 | ## Standards 50 | 51 | - Python 3.9+ 52 | - Follow [PEP8](http://pep8.org/) as closely as possible (except line length) 53 | - [google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/) 54 | - Git: Include a short description of *what* and *why* was done, *how* can be seen in the code. Use present tense, imperative mood 55 | - Git: limit the length of the first line to 72 chars. You can use multiple messages to specify a second (longer) line: `git commit -m "Patch load function" -m "This is a much longer explanation of what was done"` 56 | 57 | 58 | ### Code structure 59 | 60 | * Model validation modules assume that trained models passed for validation are developed in a scikit-learn framework (i.e. have predict_proba and other standard functions), or follow a scikit-learn API e.g. XGBoost. 61 | * Every python file used for model validation needs to be in `/probatus/` 62 | * Class structure for a given module should have a base class and specific functionality classes that inherit from base. If a given module implements only a single way of computing the output, the base class is not required. 63 | * Functions should not be as short as possible in terms of lines of code. If a lot of code is needed, try to put together snippets of code into other functions. This make the code more readable, and easier to test. 64 | * Classes follow the probatus API structure: 65 | * Each class implements `fit()`, `compute()` and `fit_compute()` methods. `fit()` is used to fit an object with provided data (unless no fit is required), and `compute()` calculates the output e.g. DataFrame with a report for the user. Lastly, `fit_compute()` applies one after the other. 66 | * If applicable, the `plot()` method presents the user with the appropriate graphs. 67 | * For `compute()` and `plot()`, check if the object is fitted first. 68 | 69 | 70 | ### Documentation 71 | 72 | Documentation is a very crucial part of the project because it ensures usability of the package. We develop the docs in the following way: 73 | 74 | * We use [mkdocs](https://www.mkdocs.org/) with [mkdocs-material](https://squidfunk.github.io/mkdocs-material/) theme. The `docs/` folder contains all the relevant documentation. 75 | * We use `mkdocs serve` to view the documentation locally. Use it to test the documentation everytime you make any changes. 76 | * Maintainers can deploy the docs using `mkdocs gh-deploy`. The documentation is deployed to `https://ing-bank.github.io/probatus/`. 77 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Copyright (c) ING Bank N.V. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | [![pytest](https://github.com/ing-bank/probatus/workflows/Development/badge.svg)](https://github.com/ing-bank/probatus/actions?query=workflow%3A%22Development%22) 4 | [![PyPi Version](https://img.shields.io/pypi/pyversions/probatus)](#) 5 | [![PyPI](https://img.shields.io/pypi/v/probatus)](#) 6 | [![PyPI - Downloads](https://img.shields.io/pypi/dm/probatus)](#) 7 | ![GitHub contributors](https://img.shields.io/github/contributors/ing-bank/probatus) 8 | 9 | # Probatus 10 | 11 | ## Overview 12 | 13 | **Probatus** is a python package that helps validate regression & (multiclass) classification models and the data used to develop them. Main features: 14 | 15 | - [probatus.interpret](https://ing-bank.github.io/probatus/api/model_interpret.html) provides shap-based model interpretation tools 16 | - [probatus.sample_similarity](https://ing-bank.github.io/probatus/api/sample_similarity.html) to compare two datasets using resemblance modelling, f.e. `train` with out-of-time `test`. 17 | - [probatus.feature_elimination.ShapRFECV](https://ing-bank.github.io/probatus/api/feature_elimination.html) provides cross-validated Recursive Feature Elimination using shap feature importance. 18 | 19 | ## Installation 20 | 21 | ```bash 22 | pip install probatus 23 | ``` 24 | 25 | ## Documentation 26 | 27 | Documentation at [ing-bank.github.io/probatus/](https://ing-bank.github.io/probatus/). 28 | 29 | You can also check out blog posts about Probatus: 30 | 31 | - [Open-sourcing ShapRFECV — Improved feature selection powered by SHAP.](https://medium.com/ing-blog/open-sourcing-shaprfecv-improved-feature-selection-powered-by-shap-994fe7861560) 32 | - [Model Explainability — How to choose the right tool?](https://medium.com/ing-blog/model-explainability-how-to-choose-the-right-tool-6c5eabd1a46a) 33 | 34 | ## Contributing 35 | 36 | To learn more about making a contribution to Probatus, please see [`CONTRIBUTING.md`](CONTRIBUTING.md). 37 | -------------------------------------------------------------------------------- /VISION.md: -------------------------------------------------------------------------------- 1 | # The Vision 2 | 3 | This page describes the main principles that drive the development of `Probatus` as well as the general directions, in which the development of the package will be heading. 4 | 5 | ## The Purpose 6 | 7 | `Probatus` has started as a side project of Data Scientists at ING Bank. 8 | Later, we have decided to open-source it, in order to share the tools and enable collaboration with the Data Science community. 9 | 10 | We mainly focus on analyzing the following aspects of building classification models: 11 | - Model input: the quality of the dataset and how to prepare it for modelling, 12 | - Model performance: the quality of the model and stability of the results. 13 | - Model interpretation: understanding the model decision making, 14 | 15 | Our main goals are: 16 | - Continue maintaining the tools that we have built, and make sure that they are well documented and tested 17 | - Continuously extend functionality available in the package 18 | - Build a community of users, which use the package in day-to-day work and learn from each other, while contributing to Probatus 19 | 20 | ## The Principles 21 | 22 | The main principles that drive development of `Probatus` are the following 23 | 24 | - Usefulness - any tool that we build should be useful for a broad range of users, 25 | - Simplicity - simple to understand and analyze steps over state-of-the-art, 26 | - Usability - the developed functionality must be have good documentation, consistent API and work flawlessly with scikit-learn compatible models, 27 | - Reliability - the code that is available for the users should be well tested and reliable, and bugs should be fixed as soon as they are detected. 28 | 29 | ## The Roadmap 30 | 31 | We are open to new ideas, so if you can think of a feature that fits the vision, make an [issue](https://github.com/ing-bank/Probatus/issues) and help us further develop this package. -------------------------------------------------------------------------------- /docs/api/feature_elimination.md: -------------------------------------------------------------------------------- 1 | # Features Elimination 2 | 3 | This module focuses on feature elimination and it contains two classes: 4 | 5 | - [ShapRFECV][probatus.feature_elimination.feature_elimination.ShapRFECV]: Perform Backwards Recursive Feature Elimination, using SHAP feature importance. It supports binary classification, regression models and hyperparameter optimization at every feature elimination step. Also for LGBM, XGBoost and CatBoost it support early stopping of the model fitting process. It can be an alternative regularization technique to hyperparameter optimization of the number of base trees in gradient boosted tree models. Particularly useful when dealing with large datasets. 6 | 7 | ::: probatus.feature_elimination.feature_elimination 8 | -------------------------------------------------------------------------------- /docs/api/model_interpret.md: -------------------------------------------------------------------------------- 1 | # Model Interpretation using SHAP 2 | 3 | The aim of this module is to provide tools for model interpretation using the [SHAP](https://shap.readthedocs.io/en/latest/) library. 4 | The class below is a convenience wrapper that implements multiple plots for tree-based & linear models. 5 | 6 | ::: probatus.interpret.model_interpret 7 | ::: probatus.interpret.shap_dependence 8 | -------------------------------------------------------------------------------- /docs/api/sample_similarity.md: -------------------------------------------------------------------------------- 1 | # Sample Similarity 2 | 3 | The goal of sample similarity module is understanding how different two samples are from a multivariate perspective. 4 | 5 | One of the ways to indicate this is Resemblance Model. Having two datasets - say X1 and X2 - one can analyse how easy it is to recognize which dataset a randomly selected row comes from. The Resemblance model assigns label 0 to the dataset X1, and label 1 to X2 and trains a binary classification model to predict which sample a given row comes from. 6 | By looking at the test AUC, one can conclude that the samples have a different distribution if the AUC is significantly higher than 0.5. Furthermore, by analysing feature importance one can understand which of the features have predictive power. 7 | 8 | 9 | 10 | 11 | The following features are implemented: 12 | 13 | - [SHAPImportanceResemblance (Recommended)][probatus.sample_similarity.resemblance_model.SHAPImportanceResemblance]: 14 | The class applies SHAP library, in order to interpret the tree based resemblance model. 15 | - [PermutationImportanceResemblance][probatus.sample_similarity.resemblance_model.PermutationImportanceResemblance]: 16 | The class applies permutation feature importance in order to understand which features the current model relies on the most. The higher the importance of the feature, the more a given feature possibly differs in X2 compared to X1. The importance indicates how much the test AUC drops if a given feature is permuted. 17 | 18 | 19 | ::: probatus.sample_similarity.resemblance_model 20 | 21 | -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | # Utility Functions 2 | 3 | This module contains various smaller functionalities that can be used across the `probatus` package. 4 | 5 | ::: probatus.utils.scoring 6 | -------------------------------------------------------------------------------- /docs/howto/grouped_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to work with grouped data\n", 8 | "\n", 9 | "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ing-bank/probatus/blob/master/docs/howto/grouped_data.ipynb)\n", 10 | "\n", 11 | "One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.\n", 12 | "\n", 13 | "Let's prepare a dataset with groups." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%%capture\n", 23 | "!pip install probatus" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "text/plain": [ 34 | "[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]" 35 | ] 36 | }, 37 | "execution_count": 1, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "from sklearn.datasets import make_classification\n", 44 | "\n", 45 | "X, y = make_classification(n_samples=100, n_features=10, random_state=42)\n", 46 | "groups = [i % 5 for i in range(100)]\n", 47 | "groups[:10]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "The integers in `groups` variable indicate the group id, to which a given sample belongs.\n", 55 | "\n", 56 | "One of the easiest ways to ensure that the data is split using the information about groups is using `from sklearn.model_selection import GroupKFold`. You can also read more about other ways of splitting data with groups in sklearn [here](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data)." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 2, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from sklearn.model_selection import GroupKFold\n", 66 | "\n", 67 | "cv = list(GroupKFold(n_splits=5).split(X, y, groups=groups))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "Such variable can be passed to the `cv` parameter in `probatus` as well as to hyperparameter optimization e.g. `RandomizedSearchCV` classes." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "from sklearn.ensemble import RandomForestClassifier\n", 84 | "from sklearn.model_selection import RandomizedSearchCV\n", 85 | "\n", 86 | "from probatus.feature_elimination import ShapRFECV\n", 87 | "\n", 88 | "model = RandomForestClassifier(random_state=42)\n", 89 | "\n", 90 | "param_grid = {\n", 91 | " \"n_estimators\": [5, 7, 10],\n", 92 | " \"max_leaf_nodes\": [3, 5, 7, 10],\n", 93 | "}\n", 94 | "search = RandomizedSearchCV(model, param_grid, cv=cv, n_iter=1, random_state=42)\n", 95 | "\n", 96 | "shap_elimination = ShapRFECV(model=search, step=0.2, cv=cv, scoring=\"roc_auc\", n_jobs=3, random_state=42)\n", 97 | "report = shap_elimination.fit_compute(X, y)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/html": [ 108 | "
\n", 109 | "\n", 122 | "\n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | "
num_featuresfeatures_seteliminated_featurestrain_metric_meantrain_metric_stdval_metric_meanval_metric_std
110[0, 1, 2, 3, 4, 5, 6, 7, 8, 9][6, 7]0.9995620.0008760.9549450.090110
28[0, 1, 2, 3, 4, 5, 8, 9][5]0.9991180.0010810.9455130.089606
37[0, 1, 2, 3, 4, 8, 9][4]0.9995590.0005480.9287490.137507
46[0, 1, 2, 3, 8, 9][8]0.9991790.0010510.9692880.058854
55[0, 1, 2, 3, 9][9]0.9997480.0002370.9617670.066540
64[0, 1, 2, 3][1]0.9994330.0007000.9508160.090982
73[0, 2, 3][0]0.9991200.0007290.9705960.051567
82[2, 3][3]0.9994960.0006170.9386390.117736
91[2][]0.9984240.0018190.9383390.097936
\n", 228 | "
" 229 | ], 230 | "text/plain": [ 231 | " num_features features_set eliminated_features \\\n", 232 | "1 10 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [6, 7] \n", 233 | "2 8 [0, 1, 2, 3, 4, 5, 8, 9] [5] \n", 234 | "3 7 [0, 1, 2, 3, 4, 8, 9] [4] \n", 235 | "4 6 [0, 1, 2, 3, 8, 9] [8] \n", 236 | "5 5 [0, 1, 2, 3, 9] [9] \n", 237 | "6 4 [0, 1, 2, 3] [1] \n", 238 | "7 3 [0, 2, 3] [0] \n", 239 | "8 2 [2, 3] [3] \n", 240 | "9 1 [2] [] \n", 241 | "\n", 242 | " train_metric_mean train_metric_std val_metric_mean val_metric_std \n", 243 | "1 0.999562 0.000876 0.954945 0.090110 \n", 244 | "2 0.999118 0.001081 0.945513 0.089606 \n", 245 | "3 0.999559 0.000548 0.928749 0.137507 \n", 246 | "4 0.999179 0.001051 0.969288 0.058854 \n", 247 | "5 0.999748 0.000237 0.961767 0.066540 \n", 248 | "6 0.999433 0.000700 0.950816 0.090982 \n", 249 | "7 0.999120 0.000729 0.970596 0.051567 \n", 250 | "8 0.999496 0.000617 0.938639 0.117736 \n", 251 | "9 0.998424 0.001819 0.938339 0.097936 " 252 | ] 253 | }, 254 | "execution_count": 4, 255 | "metadata": {}, 256 | "output_type": "execute_result" 257 | } 258 | ], 259 | "source": [ 260 | "report" 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.10.13" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 4 285 | } 286 | -------------------------------------------------------------------------------- /docs/howto/reproducibility.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# How to ensure reproducibility of the results" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ing-bank/probatus/blob/master/docs/howto/reproducibility.ipynb)\n", 15 | "\n", 16 | "This page describes how to make sure that the analysis that you perform using `probatus` is fully reproducible.\n", 17 | "\n", 18 | "There are two factors that influence reproducibility of the results:\n", 19 | "\n", 20 | "- Inputs of `probatus` modules,\n", 21 | "- The `random_state` of `probatus` modules.\n", 22 | "\n", 23 | "The below sections cover how to ensure reproducibility of the results by controling these aspects.\n", 24 | "\n", 25 | "## Inputs of probatus modules\n", 26 | "\n", 27 | "There are various parameters that modules of probatus take as input. Below we will cover the most often occurring ones.\n", 28 | "\n", 29 | "### Static dataset\n", 30 | "\n", 31 | "When using `probatus`, one of the most crucial aspects is the provided dataset. Therefore, the first thing to do is to ensure that the passed dataset does not change along the way. \n", 32 | "\n", 33 | "Below is a code snipped of random data preparation. In sklearn, you can ensure this by setting the `random_state` parameter. You will probably use a different dataset in your projects, but always make sure that the input data is static." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "%%capture\n", 43 | "!pip install probatus" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 1, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "from sklearn.datasets import make_classification\n", 53 | "\n", 54 | "X, y = make_classification(n_samples=100, n_features=10, random_state=42)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "### Static data splits\n", 62 | "\n", 63 | "Whenever you split the data in any way, you need to make sure that the splits are always the same. \n", 64 | "\n", 65 | "If you use the `train_test_split` functionality from sklearn, this can be enforced by setting the `random_state` parameter. \n", 66 | "\n", 67 | "Another crucial aspect is how you use the `cv` parameter, which defines the folds settings that you will use in the experiments. If the `cv` is set to an integer, you don't need to worry about it - the `random_state` of `probatus` will take care of it. However, if you want to pass a custom cv generator object, you have to set the `random_state` there as well.\n", 68 | "\n", 69 | "Below are some examples of static splits:" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "from sklearn.model_selection import StratifiedKFold, train_test_split\n", 79 | "\n", 80 | "# Static train/test split\n", 81 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n", 82 | "\n", 83 | "# Static CV settings\n", 84 | "cv1 = 5\n", 85 | "cv2 = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### Static classifier\n", 93 | "\n", 94 | "Most of `probatus` modules work with the provided classifiers. Whenever one needs to provide a not-fitted classifier, it is enough to set the `random_state`. However, if the classifier needs to be fitted beforehand, you have to make sure that the model training is reproducible as well." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 3, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "from sklearn.ensemble import RandomForestClassifier\n", 104 | "\n", 105 | "model = RandomForestClassifier(random_state=42)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "### Static search CV for hyperparameter tuning\n", 113 | "\n", 114 | "Some of the modules e.g. `ShapRFECV`, allow you to perform optimization of the model. Whenever, you use such functionality, make sure that these classes have set the `random_state`. This way, in every round of optimization, you will explore the same set of parameter permutations. In case the search space is also generated based on randomness, make sure that the `random_state` is set to it as well." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from sklearn.model_selection import RandomizedSearchCV\n", 124 | "\n", 125 | "param_grid = {\n", 126 | " \"n_estimators\": [5, 7, 10],\n", 127 | " \"max_leaf_nodes\": [3, 5, 7, 10],\n", 128 | "}\n", 129 | "search = RandomizedSearchCV(model, param_grid, n_iter=1, random_state=42)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "### Any other sources of randomness" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "Before running `probatus` modules think about the inputs, and consider if there is any other type of randomness involved. If there is, one option to possibly solve the issue is setting the random seed at the beginning of the code." 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 5, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "# Optional step\n", 153 | "import numpy as np\n", 154 | "\n", 155 | "np.random.seed(42)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "## Reproducibility in probatus\n", 163 | "\n", 164 | "Most of the modules in `probatus` allow you to set the `random_state`. This setting essentially makes sure that any code that the functions operate on has a static flow. As long as it is seet and you ensure all other inputs do not cause additional fluctuations between runs, you can make sure that your results are reproducible." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "from probatus.feature_elimination import ShapRFECV\n", 174 | "\n", 175 | "shap_elimination = ShapRFECV(model=search, step=0.2, cv=cv2, scoring=\"roc_auc\", n_jobs=3, random_state=42)\n", 176 | "report = shap_elimination.fit_compute(X, y)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 7, 182 | "metadata": {}, 183 | "outputs": [ 184 | { 185 | "data": { 186 | "text/html": [ 187 | "
\n", 188 | "\n", 201 | "\n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | "
num_featureseliminated_featuresval_metric_mean
110[8, 9]0.983
28[5]0.969
37[7]0.984
46[6]0.979
55[4]0.983
64[1]0.987
73[0]0.991
82[3]0.991
91[]0.969
\n", 267 | "
" 268 | ], 269 | "text/plain": [ 270 | " num_features eliminated_features val_metric_mean\n", 271 | "1 10 [8, 9] 0.983\n", 272 | "2 8 [5] 0.969\n", 273 | "3 7 [7] 0.984\n", 274 | "4 6 [6] 0.979\n", 275 | "5 5 [4] 0.983\n", 276 | "6 4 [1] 0.987\n", 277 | "7 3 [0] 0.991\n", 278 | "8 2 [3] 0.991\n", 279 | "9 1 [] 0.969" 280 | ] 281 | }, 282 | "execution_count": 7, 283 | "metadata": {}, 284 | "output_type": "execute_result" 285 | } 286 | ], 287 | "source": [ 288 | "report[[\"num_features\", \"eliminated_features\", \"val_metric_mean\"]]" 289 | ] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "Python 3", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.10.13" 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 4 313 | } 314 | -------------------------------------------------------------------------------- /docs/img/Probatus_P.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/Probatus_P.png -------------------------------------------------------------------------------- /docs/img/Probatus_P_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/Probatus_P_white.png -------------------------------------------------------------------------------- /docs/img/earlystoppingshaprfecv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/earlystoppingshaprfecv.png -------------------------------------------------------------------------------- /docs/img/logo_large.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/logo_large.png -------------------------------------------------------------------------------- /docs/img/logo_large_white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/logo_large_white.png -------------------------------------------------------------------------------- /docs/img/model_interpret_dep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_dep.png -------------------------------------------------------------------------------- /docs/img/model_interpret_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_importance.png -------------------------------------------------------------------------------- /docs/img/model_interpret_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_sample.png -------------------------------------------------------------------------------- /docs/img/model_interpret_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_summary.png -------------------------------------------------------------------------------- /docs/img/resemblance_model_schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/resemblance_model_schema.png -------------------------------------------------------------------------------- /docs/img/sample_similarity_permutation_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_permutation_importance.png -------------------------------------------------------------------------------- /docs/img/sample_similarity_shap_importance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_shap_importance.png -------------------------------------------------------------------------------- /docs/img/sample_similarity_shap_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_shap_summary.png -------------------------------------------------------------------------------- /docs/img/shaprfecv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/shaprfecv.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | **Probatus** is a Python library that allows to analyse binary classification models as well as the data used to develop them. 4 | The main features assess the metric stability and analyse differences between two data samples e.g. shift between train and test splits. 5 | 6 | ## Installation 7 | 8 | In order to install Probatus you need to use Python 3.9 or higher. 9 | 10 | Install `probatus` via pip with: 11 | 12 | ```bash 13 | pip install probatus 14 | ``` 15 | 16 | Alternatively you can fork/clone and run: 17 | 18 | ```bash 19 | git clone https://gitlab.com/ing_rpaa/probatus.git 20 | cd probatus 21 | pip install . 22 | ``` 23 | 24 | ## Licence 25 | 26 | Probatus is created under MIT License, see more in [LICENCE file](https://github.com/ing-bank/probatus/blob/main/LICENCE). 27 | 28 | 29 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Probatus 2 | 3 | repo_url: https://github.com/ing-bank/probatus/ 4 | site_url: https://ing-bank.github.io/probatus/ 5 | site_description: Validation of regressors and classifiers and data used to develop them 6 | site_author: ING Bank N. V. 7 | 8 | use_directory_urls: false 9 | 10 | watch: 11 | - probatus 12 | 13 | plugins: 14 | - mkdocs-jupyter 15 | - search 16 | - mkdocstrings: 17 | handlers: 18 | python: 19 | options: 20 | selection: 21 | inherited_members: true 22 | filters: 23 | - "!^Base" 24 | - "!^_" # exlude all members starting with _ 25 | - "^__init__$" # but always include __init__ modules and methods 26 | rendering: 27 | show_root_toc_entry: false 28 | 29 | theme: 30 | name: material 31 | logo: img/Probatus_P_white.png 32 | favicon: img/Probatus_P_white.png 33 | font: 34 | text: Ubuntu 35 | code: Ubuntu Mono 36 | features: 37 | - navigation.tabs 38 | palette: 39 | scheme: default 40 | primary: deep orange 41 | accent: indigo 42 | 43 | copyright: Copyright © ING Bank N.V. -------------------------------------------------------------------------------- /probatus/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ING Bank N.V. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | name = "probatus" 21 | -------------------------------------------------------------------------------- /probatus/feature_elimination/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_elimination import ShapRFECV 2 | from .early_stopping_feature_elimination import EarlyStoppingShapRFECV 3 | 4 | __all__ = ["ShapRFECV", "EarlyStoppingShapRFECV"] 5 | -------------------------------------------------------------------------------- /probatus/feature_elimination/early_stopping_feature_elimination.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from probatus.feature_elimination import ShapRFECV 3 | 4 | 5 | class EarlyStoppingShapRFECV(ShapRFECV): 6 | """ 7 | This class performs Backwards Recursive Feature Elimination, using SHAP feature importance. 8 | 9 | This is a child of ShapRFECV which allows early stopping of the training step, this class is compatible with 10 | LightGBM, XGBoost and CatBoost models. If you are not using early stopping, you should use the parent class, 11 | ShapRFECV, instead of EarlyStoppingShapRFECV. 12 | 13 | [Early stopping](https://en.wikipedia.org/wiki/Early_stopping) is a type of 14 | regularization technique in which the model is trained until the scoring metric, measured on a validation set, 15 | stops improving after a number of early_stopping_rounds. In boosted tree models, this technique can increase 16 | the training speed, by skipping the training of trees that do not improve the scoring metric any further, 17 | which is particularly useful when the training dataset is large. 18 | 19 | Note that if the regressor or classifier is a hyperparameter search model is used, the early stopping parameter is passed only 20 | to the fit method of the model duiring the Shapley values estimation step, and not for the hyperparameter 21 | search step. 22 | Early stopping can be seen as a type of regularization of the optimal number of trees. Therefore you can use 23 | it directly with a LightGBM or XGBoost model, as an alternative to a hyperparameter search model. 24 | 25 | At each round, for a 26 | given feature set, starting from all available features, the following steps are applied: 27 | 28 | 1. (Optional) Tune the hyperparameters of the model using sklearn compatible search CV e.g. 29 | [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html), 30 | [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV), or 31 | [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html). 32 | Note that during this step the model does not use early stopping. 33 | 2. Apply Cross-validation (CV) to estimate the SHAP feature importance on the provided dataset. In each CV 34 | iteration, the model is fitted on the train folds, and applied on the validation fold to estimate 35 | SHAP feature importance. The model is trained until the scoring metric eval_metric, measured on the 36 | validation fold, stops improving after a number of early_stopping_rounds. 37 | 3. Remove `step` lowest SHAP importance features from the dataset. 38 | 39 | At the end of the process, the user can plot the performance of the model for each iteration, and select the 40 | optimal number of features and the features set. 41 | 42 | We recommend using [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html), 43 | because by default it handles missing values and categorical features. In case of other models, make sure to 44 | handle these issues for your dataset and consider impact it might have on features importance. 45 | 46 | 47 | Example: 48 | ```python 49 | from lightgbm import LGBMClassifier 50 | import pandas as pd 51 | from probatus.feature_elimination import EarlyStoppingShapRFECV 52 | from sklearn.datasets import make_classification 53 | 54 | feature_names = [ 55 | 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 56 | 'f8', 'f9', 'f10', 'f11', 'f12', 'f13', 57 | 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20'] 58 | 59 | # Prepare two samples 60 | X, y = make_classification(n_samples=200, class_sep=0.05, n_informative=6, n_features=20, 61 | random_state=0, n_redundant=10, n_clusters_per_class=1) 62 | X = pd.DataFrame(X, columns=feature_names) 63 | 64 | # Prepare model 65 | model = LGBMClassifier(n_estimators=200, max_depth=3) 66 | 67 | # Run feature elimination 68 | shap_elimination = EarlyStoppingShapRFECV( 69 | model=model, step=0.2, cv=10, scoring='roc_auc', early_stopping_rounds=10, n_jobs=3) 70 | report = shap_elimination.fit_compute(X, y) 71 | 72 | # Make plots 73 | performance_plot = shap_elimination.plot() 74 | 75 | # Get final feature set 76 | final_features_set = shap_elimination.get_reduced_features_set(num_features=3) 77 | ``` 78 | 79 | 80 | """ # noqa 81 | 82 | def __init__( 83 | self, 84 | model, 85 | step=1, 86 | min_features_to_select=1, 87 | cv=None, 88 | scoring="roc_auc", 89 | n_jobs=-1, 90 | verbose=0, 91 | random_state=None, 92 | early_stopping_rounds=5, 93 | eval_metric="auc", 94 | ): 95 | """ 96 | This method initializes the class. 97 | 98 | Args: 99 | model (sklearn compatible classifier or regressor, sklearn compatible search CV e.g. GridSearchCV, RandomizedSearchCV or BayesSearchCV): 100 | A model that will be optimized and trained at each round of features elimination. The model must 101 | support early stopping of training, which is the case for XGBoost and LightGBM, for example. The 102 | recommended model is [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html), 103 | because it by default handles the missing values and categorical variables. This parameter also supports 104 | any hyperparameter search schema that is consistent with the sklearn API e.g. 105 | [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html), 106 | [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html) 107 | or [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html#skopt.BayesSearchCV). 108 | Note that if a hyperparemeter search model is used, the hyperparameters are tuned without early 109 | stopping. Early stopping is applied only during the Shapley values estimation for feature 110 | elimination. We recommend simply passing the model without hyperparameter optimization, or using 111 | ShapRFECV without early stopping. 112 | 113 | 114 | step (int or float, optional): 115 | Number of lowest importance features removed each round. If it is an int, then each round such number of 116 | features is discarded. If float, such percentage of remaining features (rounded down) is removed each 117 | iteration. It is recommended to use float, since it is faster for a large number of features, and slows 118 | down and becomes more precise towards less features. Note: the last round may remove fewer features in 119 | order to reach min_features_to_select. 120 | If columns_to_keep parameter is specified in the fit method, step is the number of features to remove after 121 | keeping those columns. 122 | 123 | min_features_to_select (int, optional): 124 | Minimum number of features to be kept. This is a stopping criterion of the feature elimination. By 125 | default the process stops when one feature is left. If columns_to_keep is specified in the fit method, 126 | it may override this parameter to the maximum between length of columns_to_keep the two. 127 | 128 | cv (int, cross-validation generator or an iterable, optional): 129 | Determines the cross-validation splitting strategy. Compatible with sklearn 130 | [cv parameter](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFECV.html). 131 | If None, then cv of 5 is used. 132 | 133 | scoring (string or probatus.utils.Scorer, optional): 134 | Metric for which the model performance is calculated. It can be either a metric name aligned with predefined 135 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html). 136 | Another option is using probatus.utils.Scorer to define a custom metric. 137 | 138 | n_jobs (int, optional): 139 | Number of cores to run in parallel while fitting across folds. None means 1 unless in a 140 | `joblib.parallel_backend` context. -1 means using all processors. 141 | 142 | verbose (int, optional): 143 | Controls verbosity of the output: 144 | 145 | - 0 - neither prints nor warnings are shown 146 | - 1 - only most important warnings 147 | - 2 - shows all prints and all warnings. 148 | 149 | random_state (int, optional): 150 | Random state set at each round of feature elimination. If it is None, the results will not be 151 | reproducible and in random search at each iteration a different hyperparameters might be tested. For 152 | reproducible results set it to integer. 153 | 154 | early_stopping_rounds (int, optional): 155 | Number of rounds with constant performance after which the model fitting stops. This is passed to the 156 | fit method of the model for Shapley values estimation, but not for hyperparameter search. Only 157 | supported by some models, such as XGBoost and LightGBM. 158 | 159 | eval_metric (str, optional): 160 | Metric for scoring fitting rounds and activating early stopping. This is passed to the 161 | fit method of the model for Shapley values estimation, but not for hyperparameter search. Only 162 | supported by some models, such as [XGBoost](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters) 163 | and [LightGBM](https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters). 164 | Note that `eval_metric` is an argument of the model's fit method and it is different from `scoring`. 165 | """ # noqa 166 | # TODO: This deprecation warning will removed when it's decided that this class can be deleted. 167 | warnings.warn( 168 | "The separate EarlyStoppingShapRFECV class is going to be deprecated" 169 | " in a later version of Probatus, since its now part of the" 170 | " ShapRFECV class. Please adjust your imported class name from" 171 | " 'EarlyStoppingShapRFECV' to 'ShapRFECV'.", 172 | DeprecationWarning, 173 | ) 174 | 175 | super().__init__( 176 | model, 177 | step=step, 178 | min_features_to_select=min_features_to_select, 179 | cv=cv, 180 | scoring=scoring, 181 | n_jobs=n_jobs, 182 | verbose=verbose, 183 | random_state=random_state, 184 | early_stopping_rounds=early_stopping_rounds, 185 | eval_metric=eval_metric, 186 | ) 187 | -------------------------------------------------------------------------------- /probatus/interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from .shap_dependence import DependencePlotter 2 | from .model_interpret import ShapModelInterpreter 3 | 4 | __all__ = ["DependencePlotter", "ShapModelInterpreter"] 5 | -------------------------------------------------------------------------------- /probatus/interpret/model_interpret.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | from shap import summary_plot 5 | from shap.plots._waterfall import waterfall_legacy 6 | 7 | from probatus.interpret import DependencePlotter 8 | from probatus.utils import ( 9 | BaseFitComputePlotClass, 10 | assure_list_of_strings, 11 | calculate_shap_importance, 12 | preprocess_data, 13 | preprocess_labels, 14 | get_single_scorer, 15 | shap_calc, 16 | ) 17 | 18 | 19 | class ShapModelInterpreter(BaseFitComputePlotClass): 20 | """ 21 | This class is a wrapper that allows to easily analyse a model's features. 22 | 23 | It allows us to plot SHAP feature importance, 24 | SHAP summary plot and SHAP dependence plots. 25 | 26 | Example: 27 | ```python 28 | from sklearn.datasets import make_classification 29 | from sklearn.ensemble import RandomForestClassifier 30 | from sklearn.model_selection import train_test_split 31 | from probatus.interpret import ShapModelInterpreter 32 | import numpy as np 33 | import pandas as pd 34 | 35 | feature_names = ['f1', 'f2', 'f3', 'f4'] 36 | 37 | # Prepare two samples 38 | X, y = make_classification(n_samples=5000, n_features=4, random_state=0) 39 | X = pd.DataFrame(X, columns=feature_names) 40 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 41 | 42 | # Prepare and fit model. Remember about class_weight="balanced" or an equivalent. 43 | model = RandomForestClassifier(class_weight='balanced', n_estimators = 100, max_depth=2, random_state=0) 44 | model.fit(X_train, y_train) 45 | 46 | # Train ShapModelInterpreter 47 | shap_interpreter = ShapModelInterpreter(model) 48 | feature_importance = shap_interpreter.fit_compute(X_train, X_test, y_train, y_test) 49 | 50 | # Make plots 51 | ax1 = shap_interpreter.plot('importance') 52 | ax2 = shap_interpreter.plot('summary') 53 | ax3 = shap_interpreter.plot('dependence', target_columns=['f1', 'f2']) 54 | ax4 = shap_interpreter.plot('sample', samples_index=[X_test.index.tolist()[0]]) 55 | ``` 56 | 57 | 58 | 59 | 60 | 61 | """ 62 | 63 | def __init__(self, model, scoring="roc_auc", verbose=0, random_state=None): 64 | """ 65 | Initializes the class. 66 | 67 | Args: 68 | model (classifier or regressor): 69 | Model fitted on X_train. 70 | 71 | scoring (string or probatus.utils.Scorer, optional): 72 | Metric for which the model performance is calculated. It can be either a metric name aligned with 73 | predefined classification scorers names in sklearn 74 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)). 75 | Another option is using probatus.utils.Scorer to define a custom metric. 76 | 77 | verbose (int, optional): 78 | Controls verbosity of the output: 79 | 80 | - 0 - neither prints nor warnings are shown 81 | - 1 - only most important warnings 82 | - 2 - shows all prints and all warnings. 83 | 84 | random_state (int, optional): 85 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For 86 | reproducible results set it to an integer. 87 | """ 88 | self.model = model 89 | self.scorer = get_single_scorer(scoring) 90 | self.verbose = verbose 91 | self.random_state = random_state 92 | 93 | def fit( 94 | self, 95 | X_train, 96 | X_test, 97 | y_train, 98 | y_test, 99 | column_names=None, 100 | class_names=None, 101 | **shap_kwargs, 102 | ): 103 | """ 104 | Fits the object and calculates the shap values for the provided datasets. 105 | 106 | Args: 107 | X_train (pd.DataFrame): 108 | Dataframe containing training data. 109 | 110 | X_test (pd.DataFrame): 111 | Dataframe containing test data. 112 | 113 | y_train (pd.Series): 114 | Series of labels for train data. 115 | 116 | y_test (pd.Series): 117 | Series of labels for test data. 118 | 119 | column_names (None, or list of str, optional): 120 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used. 121 | 122 | class_names (None, or list of str, optional): 123 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are 124 | used. 125 | 126 | **shap_kwargs: 127 | keyword arguments passed to 128 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer). 129 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values. 130 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while 131 | `check_additivity=False` disables the additivity check inside SHAP. 132 | """ 133 | 134 | self.X_train, self.column_names = preprocess_data( 135 | X_train, X_name="X_train", column_names=column_names, verbose=self.verbose 136 | ) 137 | self.X_test, _ = preprocess_data(X_test, X_name="X_test", column_names=column_names, verbose=self.verbose) 138 | self.y_train = preprocess_labels(y_train, y_name="y_train", index=self.X_train.index, verbose=self.verbose) 139 | self.y_test = preprocess_labels(y_test, y_name="y_test", index=self.X_test.index, verbose=self.verbose) 140 | 141 | # Set class names 142 | self.class_names = class_names 143 | if self.class_names is None: 144 | self.class_names = ["Negative Class", "Positive Class"] 145 | 146 | # Calculate Metrics 147 | self.train_score = self.scorer.score(self.model, self.X_train, self.y_train) 148 | self.test_score = self.scorer.score(self.model, self.X_test, self.y_test) 149 | 150 | self.results_text = ( 151 | f"Train {self.scorer.metric_name}: {np.round(self.train_score, 3)},\n" 152 | f"Test {self.scorer.metric_name}: {np.round(self.test_score, 3)}." 153 | ) 154 | 155 | ( 156 | self.shap_values_train, 157 | self.expected_value_train, 158 | self.tdp_train, 159 | ) = self._prep_shap_related_variables( 160 | model=self.model, 161 | X=self.X_train, 162 | y=self.y_train, 163 | column_names=self.column_names, 164 | class_names=self.class_names, 165 | verbose=self.verbose, 166 | random_state=self.random_state, 167 | **shap_kwargs, 168 | ) 169 | 170 | ( 171 | self.shap_values_test, 172 | self.expected_value_test, 173 | self.tdp_test, 174 | ) = self._prep_shap_related_variables( 175 | model=self.model, 176 | X=self.X_test, 177 | y=self.y_test, 178 | column_names=self.column_names, 179 | class_names=self.class_names, 180 | verbose=self.verbose, 181 | random_state=self.random_state, 182 | **shap_kwargs, 183 | ) 184 | 185 | self.fitted = True 186 | 187 | @staticmethod 188 | def _prep_shap_related_variables( 189 | model, 190 | X, 191 | y, 192 | approximate=False, 193 | verbose=0, 194 | random_state=None, 195 | column_names=None, 196 | class_names=None, 197 | **shap_kwargs, 198 | ): 199 | """ 200 | The function prepares the variables related to shap that are used to interpret the model. 201 | 202 | Returns: 203 | (np.array, int, DependencePlotter): 204 | Shap values, expected value of the explainer, and fitted TreeDependencePlotter for a given dataset. 205 | """ 206 | shap_values, explainer = shap_calc( 207 | model, 208 | X, 209 | approximate=approximate, 210 | verbose=verbose, 211 | random_state=random_state, 212 | return_explainer=True, 213 | **shap_kwargs, 214 | ) 215 | 216 | expected_value = explainer.expected_value 217 | 218 | # For sklearn models the expected values consists of two elements (negative_class and positive_class) 219 | if isinstance(expected_value, list) or isinstance(expected_value, np.ndarray): 220 | expected_value = expected_value[1] 221 | 222 | # Initialize tree dependence plotter 223 | tdp = DependencePlotter(model, verbose=verbose).fit( 224 | X, 225 | y, 226 | column_names=column_names, 227 | class_names=class_names, 228 | precalc_shap=shap_values, 229 | ) 230 | return shap_values, expected_value, tdp 231 | 232 | def compute(self, return_scores=False, shap_variance_penalty_factor=None): 233 | """ 234 | Computes the DataFrame that presents the importance of each feature. 235 | 236 | Args: 237 | return_scores (bool, optional): 238 | Flag indicating whether the method should return the train and test score of the model, together with 239 | the model interpretation report. If true, the output of this method is a tuple of DataFrame, float, 240 | float. 241 | 242 | shap_variance_penalty_factor (int or float, optional): 243 | Apply aggregation penalty when computing average of shap values for a given feature. 244 | Results in a preference for features that have smaller standard deviation of shap 245 | values (more coherent shap importance). Recommend value 0.5 - 1.0. 246 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor)) 247 | 248 | Returns: 249 | (pd.DataFrame or tuple(pd.DataFrame, float, float)): 250 | Dataframe with SHAP feature importance, or tuple containing the dataframe, train and test scores of the 251 | model. 252 | """ 253 | self._check_if_fitted() 254 | 255 | # Compute SHAP importance 256 | self.importance_df_train = calculate_shap_importance( 257 | self.shap_values_train, 258 | self.column_names, 259 | output_columns_suffix="_train", 260 | shap_variance_penalty_factor=shap_variance_penalty_factor, 261 | ) 262 | 263 | self.importance_df_test = calculate_shap_importance( 264 | self.shap_values_test, 265 | self.column_names, 266 | output_columns_suffix="_test", 267 | shap_variance_penalty_factor=shap_variance_penalty_factor, 268 | ) 269 | 270 | # Concatenate the train and test, sort by test set importance and reorder the columns 271 | self.importance_df = pd.concat([self.importance_df_train, self.importance_df_test], axis=1).sort_values( 272 | "mean_abs_shap_value_test", ascending=False 273 | )[ 274 | [ 275 | "mean_abs_shap_value_test", 276 | "mean_abs_shap_value_train", 277 | "mean_shap_value_test", 278 | "mean_shap_value_train", 279 | ] 280 | ] 281 | 282 | if return_scores: 283 | return self.importance_df, self.train_score, self.test_score 284 | else: 285 | return self.importance_df 286 | 287 | def fit_compute( 288 | self, 289 | X_train, 290 | X_test, 291 | y_train, 292 | y_test, 293 | column_names=None, 294 | class_names=None, 295 | return_scores=False, 296 | shap_variance_penalty_factor=None, 297 | **shap_kwargs, 298 | ): 299 | """ 300 | Fits the object and calculates the shap values for the provided datasets. 301 | 302 | Args: 303 | X_train (pd.DataFrame): 304 | Dataframe containing training data. 305 | 306 | X_test (pd.DataFrame): 307 | Dataframe containing test data. 308 | 309 | y_train (pd.Series): 310 | Series of labels for train data. 311 | 312 | y_test (pd.Series): 313 | Series of labels for test data. 314 | 315 | column_names (None, or list of str, optional): 316 | List of feature names for the dataset. 317 | If None, then column names from the X_train dataframe are used. 318 | 319 | class_names (None, or list of str, optional): 320 | List of class names e.g. ['neg', 'pos']. 321 | If none, the default ['Negative Class', 'Positive Class'] are 322 | used. 323 | 324 | return_scores (bool, optional): 325 | Flag indicating whether the method should return 326 | the train and test score of the model, 327 | together with the model interpretation report. If true, 328 | the output of this method is a tuple of DataFrame, float, 329 | float. 330 | 331 | shap_variance_penalty_factor (int or float, optional): 332 | Apply aggregation penalty when computing average of shap values for a given feature. 333 | Results in a preference for features that have smaller standard deviation of shap 334 | values (more coherent shap importance). Recommend value 0.5 - 1.0. 335 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor)) 336 | 337 | **shap_kwargs: 338 | keyword arguments passed to 339 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer). 340 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values. 341 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while 342 | `check_additivity=False` disables the additivity check inside SHAP. 343 | 344 | Returns: 345 | (pd.DataFrame or tuple(pd.DataFrame, float, float)): 346 | Dataframe with SHAP feature importance, or tuple containing the dataframe, train and test scores of the 347 | model. 348 | """ 349 | self.fit( 350 | X_train=X_train, 351 | X_test=X_test, 352 | y_train=y_train, 353 | y_test=y_test, 354 | column_names=column_names, 355 | class_names=class_names, 356 | **shap_kwargs, 357 | ) 358 | return self.compute(return_scores=return_scores, shap_variance_penalty_factor=shap_variance_penalty_factor) 359 | 360 | def plot(self, plot_type, target_set="test", target_columns=None, samples_index=None, show=True, **plot_kwargs): 361 | """ 362 | Plots the appropriate SHAP plot. 363 | 364 | Args: 365 | plot_type (str): 366 | One of the following: 367 | 368 | - `'importance'`: Feature importance plot, SHAP bar summary plot 369 | - `'summary'`: SHAP Summary plot 370 | - `'dependence'`: Dependence plot for each feature 371 | - `'sample'`: Explanation of a given sample in the test data 372 | 373 | target_set (str, optional): 374 | The set for which the plot should be generated, either `train` or `test` set. We recommend using test 375 | set, because it is not biased by model training. The train set plots are mainly used to compare with the 376 | test set plots, whether there are significant differences, which indicate shift in data distribution. 377 | 378 | target_columns (None, str or list of str, optional): 379 | List of features names, for which the plots should be generated. If None, all features will be plotted. 380 | 381 | samples_index (None, int, list or pd.Index, optional): 382 | Index of samples to be explained if the `plot_type=sample`. 383 | 384 | show (bool, optional): 385 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful, 386 | when you want to edit the returned axis, before showing it. 387 | 388 | **plot_kwargs: 389 | Keyword arguments passed to the plot method. For 'importance' and 'summary' plot_type, the kwargs are 390 | passed to shap.summary_plot, for 'dependence' plot_type, they are passed to 391 | probatus.interpret.DependencePlotter.plot method. 392 | 393 | Returns: 394 | (matplotlib.axes or list(matplotlib.axes)): 395 | An Axes with the plot, or list of axes when multiple plots are returned. 396 | """ 397 | # Choose correct columns 398 | if target_columns is None: 399 | target_columns = self.column_names 400 | 401 | target_columns = assure_list_of_strings(target_columns, "target_columns") 402 | target_columns_indices = [self.column_names.index(target_column) for target_column in target_columns] 403 | 404 | # Choose the correct dataset 405 | if target_set == "test": 406 | target_X = self.X_test 407 | target_shap_values = self.shap_values_test 408 | target_tdp = self.tdp_test 409 | target_expected_value = self.expected_value_test 410 | elif target_set == "train": 411 | target_X = self.X_train 412 | target_shap_values = self.shap_values_train 413 | target_tdp = self.tdp_train 414 | target_expected_value = self.expected_value_train 415 | else: 416 | raise (ValueError('The target_set parameter can be either "train" or "test".')) 417 | 418 | if plot_type in ["importance", "summary"]: 419 | target_X = target_X[target_columns] 420 | target_shap_values = target_shap_values[:, target_columns_indices] 421 | # Set summary plot settings 422 | if plot_type == "importance": 423 | plot_type = "bar" 424 | plot_title = f"SHAP Feature Importance for {target_set} set" 425 | else: 426 | plot_type = "dot" 427 | plot_title = f"SHAP Summary plot for {target_set} set" 428 | 429 | summary_plot( 430 | target_shap_values, 431 | target_X, 432 | plot_type=plot_type, 433 | class_names=self.class_names, 434 | show=False, 435 | **plot_kwargs, 436 | ) 437 | ax = plt.gca() 438 | ax.set_title(plot_title) 439 | 440 | ax.annotate( 441 | self.results_text, 442 | (0, 0), 443 | (0, -50), 444 | fontsize=12, 445 | xycoords="axes fraction", 446 | textcoords="offset points", 447 | va="top", 448 | ) 449 | if show: 450 | plt.show() 451 | else: 452 | plt.close() 453 | elif plot_type == "dependence": 454 | ax = [] 455 | for feature_name in target_columns: 456 | ax.append(target_tdp.plot(feature=feature_name, figsize=(10, 7), show=show, **plot_kwargs)) 457 | 458 | elif plot_type == "sample": 459 | # Ensure the correct samples_index type 460 | if samples_index is None: 461 | raise (ValueError("For sample plot, you need to specify the samples_index be plotted plot")) 462 | elif isinstance(samples_index, int) or isinstance(samples_index, str): 463 | samples_index = [samples_index] 464 | elif not (isinstance(samples_index, list) or isinstance(samples_index, pd.Index)): 465 | raise (TypeError("sample_index must be one of the following: int, str, list or pd.Index")) 466 | 467 | ax = [] 468 | for sample_index in samples_index: 469 | sample_loc = target_X.index.get_loc(sample_index) 470 | 471 | waterfall_legacy( 472 | target_expected_value, 473 | target_shap_values[sample_loc, :], 474 | target_X.loc[sample_index], 475 | show=False, 476 | **plot_kwargs, 477 | ) 478 | 479 | plot_title = f"SHAP Sample Explanation of {target_set} sample for index={sample_index}" 480 | current_ax = plt.gca() 481 | current_ax.set_title(plot_title) 482 | ax.append(current_ax) 483 | if show: 484 | plt.show() 485 | else: 486 | plt.close() 487 | else: 488 | raise ValueError("Wrong plot type, select from 'importance', 'summary', or 'dependence'") 489 | 490 | if isinstance(ax, list) and len(ax) == 1: 491 | ax = ax[0] 492 | return ax 493 | -------------------------------------------------------------------------------- /probatus/interpret/shap_dependence.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.preprocessing import KBinsDiscretizer 5 | 6 | from probatus.utils import BaseFitComputePlotClass, preprocess_data, preprocess_labels, shap_to_df 7 | 8 | 9 | class DependencePlotter(BaseFitComputePlotClass): 10 | """ 11 | Plotter used to plot SHAP dependence plot together with the target rates. 12 | 13 | Currently it supports tree-based and linear models. 14 | 15 | Args: 16 | model: classifier for which interpretation is done. 17 | 18 | Example: 19 | ```python 20 | from sklearn.datasets import make_classification 21 | from sklearn.ensemble import RandomForestClassifier 22 | from probatus.interpret import DependencePlotter 23 | 24 | X, y = make_classification(n_samples=15, n_features=3, n_informative=3, n_redundant=0, random_state=42) 25 | model = RandomForestClassifier().fit(X, y) 26 | bdp = DependencePlotter(model) 27 | shap_values = bdp.fit_compute(X, y) 28 | 29 | bdp.plot(feature=2) 30 | ``` 31 | 32 | 33 | """ 34 | 35 | def __init__(self, model, verbose=0, random_state=None): 36 | """ 37 | Initializes the class. 38 | 39 | Args: 40 | model (model object): 41 | regression or classification model or pipeline. 42 | 43 | verbose (int, optional): 44 | Controls verbosity of the output: 45 | 46 | - 0 - neither prints nor warnings are shown 47 | - 1 - only most important warnings 48 | - 2 - shows all prints and all warnings. 49 | 50 | random_state (int, optional): 51 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For 52 | reproducible results set it to an integer. 53 | """ 54 | self.model = model 55 | self.verbose = verbose 56 | self.random_state = random_state 57 | 58 | def __repr__(self): 59 | """ 60 | Represent string method. 61 | """ 62 | return f"Shap dependence plotter for {self.model.__class__.__name__}" 63 | 64 | def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs): 65 | """ 66 | Fits the plotter to the model and data by computing the shap values. 67 | 68 | If the shap_values are passed, they do not need to be computed. 69 | 70 | Args: 71 | X (pd.DataFrame): input variables. 72 | 73 | y (pd.Series): target variable. 74 | 75 | column_names (None, or list of str, optional): 76 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used. 77 | 78 | class_names (None, or list of str, optional): 79 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are 80 | used. 81 | 82 | precalc_shap (Optional, None or np.array): 83 | Precalculated shap values, If provided they don't need to be computed. 84 | 85 | **shap_kwargs: 86 | keyword arguments passed to 87 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer). 88 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values. 89 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while 90 | `check_additivity=False` disables the additivity check inside SHAP. 91 | """ 92 | self.X, self.column_names = preprocess_data(X, X_name="X", column_names=column_names, verbose=self.verbose) 93 | self.y = preprocess_labels(y, y_name="y", index=self.X.index, verbose=self.verbose) 94 | 95 | # Set class names 96 | self.class_names = class_names 97 | if self.class_names is None: 98 | self.class_names = ["Negative Class", "Positive Class"] 99 | 100 | self.shap_vals_df = shap_to_df( 101 | self.model, 102 | self.X, 103 | precalc_shap=precalc_shap, 104 | verbose=self.verbose, 105 | random_state=self.random_state, 106 | **shap_kwargs, 107 | ) 108 | 109 | self.fitted = True 110 | return self 111 | 112 | def compute(self): 113 | """ 114 | Computes the report returned to the user, namely the SHAP values generated on the dataset. 115 | 116 | Returns: 117 | (pd.DataFrame): 118 | SHAP Values for X. 119 | """ 120 | self._check_if_fitted() 121 | return self.shap_vals_df 122 | 123 | def fit_compute(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs): 124 | """ 125 | Fits the plotter to the model and data by computing the shap values. 126 | 127 | If the shap_values are passed, they do not need to be computed 128 | 129 | Args: 130 | X (pd.DataFrame): 131 | Provided dataset. 132 | 133 | y (pd.Series): 134 | Labels for X. 135 | 136 | column_names (None, or list of str, optional): 137 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used. 138 | 139 | class_names (None, or list of str, optional): 140 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are 141 | used. 142 | 143 | precalc_shap (Optional, None or np.array): 144 | Precalculated shap values, If provided they don't need to be computed. 145 | 146 | **shap_kwargs: 147 | keyword arguments passed to 148 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer). 149 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values. 150 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while 151 | `check_additivity=False` disables the additivity check inside SHAP. 152 | 153 | Returns: 154 | (pd.DataFrame): 155 | SHAP Values for X. 156 | """ 157 | self.fit(X, y, column_names=column_names, class_names=class_names, precalc_shap=precalc_shap, **shap_kwargs) 158 | return self.compute() 159 | 160 | def plot( 161 | self, 162 | feature, 163 | figsize=(15, 10), 164 | bins=10, 165 | show=True, 166 | min_q=0, 167 | max_q=1, 168 | alpha=1.0, 169 | ): 170 | """ 171 | Plots the shap values for data points for a given feature, as well as the target rate and values distribution. 172 | 173 | Args: 174 | feature (str or int): 175 | Feature name of the feature to be analyzed. 176 | 177 | figsize ((float, float), optional): 178 | Tuple specifying size (width, height) of resulting figure in inches. 179 | 180 | bins (int or list[float]): 181 | Number of bins or boundaries of bins (supplied in list) for target-rate plot. 182 | 183 | show (bool, optional): 184 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful, 185 | when you want to edit the returned axis, before showing it. 186 | 187 | min_q (float, optional): 188 | Optional minimum quantile from which to consider values, used for plotting under outliers. 189 | 190 | max_q (float, optional): 191 | Optional maximum quantile until which data points are considered, used for plotting under outliers. 192 | 193 | alpha (float, optional): 194 | Optional alpha blending value, between 0 (transparent) and 1 (opaque). 195 | 196 | Returns 197 | (list(matplotlib.axes)): 198 | List of axes that include the plots. 199 | """ 200 | self._check_if_fitted() 201 | if min_q >= max_q: 202 | raise ValueError("min_q must be smaller than max_q") 203 | if feature not in self.X.columns: 204 | raise ValueError("Feature not recognized") 205 | if (alpha < 0) or (alpha > 1): 206 | raise ValueError("alpha must be a float value between 0 and 1") 207 | 208 | self.min_q, self.max_q, self.alpha = min_q, max_q, alpha 209 | 210 | _ = plt.figure(1, figsize=figsize) 211 | ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2) 212 | ax2 = plt.subplot2grid((3, 1), (2, 0)) 213 | 214 | self._dependence_plot(feature=feature, ax=ax1) 215 | self._target_rate_plot(feature=feature, bins=bins, ax=ax2) 216 | 217 | ax2.set_xlim(ax1.get_xlim()) 218 | 219 | if show: 220 | plt.show() 221 | else: 222 | plt.close() 223 | 224 | return [ax1, ax2] 225 | 226 | def _dependence_plot(self, feature, ax=None): 227 | """ 228 | Plots shap values for data points with respect to specified feature. 229 | 230 | Args: 231 | feature (str or int): 232 | Feature for which dependence plot is to be created. 233 | 234 | ax (matplotlib.pyplot.axes, optional): 235 | Optional axis on which to draw plot. 236 | 237 | Returns: 238 | (matplotlib.pyplot.axes): 239 | Axes on which plot is drawn. 240 | """ 241 | if isinstance(feature, int): 242 | feature = self.column_names[feature] 243 | 244 | X, y, shap_val = self._get_X_y_shap_with_q_cut(feature=feature) 245 | 246 | ax.scatter(X[y == 0], shap_val[y == 0], label=self.class_names[0], color="lightblue", alpha=self.alpha) 247 | 248 | ax.scatter(X[y == 1], shap_val[y == 1], label=self.class_names[1], color="darkred", alpha=self.alpha) 249 | 250 | ax.set_ylabel("Shap value") 251 | ax.set_title(f"Dependence plot for {feature} feature") 252 | ax.legend() 253 | 254 | return ax 255 | 256 | def _target_rate_plot(self, feature, bins=10, ax=None): 257 | """ 258 | Plots the distributions of the specific features, as well as the target rate as function of the feature. 259 | 260 | Args: 261 | feature (str or int): 262 | Feature for which to create target rate plot. 263 | 264 | bins (int or list[float]), optional: 265 | Number of bins or boundaries of desired bins in list. 266 | 267 | ax (matplotlib.pyplot.axes, optional): 268 | Optional axis on which to draw plot. 269 | 270 | Returns: 271 | (list[float], matplotlib.pyplot.axes, float): 272 | Tuple of boundaries of bins used, axis on which plot is drawn, total ratio of target (positive over 273 | negative). 274 | """ 275 | x, y, shap_val = self._get_X_y_shap_with_q_cut(feature=feature) 276 | 277 | # Create bins if not explicitly supplied 278 | if isinstance(bins, int): 279 | simple_binner = KBinsDiscretizer(n_bins=bins, encode="ordinal", strategy="uniform").fit( 280 | np.array(x).reshape(-1, 1) 281 | ) 282 | bins = simple_binner.bin_edges_[0] 283 | bins[0], bins[-1] = -np.inf, np.inf 284 | 285 | # Determine bin for datapoints 286 | bins[-1] = bins[-1] + 1 287 | indices = np.digitize(x, bins) 288 | # Create dataframe with binned data 289 | dfs = pd.DataFrame({feature: x, "y": y, "bin_index": pd.Series(indices, index=x.index)}).groupby( 290 | "bin_index", as_index=True 291 | ) 292 | 293 | # Extract target ratio and mean feature value 294 | target_ratio = dfs["y"].mean() 295 | x_vals = dfs[feature].mean() 296 | 297 | # Transform the first and last bin to work with plt.hist method 298 | if bins[0] == -np.inf: 299 | bins[0] = x.min() 300 | if bins[-1] == np.inf: 301 | bins[-1] = x.max() 302 | 303 | # Plot target rate 304 | ax.hist(x, bins=bins, lw=2, alpha=0.4) 305 | ax.set_ylabel("Counts") 306 | ax2 = ax.twinx() 307 | ax2.plot(x_vals, target_ratio, color="red") 308 | ax2.set_ylabel("Target rate", color="red", fontsize=12) 309 | ax2.set_xlim(x.min(), x.max()) 310 | ax.set_xlabel(f"{feature} feature values") 311 | 312 | return bins, ax, target_ratio 313 | 314 | def _get_X_y_shap_with_q_cut(self, feature): 315 | """ 316 | Extracts all X, y pairs and shap values that fall within defined quantiles of the feature. 317 | 318 | Args: 319 | feature (str): feature to return values for 320 | 321 | Returns: 322 | x (pd.Series): selected datapoints 323 | y (pd.Series): target values of selected datapoints 324 | shap_val (pd.Series): shap values of selected datapoints 325 | """ 326 | self._check_if_fitted() 327 | if feature not in self.X.columns: 328 | raise ValueError("Feature not found in data") 329 | 330 | # Prepare arrays 331 | x = self.X[feature] 332 | y = self.y 333 | shap_val = self.shap_vals_df[feature] 334 | 335 | # Determine quantile ranges 336 | x_min = x.quantile(self.min_q) 337 | x_max = x.quantile(self.max_q) 338 | 339 | # Create filter 340 | filter = (x >= x_min) & (x <= x_max) 341 | 342 | # Filter and return terms 343 | return x[filter], y[filter], shap_val[filter] 344 | -------------------------------------------------------------------------------- /probatus/sample_similarity/__init__.py: -------------------------------------------------------------------------------- 1 | from .resemblance_model import ( 2 | BaseResemblanceModel, 3 | PermutationImportanceResemblance, 4 | SHAPImportanceResemblance, 5 | ) 6 | 7 | __all__ = [ 8 | "BaseResemblanceModel", 9 | "PermutationImportanceResemblance", 10 | "SHAPImportanceResemblance", 11 | ] 12 | -------------------------------------------------------------------------------- /probatus/sample_similarity/resemblance_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | from loguru import logger 7 | from shap import summary_plot 8 | from sklearn.inspection import permutation_importance 9 | from sklearn.model_selection import train_test_split 10 | 11 | from probatus.utils import BaseFitComputePlotClass, preprocess_data, preprocess_labels, get_single_scorer 12 | from probatus.utils.shap_helpers import calculate_shap_importance, shap_calc 13 | 14 | 15 | class BaseResemblanceModel(BaseFitComputePlotClass): 16 | """ 17 | This model checks for the similarity of two samples. 18 | 19 | A possible use case is analysis of whether th train sample differs 20 | from the test sample, due to e.g. non-stationarity. 21 | 22 | This is a base class and needs to be extended by a fit() method, which implements how the data is split, 23 | how the model is trained and evaluated. 24 | Further, inheriting classes need to implement how feature importance should be indicated. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model, 30 | scoring="roc_auc", 31 | test_prc=0.25, 32 | n_jobs=1, 33 | verbose=0, 34 | random_state=None, 35 | ): 36 | """ 37 | Initializes the class. 38 | 39 | Args: 40 | model (model object): 41 | Regression or classification model or pipeline. 42 | 43 | scoring (string or probatus.utils.Scorer, optional): 44 | Metric for which the model performance is calculated. It can be either a metric name aligned with 45 | predefined 46 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html). 47 | Another option is using probatus.utils.Scorer to define a custom metric. The recommended option for this 48 | class is 'roc_auc'. 49 | 50 | test_prc (float, optional): 51 | Percentage of data used to test the model. By default 0.25 is set. 52 | 53 | n_jobs (int, optional): 54 | Number of parallel executions. If -1 use all available cores. By default 1. 55 | 56 | verbose (int, optional): 57 | Controls verbosity of the output: 58 | 59 | - 0 - neither prints nor warnings are shown 60 | - 1 - only most important warnings 61 | - 2 - shows all prints and all warnings. 62 | 63 | random_state (int, optional): 64 | Random state set at each round of feature elimination. If it is None, the results will not be 65 | reproducible and in random search at each iteration a different hyperparameters might be tested. For 66 | reproducible results set it to an integer. 67 | """ # noqa 68 | self.model = model 69 | self.test_prc = test_prc 70 | self.n_jobs = n_jobs 71 | self.random_state = random_state 72 | self.verbose = verbose 73 | self.scorer = get_single_scorer(scoring) 74 | 75 | def _init_output_variables(self): 76 | """ 77 | Initializes variables that will be filled in during fit() method, and are used as output. 78 | """ 79 | self.X_train = None 80 | self.X_test = None 81 | self.y_train = None 82 | self.y_test = None 83 | self.train_score = None 84 | self.test_score = None 85 | self.report = None 86 | 87 | def fit(self, X1, X2, column_names=None, class_names=None): 88 | """ 89 | Base fit functionality that should be executed before each fit. 90 | 91 | Args: 92 | X1 (np.ndarray or pd.DataFrame): 93 | First sample to be compared. It needs to have the same number of columns as X2. 94 | 95 | X2 (np.ndarray or pd.DataFrame): 96 | Second sample to be compared. It needs to have the same number of columns as X1. 97 | 98 | column_names (list of str, optional): 99 | List of feature names of the provided samples. If provided it will be used to overwrite the existing 100 | feature names. If not provided the existing feature names are used or default feature names are 101 | generated. 102 | 103 | class_names (None, or list of str, optional): 104 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the 105 | default ['First Sample', 'Second Sample'] are used. 106 | 107 | Returns: 108 | (BaseResemblanceModel): 109 | Fitted object 110 | """ 111 | # Set class names 112 | self.class_names = class_names 113 | if self.class_names is None: 114 | self.class_names = ["First Sample", "Second Sample"] 115 | 116 | # Ensure inputs are correct 117 | self.X1, self.column_names = preprocess_data(X1, X_name="X1", column_names=column_names, verbose=self.verbose) 118 | self.X2, _ = preprocess_data(X2, X_name="X2", column_names=column_names, verbose=self.verbose) 119 | 120 | # Prepare dataset for modelling 121 | self.X = pd.DataFrame(pd.concat([self.X1, self.X2], axis=0), columns=self.column_names).reset_index(drop=True) 122 | 123 | self.y = pd.Series( 124 | np.concatenate( 125 | [ 126 | np.zeros(self.X1.shape[0]), 127 | np.ones(self.X2.shape[0]), 128 | ] 129 | ) 130 | ).reset_index(drop=True) 131 | 132 | # Assure the type and number of classes for the variable 133 | self.X, _ = preprocess_data(self.X, X_name="X", column_names=self.column_names, verbose=self.verbose) 134 | 135 | self.y = preprocess_labels(self.y, y_name="y", index=self.X.index, verbose=self.verbose) 136 | 137 | # Reinitialize variables in case of multiple times being fit 138 | self._init_output_variables() 139 | 140 | self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( 141 | self.X, 142 | self.y, 143 | test_size=self.test_prc, 144 | random_state=self.random_state, 145 | shuffle=True, 146 | stratify=self.y, 147 | ) 148 | self.model.fit(self.X_train, self.y_train) 149 | 150 | self.train_score = np.round(self.scorer.score(self.model, self.X_train, self.y_train), 3) 151 | self.test_score = np.round(self.scorer.score(self.model, self.X_test, self.y_test), 3) 152 | 153 | self.results_text = ( 154 | f"Train {self.scorer.metric_name}: {np.round(self.train_score, 3)},\n" 155 | f"Test {self.scorer.metric_name}: {np.round(self.test_score, 3)}." 156 | ) 157 | if self.verbose > 1: 158 | logger.info(f"Finished model training: \n{self.results_text}") 159 | 160 | if self.verbose > 0: 161 | if self.train_score > self.test_score: 162 | warnings.warn( 163 | f"Train {self.scorer.metric_name} > Test {self.scorer.metric_name}, which might indicate " 164 | f"an overfit. \n Strong overfit might lead to misleading conclusions when analysing " 165 | f"feature importance. Consider retraining with more regularization applied to the model." 166 | ) 167 | self.fitted = True 168 | return self 169 | 170 | def get_data_splits(self): 171 | """ 172 | Returns the data splits used to train the Resemblance model. 173 | 174 | Returns: 175 | (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series): 176 | X_train, X_test, y_train, y_test. 177 | """ 178 | self._check_if_fitted() 179 | return self.X_train, self.X_test, self.y_train, self.y_test 180 | 181 | def compute(self, return_scores=False): 182 | """ 183 | Checks if fit() method has been run and computes the output variables. 184 | 185 | Args: 186 | return_scores (bool, optional): 187 | Flag indicating whether the method should return a tuple (feature importances, train score, 188 | test score), or feature importances. By default the second option is selected. 189 | 190 | Returns: 191 | (tuple(pd.DataFrame, float, float) or pd.DataFrame): 192 | Depending on value of return_tuple either returns a tuple (feature importances, train AUC, test AUC), or 193 | feature importances. 194 | """ 195 | self._check_if_fitted() 196 | 197 | if return_scores: 198 | return self.report, self.train_score, self.test_score 199 | else: 200 | return self.report 201 | 202 | def fit_compute( 203 | self, 204 | X1, 205 | X2, 206 | column_names=None, 207 | class_names=None, 208 | return_scores=False, 209 | **fit_kwargs, 210 | ): 211 | """ 212 | Fits the resemblance model and computes the report regarding feature importance. 213 | 214 | Args: 215 | X1 (np.ndarray or pd.DataFrame): 216 | First sample to be compared. It needs to have the same number of columns as X2. 217 | 218 | X2 (np.ndarray or pd.DataFrame): 219 | Second sample to be compared. It needs to have the same number of columns as X1. 220 | 221 | column_names (list of str, optional): 222 | List of feature names of the provided samples. If provided it will be used to overwrite the existing 223 | feature names. If not provided the existing feature names are used or default feature names are 224 | generated. 225 | 226 | class_names (None, or list of str, optional): 227 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the 228 | default ['First Sample', 'Second Sample'] are used. 229 | 230 | return_scores (bool, optional): 231 | Flag indicating whether the method should return a tuple (feature importances, train score, 232 | test score), or feature importances. By default the second option is selected. 233 | 234 | **fit_kwargs: 235 | In case any other arguments are accepted by fit() method, they can be passed as keyword arguments. 236 | 237 | Returns: 238 | (tuple of (pd.DataFrame, float, float) or pd.DataFrame): 239 | Depending on value of return_tuple either returns a tuple (feature importances, train AUC, test AUC), or 240 | feature importances. 241 | """ 242 | self.fit(X1, X2, column_names=column_names, class_names=class_names, **fit_kwargs) 243 | return self.compute(return_scores=return_scores) 244 | 245 | def plot(self): 246 | """ 247 | Plot. 248 | """ 249 | raise (NotImplementedError("Plot method has not been implemented.")) 250 | 251 | 252 | class PermutationImportanceResemblance(BaseResemblanceModel): 253 | """ 254 | This model checks the similarity of two samples. 255 | 256 | A possible use case is analysis of whether the train sample differs 257 | from the test sample, due to e.g. non-stationarity. 258 | 259 | It assigns labels to each sample, 0 to the first sample, 1 to the second. Then, it randomly selects a portion of 260 | data to train on. The resulting model tries to distinguish which sample a given test row comes from. This 261 | provides insights on how distinguishable these samples are and which features contribute to that. The feature 262 | importance is calculated using permutation importance. 263 | 264 | If the model achieves a test AUC significantly different than 0.5, it indicates that it is possible to distinguish 265 | between the samples, and therefore, the samples differ. 266 | Features with a high permutation importance contribute to that effect the most. 267 | Thus, their distribution might differ between two samples. 268 | 269 | Examples: 270 | ```python 271 | from sklearn.datasets import make_classification 272 | from sklearn.ensemble import RandomForestClassifier 273 | from probatus.sample_similarity import PermutationImportanceResemblance 274 | X1, _ = make_classification(n_samples=100, n_features=5) 275 | X2, _ = make_classification(n_samples=100, n_features=5, shift=0.5) 276 | model = RandomForestClassifier(max_depth=2) 277 | perm = PermutationImportanceResemblance(model) 278 | feature_importance = perm.fit_compute(X1, X2) 279 | perm.plot() 280 | ``` 281 | 282 | """ 283 | 284 | def __init__( 285 | self, 286 | model, 287 | iterations=100, 288 | scoring="roc_auc", 289 | test_prc=0.25, 290 | n_jobs=1, 291 | verbose=0, 292 | random_state=None, 293 | ): 294 | """ 295 | Initializes the class. 296 | 297 | Args: 298 | model (model object): 299 | Regression or classification model or pipeline. 300 | 301 | iterations (int, optional): 302 | Number of iterations performed to calculate permutation importance. By default 100 iterations per 303 | feature are done. 304 | 305 | scoring (string or probatus.utils.Scorer, optional): 306 | Metric for which the model performance is calculated. It can be either a metric name aligned with 307 | predefined 308 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html). 309 | Another option is using probatus.utils.Scorer to define a custom metric. Recommended option for this 310 | class is 'roc_auc'. 311 | 312 | test_prc (float, optional): 313 | Percentage of data used to test the model. By default 0.25 is set. 314 | 315 | n_jobs (int, optional): 316 | Number of parallel executions. If -1 use all available cores. By default 1. 317 | 318 | verbose (int, optional): 319 | Controls verbosity of the output: 320 | 321 | - 0 - neither prints nor warnings are shown 322 | - 1 - only most important warnings 323 | - 2 - shows all prints and all warnings. 324 | 325 | random_state (int, optional): 326 | Random state set at each round of feature elimination. If it is None, the results will not be 327 | reproducible and in random search at each iteration a different hyperparameters might be tested. For 328 | reproducible results set it to integer. 329 | """ # noqa 330 | super().__init__( 331 | model=model, 332 | scoring=scoring, 333 | test_prc=test_prc, 334 | n_jobs=n_jobs, 335 | verbose=verbose, 336 | random_state=random_state, 337 | ) 338 | 339 | self.iterations = iterations 340 | 341 | self.iterations_columns = ["feature", "importance"] 342 | self.iterations_results = pd.DataFrame(columns=self.iterations_columns) 343 | 344 | self.plot_x_label = "Permutation Feature Importance" 345 | self.plot_y_label = "Feature Name" 346 | self.plot_title = "Permutation Feature Importance of Resemblance Model" 347 | 348 | def fit(self, X1, X2, column_names=None, class_names=None): 349 | """ 350 | This function assigns labels to each sample, 0 to the first sample, 1 to the second. 351 | 352 | Then, it randomly selects a 353 | portion of data to train on. The resulting model tries to distinguish which sample a given test row 354 | comes from. This provides insights on how distinguishable these samples are and which features contribute to 355 | that. The feature importance is calculated using permutation importance. 356 | 357 | Args: 358 | X1 (np.ndarray or pd.DataFrame): 359 | First sample to be compared. It needs to have the same number of columns as X2. 360 | 361 | X2 (np.ndarray or pd.DataFrame): 362 | Second sample to be compared. It needs to have the same number of columns as X1. 363 | 364 | column_names (list of str, optional): 365 | List of feature names of the provided samples. If provided it will be used to overwrite the existing 366 | feature names. If not provided the existing feature names are used or default feature names are 367 | generated. 368 | 369 | class_names (None, or list of str, optional): 370 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the 371 | default ['First Sample', 'Second Sample'] are used. 372 | 373 | Returns: 374 | (PermutationImportanceResemblance): 375 | Fitted object. 376 | """ 377 | super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names) 378 | 379 | permutation_result = permutation_importance( 380 | self.model, 381 | self.X_test, 382 | self.y_test, 383 | scoring=self.scorer.scorer, 384 | n_repeats=self.iterations, 385 | n_jobs=self.n_jobs, 386 | ) 387 | 388 | # Prepare report 389 | self.report_columns = ["mean_importance", "std_importance"] 390 | self.report = pd.DataFrame(index=self.column_names, columns=self.report_columns, dtype=float) 391 | 392 | for feature_index, feature_name in enumerate(self.column_names): 393 | # Fill in the report 394 | self.report.loc[feature_name, "mean_importance"] = permutation_result["importances_mean"][feature_index] 395 | self.report.loc[feature_name, "std_importance"] = permutation_result["importances_std"][feature_index] 396 | 397 | # Fill in the iterations 398 | current_iterations = pd.DataFrame( 399 | np.stack( 400 | [ 401 | np.repeat(feature_name, self.iterations), 402 | permutation_result["importances"][feature_index, :].reshape((self.iterations,)), 403 | ], 404 | axis=1, 405 | ), 406 | columns=self.iterations_columns, 407 | ) 408 | 409 | self.iterations_results = pd.concat([self.iterations_results, current_iterations]) 410 | 411 | self.iterations_results["importance"] = self.iterations_results["importance"].astype(float) 412 | 413 | # Sort by mean test score of first metric 414 | self.report.sort_values(by="mean_importance", ascending=False, inplace=True) 415 | 416 | return self 417 | 418 | def plot(self, ax=None, top_n=None, show=True, **plot_kwargs): 419 | """ 420 | Plots the resulting AUC of the model as well as the feature importances. 421 | 422 | Args: 423 | ax (matplotlib.axes, optional): 424 | Axes to which the output should be plotted. If not provided new axes are created. 425 | 426 | top_n (int, optional): 427 | Number of the most important features to be plotted. By default features are included in the plot. 428 | 429 | show (bool, optional): 430 | If True, the plots are shown to the user, otherwise they are not shown. Not showing a plot can be useful 431 | when you want to edit the returned axis before showing it. 432 | 433 | **plot_kwargs: 434 | Keyword arguments passed to the matplotlib.plotly.subplots method. 435 | 436 | Returns: 437 | (matplotlib.axes): 438 | Axes that include the plot. 439 | """ 440 | 441 | feature_report = self.compute() 442 | self.iterations_results["importance"] = self.iterations_results["importance"].astype(float) 443 | 444 | sorted_features = feature_report["mean_importance"].sort_values(ascending=True).index.values 445 | if top_n is not None and top_n > 0: 446 | sorted_features = sorted_features[-top_n:] 447 | 448 | if ax is None: 449 | fig, ax = plt.subplots(**plot_kwargs) 450 | 451 | for position, feature in enumerate(sorted_features): 452 | ax.boxplot( 453 | self.iterations_results[self.iterations_results["feature"] == feature]["importance"], 454 | positions=[position], 455 | vert=False, 456 | ) 457 | 458 | ax.set_yticks(range(position + 1)) 459 | ax.set_yticklabels(sorted_features) 460 | ax.set_xlabel(self.plot_x_label) 461 | ax.set_ylabel(self.plot_y_label) 462 | ax.set_title(self.plot_title) 463 | 464 | ax.annotate( 465 | self.results_text, 466 | (0, 0), 467 | (0, -50), 468 | fontsize=12, 469 | xycoords="axes fraction", 470 | textcoords="offset points", 471 | va="top", 472 | ) 473 | 474 | if show: 475 | plt.show() 476 | else: 477 | plt.close() 478 | 479 | return ax 480 | 481 | 482 | class SHAPImportanceResemblance(BaseResemblanceModel): 483 | """ 484 | This model checks for similarity of two samples. 485 | 486 | A possible use case is analysis of whether the train sample differs 487 | from the test sample, due to e.g. non-stationarity. 488 | 489 | It assigns labels to each sample, 0 to the first sample, 1 to the second. Then, it randomly selects a portion of 490 | data to train on. The resulting model tries to distinguish which sample a given test row comes from. This 491 | provides insights on how distinguishable these samples are and which features contribute to that. The feature 492 | importance is calculated using SHAP feature importance. 493 | 494 | If the model achieves test AUC significantly different than 0.5, it indicates that it is possible to distinguish 495 | between the samples, and therefore, the samples differ. Features with a high permutation importance contribute 496 | to that effect the most. Thus, their distribution might differ between two samples. 497 | 498 | This class currently works only with the Tree based models. 499 | 500 | Examples: 501 | ```python 502 | from sklearn.datasets import make_classification 503 | from sklearn.ensemble import RandomForestClassifier 504 | from probatus.sample_similarity import SHAPImportanceResemblance 505 | X1, _ = make_classification(n_samples=100, n_features=5) 506 | X2, _ = make_classification(n_samples=100, n_features=5, shift=0.5) 507 | model = RandomForestClassifier(max_depth=2) 508 | rm = SHAPImportanceResemblance(model) 509 | feature_importance = rm.fit_compute(X1, X2) 510 | rm.plot() 511 | ``` 512 | 513 | 514 | 515 | """ 516 | 517 | def __init__( 518 | self, 519 | model, 520 | scoring="roc_auc", 521 | test_prc=0.25, 522 | n_jobs=1, 523 | verbose=0, 524 | random_state=None, 525 | ): 526 | """ 527 | Initializes the class. 528 | 529 | Args: 530 | model (model object): 531 | Regression or classification model or pipeline. 532 | 533 | scoring (string or probatus.utils.Scorer, optional): 534 | Metric for which the model performance is calculated. It can be either a metric name aligned with 535 | predefined 536 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html). 537 | Another option is using probatus.utils.Scorer to define a custom metric. Recommended option for this 538 | class is 'roc_auc'. 539 | 540 | test_prc (float, optional): 541 | Percentage of data used to test the model. By default 0.25 is set. 542 | 543 | n_jobs (int, optional): 544 | Number of parallel executions. If -1 use all available cores. By default 1. 545 | 546 | verbose (int, optional): 547 | Controls verbosity of the output: 548 | 549 | - 0 - neither prints nor warnings are shown 550 | - 1 - only most important warnings 551 | - 2 - shows all prints and all warnings. 552 | 553 | random_state (int, optional): 554 | Random state set at each round of feature elimination. If it is None, the results will not be 555 | reproducible and in random search at each iteration a different hyperparameters might be tested. For 556 | reproducible results set it to integer. 557 | """ # noqa 558 | super().__init__( 559 | model=model, 560 | scoring=scoring, 561 | test_prc=test_prc, 562 | n_jobs=n_jobs, 563 | verbose=verbose, 564 | random_state=random_state, 565 | ) 566 | 567 | self.plot_title = "SHAP summary plot" 568 | 569 | def fit(self, X1, X2, column_names=None, class_names=None, **shap_kwargs): 570 | """ 571 | This function assigns labels to each sample, 0 to the first sample, 1 to the second. 572 | 573 | Then, it randomly selects a 574 | portion of data to train on. The resulting model tries to distinguish which sample a given test row 575 | comes from. This provides insights on how distinguishable these samples are and which features contribute to 576 | that. The feature importance is calculated using SHAP feature importance. 577 | 578 | Args: 579 | X1 (np.ndarray or pd.DataFrame): 580 | First sample to be compared. It needs to have the same number of columns as X2. 581 | 582 | X2 (np.ndarray or pd.DataFrame): 583 | Second sample to be compared. It needs to have the same number of columns as X1. 584 | 585 | column_names (list of str, optional): 586 | List of feature names of the provided samples. If provided it will be used to overwrite the existing 587 | feature names. If not provided the existing feature names are used or default feature names are 588 | generated. 589 | 590 | class_names (None, or list of str, optional): 591 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the 592 | default ['First Sample', 'Second Sample'] are used. 593 | 594 | **shap_kwargs: 595 | keyword arguments passed to 596 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer). 597 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values. 598 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while 599 | `check_additivity=False` disables the additivity check inside SHAP. 600 | 601 | Returns: 602 | (SHAPImportanceResemblance): 603 | Fitted object. 604 | """ 605 | super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names) 606 | 607 | self.shap_values_test = shap_calc( 608 | self.model, self.X_test, verbose=self.verbose, random_state=self.random_state, **shap_kwargs 609 | ) 610 | self.report = calculate_shap_importance(self.shap_values_test, self.column_names) 611 | return self 612 | 613 | def plot(self, plot_type="bar", show=True, **summary_plot_kwargs): 614 | """ 615 | Plots the resulting AUC of the model as well as the feature importances. 616 | 617 | Args: 618 | plot_type (str, optional): Type of plot, used to compute shap.summary_plot. By default 'bar', available ones 619 | are "dot", "bar", "violin", 620 | 621 | show (bool, optional): 622 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful, 623 | when you want to edit the returned axis, before showing it. 624 | 625 | **summary_plot_kwargs: 626 | kwargs passed to the shap.summary_plot. 627 | 628 | Returns: 629 | (matplotlib.axes): 630 | Axes that include the plot. 631 | """ 632 | 633 | # This line serves as a double check if the object has been fitted 634 | self._check_if_fitted() 635 | 636 | summary_plot( 637 | self.shap_values_test, 638 | self.X_test, 639 | plot_type=plot_type, 640 | class_names=self.class_names, 641 | show=False, 642 | **summary_plot_kwargs, 643 | ) 644 | ax = plt.gca() 645 | ax.set_title(self.plot_title) 646 | 647 | ax.annotate( 648 | self.results_text, 649 | (0, 0), 650 | (0, -50), 651 | fontsize=12, 652 | xycoords="axes fraction", 653 | textcoords="offset points", 654 | va="top", 655 | ) 656 | 657 | if show: 658 | plt.show() 659 | else: 660 | plt.close() 661 | 662 | return ax 663 | 664 | def get_shap_values(self): 665 | """ 666 | Gets the SHAP values generated on the test set. 667 | 668 | Returns: 669 | (np.array): 670 | SHAP values generated on the test set. 671 | """ 672 | self._check_if_fitted() 673 | return self.shap_values_test 674 | -------------------------------------------------------------------------------- /probatus/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .exceptions import NotFittedError 2 | from .arrayfuncs import ( 3 | assure_pandas_df, 4 | assure_pandas_series, 5 | preprocess_data, 6 | preprocess_labels, 7 | ) 8 | from .scoring import Scorer, get_single_scorer 9 | from .shap_helpers import shap_calc, shap_to_df, calculate_shap_importance 10 | from ._utils import assure_list_of_strings 11 | from .base_class_interface import BaseFitComputeClass, BaseFitComputePlotClass 12 | 13 | __all__ = [ 14 | "assure_list_of_strings", 15 | "assure_pandas_df", 16 | "assure_pandas_series", 17 | "preprocess_data", 18 | "preprocess_labels", 19 | "BaseFitComputeClass", 20 | "BaseFitComputePlotClass", 21 | "NotFittedError", 22 | "get_single_scorer", 23 | "Scorer", 24 | "shap_calc", 25 | "shap_to_df", 26 | "calculate_shap_importance", 27 | ] 28 | -------------------------------------------------------------------------------- /probatus/utils/_utils.py: -------------------------------------------------------------------------------- 1 | def assure_list_of_strings(variable, variable_name): 2 | """ 3 | Make sure object is a list of strings. 4 | """ 5 | if isinstance(variable, list): 6 | return variable 7 | elif isinstance(variable, str): 8 | return [variable] 9 | else: 10 | raise (ValueError("{} needs to be either a string or list of strings.").format(variable_name)) 11 | -------------------------------------------------------------------------------- /probatus/utils/arrayfuncs.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def assure_pandas_df(x, column_names=None): 8 | """ 9 | Returns x as pandas DataFrame. X can be a list, list of lists, numpy array, pandas DataFrame or pandas Series. 10 | 11 | Args: 12 | x (list, numpy array, pandas DataFrame, pandas Series): array to be tested 13 | 14 | Returns: 15 | pandas DataFrame 16 | """ 17 | if isinstance(x, pd.DataFrame): 18 | if column_names is not None: 19 | x.columns = column_names 20 | elif isinstance(x, (np.ndarray, pd.Series, list)): 21 | x = pd.DataFrame(x, columns=column_names) 22 | else: 23 | raise TypeError("Please supply a list, numpy array, pandas Series or pandas DataFrame") 24 | 25 | return x 26 | 27 | 28 | def assure_pandas_series(x, index=None): 29 | """ 30 | Returns x as pandas Series. X can be a list, numpy array, or pandas Series. 31 | 32 | Args: 33 | x (list, numpy array, pandas DataFrame, pandas Series): array to be tested 34 | 35 | Returns: 36 | pandas Series 37 | """ 38 | if isinstance(x, pd.Series): 39 | if isinstance(index, (list, np.ndarray)): 40 | index = pd.Index(index) 41 | current_x_index = pd.Index(x.index.values) 42 | if current_x_index.equals(index): 43 | # If exact match then keep it as it is 44 | return x 45 | elif current_x_index.sort_values().equals(index.sort_values()): 46 | # If both have the same values but in different order, then reorder 47 | return x[index] 48 | else: 49 | # If indexes have different values, overwrite 50 | x.index = index 51 | return x 52 | elif any([isinstance(x, (np.ndarray, list))]): 53 | return pd.Series(x, index=index) 54 | else: 55 | raise TypeError("Please supply a list, numpy array, pandas Series") 56 | 57 | 58 | def preprocess_data(X, X_name=None, column_names=None, verbose=0): 59 | """ 60 | Preprocess data. 61 | 62 | Does basic preprocessing of the data: Transforms to DataFrame, Warns which features have missing variables, 63 | and transforms object dtype features to category type, such that LightGBM handles them by default. 64 | 65 | Args: 66 | X (pd.DataFrame, list of lists, np.array): 67 | Provided dataset. 68 | 69 | X_name (str, optional): 70 | Name of the X variable, that will be printed in the warnings. 71 | 72 | column_names (list of str, optional): 73 | List of feature names of the provided samples. If provided it will be used to overwrite the existing 74 | feature names. If not provided the existing feature names are used or default feature names are 75 | generated. 76 | 77 | verbose (int, optional): 78 | Controls verbosity of the output: 79 | 80 | - 0 - neither prints nor warnings are shown 81 | - 1 - only most important warnings 82 | - 2 - shows all prints and all warnings. 83 | 84 | 85 | Returns: 86 | (pd.DataFrame): 87 | Preprocessed dataset. 88 | """ 89 | X_name = "X" if X_name is None else X_name 90 | 91 | # Make sure that X is a pd.DataFrame with correct column names 92 | X = assure_pandas_df(X, column_names=column_names) 93 | 94 | if verbose > 0: 95 | # Warn if missing 96 | columns_with_missing = X.columns[X.isnull().any()].tolist() 97 | if columns_with_missing: 98 | warnings.warn( 99 | f"The following variables in {X_name} contains missing values {columns_with_missing}. " 100 | f"Make sure to impute missing or apply a model that handles them automatically." 101 | ) 102 | 103 | # Warn if categorical features and change to category 104 | categorical_features = X.select_dtypes(include=["category", "object"]).columns.tolist() 105 | # Set categorical features type to category 106 | if categorical_features: 107 | if verbose > 0: 108 | warnings.warn( 109 | f"The following variables in {X_name} contains categorical variables: " 110 | f"{categorical_features}. Make sure to use a model that handles them automatically or " 111 | f"encode them into numerical variables." 112 | ) 113 | 114 | # Ensure category dtype, to enable models e.g. LighGBM, handle them automatically 115 | object_columns = X.select_dtypes(include=["object"]).columns 116 | if not object_columns.empty: 117 | X[object_columns] = X[object_columns].astype("category") 118 | 119 | return X, X.columns.tolist() 120 | 121 | 122 | def preprocess_labels(y, y_name=None, index=None, verbose=0): 123 | """ 124 | Does basic preparation of the labels. Turns them into Series, and WARS in case the target is not binary. 125 | 126 | Args: 127 | y (pd.Series, list, np.array): 128 | Provided labels. 129 | 130 | y_name (str, optional): 131 | Name of the y variable, that will be printed in the warnings. 132 | 133 | index (list of int or pd.Index, optional): 134 | The index correct index that should be used for y. In case y is a list or np.array, the index is set when 135 | creating pd.Series. In case it is pd.Series already, if the indexes consist of the same values, the y is 136 | going to be ordered based on provided index, otherwise, the current index of y is overwritten by index 137 | argument. 138 | 139 | verbose (int, optional): 140 | Controls verbosity of the output: 141 | 142 | - 0 - neither prints nor warnings are shown 143 | - 1 - only most important warnings 144 | - 2 - shows all prints and all warnings. 145 | 146 | Returns: 147 | (pd.Series): 148 | Labels in the form of pd.Series. 149 | """ 150 | y_name = "y" if y_name is None else y_name 151 | 152 | # Make sure that y is a series with correct index 153 | y = assure_pandas_series(y, index=index) 154 | 155 | return y 156 | -------------------------------------------------------------------------------- /probatus/utils/base_class_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from probatus.utils import NotFittedError 4 | 5 | 6 | class BaseFitComputeClass(ABC): 7 | """ 8 | Placeholder that must be overwritten by subclass. 9 | """ 10 | 11 | fitted = False 12 | 13 | def _check_if_fitted(self): 14 | """ 15 | Checks if object has been fitted. If not, NotFittedError is raised. 16 | """ 17 | if not self.fitted: 18 | raise (NotFittedError("The object has not been fitted. Please run fit() method first")) 19 | 20 | @abstractmethod 21 | def fit(self, *args, **kwargs): 22 | """ 23 | Placeholder that must be overwritten by subclass. 24 | """ 25 | pass 26 | 27 | @abstractmethod 28 | def compute(self, *args, **kwargs): 29 | """ 30 | Placeholder that must be overwritten by subclass. 31 | """ 32 | pass 33 | 34 | @abstractmethod 35 | def fit_compute(self, *args, **kwargs): 36 | """ 37 | Placeholder that must be overwritten by subclass. 38 | """ 39 | pass 40 | 41 | 42 | class BaseFitComputePlotClass(BaseFitComputeClass): 43 | """ 44 | Base class. 45 | """ 46 | 47 | @abstractmethod 48 | def plot(self, *args, **kwargs): 49 | """ 50 | Placeholder method for plotting. 51 | """ 52 | pass 53 | -------------------------------------------------------------------------------- /probatus/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | class NotFittedError(Exception): 2 | """ 3 | Error. 4 | """ 5 | 6 | def __init__(self, message): 7 | """ 8 | Init error. 9 | """ 10 | self.message = message 11 | -------------------------------------------------------------------------------- /probatus/utils/scoring.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import get_scorer 2 | 3 | 4 | def get_single_scorer(scoring): 5 | """ 6 | Returns Scorer, based on provided input in scoring argument. 7 | 8 | Args: 9 | scoring (string or probatus.utils.Scorer, optional): 10 | Metric for which the model performance is calculated. It can be either a metric name aligned with 11 | predefined classification scorers names in sklearn 12 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)). 13 | Another option is using probatus.utils.Scorer to define a custom metric. 14 | 15 | Returns: 16 | (probatus.utils.Scorer): 17 | Scorer that can be used for scoring models 18 | """ 19 | if isinstance(scoring, str): 20 | return Scorer(scoring) 21 | elif isinstance(scoring, Scorer): 22 | return scoring 23 | else: 24 | raise (ValueError("The scoring should contain either strings or probatus.utils.Scorer class")) 25 | 26 | 27 | class Scorer: 28 | """ 29 | Scores a given machine learning model based on the provided metric name and optionally a custom scoring function. 30 | 31 | Examples: 32 | 33 | ```python 34 | from probatus.utils import Scorer 35 | from sklearn.metrics import make_scorer 36 | from sklearn.datasets import make_classification 37 | from sklearn.model_selection import train_test_split 38 | from sklearn.ensemble import RandomForestClassifier 39 | import pandas as pd 40 | 41 | # Make ROC AUC scorer 42 | scorer1 = Scorer('roc_auc') 43 | 44 | # Make custom scorer with following function: 45 | def custom_metric(y_true, y_pred): 46 | return (y_true == y_pred).sum() 47 | scorer2 = Scorer('custom_metric', custom_scorer=make_scorer(custom_metric)) 48 | 49 | # Prepare two samples 50 | feature_names = ['f1', 'f2', 'f3', 'f4'] 51 | X, y = make_classification(n_samples=1000, n_features=4, random_state=0) 52 | X = pd.DataFrame(X, columns=feature_names) 53 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) 54 | 55 | # Prepare and fit model. Remember about class_weight="balanced" or an equivalent. 56 | model = RandomForestClassifier(class_weight='balanced', n_estimators = 100, max_depth=2, random_state=0) 57 | model = model.fit(X_train, y_train) 58 | 59 | # Score model 60 | score_test_scorer1 = scorer1.score(model, X_test, y_test) 61 | score_test_scorer2 = scorer2.score(model, X_test, y_test) 62 | 63 | print(f'Test ROC AUC is {score_test_scorer1}, Test {scorer2.metric_name} is {score_test_scorer2}') 64 | ``` 65 | """ 66 | 67 | def __init__(self, metric_name, custom_scorer=None): 68 | """ 69 | Initializes the class. 70 | 71 | Args: 72 | metric_name (str): Name of the metric used to evaluate the model. 73 | If the custom_scorer is not passed, the 74 | metric name needs to be aligned with classification scorers names in sklearn 75 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)). 76 | custom_scorer (sklearn.metrics Scorer callable, optional): Callable 77 | that can score samples. 78 | """ 79 | self.metric_name = metric_name 80 | if custom_scorer is not None: 81 | self.scorer = custom_scorer 82 | else: 83 | self.scorer = get_scorer(self.metric_name) 84 | 85 | def score(self, model, X, y): 86 | """ 87 | Scores the samples model based on the provided metric name. 88 | 89 | Args 90 | model (model object): 91 | Model to be scored. 92 | 93 | X (array-like of shape (n_samples,n_features)): 94 | Samples on which the model is scored. 95 | 96 | y (array-like of shape (n_samples,)): 97 | Labels on which the model is scored. 98 | 99 | Returns: 100 | (float): 101 | Score returned by the model 102 | """ 103 | return self.scorer(model, X, y) 104 | -------------------------------------------------------------------------------- /probatus/utils/shap_helpers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from shap import Explainer 6 | from shap.explainers import TreeExplainer 7 | from shap.utils import sample 8 | from sklearn.pipeline import Pipeline 9 | 10 | 11 | def shap_calc( 12 | model, 13 | X, 14 | return_explainer=False, 15 | verbose=0, 16 | random_state=None, 17 | sample_size=100, 18 | approximate=False, 19 | check_additivity=True, 20 | **shap_kwargs, 21 | ): 22 | """ 23 | Helper function to calculate the shapley values for a given model. 24 | 25 | Args: 26 | model (model): 27 | Trained model. 28 | 29 | X (pd.DataFrame or np.ndarray): 30 | features set. 31 | 32 | return_explainer (boolean): 33 | if True, returns a a tuple (shap_values, explainer). 34 | 35 | verbose (int, optional): 36 | Controls verbosity of the output: 37 | 38 | - 0 - neither prints nor warnings are shown 39 | - 1 - only most important warnings 40 | - 2 - shows all prints and all warnings. 41 | 42 | random_state (int, optional): 43 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For 44 | reproducible results set it to an integer. 45 | 46 | approximate (boolean): 47 | if True uses shap approximations - less accurate, but very fast. It applies to tree-based explainers only. 48 | 49 | check_additivity (boolean): 50 | if False SHAP will disable the additivity check for tree-based models. 51 | 52 | **shap_kwargs: kwargs of the shap.Explainer 53 | 54 | Returns: 55 | (np.ndarray or tuple(np.ndarray, shap.Explainer)): 56 | shapley_values for the model, optionally also returns the explainer. 57 | 58 | """ 59 | if isinstance(model, Pipeline): 60 | raise TypeError( 61 | "The provided model is a Pipeline. Unfortunately, the features based on SHAP do not support " 62 | "pipelines, because they cannot be used in combination with shap.Explainer. Please apply any " 63 | "data transformations before running the probatus module." 64 | ) 65 | 66 | # Suppress warnings regarding XGboost and Lightgbm models. 67 | with warnings.catch_warnings(): 68 | warnings.simplefilter("ignore" if verbose <= 1 else "default") 69 | 70 | # For tree explainers, do not pass masker when feature_perturbation is 71 | # tree_path_dependent, or when X contains categorical features 72 | # related to issue: 73 | # https://github.com/slundberg/shap/issues/480 74 | if shap_kwargs.get("feature_perturbation") == "tree_path_dependent" or X.select_dtypes("category").shape[1] > 0: 75 | # Calculate Shap values. 76 | explainer = Explainer(model, seed=random_state, **shap_kwargs) 77 | else: 78 | # Create the background data,required for non tree based models. 79 | # A single datapoint can passed as mask 80 | # (https://github.com/slundberg/shap/issues/955#issuecomment-569837201) 81 | if X.shape[0] < sample_size: 82 | sample_size = int(np.ceil(X.shape[0] * 0.2)) 83 | else: 84 | pass 85 | mask = sample(X, sample_size, random_state=random_state) 86 | explainer = Explainer(model, seed=random_state, masker=mask, **shap_kwargs) 87 | 88 | # For tree-explainers allow for using check_additivity and approximate arguments 89 | if isinstance(explainer, TreeExplainer): 90 | shap_values = explainer.shap_values(X, check_additivity=check_additivity, approximate=approximate) 91 | 92 | # From SHAP version 0.43+ https://github.com/shap/shap/pull/3121 required to 93 | # get the second dimension of calculated Shap values. 94 | if not isinstance(shap_values, list) and len(shap_values.shape) == 3: 95 | shap_values = shap_values[:, :, 1] 96 | else: 97 | # Calculate Shap values 98 | shap_values = explainer.shap_values(X) 99 | 100 | if isinstance(shap_values, list) and len(shap_values) == 2: 101 | warnings.warn( 102 | "Shap values are related to the output probabilities of class 1 for this model, instead of log odds." 103 | ) 104 | shap_values = shap_values[1] 105 | 106 | if return_explainer: 107 | return shap_values, explainer 108 | return shap_values 109 | 110 | 111 | def shap_to_df(model, X, precalc_shap=None, **kwargs): 112 | """ 113 | Calculates the shap values and return the pandas DataFrame with the columns and the index of the original. 114 | 115 | Args: 116 | model (model): 117 | Pretrained model (Random Forest of XGBoost at the moment). 118 | 119 | X (pd.DataFrame or np.ndarray): 120 | Dataset on which the SHAP importance is calculated. 121 | 122 | precalc_shap (np.array): 123 | Precalculated SHAP values. If None, they are computed. 124 | 125 | **kwargs: for the function shap_calc 126 | 127 | Returns: 128 | (pd.DataFrame): 129 | Dataframe with SHAP feature importance per features on X dataset. 130 | """ 131 | shap_values = precalc_shap if precalc_shap is not None else shap_calc(model, X, **kwargs) 132 | 133 | try: 134 | return pd.DataFrame(shap_values, columns=X.columns, index=X.index) 135 | except AttributeError: 136 | if isinstance(X, np.ndarray) and len(X.shape) == 2: 137 | return pd.DataFrame(shap_values, columns=[f"col_{ix}" for ix in range(X.shape[1])]) 138 | else: 139 | raise TypeError("X must be a dataframe or a 2d array") 140 | 141 | 142 | def calculate_shap_importance(shap_values, columns, output_columns_suffix="", shap_variance_penalty_factor=None): 143 | """ 144 | Returns the average shapley value for each column of the dataframe, as well as the average absolute shap value. 145 | 146 | Args: 147 | shap_values (np.array): 148 | Shap values. 149 | 150 | columns (list of str): 151 | Feature names. 152 | 153 | output_columns_suffix (str, optional): 154 | Suffix to be added at the end of column names in the output. 155 | 156 | shap_variance_penalty_factor (int or float, optional): 157 | Apply aggregation penalty when computing average of shap values for a given feature. 158 | Results in a preference for features that have smaller standard deviation of shap 159 | values (more coherent shap importance). Recommend value 0.5 - 1.0. 160 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor)) 161 | 162 | Returns: 163 | (pd.DataFrame): 164 | Mean absolute shap values and Mean shap values of features. 165 | 166 | """ 167 | if shap_variance_penalty_factor is None or shap_variance_penalty_factor < 0: 168 | shap_variance_penalty_factor = 0 169 | elif not isinstance(shap_variance_penalty_factor, (float, int)): 170 | warnings.warn( 171 | "shap_variance_penalty_factor must be None, int, or float. Setting shap_variance_penalty_factor = 0" 172 | ) 173 | shap_variance_penalty_factor = 0 174 | 175 | abs_shap_values = np.abs(shap_values) 176 | if np.ndim(shap_values) > 2: # multi-class case 177 | sum_abs_shap = np.sum(abs_shap_values, axis=0) 178 | sum_shap = np.sum(shap_values, axis=0) 179 | shap_abs_mean = np.mean(sum_abs_shap, axis=0) 180 | shap_mean = np.mean(sum_shap, axis=0) 181 | penalized_shap_abs_mean = shap_abs_mean - (np.std(sum_abs_shap, axis=0) * shap_variance_penalty_factor) 182 | else: 183 | # Find average shap importance for neg and pos class 184 | shap_abs_mean = np.mean(abs_shap_values, axis=0) 185 | shap_mean = np.mean(shap_values, axis=0) 186 | penalized_shap_abs_mean = shap_abs_mean - (np.std(abs_shap_values, axis=0) * shap_variance_penalty_factor) 187 | 188 | # Prepare the values in a df and set the correct column types 189 | importance_df = pd.DataFrame( 190 | { 191 | f"mean_abs_shap_value{output_columns_suffix}": shap_abs_mean, 192 | f"mean_shap_value{output_columns_suffix}": shap_mean, 193 | f"penalized_mean_abs_shap_value{output_columns_suffix}": penalized_shap_abs_mean, 194 | }, 195 | index=columns, 196 | ).astype(float) 197 | 198 | importance_df = importance_df.sort_values(f"penalized_mean_abs_shap_value{output_columns_suffix}", ascending=False) 199 | 200 | # Drop penalized column 201 | importance_df = importance_df.drop(columns=[f"penalized_mean_abs_shap_value{output_columns_suffix}"]) 202 | 203 | return importance_df 204 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "probatus" 7 | version = "3.1.3" 8 | requires-python= ">=3.9" 9 | description = "Validation of regression & classifiers and data used to develop them" 10 | readme = { file = "README.md", content-type = "text/markdown" } 11 | authors = [ 12 | { name = "ING Bank N.V.", email = "reinier.koops@ing.com" } 13 | ] 14 | license = { file = "LICENCE" } 15 | classifiers = [ 16 | "Intended Audience :: Developers", 17 | "Intended Audience :: Science/Research", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Topic :: Scientific/Engineering", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ] 28 | dependencies = [ 29 | "scikit-learn>=0.22.2", 30 | "pandas>=1.0.0", 31 | "matplotlib>=3.1.1", 32 | "joblib>=0.13.2", 33 | "shap>=0.43.0", 34 | "numpy>=1.23.2,<2.0.0", 35 | "numba>=0.57.0", 36 | "loguru>=0.7.2", 37 | ] 38 | 39 | [project.urls] 40 | Homepage = "https://ing-bank.github.io/probatus/" 41 | Documentation = "https://ing-bank.github.io/probatus/api/feature_elimination.html" 42 | Repository = "https://github.com/ing-bank/probatus.git" 43 | Changelog = "https://github.com/ing-bank/probatus/blob/main/CHANGELOG.md" 44 | 45 | [project.optional-dependencies] 46 | 47 | 48 | dev = [ 49 | "black>=19.10b0", 50 | "mypy>=0.770", 51 | "pytest>=6.0.0", 52 | "pytest-cov>=2.10.0", 53 | "pyflakes", 54 | "joblib>=0.13.2", 55 | "jupyter>=1.0.0", 56 | "nbconvert>=6.0.7", 57 | "pre-commit>=2.7.1", 58 | "isort>=5.12.0", 59 | "codespell>=2.2.4", 60 | "ruff>=0.2.2", 61 | "lightgbm>=3.3.0", 62 | "catboost>=1.2", 63 | "xgboost>=1.5.0", 64 | ] 65 | docs = [ 66 | "mkdocs>=1.5.3", 67 | "mkdocs-jupyter>=0.24.3", 68 | "mkdocs-material>=9.5.13", 69 | "mkdocstrings>=0.24.1", 70 | "mkdocstrings-python>=1.8.0", 71 | ] 72 | 73 | # Separating these allow for more install flexibility. 74 | all = ["probatus[dev,docs]"] 75 | 76 | [tool.setuptools.packages.find] 77 | exclude = ["tests", "notebooks", "docs"] 78 | 79 | [tool.nbqa.addopts] 80 | # E402: Ignores imports not at the top of file for IPYNB since this makes copy-pasting easier. 81 | ruff = ["--fix", "--ignore=E402"] 82 | isort = ["--profile=black"] 83 | black = ["--line-length=120"] 84 | 85 | [tool.mypy] 86 | python_version = "3.9" 87 | ignore_missing_imports = true 88 | namespace_packages = true 89 | pretty = true 90 | 91 | [tool.ruff] 92 | line-length = 120 93 | extend-exclude = ["docs", "mkdocs.yml", ".github", "*md", "LICENCE", ".pre-commit-config.yaml", ".gitignore"] 94 | force-exclude = true 95 | 96 | [tool.ruff.lint] 97 | # D100 requires all Python files (modules) to have a "public" docstring even if all functions within have a docstring. 98 | # D104 requires __init__ files to have a docstring 99 | # D202 No blank lines allowed after function docstring 100 | # D212 101 | # D200 102 | # D411 Missing blank line before section 103 | # D412 No blank lines allowed between a section header and its content 104 | # D417 Missing argument descriptions in the docstring # Only ignored because of false positve when using multiline args. 105 | # E203 106 | # E731 do not assign a lambda expression, use a def 107 | # W293 blank line contains whitespace 108 | ignore = ["D100", "D104", "D202", "D212", "D200", "E203", "E731", "W293", "D412", "D417", "D411", "RUF100"] 109 | 110 | [tool.ruff.lint.pydocstyle] 111 | convention = "google" 112 | 113 | [tool.isort] 114 | line_length = "120" 115 | profile = "black" 116 | filter_files = true 117 | extend_skip = ["__init__.py", "docs"] 118 | 119 | [tool.codespell] 120 | skip = "**.egg-info*" 121 | 122 | [tool.black] 123 | line-length = 120 124 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from sklearn.datasets import make_classification 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.tree import DecisionTreeClassifier 9 | from catboost import CatBoostClassifier 10 | from lightgbm import LGBMClassifier 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.model_selection import RandomizedSearchCV 13 | 14 | 15 | @pytest.fixture(scope="function") 16 | def random_state(): 17 | RANDOM_STATE = 0 18 | 19 | return RANDOM_STATE 20 | 21 | 22 | @pytest.fixture(scope="function") 23 | def random_state_42(): 24 | RANDOM_STATE = 42 25 | 26 | return RANDOM_STATE 27 | 28 | 29 | @pytest.fixture(scope="function") 30 | def random_state_1234(): 31 | RANDOM_STATE = 1234 32 | 33 | return RANDOM_STATE 34 | 35 | 36 | @pytest.fixture(scope="function") 37 | def random_state_1(): 38 | RANDOM_STATE = 1 39 | 40 | return RANDOM_STATE 41 | 42 | 43 | @pytest.fixture(scope="function") 44 | def mock_model(): 45 | return Mock() 46 | 47 | 48 | @pytest.fixture(scope="function") 49 | def complex_data(random_state): 50 | feature_names = ["f1_categorical", "f2_missing", "f3_static", "f4", "f5"] 51 | 52 | # Prepare two samples 53 | X, y = make_classification( 54 | n_samples=50, 55 | class_sep=0.05, 56 | n_informative=2, 57 | n_features=5, 58 | random_state=random_state, 59 | n_redundant=2, 60 | n_clusters_per_class=1, 61 | ) 62 | X = pd.DataFrame(X, columns=feature_names) 63 | X.loc[0:10, "f2_missing"] = np.nan 64 | return X, y 65 | 66 | 67 | @pytest.fixture(scope="function") 68 | def complex_data_with_categorical(complex_data): 69 | X, y = complex_data 70 | X["f1_categorical"] = X["f1_categorical"].astype(str).astype("category") 71 | 72 | return X, y 73 | 74 | 75 | @pytest.fixture(scope="function") 76 | def complex_data_split(complex_data, random_state_42): 77 | X, y = complex_data 78 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state_42) 79 | return X_train, X_test, y_train, y_test 80 | 81 | 82 | @pytest.fixture(scope="function") 83 | def complex_data_split_with_categorical(complex_data_split): 84 | X_train, X_test, y_train, y_test = complex_data_split 85 | X_train["f1_categorical"] = X_train["f1_categorical"].astype(str).astype("category") 86 | X_test["f1_categorical"] = X_test["f1_categorical"].astype(str).astype("category") 87 | 88 | return X_train, X_test, y_train, y_test 89 | 90 | 91 | @pytest.fixture(scope="function") 92 | def complex_lightgbm(random_state_42): 93 | model = LGBMClassifier(max_depth=5, num_leaves=11, class_weight="balanced", random_state=random_state_42) 94 | return model 95 | 96 | 97 | @pytest.fixture(scope="function") 98 | def complex_fitted_lightgbm(complex_data_split_with_categorical, complex_lightgbm): 99 | X_train, _, y_train, _ = complex_data_split_with_categorical 100 | 101 | return complex_lightgbm.fit(X_train, y_train) 102 | 103 | 104 | @pytest.fixture(scope="function") 105 | def catboost_classifier(random_state): 106 | model = CatBoostClassifier(random_seed=random_state) 107 | return model 108 | 109 | 110 | @pytest.fixture(scope="function") 111 | def decision_tree_classifier(random_state): 112 | model = DecisionTreeClassifier(max_depth=1, random_state=random_state) 113 | return model 114 | 115 | 116 | @pytest.fixture(scope="function") 117 | def randomized_search_decision_tree_classifier(decision_tree_classifier, random_state): 118 | param_grid = {"criterion": ["gini"], "min_samples_split": [1, 2]} 119 | cv = RandomizedSearchCV(decision_tree_classifier, param_grid, cv=2, n_iter=2, random_state=random_state) 120 | return cv 121 | 122 | 123 | @pytest.fixture(scope="function") 124 | def logistic_regression(random_state): 125 | model = LogisticRegression(random_state=random_state) 126 | return model 127 | 128 | 129 | @pytest.fixture(scope="function") 130 | def X_train(): 131 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[1, 2, 3, 4]) 132 | 133 | 134 | @pytest.fixture(scope="function") 135 | def y_train(): 136 | return pd.Series([1, 0, 1, 0], index=[1, 2, 3, 4]) 137 | 138 | 139 | @pytest.fixture(scope="function") 140 | def X_test(): 141 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[5, 6, 7, 8]) 142 | 143 | 144 | @pytest.fixture(scope="function") 145 | def y_test(): 146 | return pd.Series([0, 0, 1, 0], index=[5, 6, 7, 8]) 147 | 148 | 149 | @pytest.fixture(scope="function") 150 | def fitted_logistic_regression(X_train, y_train, logistic_regression): 151 | return logistic_regression.fit(X_train, y_train) 152 | 153 | 154 | @pytest.fixture(scope="function") 155 | def fitted_tree(X_train, y_train, decision_tree_classifier): 156 | return decision_tree_classifier.fit(X_train, y_train) 157 | -------------------------------------------------------------------------------- /tests/docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/docs/__init__.py -------------------------------------------------------------------------------- /tests/docs/test_docstring.py: -------------------------------------------------------------------------------- 1 | # This approach is adapted from, and explained in: https://calmcode.io/docs/epic.html 2 | 3 | import os 4 | from typing import List 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | import pytest 9 | 10 | import probatus.feature_elimination 11 | import probatus.interpret 12 | import probatus.sample_similarity 13 | import probatus.utils 14 | 15 | # Turn off interactive mode in plots 16 | plt.ioff() 17 | matplotlib.use("Agg") 18 | 19 | CLASSES_TO_TEST = [ 20 | probatus.feature_elimination.ShapRFECV, 21 | probatus.interpret.DependencePlotter, 22 | probatus.sample_similarity.SHAPImportanceResemblance, 23 | probatus.sample_similarity.PermutationImportanceResemblance, 24 | probatus.utils.Scorer, 25 | ] 26 | 27 | CLASSES_TO_TEST_LGBM = [ 28 | probatus.feature_elimination.EarlyStoppingShapRFECV, 29 | ] 30 | 31 | FUNCTIONS_TO_TEST: List = [] 32 | 33 | 34 | def handle_docstring(doc, indent): 35 | """ 36 | Check python code in docstring. 37 | 38 | This function will read through the docstring and grab 39 | the first python code block. It will try to execute it. 40 | If it fails, the calling test should raise a flag. 41 | """ 42 | if not doc: 43 | return 44 | start = doc.find("```python\n") 45 | end = doc.find("```\n") 46 | if start != -1: 47 | if end != -1: 48 | code_part = doc[(start + 10) : end].replace(" " * indent, "") 49 | exec(code_part) 50 | 51 | 52 | @pytest.mark.parametrize("c", CLASSES_TO_TEST) 53 | def test_class_docstrings(c): 54 | """ 55 | Take the docstring of a given class. 56 | 57 | The test passes if the usage examples causes no errors. 58 | """ 59 | handle_docstring(c.__doc__, indent=4) 60 | 61 | 62 | @pytest.mark.skipif(os.environ.get("SKIP_LIGHTGBM") == "true", reason="LightGBM tests disabled") 63 | @pytest.mark.parametrize("c", CLASSES_TO_TEST_LGBM) 64 | def test_class_docstrings_lgbm(c): 65 | """ 66 | Take the docstring of a given class which uses LightGBM. 67 | 68 | The test passes if the usage examples causes no errors. 69 | 70 | The test is skipped if the environment does not support LightGBM correctly, such as macos. 71 | """ 72 | handle_docstring(c.__doc__, indent=4) 73 | 74 | 75 | @pytest.mark.parametrize("f", FUNCTIONS_TO_TEST) 76 | def test_function_docstrings(f): 77 | """ 78 | Take the docstring of every function. 79 | 80 | The test passes if the usage examples causes no errors. 81 | """ 82 | handle_docstring(f.__doc__, indent=4) 83 | -------------------------------------------------------------------------------- /tests/docs/test_notebooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | import nbformat 6 | from nbconvert.preprocessors import ExecutePreprocessor 7 | 8 | TIMEOUT_SECONDS = 1800 9 | PATH_NOTEBOOKS = [str(path) for path in Path("docs").glob("*/*.ipynb")] 10 | NB_FLAG = os.environ.get("TEST_NOTEBOOKS") # Turn on tests by setting TEST_NOTEBOOKS = 1 11 | TEST_NOTEBOOKS = False if NB_FLAG == "1" else True 12 | 13 | 14 | @pytest.mark.parametrize("notebook_path", PATH_NOTEBOOKS) 15 | @pytest.mark.skipif(TEST_NOTEBOOKS, reason="Skip notebook tests if TEST_NOTEBOOK isn't set") 16 | def test_notebook(notebook_path: str) -> None: 17 | """Run a notebook and check no exception is raised.""" 18 | with open(notebook_path) as f: 19 | nb = nbformat.read(f, as_version=4) 20 | 21 | ep = ExecutePreprocessor(timeout=TIMEOUT_SECONDS, kernel_name="python3") 22 | ep.preprocess(nb, {"metadata": {"path": str(Path(notebook_path).parent)}}) 23 | -------------------------------------------------------------------------------- /tests/feature_elimination/test_feature_elimination.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | from lightgbm import LGBMClassifier 4 | from sklearn.datasets import load_diabetes, make_classification 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.model_selection import RandomizedSearchCV, StratifiedGroupKFold, StratifiedKFold 8 | from sklearn.pipeline import Pipeline 9 | from sklearn.preprocessing import StandardScaler 10 | from sklearn.svm import SVC 11 | from xgboost import XGBClassifier, XGBRegressor 12 | 13 | from probatus.feature_elimination import ShapRFECV, EarlyStoppingShapRFECV 14 | from probatus.utils import preprocess_labels 15 | 16 | 17 | @pytest.fixture(scope="function") 18 | def X(): 19 | return pd.DataFrame( 20 | { 21 | "col_1": [1, 1, 1, 1, 1, 1, 1, 0], 22 | "col_2": [0, 0, 0, 0, 0, 0, 0, 1], 23 | "col_3": [1, 0, 1, 0, 1, 0, 1, 0], 24 | }, 25 | index=[1, 2, 3, 4, 5, 6, 7, 8], 26 | ) 27 | 28 | 29 | @pytest.fixture(scope="function") 30 | def y(): 31 | return pd.Series([1, 0, 1, 0, 1, 0, 1, 0], index=[1, 2, 3, 4, 5, 6, 7, 8]) 32 | 33 | 34 | @pytest.fixture(scope="function") 35 | def sample_weight(): 36 | return pd.Series([1, 1, 1, 1, 1, 1, 1, 1], index=[1, 2, 3, 4, 5, 6, 7, 8]) 37 | 38 | 39 | @pytest.fixture(scope="function") 40 | def groups(): 41 | return pd.Series(["grp1", "grp1", "grp1", "grp1", "grp2", "grp2", "grp2", "grp2"], index=[1, 2, 3, 4, 5, 6, 7, 8]) 42 | 43 | 44 | @pytest.fixture(scope="function") 45 | def XGBoost_classifier(random_state): 46 | model = XGBClassifier(n_estimators=200, max_depth=3, random_state=random_state) 47 | return model 48 | 49 | 50 | @pytest.fixture(scope="function") 51 | def XGBoost_regressor(random_state): 52 | model = XGBRegressor(n_estimators=200, max_depth=3, random_state=random_state) 53 | return model 54 | 55 | 56 | def test_shap_rfe_regressor(XGBoost_regressor, random_state): 57 | diabetes = load_diabetes() 58 | X = pd.DataFrame(diabetes.data, columns=diabetes.feature_names) 59 | y = diabetes.target 60 | 61 | shap_elimination = ShapRFECV(XGBoost_regressor, step=0.8, cv=2, scoring="r2", n_jobs=4, random_state=random_state) 62 | report = shap_elimination.fit_compute(X, y) 63 | 64 | assert report.shape[0] == 3 65 | assert shap_elimination.get_reduced_features_set(1) == ["bmi"] 66 | 67 | _ = shap_elimination.plot(show=False) 68 | 69 | 70 | def test_shap_rfe_randomized_search(X, y, randomized_search_decision_tree_classifier, random_state): 71 | search = randomized_search_decision_tree_classifier 72 | shap_elimination = ShapRFECV(search, step=0.8, cv=2, scoring="roc_auc", n_jobs=4, random_state=random_state) 73 | report = shap_elimination.fit_compute(X, y) 74 | 75 | assert report.shape[0] == 2 76 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 77 | 78 | _ = shap_elimination.plot(show=False) 79 | 80 | 81 | def test_shap_rfe_multi_class(X, y, decision_tree_classifier, random_state): 82 | shap_elimination = ShapRFECV( 83 | decision_tree_classifier, 84 | cv=2, 85 | scoring="roc_auc_ovr", 86 | random_state=random_state, 87 | ) 88 | 89 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False) 90 | 91 | assert report.shape[0] == 3 92 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 93 | 94 | 95 | def test_shap_rfe(X, y, sample_weight, decision_tree_classifier, random_state): 96 | shap_elimination = ShapRFECV( 97 | decision_tree_classifier, 98 | random_state=random_state, 99 | step=1, 100 | cv=2, 101 | scoring="roc_auc", 102 | n_jobs=4, 103 | ) 104 | report = shap_elimination.fit_compute(X, y, sample_weight=sample_weight, approximate=True, check_additivity=False) 105 | 106 | assert report.shape[0] == 3 107 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 108 | 109 | 110 | def test_shap_rfe_group_cv(X, y, groups, sample_weight, decision_tree_classifier, random_state): 111 | cv = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=random_state) 112 | shap_elimination = ShapRFECV( 113 | decision_tree_classifier, 114 | random_state=random_state, 115 | step=1, 116 | cv=cv, 117 | scoring="roc_auc", 118 | n_jobs=4, 119 | ) 120 | report = shap_elimination.fit_compute( 121 | X, y, groups=groups, sample_weight=sample_weight, approximate=True, check_additivity=False 122 | ) 123 | 124 | assert report.shape[0] == 3 125 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 126 | 127 | 128 | def test_shap_pipeline_error(X, y, decision_tree_classifier, random_state): 129 | model = Pipeline( 130 | [ 131 | ("scaler", StandardScaler()), 132 | ("dt", decision_tree_classifier), 133 | ] 134 | ) 135 | with pytest.raises(TypeError): 136 | shap_elimination = ShapRFECV( 137 | model, 138 | random_state=random_state, 139 | step=1, 140 | cv=2, 141 | scoring="roc_auc", 142 | n_jobs=4, 143 | ) 144 | shap_elimination = shap_elimination.fit(X, y, approximate=True, check_additivity=False) 145 | 146 | 147 | def test_shap_rfe_linear_model(X, y, random_state): 148 | model = LogisticRegression(C=1, random_state=random_state) 149 | shap_elimination = ShapRFECV(model, random_state=random_state, step=1, cv=2, scoring="roc_auc", n_jobs=4) 150 | report = shap_elimination.fit_compute(X, y) 151 | 152 | assert report.shape[0] == 3 153 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 154 | 155 | 156 | def test_shap_rfe_svm(X, y, random_state): 157 | model = SVC(C=1, kernel="linear", probability=True, random_state=random_state) 158 | shap_elimination = ShapRFECV(model, random_state=random_state, step=1, cv=2, scoring="roc_auc", n_jobs=4) 159 | shap_elimination = shap_elimination.fit(X, y) 160 | report = shap_elimination.compute() 161 | 162 | assert report.shape[0] == 3 163 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"] 164 | 165 | 166 | def test_shap_rfe_cols_to_keep(X, y, decision_tree_classifier, random_state): 167 | shap_elimination = ShapRFECV( 168 | decision_tree_classifier, 169 | random_state=random_state, 170 | step=2, 171 | cv=2, 172 | scoring="roc_auc", 173 | n_jobs=4, 174 | min_features_to_select=1, 175 | ) 176 | report = shap_elimination.fit_compute(X, y, columns_to_keep=["col_2", "col_3"]) 177 | 178 | assert report.shape[0] == 2 179 | reduced_feature_set = set(shap_elimination.get_reduced_features_set(num_features=2)) 180 | assert reduced_feature_set == {"col_2", "col_3"} 181 | 182 | 183 | def test_shap_rfe_randomized_search_cols_to_keep(X, y, randomized_search_decision_tree_classifier, random_state): 184 | search = randomized_search_decision_tree_classifier 185 | shap_elimination = ShapRFECV(search, step=0.8, cv=2, scoring="roc_auc", n_jobs=4, random_state=random_state) 186 | report = shap_elimination.fit_compute(X, y, columns_to_keep=["col_2", "col_3"]) 187 | 188 | assert report.shape[0] == 2 189 | reduced_feature_set = set(shap_elimination.get_reduced_features_set(num_features=2)) 190 | assert reduced_feature_set == {"col_2", "col_3"} 191 | 192 | 193 | def test_calculate_number_of_features_to_remove(): 194 | assert 3 == ShapRFECV._calculate_number_of_features_to_remove( 195 | current_num_of_features=10, num_features_to_remove=3, min_num_features_to_keep=5 196 | ) 197 | assert 3 == ShapRFECV._calculate_number_of_features_to_remove( 198 | current_num_of_features=8, num_features_to_remove=5, min_num_features_to_keep=5 199 | ) 200 | assert 0 == ShapRFECV._calculate_number_of_features_to_remove( 201 | current_num_of_features=5, num_features_to_remove=1, min_num_features_to_keep=5 202 | ) 203 | assert 4 == ShapRFECV._calculate_number_of_features_to_remove( 204 | current_num_of_features=5, num_features_to_remove=7, min_num_features_to_keep=1 205 | ) 206 | 207 | 208 | def test_shap_automatic_num_feature_selection(decision_tree_classifier, random_state): 209 | X = pd.DataFrame( 210 | { 211 | "col_1": [1, 0, 1, 0, 1, 0, 1, 0], 212 | "col_2": [0, 0, 0, 0, 0, 1, 1, 1], 213 | "col_3": [1, 1, 1, 0, 0, 0, 0, 0], 214 | } 215 | ) 216 | y = pd.Series([0, 0, 0, 0, 1, 1, 1, 1]) 217 | 218 | shap_elimination = ShapRFECV( 219 | decision_tree_classifier, 220 | random_state=random_state, 221 | step=1, 222 | cv=2, 223 | scoring="roc_auc", 224 | n_jobs=1, 225 | ) 226 | _ = shap_elimination.fit_compute(X, y, approximate=True, check_additivity=False) 227 | 228 | best_features = shap_elimination.get_reduced_features_set(num_features="best") 229 | best_coherent_features = shap_elimination.get_reduced_features_set( 230 | num_features="best_coherent", 231 | ) 232 | best_parsimonious_features = shap_elimination.get_reduced_features_set(num_features="best_parsimonious") 233 | 234 | assert best_features == ["col_2"] 235 | assert best_coherent_features == ["col_1", "col_2", "col_3"] 236 | assert best_parsimonious_features == ["col_2"] 237 | 238 | 239 | def test_get_feature_shap_values_per_fold(X, y, decision_tree_classifier, random_state): 240 | shap_elimination = ShapRFECV(decision_tree_classifier, scoring="roc_auc", random_state=random_state) 241 | ( 242 | shap_values, 243 | train_score, 244 | test_score, 245 | ) = shap_elimination._get_feature_shap_values_per_fold( 246 | X, 247 | y, 248 | decision_tree_classifier, 249 | train_index=[2, 3, 4, 5, 6, 7], 250 | val_index=[0, 1], 251 | ) 252 | assert test_score == 1 253 | assert train_score > 0.9 254 | assert shap_values.shape == (2, 3) 255 | 256 | 257 | def test_shap_rfe_same_features_are_kept_after_each_run(random_state_1234): 258 | """ 259 | Test a use case which appears to be flickering with Probatus 1.8.9 and lower. 260 | 261 | Expected result: every run the same outcome. 262 | Probatus <= 1.8.9: A different order every time. 263 | """ 264 | feature_names = [(f"f{num}") for num in range(1, 21)] 265 | 266 | # Code from tutorial on probatus documentation 267 | X, y = make_classification( 268 | n_samples=100, 269 | class_sep=0.05, 270 | n_informative=6, 271 | n_features=20, 272 | random_state=random_state_1234, 273 | n_redundant=10, 274 | n_clusters_per_class=1, 275 | ) 276 | X = pd.DataFrame(X, columns=feature_names) 277 | 278 | random_forest = RandomForestClassifier( 279 | random_state=random_state_1234, 280 | n_estimators=70, 281 | max_features="log2", 282 | criterion="entropy", 283 | class_weight="balanced", 284 | ) 285 | 286 | shap_elimination = ShapRFECV( 287 | random_forest, 288 | step=0.2, 289 | cv=5, 290 | scoring="f1_macro", 291 | n_jobs=1, 292 | random_state=random_state_1234, 293 | ) 294 | 295 | report = shap_elimination.fit_compute(X, y, check_additivity=True) 296 | # Return the set of features with the best validation accuracy 297 | 298 | kept_features = list(report.iloc[[report["val_metric_mean"].idxmax() - 1]]["features_set"].to_list()[0]) 299 | 300 | # Results from the first run 301 | assert [ 302 | "f1", 303 | "f2", 304 | "f3", 305 | "f5", 306 | "f6", 307 | "f10", 308 | "f11", 309 | "f12", 310 | "f13", 311 | "f14", 312 | "f15", 313 | "f16", 314 | "f17", 315 | "f18", 316 | "f19", 317 | "f20", 318 | ] == kept_features 319 | 320 | 321 | def test_shap_rfe_penalty_factor(X, y, decision_tree_classifier, random_state): 322 | shap_elimination = ShapRFECV( 323 | decision_tree_classifier, 324 | random_state=random_state, 325 | step=1, 326 | cv=2, 327 | scoring="roc_auc", 328 | n_jobs=1, 329 | ) 330 | report = shap_elimination.fit_compute( 331 | X, y, shap_variance_penalty_factor=1.0, approximate=True, check_additivity=False 332 | ) 333 | 334 | assert report.shape[0] == 3 335 | assert shap_elimination.get_reduced_features_set(1) == ["col_1"] 336 | 337 | 338 | def test_complex_dataset(complex_data, complex_lightgbm, random_state_1): 339 | X, y = complex_data 340 | 341 | param_grid = { 342 | "n_estimators": [5, 7, 10], 343 | "num_leaves": [3, 5, 7, 10], 344 | } 345 | search = RandomizedSearchCV(complex_lightgbm, param_grid, n_iter=1, random_state=random_state_1) 346 | 347 | shap_elimination = ShapRFECV( 348 | model=search, step=1, cv=10, scoring="roc_auc", n_jobs=3, verbose=1, random_state=random_state_1 349 | ) 350 | 351 | report = shap_elimination.fit_compute(X, y) 352 | 353 | assert report.shape[0] == X.shape[1] 354 | 355 | 356 | def test_shap_rfe_early_stopping_lightGBM(complex_data, random_state): 357 | model = LGBMClassifier(n_estimators=200, max_depth=3, random_state=random_state) 358 | X, y = complex_data 359 | 360 | shap_elimination = EarlyStoppingShapRFECV( 361 | model, 362 | random_state=random_state, 363 | step=1, 364 | cv=10, 365 | scoring="roc_auc", 366 | n_jobs=4, 367 | early_stopping_rounds=5, 368 | eval_metric="auc", 369 | ) 370 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False) 371 | 372 | assert report.shape[0] == 5 373 | assert shap_elimination.get_reduced_features_set(1) == ["f5"] 374 | 375 | 376 | def test_shap_rfe_early_stopping_XGBoost(XGBoost_classifier, complex_data, random_state): 377 | X, y = complex_data 378 | X["f1_categorical"] = X["f1_categorical"].astype(float) 379 | 380 | shap_elimination = ShapRFECV( 381 | XGBoost_classifier, 382 | random_state=random_state, 383 | step=1, 384 | cv=10, 385 | scoring="roc_auc", 386 | n_jobs=4, 387 | early_stopping_rounds=5, 388 | eval_metric="auc", 389 | ) 390 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False) 391 | 392 | assert report.shape[0] == 5 393 | assert shap_elimination.get_reduced_features_set(1) == ["f4"] 394 | 395 | 396 | # 397 | # 398 | def test_shap_rfe_early_stopping_CatBoost(complex_data_with_categorical, catboost_classifier, random_state): 399 | X, y = complex_data_with_categorical 400 | 401 | shap_elimination = ShapRFECV( 402 | catboost_classifier, 403 | random_state=random_state, 404 | step=1, 405 | cv=10, 406 | scoring="roc_auc", 407 | n_jobs=4, 408 | early_stopping_rounds=5, 409 | eval_metric="auc", 410 | ) 411 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False) 412 | 413 | assert report.shape[0] == 5 414 | assert shap_elimination.get_reduced_features_set(1)[0] in ["f4", "f5"] 415 | 416 | 417 | def test_shap_rfe_randomized_search_early_stopping_lightGBM(complex_data, random_state): 418 | model = LGBMClassifier(n_estimators=200, random_state=random_state) 419 | X, y = complex_data 420 | 421 | param_grid = { 422 | "max_depth": [3, 4, 5], 423 | } 424 | search = RandomizedSearchCV(model, param_grid, cv=2, n_iter=2, random_state=random_state) 425 | shap_elimination = ShapRFECV( 426 | search, 427 | step=1, 428 | cv=10, 429 | scoring="roc_auc", 430 | early_stopping_rounds=5, 431 | eval_metric="auc", 432 | n_jobs=4, 433 | verbose=1, 434 | random_state=random_state, 435 | ) 436 | report = shap_elimination.fit_compute(X, y) 437 | 438 | assert report.shape[0] == X.shape[1] 439 | assert shap_elimination.get_reduced_features_set(1) == ["f5"] 440 | 441 | _ = shap_elimination.plot(show=False) 442 | 443 | 444 | def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data, random_state): 445 | model = LGBMClassifier(n_estimators=200, max_depth=3, random_state=random_state) 446 | X, y = complex_data 447 | y = preprocess_labels(y, y_name="y", index=X.index) 448 | 449 | shap_elimination = ShapRFECV(model, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state) 450 | ( 451 | shap_values, 452 | train_score, 453 | test_score, 454 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping( 455 | X, 456 | y, 457 | model, 458 | train_index=list(range(5, 50)), 459 | val_index=[0, 1, 2, 3, 4], 460 | ) 461 | assert test_score > 0.6 462 | assert train_score > 0.6 463 | assert shap_values.shape == (5, 5) 464 | 465 | 466 | def test_get_feature_shap_values_per_fold_early_stopping_CatBoost( 467 | complex_data_with_categorical, catboost_classifier, random_state 468 | ): 469 | X, y = complex_data_with_categorical 470 | y = preprocess_labels(y, y_name="y", index=X.index) 471 | 472 | shap_elimination = ShapRFECV( 473 | catboost_classifier, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state 474 | ) 475 | ( 476 | shap_values, 477 | train_score, 478 | test_score, 479 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping( 480 | X, 481 | y, 482 | catboost_classifier, 483 | train_index=list(range(5, 50)), 484 | val_index=[0, 1, 2, 3, 4], 485 | ) 486 | assert test_score > 0 487 | assert train_score > 0.6 488 | assert shap_values.shape == (5, 5) 489 | 490 | 491 | def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(XGBoost_classifier, complex_data, random_state): 492 | X, y = complex_data 493 | y = preprocess_labels(y, y_name="y", index=X.index) 494 | 495 | shap_elimination = ShapRFECV( 496 | XGBoost_classifier, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state 497 | ) 498 | ( 499 | shap_values, 500 | train_score, 501 | test_score, 502 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping( 503 | X, 504 | y, 505 | XGBoost_classifier, 506 | train_index=list(range(5, 50)), 507 | val_index=[0, 1, 2, 3, 4], 508 | ) 509 | assert test_score > 0 510 | assert train_score > 0.6 511 | assert shap_values.shape == (5, 5) 512 | 513 | 514 | def test_EarlyStoppingShapRFECV_no_categorical(complex_data, random_state): 515 | model = LGBMClassifier(n_estimators=50, max_depth=3, num_leaves=3, random_state=random_state) 516 | 517 | shap_elimination = ShapRFECV( 518 | model=model, 519 | step=0.33, 520 | cv=5, 521 | scoring="accuracy", 522 | eval_metric="logloss", 523 | early_stopping_rounds=5, 524 | random_state=random_state, 525 | ) 526 | X, y = complex_data 527 | X = X.drop(columns=["f1_categorical"]) 528 | report = shap_elimination.fit_compute(X, y, feature_perturbation="tree_path_dependent") 529 | 530 | assert report.shape[0] == X.shape[1] 531 | assert shap_elimination.get_reduced_features_set(1) == ["f5"] 532 | 533 | _ = shap_elimination.plot(show=False) 534 | 535 | 536 | def test_LightGBM_stratified_kfold(random_state): 537 | """ 538 | Test added to check for https://github.com/ing-bank/probatus/issues/170. 539 | """ 540 | X = pd.DataFrame( 541 | [ 542 | [1, 2, 3, 4, 5, 101, 102, 103, 104, 105], 543 | [-1, -2, 2, -5, -7, 1, 2, 5, -1, 3], 544 | ["a", "b"] * 5, # noisy categorical will dropped first 545 | ] 546 | ).transpose() 547 | X[2] = X[2].astype("category") 548 | X[1] = X[1].astype("float") 549 | X[0] = X[0].astype("float") 550 | y = [0] * 5 + [1] * 5 551 | 552 | model = LGBMClassifier(random_state=random_state) 553 | n_iter = 2 554 | n_folds = 3 555 | 556 | for _ in range(n_iter): 557 | skf = StratifiedKFold(n_folds, shuffle=True, random_state=random_state) 558 | shap_elimination = ShapRFECV( 559 | model=model, 560 | step=1 / (n_iter + 1), 561 | cv=skf, 562 | scoring="accuracy", 563 | eval_metric="logloss", 564 | early_stopping_rounds=5, 565 | random_state=random_state, 566 | ) 567 | report = shap_elimination.fit_compute(X, y, feature_perturbation="tree_path_dependent") 568 | 569 | assert report.shape[0] == X.shape[1] 570 | 571 | shap_elimination.plot(show=False) 572 | -------------------------------------------------------------------------------- /tests/interpret/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/interpret/__init__.py -------------------------------------------------------------------------------- /tests/interpret/test_model_interpret.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from probatus.interpret import ShapModelInterpreter 6 | 7 | 8 | @pytest.fixture(scope="function") 9 | def expected_feature_importance(): 10 | return pd.DataFrame( 11 | { 12 | "mean_abs_shap_value_test": [0.5, 0.0, 0.0], 13 | "mean_abs_shap_value_train": [0.5, 0.0, 0.0], 14 | "mean_shap_value_test": [-0.5, 0.0, 0.0], 15 | "mean_shap_value_train": [-0.5, 0.0, 0.0], 16 | }, 17 | index=["col_3", "col_1", "col_2"], 18 | ) 19 | 20 | 21 | @pytest.fixture(scope="function") 22 | def expected_feature_importance_lin_models(): 23 | return pd.DataFrame( 24 | { 25 | "mean_abs_shap_value_test": [0.4, 0.0, 0.0], 26 | "mean_abs_shap_value_train": [0.4, 0.0, 0.0], 27 | "mean_shap_value_test": [-0.4, 0.0, 0.0], 28 | "mean_shap_value_train": [-0.4, 0.0, 0.0], 29 | }, 30 | index=["col_3", "col_1", "col_2"], 31 | ) 32 | 33 | 34 | def test_shap_interpret(fitted_tree, X_train, y_train, X_test, y_test, expected_feature_importance, random_state): 35 | class_names = ["neg", "pos"] 36 | 37 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state) 38 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names) 39 | 40 | assert shap_interpret.class_names == class_names 41 | assert shap_interpret.train_score == 1 42 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01) 43 | 44 | # Check expected shap values 45 | assert (np.mean(np.abs(shap_interpret.shap_values_test), axis=0) == [0, 0, 0.5]).all() 46 | assert (np.mean(np.abs(shap_interpret.shap_values_train), axis=0) == [0, 0, 0.5]).all() 47 | 48 | importance_df, train_auc, test_auc = shap_interpret.compute(return_scores=True) 49 | 50 | pd.testing.assert_frame_equal(expected_feature_importance, importance_df) 51 | assert train_auc == 1 52 | assert test_auc == pytest.approx(0.833, 0.01) 53 | 54 | # Check if plots work for such dataset 55 | ax1 = shap_interpret.plot("importance", target_set="test", show=False) 56 | ax2 = shap_interpret.plot("summary", target_set="test", show=False) 57 | ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False) 58 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False) 59 | ax5 = shap_interpret.plot("importance", target_set="train", show=False) 60 | ax6 = shap_interpret.plot("summary", target_set="train", show=False) 61 | ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False) 62 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False) 63 | assert not (isinstance(ax1, list)) 64 | assert not (isinstance(ax2, list)) 65 | assert isinstance(ax3, list) and len(ax4) == 2 66 | assert isinstance(ax4, list) and len(ax4) == 2 67 | assert not (isinstance(ax5, list)) 68 | assert not (isinstance(ax6, list)) 69 | assert isinstance(ax7, list) and len(ax7) == 2 70 | assert isinstance(ax8, list) and len(ax8) == 2 71 | 72 | 73 | def test_shap_interpret_lin_models( 74 | fitted_logistic_regression, X_train, y_train, X_test, y_test, expected_feature_importance_lin_models, random_state 75 | ): 76 | class_names = ["neg", "pos"] 77 | 78 | shap_interpret = ShapModelInterpreter(fitted_logistic_regression, random_state=random_state) 79 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names) 80 | 81 | assert shap_interpret.class_names == class_names 82 | assert shap_interpret.train_score == 1 83 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01) 84 | 85 | # Check expected shap values 86 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_test), axis=0), 2) == [0, 0, 0.4]).all() 87 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_train), axis=0), 2) == [0, 0, 0.4]).all() 88 | 89 | importance_df, train_auc, test_auc = shap_interpret.compute(return_scores=True) 90 | importance_df = importance_df.round(2) 91 | 92 | pd.testing.assert_frame_equal(expected_feature_importance_lin_models, importance_df) 93 | assert train_auc == 1 94 | assert test_auc == pytest.approx(0.833, 0.01) 95 | 96 | # Check if plots work for such dataset 97 | ax1 = shap_interpret.plot("importance", target_set="test", show=False) 98 | ax2 = shap_interpret.plot("summary", target_set="test", show=False) 99 | ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False) 100 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False) 101 | ax5 = shap_interpret.plot("importance", target_set="train", show=False) 102 | ax6 = shap_interpret.plot("summary", target_set="train", show=False) 103 | ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False) 104 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False) 105 | assert not (isinstance(ax1, list)) 106 | assert not (isinstance(ax2, list)) 107 | assert isinstance(ax3, list) and len(ax4) == 2 108 | assert isinstance(ax4, list) and len(ax4) == 2 109 | assert not (isinstance(ax5, list)) 110 | assert not (isinstance(ax6, list)) 111 | assert isinstance(ax7, list) and len(ax7) == 2 112 | assert isinstance(ax8, list) and len(ax8) == 2 113 | 114 | 115 | def test_shap_interpret_fit_compute_lin_models( 116 | fitted_logistic_regression, X_train, y_train, X_test, y_test, expected_feature_importance_lin_models, random_state 117 | ): 118 | class_names = ["neg", "pos"] 119 | 120 | shap_interpret = ShapModelInterpreter(fitted_logistic_regression, random_state=random_state) 121 | importance_df = shap_interpret.fit_compute(X_train, X_test, y_train, y_test, class_names=class_names) 122 | importance_df = importance_df.round(2) 123 | 124 | assert shap_interpret.class_names == class_names 125 | assert shap_interpret.train_score == 1 126 | 127 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01) 128 | 129 | # Check expected shap values 130 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_test), axis=0), 2) == [0, 0, 0.4]).all() 131 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_train), axis=0), 2) == [0, 0, 0.4]).all() 132 | 133 | pd.testing.assert_frame_equal(expected_feature_importance_lin_models, importance_df) 134 | 135 | 136 | def test_shap_interpret_fit_compute( 137 | fitted_tree, X_train, y_train, X_test, y_test, expected_feature_importance, random_state 138 | ): 139 | class_names = ["neg", "pos"] 140 | 141 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state) 142 | importance_df = shap_interpret.fit_compute(X_train, X_test, y_train, y_test, class_names=class_names) 143 | 144 | assert shap_interpret.class_names == class_names 145 | assert shap_interpret.train_score == 1 146 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01) 147 | 148 | # Check expected shap values 149 | assert (np.mean(np.abs(shap_interpret.shap_values_test), axis=0) == [0, 0, 0.5]).all() 150 | assert (np.mean(np.abs(shap_interpret.shap_values_train), axis=0) == [0, 0, 0.5]).all() 151 | 152 | pd.testing.assert_frame_equal(expected_feature_importance, importance_df) 153 | 154 | 155 | def test_shap_interpret_complex_data(complex_data_split_with_categorical, complex_fitted_lightgbm, random_state): 156 | class_names = ["neg", "pos"] 157 | X_train, X_test, y_train, y_test = complex_data_split_with_categorical 158 | 159 | shap_interpret = ShapModelInterpreter(complex_fitted_lightgbm, verbose=1, random_state=random_state) 160 | importance_df = shap_interpret.fit_compute( 161 | X_train, X_test, y_train, y_test, class_names=class_names, approximate=False, check_additivity=False 162 | ) 163 | 164 | assert shap_interpret.class_names == class_names 165 | assert importance_df.shape[0] == X_train.shape[1] 166 | 167 | # Check if plots work for such dataset 168 | ax1 = shap_interpret.plot("importance", target_set="test", show=False) 169 | ax2 = shap_interpret.plot("summary", target_set="test", show=False) 170 | ax3 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="test", show=False) 171 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False) 172 | ax5 = shap_interpret.plot("importance", target_set="train", show=False) 173 | ax6 = shap_interpret.plot("summary", target_set="train", show=False) 174 | ax7 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="train", show=False) 175 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False) 176 | assert not (isinstance(ax1, list)) 177 | assert not (isinstance(ax2, list)) 178 | assert isinstance(ax3, list) and len(ax4) == 2 179 | assert isinstance(ax4, list) and len(ax4) == 2 180 | assert not (isinstance(ax5, list)) 181 | assert not (isinstance(ax6, list)) 182 | assert isinstance(ax7, list) and len(ax7) == 2 183 | assert isinstance(ax8, list) and len(ax8) == 2 184 | -------------------------------------------------------------------------------- /tests/interpret/test_shap_dependence.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from sklearn.ensemble import RandomForestClassifier 7 | 8 | from probatus.interpret.shap_dependence import DependencePlotter 9 | from probatus.utils.exceptions import NotFittedError 10 | 11 | # Turn off interactive mode in plots 12 | plt.ioff() 13 | matplotlib.use("Agg") 14 | 15 | 16 | @pytest.fixture(scope="function") 17 | def X_y(): 18 | return ( 19 | pd.DataFrame( 20 | [ 21 | [1.72568193, 2.21070436, 1.46039061], 22 | [-1.48382902, 2.88364928, 0.22323996], 23 | [-0.44947744, 0.85434638, -2.54486421], 24 | [-1.38101231, 1.77505901, -1.36000132], 25 | [-0.18261804, -0.25829609, 1.46925993], 26 | [0.27514902, 0.09608222, 0.7221381], 27 | [-0.27264455, 1.99366793, -2.62161046], 28 | [-2.81587587, 3.46459717, -0.11740999], 29 | [1.48374489, 0.79662903, 1.18898706], 30 | [-1.27251335, -1.57344342, -0.39540133], 31 | [0.31532891, 0.38299269, 1.29998754], 32 | [-2.10917352, -0.70033132, -0.89922129], 33 | [-2.14396343, -0.44549774, -1.80572922], 34 | [-3.4503348, 3.43476247, -0.74957725], 35 | [-1.25945582, -1.7234203, -0.77435353], 36 | ] 37 | ), 38 | pd.Series([1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0]), 39 | ) 40 | 41 | 42 | @pytest.fixture(scope="function") 43 | def expected_shap_vals(): 44 | return pd.DataFrame( 45 | [ 46 | [0.176667, 0.005833, 0.284167], 47 | [-0.042020, 0.224520, 0.284167], 48 | [-0.092020, -0.135480, -0.205833], 49 | [-0.092020, -0.135480, -0.205833], 50 | [0.002424, 0.000909, 0.263333], 51 | [0.176667, 0.105833, 0.284167], 52 | [-0.092020, -0.135480, -0.205833], 53 | [-0.028687, 0.311187, 0.184167], 54 | [0.176667, 0.005833, 0.284167], 55 | [-0.092020, -0.164646, -0.076667], 56 | [0.176667, 0.105833, 0.284167], 57 | [-0.092020, -0.164646, -0.176667], 58 | [-0.092020, -0.164646, -0.176667], 59 | [-0.108687, 0.081187, -0.205833], 60 | [-0.092020, -0.164646, -0.176667], 61 | ] 62 | ) 63 | 64 | 65 | @pytest.fixture(scope="function") 66 | def model(X_y, random_state): 67 | X, y = X_y 68 | 69 | model = RandomForestClassifier(random_state=random_state, n_estimators=10, max_depth=5) 70 | 71 | model.fit(X, y) 72 | return model 73 | 74 | 75 | @pytest.fixture(scope="function") 76 | def expected_feat_importances(): 77 | return pd.DataFrame( 78 | { 79 | "Feature Name": {0: 2, 1: 1, 2: 0}, 80 | "Shap absolute importance": {0: 0.2199, 1: 0.1271, 2: 0.1022}, 81 | "Shap signed importance": {0: 0.0292, 1: -0.0149, 2: -0.0076}, 82 | } 83 | ) 84 | 85 | 86 | def test_not_fitted(model, random_state): 87 | plotter = DependencePlotter(model, random_state) 88 | assert plotter.fitted is False 89 | 90 | 91 | def test_fit_complex(complex_data_split, complex_fitted_lightgbm, random_state): 92 | _, X_test, _, y_test = complex_data_split 93 | 94 | plotter = DependencePlotter(complex_fitted_lightgbm, random_state=random_state) 95 | 96 | plotter.fit(X_test, y_test) 97 | 98 | pd.testing.assert_frame_equal(plotter.X, X_test) 99 | pd.testing.assert_series_equal(plotter.y, pd.Series(y_test, index=X_test.index)) 100 | assert plotter.fitted is True 101 | 102 | # Check if plotting does not cause errors 103 | _ = plotter.plot(feature="f2_missing", show=False) 104 | 105 | 106 | def test_get_X_y_shap_with_q_cut_normal(X_y, model, random_state): 107 | X, y = X_y 108 | 109 | plotter = DependencePlotter(model, random_state).fit(X, y) 110 | plotter.min_q, plotter.max_q = 0, 1 111 | 112 | X_cut, y_cut, _ = plotter._get_X_y_shap_with_q_cut(0) 113 | assert np.isclose(X[0], X_cut).all() 114 | assert y.equals(y_cut) 115 | 116 | plotter.min_q = 0.2 117 | plotter.max_q = 0.8 118 | 119 | X_cut, y_cut, _ = plotter._get_X_y_shap_with_q_cut(0) 120 | assert np.isclose( 121 | X_cut, 122 | [ 123 | -1.48382902, 124 | -0.44947744, 125 | -1.38101231, 126 | -0.18261804, 127 | 0.27514902, 128 | -0.27264455, 129 | -1.27251335, 130 | -2.10917352, 131 | -1.25945582, 132 | ], 133 | ).all() 134 | assert np.equal(y_cut.values, [1, 0, 0, 1, 1, 0, 0, 0, 0]).all() 135 | 136 | 137 | def test_get_X_y_shap_with_q_cut_unfitted(model, random_state): 138 | plotter = DependencePlotter(model, random_state) 139 | with pytest.raises(NotFittedError): 140 | plotter._get_X_y_shap_with_q_cut(0) 141 | 142 | 143 | def test_get_X_y_shap_with_q_cut_input(X_y, model, random_state): 144 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1]) 145 | with pytest.raises(ValueError): 146 | plotter._get_X_y_shap_with_q_cut("not a feature") 147 | 148 | 149 | def test_plot_normal(X_y, model, random_state): 150 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1]) 151 | _ = plotter.plot(feature=0) 152 | 153 | 154 | def test_plot_class_names(X_y, model, random_state): 155 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1], class_names=["a", "b"]) 156 | _ = plotter.plot(feature=0) 157 | assert plotter.class_names == ["a", "b"] 158 | 159 | 160 | def test_plot_input(X_y, model, random_state): 161 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1]) 162 | with pytest.raises(ValueError): 163 | plotter.plot(feature="not a feature") 164 | with pytest.raises(TypeError): 165 | plotter.plot(feature=0, bins=5.0) 166 | with pytest.raises(ValueError): 167 | plotter.plot(feature=0, min_q=1, max_q=0) 168 | 169 | 170 | def test__repr__(model, random_state): 171 | """ 172 | Test string representation. 173 | """ 174 | plotter = DependencePlotter(model, random_state) 175 | assert str(plotter) == "Shap dependence plotter for RandomForestClassifier" 176 | -------------------------------------------------------------------------------- /tests/mocks.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock 2 | 3 | # These are shell classes that define the methods of the models that we use. Each of the functions that we use needs 4 | # To be defined inside these shell classes. Then when we want to write a specific test you need to simply mock.patch 5 | # the desired functionality. You can also set the return_value to the patched method. 6 | 7 | 8 | class MockClusterer(Mock): 9 | """ 10 | MockCluster. 11 | """ 12 | 13 | def __init__(self): 14 | """ 15 | Init. 16 | """ 17 | pass 18 | 19 | def fit(self): 20 | """ 21 | Fit. 22 | """ 23 | pass 24 | 25 | def predict(self): 26 | """ 27 | Predict. 28 | """ 29 | pass 30 | 31 | def fit_predict(self): 32 | """ 33 | Both. 34 | """ 35 | pass 36 | 37 | 38 | class MockModel(Mock): 39 | """ 40 | Mockmodel. 41 | """ 42 | 43 | def __init__(self, **kwargs): 44 | """ 45 | Init. 46 | """ 47 | pass 48 | 49 | def fit(self): 50 | """ 51 | Fit. 52 | """ 53 | pass 54 | 55 | def predict(self): 56 | """ 57 | Predict. 58 | """ 59 | pass 60 | 61 | def predict_proba(self): 62 | """ 63 | Both. 64 | """ 65 | pass 66 | -------------------------------------------------------------------------------- /tests/sample_similarity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/sample_similarity/__init__.py -------------------------------------------------------------------------------- /tests/sample_similarity/test_resemblance_model.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from pandas.api.types import is_numeric_dtype 7 | 8 | from probatus.sample_similarity import BaseResemblanceModel, PermutationImportanceResemblance, SHAPImportanceResemblance 9 | 10 | # Turn off interactive mode in plots 11 | plt.ioff() 12 | matplotlib.use("Agg") 13 | 14 | 15 | @pytest.fixture(scope="function") 16 | def X1(): 17 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [0, 0, 0, 0]}, index=[1, 2, 3, 4]) 18 | 19 | 20 | @pytest.fixture(scope="function") 21 | def X2(): 22 | return pd.DataFrame({"col_1": [0, 0, 0, 0], "col_2": [0, 0, 0, 0], "col_3": [0, 0, 0, 0]}, index=[1, 2, 3, 4]) 23 | 24 | 25 | def test_base_class(X1, X2, decision_tree_classifier, random_state): 26 | rm = BaseResemblanceModel(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state) 27 | 28 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True) 29 | 30 | assert train_score == 1 31 | assert test_score == 1 32 | assert actual_report is None 33 | 34 | # Check data splits if correct 35 | actual_X_train, actual_X_test, actual_y_train, actual_y_test = rm.get_data_splits() 36 | 37 | assert actual_X_train.shape == (4, 3) 38 | assert actual_X_test.shape == (4, 3) 39 | assert len(actual_y_train) == 4 40 | assert len(actual_y_test) == 4 41 | 42 | # Check if stratified 43 | assert np.sum(actual_y_train) == 2 44 | assert np.sum(actual_y_test) == 2 45 | 46 | # Check if index is correct 47 | assert len(rm.X.index.unique()) == 8 48 | assert list(rm.X.index) == list(rm.y.index) 49 | 50 | with pytest.raises(NotImplementedError) as _: 51 | rm.plot() 52 | 53 | 54 | def test_base_class_lin_models(X1, X2, logistic_regression, random_state): 55 | # Test class BaseResemblanceModel for linear models. 56 | rm = BaseResemblanceModel(logistic_regression, test_prc=0.5, n_jobs=1, random_state=random_state) 57 | 58 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True) 59 | 60 | assert train_score == 1 61 | assert test_score == 1 62 | assert actual_report is None 63 | 64 | # Check data splits if correct 65 | actual_X_train, actual_X_test, actual_y_train, actual_y_test = rm.get_data_splits() 66 | 67 | assert actual_X_train.shape == (4, 3) 68 | assert actual_X_test.shape == (4, 3) 69 | assert len(actual_y_train) == 4 70 | assert len(actual_y_test) == 4 71 | 72 | # Check if stratified 73 | assert np.sum(actual_y_train) == 2 74 | assert np.sum(actual_y_test) == 2 75 | 76 | # Check if index is correct 77 | assert len(rm.X.index.unique()) == 8 78 | assert list(rm.X.index) == list(rm.y.index) 79 | 80 | with pytest.raises(NotImplementedError) as _: 81 | rm.plot() 82 | 83 | 84 | def test_shap_resemblance_class(X1, X2, decision_tree_classifier, random_state): 85 | rm = SHAPImportanceResemblance(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state) 86 | 87 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True) 88 | 89 | assert train_score == 1 90 | assert test_score == 1 91 | 92 | # Check report shape 93 | assert actual_report.shape == (3, 2) 94 | # Check if it is sorted by importance 95 | assert actual_report.iloc[0].name == "col_1" 96 | # Check report values 97 | assert actual_report.loc["col_1"]["mean_abs_shap_value"] > 0 98 | assert actual_report.loc["col_1"]["mean_shap_value"] < 0 99 | assert actual_report.loc["col_2"]["mean_abs_shap_value"] == 0 100 | assert actual_report.loc["col_2"]["mean_shap_value"] == 0 101 | assert actual_report.loc["col_3"]["mean_abs_shap_value"] == 0 102 | assert actual_report.loc["col_3"]["mean_shap_value"] == 0 103 | 104 | actual_shap_values_test = rm.get_shap_values() 105 | assert actual_shap_values_test.shape == (4, 3) 106 | 107 | # Run plots 108 | rm.plot(plot_type="bar") 109 | rm.plot(plot_type="dot") 110 | 111 | 112 | def test_shap_resemblance_class_lin_models(X1, X2, logistic_regression, random_state): 113 | # Test SHAP Resemblance Model for linear models. 114 | rm = SHAPImportanceResemblance(logistic_regression, test_prc=0.5, n_jobs=1, random_state=random_state) 115 | 116 | actual_report, train_score, test_score = rm.fit_compute( 117 | X1, X2, return_scores=True, approximate=True, check_additivity=False 118 | ) 119 | 120 | assert train_score == 1 121 | assert test_score == 1 122 | 123 | # Check report shape 124 | assert actual_report.shape == (3, 2) 125 | # Check if it is sorted by importance 126 | assert actual_report.iloc[0].name == "col_1" 127 | # Check report values 128 | assert actual_report.loc["col_1"]["mean_abs_shap_value"] > 0 129 | assert actual_report.loc["col_1"]["mean_shap_value"] < 0 130 | assert actual_report.loc["col_2"]["mean_abs_shap_value"] == 0 131 | assert actual_report.loc["col_2"]["mean_shap_value"] == 0 132 | assert actual_report.loc["col_3"]["mean_abs_shap_value"] == 0 133 | assert actual_report.loc["col_3"]["mean_shap_value"] == 0 134 | 135 | actual_shap_values_test = rm.get_shap_values() 136 | assert actual_shap_values_test.shape == (4, 3) 137 | 138 | # Run plots 139 | rm.plot(plot_type="bar") 140 | rm.plot(plot_type="dot") 141 | 142 | 143 | def test_shap_resemblance_class2(complex_data_with_categorical, complex_lightgbm, random_state): 144 | X1, _ = complex_data_with_categorical 145 | X2 = X1.copy() 146 | X2["f4"] = X2["f4"] + 100 147 | 148 | rm = SHAPImportanceResemblance( 149 | complex_lightgbm, scoring="accuracy", test_prc=0.5, n_jobs=1, random_state=random_state 150 | ) 151 | 152 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True, class_names=["a", "b"]) 153 | 154 | # Check if the X and y within the rm have correct types 155 | assert rm.X["f1_categorical"].dtype.name == "category" 156 | for num_column in ["f2_missing", "f3_static", "f4", "f5"]: 157 | assert is_numeric_dtype(rm.X[num_column]) 158 | 159 | assert train_score == pytest.approx(1, 0.05) 160 | assert test_score == pytest.approx(1, 0.05) 161 | 162 | # Check report shape 163 | assert actual_report.shape == (5, 2) 164 | # Check if it is sorted by importance 165 | assert actual_report.iloc[0].name == "f4" 166 | 167 | # Check report values 168 | assert actual_report.loc["f4"]["mean_abs_shap_value"] > 0 169 | 170 | actual_shap_values_test = rm.get_shap_values() 171 | # 50 test samples and 5 features 172 | assert actual_shap_values_test.shape == (X1.shape[0], X1.shape[1]) 173 | 174 | # Run plots 175 | rm.plot(plot_type="bar", show=True) 176 | rm.plot(plot_type="dot", show=False) 177 | 178 | 179 | def test_permutation_resemblance_class(X1, X2, decision_tree_classifier, random_state): 180 | rm = PermutationImportanceResemblance( 181 | decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state, iterations=20 182 | ) 183 | 184 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True) 185 | 186 | assert train_score == 1 187 | assert test_score == 1 188 | 189 | # Check report shape 190 | assert actual_report.shape == (3, 2) 191 | # Check if it is sorted by importance 192 | assert actual_report.iloc[0].name == "col_1" 193 | # Check report values 194 | assert actual_report.loc["col_1"]["mean_importance"] > 0 195 | assert actual_report.loc["col_1"]["std_importance"] > 0 196 | assert actual_report.loc["col_2"]["mean_importance"] == 0 197 | assert actual_report.loc["col_2"]["std_importance"] == 0 198 | assert actual_report.loc["col_3"]["mean_importance"] == 0 199 | assert actual_report.loc["col_3"]["std_importance"] == 0 200 | 201 | rm.plot(figsize=(10, 10)) 202 | # Check plot size 203 | fig = plt.gcf() 204 | size = fig.get_size_inches() 205 | assert size[0] == 10 and size[1] == 10 206 | 207 | 208 | def test_base_class_same_data(X1, decision_tree_classifier, random_state): 209 | rm = BaseResemblanceModel(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state) 210 | 211 | actual_report, train_score, test_score = rm.fit_compute(X1, X1, return_scores=True) 212 | 213 | assert train_score == 0.5 214 | assert test_score == 0.5 215 | assert actual_report is None 216 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_base_class.py: -------------------------------------------------------------------------------- 1 | from probatus.interpret import ShapModelInterpreter 2 | import pytest 3 | from probatus.utils import NotFittedError 4 | 5 | 6 | def test_fitted_exception(fitted_tree, X_train, y_train, X_test, y_test, random_state): 7 | class_names = ["neg", "pos"] 8 | 9 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state) 10 | 11 | # Before fit it should raise an exception 12 | with pytest.raises(NotFittedError) as _: 13 | shap_interpret._check_if_fitted() 14 | 15 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names) 16 | 17 | # Check parameters 18 | assert shap_interpret.fitted 19 | shap_interpret._check_if_fitted 20 | 21 | 22 | @pytest.mark.xfail 23 | def test_fitted_exception_is_raised(fitted_tree, random_state): 24 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state) 25 | 26 | shap_interpret._check_if_fitted 27 | -------------------------------------------------------------------------------- /tests/utils/test_utils_array_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from probatus.utils import assure_pandas_df, preprocess_data, preprocess_labels 6 | 7 | 8 | @pytest.fixture(scope="function") 9 | def expected_df_2d(): 10 | return pd.DataFrame({0: [1, 2], 1: [2, 3], 2: [3, 4]}) 11 | 12 | 13 | @pytest.fixture(scope="function") 14 | def expected_df(): 15 | return pd.DataFrame({0: [1, 2, 3]}) 16 | 17 | 18 | def test_assure_pandas_df_list(expected_df): 19 | x = [1, 2, 3] 20 | x_df = assure_pandas_df(x) 21 | pd.testing.assert_frame_equal(x_df, expected_df) 22 | 23 | 24 | def test_assure_pandas_df_list_of_lists(expected_df_2d): 25 | x = [[1, 2, 3], [2, 3, 4]] 26 | x_df = assure_pandas_df(x) 27 | pd.testing.assert_frame_equal(x_df, expected_df_2d) 28 | 29 | 30 | def test_assure_pandas_df_series(expected_df): 31 | x = pd.Series([1, 2, 3]) 32 | x_df = assure_pandas_df(x) 33 | pd.testing.assert_frame_equal(x_df, expected_df) 34 | 35 | 36 | def test_assure_pandas_df_array(expected_df, expected_df_2d): 37 | x = np.array([[1, 2, 3], [2, 3, 4]], dtype="int64") 38 | x_df = assure_pandas_df(x) 39 | pd.testing.assert_frame_equal(x_df, expected_df_2d) 40 | 41 | x = np.array([1, 2, 3], dtype="int64") 42 | x_df = assure_pandas_df(x) 43 | pd.testing.assert_frame_equal(x_df, expected_df) 44 | 45 | 46 | def test_assure_pandas_df_df(expected_df_2d): 47 | x = pd.DataFrame([[1, 2, 3], [2, 3, 4]]) 48 | x_df = assure_pandas_df(x) 49 | pd.testing.assert_frame_equal(x_df, expected_df_2d) 50 | 51 | 52 | def test_assure_pandas_df_types(): 53 | with pytest.raises(TypeError): 54 | assure_pandas_df("Test") 55 | with pytest.raises(TypeError): 56 | assure_pandas_df(5) 57 | 58 | 59 | def test_preprocess_labels(): 60 | y1 = pd.Series([1, 0, 1, 0, 1]) 61 | index_1 = np.array([5, 4, 3, 2, 1]) 62 | 63 | y1_output = preprocess_labels(y1, y_name="y1", index=index_1, verbose=2) 64 | pd.testing.assert_series_equal(y1_output, pd.Series([1, 0, 1, 0, 1], index=index_1)) 65 | 66 | y2 = [False, False, False, False, False] 67 | y2_output = preprocess_labels(y2, y_name="y2", verbose=2) 68 | pd.testing.assert_series_equal(y2_output, pd.Series(y2)) 69 | 70 | y3 = np.array([0, 1, 2, 3, 4]) 71 | y3_output = preprocess_labels(y3, y_name="y3", verbose=2) 72 | pd.testing.assert_series_equal(y3_output, pd.Series(y3)) 73 | 74 | y4 = pd.Series(["2", "1", "3", "2", "1"]) 75 | index4 = pd.Index([0, 2, 1, 3, 4]) 76 | y4_output = preprocess_labels(y4, y_name="y4", index=index4, verbose=0) 77 | pd.testing.assert_series_equal(y4_output, pd.Series(["2", "3", "1", "2", "1"], index=index4)) 78 | 79 | 80 | def test_preprocess_data(): 81 | X1 = pd.DataFrame({"cat": ["a", "b", "c"], "missing": [1, np.nan, 2], "num_1": [1, 2, 3]}) 82 | 83 | target_column_names_X1 = ["1", "2", "3"] 84 | X1_expected_output = pd.DataFrame({"1": ["a", "b", "c"], "2": [1, np.nan, 2], "3": [1, 2, 3]}) 85 | 86 | X1_expected_output["1"] = X1_expected_output["1"].astype("category") 87 | X1_output, output_column_names_X1 = preprocess_data(X1, X_name="X1", column_names=target_column_names_X1, verbose=2) 88 | assert target_column_names_X1 == output_column_names_X1 89 | pd.testing.assert_frame_equal(X1_output, X1_expected_output) 90 | 91 | X2 = np.array([[1, 3, 2], [1, 2, 2], [1, 2, 3]]) 92 | 93 | target_column_names_X1 = [0, 1, 2] 94 | X2_expected_output = pd.DataFrame(X2, columns=target_column_names_X1) 95 | X2_output, output_column_names_X2 = preprocess_data(X2, X_name="X2", column_names=None, verbose=2) 96 | 97 | assert target_column_names_X1 == output_column_names_X2 98 | pd.testing.assert_frame_equal(X2_output, X2_expected_output) 99 | --------------------------------------------------------------------------------