├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── other-issue.md
├── PULL_REQUEST_TEMPLATE
│ └── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ ├── cronjob_unit_tests.yml
│ ├── publish_to_pypi.yml
│ └── unit_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENCE
├── README.md
├── VISION.md
├── docs
├── api
│ ├── feature_elimination.md
│ ├── model_interpret.md
│ ├── sample_similarity.md
│ └── utils.md
├── discussion
│ └── nb_rfecv_vs_shaprfecv.ipynb
├── howto
│ ├── grouped_data.ipynb
│ └── reproducibility.ipynb
├── img
│ ├── Probatus_P.png
│ ├── Probatus_P_white.png
│ ├── earlystoppingshaprfecv.png
│ ├── logo_large.png
│ ├── logo_large_white.png
│ ├── model_interpret_dep.png
│ ├── model_interpret_importance.png
│ ├── model_interpret_sample.png
│ ├── model_interpret_summary.png
│ ├── resemblance_model_schema.png
│ ├── sample_similarity_permutation_importance.png
│ ├── sample_similarity_shap_importance.png
│ ├── sample_similarity_shap_summary.png
│ └── shaprfecv.png
├── index.md
└── tutorials
│ ├── nb_automatic_best_num_features.ipynb
│ ├── nb_custom_scoring.ipynb
│ ├── nb_sample_similarity.ipynb
│ ├── nb_shap_dependence.ipynb
│ ├── nb_shap_feature_elimination.ipynb
│ ├── nb_shap_model_interpreter.ipynb
│ └── nb_shap_variance_penalty_and_results_comparison.ipynb
├── mkdocs.yml
├── probatus
├── __init__.py
├── feature_elimination
│ ├── __init__.py
│ ├── early_stopping_feature_elimination.py
│ └── feature_elimination.py
├── interpret
│ ├── __init__.py
│ ├── model_interpret.py
│ └── shap_dependence.py
├── sample_similarity
│ ├── __init__.py
│ └── resemblance_model.py
└── utils
│ ├── __init__.py
│ ├── _utils.py
│ ├── arrayfuncs.py
│ ├── base_class_interface.py
│ ├── exceptions.py
│ ├── scoring.py
│ └── shap_helpers.py
├── pyproject.toml
└── tests
├── __init__.py
├── conftest.py
├── docs
├── __init__.py
├── test_docstring.py
└── test_notebooks.py
├── feature_elimination
└── test_feature_elimination.py
├── interpret
├── __init__.py
├── test_model_interpret.py
└── test_shap_dependence.py
├── mocks.py
├── sample_similarity
├── __init__.py
└── test_resemblance_model.py
└── utils
├── __init__.py
├── test_base_class.py
└── test_utils_array_funcs.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: bug
6 | assignees: ''
7 | ---
8 |
9 | **Describe the bug**
10 |
11 | A clear and concise description of what the bug is.
12 |
13 | **Environment (please complete the following information):**
14 |
15 | - probatus version [e.g. 1.5.0] (`pip list | grep probatus`)
16 | - python version (`python -V`)
17 | - OS: [e.g. macOS, windows, linux]
18 |
19 | **To Reproduce**
20 |
21 | Code or steps to reproduce the error
22 |
23 | ```python
24 | # Put your code here
25 | ```
26 |
27 | **Error traceback**
28 |
29 | If applicable please provide full traceback of the error.
30 |
31 | ```
32 | # Put traceback here
33 | ```
34 |
35 | **Expected behavior**
36 |
37 | A clear and concise description of what you expected to happen.
38 |
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Problem Description**
11 | A clear and concise description of what the problem is, e.g. probatus currently does not allow to use this model.
12 |
13 | **Desired Outcome**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Solution Outline**
17 | If you have an idea how the solution should look like, please describe it.
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/other-issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Other Issue
3 | about: Other kind of issue
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Issue Description**
11 | Please describe your issue.
12 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Pull Request
3 | about: Propose changes to the codebase
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## Description
11 | Please provide a summary of the changes you are proposing with this Pull Request. Include the motivation and context. Highlight any key changes. If this PR fixes an open issue, mention the #issue_number" in this PR.
12 |
13 | ## Type of change
14 | Please delete options that are not relevant.
15 |
16 | - [ ] Bug fix (non-breaking change which fixes an issue)
17 | - [ ] New feature (non-breaking change which adds functionality)
18 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
19 | - [ ] Documentation update
20 |
21 | ## Checklist:
22 | Go over all the following points, and put an "x" in all the boxes that apply. If you're unsure about any of these, don't hesitate to ask. We're here to help!
23 |
24 | - [ ] My code follows the code style of this project.
25 | - [ ] My change requires a change to the documentation/notebooks.
26 | - [ ] I have updated the documentation or added a notebook accordingly.
27 | - [ ] I have tested the code locally according to the **CONTRIBUTING.md** document.
28 | - [ ] I have added tests to cover my changes.
29 | - [ ] All new and existing tests passed.
30 |
31 | ## How Has This Been Tested?
32 | Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. List any relevant details for your test configuration.
33 |
34 | ## Screenshots (if applicable):
35 | Include screenshots or gifs of the changes you have made. This is especially useful for UI changes.
36 |
37 | ## Additional Context:
38 | Add any other context or screenshots about the pull request here. This could include reasons for changes, decisions made, or anything you want the reviewer to know.
39 |
40 | ---
41 |
42 | **For more information on contributing, please see the [CONTRIBUTING.md](../../CONTRIBUTING.md) document.**
43 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "pip" # See documentation for possible values
9 | directory: "/" # pyproject.toml
10 | schedule:
11 | interval: "weekly"
12 |
--------------------------------------------------------------------------------
/.github/workflows/cronjob_unit_tests.yml:
--------------------------------------------------------------------------------
1 | name: Cron Test Dependencies
2 |
3 | # Controls when the action will run.
4 | # Every sunday at 4:05
5 | # See https://crontab.guru/#5 4 * * 0
6 | on:
7 | schedule:
8 | - cron: "5 4 * * 0"
9 |
10 | jobs:
11 | run:
12 | name: Run unit tests
13 | runs-on: ${{ matrix.os }}
14 | strategy:
15 | matrix:
16 | build: [macos, ubuntu, windows]
17 | include:
18 | - build: macos
19 | os: macos-latest
20 | - build: ubuntu
21 | os: ubuntu-latest
22 | - build: windows
23 | os: windows-latest
24 | python-version: [3.9, "3.10", "3.11", "3.12"]
25 | steps:
26 | - uses: actions/checkout@master
27 |
28 | - name: Get latest CMake and Ninja
29 | uses: lukka/get-cmake@latest
30 | with:
31 | cmakeVersion: latest
32 | ninjaVersion: latest
33 |
34 | - name: Install LIBOMP on Macos runners
35 | if: runner.os == 'macOS'
36 | run: |
37 | brew install libomp
38 |
39 | - name: Setup Python
40 | uses: actions/setup-python@master
41 | with:
42 | python-version: ${{ matrix.python-version }}
43 |
44 | - name: Install Python dependencies
45 | run: |
46 | pip3 install --upgrade setuptools pip
47 | pip3 install ".[all]"
48 |
49 | - name: Run linters
50 | run: |
51 | pre-commit install
52 | pre-commit run --all-files
53 |
54 | - name: Run (unit) tests
55 | env:
56 | TEST_NOTEBOOKS: 0
57 | run: |
58 | pytest --cov=probatus/binning --cov=probatus/metric_volatility --cov=probatus/missing_values --cov=probatus/sample_similarity --cov=probatus/stat_tests --cov=probatus/utils --cov=probatus/interpret/ --ignore==tests/interpret/test_inspector.py --cov-report=xml
59 | pyflakes probatus
60 |
--------------------------------------------------------------------------------
/.github/workflows/publish_to_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@master
12 | - name: Set up Python
13 | uses: actions/setup-python@master
14 | with:
15 | python-version: "3.10"
16 | - name: Install dependencies
17 | run: |
18 | pip3 install --upgrade setuptools pip
19 | pip3 install ".[all]"
20 | - name: Run unit tests and linters
21 | run: |
22 | pytest
23 | - name: Build package & publish to PyPi
24 | env:
25 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
26 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
27 | run: |
28 | pip3 install --upgrade wheel twine build
29 | python -m build
30 | twine upload dist/*
31 | - name: Deploy mkdocs site
32 | run: |
33 | mkdocs gh-deploy --force
34 |
--------------------------------------------------------------------------------
/.github/workflows/unit_tests.yml:
--------------------------------------------------------------------------------
1 | name: Development
2 | on:
3 | # Trigger the workflow on push or pull request,
4 | # but only for the main branch
5 | push:
6 | branches:
7 | - main
8 | pull_request:
9 | jobs:
10 | run:
11 | name: Run unit tests
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | build: [macos, ubuntu, windows]
16 | include:
17 | - build: macos
18 | os: macos-latest
19 | - build: ubuntu
20 | os: ubuntu-latest
21 | - build: windows
22 | os: windows-latest
23 | python-version: [3.9, "3.10", "3.11", "3.12"]
24 | steps:
25 | - uses: actions/checkout@master
26 |
27 | - name: Get latest CMake and Ninja
28 | uses: lukka/get-cmake@latest
29 | with:
30 | cmakeVersion: latest
31 | ninjaVersion: latest
32 |
33 | - name: Install LIBOMP on Macos runners
34 | if: runner.os == 'macOS'
35 | run: |
36 | brew install libomp
37 |
38 | - name: Setup Python
39 | uses: actions/setup-python@master
40 | with:
41 | python-version: ${{ matrix.python-version }}
42 |
43 | - name: Install Python dependencies
44 | run: |
45 | pip3 install --upgrade setuptools pip
46 | pip3 install ".[all]"
47 |
48 | - name: Run linters
49 | run: |
50 | pre-commit install
51 | pre-commit run --all-files
52 |
53 | - name: Run (unit) tests
54 | env:
55 | TEST_NOTEBOOKS: 0
56 | run: |
57 | pytest --cov=probatus/binning --cov=probatus/metric_volatility --cov=probatus/missing_values --cov=probatus/sample_similarity --cov=probatus/stat_tests --cov=probatus/utils --cov=probatus/interpret/ --ignore==tests/interpret/test_inspector.py --cov-report=xml
58 | pyflakes probatus
59 |
60 | - name: Upload coverage to Codecov
61 | if: github.ref == 'refs/heads/main'
62 | uses: codecov/codecov-action@v1
63 | with:
64 | token: ${{ secrets.CODECOV_TOKEN }}
65 | file: ./coverage.xml
66 | flags: unittests
67 | fail_ci_if_error: false
68 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | notebooks/private
2 | notebooks/explore
3 | .vscode/
4 | .idea/
5 | docs/source/.ipynb_checkpoints/
6 | notebooks/.ipynb_checkpoints/
7 |
8 |
9 | # Created by https://www.gitignore.io/api/macos,python,pycharm,jupyternotebooks,visualstudiocode
10 | # Edit at https://www.gitignore.io/?templates=macos,python,pycharm,jupyternotebooks,visualstudiocode
11 |
12 | ### JupyterNotebooks ###
13 | # gitignore template for Jupyter Notebooks
14 | # website: http://jupyter.org/
15 |
16 | .ipynb_checkpoints
17 | */.ipynb_checkpoints/*
18 |
19 | # IPython
20 | profile_default/
21 | ipython_config.py
22 |
23 | # Remove previous ipynb_checkpoints
24 | # git rm -r .ipynb_checkpoints/
25 |
26 | ### macOS ###
27 | # General
28 | .DS_Store
29 | .AppleDouble
30 | .LSOverride
31 |
32 | # Icon must end with two \r
33 | Icon
34 |
35 | # Thumbnails
36 | ._*
37 |
38 | # Files that might appear in the root of a volume
39 | .DocumentRevisions-V100
40 | .fseventsd
41 | .Spotlight-V100
42 | .TemporaryItems
43 | .Trashes
44 | .VolumeIcon.icns
45 | .com.apple.timemachine.donotpresent
46 |
47 | # Directories potentially created on remote AFP share
48 | .AppleDB
49 | .AppleDesktop
50 | Network Trash Folder
51 | Temporary Items
52 | .apdisk
53 |
54 | ### PyCharm ###
55 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
56 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
57 |
58 | # User-specific stuff
59 | .idea/**/workspace.xml
60 | .idea/**/tasks.xml
61 | .idea/**/usage.statistics.xml
62 | .idea/**/dictionaries
63 | .idea/**/shelf
64 |
65 | # Generated files
66 | .idea/**/contentModel.xml
67 |
68 | # Sensitive or high-churn files
69 | .idea/**/dataSources/
70 | .idea/**/dataSources.ids
71 | .idea/**/dataSources.local.xml
72 | .idea/**/sqlDataSources.xml
73 | .idea/**/dynamic.xml
74 | .idea/**/uiDesigner.xml
75 | .idea/**/dbnavigator.xml
76 |
77 | # Gradle
78 | .idea/**/gradle.xml
79 | .idea/**/libraries
80 |
81 | # Gradle and Maven with auto-import
82 | # When using Gradle or Maven with auto-import, you should exclude module files,
83 | # since they will be recreated, and may cause churn. Uncomment if using
84 | # auto-import.
85 | # .idea/modules.xml
86 | # .idea/*.iml
87 | # .idea/modules
88 | # *.iml
89 | # *.ipr
90 |
91 | # CMake
92 | cmake-build-*/
93 |
94 | # Mongo Explorer plugin
95 | .idea/**/mongoSettings.xml
96 |
97 | # File-based project format
98 | *.iws
99 |
100 | # IntelliJ
101 | out/
102 |
103 | # mpeltonen/sbt-idea plugin
104 | .idea_modules/
105 |
106 | # JIRA plugin
107 | atlassian-ide-plugin.xml
108 |
109 | # Cursive Clojure plugin
110 | .idea/replstate.xml
111 |
112 | # Crashlytics plugin (for Android Studio and IntelliJ)
113 | com_crashlytics_export_strings.xml
114 | crashlytics.properties
115 | crashlytics-build.properties
116 | fabric.properties
117 |
118 | # Editor-based Rest Client
119 | .idea/httpRequests
120 |
121 | # Android studio 3.1+ serialized cache file
122 | .idea/caches/build_file_checksums.ser
123 |
124 | # Add after merge review issue #79
125 | .vscode/
126 |
127 | ### PyCharm Patch ###
128 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
129 |
130 | # *.iml
131 | # modules.xml
132 | # .idea/misc.xml
133 | # *.ipr
134 |
135 | # Sonarlint plugin
136 | .idea/**/sonarlint/
137 |
138 | # SonarQube Plugin
139 | .idea/**/sonarIssues.xml
140 |
141 | # Markdown Navigator plugin
142 | .idea/**/markdown-navigator.xml
143 | .idea/**/markdown-navigator/
144 |
145 | ### Python ###
146 | # Byte-compiled / optimized / DLL files
147 | __pycache__/
148 | *.py[cod]
149 | *$py.class
150 |
151 | # C extensions
152 | *.so
153 |
154 | # Distribution / packaging
155 | .Python
156 | build/
157 | develop-eggs/
158 | dist/
159 | downloads/
160 | eggs/
161 | .eggs/
162 | lib/
163 | lib64/
164 | parts/
165 | sdist/
166 | var/
167 | wheels/
168 | pip-wheel-metadata/
169 | share/python-wheels/
170 | *.egg-info/
171 | .installed.cfg
172 | *.egg
173 | MANIFEST
174 |
175 | # PyInstaller
176 | # Usually these files are written by a python script from a template
177 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
178 | *.manifest
179 | *.spec
180 |
181 | # Installer logs
182 | pip-log.txt
183 | pip-delete-this-directory.txt
184 |
185 | # Unit test / coverage reports
186 | htmlcov/
187 | .tox/
188 | .nox/
189 | .coverage
190 | .coverage.*
191 | .cache
192 | nosetests.xml
193 | coverage.xml
194 | *.cover
195 | .hypothesis/
196 | .pytest_cache/
197 |
198 | # Translations
199 | *.mo
200 | *.pot
201 |
202 | # Scrapy stuff:
203 | .scrapy
204 |
205 | # Sphinx documentation
206 | docs/_build/
207 |
208 | # PyBuilder
209 | target/
210 |
211 | # pyenv
212 | .python-version
213 |
214 | # pipenv
215 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
216 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
217 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
218 | # install all needed dependencies.
219 | #Pipfile.lock
220 |
221 | # virtualenv
222 | env/
223 | venv/
224 |
225 | # celery beat schedule file
226 | celerybeat-schedule
227 |
228 | # SageMath parsed files
229 | *.sage.py
230 |
231 | # Spyder project settings
232 | .spyderproject
233 | .spyproject
234 |
235 | # Rope project settings
236 | .ropeproject
237 |
238 | # Mr Developer
239 | .mr.developer.cfg
240 | .project
241 | .pydevproject
242 |
243 | # mkdocs documentation
244 | /site
245 |
246 | # mypy
247 | .mypy_cache/
248 | .dmypy.json
249 | dmypy.json
250 |
251 | # Pyre type checker
252 | .pyre/
253 |
254 | ### VisualStudioCode ###
255 | .vscode/*
256 | !.vscode/settings.json
257 | !.vscode/tasks.json
258 | !.vscode/launch.json
259 | !.vscode/extensions.json
260 |
261 | ### VisualStudioCode Patch ###
262 | # Ignore all local history of files
263 | .history
264 |
265 | # End of https://www.gitignore.io/api/macos,python,pycharm,jupyternotebooks,visualstudiocode
266 |
267 | # Catboost-related files
268 | catboost*
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v3.2.0
4 | hooks:
5 | - id: check-case-conflict # Different OSes
6 | name: "Check case conflict: Naming of files is compatible with all OSes"
7 | - id: check-docstring-first
8 | name: "Check docstring first: Ensures Docstring present and first"
9 | - id: detect-private-key
10 | name: "Detect private key: Prevent commit of env related keys"
11 | - id: trailing-whitespace
12 | name: "Trailing whitespace: Remove empty spaces"
13 | - repo: https://github.com/nbQA-dev/nbQA
14 | rev: 1.8.5
15 | hooks:
16 | - id: nbqa-ruff
17 | name: "ruff nb: Check for errors, styling issues and complexity"
18 | - id: nbqa-mypy
19 | name: "mypy nb: Static type checking"
20 | - id: nbqa-isort
21 | name: "isort nb: Sort file imports"
22 | - id: nbqa-pyupgrade
23 | name: "pyupgrade nb: Updates code to Python 3.9+ code convention"
24 | args: [&py_version --py38-plus]
25 | - id: nbqa-black
26 | name: "black nb: PEP8 compliant code formatter"
27 | - repo: local
28 | hooks:
29 | - id: mypy
30 | name: "mypy: Static type checking"
31 | entry: mypy
32 | language: system
33 | types: [python]
34 | - repo: local
35 | hooks:
36 | - id: ruff-check
37 | name: "Ruff: Check for errors, styling issues and complexity, and fixes issues if possible (including import order)"
38 | entry: ruff check
39 | language: system
40 | args: [--fix, --no-cache]
41 | - id: ruff-format
42 | name: "Ruff: format code in line with PEP8"
43 | entry: ruff format
44 | language: system
45 | args: [--no-cache]
46 | - repo: local
47 | hooks:
48 | - id: codespell
49 | name: "codespell: Check for grammar"
50 | entry: codespell
51 | language: system
52 | types: [python]
53 | args: [-L mot] # Skip the word "mot"
54 | - repo: https://github.com/asottile/pyupgrade
55 | rev: v3.4.0
56 | hooks:
57 | - id: pyupgrade
58 | name: "pyupgrade: Updates code to Python 3.9+ code convention"
59 | args: [*py_version]
60 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guide
2 |
3 | `Probatus` aims to provide a set of tools that can speed up common workflows around validating regressors & classifiers and the data used to train them.
4 | We're very much open to contributions but there are some things to keep in mind:
5 |
6 | - Discuss the feature and implementation you want to add on Github before you write a PR for it. On disagreements, maintainer(s) will have the final word.
7 | - Features need a somewhat general use case. If the use case is very niche it will be hard for us to consider maintaining it.
8 | - If you’re going to add a feature, consider if you could help out in the maintenance of it.
9 | - When issues or pull requests are not going to be resolved or merged, they should be closed as soon as possible. This is kinder than deciding this after a long period. Our issue tracker should reflect work to be done.
10 |
11 | That said, there are many ways to contribute to Probatus, including:
12 |
13 | - Contribution to code
14 | - Improving the documentation
15 | - Reviewing merge requests
16 | - Investigating bugs
17 | - Reporting issues
18 |
19 | Starting out with open source? See the guide [How to Contribute to Open Source](https://opensource.guide/how-to-contribute/) and have a look at [our issues labelled *good first issue*](https://github.com/ing-bank/probatus/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22).
20 |
21 | ## Setup
22 |
23 | Development install:
24 |
25 | ```shell
26 | pip install -e '.[all]'
27 | ```
28 |
29 | Unit testing:
30 |
31 | ```shell
32 | pytest
33 | ```
34 |
35 | We use [pre-commit](https://pre-commit.com/) hooks to ensure code styling. Install with:
36 |
37 | ```shell
38 | pre-commit install
39 | ```
40 |
41 | Now if you install it (which you are encouraged to do), you are encouraged to do the following command before committing your work:
42 |
43 | ```shell
44 | pre-commit run --all-files
45 | ```
46 |
47 | This will allow you to quickly see if the work you made contains some adaptions that you still might need to make before a pull request is accepted.
48 |
49 | ## Standards
50 |
51 | - Python 3.9+
52 | - Follow [PEP8](http://pep8.org/) as closely as possible (except line length)
53 | - [google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/)
54 | - Git: Include a short description of *what* and *why* was done, *how* can be seen in the code. Use present tense, imperative mood
55 | - Git: limit the length of the first line to 72 chars. You can use multiple messages to specify a second (longer) line: `git commit -m "Patch load function" -m "This is a much longer explanation of what was done"`
56 |
57 |
58 | ### Code structure
59 |
60 | * Model validation modules assume that trained models passed for validation are developed in a scikit-learn framework (i.e. have predict_proba and other standard functions), or follow a scikit-learn API e.g. XGBoost.
61 | * Every python file used for model validation needs to be in `/probatus/`
62 | * Class structure for a given module should have a base class and specific functionality classes that inherit from base. If a given module implements only a single way of computing the output, the base class is not required.
63 | * Functions should not be as short as possible in terms of lines of code. If a lot of code is needed, try to put together snippets of code into other functions. This make the code more readable, and easier to test.
64 | * Classes follow the probatus API structure:
65 | * Each class implements `fit()`, `compute()` and `fit_compute()` methods. `fit()` is used to fit an object with provided data (unless no fit is required), and `compute()` calculates the output e.g. DataFrame with a report for the user. Lastly, `fit_compute()` applies one after the other.
66 | * If applicable, the `plot()` method presents the user with the appropriate graphs.
67 | * For `compute()` and `plot()`, check if the object is fitted first.
68 |
69 |
70 | ### Documentation
71 |
72 | Documentation is a very crucial part of the project because it ensures usability of the package. We develop the docs in the following way:
73 |
74 | * We use [mkdocs](https://www.mkdocs.org/) with [mkdocs-material](https://squidfunk.github.io/mkdocs-material/) theme. The `docs/` folder contains all the relevant documentation.
75 | * We use `mkdocs serve` to view the documentation locally. Use it to test the documentation everytime you make any changes.
76 | * Maintainers can deploy the docs using `mkdocs gh-deploy`. The documentation is deployed to `https://ing-bank.github.io/probatus/`.
77 |
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | Copyright (c) ING Bank N.V.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | [](https://github.com/ing-bank/probatus/actions?query=workflow%3A%22Development%22)
4 | [](#)
5 | [](#)
6 | [](#)
7 | 
8 |
9 | # Probatus
10 |
11 | ## Overview
12 |
13 | **Probatus** is a python package that helps validate regression & (multiclass) classification models and the data used to develop them. Main features:
14 |
15 | - [probatus.interpret](https://ing-bank.github.io/probatus/api/model_interpret.html) provides shap-based model interpretation tools
16 | - [probatus.sample_similarity](https://ing-bank.github.io/probatus/api/sample_similarity.html) to compare two datasets using resemblance modelling, f.e. `train` with out-of-time `test`.
17 | - [probatus.feature_elimination.ShapRFECV](https://ing-bank.github.io/probatus/api/feature_elimination.html) provides cross-validated Recursive Feature Elimination using shap feature importance.
18 |
19 | ## Installation
20 |
21 | ```bash
22 | pip install probatus
23 | ```
24 |
25 | ## Documentation
26 |
27 | Documentation at [ing-bank.github.io/probatus/](https://ing-bank.github.io/probatus/).
28 |
29 | You can also check out blog posts about Probatus:
30 |
31 | - [Open-sourcing ShapRFECV — Improved feature selection powered by SHAP.](https://medium.com/ing-blog/open-sourcing-shaprfecv-improved-feature-selection-powered-by-shap-994fe7861560)
32 | - [Model Explainability — How to choose the right tool?](https://medium.com/ing-blog/model-explainability-how-to-choose-the-right-tool-6c5eabd1a46a)
33 |
34 | ## Contributing
35 |
36 | To learn more about making a contribution to Probatus, please see [`CONTRIBUTING.md`](CONTRIBUTING.md).
37 |
--------------------------------------------------------------------------------
/VISION.md:
--------------------------------------------------------------------------------
1 | # The Vision
2 |
3 | This page describes the main principles that drive the development of `Probatus` as well as the general directions, in which the development of the package will be heading.
4 |
5 | ## The Purpose
6 |
7 | `Probatus` has started as a side project of Data Scientists at ING Bank.
8 | Later, we have decided to open-source it, in order to share the tools and enable collaboration with the Data Science community.
9 |
10 | We mainly focus on analyzing the following aspects of building classification models:
11 | - Model input: the quality of the dataset and how to prepare it for modelling,
12 | - Model performance: the quality of the model and stability of the results.
13 | - Model interpretation: understanding the model decision making,
14 |
15 | Our main goals are:
16 | - Continue maintaining the tools that we have built, and make sure that they are well documented and tested
17 | - Continuously extend functionality available in the package
18 | - Build a community of users, which use the package in day-to-day work and learn from each other, while contributing to Probatus
19 |
20 | ## The Principles
21 |
22 | The main principles that drive development of `Probatus` are the following
23 |
24 | - Usefulness - any tool that we build should be useful for a broad range of users,
25 | - Simplicity - simple to understand and analyze steps over state-of-the-art,
26 | - Usability - the developed functionality must be have good documentation, consistent API and work flawlessly with scikit-learn compatible models,
27 | - Reliability - the code that is available for the users should be well tested and reliable, and bugs should be fixed as soon as they are detected.
28 |
29 | ## The Roadmap
30 |
31 | We are open to new ideas, so if you can think of a feature that fits the vision, make an [issue](https://github.com/ing-bank/Probatus/issues) and help us further develop this package.
--------------------------------------------------------------------------------
/docs/api/feature_elimination.md:
--------------------------------------------------------------------------------
1 | # Features Elimination
2 |
3 | This module focuses on feature elimination and it contains two classes:
4 |
5 | - [ShapRFECV][probatus.feature_elimination.feature_elimination.ShapRFECV]: Perform Backwards Recursive Feature Elimination, using SHAP feature importance. It supports binary classification, regression models and hyperparameter optimization at every feature elimination step. Also for LGBM, XGBoost and CatBoost it support early stopping of the model fitting process. It can be an alternative regularization technique to hyperparameter optimization of the number of base trees in gradient boosted tree models. Particularly useful when dealing with large datasets.
6 |
7 | ::: probatus.feature_elimination.feature_elimination
8 |
--------------------------------------------------------------------------------
/docs/api/model_interpret.md:
--------------------------------------------------------------------------------
1 | # Model Interpretation using SHAP
2 |
3 | The aim of this module is to provide tools for model interpretation using the [SHAP](https://shap.readthedocs.io/en/latest/) library.
4 | The class below is a convenience wrapper that implements multiple plots for tree-based & linear models.
5 |
6 | ::: probatus.interpret.model_interpret
7 | ::: probatus.interpret.shap_dependence
8 |
--------------------------------------------------------------------------------
/docs/api/sample_similarity.md:
--------------------------------------------------------------------------------
1 | # Sample Similarity
2 |
3 | The goal of sample similarity module is understanding how different two samples are from a multivariate perspective.
4 |
5 | One of the ways to indicate this is Resemblance Model. Having two datasets - say X1 and X2 - one can analyse how easy it is to recognize which dataset a randomly selected row comes from. The Resemblance model assigns label 0 to the dataset X1, and label 1 to X2 and trains a binary classification model to predict which sample a given row comes from.
6 | By looking at the test AUC, one can conclude that the samples have a different distribution if the AUC is significantly higher than 0.5. Furthermore, by analysing feature importance one can understand which of the features have predictive power.
7 |
8 |
9 |
10 |
11 | The following features are implemented:
12 |
13 | - [SHAPImportanceResemblance (Recommended)][probatus.sample_similarity.resemblance_model.SHAPImportanceResemblance]:
14 | The class applies SHAP library, in order to interpret the tree based resemblance model.
15 | - [PermutationImportanceResemblance][probatus.sample_similarity.resemblance_model.PermutationImportanceResemblance]:
16 | The class applies permutation feature importance in order to understand which features the current model relies on the most. The higher the importance of the feature, the more a given feature possibly differs in X2 compared to X1. The importance indicates how much the test AUC drops if a given feature is permuted.
17 |
18 |
19 | ::: probatus.sample_similarity.resemblance_model
20 |
21 |
--------------------------------------------------------------------------------
/docs/api/utils.md:
--------------------------------------------------------------------------------
1 | # Utility Functions
2 |
3 | This module contains various smaller functionalities that can be used across the `probatus` package.
4 |
5 | ::: probatus.utils.scoring
6 |
--------------------------------------------------------------------------------
/docs/howto/grouped_data.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# How to work with grouped data\n",
8 | "\n",
9 | "[](https://colab.research.google.com/github/ing-bank/probatus/blob/master/docs/howto/grouped_data.ipynb)\n",
10 | "\n",
11 | "One of the often appearing properties of the Data Science problems is the natural grouping of the data. You could for instance have multiple samples for the same customer. In such case, you need to make sure that all samples from a given group are in the same fold e.g. in Cross-Validation.\n",
12 | "\n",
13 | "Let's prepare a dataset with groups."
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "%%capture\n",
23 | "!pip install probatus"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 1,
29 | "metadata": {},
30 | "outputs": [
31 | {
32 | "data": {
33 | "text/plain": [
34 | "[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]"
35 | ]
36 | },
37 | "execution_count": 1,
38 | "metadata": {},
39 | "output_type": "execute_result"
40 | }
41 | ],
42 | "source": [
43 | "from sklearn.datasets import make_classification\n",
44 | "\n",
45 | "X, y = make_classification(n_samples=100, n_features=10, random_state=42)\n",
46 | "groups = [i % 5 for i in range(100)]\n",
47 | "groups[:10]"
48 | ]
49 | },
50 | {
51 | "cell_type": "markdown",
52 | "metadata": {},
53 | "source": [
54 | "The integers in `groups` variable indicate the group id, to which a given sample belongs.\n",
55 | "\n",
56 | "One of the easiest ways to ensure that the data is split using the information about groups is using `from sklearn.model_selection import GroupKFold`. You can also read more about other ways of splitting data with groups in sklearn [here](https://scikit-learn.org/stable/modules/cross_validation.html#cross-validation-iterators-for-grouped-data)."
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": 2,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from sklearn.model_selection import GroupKFold\n",
66 | "\n",
67 | "cv = list(GroupKFold(n_splits=5).split(X, y, groups=groups))"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "Such variable can be passed to the `cv` parameter in `probatus` as well as to hyperparameter optimization e.g. `RandomizedSearchCV` classes."
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 3,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "from sklearn.ensemble import RandomForestClassifier\n",
84 | "from sklearn.model_selection import RandomizedSearchCV\n",
85 | "\n",
86 | "from probatus.feature_elimination import ShapRFECV\n",
87 | "\n",
88 | "model = RandomForestClassifier(random_state=42)\n",
89 | "\n",
90 | "param_grid = {\n",
91 | " \"n_estimators\": [5, 7, 10],\n",
92 | " \"max_leaf_nodes\": [3, 5, 7, 10],\n",
93 | "}\n",
94 | "search = RandomizedSearchCV(model, param_grid, cv=cv, n_iter=1, random_state=42)\n",
95 | "\n",
96 | "shap_elimination = ShapRFECV(model=search, step=0.2, cv=cv, scoring=\"roc_auc\", n_jobs=3, random_state=42)\n",
97 | "report = shap_elimination.fit_compute(X, y)"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 4,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "text/html": [
108 | "
\n",
109 | "\n",
122 | "
\n",
123 | " \n",
124 | " \n",
125 | " | \n",
126 | " num_features | \n",
127 | " features_set | \n",
128 | " eliminated_features | \n",
129 | " train_metric_mean | \n",
130 | " train_metric_std | \n",
131 | " val_metric_mean | \n",
132 | " val_metric_std | \n",
133 | "
\n",
134 | " \n",
135 | " \n",
136 | " \n",
137 | " 1 | \n",
138 | " 10 | \n",
139 | " [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | \n",
140 | " [6, 7] | \n",
141 | " 0.999562 | \n",
142 | " 0.000876 | \n",
143 | " 0.954945 | \n",
144 | " 0.090110 | \n",
145 | "
\n",
146 | " \n",
147 | " 2 | \n",
148 | " 8 | \n",
149 | " [0, 1, 2, 3, 4, 5, 8, 9] | \n",
150 | " [5] | \n",
151 | " 0.999118 | \n",
152 | " 0.001081 | \n",
153 | " 0.945513 | \n",
154 | " 0.089606 | \n",
155 | "
\n",
156 | " \n",
157 | " 3 | \n",
158 | " 7 | \n",
159 | " [0, 1, 2, 3, 4, 8, 9] | \n",
160 | " [4] | \n",
161 | " 0.999559 | \n",
162 | " 0.000548 | \n",
163 | " 0.928749 | \n",
164 | " 0.137507 | \n",
165 | "
\n",
166 | " \n",
167 | " 4 | \n",
168 | " 6 | \n",
169 | " [0, 1, 2, 3, 8, 9] | \n",
170 | " [8] | \n",
171 | " 0.999179 | \n",
172 | " 0.001051 | \n",
173 | " 0.969288 | \n",
174 | " 0.058854 | \n",
175 | "
\n",
176 | " \n",
177 | " 5 | \n",
178 | " 5 | \n",
179 | " [0, 1, 2, 3, 9] | \n",
180 | " [9] | \n",
181 | " 0.999748 | \n",
182 | " 0.000237 | \n",
183 | " 0.961767 | \n",
184 | " 0.066540 | \n",
185 | "
\n",
186 | " \n",
187 | " 6 | \n",
188 | " 4 | \n",
189 | " [0, 1, 2, 3] | \n",
190 | " [1] | \n",
191 | " 0.999433 | \n",
192 | " 0.000700 | \n",
193 | " 0.950816 | \n",
194 | " 0.090982 | \n",
195 | "
\n",
196 | " \n",
197 | " 7 | \n",
198 | " 3 | \n",
199 | " [0, 2, 3] | \n",
200 | " [0] | \n",
201 | " 0.999120 | \n",
202 | " 0.000729 | \n",
203 | " 0.970596 | \n",
204 | " 0.051567 | \n",
205 | "
\n",
206 | " \n",
207 | " 8 | \n",
208 | " 2 | \n",
209 | " [2, 3] | \n",
210 | " [3] | \n",
211 | " 0.999496 | \n",
212 | " 0.000617 | \n",
213 | " 0.938639 | \n",
214 | " 0.117736 | \n",
215 | "
\n",
216 | " \n",
217 | " 9 | \n",
218 | " 1 | \n",
219 | " [2] | \n",
220 | " [] | \n",
221 | " 0.998424 | \n",
222 | " 0.001819 | \n",
223 | " 0.938339 | \n",
224 | " 0.097936 | \n",
225 | "
\n",
226 | " \n",
227 | "
\n",
228 | "
"
229 | ],
230 | "text/plain": [
231 | " num_features features_set eliminated_features \\\n",
232 | "1 10 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [6, 7] \n",
233 | "2 8 [0, 1, 2, 3, 4, 5, 8, 9] [5] \n",
234 | "3 7 [0, 1, 2, 3, 4, 8, 9] [4] \n",
235 | "4 6 [0, 1, 2, 3, 8, 9] [8] \n",
236 | "5 5 [0, 1, 2, 3, 9] [9] \n",
237 | "6 4 [0, 1, 2, 3] [1] \n",
238 | "7 3 [0, 2, 3] [0] \n",
239 | "8 2 [2, 3] [3] \n",
240 | "9 1 [2] [] \n",
241 | "\n",
242 | " train_metric_mean train_metric_std val_metric_mean val_metric_std \n",
243 | "1 0.999562 0.000876 0.954945 0.090110 \n",
244 | "2 0.999118 0.001081 0.945513 0.089606 \n",
245 | "3 0.999559 0.000548 0.928749 0.137507 \n",
246 | "4 0.999179 0.001051 0.969288 0.058854 \n",
247 | "5 0.999748 0.000237 0.961767 0.066540 \n",
248 | "6 0.999433 0.000700 0.950816 0.090982 \n",
249 | "7 0.999120 0.000729 0.970596 0.051567 \n",
250 | "8 0.999496 0.000617 0.938639 0.117736 \n",
251 | "9 0.998424 0.001819 0.938339 0.097936 "
252 | ]
253 | },
254 | "execution_count": 4,
255 | "metadata": {},
256 | "output_type": "execute_result"
257 | }
258 | ],
259 | "source": [
260 | "report"
261 | ]
262 | }
263 | ],
264 | "metadata": {
265 | "kernelspec": {
266 | "display_name": "Python 3",
267 | "language": "python",
268 | "name": "python3"
269 | },
270 | "language_info": {
271 | "codemirror_mode": {
272 | "name": "ipython",
273 | "version": 3
274 | },
275 | "file_extension": ".py",
276 | "mimetype": "text/x-python",
277 | "name": "python",
278 | "nbconvert_exporter": "python",
279 | "pygments_lexer": "ipython3",
280 | "version": "3.10.13"
281 | }
282 | },
283 | "nbformat": 4,
284 | "nbformat_minor": 4
285 | }
286 |
--------------------------------------------------------------------------------
/docs/howto/reproducibility.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# How to ensure reproducibility of the results"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "[](https://colab.research.google.com/github/ing-bank/probatus/blob/master/docs/howto/reproducibility.ipynb)\n",
15 | "\n",
16 | "This page describes how to make sure that the analysis that you perform using `probatus` is fully reproducible.\n",
17 | "\n",
18 | "There are two factors that influence reproducibility of the results:\n",
19 | "\n",
20 | "- Inputs of `probatus` modules,\n",
21 | "- The `random_state` of `probatus` modules.\n",
22 | "\n",
23 | "The below sections cover how to ensure reproducibility of the results by controling these aspects.\n",
24 | "\n",
25 | "## Inputs of probatus modules\n",
26 | "\n",
27 | "There are various parameters that modules of probatus take as input. Below we will cover the most often occurring ones.\n",
28 | "\n",
29 | "### Static dataset\n",
30 | "\n",
31 | "When using `probatus`, one of the most crucial aspects is the provided dataset. Therefore, the first thing to do is to ensure that the passed dataset does not change along the way. \n",
32 | "\n",
33 | "Below is a code snipped of random data preparation. In sklearn, you can ensure this by setting the `random_state` parameter. You will probably use a different dataset in your projects, but always make sure that the input data is static."
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "%%capture\n",
43 | "!pip install probatus"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 1,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "from sklearn.datasets import make_classification\n",
53 | "\n",
54 | "X, y = make_classification(n_samples=100, n_features=10, random_state=42)"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "### Static data splits\n",
62 | "\n",
63 | "Whenever you split the data in any way, you need to make sure that the splits are always the same. \n",
64 | "\n",
65 | "If you use the `train_test_split` functionality from sklearn, this can be enforced by setting the `random_state` parameter. \n",
66 | "\n",
67 | "Another crucial aspect is how you use the `cv` parameter, which defines the folds settings that you will use in the experiments. If the `cv` is set to an integer, you don't need to worry about it - the `random_state` of `probatus` will take care of it. However, if you want to pass a custom cv generator object, you have to set the `random_state` there as well.\n",
68 | "\n",
69 | "Below are some examples of static splits:"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 2,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "from sklearn.model_selection import StratifiedKFold, train_test_split\n",
79 | "\n",
80 | "# Static train/test split\n",
81 | "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n",
82 | "\n",
83 | "# Static CV settings\n",
84 | "cv1 = 5\n",
85 | "cv2 = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "### Static classifier\n",
93 | "\n",
94 | "Most of `probatus` modules work with the provided classifiers. Whenever one needs to provide a not-fitted classifier, it is enough to set the `random_state`. However, if the classifier needs to be fitted beforehand, you have to make sure that the model training is reproducible as well."
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 3,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "from sklearn.ensemble import RandomForestClassifier\n",
104 | "\n",
105 | "model = RandomForestClassifier(random_state=42)"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "### Static search CV for hyperparameter tuning\n",
113 | "\n",
114 | "Some of the modules e.g. `ShapRFECV`, allow you to perform optimization of the model. Whenever, you use such functionality, make sure that these classes have set the `random_state`. This way, in every round of optimization, you will explore the same set of parameter permutations. In case the search space is also generated based on randomness, make sure that the `random_state` is set to it as well."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 4,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "from sklearn.model_selection import RandomizedSearchCV\n",
124 | "\n",
125 | "param_grid = {\n",
126 | " \"n_estimators\": [5, 7, 10],\n",
127 | " \"max_leaf_nodes\": [3, 5, 7, 10],\n",
128 | "}\n",
129 | "search = RandomizedSearchCV(model, param_grid, n_iter=1, random_state=42)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "### Any other sources of randomness"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "Before running `probatus` modules think about the inputs, and consider if there is any other type of randomness involved. If there is, one option to possibly solve the issue is setting the random seed at the beginning of the code."
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 5,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "# Optional step\n",
153 | "import numpy as np\n",
154 | "\n",
155 | "np.random.seed(42)"
156 | ]
157 | },
158 | {
159 | "cell_type": "markdown",
160 | "metadata": {},
161 | "source": [
162 | "## Reproducibility in probatus\n",
163 | "\n",
164 | "Most of the modules in `probatus` allow you to set the `random_state`. This setting essentially makes sure that any code that the functions operate on has a static flow. As long as it is seet and you ensure all other inputs do not cause additional fluctuations between runs, you can make sure that your results are reproducible."
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 6,
170 | "metadata": {},
171 | "outputs": [],
172 | "source": [
173 | "from probatus.feature_elimination import ShapRFECV\n",
174 | "\n",
175 | "shap_elimination = ShapRFECV(model=search, step=0.2, cv=cv2, scoring=\"roc_auc\", n_jobs=3, random_state=42)\n",
176 | "report = shap_elimination.fit_compute(X, y)"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 7,
182 | "metadata": {},
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/html": [
187 | "\n",
188 | "\n",
201 | "
\n",
202 | " \n",
203 | " \n",
204 | " | \n",
205 | " num_features | \n",
206 | " eliminated_features | \n",
207 | " val_metric_mean | \n",
208 | "
\n",
209 | " \n",
210 | " \n",
211 | " \n",
212 | " 1 | \n",
213 | " 10 | \n",
214 | " [8, 9] | \n",
215 | " 0.983 | \n",
216 | "
\n",
217 | " \n",
218 | " 2 | \n",
219 | " 8 | \n",
220 | " [5] | \n",
221 | " 0.969 | \n",
222 | "
\n",
223 | " \n",
224 | " 3 | \n",
225 | " 7 | \n",
226 | " [7] | \n",
227 | " 0.984 | \n",
228 | "
\n",
229 | " \n",
230 | " 4 | \n",
231 | " 6 | \n",
232 | " [6] | \n",
233 | " 0.979 | \n",
234 | "
\n",
235 | " \n",
236 | " 5 | \n",
237 | " 5 | \n",
238 | " [4] | \n",
239 | " 0.983 | \n",
240 | "
\n",
241 | " \n",
242 | " 6 | \n",
243 | " 4 | \n",
244 | " [1] | \n",
245 | " 0.987 | \n",
246 | "
\n",
247 | " \n",
248 | " 7 | \n",
249 | " 3 | \n",
250 | " [0] | \n",
251 | " 0.991 | \n",
252 | "
\n",
253 | " \n",
254 | " 8 | \n",
255 | " 2 | \n",
256 | " [3] | \n",
257 | " 0.991 | \n",
258 | "
\n",
259 | " \n",
260 | " 9 | \n",
261 | " 1 | \n",
262 | " [] | \n",
263 | " 0.969 | \n",
264 | "
\n",
265 | " \n",
266 | "
\n",
267 | "
"
268 | ],
269 | "text/plain": [
270 | " num_features eliminated_features val_metric_mean\n",
271 | "1 10 [8, 9] 0.983\n",
272 | "2 8 [5] 0.969\n",
273 | "3 7 [7] 0.984\n",
274 | "4 6 [6] 0.979\n",
275 | "5 5 [4] 0.983\n",
276 | "6 4 [1] 0.987\n",
277 | "7 3 [0] 0.991\n",
278 | "8 2 [3] 0.991\n",
279 | "9 1 [] 0.969"
280 | ]
281 | },
282 | "execution_count": 7,
283 | "metadata": {},
284 | "output_type": "execute_result"
285 | }
286 | ],
287 | "source": [
288 | "report[[\"num_features\", \"eliminated_features\", \"val_metric_mean\"]]"
289 | ]
290 | }
291 | ],
292 | "metadata": {
293 | "kernelspec": {
294 | "display_name": "Python 3",
295 | "language": "python",
296 | "name": "python3"
297 | },
298 | "language_info": {
299 | "codemirror_mode": {
300 | "name": "ipython",
301 | "version": 3
302 | },
303 | "file_extension": ".py",
304 | "mimetype": "text/x-python",
305 | "name": "python",
306 | "nbconvert_exporter": "python",
307 | "pygments_lexer": "ipython3",
308 | "version": "3.10.13"
309 | }
310 | },
311 | "nbformat": 4,
312 | "nbformat_minor": 4
313 | }
314 |
--------------------------------------------------------------------------------
/docs/img/Probatus_P.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/Probatus_P.png
--------------------------------------------------------------------------------
/docs/img/Probatus_P_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/Probatus_P_white.png
--------------------------------------------------------------------------------
/docs/img/earlystoppingshaprfecv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/earlystoppingshaprfecv.png
--------------------------------------------------------------------------------
/docs/img/logo_large.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/logo_large.png
--------------------------------------------------------------------------------
/docs/img/logo_large_white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/logo_large_white.png
--------------------------------------------------------------------------------
/docs/img/model_interpret_dep.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_dep.png
--------------------------------------------------------------------------------
/docs/img/model_interpret_importance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_importance.png
--------------------------------------------------------------------------------
/docs/img/model_interpret_sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_sample.png
--------------------------------------------------------------------------------
/docs/img/model_interpret_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/model_interpret_summary.png
--------------------------------------------------------------------------------
/docs/img/resemblance_model_schema.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/resemblance_model_schema.png
--------------------------------------------------------------------------------
/docs/img/sample_similarity_permutation_importance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_permutation_importance.png
--------------------------------------------------------------------------------
/docs/img/sample_similarity_shap_importance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_shap_importance.png
--------------------------------------------------------------------------------
/docs/img/sample_similarity_shap_summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/sample_similarity_shap_summary.png
--------------------------------------------------------------------------------
/docs/img/shaprfecv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/docs/img/shaprfecv.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | **Probatus** is a Python library that allows to analyse binary classification models as well as the data used to develop them.
4 | The main features assess the metric stability and analyse differences between two data samples e.g. shift between train and test splits.
5 |
6 | ## Installation
7 |
8 | In order to install Probatus you need to use Python 3.9 or higher.
9 |
10 | Install `probatus` via pip with:
11 |
12 | ```bash
13 | pip install probatus
14 | ```
15 |
16 | Alternatively you can fork/clone and run:
17 |
18 | ```bash
19 | git clone https://gitlab.com/ing_rpaa/probatus.git
20 | cd probatus
21 | pip install .
22 | ```
23 |
24 | ## Licence
25 |
26 | Probatus is created under MIT License, see more in [LICENCE file](https://github.com/ing-bank/probatus/blob/main/LICENCE).
27 |
28 |
29 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Probatus
2 |
3 | repo_url: https://github.com/ing-bank/probatus/
4 | site_url: https://ing-bank.github.io/probatus/
5 | site_description: Validation of regressors and classifiers and data used to develop them
6 | site_author: ING Bank N. V.
7 |
8 | use_directory_urls: false
9 |
10 | watch:
11 | - probatus
12 |
13 | plugins:
14 | - mkdocs-jupyter
15 | - search
16 | - mkdocstrings:
17 | handlers:
18 | python:
19 | options:
20 | selection:
21 | inherited_members: true
22 | filters:
23 | - "!^Base"
24 | - "!^_" # exlude all members starting with _
25 | - "^__init__$" # but always include __init__ modules and methods
26 | rendering:
27 | show_root_toc_entry: false
28 |
29 | theme:
30 | name: material
31 | logo: img/Probatus_P_white.png
32 | favicon: img/Probatus_P_white.png
33 | font:
34 | text: Ubuntu
35 | code: Ubuntu Mono
36 | features:
37 | - navigation.tabs
38 | palette:
39 | scheme: default
40 | primary: deep orange
41 | accent: indigo
42 |
43 | copyright: Copyright © ING Bank N.V.
--------------------------------------------------------------------------------
/probatus/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ING Bank N.V.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | # the Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
20 | name = "probatus"
21 |
--------------------------------------------------------------------------------
/probatus/feature_elimination/__init__.py:
--------------------------------------------------------------------------------
1 | from .feature_elimination import ShapRFECV
2 | from .early_stopping_feature_elimination import EarlyStoppingShapRFECV
3 |
4 | __all__ = ["ShapRFECV", "EarlyStoppingShapRFECV"]
5 |
--------------------------------------------------------------------------------
/probatus/feature_elimination/early_stopping_feature_elimination.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from probatus.feature_elimination import ShapRFECV
3 |
4 |
5 | class EarlyStoppingShapRFECV(ShapRFECV):
6 | """
7 | This class performs Backwards Recursive Feature Elimination, using SHAP feature importance.
8 |
9 | This is a child of ShapRFECV which allows early stopping of the training step, this class is compatible with
10 | LightGBM, XGBoost and CatBoost models. If you are not using early stopping, you should use the parent class,
11 | ShapRFECV, instead of EarlyStoppingShapRFECV.
12 |
13 | [Early stopping](https://en.wikipedia.org/wiki/Early_stopping) is a type of
14 | regularization technique in which the model is trained until the scoring metric, measured on a validation set,
15 | stops improving after a number of early_stopping_rounds. In boosted tree models, this technique can increase
16 | the training speed, by skipping the training of trees that do not improve the scoring metric any further,
17 | which is particularly useful when the training dataset is large.
18 |
19 | Note that if the regressor or classifier is a hyperparameter search model is used, the early stopping parameter is passed only
20 | to the fit method of the model duiring the Shapley values estimation step, and not for the hyperparameter
21 | search step.
22 | Early stopping can be seen as a type of regularization of the optimal number of trees. Therefore you can use
23 | it directly with a LightGBM or XGBoost model, as an alternative to a hyperparameter search model.
24 |
25 | At each round, for a
26 | given feature set, starting from all available features, the following steps are applied:
27 |
28 | 1. (Optional) Tune the hyperparameters of the model using sklearn compatible search CV e.g.
29 | [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html),
30 | [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html?highlight=randomized#sklearn.model_selection.RandomizedSearchCV), or
31 | [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html).
32 | Note that during this step the model does not use early stopping.
33 | 2. Apply Cross-validation (CV) to estimate the SHAP feature importance on the provided dataset. In each CV
34 | iteration, the model is fitted on the train folds, and applied on the validation fold to estimate
35 | SHAP feature importance. The model is trained until the scoring metric eval_metric, measured on the
36 | validation fold, stops improving after a number of early_stopping_rounds.
37 | 3. Remove `step` lowest SHAP importance features from the dataset.
38 |
39 | At the end of the process, the user can plot the performance of the model for each iteration, and select the
40 | optimal number of features and the features set.
41 |
42 | We recommend using [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
43 | because by default it handles missing values and categorical features. In case of other models, make sure to
44 | handle these issues for your dataset and consider impact it might have on features importance.
45 |
46 |
47 | Example:
48 | ```python
49 | from lightgbm import LGBMClassifier
50 | import pandas as pd
51 | from probatus.feature_elimination import EarlyStoppingShapRFECV
52 | from sklearn.datasets import make_classification
53 |
54 | feature_names = [
55 | 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7',
56 | 'f8', 'f9', 'f10', 'f11', 'f12', 'f13',
57 | 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20']
58 |
59 | # Prepare two samples
60 | X, y = make_classification(n_samples=200, class_sep=0.05, n_informative=6, n_features=20,
61 | random_state=0, n_redundant=10, n_clusters_per_class=1)
62 | X = pd.DataFrame(X, columns=feature_names)
63 |
64 | # Prepare model
65 | model = LGBMClassifier(n_estimators=200, max_depth=3)
66 |
67 | # Run feature elimination
68 | shap_elimination = EarlyStoppingShapRFECV(
69 | model=model, step=0.2, cv=10, scoring='roc_auc', early_stopping_rounds=10, n_jobs=3)
70 | report = shap_elimination.fit_compute(X, y)
71 |
72 | # Make plots
73 | performance_plot = shap_elimination.plot()
74 |
75 | # Get final feature set
76 | final_features_set = shap_elimination.get_reduced_features_set(num_features=3)
77 | ```
78 |
79 |
80 | """ # noqa
81 |
82 | def __init__(
83 | self,
84 | model,
85 | step=1,
86 | min_features_to_select=1,
87 | cv=None,
88 | scoring="roc_auc",
89 | n_jobs=-1,
90 | verbose=0,
91 | random_state=None,
92 | early_stopping_rounds=5,
93 | eval_metric="auc",
94 | ):
95 | """
96 | This method initializes the class.
97 |
98 | Args:
99 | model (sklearn compatible classifier or regressor, sklearn compatible search CV e.g. GridSearchCV, RandomizedSearchCV or BayesSearchCV):
100 | A model that will be optimized and trained at each round of features elimination. The model must
101 | support early stopping of training, which is the case for XGBoost and LightGBM, for example. The
102 | recommended model is [LGBMClassifier](https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.LGBMClassifier.html),
103 | because it by default handles the missing values and categorical variables. This parameter also supports
104 | any hyperparameter search schema that is consistent with the sklearn API e.g.
105 | [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html),
106 | [RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
107 | or [BayesSearchCV](https://scikit-optimize.github.io/stable/modules/generated/skopt.BayesSearchCV.html#skopt.BayesSearchCV).
108 | Note that if a hyperparemeter search model is used, the hyperparameters are tuned without early
109 | stopping. Early stopping is applied only during the Shapley values estimation for feature
110 | elimination. We recommend simply passing the model without hyperparameter optimization, or using
111 | ShapRFECV without early stopping.
112 |
113 |
114 | step (int or float, optional):
115 | Number of lowest importance features removed each round. If it is an int, then each round such number of
116 | features is discarded. If float, such percentage of remaining features (rounded down) is removed each
117 | iteration. It is recommended to use float, since it is faster for a large number of features, and slows
118 | down and becomes more precise towards less features. Note: the last round may remove fewer features in
119 | order to reach min_features_to_select.
120 | If columns_to_keep parameter is specified in the fit method, step is the number of features to remove after
121 | keeping those columns.
122 |
123 | min_features_to_select (int, optional):
124 | Minimum number of features to be kept. This is a stopping criterion of the feature elimination. By
125 | default the process stops when one feature is left. If columns_to_keep is specified in the fit method,
126 | it may override this parameter to the maximum between length of columns_to_keep the two.
127 |
128 | cv (int, cross-validation generator or an iterable, optional):
129 | Determines the cross-validation splitting strategy. Compatible with sklearn
130 | [cv parameter](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.RFECV.html).
131 | If None, then cv of 5 is used.
132 |
133 | scoring (string or probatus.utils.Scorer, optional):
134 | Metric for which the model performance is calculated. It can be either a metric name aligned with predefined
135 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html).
136 | Another option is using probatus.utils.Scorer to define a custom metric.
137 |
138 | n_jobs (int, optional):
139 | Number of cores to run in parallel while fitting across folds. None means 1 unless in a
140 | `joblib.parallel_backend` context. -1 means using all processors.
141 |
142 | verbose (int, optional):
143 | Controls verbosity of the output:
144 |
145 | - 0 - neither prints nor warnings are shown
146 | - 1 - only most important warnings
147 | - 2 - shows all prints and all warnings.
148 |
149 | random_state (int, optional):
150 | Random state set at each round of feature elimination. If it is None, the results will not be
151 | reproducible and in random search at each iteration a different hyperparameters might be tested. For
152 | reproducible results set it to integer.
153 |
154 | early_stopping_rounds (int, optional):
155 | Number of rounds with constant performance after which the model fitting stops. This is passed to the
156 | fit method of the model for Shapley values estimation, but not for hyperparameter search. Only
157 | supported by some models, such as XGBoost and LightGBM.
158 |
159 | eval_metric (str, optional):
160 | Metric for scoring fitting rounds and activating early stopping. This is passed to the
161 | fit method of the model for Shapley values estimation, but not for hyperparameter search. Only
162 | supported by some models, such as [XGBoost](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters)
163 | and [LightGBM](https://lightgbm.readthedocs.io/en/latest/Parameters.html#metric-parameters).
164 | Note that `eval_metric` is an argument of the model's fit method and it is different from `scoring`.
165 | """ # noqa
166 | # TODO: This deprecation warning will removed when it's decided that this class can be deleted.
167 | warnings.warn(
168 | "The separate EarlyStoppingShapRFECV class is going to be deprecated"
169 | " in a later version of Probatus, since its now part of the"
170 | " ShapRFECV class. Please adjust your imported class name from"
171 | " 'EarlyStoppingShapRFECV' to 'ShapRFECV'.",
172 | DeprecationWarning,
173 | )
174 |
175 | super().__init__(
176 | model,
177 | step=step,
178 | min_features_to_select=min_features_to_select,
179 | cv=cv,
180 | scoring=scoring,
181 | n_jobs=n_jobs,
182 | verbose=verbose,
183 | random_state=random_state,
184 | early_stopping_rounds=early_stopping_rounds,
185 | eval_metric=eval_metric,
186 | )
187 |
--------------------------------------------------------------------------------
/probatus/interpret/__init__.py:
--------------------------------------------------------------------------------
1 | from .shap_dependence import DependencePlotter
2 | from .model_interpret import ShapModelInterpreter
3 |
4 | __all__ = ["DependencePlotter", "ShapModelInterpreter"]
5 |
--------------------------------------------------------------------------------
/probatus/interpret/model_interpret.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import pandas as pd
4 | from shap import summary_plot
5 | from shap.plots._waterfall import waterfall_legacy
6 |
7 | from probatus.interpret import DependencePlotter
8 | from probatus.utils import (
9 | BaseFitComputePlotClass,
10 | assure_list_of_strings,
11 | calculate_shap_importance,
12 | preprocess_data,
13 | preprocess_labels,
14 | get_single_scorer,
15 | shap_calc,
16 | )
17 |
18 |
19 | class ShapModelInterpreter(BaseFitComputePlotClass):
20 | """
21 | This class is a wrapper that allows to easily analyse a model's features.
22 |
23 | It allows us to plot SHAP feature importance,
24 | SHAP summary plot and SHAP dependence plots.
25 |
26 | Example:
27 | ```python
28 | from sklearn.datasets import make_classification
29 | from sklearn.ensemble import RandomForestClassifier
30 | from sklearn.model_selection import train_test_split
31 | from probatus.interpret import ShapModelInterpreter
32 | import numpy as np
33 | import pandas as pd
34 |
35 | feature_names = ['f1', 'f2', 'f3', 'f4']
36 |
37 | # Prepare two samples
38 | X, y = make_classification(n_samples=5000, n_features=4, random_state=0)
39 | X = pd.DataFrame(X, columns=feature_names)
40 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
41 |
42 | # Prepare and fit model. Remember about class_weight="balanced" or an equivalent.
43 | model = RandomForestClassifier(class_weight='balanced', n_estimators = 100, max_depth=2, random_state=0)
44 | model.fit(X_train, y_train)
45 |
46 | # Train ShapModelInterpreter
47 | shap_interpreter = ShapModelInterpreter(model)
48 | feature_importance = shap_interpreter.fit_compute(X_train, X_test, y_train, y_test)
49 |
50 | # Make plots
51 | ax1 = shap_interpreter.plot('importance')
52 | ax2 = shap_interpreter.plot('summary')
53 | ax3 = shap_interpreter.plot('dependence', target_columns=['f1', 'f2'])
54 | ax4 = shap_interpreter.plot('sample', samples_index=[X_test.index.tolist()[0]])
55 | ```
56 |
57 |
58 |
59 |
60 |
61 | """
62 |
63 | def __init__(self, model, scoring="roc_auc", verbose=0, random_state=None):
64 | """
65 | Initializes the class.
66 |
67 | Args:
68 | model (classifier or regressor):
69 | Model fitted on X_train.
70 |
71 | scoring (string or probatus.utils.Scorer, optional):
72 | Metric for which the model performance is calculated. It can be either a metric name aligned with
73 | predefined classification scorers names in sklearn
74 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)).
75 | Another option is using probatus.utils.Scorer to define a custom metric.
76 |
77 | verbose (int, optional):
78 | Controls verbosity of the output:
79 |
80 | - 0 - neither prints nor warnings are shown
81 | - 1 - only most important warnings
82 | - 2 - shows all prints and all warnings.
83 |
84 | random_state (int, optional):
85 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For
86 | reproducible results set it to an integer.
87 | """
88 | self.model = model
89 | self.scorer = get_single_scorer(scoring)
90 | self.verbose = verbose
91 | self.random_state = random_state
92 |
93 | def fit(
94 | self,
95 | X_train,
96 | X_test,
97 | y_train,
98 | y_test,
99 | column_names=None,
100 | class_names=None,
101 | **shap_kwargs,
102 | ):
103 | """
104 | Fits the object and calculates the shap values for the provided datasets.
105 |
106 | Args:
107 | X_train (pd.DataFrame):
108 | Dataframe containing training data.
109 |
110 | X_test (pd.DataFrame):
111 | Dataframe containing test data.
112 |
113 | y_train (pd.Series):
114 | Series of labels for train data.
115 |
116 | y_test (pd.Series):
117 | Series of labels for test data.
118 |
119 | column_names (None, or list of str, optional):
120 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used.
121 |
122 | class_names (None, or list of str, optional):
123 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are
124 | used.
125 |
126 | **shap_kwargs:
127 | keyword arguments passed to
128 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
129 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
130 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while
131 | `check_additivity=False` disables the additivity check inside SHAP.
132 | """
133 |
134 | self.X_train, self.column_names = preprocess_data(
135 | X_train, X_name="X_train", column_names=column_names, verbose=self.verbose
136 | )
137 | self.X_test, _ = preprocess_data(X_test, X_name="X_test", column_names=column_names, verbose=self.verbose)
138 | self.y_train = preprocess_labels(y_train, y_name="y_train", index=self.X_train.index, verbose=self.verbose)
139 | self.y_test = preprocess_labels(y_test, y_name="y_test", index=self.X_test.index, verbose=self.verbose)
140 |
141 | # Set class names
142 | self.class_names = class_names
143 | if self.class_names is None:
144 | self.class_names = ["Negative Class", "Positive Class"]
145 |
146 | # Calculate Metrics
147 | self.train_score = self.scorer.score(self.model, self.X_train, self.y_train)
148 | self.test_score = self.scorer.score(self.model, self.X_test, self.y_test)
149 |
150 | self.results_text = (
151 | f"Train {self.scorer.metric_name}: {np.round(self.train_score, 3)},\n"
152 | f"Test {self.scorer.metric_name}: {np.round(self.test_score, 3)}."
153 | )
154 |
155 | (
156 | self.shap_values_train,
157 | self.expected_value_train,
158 | self.tdp_train,
159 | ) = self._prep_shap_related_variables(
160 | model=self.model,
161 | X=self.X_train,
162 | y=self.y_train,
163 | column_names=self.column_names,
164 | class_names=self.class_names,
165 | verbose=self.verbose,
166 | random_state=self.random_state,
167 | **shap_kwargs,
168 | )
169 |
170 | (
171 | self.shap_values_test,
172 | self.expected_value_test,
173 | self.tdp_test,
174 | ) = self._prep_shap_related_variables(
175 | model=self.model,
176 | X=self.X_test,
177 | y=self.y_test,
178 | column_names=self.column_names,
179 | class_names=self.class_names,
180 | verbose=self.verbose,
181 | random_state=self.random_state,
182 | **shap_kwargs,
183 | )
184 |
185 | self.fitted = True
186 |
187 | @staticmethod
188 | def _prep_shap_related_variables(
189 | model,
190 | X,
191 | y,
192 | approximate=False,
193 | verbose=0,
194 | random_state=None,
195 | column_names=None,
196 | class_names=None,
197 | **shap_kwargs,
198 | ):
199 | """
200 | The function prepares the variables related to shap that are used to interpret the model.
201 |
202 | Returns:
203 | (np.array, int, DependencePlotter):
204 | Shap values, expected value of the explainer, and fitted TreeDependencePlotter for a given dataset.
205 | """
206 | shap_values, explainer = shap_calc(
207 | model,
208 | X,
209 | approximate=approximate,
210 | verbose=verbose,
211 | random_state=random_state,
212 | return_explainer=True,
213 | **shap_kwargs,
214 | )
215 |
216 | expected_value = explainer.expected_value
217 |
218 | # For sklearn models the expected values consists of two elements (negative_class and positive_class)
219 | if isinstance(expected_value, list) or isinstance(expected_value, np.ndarray):
220 | expected_value = expected_value[1]
221 |
222 | # Initialize tree dependence plotter
223 | tdp = DependencePlotter(model, verbose=verbose).fit(
224 | X,
225 | y,
226 | column_names=column_names,
227 | class_names=class_names,
228 | precalc_shap=shap_values,
229 | )
230 | return shap_values, expected_value, tdp
231 |
232 | def compute(self, return_scores=False, shap_variance_penalty_factor=None):
233 | """
234 | Computes the DataFrame that presents the importance of each feature.
235 |
236 | Args:
237 | return_scores (bool, optional):
238 | Flag indicating whether the method should return the train and test score of the model, together with
239 | the model interpretation report. If true, the output of this method is a tuple of DataFrame, float,
240 | float.
241 |
242 | shap_variance_penalty_factor (int or float, optional):
243 | Apply aggregation penalty when computing average of shap values for a given feature.
244 | Results in a preference for features that have smaller standard deviation of shap
245 | values (more coherent shap importance). Recommend value 0.5 - 1.0.
246 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor))
247 |
248 | Returns:
249 | (pd.DataFrame or tuple(pd.DataFrame, float, float)):
250 | Dataframe with SHAP feature importance, or tuple containing the dataframe, train and test scores of the
251 | model.
252 | """
253 | self._check_if_fitted()
254 |
255 | # Compute SHAP importance
256 | self.importance_df_train = calculate_shap_importance(
257 | self.shap_values_train,
258 | self.column_names,
259 | output_columns_suffix="_train",
260 | shap_variance_penalty_factor=shap_variance_penalty_factor,
261 | )
262 |
263 | self.importance_df_test = calculate_shap_importance(
264 | self.shap_values_test,
265 | self.column_names,
266 | output_columns_suffix="_test",
267 | shap_variance_penalty_factor=shap_variance_penalty_factor,
268 | )
269 |
270 | # Concatenate the train and test, sort by test set importance and reorder the columns
271 | self.importance_df = pd.concat([self.importance_df_train, self.importance_df_test], axis=1).sort_values(
272 | "mean_abs_shap_value_test", ascending=False
273 | )[
274 | [
275 | "mean_abs_shap_value_test",
276 | "mean_abs_shap_value_train",
277 | "mean_shap_value_test",
278 | "mean_shap_value_train",
279 | ]
280 | ]
281 |
282 | if return_scores:
283 | return self.importance_df, self.train_score, self.test_score
284 | else:
285 | return self.importance_df
286 |
287 | def fit_compute(
288 | self,
289 | X_train,
290 | X_test,
291 | y_train,
292 | y_test,
293 | column_names=None,
294 | class_names=None,
295 | return_scores=False,
296 | shap_variance_penalty_factor=None,
297 | **shap_kwargs,
298 | ):
299 | """
300 | Fits the object and calculates the shap values for the provided datasets.
301 |
302 | Args:
303 | X_train (pd.DataFrame):
304 | Dataframe containing training data.
305 |
306 | X_test (pd.DataFrame):
307 | Dataframe containing test data.
308 |
309 | y_train (pd.Series):
310 | Series of labels for train data.
311 |
312 | y_test (pd.Series):
313 | Series of labels for test data.
314 |
315 | column_names (None, or list of str, optional):
316 | List of feature names for the dataset.
317 | If None, then column names from the X_train dataframe are used.
318 |
319 | class_names (None, or list of str, optional):
320 | List of class names e.g. ['neg', 'pos'].
321 | If none, the default ['Negative Class', 'Positive Class'] are
322 | used.
323 |
324 | return_scores (bool, optional):
325 | Flag indicating whether the method should return
326 | the train and test score of the model,
327 | together with the model interpretation report. If true,
328 | the output of this method is a tuple of DataFrame, float,
329 | float.
330 |
331 | shap_variance_penalty_factor (int or float, optional):
332 | Apply aggregation penalty when computing average of shap values for a given feature.
333 | Results in a preference for features that have smaller standard deviation of shap
334 | values (more coherent shap importance). Recommend value 0.5 - 1.0.
335 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor))
336 |
337 | **shap_kwargs:
338 | keyword arguments passed to
339 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
340 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
341 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while
342 | `check_additivity=False` disables the additivity check inside SHAP.
343 |
344 | Returns:
345 | (pd.DataFrame or tuple(pd.DataFrame, float, float)):
346 | Dataframe with SHAP feature importance, or tuple containing the dataframe, train and test scores of the
347 | model.
348 | """
349 | self.fit(
350 | X_train=X_train,
351 | X_test=X_test,
352 | y_train=y_train,
353 | y_test=y_test,
354 | column_names=column_names,
355 | class_names=class_names,
356 | **shap_kwargs,
357 | )
358 | return self.compute(return_scores=return_scores, shap_variance_penalty_factor=shap_variance_penalty_factor)
359 |
360 | def plot(self, plot_type, target_set="test", target_columns=None, samples_index=None, show=True, **plot_kwargs):
361 | """
362 | Plots the appropriate SHAP plot.
363 |
364 | Args:
365 | plot_type (str):
366 | One of the following:
367 |
368 | - `'importance'`: Feature importance plot, SHAP bar summary plot
369 | - `'summary'`: SHAP Summary plot
370 | - `'dependence'`: Dependence plot for each feature
371 | - `'sample'`: Explanation of a given sample in the test data
372 |
373 | target_set (str, optional):
374 | The set for which the plot should be generated, either `train` or `test` set. We recommend using test
375 | set, because it is not biased by model training. The train set plots are mainly used to compare with the
376 | test set plots, whether there are significant differences, which indicate shift in data distribution.
377 |
378 | target_columns (None, str or list of str, optional):
379 | List of features names, for which the plots should be generated. If None, all features will be plotted.
380 |
381 | samples_index (None, int, list or pd.Index, optional):
382 | Index of samples to be explained if the `plot_type=sample`.
383 |
384 | show (bool, optional):
385 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful,
386 | when you want to edit the returned axis, before showing it.
387 |
388 | **plot_kwargs:
389 | Keyword arguments passed to the plot method. For 'importance' and 'summary' plot_type, the kwargs are
390 | passed to shap.summary_plot, for 'dependence' plot_type, they are passed to
391 | probatus.interpret.DependencePlotter.plot method.
392 |
393 | Returns:
394 | (matplotlib.axes or list(matplotlib.axes)):
395 | An Axes with the plot, or list of axes when multiple plots are returned.
396 | """
397 | # Choose correct columns
398 | if target_columns is None:
399 | target_columns = self.column_names
400 |
401 | target_columns = assure_list_of_strings(target_columns, "target_columns")
402 | target_columns_indices = [self.column_names.index(target_column) for target_column in target_columns]
403 |
404 | # Choose the correct dataset
405 | if target_set == "test":
406 | target_X = self.X_test
407 | target_shap_values = self.shap_values_test
408 | target_tdp = self.tdp_test
409 | target_expected_value = self.expected_value_test
410 | elif target_set == "train":
411 | target_X = self.X_train
412 | target_shap_values = self.shap_values_train
413 | target_tdp = self.tdp_train
414 | target_expected_value = self.expected_value_train
415 | else:
416 | raise (ValueError('The target_set parameter can be either "train" or "test".'))
417 |
418 | if plot_type in ["importance", "summary"]:
419 | target_X = target_X[target_columns]
420 | target_shap_values = target_shap_values[:, target_columns_indices]
421 | # Set summary plot settings
422 | if plot_type == "importance":
423 | plot_type = "bar"
424 | plot_title = f"SHAP Feature Importance for {target_set} set"
425 | else:
426 | plot_type = "dot"
427 | plot_title = f"SHAP Summary plot for {target_set} set"
428 |
429 | summary_plot(
430 | target_shap_values,
431 | target_X,
432 | plot_type=plot_type,
433 | class_names=self.class_names,
434 | show=False,
435 | **plot_kwargs,
436 | )
437 | ax = plt.gca()
438 | ax.set_title(plot_title)
439 |
440 | ax.annotate(
441 | self.results_text,
442 | (0, 0),
443 | (0, -50),
444 | fontsize=12,
445 | xycoords="axes fraction",
446 | textcoords="offset points",
447 | va="top",
448 | )
449 | if show:
450 | plt.show()
451 | else:
452 | plt.close()
453 | elif plot_type == "dependence":
454 | ax = []
455 | for feature_name in target_columns:
456 | ax.append(target_tdp.plot(feature=feature_name, figsize=(10, 7), show=show, **plot_kwargs))
457 |
458 | elif plot_type == "sample":
459 | # Ensure the correct samples_index type
460 | if samples_index is None:
461 | raise (ValueError("For sample plot, you need to specify the samples_index be plotted plot"))
462 | elif isinstance(samples_index, int) or isinstance(samples_index, str):
463 | samples_index = [samples_index]
464 | elif not (isinstance(samples_index, list) or isinstance(samples_index, pd.Index)):
465 | raise (TypeError("sample_index must be one of the following: int, str, list or pd.Index"))
466 |
467 | ax = []
468 | for sample_index in samples_index:
469 | sample_loc = target_X.index.get_loc(sample_index)
470 |
471 | waterfall_legacy(
472 | target_expected_value,
473 | target_shap_values[sample_loc, :],
474 | target_X.loc[sample_index],
475 | show=False,
476 | **plot_kwargs,
477 | )
478 |
479 | plot_title = f"SHAP Sample Explanation of {target_set} sample for index={sample_index}"
480 | current_ax = plt.gca()
481 | current_ax.set_title(plot_title)
482 | ax.append(current_ax)
483 | if show:
484 | plt.show()
485 | else:
486 | plt.close()
487 | else:
488 | raise ValueError("Wrong plot type, select from 'importance', 'summary', or 'dependence'")
489 |
490 | if isinstance(ax, list) and len(ax) == 1:
491 | ax = ax[0]
492 | return ax
493 |
--------------------------------------------------------------------------------
/probatus/interpret/shap_dependence.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import pandas as pd
4 | from sklearn.preprocessing import KBinsDiscretizer
5 |
6 | from probatus.utils import BaseFitComputePlotClass, preprocess_data, preprocess_labels, shap_to_df
7 |
8 |
9 | class DependencePlotter(BaseFitComputePlotClass):
10 | """
11 | Plotter used to plot SHAP dependence plot together with the target rates.
12 |
13 | Currently it supports tree-based and linear models.
14 |
15 | Args:
16 | model: classifier for which interpretation is done.
17 |
18 | Example:
19 | ```python
20 | from sklearn.datasets import make_classification
21 | from sklearn.ensemble import RandomForestClassifier
22 | from probatus.interpret import DependencePlotter
23 |
24 | X, y = make_classification(n_samples=15, n_features=3, n_informative=3, n_redundant=0, random_state=42)
25 | model = RandomForestClassifier().fit(X, y)
26 | bdp = DependencePlotter(model)
27 | shap_values = bdp.fit_compute(X, y)
28 |
29 | bdp.plot(feature=2)
30 | ```
31 |
32 |
33 | """
34 |
35 | def __init__(self, model, verbose=0, random_state=None):
36 | """
37 | Initializes the class.
38 |
39 | Args:
40 | model (model object):
41 | regression or classification model or pipeline.
42 |
43 | verbose (int, optional):
44 | Controls verbosity of the output:
45 |
46 | - 0 - neither prints nor warnings are shown
47 | - 1 - only most important warnings
48 | - 2 - shows all prints and all warnings.
49 |
50 | random_state (int, optional):
51 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For
52 | reproducible results set it to an integer.
53 | """
54 | self.model = model
55 | self.verbose = verbose
56 | self.random_state = random_state
57 |
58 | def __repr__(self):
59 | """
60 | Represent string method.
61 | """
62 | return f"Shap dependence plotter for {self.model.__class__.__name__}"
63 |
64 | def fit(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs):
65 | """
66 | Fits the plotter to the model and data by computing the shap values.
67 |
68 | If the shap_values are passed, they do not need to be computed.
69 |
70 | Args:
71 | X (pd.DataFrame): input variables.
72 |
73 | y (pd.Series): target variable.
74 |
75 | column_names (None, or list of str, optional):
76 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used.
77 |
78 | class_names (None, or list of str, optional):
79 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are
80 | used.
81 |
82 | precalc_shap (Optional, None or np.array):
83 | Precalculated shap values, If provided they don't need to be computed.
84 |
85 | **shap_kwargs:
86 | keyword arguments passed to
87 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
88 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
89 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while
90 | `check_additivity=False` disables the additivity check inside SHAP.
91 | """
92 | self.X, self.column_names = preprocess_data(X, X_name="X", column_names=column_names, verbose=self.verbose)
93 | self.y = preprocess_labels(y, y_name="y", index=self.X.index, verbose=self.verbose)
94 |
95 | # Set class names
96 | self.class_names = class_names
97 | if self.class_names is None:
98 | self.class_names = ["Negative Class", "Positive Class"]
99 |
100 | self.shap_vals_df = shap_to_df(
101 | self.model,
102 | self.X,
103 | precalc_shap=precalc_shap,
104 | verbose=self.verbose,
105 | random_state=self.random_state,
106 | **shap_kwargs,
107 | )
108 |
109 | self.fitted = True
110 | return self
111 |
112 | def compute(self):
113 | """
114 | Computes the report returned to the user, namely the SHAP values generated on the dataset.
115 |
116 | Returns:
117 | (pd.DataFrame):
118 | SHAP Values for X.
119 | """
120 | self._check_if_fitted()
121 | return self.shap_vals_df
122 |
123 | def fit_compute(self, X, y, column_names=None, class_names=None, precalc_shap=None, **shap_kwargs):
124 | """
125 | Fits the plotter to the model and data by computing the shap values.
126 |
127 | If the shap_values are passed, they do not need to be computed
128 |
129 | Args:
130 | X (pd.DataFrame):
131 | Provided dataset.
132 |
133 | y (pd.Series):
134 | Labels for X.
135 |
136 | column_names (None, or list of str, optional):
137 | List of feature names for the dataset. If None, then column names from the X_train dataframe are used.
138 |
139 | class_names (None, or list of str, optional):
140 | List of class names e.g. ['neg', 'pos']. If none, the default ['Negative Class', 'Positive Class'] are
141 | used.
142 |
143 | precalc_shap (Optional, None or np.array):
144 | Precalculated shap values, If provided they don't need to be computed.
145 |
146 | **shap_kwargs:
147 | keyword arguments passed to
148 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
149 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
150 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while
151 | `check_additivity=False` disables the additivity check inside SHAP.
152 |
153 | Returns:
154 | (pd.DataFrame):
155 | SHAP Values for X.
156 | """
157 | self.fit(X, y, column_names=column_names, class_names=class_names, precalc_shap=precalc_shap, **shap_kwargs)
158 | return self.compute()
159 |
160 | def plot(
161 | self,
162 | feature,
163 | figsize=(15, 10),
164 | bins=10,
165 | show=True,
166 | min_q=0,
167 | max_q=1,
168 | alpha=1.0,
169 | ):
170 | """
171 | Plots the shap values for data points for a given feature, as well as the target rate and values distribution.
172 |
173 | Args:
174 | feature (str or int):
175 | Feature name of the feature to be analyzed.
176 |
177 | figsize ((float, float), optional):
178 | Tuple specifying size (width, height) of resulting figure in inches.
179 |
180 | bins (int or list[float]):
181 | Number of bins or boundaries of bins (supplied in list) for target-rate plot.
182 |
183 | show (bool, optional):
184 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful,
185 | when you want to edit the returned axis, before showing it.
186 |
187 | min_q (float, optional):
188 | Optional minimum quantile from which to consider values, used for plotting under outliers.
189 |
190 | max_q (float, optional):
191 | Optional maximum quantile until which data points are considered, used for plotting under outliers.
192 |
193 | alpha (float, optional):
194 | Optional alpha blending value, between 0 (transparent) and 1 (opaque).
195 |
196 | Returns
197 | (list(matplotlib.axes)):
198 | List of axes that include the plots.
199 | """
200 | self._check_if_fitted()
201 | if min_q >= max_q:
202 | raise ValueError("min_q must be smaller than max_q")
203 | if feature not in self.X.columns:
204 | raise ValueError("Feature not recognized")
205 | if (alpha < 0) or (alpha > 1):
206 | raise ValueError("alpha must be a float value between 0 and 1")
207 |
208 | self.min_q, self.max_q, self.alpha = min_q, max_q, alpha
209 |
210 | _ = plt.figure(1, figsize=figsize)
211 | ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
212 | ax2 = plt.subplot2grid((3, 1), (2, 0))
213 |
214 | self._dependence_plot(feature=feature, ax=ax1)
215 | self._target_rate_plot(feature=feature, bins=bins, ax=ax2)
216 |
217 | ax2.set_xlim(ax1.get_xlim())
218 |
219 | if show:
220 | plt.show()
221 | else:
222 | plt.close()
223 |
224 | return [ax1, ax2]
225 |
226 | def _dependence_plot(self, feature, ax=None):
227 | """
228 | Plots shap values for data points with respect to specified feature.
229 |
230 | Args:
231 | feature (str or int):
232 | Feature for which dependence plot is to be created.
233 |
234 | ax (matplotlib.pyplot.axes, optional):
235 | Optional axis on which to draw plot.
236 |
237 | Returns:
238 | (matplotlib.pyplot.axes):
239 | Axes on which plot is drawn.
240 | """
241 | if isinstance(feature, int):
242 | feature = self.column_names[feature]
243 |
244 | X, y, shap_val = self._get_X_y_shap_with_q_cut(feature=feature)
245 |
246 | ax.scatter(X[y == 0], shap_val[y == 0], label=self.class_names[0], color="lightblue", alpha=self.alpha)
247 |
248 | ax.scatter(X[y == 1], shap_val[y == 1], label=self.class_names[1], color="darkred", alpha=self.alpha)
249 |
250 | ax.set_ylabel("Shap value")
251 | ax.set_title(f"Dependence plot for {feature} feature")
252 | ax.legend()
253 |
254 | return ax
255 |
256 | def _target_rate_plot(self, feature, bins=10, ax=None):
257 | """
258 | Plots the distributions of the specific features, as well as the target rate as function of the feature.
259 |
260 | Args:
261 | feature (str or int):
262 | Feature for which to create target rate plot.
263 |
264 | bins (int or list[float]), optional:
265 | Number of bins or boundaries of desired bins in list.
266 |
267 | ax (matplotlib.pyplot.axes, optional):
268 | Optional axis on which to draw plot.
269 |
270 | Returns:
271 | (list[float], matplotlib.pyplot.axes, float):
272 | Tuple of boundaries of bins used, axis on which plot is drawn, total ratio of target (positive over
273 | negative).
274 | """
275 | x, y, shap_val = self._get_X_y_shap_with_q_cut(feature=feature)
276 |
277 | # Create bins if not explicitly supplied
278 | if isinstance(bins, int):
279 | simple_binner = KBinsDiscretizer(n_bins=bins, encode="ordinal", strategy="uniform").fit(
280 | np.array(x).reshape(-1, 1)
281 | )
282 | bins = simple_binner.bin_edges_[0]
283 | bins[0], bins[-1] = -np.inf, np.inf
284 |
285 | # Determine bin for datapoints
286 | bins[-1] = bins[-1] + 1
287 | indices = np.digitize(x, bins)
288 | # Create dataframe with binned data
289 | dfs = pd.DataFrame({feature: x, "y": y, "bin_index": pd.Series(indices, index=x.index)}).groupby(
290 | "bin_index", as_index=True
291 | )
292 |
293 | # Extract target ratio and mean feature value
294 | target_ratio = dfs["y"].mean()
295 | x_vals = dfs[feature].mean()
296 |
297 | # Transform the first and last bin to work with plt.hist method
298 | if bins[0] == -np.inf:
299 | bins[0] = x.min()
300 | if bins[-1] == np.inf:
301 | bins[-1] = x.max()
302 |
303 | # Plot target rate
304 | ax.hist(x, bins=bins, lw=2, alpha=0.4)
305 | ax.set_ylabel("Counts")
306 | ax2 = ax.twinx()
307 | ax2.plot(x_vals, target_ratio, color="red")
308 | ax2.set_ylabel("Target rate", color="red", fontsize=12)
309 | ax2.set_xlim(x.min(), x.max())
310 | ax.set_xlabel(f"{feature} feature values")
311 |
312 | return bins, ax, target_ratio
313 |
314 | def _get_X_y_shap_with_q_cut(self, feature):
315 | """
316 | Extracts all X, y pairs and shap values that fall within defined quantiles of the feature.
317 |
318 | Args:
319 | feature (str): feature to return values for
320 |
321 | Returns:
322 | x (pd.Series): selected datapoints
323 | y (pd.Series): target values of selected datapoints
324 | shap_val (pd.Series): shap values of selected datapoints
325 | """
326 | self._check_if_fitted()
327 | if feature not in self.X.columns:
328 | raise ValueError("Feature not found in data")
329 |
330 | # Prepare arrays
331 | x = self.X[feature]
332 | y = self.y
333 | shap_val = self.shap_vals_df[feature]
334 |
335 | # Determine quantile ranges
336 | x_min = x.quantile(self.min_q)
337 | x_max = x.quantile(self.max_q)
338 |
339 | # Create filter
340 | filter = (x >= x_min) & (x <= x_max)
341 |
342 | # Filter and return terms
343 | return x[filter], y[filter], shap_val[filter]
344 |
--------------------------------------------------------------------------------
/probatus/sample_similarity/__init__.py:
--------------------------------------------------------------------------------
1 | from .resemblance_model import (
2 | BaseResemblanceModel,
3 | PermutationImportanceResemblance,
4 | SHAPImportanceResemblance,
5 | )
6 |
7 | __all__ = [
8 | "BaseResemblanceModel",
9 | "PermutationImportanceResemblance",
10 | "SHAPImportanceResemblance",
11 | ]
12 |
--------------------------------------------------------------------------------
/probatus/sample_similarity/resemblance_model.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import pandas as pd
6 | from loguru import logger
7 | from shap import summary_plot
8 | from sklearn.inspection import permutation_importance
9 | from sklearn.model_selection import train_test_split
10 |
11 | from probatus.utils import BaseFitComputePlotClass, preprocess_data, preprocess_labels, get_single_scorer
12 | from probatus.utils.shap_helpers import calculate_shap_importance, shap_calc
13 |
14 |
15 | class BaseResemblanceModel(BaseFitComputePlotClass):
16 | """
17 | This model checks for the similarity of two samples.
18 |
19 | A possible use case is analysis of whether th train sample differs
20 | from the test sample, due to e.g. non-stationarity.
21 |
22 | This is a base class and needs to be extended by a fit() method, which implements how the data is split,
23 | how the model is trained and evaluated.
24 | Further, inheriting classes need to implement how feature importance should be indicated.
25 | """
26 |
27 | def __init__(
28 | self,
29 | model,
30 | scoring="roc_auc",
31 | test_prc=0.25,
32 | n_jobs=1,
33 | verbose=0,
34 | random_state=None,
35 | ):
36 | """
37 | Initializes the class.
38 |
39 | Args:
40 | model (model object):
41 | Regression or classification model or pipeline.
42 |
43 | scoring (string or probatus.utils.Scorer, optional):
44 | Metric for which the model performance is calculated. It can be either a metric name aligned with
45 | predefined
46 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html).
47 | Another option is using probatus.utils.Scorer to define a custom metric. The recommended option for this
48 | class is 'roc_auc'.
49 |
50 | test_prc (float, optional):
51 | Percentage of data used to test the model. By default 0.25 is set.
52 |
53 | n_jobs (int, optional):
54 | Number of parallel executions. If -1 use all available cores. By default 1.
55 |
56 | verbose (int, optional):
57 | Controls verbosity of the output:
58 |
59 | - 0 - neither prints nor warnings are shown
60 | - 1 - only most important warnings
61 | - 2 - shows all prints and all warnings.
62 |
63 | random_state (int, optional):
64 | Random state set at each round of feature elimination. If it is None, the results will not be
65 | reproducible and in random search at each iteration a different hyperparameters might be tested. For
66 | reproducible results set it to an integer.
67 | """ # noqa
68 | self.model = model
69 | self.test_prc = test_prc
70 | self.n_jobs = n_jobs
71 | self.random_state = random_state
72 | self.verbose = verbose
73 | self.scorer = get_single_scorer(scoring)
74 |
75 | def _init_output_variables(self):
76 | """
77 | Initializes variables that will be filled in during fit() method, and are used as output.
78 | """
79 | self.X_train = None
80 | self.X_test = None
81 | self.y_train = None
82 | self.y_test = None
83 | self.train_score = None
84 | self.test_score = None
85 | self.report = None
86 |
87 | def fit(self, X1, X2, column_names=None, class_names=None):
88 | """
89 | Base fit functionality that should be executed before each fit.
90 |
91 | Args:
92 | X1 (np.ndarray or pd.DataFrame):
93 | First sample to be compared. It needs to have the same number of columns as X2.
94 |
95 | X2 (np.ndarray or pd.DataFrame):
96 | Second sample to be compared. It needs to have the same number of columns as X1.
97 |
98 | column_names (list of str, optional):
99 | List of feature names of the provided samples. If provided it will be used to overwrite the existing
100 | feature names. If not provided the existing feature names are used or default feature names are
101 | generated.
102 |
103 | class_names (None, or list of str, optional):
104 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the
105 | default ['First Sample', 'Second Sample'] are used.
106 |
107 | Returns:
108 | (BaseResemblanceModel):
109 | Fitted object
110 | """
111 | # Set class names
112 | self.class_names = class_names
113 | if self.class_names is None:
114 | self.class_names = ["First Sample", "Second Sample"]
115 |
116 | # Ensure inputs are correct
117 | self.X1, self.column_names = preprocess_data(X1, X_name="X1", column_names=column_names, verbose=self.verbose)
118 | self.X2, _ = preprocess_data(X2, X_name="X2", column_names=column_names, verbose=self.verbose)
119 |
120 | # Prepare dataset for modelling
121 | self.X = pd.DataFrame(pd.concat([self.X1, self.X2], axis=0), columns=self.column_names).reset_index(drop=True)
122 |
123 | self.y = pd.Series(
124 | np.concatenate(
125 | [
126 | np.zeros(self.X1.shape[0]),
127 | np.ones(self.X2.shape[0]),
128 | ]
129 | )
130 | ).reset_index(drop=True)
131 |
132 | # Assure the type and number of classes for the variable
133 | self.X, _ = preprocess_data(self.X, X_name="X", column_names=self.column_names, verbose=self.verbose)
134 |
135 | self.y = preprocess_labels(self.y, y_name="y", index=self.X.index, verbose=self.verbose)
136 |
137 | # Reinitialize variables in case of multiple times being fit
138 | self._init_output_variables()
139 |
140 | self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
141 | self.X,
142 | self.y,
143 | test_size=self.test_prc,
144 | random_state=self.random_state,
145 | shuffle=True,
146 | stratify=self.y,
147 | )
148 | self.model.fit(self.X_train, self.y_train)
149 |
150 | self.train_score = np.round(self.scorer.score(self.model, self.X_train, self.y_train), 3)
151 | self.test_score = np.round(self.scorer.score(self.model, self.X_test, self.y_test), 3)
152 |
153 | self.results_text = (
154 | f"Train {self.scorer.metric_name}: {np.round(self.train_score, 3)},\n"
155 | f"Test {self.scorer.metric_name}: {np.round(self.test_score, 3)}."
156 | )
157 | if self.verbose > 1:
158 | logger.info(f"Finished model training: \n{self.results_text}")
159 |
160 | if self.verbose > 0:
161 | if self.train_score > self.test_score:
162 | warnings.warn(
163 | f"Train {self.scorer.metric_name} > Test {self.scorer.metric_name}, which might indicate "
164 | f"an overfit. \n Strong overfit might lead to misleading conclusions when analysing "
165 | f"feature importance. Consider retraining with more regularization applied to the model."
166 | )
167 | self.fitted = True
168 | return self
169 |
170 | def get_data_splits(self):
171 | """
172 | Returns the data splits used to train the Resemblance model.
173 |
174 | Returns:
175 | (pd.DataFrame, pd.DataFrame, pd.Series, pd.Series):
176 | X_train, X_test, y_train, y_test.
177 | """
178 | self._check_if_fitted()
179 | return self.X_train, self.X_test, self.y_train, self.y_test
180 |
181 | def compute(self, return_scores=False):
182 | """
183 | Checks if fit() method has been run and computes the output variables.
184 |
185 | Args:
186 | return_scores (bool, optional):
187 | Flag indicating whether the method should return a tuple (feature importances, train score,
188 | test score), or feature importances. By default the second option is selected.
189 |
190 | Returns:
191 | (tuple(pd.DataFrame, float, float) or pd.DataFrame):
192 | Depending on value of return_tuple either returns a tuple (feature importances, train AUC, test AUC), or
193 | feature importances.
194 | """
195 | self._check_if_fitted()
196 |
197 | if return_scores:
198 | return self.report, self.train_score, self.test_score
199 | else:
200 | return self.report
201 |
202 | def fit_compute(
203 | self,
204 | X1,
205 | X2,
206 | column_names=None,
207 | class_names=None,
208 | return_scores=False,
209 | **fit_kwargs,
210 | ):
211 | """
212 | Fits the resemblance model and computes the report regarding feature importance.
213 |
214 | Args:
215 | X1 (np.ndarray or pd.DataFrame):
216 | First sample to be compared. It needs to have the same number of columns as X2.
217 |
218 | X2 (np.ndarray or pd.DataFrame):
219 | Second sample to be compared. It needs to have the same number of columns as X1.
220 |
221 | column_names (list of str, optional):
222 | List of feature names of the provided samples. If provided it will be used to overwrite the existing
223 | feature names. If not provided the existing feature names are used or default feature names are
224 | generated.
225 |
226 | class_names (None, or list of str, optional):
227 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the
228 | default ['First Sample', 'Second Sample'] are used.
229 |
230 | return_scores (bool, optional):
231 | Flag indicating whether the method should return a tuple (feature importances, train score,
232 | test score), or feature importances. By default the second option is selected.
233 |
234 | **fit_kwargs:
235 | In case any other arguments are accepted by fit() method, they can be passed as keyword arguments.
236 |
237 | Returns:
238 | (tuple of (pd.DataFrame, float, float) or pd.DataFrame):
239 | Depending on value of return_tuple either returns a tuple (feature importances, train AUC, test AUC), or
240 | feature importances.
241 | """
242 | self.fit(X1, X2, column_names=column_names, class_names=class_names, **fit_kwargs)
243 | return self.compute(return_scores=return_scores)
244 |
245 | def plot(self):
246 | """
247 | Plot.
248 | """
249 | raise (NotImplementedError("Plot method has not been implemented."))
250 |
251 |
252 | class PermutationImportanceResemblance(BaseResemblanceModel):
253 | """
254 | This model checks the similarity of two samples.
255 |
256 | A possible use case is analysis of whether the train sample differs
257 | from the test sample, due to e.g. non-stationarity.
258 |
259 | It assigns labels to each sample, 0 to the first sample, 1 to the second. Then, it randomly selects a portion of
260 | data to train on. The resulting model tries to distinguish which sample a given test row comes from. This
261 | provides insights on how distinguishable these samples are and which features contribute to that. The feature
262 | importance is calculated using permutation importance.
263 |
264 | If the model achieves a test AUC significantly different than 0.5, it indicates that it is possible to distinguish
265 | between the samples, and therefore, the samples differ.
266 | Features with a high permutation importance contribute to that effect the most.
267 | Thus, their distribution might differ between two samples.
268 |
269 | Examples:
270 | ```python
271 | from sklearn.datasets import make_classification
272 | from sklearn.ensemble import RandomForestClassifier
273 | from probatus.sample_similarity import PermutationImportanceResemblance
274 | X1, _ = make_classification(n_samples=100, n_features=5)
275 | X2, _ = make_classification(n_samples=100, n_features=5, shift=0.5)
276 | model = RandomForestClassifier(max_depth=2)
277 | perm = PermutationImportanceResemblance(model)
278 | feature_importance = perm.fit_compute(X1, X2)
279 | perm.plot()
280 | ```
281 |
282 | """
283 |
284 | def __init__(
285 | self,
286 | model,
287 | iterations=100,
288 | scoring="roc_auc",
289 | test_prc=0.25,
290 | n_jobs=1,
291 | verbose=0,
292 | random_state=None,
293 | ):
294 | """
295 | Initializes the class.
296 |
297 | Args:
298 | model (model object):
299 | Regression or classification model or pipeline.
300 |
301 | iterations (int, optional):
302 | Number of iterations performed to calculate permutation importance. By default 100 iterations per
303 | feature are done.
304 |
305 | scoring (string or probatus.utils.Scorer, optional):
306 | Metric for which the model performance is calculated. It can be either a metric name aligned with
307 | predefined
308 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html).
309 | Another option is using probatus.utils.Scorer to define a custom metric. Recommended option for this
310 | class is 'roc_auc'.
311 |
312 | test_prc (float, optional):
313 | Percentage of data used to test the model. By default 0.25 is set.
314 |
315 | n_jobs (int, optional):
316 | Number of parallel executions. If -1 use all available cores. By default 1.
317 |
318 | verbose (int, optional):
319 | Controls verbosity of the output:
320 |
321 | - 0 - neither prints nor warnings are shown
322 | - 1 - only most important warnings
323 | - 2 - shows all prints and all warnings.
324 |
325 | random_state (int, optional):
326 | Random state set at each round of feature elimination. If it is None, the results will not be
327 | reproducible and in random search at each iteration a different hyperparameters might be tested. For
328 | reproducible results set it to integer.
329 | """ # noqa
330 | super().__init__(
331 | model=model,
332 | scoring=scoring,
333 | test_prc=test_prc,
334 | n_jobs=n_jobs,
335 | verbose=verbose,
336 | random_state=random_state,
337 | )
338 |
339 | self.iterations = iterations
340 |
341 | self.iterations_columns = ["feature", "importance"]
342 | self.iterations_results = pd.DataFrame(columns=self.iterations_columns)
343 |
344 | self.plot_x_label = "Permutation Feature Importance"
345 | self.plot_y_label = "Feature Name"
346 | self.plot_title = "Permutation Feature Importance of Resemblance Model"
347 |
348 | def fit(self, X1, X2, column_names=None, class_names=None):
349 | """
350 | This function assigns labels to each sample, 0 to the first sample, 1 to the second.
351 |
352 | Then, it randomly selects a
353 | portion of data to train on. The resulting model tries to distinguish which sample a given test row
354 | comes from. This provides insights on how distinguishable these samples are and which features contribute to
355 | that. The feature importance is calculated using permutation importance.
356 |
357 | Args:
358 | X1 (np.ndarray or pd.DataFrame):
359 | First sample to be compared. It needs to have the same number of columns as X2.
360 |
361 | X2 (np.ndarray or pd.DataFrame):
362 | Second sample to be compared. It needs to have the same number of columns as X1.
363 |
364 | column_names (list of str, optional):
365 | List of feature names of the provided samples. If provided it will be used to overwrite the existing
366 | feature names. If not provided the existing feature names are used or default feature names are
367 | generated.
368 |
369 | class_names (None, or list of str, optional):
370 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the
371 | default ['First Sample', 'Second Sample'] are used.
372 |
373 | Returns:
374 | (PermutationImportanceResemblance):
375 | Fitted object.
376 | """
377 | super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names)
378 |
379 | permutation_result = permutation_importance(
380 | self.model,
381 | self.X_test,
382 | self.y_test,
383 | scoring=self.scorer.scorer,
384 | n_repeats=self.iterations,
385 | n_jobs=self.n_jobs,
386 | )
387 |
388 | # Prepare report
389 | self.report_columns = ["mean_importance", "std_importance"]
390 | self.report = pd.DataFrame(index=self.column_names, columns=self.report_columns, dtype=float)
391 |
392 | for feature_index, feature_name in enumerate(self.column_names):
393 | # Fill in the report
394 | self.report.loc[feature_name, "mean_importance"] = permutation_result["importances_mean"][feature_index]
395 | self.report.loc[feature_name, "std_importance"] = permutation_result["importances_std"][feature_index]
396 |
397 | # Fill in the iterations
398 | current_iterations = pd.DataFrame(
399 | np.stack(
400 | [
401 | np.repeat(feature_name, self.iterations),
402 | permutation_result["importances"][feature_index, :].reshape((self.iterations,)),
403 | ],
404 | axis=1,
405 | ),
406 | columns=self.iterations_columns,
407 | )
408 |
409 | self.iterations_results = pd.concat([self.iterations_results, current_iterations])
410 |
411 | self.iterations_results["importance"] = self.iterations_results["importance"].astype(float)
412 |
413 | # Sort by mean test score of first metric
414 | self.report.sort_values(by="mean_importance", ascending=False, inplace=True)
415 |
416 | return self
417 |
418 | def plot(self, ax=None, top_n=None, show=True, **plot_kwargs):
419 | """
420 | Plots the resulting AUC of the model as well as the feature importances.
421 |
422 | Args:
423 | ax (matplotlib.axes, optional):
424 | Axes to which the output should be plotted. If not provided new axes are created.
425 |
426 | top_n (int, optional):
427 | Number of the most important features to be plotted. By default features are included in the plot.
428 |
429 | show (bool, optional):
430 | If True, the plots are shown to the user, otherwise they are not shown. Not showing a plot can be useful
431 | when you want to edit the returned axis before showing it.
432 |
433 | **plot_kwargs:
434 | Keyword arguments passed to the matplotlib.plotly.subplots method.
435 |
436 | Returns:
437 | (matplotlib.axes):
438 | Axes that include the plot.
439 | """
440 |
441 | feature_report = self.compute()
442 | self.iterations_results["importance"] = self.iterations_results["importance"].astype(float)
443 |
444 | sorted_features = feature_report["mean_importance"].sort_values(ascending=True).index.values
445 | if top_n is not None and top_n > 0:
446 | sorted_features = sorted_features[-top_n:]
447 |
448 | if ax is None:
449 | fig, ax = plt.subplots(**plot_kwargs)
450 |
451 | for position, feature in enumerate(sorted_features):
452 | ax.boxplot(
453 | self.iterations_results[self.iterations_results["feature"] == feature]["importance"],
454 | positions=[position],
455 | vert=False,
456 | )
457 |
458 | ax.set_yticks(range(position + 1))
459 | ax.set_yticklabels(sorted_features)
460 | ax.set_xlabel(self.plot_x_label)
461 | ax.set_ylabel(self.plot_y_label)
462 | ax.set_title(self.plot_title)
463 |
464 | ax.annotate(
465 | self.results_text,
466 | (0, 0),
467 | (0, -50),
468 | fontsize=12,
469 | xycoords="axes fraction",
470 | textcoords="offset points",
471 | va="top",
472 | )
473 |
474 | if show:
475 | plt.show()
476 | else:
477 | plt.close()
478 |
479 | return ax
480 |
481 |
482 | class SHAPImportanceResemblance(BaseResemblanceModel):
483 | """
484 | This model checks for similarity of two samples.
485 |
486 | A possible use case is analysis of whether the train sample differs
487 | from the test sample, due to e.g. non-stationarity.
488 |
489 | It assigns labels to each sample, 0 to the first sample, 1 to the second. Then, it randomly selects a portion of
490 | data to train on. The resulting model tries to distinguish which sample a given test row comes from. This
491 | provides insights on how distinguishable these samples are and which features contribute to that. The feature
492 | importance is calculated using SHAP feature importance.
493 |
494 | If the model achieves test AUC significantly different than 0.5, it indicates that it is possible to distinguish
495 | between the samples, and therefore, the samples differ. Features with a high permutation importance contribute
496 | to that effect the most. Thus, their distribution might differ between two samples.
497 |
498 | This class currently works only with the Tree based models.
499 |
500 | Examples:
501 | ```python
502 | from sklearn.datasets import make_classification
503 | from sklearn.ensemble import RandomForestClassifier
504 | from probatus.sample_similarity import SHAPImportanceResemblance
505 | X1, _ = make_classification(n_samples=100, n_features=5)
506 | X2, _ = make_classification(n_samples=100, n_features=5, shift=0.5)
507 | model = RandomForestClassifier(max_depth=2)
508 | rm = SHAPImportanceResemblance(model)
509 | feature_importance = rm.fit_compute(X1, X2)
510 | rm.plot()
511 | ```
512 |
513 |
514 |
515 | """
516 |
517 | def __init__(
518 | self,
519 | model,
520 | scoring="roc_auc",
521 | test_prc=0.25,
522 | n_jobs=1,
523 | verbose=0,
524 | random_state=None,
525 | ):
526 | """
527 | Initializes the class.
528 |
529 | Args:
530 | model (model object):
531 | Regression or classification model or pipeline.
532 |
533 | scoring (string or probatus.utils.Scorer, optional):
534 | Metric for which the model performance is calculated. It can be either a metric name aligned with
535 | predefined
536 | [classification scorers names in sklearn](https://scikit-learn.org/stable/modules/model_evaluation.html).
537 | Another option is using probatus.utils.Scorer to define a custom metric. Recommended option for this
538 | class is 'roc_auc'.
539 |
540 | test_prc (float, optional):
541 | Percentage of data used to test the model. By default 0.25 is set.
542 |
543 | n_jobs (int, optional):
544 | Number of parallel executions. If -1 use all available cores. By default 1.
545 |
546 | verbose (int, optional):
547 | Controls verbosity of the output:
548 |
549 | - 0 - neither prints nor warnings are shown
550 | - 1 - only most important warnings
551 | - 2 - shows all prints and all warnings.
552 |
553 | random_state (int, optional):
554 | Random state set at each round of feature elimination. If it is None, the results will not be
555 | reproducible and in random search at each iteration a different hyperparameters might be tested. For
556 | reproducible results set it to integer.
557 | """ # noqa
558 | super().__init__(
559 | model=model,
560 | scoring=scoring,
561 | test_prc=test_prc,
562 | n_jobs=n_jobs,
563 | verbose=verbose,
564 | random_state=random_state,
565 | )
566 |
567 | self.plot_title = "SHAP summary plot"
568 |
569 | def fit(self, X1, X2, column_names=None, class_names=None, **shap_kwargs):
570 | """
571 | This function assigns labels to each sample, 0 to the first sample, 1 to the second.
572 |
573 | Then, it randomly selects a
574 | portion of data to train on. The resulting model tries to distinguish which sample a given test row
575 | comes from. This provides insights on how distinguishable these samples are and which features contribute to
576 | that. The feature importance is calculated using SHAP feature importance.
577 |
578 | Args:
579 | X1 (np.ndarray or pd.DataFrame):
580 | First sample to be compared. It needs to have the same number of columns as X2.
581 |
582 | X2 (np.ndarray or pd.DataFrame):
583 | Second sample to be compared. It needs to have the same number of columns as X1.
584 |
585 | column_names (list of str, optional):
586 | List of feature names of the provided samples. If provided it will be used to overwrite the existing
587 | feature names. If not provided the existing feature names are used or default feature names are
588 | generated.
589 |
590 | class_names (None, or list of str, optional):
591 | List of class names assigned, in this case provided samples e.g. ['sample1', 'sample2']. If none, the
592 | default ['First Sample', 'Second Sample'] are used.
593 |
594 | **shap_kwargs:
595 | keyword arguments passed to
596 | [shap.Explainer](https://shap.readthedocs.io/en/latest/generated/shap.Explainer.html#shap.Explainer).
597 | It also enables `approximate` and `check_additivity` parameters, passed while calculating SHAP values.
598 | The `approximate=True` causes less accurate, but faster SHAP values calculation, while
599 | `check_additivity=False` disables the additivity check inside SHAP.
600 |
601 | Returns:
602 | (SHAPImportanceResemblance):
603 | Fitted object.
604 | """
605 | super().fit(X1=X1, X2=X2, column_names=column_names, class_names=class_names)
606 |
607 | self.shap_values_test = shap_calc(
608 | self.model, self.X_test, verbose=self.verbose, random_state=self.random_state, **shap_kwargs
609 | )
610 | self.report = calculate_shap_importance(self.shap_values_test, self.column_names)
611 | return self
612 |
613 | def plot(self, plot_type="bar", show=True, **summary_plot_kwargs):
614 | """
615 | Plots the resulting AUC of the model as well as the feature importances.
616 |
617 | Args:
618 | plot_type (str, optional): Type of plot, used to compute shap.summary_plot. By default 'bar', available ones
619 | are "dot", "bar", "violin",
620 |
621 | show (bool, optional):
622 | If True, the plots are showed to the user, otherwise they are not shown. Not showing plot can be useful,
623 | when you want to edit the returned axis, before showing it.
624 |
625 | **summary_plot_kwargs:
626 | kwargs passed to the shap.summary_plot.
627 |
628 | Returns:
629 | (matplotlib.axes):
630 | Axes that include the plot.
631 | """
632 |
633 | # This line serves as a double check if the object has been fitted
634 | self._check_if_fitted()
635 |
636 | summary_plot(
637 | self.shap_values_test,
638 | self.X_test,
639 | plot_type=plot_type,
640 | class_names=self.class_names,
641 | show=False,
642 | **summary_plot_kwargs,
643 | )
644 | ax = plt.gca()
645 | ax.set_title(self.plot_title)
646 |
647 | ax.annotate(
648 | self.results_text,
649 | (0, 0),
650 | (0, -50),
651 | fontsize=12,
652 | xycoords="axes fraction",
653 | textcoords="offset points",
654 | va="top",
655 | )
656 |
657 | if show:
658 | plt.show()
659 | else:
660 | plt.close()
661 |
662 | return ax
663 |
664 | def get_shap_values(self):
665 | """
666 | Gets the SHAP values generated on the test set.
667 |
668 | Returns:
669 | (np.array):
670 | SHAP values generated on the test set.
671 | """
672 | self._check_if_fitted()
673 | return self.shap_values_test
674 |
--------------------------------------------------------------------------------
/probatus/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .exceptions import NotFittedError
2 | from .arrayfuncs import (
3 | assure_pandas_df,
4 | assure_pandas_series,
5 | preprocess_data,
6 | preprocess_labels,
7 | )
8 | from .scoring import Scorer, get_single_scorer
9 | from .shap_helpers import shap_calc, shap_to_df, calculate_shap_importance
10 | from ._utils import assure_list_of_strings
11 | from .base_class_interface import BaseFitComputeClass, BaseFitComputePlotClass
12 |
13 | __all__ = [
14 | "assure_list_of_strings",
15 | "assure_pandas_df",
16 | "assure_pandas_series",
17 | "preprocess_data",
18 | "preprocess_labels",
19 | "BaseFitComputeClass",
20 | "BaseFitComputePlotClass",
21 | "NotFittedError",
22 | "get_single_scorer",
23 | "Scorer",
24 | "shap_calc",
25 | "shap_to_df",
26 | "calculate_shap_importance",
27 | ]
28 |
--------------------------------------------------------------------------------
/probatus/utils/_utils.py:
--------------------------------------------------------------------------------
1 | def assure_list_of_strings(variable, variable_name):
2 | """
3 | Make sure object is a list of strings.
4 | """
5 | if isinstance(variable, list):
6 | return variable
7 | elif isinstance(variable, str):
8 | return [variable]
9 | else:
10 | raise (ValueError("{} needs to be either a string or list of strings.").format(variable_name))
11 |
--------------------------------------------------------------------------------
/probatus/utils/arrayfuncs.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 |
7 | def assure_pandas_df(x, column_names=None):
8 | """
9 | Returns x as pandas DataFrame. X can be a list, list of lists, numpy array, pandas DataFrame or pandas Series.
10 |
11 | Args:
12 | x (list, numpy array, pandas DataFrame, pandas Series): array to be tested
13 |
14 | Returns:
15 | pandas DataFrame
16 | """
17 | if isinstance(x, pd.DataFrame):
18 | if column_names is not None:
19 | x.columns = column_names
20 | elif isinstance(x, (np.ndarray, pd.Series, list)):
21 | x = pd.DataFrame(x, columns=column_names)
22 | else:
23 | raise TypeError("Please supply a list, numpy array, pandas Series or pandas DataFrame")
24 |
25 | return x
26 |
27 |
28 | def assure_pandas_series(x, index=None):
29 | """
30 | Returns x as pandas Series. X can be a list, numpy array, or pandas Series.
31 |
32 | Args:
33 | x (list, numpy array, pandas DataFrame, pandas Series): array to be tested
34 |
35 | Returns:
36 | pandas Series
37 | """
38 | if isinstance(x, pd.Series):
39 | if isinstance(index, (list, np.ndarray)):
40 | index = pd.Index(index)
41 | current_x_index = pd.Index(x.index.values)
42 | if current_x_index.equals(index):
43 | # If exact match then keep it as it is
44 | return x
45 | elif current_x_index.sort_values().equals(index.sort_values()):
46 | # If both have the same values but in different order, then reorder
47 | return x[index]
48 | else:
49 | # If indexes have different values, overwrite
50 | x.index = index
51 | return x
52 | elif any([isinstance(x, (np.ndarray, list))]):
53 | return pd.Series(x, index=index)
54 | else:
55 | raise TypeError("Please supply a list, numpy array, pandas Series")
56 |
57 |
58 | def preprocess_data(X, X_name=None, column_names=None, verbose=0):
59 | """
60 | Preprocess data.
61 |
62 | Does basic preprocessing of the data: Transforms to DataFrame, Warns which features have missing variables,
63 | and transforms object dtype features to category type, such that LightGBM handles them by default.
64 |
65 | Args:
66 | X (pd.DataFrame, list of lists, np.array):
67 | Provided dataset.
68 |
69 | X_name (str, optional):
70 | Name of the X variable, that will be printed in the warnings.
71 |
72 | column_names (list of str, optional):
73 | List of feature names of the provided samples. If provided it will be used to overwrite the existing
74 | feature names. If not provided the existing feature names are used or default feature names are
75 | generated.
76 |
77 | verbose (int, optional):
78 | Controls verbosity of the output:
79 |
80 | - 0 - neither prints nor warnings are shown
81 | - 1 - only most important warnings
82 | - 2 - shows all prints and all warnings.
83 |
84 |
85 | Returns:
86 | (pd.DataFrame):
87 | Preprocessed dataset.
88 | """
89 | X_name = "X" if X_name is None else X_name
90 |
91 | # Make sure that X is a pd.DataFrame with correct column names
92 | X = assure_pandas_df(X, column_names=column_names)
93 |
94 | if verbose > 0:
95 | # Warn if missing
96 | columns_with_missing = X.columns[X.isnull().any()].tolist()
97 | if columns_with_missing:
98 | warnings.warn(
99 | f"The following variables in {X_name} contains missing values {columns_with_missing}. "
100 | f"Make sure to impute missing or apply a model that handles them automatically."
101 | )
102 |
103 | # Warn if categorical features and change to category
104 | categorical_features = X.select_dtypes(include=["category", "object"]).columns.tolist()
105 | # Set categorical features type to category
106 | if categorical_features:
107 | if verbose > 0:
108 | warnings.warn(
109 | f"The following variables in {X_name} contains categorical variables: "
110 | f"{categorical_features}. Make sure to use a model that handles them automatically or "
111 | f"encode them into numerical variables."
112 | )
113 |
114 | # Ensure category dtype, to enable models e.g. LighGBM, handle them automatically
115 | object_columns = X.select_dtypes(include=["object"]).columns
116 | if not object_columns.empty:
117 | X[object_columns] = X[object_columns].astype("category")
118 |
119 | return X, X.columns.tolist()
120 |
121 |
122 | def preprocess_labels(y, y_name=None, index=None, verbose=0):
123 | """
124 | Does basic preparation of the labels. Turns them into Series, and WARS in case the target is not binary.
125 |
126 | Args:
127 | y (pd.Series, list, np.array):
128 | Provided labels.
129 |
130 | y_name (str, optional):
131 | Name of the y variable, that will be printed in the warnings.
132 |
133 | index (list of int or pd.Index, optional):
134 | The index correct index that should be used for y. In case y is a list or np.array, the index is set when
135 | creating pd.Series. In case it is pd.Series already, if the indexes consist of the same values, the y is
136 | going to be ordered based on provided index, otherwise, the current index of y is overwritten by index
137 | argument.
138 |
139 | verbose (int, optional):
140 | Controls verbosity of the output:
141 |
142 | - 0 - neither prints nor warnings are shown
143 | - 1 - only most important warnings
144 | - 2 - shows all prints and all warnings.
145 |
146 | Returns:
147 | (pd.Series):
148 | Labels in the form of pd.Series.
149 | """
150 | y_name = "y" if y_name is None else y_name
151 |
152 | # Make sure that y is a series with correct index
153 | y = assure_pandas_series(y, index=index)
154 |
155 | return y
156 |
--------------------------------------------------------------------------------
/probatus/utils/base_class_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from probatus.utils import NotFittedError
4 |
5 |
6 | class BaseFitComputeClass(ABC):
7 | """
8 | Placeholder that must be overwritten by subclass.
9 | """
10 |
11 | fitted = False
12 |
13 | def _check_if_fitted(self):
14 | """
15 | Checks if object has been fitted. If not, NotFittedError is raised.
16 | """
17 | if not self.fitted:
18 | raise (NotFittedError("The object has not been fitted. Please run fit() method first"))
19 |
20 | @abstractmethod
21 | def fit(self, *args, **kwargs):
22 | """
23 | Placeholder that must be overwritten by subclass.
24 | """
25 | pass
26 |
27 | @abstractmethod
28 | def compute(self, *args, **kwargs):
29 | """
30 | Placeholder that must be overwritten by subclass.
31 | """
32 | pass
33 |
34 | @abstractmethod
35 | def fit_compute(self, *args, **kwargs):
36 | """
37 | Placeholder that must be overwritten by subclass.
38 | """
39 | pass
40 |
41 |
42 | class BaseFitComputePlotClass(BaseFitComputeClass):
43 | """
44 | Base class.
45 | """
46 |
47 | @abstractmethod
48 | def plot(self, *args, **kwargs):
49 | """
50 | Placeholder method for plotting.
51 | """
52 | pass
53 |
--------------------------------------------------------------------------------
/probatus/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | class NotFittedError(Exception):
2 | """
3 | Error.
4 | """
5 |
6 | def __init__(self, message):
7 | """
8 | Init error.
9 | """
10 | self.message = message
11 |
--------------------------------------------------------------------------------
/probatus/utils/scoring.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import get_scorer
2 |
3 |
4 | def get_single_scorer(scoring):
5 | """
6 | Returns Scorer, based on provided input in scoring argument.
7 |
8 | Args:
9 | scoring (string or probatus.utils.Scorer, optional):
10 | Metric for which the model performance is calculated. It can be either a metric name aligned with
11 | predefined classification scorers names in sklearn
12 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)).
13 | Another option is using probatus.utils.Scorer to define a custom metric.
14 |
15 | Returns:
16 | (probatus.utils.Scorer):
17 | Scorer that can be used for scoring models
18 | """
19 | if isinstance(scoring, str):
20 | return Scorer(scoring)
21 | elif isinstance(scoring, Scorer):
22 | return scoring
23 | else:
24 | raise (ValueError("The scoring should contain either strings or probatus.utils.Scorer class"))
25 |
26 |
27 | class Scorer:
28 | """
29 | Scores a given machine learning model based on the provided metric name and optionally a custom scoring function.
30 |
31 | Examples:
32 |
33 | ```python
34 | from probatus.utils import Scorer
35 | from sklearn.metrics import make_scorer
36 | from sklearn.datasets import make_classification
37 | from sklearn.model_selection import train_test_split
38 | from sklearn.ensemble import RandomForestClassifier
39 | import pandas as pd
40 |
41 | # Make ROC AUC scorer
42 | scorer1 = Scorer('roc_auc')
43 |
44 | # Make custom scorer with following function:
45 | def custom_metric(y_true, y_pred):
46 | return (y_true == y_pred).sum()
47 | scorer2 = Scorer('custom_metric', custom_scorer=make_scorer(custom_metric))
48 |
49 | # Prepare two samples
50 | feature_names = ['f1', 'f2', 'f3', 'f4']
51 | X, y = make_classification(n_samples=1000, n_features=4, random_state=0)
52 | X = pd.DataFrame(X, columns=feature_names)
53 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
54 |
55 | # Prepare and fit model. Remember about class_weight="balanced" or an equivalent.
56 | model = RandomForestClassifier(class_weight='balanced', n_estimators = 100, max_depth=2, random_state=0)
57 | model = model.fit(X_train, y_train)
58 |
59 | # Score model
60 | score_test_scorer1 = scorer1.score(model, X_test, y_test)
61 | score_test_scorer2 = scorer2.score(model, X_test, y_test)
62 |
63 | print(f'Test ROC AUC is {score_test_scorer1}, Test {scorer2.metric_name} is {score_test_scorer2}')
64 | ```
65 | """
66 |
67 | def __init__(self, metric_name, custom_scorer=None):
68 | """
69 | Initializes the class.
70 |
71 | Args:
72 | metric_name (str): Name of the metric used to evaluate the model.
73 | If the custom_scorer is not passed, the
74 | metric name needs to be aligned with classification scorers names in sklearn
75 | ([link](https://scikit-learn.org/stable/modules/model_evaluation.html)).
76 | custom_scorer (sklearn.metrics Scorer callable, optional): Callable
77 | that can score samples.
78 | """
79 | self.metric_name = metric_name
80 | if custom_scorer is not None:
81 | self.scorer = custom_scorer
82 | else:
83 | self.scorer = get_scorer(self.metric_name)
84 |
85 | def score(self, model, X, y):
86 | """
87 | Scores the samples model based on the provided metric name.
88 |
89 | Args
90 | model (model object):
91 | Model to be scored.
92 |
93 | X (array-like of shape (n_samples,n_features)):
94 | Samples on which the model is scored.
95 |
96 | y (array-like of shape (n_samples,)):
97 | Labels on which the model is scored.
98 |
99 | Returns:
100 | (float):
101 | Score returned by the model
102 | """
103 | return self.scorer(model, X, y)
104 |
--------------------------------------------------------------------------------
/probatus/utils/shap_helpers.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from shap import Explainer
6 | from shap.explainers import TreeExplainer
7 | from shap.utils import sample
8 | from sklearn.pipeline import Pipeline
9 |
10 |
11 | def shap_calc(
12 | model,
13 | X,
14 | return_explainer=False,
15 | verbose=0,
16 | random_state=None,
17 | sample_size=100,
18 | approximate=False,
19 | check_additivity=True,
20 | **shap_kwargs,
21 | ):
22 | """
23 | Helper function to calculate the shapley values for a given model.
24 |
25 | Args:
26 | model (model):
27 | Trained model.
28 |
29 | X (pd.DataFrame or np.ndarray):
30 | features set.
31 |
32 | return_explainer (boolean):
33 | if True, returns a a tuple (shap_values, explainer).
34 |
35 | verbose (int, optional):
36 | Controls verbosity of the output:
37 |
38 | - 0 - neither prints nor warnings are shown
39 | - 1 - only most important warnings
40 | - 2 - shows all prints and all warnings.
41 |
42 | random_state (int, optional):
43 | Random state set for the nr of samples. If it is None, the results will not be reproducible. For
44 | reproducible results set it to an integer.
45 |
46 | approximate (boolean):
47 | if True uses shap approximations - less accurate, but very fast. It applies to tree-based explainers only.
48 |
49 | check_additivity (boolean):
50 | if False SHAP will disable the additivity check for tree-based models.
51 |
52 | **shap_kwargs: kwargs of the shap.Explainer
53 |
54 | Returns:
55 | (np.ndarray or tuple(np.ndarray, shap.Explainer)):
56 | shapley_values for the model, optionally also returns the explainer.
57 |
58 | """
59 | if isinstance(model, Pipeline):
60 | raise TypeError(
61 | "The provided model is a Pipeline. Unfortunately, the features based on SHAP do not support "
62 | "pipelines, because they cannot be used in combination with shap.Explainer. Please apply any "
63 | "data transformations before running the probatus module."
64 | )
65 |
66 | # Suppress warnings regarding XGboost and Lightgbm models.
67 | with warnings.catch_warnings():
68 | warnings.simplefilter("ignore" if verbose <= 1 else "default")
69 |
70 | # For tree explainers, do not pass masker when feature_perturbation is
71 | # tree_path_dependent, or when X contains categorical features
72 | # related to issue:
73 | # https://github.com/slundberg/shap/issues/480
74 | if shap_kwargs.get("feature_perturbation") == "tree_path_dependent" or X.select_dtypes("category").shape[1] > 0:
75 | # Calculate Shap values.
76 | explainer = Explainer(model, seed=random_state, **shap_kwargs)
77 | else:
78 | # Create the background data,required for non tree based models.
79 | # A single datapoint can passed as mask
80 | # (https://github.com/slundberg/shap/issues/955#issuecomment-569837201)
81 | if X.shape[0] < sample_size:
82 | sample_size = int(np.ceil(X.shape[0] * 0.2))
83 | else:
84 | pass
85 | mask = sample(X, sample_size, random_state=random_state)
86 | explainer = Explainer(model, seed=random_state, masker=mask, **shap_kwargs)
87 |
88 | # For tree-explainers allow for using check_additivity and approximate arguments
89 | if isinstance(explainer, TreeExplainer):
90 | shap_values = explainer.shap_values(X, check_additivity=check_additivity, approximate=approximate)
91 |
92 | # From SHAP version 0.43+ https://github.com/shap/shap/pull/3121 required to
93 | # get the second dimension of calculated Shap values.
94 | if not isinstance(shap_values, list) and len(shap_values.shape) == 3:
95 | shap_values = shap_values[:, :, 1]
96 | else:
97 | # Calculate Shap values
98 | shap_values = explainer.shap_values(X)
99 |
100 | if isinstance(shap_values, list) and len(shap_values) == 2:
101 | warnings.warn(
102 | "Shap values are related to the output probabilities of class 1 for this model, instead of log odds."
103 | )
104 | shap_values = shap_values[1]
105 |
106 | if return_explainer:
107 | return shap_values, explainer
108 | return shap_values
109 |
110 |
111 | def shap_to_df(model, X, precalc_shap=None, **kwargs):
112 | """
113 | Calculates the shap values and return the pandas DataFrame with the columns and the index of the original.
114 |
115 | Args:
116 | model (model):
117 | Pretrained model (Random Forest of XGBoost at the moment).
118 |
119 | X (pd.DataFrame or np.ndarray):
120 | Dataset on which the SHAP importance is calculated.
121 |
122 | precalc_shap (np.array):
123 | Precalculated SHAP values. If None, they are computed.
124 |
125 | **kwargs: for the function shap_calc
126 |
127 | Returns:
128 | (pd.DataFrame):
129 | Dataframe with SHAP feature importance per features on X dataset.
130 | """
131 | shap_values = precalc_shap if precalc_shap is not None else shap_calc(model, X, **kwargs)
132 |
133 | try:
134 | return pd.DataFrame(shap_values, columns=X.columns, index=X.index)
135 | except AttributeError:
136 | if isinstance(X, np.ndarray) and len(X.shape) == 2:
137 | return pd.DataFrame(shap_values, columns=[f"col_{ix}" for ix in range(X.shape[1])])
138 | else:
139 | raise TypeError("X must be a dataframe or a 2d array")
140 |
141 |
142 | def calculate_shap_importance(shap_values, columns, output_columns_suffix="", shap_variance_penalty_factor=None):
143 | """
144 | Returns the average shapley value for each column of the dataframe, as well as the average absolute shap value.
145 |
146 | Args:
147 | shap_values (np.array):
148 | Shap values.
149 |
150 | columns (list of str):
151 | Feature names.
152 |
153 | output_columns_suffix (str, optional):
154 | Suffix to be added at the end of column names in the output.
155 |
156 | shap_variance_penalty_factor (int or float, optional):
157 | Apply aggregation penalty when computing average of shap values for a given feature.
158 | Results in a preference for features that have smaller standard deviation of shap
159 | values (more coherent shap importance). Recommend value 0.5 - 1.0.
160 | Formula: penalized_shap_mean = (mean_shap - (std_shap * shap_variance_penalty_factor))
161 |
162 | Returns:
163 | (pd.DataFrame):
164 | Mean absolute shap values and Mean shap values of features.
165 |
166 | """
167 | if shap_variance_penalty_factor is None or shap_variance_penalty_factor < 0:
168 | shap_variance_penalty_factor = 0
169 | elif not isinstance(shap_variance_penalty_factor, (float, int)):
170 | warnings.warn(
171 | "shap_variance_penalty_factor must be None, int, or float. Setting shap_variance_penalty_factor = 0"
172 | )
173 | shap_variance_penalty_factor = 0
174 |
175 | abs_shap_values = np.abs(shap_values)
176 | if np.ndim(shap_values) > 2: # multi-class case
177 | sum_abs_shap = np.sum(abs_shap_values, axis=0)
178 | sum_shap = np.sum(shap_values, axis=0)
179 | shap_abs_mean = np.mean(sum_abs_shap, axis=0)
180 | shap_mean = np.mean(sum_shap, axis=0)
181 | penalized_shap_abs_mean = shap_abs_mean - (np.std(sum_abs_shap, axis=0) * shap_variance_penalty_factor)
182 | else:
183 | # Find average shap importance for neg and pos class
184 | shap_abs_mean = np.mean(abs_shap_values, axis=0)
185 | shap_mean = np.mean(shap_values, axis=0)
186 | penalized_shap_abs_mean = shap_abs_mean - (np.std(abs_shap_values, axis=0) * shap_variance_penalty_factor)
187 |
188 | # Prepare the values in a df and set the correct column types
189 | importance_df = pd.DataFrame(
190 | {
191 | f"mean_abs_shap_value{output_columns_suffix}": shap_abs_mean,
192 | f"mean_shap_value{output_columns_suffix}": shap_mean,
193 | f"penalized_mean_abs_shap_value{output_columns_suffix}": penalized_shap_abs_mean,
194 | },
195 | index=columns,
196 | ).astype(float)
197 |
198 | importance_df = importance_df.sort_values(f"penalized_mean_abs_shap_value{output_columns_suffix}", ascending=False)
199 |
200 | # Drop penalized column
201 | importance_df = importance_df.drop(columns=[f"penalized_mean_abs_shap_value{output_columns_suffix}"])
202 |
203 | return importance_df
204 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "probatus"
7 | version = "3.1.3"
8 | requires-python= ">=3.9"
9 | description = "Validation of regression & classifiers and data used to develop them"
10 | readme = { file = "README.md", content-type = "text/markdown" }
11 | authors = [
12 | { name = "ING Bank N.V.", email = "reinier.koops@ing.com" }
13 | ]
14 | license = { file = "LICENCE" }
15 | classifiers = [
16 | "Intended Audience :: Developers",
17 | "Intended Audience :: Science/Research",
18 | "Programming Language :: Python :: 3",
19 | "Programming Language :: Python :: 3.9",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: 3.12",
23 | "Topic :: Scientific/Engineering",
24 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
25 | "License :: OSI Approved :: MIT License",
26 | "Operating System :: OS Independent",
27 | ]
28 | dependencies = [
29 | "scikit-learn>=0.22.2",
30 | "pandas>=1.0.0",
31 | "matplotlib>=3.1.1",
32 | "joblib>=0.13.2",
33 | "shap>=0.43.0",
34 | "numpy>=1.23.2,<2.0.0",
35 | "numba>=0.57.0",
36 | "loguru>=0.7.2",
37 | ]
38 |
39 | [project.urls]
40 | Homepage = "https://ing-bank.github.io/probatus/"
41 | Documentation = "https://ing-bank.github.io/probatus/api/feature_elimination.html"
42 | Repository = "https://github.com/ing-bank/probatus.git"
43 | Changelog = "https://github.com/ing-bank/probatus/blob/main/CHANGELOG.md"
44 |
45 | [project.optional-dependencies]
46 |
47 |
48 | dev = [
49 | "black>=19.10b0",
50 | "mypy>=0.770",
51 | "pytest>=6.0.0",
52 | "pytest-cov>=2.10.0",
53 | "pyflakes",
54 | "joblib>=0.13.2",
55 | "jupyter>=1.0.0",
56 | "nbconvert>=6.0.7",
57 | "pre-commit>=2.7.1",
58 | "isort>=5.12.0",
59 | "codespell>=2.2.4",
60 | "ruff>=0.2.2",
61 | "lightgbm>=3.3.0",
62 | "catboost>=1.2",
63 | "xgboost>=1.5.0",
64 | ]
65 | docs = [
66 | "mkdocs>=1.5.3",
67 | "mkdocs-jupyter>=0.24.3",
68 | "mkdocs-material>=9.5.13",
69 | "mkdocstrings>=0.24.1",
70 | "mkdocstrings-python>=1.8.0",
71 | ]
72 |
73 | # Separating these allow for more install flexibility.
74 | all = ["probatus[dev,docs]"]
75 |
76 | [tool.setuptools.packages.find]
77 | exclude = ["tests", "notebooks", "docs"]
78 |
79 | [tool.nbqa.addopts]
80 | # E402: Ignores imports not at the top of file for IPYNB since this makes copy-pasting easier.
81 | ruff = ["--fix", "--ignore=E402"]
82 | isort = ["--profile=black"]
83 | black = ["--line-length=120"]
84 |
85 | [tool.mypy]
86 | python_version = "3.9"
87 | ignore_missing_imports = true
88 | namespace_packages = true
89 | pretty = true
90 |
91 | [tool.ruff]
92 | line-length = 120
93 | extend-exclude = ["docs", "mkdocs.yml", ".github", "*md", "LICENCE", ".pre-commit-config.yaml", ".gitignore"]
94 | force-exclude = true
95 |
96 | [tool.ruff.lint]
97 | # D100 requires all Python files (modules) to have a "public" docstring even if all functions within have a docstring.
98 | # D104 requires __init__ files to have a docstring
99 | # D202 No blank lines allowed after function docstring
100 | # D212
101 | # D200
102 | # D411 Missing blank line before section
103 | # D412 No blank lines allowed between a section header and its content
104 | # D417 Missing argument descriptions in the docstring # Only ignored because of false positve when using multiline args.
105 | # E203
106 | # E731 do not assign a lambda expression, use a def
107 | # W293 blank line contains whitespace
108 | ignore = ["D100", "D104", "D202", "D212", "D200", "E203", "E731", "W293", "D412", "D417", "D411", "RUF100"]
109 |
110 | [tool.ruff.lint.pydocstyle]
111 | convention = "google"
112 |
113 | [tool.isort]
114 | line_length = "120"
115 | profile = "black"
116 | filter_files = true
117 | extend_skip = ["__init__.py", "docs"]
118 |
119 | [tool.codespell]
120 | skip = "**.egg-info*"
121 |
122 | [tool.black]
123 | line-length = 120
124 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from sklearn.datasets import make_classification
7 | from sklearn.model_selection import train_test_split
8 | from sklearn.tree import DecisionTreeClassifier
9 | from catboost import CatBoostClassifier
10 | from lightgbm import LGBMClassifier
11 | from sklearn.linear_model import LogisticRegression
12 | from sklearn.model_selection import RandomizedSearchCV
13 |
14 |
15 | @pytest.fixture(scope="function")
16 | def random_state():
17 | RANDOM_STATE = 0
18 |
19 | return RANDOM_STATE
20 |
21 |
22 | @pytest.fixture(scope="function")
23 | def random_state_42():
24 | RANDOM_STATE = 42
25 |
26 | return RANDOM_STATE
27 |
28 |
29 | @pytest.fixture(scope="function")
30 | def random_state_1234():
31 | RANDOM_STATE = 1234
32 |
33 | return RANDOM_STATE
34 |
35 |
36 | @pytest.fixture(scope="function")
37 | def random_state_1():
38 | RANDOM_STATE = 1
39 |
40 | return RANDOM_STATE
41 |
42 |
43 | @pytest.fixture(scope="function")
44 | def mock_model():
45 | return Mock()
46 |
47 |
48 | @pytest.fixture(scope="function")
49 | def complex_data(random_state):
50 | feature_names = ["f1_categorical", "f2_missing", "f3_static", "f4", "f5"]
51 |
52 | # Prepare two samples
53 | X, y = make_classification(
54 | n_samples=50,
55 | class_sep=0.05,
56 | n_informative=2,
57 | n_features=5,
58 | random_state=random_state,
59 | n_redundant=2,
60 | n_clusters_per_class=1,
61 | )
62 | X = pd.DataFrame(X, columns=feature_names)
63 | X.loc[0:10, "f2_missing"] = np.nan
64 | return X, y
65 |
66 |
67 | @pytest.fixture(scope="function")
68 | def complex_data_with_categorical(complex_data):
69 | X, y = complex_data
70 | X["f1_categorical"] = X["f1_categorical"].astype(str).astype("category")
71 |
72 | return X, y
73 |
74 |
75 | @pytest.fixture(scope="function")
76 | def complex_data_split(complex_data, random_state_42):
77 | X, y = complex_data
78 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state_42)
79 | return X_train, X_test, y_train, y_test
80 |
81 |
82 | @pytest.fixture(scope="function")
83 | def complex_data_split_with_categorical(complex_data_split):
84 | X_train, X_test, y_train, y_test = complex_data_split
85 | X_train["f1_categorical"] = X_train["f1_categorical"].astype(str).astype("category")
86 | X_test["f1_categorical"] = X_test["f1_categorical"].astype(str).astype("category")
87 |
88 | return X_train, X_test, y_train, y_test
89 |
90 |
91 | @pytest.fixture(scope="function")
92 | def complex_lightgbm(random_state_42):
93 | model = LGBMClassifier(max_depth=5, num_leaves=11, class_weight="balanced", random_state=random_state_42)
94 | return model
95 |
96 |
97 | @pytest.fixture(scope="function")
98 | def complex_fitted_lightgbm(complex_data_split_with_categorical, complex_lightgbm):
99 | X_train, _, y_train, _ = complex_data_split_with_categorical
100 |
101 | return complex_lightgbm.fit(X_train, y_train)
102 |
103 |
104 | @pytest.fixture(scope="function")
105 | def catboost_classifier(random_state):
106 | model = CatBoostClassifier(random_seed=random_state)
107 | return model
108 |
109 |
110 | @pytest.fixture(scope="function")
111 | def decision_tree_classifier(random_state):
112 | model = DecisionTreeClassifier(max_depth=1, random_state=random_state)
113 | return model
114 |
115 |
116 | @pytest.fixture(scope="function")
117 | def randomized_search_decision_tree_classifier(decision_tree_classifier, random_state):
118 | param_grid = {"criterion": ["gini"], "min_samples_split": [1, 2]}
119 | cv = RandomizedSearchCV(decision_tree_classifier, param_grid, cv=2, n_iter=2, random_state=random_state)
120 | return cv
121 |
122 |
123 | @pytest.fixture(scope="function")
124 | def logistic_regression(random_state):
125 | model = LogisticRegression(random_state=random_state)
126 | return model
127 |
128 |
129 | @pytest.fixture(scope="function")
130 | def X_train():
131 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[1, 2, 3, 4])
132 |
133 |
134 | @pytest.fixture(scope="function")
135 | def y_train():
136 | return pd.Series([1, 0, 1, 0], index=[1, 2, 3, 4])
137 |
138 |
139 | @pytest.fixture(scope="function")
140 | def X_test():
141 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[5, 6, 7, 8])
142 |
143 |
144 | @pytest.fixture(scope="function")
145 | def y_test():
146 | return pd.Series([0, 0, 1, 0], index=[5, 6, 7, 8])
147 |
148 |
149 | @pytest.fixture(scope="function")
150 | def fitted_logistic_regression(X_train, y_train, logistic_regression):
151 | return logistic_regression.fit(X_train, y_train)
152 |
153 |
154 | @pytest.fixture(scope="function")
155 | def fitted_tree(X_train, y_train, decision_tree_classifier):
156 | return decision_tree_classifier.fit(X_train, y_train)
157 |
--------------------------------------------------------------------------------
/tests/docs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/docs/__init__.py
--------------------------------------------------------------------------------
/tests/docs/test_docstring.py:
--------------------------------------------------------------------------------
1 | # This approach is adapted from, and explained in: https://calmcode.io/docs/epic.html
2 |
3 | import os
4 | from typing import List
5 |
6 | import matplotlib
7 | import matplotlib.pyplot as plt
8 | import pytest
9 |
10 | import probatus.feature_elimination
11 | import probatus.interpret
12 | import probatus.sample_similarity
13 | import probatus.utils
14 |
15 | # Turn off interactive mode in plots
16 | plt.ioff()
17 | matplotlib.use("Agg")
18 |
19 | CLASSES_TO_TEST = [
20 | probatus.feature_elimination.ShapRFECV,
21 | probatus.interpret.DependencePlotter,
22 | probatus.sample_similarity.SHAPImportanceResemblance,
23 | probatus.sample_similarity.PermutationImportanceResemblance,
24 | probatus.utils.Scorer,
25 | ]
26 |
27 | CLASSES_TO_TEST_LGBM = [
28 | probatus.feature_elimination.EarlyStoppingShapRFECV,
29 | ]
30 |
31 | FUNCTIONS_TO_TEST: List = []
32 |
33 |
34 | def handle_docstring(doc, indent):
35 | """
36 | Check python code in docstring.
37 |
38 | This function will read through the docstring and grab
39 | the first python code block. It will try to execute it.
40 | If it fails, the calling test should raise a flag.
41 | """
42 | if not doc:
43 | return
44 | start = doc.find("```python\n")
45 | end = doc.find("```\n")
46 | if start != -1:
47 | if end != -1:
48 | code_part = doc[(start + 10) : end].replace(" " * indent, "")
49 | exec(code_part)
50 |
51 |
52 | @pytest.mark.parametrize("c", CLASSES_TO_TEST)
53 | def test_class_docstrings(c):
54 | """
55 | Take the docstring of a given class.
56 |
57 | The test passes if the usage examples causes no errors.
58 | """
59 | handle_docstring(c.__doc__, indent=4)
60 |
61 |
62 | @pytest.mark.skipif(os.environ.get("SKIP_LIGHTGBM") == "true", reason="LightGBM tests disabled")
63 | @pytest.mark.parametrize("c", CLASSES_TO_TEST_LGBM)
64 | def test_class_docstrings_lgbm(c):
65 | """
66 | Take the docstring of a given class which uses LightGBM.
67 |
68 | The test passes if the usage examples causes no errors.
69 |
70 | The test is skipped if the environment does not support LightGBM correctly, such as macos.
71 | """
72 | handle_docstring(c.__doc__, indent=4)
73 |
74 |
75 | @pytest.mark.parametrize("f", FUNCTIONS_TO_TEST)
76 | def test_function_docstrings(f):
77 | """
78 | Take the docstring of every function.
79 |
80 | The test passes if the usage examples causes no errors.
81 | """
82 | handle_docstring(f.__doc__, indent=4)
83 |
--------------------------------------------------------------------------------
/tests/docs/test_notebooks.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import pytest
5 | import nbformat
6 | from nbconvert.preprocessors import ExecutePreprocessor
7 |
8 | TIMEOUT_SECONDS = 1800
9 | PATH_NOTEBOOKS = [str(path) for path in Path("docs").glob("*/*.ipynb")]
10 | NB_FLAG = os.environ.get("TEST_NOTEBOOKS") # Turn on tests by setting TEST_NOTEBOOKS = 1
11 | TEST_NOTEBOOKS = False if NB_FLAG == "1" else True
12 |
13 |
14 | @pytest.mark.parametrize("notebook_path", PATH_NOTEBOOKS)
15 | @pytest.mark.skipif(TEST_NOTEBOOKS, reason="Skip notebook tests if TEST_NOTEBOOK isn't set")
16 | def test_notebook(notebook_path: str) -> None:
17 | """Run a notebook and check no exception is raised."""
18 | with open(notebook_path) as f:
19 | nb = nbformat.read(f, as_version=4)
20 |
21 | ep = ExecutePreprocessor(timeout=TIMEOUT_SECONDS, kernel_name="python3")
22 | ep.preprocess(nb, {"metadata": {"path": str(Path(notebook_path).parent)}})
23 |
--------------------------------------------------------------------------------
/tests/feature_elimination/test_feature_elimination.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pytest
3 | from lightgbm import LGBMClassifier
4 | from sklearn.datasets import load_diabetes, make_classification
5 | from sklearn.ensemble import RandomForestClassifier
6 | from sklearn.linear_model import LogisticRegression
7 | from sklearn.model_selection import RandomizedSearchCV, StratifiedGroupKFold, StratifiedKFold
8 | from sklearn.pipeline import Pipeline
9 | from sklearn.preprocessing import StandardScaler
10 | from sklearn.svm import SVC
11 | from xgboost import XGBClassifier, XGBRegressor
12 |
13 | from probatus.feature_elimination import ShapRFECV, EarlyStoppingShapRFECV
14 | from probatus.utils import preprocess_labels
15 |
16 |
17 | @pytest.fixture(scope="function")
18 | def X():
19 | return pd.DataFrame(
20 | {
21 | "col_1": [1, 1, 1, 1, 1, 1, 1, 0],
22 | "col_2": [0, 0, 0, 0, 0, 0, 0, 1],
23 | "col_3": [1, 0, 1, 0, 1, 0, 1, 0],
24 | },
25 | index=[1, 2, 3, 4, 5, 6, 7, 8],
26 | )
27 |
28 |
29 | @pytest.fixture(scope="function")
30 | def y():
31 | return pd.Series([1, 0, 1, 0, 1, 0, 1, 0], index=[1, 2, 3, 4, 5, 6, 7, 8])
32 |
33 |
34 | @pytest.fixture(scope="function")
35 | def sample_weight():
36 | return pd.Series([1, 1, 1, 1, 1, 1, 1, 1], index=[1, 2, 3, 4, 5, 6, 7, 8])
37 |
38 |
39 | @pytest.fixture(scope="function")
40 | def groups():
41 | return pd.Series(["grp1", "grp1", "grp1", "grp1", "grp2", "grp2", "grp2", "grp2"], index=[1, 2, 3, 4, 5, 6, 7, 8])
42 |
43 |
44 | @pytest.fixture(scope="function")
45 | def XGBoost_classifier(random_state):
46 | model = XGBClassifier(n_estimators=200, max_depth=3, random_state=random_state)
47 | return model
48 |
49 |
50 | @pytest.fixture(scope="function")
51 | def XGBoost_regressor(random_state):
52 | model = XGBRegressor(n_estimators=200, max_depth=3, random_state=random_state)
53 | return model
54 |
55 |
56 | def test_shap_rfe_regressor(XGBoost_regressor, random_state):
57 | diabetes = load_diabetes()
58 | X = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
59 | y = diabetes.target
60 |
61 | shap_elimination = ShapRFECV(XGBoost_regressor, step=0.8, cv=2, scoring="r2", n_jobs=4, random_state=random_state)
62 | report = shap_elimination.fit_compute(X, y)
63 |
64 | assert report.shape[0] == 3
65 | assert shap_elimination.get_reduced_features_set(1) == ["bmi"]
66 |
67 | _ = shap_elimination.plot(show=False)
68 |
69 |
70 | def test_shap_rfe_randomized_search(X, y, randomized_search_decision_tree_classifier, random_state):
71 | search = randomized_search_decision_tree_classifier
72 | shap_elimination = ShapRFECV(search, step=0.8, cv=2, scoring="roc_auc", n_jobs=4, random_state=random_state)
73 | report = shap_elimination.fit_compute(X, y)
74 |
75 | assert report.shape[0] == 2
76 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
77 |
78 | _ = shap_elimination.plot(show=False)
79 |
80 |
81 | def test_shap_rfe_multi_class(X, y, decision_tree_classifier, random_state):
82 | shap_elimination = ShapRFECV(
83 | decision_tree_classifier,
84 | cv=2,
85 | scoring="roc_auc_ovr",
86 | random_state=random_state,
87 | )
88 |
89 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False)
90 |
91 | assert report.shape[0] == 3
92 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
93 |
94 |
95 | def test_shap_rfe(X, y, sample_weight, decision_tree_classifier, random_state):
96 | shap_elimination = ShapRFECV(
97 | decision_tree_classifier,
98 | random_state=random_state,
99 | step=1,
100 | cv=2,
101 | scoring="roc_auc",
102 | n_jobs=4,
103 | )
104 | report = shap_elimination.fit_compute(X, y, sample_weight=sample_weight, approximate=True, check_additivity=False)
105 |
106 | assert report.shape[0] == 3
107 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
108 |
109 |
110 | def test_shap_rfe_group_cv(X, y, groups, sample_weight, decision_tree_classifier, random_state):
111 | cv = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=random_state)
112 | shap_elimination = ShapRFECV(
113 | decision_tree_classifier,
114 | random_state=random_state,
115 | step=1,
116 | cv=cv,
117 | scoring="roc_auc",
118 | n_jobs=4,
119 | )
120 | report = shap_elimination.fit_compute(
121 | X, y, groups=groups, sample_weight=sample_weight, approximate=True, check_additivity=False
122 | )
123 |
124 | assert report.shape[0] == 3
125 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
126 |
127 |
128 | def test_shap_pipeline_error(X, y, decision_tree_classifier, random_state):
129 | model = Pipeline(
130 | [
131 | ("scaler", StandardScaler()),
132 | ("dt", decision_tree_classifier),
133 | ]
134 | )
135 | with pytest.raises(TypeError):
136 | shap_elimination = ShapRFECV(
137 | model,
138 | random_state=random_state,
139 | step=1,
140 | cv=2,
141 | scoring="roc_auc",
142 | n_jobs=4,
143 | )
144 | shap_elimination = shap_elimination.fit(X, y, approximate=True, check_additivity=False)
145 |
146 |
147 | def test_shap_rfe_linear_model(X, y, random_state):
148 | model = LogisticRegression(C=1, random_state=random_state)
149 | shap_elimination = ShapRFECV(model, random_state=random_state, step=1, cv=2, scoring="roc_auc", n_jobs=4)
150 | report = shap_elimination.fit_compute(X, y)
151 |
152 | assert report.shape[0] == 3
153 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
154 |
155 |
156 | def test_shap_rfe_svm(X, y, random_state):
157 | model = SVC(C=1, kernel="linear", probability=True, random_state=random_state)
158 | shap_elimination = ShapRFECV(model, random_state=random_state, step=1, cv=2, scoring="roc_auc", n_jobs=4)
159 | shap_elimination = shap_elimination.fit(X, y)
160 | report = shap_elimination.compute()
161 |
162 | assert report.shape[0] == 3
163 | assert shap_elimination.get_reduced_features_set(1) == ["col_3"]
164 |
165 |
166 | def test_shap_rfe_cols_to_keep(X, y, decision_tree_classifier, random_state):
167 | shap_elimination = ShapRFECV(
168 | decision_tree_classifier,
169 | random_state=random_state,
170 | step=2,
171 | cv=2,
172 | scoring="roc_auc",
173 | n_jobs=4,
174 | min_features_to_select=1,
175 | )
176 | report = shap_elimination.fit_compute(X, y, columns_to_keep=["col_2", "col_3"])
177 |
178 | assert report.shape[0] == 2
179 | reduced_feature_set = set(shap_elimination.get_reduced_features_set(num_features=2))
180 | assert reduced_feature_set == {"col_2", "col_3"}
181 |
182 |
183 | def test_shap_rfe_randomized_search_cols_to_keep(X, y, randomized_search_decision_tree_classifier, random_state):
184 | search = randomized_search_decision_tree_classifier
185 | shap_elimination = ShapRFECV(search, step=0.8, cv=2, scoring="roc_auc", n_jobs=4, random_state=random_state)
186 | report = shap_elimination.fit_compute(X, y, columns_to_keep=["col_2", "col_3"])
187 |
188 | assert report.shape[0] == 2
189 | reduced_feature_set = set(shap_elimination.get_reduced_features_set(num_features=2))
190 | assert reduced_feature_set == {"col_2", "col_3"}
191 |
192 |
193 | def test_calculate_number_of_features_to_remove():
194 | assert 3 == ShapRFECV._calculate_number_of_features_to_remove(
195 | current_num_of_features=10, num_features_to_remove=3, min_num_features_to_keep=5
196 | )
197 | assert 3 == ShapRFECV._calculate_number_of_features_to_remove(
198 | current_num_of_features=8, num_features_to_remove=5, min_num_features_to_keep=5
199 | )
200 | assert 0 == ShapRFECV._calculate_number_of_features_to_remove(
201 | current_num_of_features=5, num_features_to_remove=1, min_num_features_to_keep=5
202 | )
203 | assert 4 == ShapRFECV._calculate_number_of_features_to_remove(
204 | current_num_of_features=5, num_features_to_remove=7, min_num_features_to_keep=1
205 | )
206 |
207 |
208 | def test_shap_automatic_num_feature_selection(decision_tree_classifier, random_state):
209 | X = pd.DataFrame(
210 | {
211 | "col_1": [1, 0, 1, 0, 1, 0, 1, 0],
212 | "col_2": [0, 0, 0, 0, 0, 1, 1, 1],
213 | "col_3": [1, 1, 1, 0, 0, 0, 0, 0],
214 | }
215 | )
216 | y = pd.Series([0, 0, 0, 0, 1, 1, 1, 1])
217 |
218 | shap_elimination = ShapRFECV(
219 | decision_tree_classifier,
220 | random_state=random_state,
221 | step=1,
222 | cv=2,
223 | scoring="roc_auc",
224 | n_jobs=1,
225 | )
226 | _ = shap_elimination.fit_compute(X, y, approximate=True, check_additivity=False)
227 |
228 | best_features = shap_elimination.get_reduced_features_set(num_features="best")
229 | best_coherent_features = shap_elimination.get_reduced_features_set(
230 | num_features="best_coherent",
231 | )
232 | best_parsimonious_features = shap_elimination.get_reduced_features_set(num_features="best_parsimonious")
233 |
234 | assert best_features == ["col_2"]
235 | assert best_coherent_features == ["col_1", "col_2", "col_3"]
236 | assert best_parsimonious_features == ["col_2"]
237 |
238 |
239 | def test_get_feature_shap_values_per_fold(X, y, decision_tree_classifier, random_state):
240 | shap_elimination = ShapRFECV(decision_tree_classifier, scoring="roc_auc", random_state=random_state)
241 | (
242 | shap_values,
243 | train_score,
244 | test_score,
245 | ) = shap_elimination._get_feature_shap_values_per_fold(
246 | X,
247 | y,
248 | decision_tree_classifier,
249 | train_index=[2, 3, 4, 5, 6, 7],
250 | val_index=[0, 1],
251 | )
252 | assert test_score == 1
253 | assert train_score > 0.9
254 | assert shap_values.shape == (2, 3)
255 |
256 |
257 | def test_shap_rfe_same_features_are_kept_after_each_run(random_state_1234):
258 | """
259 | Test a use case which appears to be flickering with Probatus 1.8.9 and lower.
260 |
261 | Expected result: every run the same outcome.
262 | Probatus <= 1.8.9: A different order every time.
263 | """
264 | feature_names = [(f"f{num}") for num in range(1, 21)]
265 |
266 | # Code from tutorial on probatus documentation
267 | X, y = make_classification(
268 | n_samples=100,
269 | class_sep=0.05,
270 | n_informative=6,
271 | n_features=20,
272 | random_state=random_state_1234,
273 | n_redundant=10,
274 | n_clusters_per_class=1,
275 | )
276 | X = pd.DataFrame(X, columns=feature_names)
277 |
278 | random_forest = RandomForestClassifier(
279 | random_state=random_state_1234,
280 | n_estimators=70,
281 | max_features="log2",
282 | criterion="entropy",
283 | class_weight="balanced",
284 | )
285 |
286 | shap_elimination = ShapRFECV(
287 | random_forest,
288 | step=0.2,
289 | cv=5,
290 | scoring="f1_macro",
291 | n_jobs=1,
292 | random_state=random_state_1234,
293 | )
294 |
295 | report = shap_elimination.fit_compute(X, y, check_additivity=True)
296 | # Return the set of features with the best validation accuracy
297 |
298 | kept_features = list(report.iloc[[report["val_metric_mean"].idxmax() - 1]]["features_set"].to_list()[0])
299 |
300 | # Results from the first run
301 | assert [
302 | "f1",
303 | "f2",
304 | "f3",
305 | "f5",
306 | "f6",
307 | "f10",
308 | "f11",
309 | "f12",
310 | "f13",
311 | "f14",
312 | "f15",
313 | "f16",
314 | "f17",
315 | "f18",
316 | "f19",
317 | "f20",
318 | ] == kept_features
319 |
320 |
321 | def test_shap_rfe_penalty_factor(X, y, decision_tree_classifier, random_state):
322 | shap_elimination = ShapRFECV(
323 | decision_tree_classifier,
324 | random_state=random_state,
325 | step=1,
326 | cv=2,
327 | scoring="roc_auc",
328 | n_jobs=1,
329 | )
330 | report = shap_elimination.fit_compute(
331 | X, y, shap_variance_penalty_factor=1.0, approximate=True, check_additivity=False
332 | )
333 |
334 | assert report.shape[0] == 3
335 | assert shap_elimination.get_reduced_features_set(1) == ["col_1"]
336 |
337 |
338 | def test_complex_dataset(complex_data, complex_lightgbm, random_state_1):
339 | X, y = complex_data
340 |
341 | param_grid = {
342 | "n_estimators": [5, 7, 10],
343 | "num_leaves": [3, 5, 7, 10],
344 | }
345 | search = RandomizedSearchCV(complex_lightgbm, param_grid, n_iter=1, random_state=random_state_1)
346 |
347 | shap_elimination = ShapRFECV(
348 | model=search, step=1, cv=10, scoring="roc_auc", n_jobs=3, verbose=1, random_state=random_state_1
349 | )
350 |
351 | report = shap_elimination.fit_compute(X, y)
352 |
353 | assert report.shape[0] == X.shape[1]
354 |
355 |
356 | def test_shap_rfe_early_stopping_lightGBM(complex_data, random_state):
357 | model = LGBMClassifier(n_estimators=200, max_depth=3, random_state=random_state)
358 | X, y = complex_data
359 |
360 | shap_elimination = EarlyStoppingShapRFECV(
361 | model,
362 | random_state=random_state,
363 | step=1,
364 | cv=10,
365 | scoring="roc_auc",
366 | n_jobs=4,
367 | early_stopping_rounds=5,
368 | eval_metric="auc",
369 | )
370 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False)
371 |
372 | assert report.shape[0] == 5
373 | assert shap_elimination.get_reduced_features_set(1) == ["f5"]
374 |
375 |
376 | def test_shap_rfe_early_stopping_XGBoost(XGBoost_classifier, complex_data, random_state):
377 | X, y = complex_data
378 | X["f1_categorical"] = X["f1_categorical"].astype(float)
379 |
380 | shap_elimination = ShapRFECV(
381 | XGBoost_classifier,
382 | random_state=random_state,
383 | step=1,
384 | cv=10,
385 | scoring="roc_auc",
386 | n_jobs=4,
387 | early_stopping_rounds=5,
388 | eval_metric="auc",
389 | )
390 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False)
391 |
392 | assert report.shape[0] == 5
393 | assert shap_elimination.get_reduced_features_set(1) == ["f4"]
394 |
395 |
396 | #
397 | #
398 | def test_shap_rfe_early_stopping_CatBoost(complex_data_with_categorical, catboost_classifier, random_state):
399 | X, y = complex_data_with_categorical
400 |
401 | shap_elimination = ShapRFECV(
402 | catboost_classifier,
403 | random_state=random_state,
404 | step=1,
405 | cv=10,
406 | scoring="roc_auc",
407 | n_jobs=4,
408 | early_stopping_rounds=5,
409 | eval_metric="auc",
410 | )
411 | report = shap_elimination.fit_compute(X, y, approximate=False, check_additivity=False)
412 |
413 | assert report.shape[0] == 5
414 | assert shap_elimination.get_reduced_features_set(1)[0] in ["f4", "f5"]
415 |
416 |
417 | def test_shap_rfe_randomized_search_early_stopping_lightGBM(complex_data, random_state):
418 | model = LGBMClassifier(n_estimators=200, random_state=random_state)
419 | X, y = complex_data
420 |
421 | param_grid = {
422 | "max_depth": [3, 4, 5],
423 | }
424 | search = RandomizedSearchCV(model, param_grid, cv=2, n_iter=2, random_state=random_state)
425 | shap_elimination = ShapRFECV(
426 | search,
427 | step=1,
428 | cv=10,
429 | scoring="roc_auc",
430 | early_stopping_rounds=5,
431 | eval_metric="auc",
432 | n_jobs=4,
433 | verbose=1,
434 | random_state=random_state,
435 | )
436 | report = shap_elimination.fit_compute(X, y)
437 |
438 | assert report.shape[0] == X.shape[1]
439 | assert shap_elimination.get_reduced_features_set(1) == ["f5"]
440 |
441 | _ = shap_elimination.plot(show=False)
442 |
443 |
444 | def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data, random_state):
445 | model = LGBMClassifier(n_estimators=200, max_depth=3, random_state=random_state)
446 | X, y = complex_data
447 | y = preprocess_labels(y, y_name="y", index=X.index)
448 |
449 | shap_elimination = ShapRFECV(model, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state)
450 | (
451 | shap_values,
452 | train_score,
453 | test_score,
454 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping(
455 | X,
456 | y,
457 | model,
458 | train_index=list(range(5, 50)),
459 | val_index=[0, 1, 2, 3, 4],
460 | )
461 | assert test_score > 0.6
462 | assert train_score > 0.6
463 | assert shap_values.shape == (5, 5)
464 |
465 |
466 | def test_get_feature_shap_values_per_fold_early_stopping_CatBoost(
467 | complex_data_with_categorical, catboost_classifier, random_state
468 | ):
469 | X, y = complex_data_with_categorical
470 | y = preprocess_labels(y, y_name="y", index=X.index)
471 |
472 | shap_elimination = ShapRFECV(
473 | catboost_classifier, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state
474 | )
475 | (
476 | shap_values,
477 | train_score,
478 | test_score,
479 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping(
480 | X,
481 | y,
482 | catboost_classifier,
483 | train_index=list(range(5, 50)),
484 | val_index=[0, 1, 2, 3, 4],
485 | )
486 | assert test_score > 0
487 | assert train_score > 0.6
488 | assert shap_values.shape == (5, 5)
489 |
490 |
491 | def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(XGBoost_classifier, complex_data, random_state):
492 | X, y = complex_data
493 | y = preprocess_labels(y, y_name="y", index=X.index)
494 |
495 | shap_elimination = ShapRFECV(
496 | XGBoost_classifier, early_stopping_rounds=5, scoring="roc_auc", random_state=random_state
497 | )
498 | (
499 | shap_values,
500 | train_score,
501 | test_score,
502 | ) = shap_elimination._get_feature_shap_values_per_fold_early_stopping(
503 | X,
504 | y,
505 | XGBoost_classifier,
506 | train_index=list(range(5, 50)),
507 | val_index=[0, 1, 2, 3, 4],
508 | )
509 | assert test_score > 0
510 | assert train_score > 0.6
511 | assert shap_values.shape == (5, 5)
512 |
513 |
514 | def test_EarlyStoppingShapRFECV_no_categorical(complex_data, random_state):
515 | model = LGBMClassifier(n_estimators=50, max_depth=3, num_leaves=3, random_state=random_state)
516 |
517 | shap_elimination = ShapRFECV(
518 | model=model,
519 | step=0.33,
520 | cv=5,
521 | scoring="accuracy",
522 | eval_metric="logloss",
523 | early_stopping_rounds=5,
524 | random_state=random_state,
525 | )
526 | X, y = complex_data
527 | X = X.drop(columns=["f1_categorical"])
528 | report = shap_elimination.fit_compute(X, y, feature_perturbation="tree_path_dependent")
529 |
530 | assert report.shape[0] == X.shape[1]
531 | assert shap_elimination.get_reduced_features_set(1) == ["f5"]
532 |
533 | _ = shap_elimination.plot(show=False)
534 |
535 |
536 | def test_LightGBM_stratified_kfold(random_state):
537 | """
538 | Test added to check for https://github.com/ing-bank/probatus/issues/170.
539 | """
540 | X = pd.DataFrame(
541 | [
542 | [1, 2, 3, 4, 5, 101, 102, 103, 104, 105],
543 | [-1, -2, 2, -5, -7, 1, 2, 5, -1, 3],
544 | ["a", "b"] * 5, # noisy categorical will dropped first
545 | ]
546 | ).transpose()
547 | X[2] = X[2].astype("category")
548 | X[1] = X[1].astype("float")
549 | X[0] = X[0].astype("float")
550 | y = [0] * 5 + [1] * 5
551 |
552 | model = LGBMClassifier(random_state=random_state)
553 | n_iter = 2
554 | n_folds = 3
555 |
556 | for _ in range(n_iter):
557 | skf = StratifiedKFold(n_folds, shuffle=True, random_state=random_state)
558 | shap_elimination = ShapRFECV(
559 | model=model,
560 | step=1 / (n_iter + 1),
561 | cv=skf,
562 | scoring="accuracy",
563 | eval_metric="logloss",
564 | early_stopping_rounds=5,
565 | random_state=random_state,
566 | )
567 | report = shap_elimination.fit_compute(X, y, feature_perturbation="tree_path_dependent")
568 |
569 | assert report.shape[0] == X.shape[1]
570 |
571 | shap_elimination.plot(show=False)
572 |
--------------------------------------------------------------------------------
/tests/interpret/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/interpret/__init__.py
--------------------------------------------------------------------------------
/tests/interpret/test_model_interpret.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from probatus.interpret import ShapModelInterpreter
6 |
7 |
8 | @pytest.fixture(scope="function")
9 | def expected_feature_importance():
10 | return pd.DataFrame(
11 | {
12 | "mean_abs_shap_value_test": [0.5, 0.0, 0.0],
13 | "mean_abs_shap_value_train": [0.5, 0.0, 0.0],
14 | "mean_shap_value_test": [-0.5, 0.0, 0.0],
15 | "mean_shap_value_train": [-0.5, 0.0, 0.0],
16 | },
17 | index=["col_3", "col_1", "col_2"],
18 | )
19 |
20 |
21 | @pytest.fixture(scope="function")
22 | def expected_feature_importance_lin_models():
23 | return pd.DataFrame(
24 | {
25 | "mean_abs_shap_value_test": [0.4, 0.0, 0.0],
26 | "mean_abs_shap_value_train": [0.4, 0.0, 0.0],
27 | "mean_shap_value_test": [-0.4, 0.0, 0.0],
28 | "mean_shap_value_train": [-0.4, 0.0, 0.0],
29 | },
30 | index=["col_3", "col_1", "col_2"],
31 | )
32 |
33 |
34 | def test_shap_interpret(fitted_tree, X_train, y_train, X_test, y_test, expected_feature_importance, random_state):
35 | class_names = ["neg", "pos"]
36 |
37 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state)
38 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names)
39 |
40 | assert shap_interpret.class_names == class_names
41 | assert shap_interpret.train_score == 1
42 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01)
43 |
44 | # Check expected shap values
45 | assert (np.mean(np.abs(shap_interpret.shap_values_test), axis=0) == [0, 0, 0.5]).all()
46 | assert (np.mean(np.abs(shap_interpret.shap_values_train), axis=0) == [0, 0, 0.5]).all()
47 |
48 | importance_df, train_auc, test_auc = shap_interpret.compute(return_scores=True)
49 |
50 | pd.testing.assert_frame_equal(expected_feature_importance, importance_df)
51 | assert train_auc == 1
52 | assert test_auc == pytest.approx(0.833, 0.01)
53 |
54 | # Check if plots work for such dataset
55 | ax1 = shap_interpret.plot("importance", target_set="test", show=False)
56 | ax2 = shap_interpret.plot("summary", target_set="test", show=False)
57 | ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
58 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
59 | ax5 = shap_interpret.plot("importance", target_set="train", show=False)
60 | ax6 = shap_interpret.plot("summary", target_set="train", show=False)
61 | ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False)
62 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
63 | assert not (isinstance(ax1, list))
64 | assert not (isinstance(ax2, list))
65 | assert isinstance(ax3, list) and len(ax4) == 2
66 | assert isinstance(ax4, list) and len(ax4) == 2
67 | assert not (isinstance(ax5, list))
68 | assert not (isinstance(ax6, list))
69 | assert isinstance(ax7, list) and len(ax7) == 2
70 | assert isinstance(ax8, list) and len(ax8) == 2
71 |
72 |
73 | def test_shap_interpret_lin_models(
74 | fitted_logistic_regression, X_train, y_train, X_test, y_test, expected_feature_importance_lin_models, random_state
75 | ):
76 | class_names = ["neg", "pos"]
77 |
78 | shap_interpret = ShapModelInterpreter(fitted_logistic_regression, random_state=random_state)
79 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names)
80 |
81 | assert shap_interpret.class_names == class_names
82 | assert shap_interpret.train_score == 1
83 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01)
84 |
85 | # Check expected shap values
86 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_test), axis=0), 2) == [0, 0, 0.4]).all()
87 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_train), axis=0), 2) == [0, 0, 0.4]).all()
88 |
89 | importance_df, train_auc, test_auc = shap_interpret.compute(return_scores=True)
90 | importance_df = importance_df.round(2)
91 |
92 | pd.testing.assert_frame_equal(expected_feature_importance_lin_models, importance_df)
93 | assert train_auc == 1
94 | assert test_auc == pytest.approx(0.833, 0.01)
95 |
96 | # Check if plots work for such dataset
97 | ax1 = shap_interpret.plot("importance", target_set="test", show=False)
98 | ax2 = shap_interpret.plot("summary", target_set="test", show=False)
99 | ax3 = shap_interpret.plot("dependence", target_columns="col_3", target_set="test", show=False)
100 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
101 | ax5 = shap_interpret.plot("importance", target_set="train", show=False)
102 | ax6 = shap_interpret.plot("summary", target_set="train", show=False)
103 | ax7 = shap_interpret.plot("dependence", target_columns="col_3", target_set="train", show=False)
104 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
105 | assert not (isinstance(ax1, list))
106 | assert not (isinstance(ax2, list))
107 | assert isinstance(ax3, list) and len(ax4) == 2
108 | assert isinstance(ax4, list) and len(ax4) == 2
109 | assert not (isinstance(ax5, list))
110 | assert not (isinstance(ax6, list))
111 | assert isinstance(ax7, list) and len(ax7) == 2
112 | assert isinstance(ax8, list) and len(ax8) == 2
113 |
114 |
115 | def test_shap_interpret_fit_compute_lin_models(
116 | fitted_logistic_regression, X_train, y_train, X_test, y_test, expected_feature_importance_lin_models, random_state
117 | ):
118 | class_names = ["neg", "pos"]
119 |
120 | shap_interpret = ShapModelInterpreter(fitted_logistic_regression, random_state=random_state)
121 | importance_df = shap_interpret.fit_compute(X_train, X_test, y_train, y_test, class_names=class_names)
122 | importance_df = importance_df.round(2)
123 |
124 | assert shap_interpret.class_names == class_names
125 | assert shap_interpret.train_score == 1
126 |
127 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01)
128 |
129 | # Check expected shap values
130 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_test), axis=0), 2) == [0, 0, 0.4]).all()
131 | assert (np.round(np.mean(np.abs(shap_interpret.shap_values_train), axis=0), 2) == [0, 0, 0.4]).all()
132 |
133 | pd.testing.assert_frame_equal(expected_feature_importance_lin_models, importance_df)
134 |
135 |
136 | def test_shap_interpret_fit_compute(
137 | fitted_tree, X_train, y_train, X_test, y_test, expected_feature_importance, random_state
138 | ):
139 | class_names = ["neg", "pos"]
140 |
141 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state)
142 | importance_df = shap_interpret.fit_compute(X_train, X_test, y_train, y_test, class_names=class_names)
143 |
144 | assert shap_interpret.class_names == class_names
145 | assert shap_interpret.train_score == 1
146 | assert shap_interpret.test_score == pytest.approx(0.833, 0.01)
147 |
148 | # Check expected shap values
149 | assert (np.mean(np.abs(shap_interpret.shap_values_test), axis=0) == [0, 0, 0.5]).all()
150 | assert (np.mean(np.abs(shap_interpret.shap_values_train), axis=0) == [0, 0, 0.5]).all()
151 |
152 | pd.testing.assert_frame_equal(expected_feature_importance, importance_df)
153 |
154 |
155 | def test_shap_interpret_complex_data(complex_data_split_with_categorical, complex_fitted_lightgbm, random_state):
156 | class_names = ["neg", "pos"]
157 | X_train, X_test, y_train, y_test = complex_data_split_with_categorical
158 |
159 | shap_interpret = ShapModelInterpreter(complex_fitted_lightgbm, verbose=1, random_state=random_state)
160 | importance_df = shap_interpret.fit_compute(
161 | X_train, X_test, y_train, y_test, class_names=class_names, approximate=False, check_additivity=False
162 | )
163 |
164 | assert shap_interpret.class_names == class_names
165 | assert importance_df.shape[0] == X_train.shape[1]
166 |
167 | # Check if plots work for such dataset
168 | ax1 = shap_interpret.plot("importance", target_set="test", show=False)
169 | ax2 = shap_interpret.plot("summary", target_set="test", show=False)
170 | ax3 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="test", show=False)
171 | ax4 = shap_interpret.plot("sample", samples_index=X_test.index.tolist()[0:2], target_set="test", show=False)
172 | ax5 = shap_interpret.plot("importance", target_set="train", show=False)
173 | ax6 = shap_interpret.plot("summary", target_set="train", show=False)
174 | ax7 = shap_interpret.plot("dependence", target_columns="f2_missing", target_set="train", show=False)
175 | ax8 = shap_interpret.plot("sample", samples_index=X_train.index.tolist()[0:2], target_set="train", show=False)
176 | assert not (isinstance(ax1, list))
177 | assert not (isinstance(ax2, list))
178 | assert isinstance(ax3, list) and len(ax4) == 2
179 | assert isinstance(ax4, list) and len(ax4) == 2
180 | assert not (isinstance(ax5, list))
181 | assert not (isinstance(ax6, list))
182 | assert isinstance(ax7, list) and len(ax7) == 2
183 | assert isinstance(ax8, list) and len(ax8) == 2
184 |
--------------------------------------------------------------------------------
/tests/interpret/test_shap_dependence.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from sklearn.ensemble import RandomForestClassifier
7 |
8 | from probatus.interpret.shap_dependence import DependencePlotter
9 | from probatus.utils.exceptions import NotFittedError
10 |
11 | # Turn off interactive mode in plots
12 | plt.ioff()
13 | matplotlib.use("Agg")
14 |
15 |
16 | @pytest.fixture(scope="function")
17 | def X_y():
18 | return (
19 | pd.DataFrame(
20 | [
21 | [1.72568193, 2.21070436, 1.46039061],
22 | [-1.48382902, 2.88364928, 0.22323996],
23 | [-0.44947744, 0.85434638, -2.54486421],
24 | [-1.38101231, 1.77505901, -1.36000132],
25 | [-0.18261804, -0.25829609, 1.46925993],
26 | [0.27514902, 0.09608222, 0.7221381],
27 | [-0.27264455, 1.99366793, -2.62161046],
28 | [-2.81587587, 3.46459717, -0.11740999],
29 | [1.48374489, 0.79662903, 1.18898706],
30 | [-1.27251335, -1.57344342, -0.39540133],
31 | [0.31532891, 0.38299269, 1.29998754],
32 | [-2.10917352, -0.70033132, -0.89922129],
33 | [-2.14396343, -0.44549774, -1.80572922],
34 | [-3.4503348, 3.43476247, -0.74957725],
35 | [-1.25945582, -1.7234203, -0.77435353],
36 | ]
37 | ),
38 | pd.Series([1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0]),
39 | )
40 |
41 |
42 | @pytest.fixture(scope="function")
43 | def expected_shap_vals():
44 | return pd.DataFrame(
45 | [
46 | [0.176667, 0.005833, 0.284167],
47 | [-0.042020, 0.224520, 0.284167],
48 | [-0.092020, -0.135480, -0.205833],
49 | [-0.092020, -0.135480, -0.205833],
50 | [0.002424, 0.000909, 0.263333],
51 | [0.176667, 0.105833, 0.284167],
52 | [-0.092020, -0.135480, -0.205833],
53 | [-0.028687, 0.311187, 0.184167],
54 | [0.176667, 0.005833, 0.284167],
55 | [-0.092020, -0.164646, -0.076667],
56 | [0.176667, 0.105833, 0.284167],
57 | [-0.092020, -0.164646, -0.176667],
58 | [-0.092020, -0.164646, -0.176667],
59 | [-0.108687, 0.081187, -0.205833],
60 | [-0.092020, -0.164646, -0.176667],
61 | ]
62 | )
63 |
64 |
65 | @pytest.fixture(scope="function")
66 | def model(X_y, random_state):
67 | X, y = X_y
68 |
69 | model = RandomForestClassifier(random_state=random_state, n_estimators=10, max_depth=5)
70 |
71 | model.fit(X, y)
72 | return model
73 |
74 |
75 | @pytest.fixture(scope="function")
76 | def expected_feat_importances():
77 | return pd.DataFrame(
78 | {
79 | "Feature Name": {0: 2, 1: 1, 2: 0},
80 | "Shap absolute importance": {0: 0.2199, 1: 0.1271, 2: 0.1022},
81 | "Shap signed importance": {0: 0.0292, 1: -0.0149, 2: -0.0076},
82 | }
83 | )
84 |
85 |
86 | def test_not_fitted(model, random_state):
87 | plotter = DependencePlotter(model, random_state)
88 | assert plotter.fitted is False
89 |
90 |
91 | def test_fit_complex(complex_data_split, complex_fitted_lightgbm, random_state):
92 | _, X_test, _, y_test = complex_data_split
93 |
94 | plotter = DependencePlotter(complex_fitted_lightgbm, random_state=random_state)
95 |
96 | plotter.fit(X_test, y_test)
97 |
98 | pd.testing.assert_frame_equal(plotter.X, X_test)
99 | pd.testing.assert_series_equal(plotter.y, pd.Series(y_test, index=X_test.index))
100 | assert plotter.fitted is True
101 |
102 | # Check if plotting does not cause errors
103 | _ = plotter.plot(feature="f2_missing", show=False)
104 |
105 |
106 | def test_get_X_y_shap_with_q_cut_normal(X_y, model, random_state):
107 | X, y = X_y
108 |
109 | plotter = DependencePlotter(model, random_state).fit(X, y)
110 | plotter.min_q, plotter.max_q = 0, 1
111 |
112 | X_cut, y_cut, _ = plotter._get_X_y_shap_with_q_cut(0)
113 | assert np.isclose(X[0], X_cut).all()
114 | assert y.equals(y_cut)
115 |
116 | plotter.min_q = 0.2
117 | plotter.max_q = 0.8
118 |
119 | X_cut, y_cut, _ = plotter._get_X_y_shap_with_q_cut(0)
120 | assert np.isclose(
121 | X_cut,
122 | [
123 | -1.48382902,
124 | -0.44947744,
125 | -1.38101231,
126 | -0.18261804,
127 | 0.27514902,
128 | -0.27264455,
129 | -1.27251335,
130 | -2.10917352,
131 | -1.25945582,
132 | ],
133 | ).all()
134 | assert np.equal(y_cut.values, [1, 0, 0, 1, 1, 0, 0, 0, 0]).all()
135 |
136 |
137 | def test_get_X_y_shap_with_q_cut_unfitted(model, random_state):
138 | plotter = DependencePlotter(model, random_state)
139 | with pytest.raises(NotFittedError):
140 | plotter._get_X_y_shap_with_q_cut(0)
141 |
142 |
143 | def test_get_X_y_shap_with_q_cut_input(X_y, model, random_state):
144 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1])
145 | with pytest.raises(ValueError):
146 | plotter._get_X_y_shap_with_q_cut("not a feature")
147 |
148 |
149 | def test_plot_normal(X_y, model, random_state):
150 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1])
151 | _ = plotter.plot(feature=0)
152 |
153 |
154 | def test_plot_class_names(X_y, model, random_state):
155 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1], class_names=["a", "b"])
156 | _ = plotter.plot(feature=0)
157 | assert plotter.class_names == ["a", "b"]
158 |
159 |
160 | def test_plot_input(X_y, model, random_state):
161 | plotter = DependencePlotter(model, random_state).fit(X_y[0], X_y[1])
162 | with pytest.raises(ValueError):
163 | plotter.plot(feature="not a feature")
164 | with pytest.raises(TypeError):
165 | plotter.plot(feature=0, bins=5.0)
166 | with pytest.raises(ValueError):
167 | plotter.plot(feature=0, min_q=1, max_q=0)
168 |
169 |
170 | def test__repr__(model, random_state):
171 | """
172 | Test string representation.
173 | """
174 | plotter = DependencePlotter(model, random_state)
175 | assert str(plotter) == "Shap dependence plotter for RandomForestClassifier"
176 |
--------------------------------------------------------------------------------
/tests/mocks.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock
2 |
3 | # These are shell classes that define the methods of the models that we use. Each of the functions that we use needs
4 | # To be defined inside these shell classes. Then when we want to write a specific test you need to simply mock.patch
5 | # the desired functionality. You can also set the return_value to the patched method.
6 |
7 |
8 | class MockClusterer(Mock):
9 | """
10 | MockCluster.
11 | """
12 |
13 | def __init__(self):
14 | """
15 | Init.
16 | """
17 | pass
18 |
19 | def fit(self):
20 | """
21 | Fit.
22 | """
23 | pass
24 |
25 | def predict(self):
26 | """
27 | Predict.
28 | """
29 | pass
30 |
31 | def fit_predict(self):
32 | """
33 | Both.
34 | """
35 | pass
36 |
37 |
38 | class MockModel(Mock):
39 | """
40 | Mockmodel.
41 | """
42 |
43 | def __init__(self, **kwargs):
44 | """
45 | Init.
46 | """
47 | pass
48 |
49 | def fit(self):
50 | """
51 | Fit.
52 | """
53 | pass
54 |
55 | def predict(self):
56 | """
57 | Predict.
58 | """
59 | pass
60 |
61 | def predict_proba(self):
62 | """
63 | Both.
64 | """
65 | pass
66 |
--------------------------------------------------------------------------------
/tests/sample_similarity/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/sample_similarity/__init__.py
--------------------------------------------------------------------------------
/tests/sample_similarity/test_resemblance_model.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from pandas.api.types import is_numeric_dtype
7 |
8 | from probatus.sample_similarity import BaseResemblanceModel, PermutationImportanceResemblance, SHAPImportanceResemblance
9 |
10 | # Turn off interactive mode in plots
11 | plt.ioff()
12 | matplotlib.use("Agg")
13 |
14 |
15 | @pytest.fixture(scope="function")
16 | def X1():
17 | return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [0, 0, 0, 0]}, index=[1, 2, 3, 4])
18 |
19 |
20 | @pytest.fixture(scope="function")
21 | def X2():
22 | return pd.DataFrame({"col_1": [0, 0, 0, 0], "col_2": [0, 0, 0, 0], "col_3": [0, 0, 0, 0]}, index=[1, 2, 3, 4])
23 |
24 |
25 | def test_base_class(X1, X2, decision_tree_classifier, random_state):
26 | rm = BaseResemblanceModel(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state)
27 |
28 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True)
29 |
30 | assert train_score == 1
31 | assert test_score == 1
32 | assert actual_report is None
33 |
34 | # Check data splits if correct
35 | actual_X_train, actual_X_test, actual_y_train, actual_y_test = rm.get_data_splits()
36 |
37 | assert actual_X_train.shape == (4, 3)
38 | assert actual_X_test.shape == (4, 3)
39 | assert len(actual_y_train) == 4
40 | assert len(actual_y_test) == 4
41 |
42 | # Check if stratified
43 | assert np.sum(actual_y_train) == 2
44 | assert np.sum(actual_y_test) == 2
45 |
46 | # Check if index is correct
47 | assert len(rm.X.index.unique()) == 8
48 | assert list(rm.X.index) == list(rm.y.index)
49 |
50 | with pytest.raises(NotImplementedError) as _:
51 | rm.plot()
52 |
53 |
54 | def test_base_class_lin_models(X1, X2, logistic_regression, random_state):
55 | # Test class BaseResemblanceModel for linear models.
56 | rm = BaseResemblanceModel(logistic_regression, test_prc=0.5, n_jobs=1, random_state=random_state)
57 |
58 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True)
59 |
60 | assert train_score == 1
61 | assert test_score == 1
62 | assert actual_report is None
63 |
64 | # Check data splits if correct
65 | actual_X_train, actual_X_test, actual_y_train, actual_y_test = rm.get_data_splits()
66 |
67 | assert actual_X_train.shape == (4, 3)
68 | assert actual_X_test.shape == (4, 3)
69 | assert len(actual_y_train) == 4
70 | assert len(actual_y_test) == 4
71 |
72 | # Check if stratified
73 | assert np.sum(actual_y_train) == 2
74 | assert np.sum(actual_y_test) == 2
75 |
76 | # Check if index is correct
77 | assert len(rm.X.index.unique()) == 8
78 | assert list(rm.X.index) == list(rm.y.index)
79 |
80 | with pytest.raises(NotImplementedError) as _:
81 | rm.plot()
82 |
83 |
84 | def test_shap_resemblance_class(X1, X2, decision_tree_classifier, random_state):
85 | rm = SHAPImportanceResemblance(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state)
86 |
87 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True)
88 |
89 | assert train_score == 1
90 | assert test_score == 1
91 |
92 | # Check report shape
93 | assert actual_report.shape == (3, 2)
94 | # Check if it is sorted by importance
95 | assert actual_report.iloc[0].name == "col_1"
96 | # Check report values
97 | assert actual_report.loc["col_1"]["mean_abs_shap_value"] > 0
98 | assert actual_report.loc["col_1"]["mean_shap_value"] < 0
99 | assert actual_report.loc["col_2"]["mean_abs_shap_value"] == 0
100 | assert actual_report.loc["col_2"]["mean_shap_value"] == 0
101 | assert actual_report.loc["col_3"]["mean_abs_shap_value"] == 0
102 | assert actual_report.loc["col_3"]["mean_shap_value"] == 0
103 |
104 | actual_shap_values_test = rm.get_shap_values()
105 | assert actual_shap_values_test.shape == (4, 3)
106 |
107 | # Run plots
108 | rm.plot(plot_type="bar")
109 | rm.plot(plot_type="dot")
110 |
111 |
112 | def test_shap_resemblance_class_lin_models(X1, X2, logistic_regression, random_state):
113 | # Test SHAP Resemblance Model for linear models.
114 | rm = SHAPImportanceResemblance(logistic_regression, test_prc=0.5, n_jobs=1, random_state=random_state)
115 |
116 | actual_report, train_score, test_score = rm.fit_compute(
117 | X1, X2, return_scores=True, approximate=True, check_additivity=False
118 | )
119 |
120 | assert train_score == 1
121 | assert test_score == 1
122 |
123 | # Check report shape
124 | assert actual_report.shape == (3, 2)
125 | # Check if it is sorted by importance
126 | assert actual_report.iloc[0].name == "col_1"
127 | # Check report values
128 | assert actual_report.loc["col_1"]["mean_abs_shap_value"] > 0
129 | assert actual_report.loc["col_1"]["mean_shap_value"] < 0
130 | assert actual_report.loc["col_2"]["mean_abs_shap_value"] == 0
131 | assert actual_report.loc["col_2"]["mean_shap_value"] == 0
132 | assert actual_report.loc["col_3"]["mean_abs_shap_value"] == 0
133 | assert actual_report.loc["col_3"]["mean_shap_value"] == 0
134 |
135 | actual_shap_values_test = rm.get_shap_values()
136 | assert actual_shap_values_test.shape == (4, 3)
137 |
138 | # Run plots
139 | rm.plot(plot_type="bar")
140 | rm.plot(plot_type="dot")
141 |
142 |
143 | def test_shap_resemblance_class2(complex_data_with_categorical, complex_lightgbm, random_state):
144 | X1, _ = complex_data_with_categorical
145 | X2 = X1.copy()
146 | X2["f4"] = X2["f4"] + 100
147 |
148 | rm = SHAPImportanceResemblance(
149 | complex_lightgbm, scoring="accuracy", test_prc=0.5, n_jobs=1, random_state=random_state
150 | )
151 |
152 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True, class_names=["a", "b"])
153 |
154 | # Check if the X and y within the rm have correct types
155 | assert rm.X["f1_categorical"].dtype.name == "category"
156 | for num_column in ["f2_missing", "f3_static", "f4", "f5"]:
157 | assert is_numeric_dtype(rm.X[num_column])
158 |
159 | assert train_score == pytest.approx(1, 0.05)
160 | assert test_score == pytest.approx(1, 0.05)
161 |
162 | # Check report shape
163 | assert actual_report.shape == (5, 2)
164 | # Check if it is sorted by importance
165 | assert actual_report.iloc[0].name == "f4"
166 |
167 | # Check report values
168 | assert actual_report.loc["f4"]["mean_abs_shap_value"] > 0
169 |
170 | actual_shap_values_test = rm.get_shap_values()
171 | # 50 test samples and 5 features
172 | assert actual_shap_values_test.shape == (X1.shape[0], X1.shape[1])
173 |
174 | # Run plots
175 | rm.plot(plot_type="bar", show=True)
176 | rm.plot(plot_type="dot", show=False)
177 |
178 |
179 | def test_permutation_resemblance_class(X1, X2, decision_tree_classifier, random_state):
180 | rm = PermutationImportanceResemblance(
181 | decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state, iterations=20
182 | )
183 |
184 | actual_report, train_score, test_score = rm.fit_compute(X1, X2, return_scores=True)
185 |
186 | assert train_score == 1
187 | assert test_score == 1
188 |
189 | # Check report shape
190 | assert actual_report.shape == (3, 2)
191 | # Check if it is sorted by importance
192 | assert actual_report.iloc[0].name == "col_1"
193 | # Check report values
194 | assert actual_report.loc["col_1"]["mean_importance"] > 0
195 | assert actual_report.loc["col_1"]["std_importance"] > 0
196 | assert actual_report.loc["col_2"]["mean_importance"] == 0
197 | assert actual_report.loc["col_2"]["std_importance"] == 0
198 | assert actual_report.loc["col_3"]["mean_importance"] == 0
199 | assert actual_report.loc["col_3"]["std_importance"] == 0
200 |
201 | rm.plot(figsize=(10, 10))
202 | # Check plot size
203 | fig = plt.gcf()
204 | size = fig.get_size_inches()
205 | assert size[0] == 10 and size[1] == 10
206 |
207 |
208 | def test_base_class_same_data(X1, decision_tree_classifier, random_state):
209 | rm = BaseResemblanceModel(decision_tree_classifier, test_prc=0.5, n_jobs=1, random_state=random_state)
210 |
211 | actual_report, train_score, test_score = rm.fit_compute(X1, X1, return_scores=True)
212 |
213 | assert train_score == 0.5
214 | assert test_score == 0.5
215 | assert actual_report is None
216 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/probatus/04e8393b6c5416a3e01da05416aac4fdc6d94892/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_base_class.py:
--------------------------------------------------------------------------------
1 | from probatus.interpret import ShapModelInterpreter
2 | import pytest
3 | from probatus.utils import NotFittedError
4 |
5 |
6 | def test_fitted_exception(fitted_tree, X_train, y_train, X_test, y_test, random_state):
7 | class_names = ["neg", "pos"]
8 |
9 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state)
10 |
11 | # Before fit it should raise an exception
12 | with pytest.raises(NotFittedError) as _:
13 | shap_interpret._check_if_fitted()
14 |
15 | shap_interpret.fit(X_train, X_test, y_train, y_test, class_names=class_names)
16 |
17 | # Check parameters
18 | assert shap_interpret.fitted
19 | shap_interpret._check_if_fitted
20 |
21 |
22 | @pytest.mark.xfail
23 | def test_fitted_exception_is_raised(fitted_tree, random_state):
24 | shap_interpret = ShapModelInterpreter(fitted_tree, random_state=random_state)
25 |
26 | shap_interpret._check_if_fitted
27 |
--------------------------------------------------------------------------------
/tests/utils/test_utils_array_funcs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 |
5 | from probatus.utils import assure_pandas_df, preprocess_data, preprocess_labels
6 |
7 |
8 | @pytest.fixture(scope="function")
9 | def expected_df_2d():
10 | return pd.DataFrame({0: [1, 2], 1: [2, 3], 2: [3, 4]})
11 |
12 |
13 | @pytest.fixture(scope="function")
14 | def expected_df():
15 | return pd.DataFrame({0: [1, 2, 3]})
16 |
17 |
18 | def test_assure_pandas_df_list(expected_df):
19 | x = [1, 2, 3]
20 | x_df = assure_pandas_df(x)
21 | pd.testing.assert_frame_equal(x_df, expected_df)
22 |
23 |
24 | def test_assure_pandas_df_list_of_lists(expected_df_2d):
25 | x = [[1, 2, 3], [2, 3, 4]]
26 | x_df = assure_pandas_df(x)
27 | pd.testing.assert_frame_equal(x_df, expected_df_2d)
28 |
29 |
30 | def test_assure_pandas_df_series(expected_df):
31 | x = pd.Series([1, 2, 3])
32 | x_df = assure_pandas_df(x)
33 | pd.testing.assert_frame_equal(x_df, expected_df)
34 |
35 |
36 | def test_assure_pandas_df_array(expected_df, expected_df_2d):
37 | x = np.array([[1, 2, 3], [2, 3, 4]], dtype="int64")
38 | x_df = assure_pandas_df(x)
39 | pd.testing.assert_frame_equal(x_df, expected_df_2d)
40 |
41 | x = np.array([1, 2, 3], dtype="int64")
42 | x_df = assure_pandas_df(x)
43 | pd.testing.assert_frame_equal(x_df, expected_df)
44 |
45 |
46 | def test_assure_pandas_df_df(expected_df_2d):
47 | x = pd.DataFrame([[1, 2, 3], [2, 3, 4]])
48 | x_df = assure_pandas_df(x)
49 | pd.testing.assert_frame_equal(x_df, expected_df_2d)
50 |
51 |
52 | def test_assure_pandas_df_types():
53 | with pytest.raises(TypeError):
54 | assure_pandas_df("Test")
55 | with pytest.raises(TypeError):
56 | assure_pandas_df(5)
57 |
58 |
59 | def test_preprocess_labels():
60 | y1 = pd.Series([1, 0, 1, 0, 1])
61 | index_1 = np.array([5, 4, 3, 2, 1])
62 |
63 | y1_output = preprocess_labels(y1, y_name="y1", index=index_1, verbose=2)
64 | pd.testing.assert_series_equal(y1_output, pd.Series([1, 0, 1, 0, 1], index=index_1))
65 |
66 | y2 = [False, False, False, False, False]
67 | y2_output = preprocess_labels(y2, y_name="y2", verbose=2)
68 | pd.testing.assert_series_equal(y2_output, pd.Series(y2))
69 |
70 | y3 = np.array([0, 1, 2, 3, 4])
71 | y3_output = preprocess_labels(y3, y_name="y3", verbose=2)
72 | pd.testing.assert_series_equal(y3_output, pd.Series(y3))
73 |
74 | y4 = pd.Series(["2", "1", "3", "2", "1"])
75 | index4 = pd.Index([0, 2, 1, 3, 4])
76 | y4_output = preprocess_labels(y4, y_name="y4", index=index4, verbose=0)
77 | pd.testing.assert_series_equal(y4_output, pd.Series(["2", "3", "1", "2", "1"], index=index4))
78 |
79 |
80 | def test_preprocess_data():
81 | X1 = pd.DataFrame({"cat": ["a", "b", "c"], "missing": [1, np.nan, 2], "num_1": [1, 2, 3]})
82 |
83 | target_column_names_X1 = ["1", "2", "3"]
84 | X1_expected_output = pd.DataFrame({"1": ["a", "b", "c"], "2": [1, np.nan, 2], "3": [1, 2, 3]})
85 |
86 | X1_expected_output["1"] = X1_expected_output["1"].astype("category")
87 | X1_output, output_column_names_X1 = preprocess_data(X1, X_name="X1", column_names=target_column_names_X1, verbose=2)
88 | assert target_column_names_X1 == output_column_names_X1
89 | pd.testing.assert_frame_equal(X1_output, X1_expected_output)
90 |
91 | X2 = np.array([[1, 3, 2], [1, 2, 2], [1, 2, 3]])
92 |
93 | target_column_names_X1 = [0, 1, 2]
94 | X2_expected_output = pd.DataFrame(X2, columns=target_column_names_X1)
95 | X2_output, output_column_names_X2 = preprocess_data(X2, X_name="X2", column_names=None, verbose=2)
96 |
97 | assert target_column_names_X1 == output_column_names_X2
98 | pd.testing.assert_frame_equal(X2_output, X2_expected_output)
99 |
--------------------------------------------------------------------------------