├── .github ├── dependabot.yml └── workflows │ ├── docs.yml │ ├── release.yml │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── docs ├── api.md ├── examples.md ├── images │ ├── glm_bar.png │ ├── glm_beeswarm.png │ ├── glm_scatter.png │ ├── glm_waterfall.png │ ├── logo.png │ ├── logo.svg │ ├── tree_bar.png │ ├── tree_beeswarm.png │ ├── tree_scatter.png │ └── tree_waterfall.png └── index.md ├── mkdocs.yml ├── pyproject.toml └── src └── lightshap ├── __init__.py ├── _version.py ├── explainers ├── __init__.py ├── _utils.py ├── explain_any.py ├── explain_tree.py ├── kernel_utils.py ├── parallel.py ├── permutation_utils.py └── tests │ ├── test_explain_any.py │ ├── test_explain_tree.py │ ├── test_explainer_utils.py │ ├── test_kernel_utils.py │ ├── test_parallel.py │ └── test_permutation_utils.py ├── explanation ├── __init__.py ├── _utils.py ├── explanation.py ├── explanationplotter.py └── tests │ ├── test_explanation.py │ ├── test_explanation_utils.py │ └── test_plots.py ├── tests └── test_utils.py └── utils.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | open-pull-requests-limit: 10 8 | 9 | - package-ecosystem: "github-actions" 10 | directory: "/" 11 | schedule: 12 | interval: "weekly" 13 | open-pull-requests-limit: 5 14 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | concurrency: 15 | group: "pages" 16 | cancel-in-progress: false 17 | 18 | jobs: 19 | build: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: actions/checkout@v5 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: '3.11' 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | # Use compatible versions to avoid griffe.collections error 33 | pip install mkdocs mkdocs-material mkdocstrings[python] griffe 34 | pip install numpy pandas matplotlib 35 | 36 | - name: Build documentation 37 | env: 38 | PYTHONPATH: ${{ github.workspace }}/src 39 | run: mkdocs build --clean 40 | 41 | - name: Upload artifact 42 | uses: actions/upload-pages-artifact@v4 43 | with: 44 | path: ./site 45 | 46 | deploy: 47 | environment: 48 | name: github-pages 49 | url: ${{ steps.deployment.outputs.page_url }} 50 | runs-on: ubuntu-latest 51 | needs: build 52 | steps: 53 | - name: Deploy to GitHub Pages 54 | id: deployment 55 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | release: 9 | runs-on: ubuntu-latest 10 | environment: release 11 | 12 | permissions: 13 | id-token: write 14 | 15 | steps: 16 | - uses: actions/checkout@v5 17 | with: 18 | ref: ${{ github.event.release.tag_name }} # Checkout the exact tag 19 | fetch-depth: 0 # Full history for hatch-vcs 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v6 23 | with: 24 | python-version: "3.11" 25 | 26 | - name: Install Hatch 27 | run: pip install hatch 28 | 29 | - name: Build package 30 | run: hatch build 31 | 32 | - name: Publish to PyPI 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | verbose: true 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: CI - Test 2 | 3 | on: 4 | push: 5 | branches: [ main, develop ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest, windows-latest, macos-latest] 15 | python-version: ["3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v5 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v6 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | # Install OpenMP for macOS to fix XGBoost/LightGBM issues 26 | - name: Install OpenMP on macOS 27 | if: runner.os == 'macOS' 28 | run: brew install libomp 29 | 30 | - name: Install ruff (lint/format) 31 | run: pip install ruff 32 | 33 | - name: Run linting 34 | run: ruff check src/ 35 | 36 | - name: Run formatting check 37 | run: ruff format --check src/ 38 | 39 | - name: Cache pip 40 | uses: actions/cache@v4 41 | with: 42 | path: ~/.cache/pip 43 | key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} 44 | restore-keys: | 45 | ${{ runner.os }}-pip- 46 | 47 | - name: Install package + dev & optional dependencies 48 | run: | 49 | python -m pip install --upgrade pip 50 | # install package and both dev + optional extras so tests that import polars, pyarrow, matplotlib, ... 51 | # succeed in CI 52 | python -m pip install -e ".[dev,all]" 53 | - name: Run tests 54 | run: python -m pytest --cov=lightshap --cov-branch --cov-report=xml --cov-report=html --cov-report=term-missing 55 | 56 | - name: Upload coverage reports to Codecov 57 | uses: codecov/codecov-action@v5 58 | with: 59 | token: ${{ secrets.CODECOV_TOKEN }} 60 | slug: mayer79/LightSHAP 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # IPython 58 | profile_default/ 59 | ipython_config.py 60 | 61 | # pyenv 62 | # For a library or package, you might want to ignore these files since the code is 63 | # intended to run in multiple environments; otherwise, check them in: 64 | # .python-version 65 | 66 | # pipenv 67 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 68 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 69 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 70 | # install all needed dependencies. 71 | #Pipfile.lock 72 | 73 | # poetry 74 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 75 | # This is especially recommended for binary packages to ensure reproducibility, and is more 76 | # commonly ignored for libraries. 77 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 78 | #poetry.lock 79 | 80 | # pdm 81 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 82 | #pdm.lock 83 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 84 | # in version control. 85 | # https://pdm.fming.dev/#use-with-ide 86 | .pdm.toml 87 | 88 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 89 | __pypackages__/ 90 | 91 | # Celery stuff 92 | celerybeat-schedule 93 | celerybeat.pid 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # pytype static type analyzer 126 | .pytype/ 127 | 128 | # Cython debug symbols 129 | cython_debug/ 130 | 131 | # PyCharm 132 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 133 | # be added to the global gitignore or merged into this project gitignore. For a PyCharmAIS... 134 | # (PyCharm File Watchers) 135 | .idea/ 136 | 137 | # VS Code 138 | .vscode/ 139 | 140 | # macOS 141 | .DS_Store 142 | 143 | # Windows 144 | Thumbs.db 145 | ehthumbs.db 146 | Desktop.ini 147 | 148 | # Temporary files 149 | *.tmp 150 | *.temp 151 | *~ 152 | 153 | # Log files 154 | *.log 155 | 156 | # Documentation build artifacts 157 | docs/_build/ 158 | docs/_static/ 159 | docs/_templates/ 160 | site/ 161 | 162 | # Test notebooks and scratch files 163 | **/scratch_*.ipynb 164 | **/test_*.ipynb 165 | **/permshap.ipynb 166 | 167 | # Backup files 168 | *.bak 169 | *.backup 170 | 171 | # IDE and editor files 172 | *.swp 173 | *.swo 174 | 175 | # Performance profiling 176 | *.prof -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [0.1.12] - 2025-10-14 4 | 5 | ### Fixed 6 | 7 | - Fixed workflow file 8 | 9 | ## [0.1.11] - 2025-10-14 10 | 11 | ### Fixed 12 | 13 | - Removed unnecessary __about__.py file 14 | - Import __version__ in __init__.py 15 | 16 | ## [0.1.10] - 2025-10-05 17 | 18 | ### Fixed 19 | 20 | - Removed small white lines in waterfall plot bars, caused by anti-aliasing issues in matplotlib 21 | - Fixed issue with text labels being cut off in waterfall plots when bars are too small 22 | 23 | ## [0.1.9] - 2025-10-03 24 | 25 | ### Added 26 | - Released package on PyPI 27 | 28 | ### Fixed 29 | - Resolved build issues by switching to static versioning (disabled hatch-vcs) 30 | 31 | ## [0.1.8] - 2025-10-03 32 | 33 | ### Added 34 | - Attempt for PyPI release (build issues encountered) 35 | 36 | ## [0.1.7] - 2025-10-03 37 | 38 | ### Added 39 | - Attempt for PyPI release (build issues encountered) 40 | 41 | ## [0.1.6] - 2025-09-05 42 | 43 | ### Fixed 44 | - Fixed release workflow 45 | 46 | ## [0.1.5] - 2025-09-05 47 | 48 | ### Added 49 | - More and better unit tests 50 | 51 | ### Fixed 52 | - Tests requiring xgboost, lightgbm, and catboost now use `pytest.importorskip()` for safer handling 53 | - Fixed OpenMP dependency issues on macOS in CI by installing `libomp` via homebrew 54 | 55 | ### Changed 56 | - Renamed some test files for consistency 57 | 58 | ## [0.1.4] - 2025-09-02 59 | 60 | ### Changed 61 | - Migrated from setuptools to hatchling build backend 62 | - Implemented dynamic versioning with hatch-vcs (version now reads from git tags) 63 | - Modernized build system configuration 64 | 65 | ## [0.1.3] - 2025-09-02 66 | 67 | ### Fixed 68 | - Fixed TestPyPI trusted publisher configuration error 69 | - Resolved "invalid-publisher" error in release workflow 70 | 71 | ## [0.1.2] - 2025-09-02 72 | 73 | ### Fixed 74 | - Fixed release workflow (removed unnecessary release environment) 75 | 76 | ## [0.1.1] - 2025-09-02 77 | 78 | ### Fixed 79 | - Fixed release workflow to only publish to TestPyPI (removed PyPI publishing) 80 | - Resolved workflow syntax errors and duplicate steps 81 | 82 | ## [0.1.0] - 2025-09-02 83 | 84 | ### Added 85 | - Initial beta release of LightSHAP 86 | - Model-agnostic SHAP via `explain_any()` function 87 | - Support for Permutation SHAP and Kernel SHAP 88 | - Exact and sampling methods with convergence detection 89 | - Hybrid approaches for large feature sets 90 | - TreeSHAP via `explain_tree()` function 91 | - Support for XGBoost, LightGBM, and CatBoost 92 | - Comprehensive visualization suite 93 | - Bar plots for feature importance 94 | - Beeswarm plots for summary visualization 95 | - Scatter plots to describe effects 96 | - Waterfall plots for individual explanations 97 | - Multi-output model support 98 | - Background data weighting 99 | - Parallel processing via joblib 100 | - Support for pandas, numpy, and polars DataFrames 101 | - Categorical feature handling 102 | - Standard error estimation for sampling methods 103 | 104 | ### Technical Details 105 | - Python 3.11+ support 106 | - Modern build system with Hatch 107 | - Comprehensive test suite with pytest 108 | - CI/CD pipeline with GitHub Actions 109 | - Code quality enforcement with Ruff 110 | 111 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Michael Mayer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Logo](./docs/images/logo.svg?raw=true) 2 | 3 | # LightSHAP 4 | 5 | | | | 6 | | --- | --- | 7 | | Package | [![PyPI - Python Version](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/) [![PyPI - Version](https://img.shields.io/pypi/v/lightshap)](https://pypi.org/project/lightshap/) [![License - MIT](https://img.shields.io/badge/license-MIT-9400d3.svg)](https://spdx.org/licenses/) [![GitHub release](https://img.shields.io/github/v/release/mayer79/LightSHAP)](https://github.com/mayer79/LightSHAP/releases) [![Development Status](https://img.shields.io/badge/status-beta-orange.svg)](https://github.com/mayer79/LightSHAP) | 8 | | CI/CD | [![CI - Test](https://github.com/mayer79/LightSHAP/actions/workflows/test.yml/badge.svg)](https://github.com/mayer79/LightSHAP/actions/workflows/test.yml) [![GitHub release](https://img.shields.io/github/v/release/mayer79/LightSHAP?label=release)](https://github.com/mayer79/LightSHAP/releases) | 9 | | Quality | [![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![codecov](https://codecov.io/gh/mayer79/LightSHAP/graph/badge.svg)](https://codecov.io/gh/mayer79/LightSHAP) [![GitHub issues](https://img.shields.io/github/issues/mayer79/LightSHAP)](https://github.com/mayer79/LightSHAP/issues) | 10 | | Meta | [![Hatch project](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch) [![GitHub contributors](https://img.shields.io/github/contributors/mayer79/LightSHAP)](https://github.com/mayer79/LightSHAP/graphs/contributors) | 11 | 12 | **Lightweight Python implementation of SHAP (SHapley Additive exPlanations).** 13 | 14 | 📖 **[Documentation](https://mayer79.github.io/LightSHAP/)** | 🚀 **[Examples](https://mayer79.github.io/LightSHAP/examples/)** | 📋 **[API Reference](https://mayer79.github.io/LightSHAP/api/)** 15 | 16 | ## Key Features 17 | 18 | - **Tree Models**: TreeSHAP wrappers for XGBoost, LightGBM, and CatBoost via `explain_tree()` 19 | - **Model-Agnostic**: Permutation SHAP and Kernel SHAP via `explain_any()` 20 | - **Visualization**: Flexible plots 21 | 22 | **Highlights of the agnostic explainer:** 23 | 24 | 1. Exact and sampling versions of permutation SHAP and Kernel SHAP 25 | 2. Sampling versions iterate until convergence, and provide standard errors 26 | 3. Parallel processing via joblib 27 | 4. Supports multi-output models 28 | 5. Supports case weights 29 | 6. Accepts numpy, pandas, and polars input, and categorical features 30 | 31 | **Some methods of the explanation object:** 32 | 33 | - `plot.bar()`: Feature importance bar plot 34 | - `plot.beeswarm()`: Summary beeswarm plot 35 | - `plot.scatter()`: Dependence plots 36 | - `plot.waterfall()`: Waterfall plot for individual explanations 37 | - `importance()`: Returns feature importance values 38 | - `set_X()`: Update explanation data, e.g., to replace a numpy array with a DataFrame 39 | - `set_feature_names()`: Set or update feature names 40 | - `select_output()`: Select a specific output for multi-output models 41 | - `filter()`: Subset explanations by condition or indices 42 | - ... 43 | 44 | ## Usage 45 | 46 | ```python 47 | from lightshap import explain_any, explain_tree 48 | 49 | # For any model 50 | explanation = explain_any(model.predict, X) 51 | 52 | # For tree models (XGBoost, LightGBM, CatBoost) 53 | explanation = explain_tree(model, X) 54 | # explanation.set_X(df) # Optional: replace array with DataFrame for better plots 55 | 56 | # Create plots 57 | explanation.plot.bar() # Feature importance 58 | explanation.plot.beeswarm() # Summary plot 59 | explanation.plot.scatter() # Dependence plots 60 | explanation.plot.waterfall() # Individual explanation 61 | ``` 62 | 63 | ## Gallery 64 | 65 | ![SHAP importance](docs/images/tree_bar.png?raw=true) 66 | 67 | ![SHAP summary](docs/images/tree_beeswarm.png?raw=true) 68 | 69 | ![SHAP dependence](docs/images/tree_scatter.png?raw=true) 70 | 71 | ![SHAP waterfall](docs/images/tree_waterfall.png?raw=true) 72 | 73 | ## Installation 74 | 75 | ```bash 76 | # From PyPI 77 | pip install lightshap 78 | 79 | # With all optional dependencies 80 | pip install lightshap[all] 81 | 82 | # From GitHub 83 | pip install git+https://github.com/mayer79/LightSHAP.git 84 | ``` 85 | 86 | Contributions are highly appreciated! When contributing, you agree that your contributions will be subject to the [MIT License](https://github.com/mayer79/lightshap/blob/main/LICENSE). 87 | 88 | Please feel free to open an issue for bug reports, feature requests, or general discussions. 89 | 90 | ## License 91 | 92 | MIT License - see [LICENSE](LICENSE) file for details. 93 | 94 | ## Acknowledgements 95 | 96 | LightSHAP builds on top of wonderful packages like numpy, pandas, and matplotlib. 97 | 98 | It is heavily influenced by these projects: 99 | 100 | [shap](https://github.com/slundberg/shap) | 101 | [shapley-regression](https://github.com/iancovert/shapley-regression) | 102 | [kernelshap](https://github.com/ModelOriented/kernelshap) | 103 | [shapviz](https://github.com/ModelOriented/shapviz) 104 | 105 | # References 106 | 107 |
108 | 109 | "A Unified Approach to Interpreting Model Predictions" (S. M. Lundberg and S.-I. Lee 2017) 110 | 111 |
112 |
113 | @incollection{lundberglee2017,
114 |  title = {A Unified Approach to Interpreting Model Predictions},
115 |  author = {Lundberg, Scott M and Lee, Su-In},
116 |  booktitle = {Advances in Neural Information Processing Systems 30},
117 |  editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
118 |  pages = {4765--4774},
119 |  year = {2017},
120 |  publisher = {Curran Associates, Inc.},
121 |  url = {https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf}
122 | }
123 | 
124 | Paper link 125 |
126 | 127 |
128 | 129 | "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression" (I. Covert and S.-I. Lee 2020) 130 | 131 |
132 |
133 | @inproceedings{covertlee2020,
134 |   title={Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression},
135 |   author={Ian Covert and Su-In Lee},
136 |   booktitle={International Conference on Artificial Intelligence and Statistics},
137 |   year={2020},
138 |   url={https://proceedings.mlr.press/v130/covert21a/covert21a.pdf}
139 | }
140 | 
141 | Paper link 142 |
143 | 144 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | ## Functions 4 | 5 | ::: lightshap.explain_tree 6 | options: 7 | show_root_heading: true 8 | heading_level: 3 9 | 10 | ::: lightshap.explain_any 11 | options: 12 | show_root_heading: true 13 | heading_level: 3 14 | 15 | ## Classes 16 | 17 | ::: lightshap.Explanation 18 | options: 19 | show_root_heading: true 20 | heading_level: 3 21 | members_order: source 22 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Example 1: CatBoost model for diamond prices 4 | 5 | ```python 6 | import catboost 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.datasets import fetch_openml 10 | 11 | # Load and prepare diamond data 12 | diamonds = fetch_openml(data_id=42225, as_frame=True) 13 | 14 | X = diamonds.data.assign( 15 | log_carat=lambda x: np.log(x.carat), # better visualization 16 | clarity=lambda x: pd.Categorical( 17 | x.clarity, categories=["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"] 18 | ), 19 | cut=lambda x: pd.Categorical( 20 | x.cut, categories=["Fair", "Good", "Very Good", "Premium", "Ideal"] 21 | ), 22 | )[["log_carat", "cut", "color", "clarity"]] 23 | y = np.log(diamonds.target) 24 | 25 | # Fit model 26 | model = catboost.CatBoostRegressor( 27 | iterations=100, depth=4, cat_features=["cut", "color", "clarity"], verbose=0 28 | ) 29 | model.fit(X, y=y) 30 | ``` 31 | 32 | ### TreeSHAP analysis 33 | 34 | ```python 35 | from lightshap import explain_tree 36 | 37 | X_explain = X.sample(1000, random_state=0) 38 | explanation = explain_tree(model, X_explain) 39 | 40 | explanation.plot.bar() 41 | explanation.plot.beeswarm() 42 | explanation.plot.scatter(sharey=False) 43 | explanation.plot.waterfall(row_id=0) 44 | ``` 45 | 46 | #### SHAP importance 47 | 48 | ![SHAP importance](images/tree_bar.png) 49 | 50 | #### SHAP summary 51 | 52 | ![SHAP summary](images/tree_beeswarm.png) 53 | 54 | #### SHAP dependence 55 | 56 | ![SHAP dependence](images/tree_scatter.png) 57 | 58 | #### Individual explanation 59 | 60 | ![SHAP waterfall](images/tree_waterfall.png) 61 | 62 | 63 | ## Example 2: Linear regression with interactions 64 | 65 | We use `X` and `y` from the previous example. 66 | 67 | > **Note:** This example requires `glum`. Install with `pip install glum` 68 | 69 | ```python 70 | from glum import GeneralizedLinearRegressor 71 | 72 | # Fit with interactions 73 | glm = GeneralizedLinearRegressor( 74 | family="gaussian", 75 | formula="log_carat * (clarity + cut + color)", 76 | drop_first=True, 77 | ) 78 | glm.fit(X, y=y) 79 | ``` 80 | 81 | ### Model-agnostic SHAP analysis 82 | 83 | ```python 84 | from lightshap import explain_any 85 | 86 | X_explain = X.sample(1000, random_state=0) 87 | explanation = explain_any(glm.predict, X_explain) 88 | 89 | explanation.plot.bar() 90 | explanation.plot.beeswarm() 91 | explanation.plot.scatter(sharey=False) 92 | explanation.plot.waterfall(row_id=0) 93 | ``` 94 | 95 | #### SHAP importance 96 | 97 | ![SHAP importance](images/glm_bar.png) 98 | 99 | #### SHAP summary 100 | 101 | ![SHAP summary](images/glm_beeswarm.png) 102 | 103 | #### SHAP dependence 104 | 105 | ![SHAP dependence](images/glm_scatter.png) 106 | 107 | #### Individual explanation 108 | 109 | ![SHAP waterfall](images/glm_waterfall.png) -------------------------------------------------------------------------------- /docs/images/glm_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/glm_bar.png -------------------------------------------------------------------------------- /docs/images/glm_beeswarm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/glm_beeswarm.png -------------------------------------------------------------------------------- /docs/images/glm_scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/glm_scatter.png -------------------------------------------------------------------------------- /docs/images/glm_waterfall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/glm_waterfall.png -------------------------------------------------------------------------------- /docs/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/logo.png -------------------------------------------------------------------------------- /docs/images/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 39 | 41 | 46 | 54 | 62 | 70 | LightSHAP 81 | 82 | 83 | -------------------------------------------------------------------------------- /docs/images/tree_bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/tree_bar.png -------------------------------------------------------------------------------- /docs/images/tree_beeswarm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/tree_beeswarm.png -------------------------------------------------------------------------------- /docs/images/tree_scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/tree_scatter.png -------------------------------------------------------------------------------- /docs/images/tree_waterfall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mayer79/LightSHAP/7d7c16cc2806b448f70e268c19a688d37f8d92b8/docs/images/tree_waterfall.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # LightSHAP 2 | 3 | **Lightweight Python implementation of SHAP (SHapley Additive exPlanations).** 4 | 5 | ## Key Features 6 | 7 | - **Tree Models**: TreeSHAP wrappers for XGBoost, LightGBM, and CatBoost via `explain_tree()` 8 | - **Model-Agnostic**: Permutation SHAP and Kernel SHAP via `explain_any()` 9 | - **Visualization**: Flexible plots 10 | 11 | **Highlights of the agnostic explainer:** 12 | 13 | 1. Exact and sampling versions of permutation SHAP and Kernel SHAP 14 | 2. Sampling versions iterate until convergence, and provide standard errors 15 | 3. Parallel processing via joblib 16 | 4. Supports multi-output models 17 | 5. Supports case weights 18 | 6. Accepts numpy, pandas, and polars input, and categorical features 19 | 20 | **Some methods of the explanation object:** 21 | 22 | - `plot.bar()`: Feature importance bar plot 23 | - `plot.beeswarm()`: Summary beeswarm plot 24 | - `plot.scatter()`: Dependence plots 25 | - `plot.waterfall()`: Waterfall plot for individual explanations 26 | - `importance()`: Returns feature importance values 27 | - `set_X()`: Update explanation data, e.g., to replace a numpy array with a DataFrame 28 | - `set_feature_names()`: Set or update feature names 29 | - `select_output()`: Select a specific output for multi-output models 30 | - `filter()`: Subset explanations by condition or indices 31 | - ... 32 | 33 | ## Quick Start 34 | 35 | ```python 36 | from lightshap import explain_any, explain_tree 37 | 38 | # For any model 39 | explanation = explain_any(model.predict, X) 40 | 41 | # For tree models (XGBoost, LightGBM, CatBoost) 42 | explanation = explain_tree(model, X) 43 | 44 | # Create plots 45 | explanation.plot.bar() # Feature importance 46 | explanation.plot.beeswarm() # Summary plot 47 | explanation.plot.scatter() # Dependence plots 48 | explanation.plot.waterfall() # Individual explanation 49 | ``` 50 | 51 | ## Documentation 52 | 53 | - [API Reference](api.md) - Detailed API documentation 54 | - [Examples](examples.md) - Usage examples and tutorials 55 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: LightSHAP 2 | site_description: Lightweight Python implementation of SHAP 3 | site_url: https://mayer79.github.io/LightSHAP/ 4 | repo_url: https://github.com/mayer79/LightSHAP 5 | repo_name: mayer79/LightSHAP 6 | 7 | theme: 8 | name: material 9 | palette: 10 | - scheme: default 11 | primary: blue 12 | accent: blue 13 | toggle: 14 | icon: material/brightness-7 15 | name: Switch to dark mode 16 | - scheme: slate 17 | primary: blue 18 | accent: blue 19 | toggle: 20 | icon: material/brightness-4 21 | name: Switch to light mode 22 | features: 23 | - navigation.tabs 24 | - navigation.sections 25 | - navigation.expand 26 | - navigation.top 27 | - navigation.instant 28 | - navigation.tracking 29 | - search.highlight 30 | - content.code.copy 31 | - content.action.edit 32 | - content.action.view 33 | 34 | plugins: 35 | - search 36 | - mkdocstrings: 37 | handlers: 38 | python: 39 | options: 40 | docstring_style: numpy 41 | show_root_heading: true 42 | show_root_toc_entry: false 43 | show_signature_annotations: true 44 | show_source: false 45 | members_order: source 46 | heading_level: 2 47 | show_symbol_type_heading: true 48 | show_symbol_type_toc: true 49 | group_by_category: true 50 | show_submodules: true 51 | inherited_members: true 52 | filters: 53 | - "!^_" # exclude anything starting with _ 54 | 55 | nav: 56 | - Home: index.md 57 | - API Reference: api.md 58 | - Examples: examples.md 59 | 60 | markdown_extensions: 61 | - admonition 62 | - pymdownx.details 63 | - pymdownx.superfences 64 | - pymdownx.highlight: 65 | anchor_linenums: true 66 | - pymdownx.inlinehilite 67 | - pymdownx.snippets 68 | - pymdownx.tabbed: 69 | alternate_style: true 70 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "hatchling==1.27.0", 4 | ] 5 | build-backend = "hatchling.build" 6 | 7 | [project] 8 | name = "lightshap" 9 | version = "0.1.12" # Static version 10 | # dynamic = ["version"] # Commented out for static versioning 11 | description = "A lightweight SHAP library" 12 | readme = "README.md" 13 | license = {text = "MIT"} 14 | authors = [ 15 | {name = "Michael Mayer", email = "mayermichael79@gmail.com"} 16 | ] 17 | maintainers = [ 18 | {name = "Michael Mayer", email = "mayermichael79@gmail.com"} 19 | ] 20 | keywords = ["shap", "explainability", "machine-learning", "interpretability", "xai"] 21 | classifiers = [ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3.13", 30 | "Topic :: Scientific/Engineering", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | "Operating System :: OS Independent", 34 | ] 35 | requires-python = ">=3.11" 36 | dependencies = [ 37 | "numpy>=2.0.0", 38 | "pandas>=2.2.2", # numpy 2 39 | "matplotlib>=3.8.4", # numpy 2 40 | "scipy>=1.13.1", # numpy 2 41 | "joblib>=1.3.0", 42 | "tqdm>=4.67.0", 43 | ] 44 | 45 | [project.optional-dependencies] 46 | # Polars support 47 | polars = ["polars>=1.1.0", "pyarrow>=18.0.0"] 48 | 49 | # Tree model support (these are optional since users might only want one) 50 | xgboost = ["xgboost>=2.1.0"] 51 | lightgbm = ["lightgbm>=4.4.0"] 52 | catboost = ["catboost>=1.2.5"] 53 | 54 | # All tree models 55 | tree = ["xgboost>=2.1.0", "lightgbm>=4.4.0", "catboost>=1.2.5"] 56 | 57 | # Documentation dependencies 58 | docs = [ 59 | "mkdocs>=1.5.0", 60 | "mkdocs-material>=9.0.0", 61 | "mkdocstrings[python]>=0.24.0", 62 | ] 63 | 64 | # Development dependencies 65 | dev = [ 66 | "pytest>=7.0.0", 67 | "pytest-cov>=4.0.0", 68 | "ruff>=0.1.0", 69 | "pre-commit>=2.20.0", 70 | # Documentation 71 | "mkdocs>=1.5.0", 72 | "mkdocs-material>=9.0.0", 73 | "mkdocstrings[python]>=0.24.0", 74 | # Optional dependencies needed for testing 75 | "polars>=1.1.0", 76 | "pyarrow>=18.0.0", 77 | "xgboost>=2.1.0", 78 | "lightgbm>=4.4.0", 79 | "catboost>=1.2.5", 80 | "scikit-learn>=1.5.0", 81 | ] 82 | 83 | # All optional dependencies 84 | all = [ 85 | "polars>=1.1.0", 86 | "pyarrow>=18.0.0", 87 | "xgboost>=2.1.0", 88 | "lightgbm>=4.4.0", 89 | "catboost>=1.2.5", 90 | "scikit-learn>=1.5.0", 91 | "mkdocs>=1.5.0", 92 | "mkdocs-material>=9.0.0", 93 | "mkdocstrings[python]>=0.24.0", 94 | ] 95 | 96 | [project.urls] 97 | Homepage = "https://github.com/mayer79/LightSHAP" 98 | Repository = "https://github.com/mayer79/LightSHAP" 99 | "Bug Tracker" = "https://github.com/mayer79/LightSHAP/issues" 100 | 101 | [tool.hatch.build.targets.wheel] 102 | packages = ["src/lightshap"] 103 | 104 | # Ruff configuration (handles both linting and formatting) 105 | [tool.ruff] 106 | line-length = 88 107 | target-version = "py311" 108 | 109 | [tool.ruff.lint] 110 | select = [ 111 | "E", # pycodestyle errors 112 | "W", # pycodestyle warnings 113 | "F", # pyflakes 114 | "I", # isort 115 | "B", # flake8-bugbear 116 | "C4", # flake8-comprehensions 117 | "UP", # pyupgrade 118 | ] 119 | ignore = [ 120 | "E501", # line too long, handled by formatter 121 | "B008", # do not perform function calls in argument defaults 122 | "C901", # too complex 123 | ] 124 | 125 | [tool.ruff.lint.per-file-ignores] 126 | "__init__.py" = ["F401"] # Allow unused imports in __init__.py 127 | "tests/**/*" = ["B011"] # Allow assert False in tests 128 | 129 | # Format configuration 130 | [tool.ruff.format] 131 | quote-style = "double" 132 | indent-style = "space" 133 | skip-magic-trailing-comma = false 134 | line-ending = "auto" 135 | 136 | # Pytest configuration 137 | [tool.pytest.ini_options] 138 | testpaths = ["src/lightshap/tests", "src/lightshap/explanation/tests", "src/lightshap/explainers/tests"] 139 | python_files = ["test_*.py"] 140 | python_classes = ["Test*"] 141 | python_functions = ["test_*"] 142 | addopts = [ 143 | "--strict-markers", 144 | "--strict-config", 145 | "--verbose", 146 | "--cov=lightshap", 147 | "--cov-report=term-missing", 148 | "--cov-report=html", 149 | ] 150 | markers = [ 151 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 152 | "integration: marks tests as integration tests", 153 | ] 154 | 155 | # Coverage configuration 156 | [tool.coverage.run] 157 | source = ["src/lightshap"] 158 | omit = [ 159 | "*/tests/*", 160 | "*/test_*.py", 161 | "src/__about__.py", 162 | "*/__init__.py", # Import-only files 163 | ] 164 | 165 | [tool.coverage.report] 166 | exclude_lines = [ 167 | "pragma: no cover", 168 | "def __repr__", 169 | "if self.debug:", 170 | "if settings.DEBUG", 171 | "raise AssertionError", 172 | "raise NotImplementedError", 173 | "if 0:", 174 | "if __name__ == .__main__.:", 175 | "class .*\\bProtocol\\):", 176 | "@(abc\\.)?abstractmethod", 177 | ] 178 | 179 | [tool.hatch.envs.test] 180 | dependencies = [ 181 | "scikit-learn>=1.5.0", 182 | "pytest>=7.0.0", 183 | "pytest-cov>=4.0.0", 184 | # Optional dependencies needed for testing 185 | "polars>=1.1.0", 186 | "pyarrow>=18.0.0", # for pl.to_pandas() 187 | "xgboost>=2.1.0", 188 | "lightgbm>=4.4.0", 189 | "catboost>=1.2.5", 190 | ] 191 | 192 | [tool.hatch.envs.test.scripts] 193 | test = "pytest {args}" 194 | cov = "pytest --cov=lightshap --cov-report=html --cov-report=term-missing {args}" 195 | 196 | [tool.hatch.publish.index] 197 | disable = false 198 | 199 | [tool.hatch.publish.index.repos] 200 | testpypi = "https://test.pypi.org/legacy/" 201 | 202 | [tool.hatch.envs.release] 203 | dependencies = [ 204 | "build", 205 | "twine", 206 | ] 207 | 208 | [tool.hatch.envs.release.scripts] 209 | check = "python -m build --check" 210 | build = "python -m build" # Use python -m build instead of hatch build 211 | publish-test = "python -m twine upload --repository testpypi dist/*" 212 | publish = "python -m twine upload dist/*" 213 | 214 | # GitHub deployment environment configuration 215 | [tool.hatch.envs.release.env-vars] 216 | ENVIRONMENT = "release" 217 | 218 | [tool.hatch.envs.dev] 219 | dependencies = [ 220 | "mkdocs>=1.6.0", 221 | "mkdocs-material>=9.6.0", 222 | "mkdocstrings[python]>=1.18.0", 223 | ] 224 | 225 | # NOTE: removed the [tool.hatch.envs.dev.scripts] section to avoid hatch wrapping/CLI parsing issues. -------------------------------------------------------------------------------- /src/lightshap/__init__.py: -------------------------------------------------------------------------------- 1 | """LightSHAP: Lightweight SHAP implementation.""" 2 | 3 | from ._version import __version__ 4 | from .explainers import explain_any, explain_tree 5 | from .explanation import Explanation 6 | 7 | __all__ = ["__version__", "explain_any", "explain_tree", "Explanation"] 8 | -------------------------------------------------------------------------------- /src/lightshap/_version.py: -------------------------------------------------------------------------------- 1 | """Version information.""" 2 | 3 | __all__ = [ 4 | "__version__", 5 | "__version_tuple__", 6 | "version", 7 | "version_tuple", 8 | "__commit_id__", 9 | "commit_id", 10 | ] 11 | 12 | __version__ = version = "0.1.12" 13 | __version_tuple__ = version_tuple = (0, 1, 12) 14 | 15 | __commit_id__ = commit_id = None 16 | -------------------------------------------------------------------------------- /src/lightshap/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .explain_any import explain_any 2 | from .explain_tree import explain_tree 3 | 4 | __all__ = ["explain_any", "explain_tree"] 5 | -------------------------------------------------------------------------------- /src/lightshap/explainers/_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from itertools import combinations, product 3 | 4 | import numpy as np 5 | 6 | from lightshap.utils import get_dataclass, get_polars 7 | 8 | 9 | def replicate_data(X, m): 10 | """Replicate data X m times. 11 | 12 | If X has rows 0, 1, 2, then the output will have rows 0, 1, 2, 0, 1, 2, ... . 13 | 14 | Parameters 15 | ---------- 16 | X : DataFrame, array 17 | The input data to be replicated. 18 | m : int 19 | The number of times to replicate the data. 20 | 21 | Returns 22 | ------- 23 | DataFrame, array 24 | The replicated data. 25 | """ 26 | 27 | if m <= 0: 28 | msg = "Replication factor m must be positive" 29 | raise ValueError(msg) 30 | 31 | xclass = get_dataclass(X) 32 | 33 | if xclass == "np": 34 | return np.tile(X, (m, 1)) 35 | J = np.tile(np.arange(X.shape[0]), m) 36 | if xclass == "pd": 37 | return X.iloc[J] # no reset index required 38 | 39 | return X[J] # polars 40 | 41 | 42 | def repeat_masks(Z, m, pl_schema=None): 43 | """ 44 | Repeat the masks m times. 45 | 46 | Parameters 47 | ---------- 48 | Z : ndarray 49 | The input masks to be repeated. 50 | m : int 51 | The number of times to repeat the masks. 52 | pl_schema : list, optional 53 | The column names to use if the output is a polars DataFrame. 54 | 55 | Returns 56 | ------- 57 | ndarray or pl.DataFrame 58 | The repeated masks. 59 | """ 60 | out = np.repeat(Z, m, axis=0) 61 | if pl_schema is not None: 62 | pl = get_polars() 63 | out = pl.DataFrame(out, schema=pl_schema) 64 | return out 65 | 66 | 67 | def welford_iteration(new_value, avg, sum_squared, j): 68 | """ 69 | Welford's method for updating the mean and variance incrementally. 70 | 71 | Parameters 72 | ---------- 73 | 74 | new_value: float 75 | The new value to incorporate into the average and sum of squares. 76 | avg: float 77 | The current average. 78 | sum_squared: float 79 | The current sum of squared differences from the average. 80 | j: int 81 | The number of observations so far (1 for the first obs). 82 | This is used to compute the updated average and variance. 83 | 84 | Returns 85 | ------- 86 | tuple 87 | A tuple containing the updated average and the updated sum of squares. 88 | 89 | """ 90 | delta = new_value - avg 91 | avg += delta / j 92 | sum_squared += delta * (new_value - avg) 93 | return avg, sum_squared 94 | 95 | 96 | def check_convergence(beta_n, sum_squared, n_iter, tol): 97 | """Standard error <= tolerance times range of values? 98 | 99 | Checks if the standard error for each output dimension is smaller or equal 100 | to the tolerance times the range of SHAP values. 101 | 102 | Required only for sampling version of permutation SHAP. 103 | 104 | Parameters 105 | ---------- 106 | beta_n : array-like 107 | (p x K) matrix of SHAP values. 108 | sum_squared : array-like 109 | (p x K) matrix of sum of squares. 110 | n_iter : array-like 111 | Number of iterations. 112 | tol : float 113 | The tolerance level. 114 | 115 | Returns 116 | ------- 117 | bool 118 | True if the convergence criterion is met for all output dimensions, 119 | False otherwise. 120 | """ 121 | 122 | shap_range = np.ptp(beta_n, axis=0) 123 | converged = sum_squared.max(axis=0) <= (tol * n_iter * shap_range) ** 2 124 | return all(converged) 125 | 126 | 127 | def generate_all_masks(p): 128 | """ 129 | Generate a matrix of all possible boolean combinations for p features. 130 | 131 | This creates a 2^p x p boolean matrix where each row is a unique 132 | combination of True/False values, representing all possible subsets of features. 133 | 134 | Required only for exact permutation SHAP. 135 | 136 | Parameters 137 | ---------- 138 | p : int 139 | Number of features 140 | 141 | Returns 142 | ------- 143 | numpy.ndarray 144 | A 2^p x p boolean matrix with all possible combinations 145 | """ 146 | 147 | return np.array(list(product([False, True], repeat=p)), dtype=bool) 148 | 149 | 150 | def generate_partly_exact_masks(p, degree): 151 | """ 152 | List all length p vectors z with sum(z) in {degree, p - degree} and 153 | organize them in a boolean matrix with p columns and either choose(p, degree) or 154 | 2 * choose(p, degree) rows. 155 | 156 | Parameters 157 | ---------- 158 | p : int 159 | Number of features. 160 | degree : int 161 | Degree of the hybrid approach. 162 | 163 | Returns 164 | ------- 165 | np.ndarray 166 | A boolean matrix with partly exact masks. 167 | """ 168 | if degree < 1: 169 | msg = "degree must be at least 1" 170 | raise ValueError(msg) 171 | if 2 * degree > p: 172 | msg = "p must be >= 2 * degree" 173 | raise ValueError(msg) 174 | if degree == 1: 175 | Z = np.eye(p, dtype=bool) 176 | else: 177 | comb = np.array(list(combinations(range(p), r=degree))) 178 | Z = np.zeros((len(comb), p), dtype=bool) 179 | row_indices = np.repeat(np.arange(len(comb)), degree) 180 | col_indices = comb.flatten() 181 | Z[row_indices, col_indices] = True 182 | 183 | if 2 * degree != p: 184 | Z = np.vstack((~Z, Z)) 185 | 186 | return Z 187 | 188 | 189 | def random_permutation_from_start(p, start, rng): 190 | """ 191 | Returns a random permutation of integers from 0 to p-1 starting with value `start`. 192 | 193 | Required only for sampling version of permutation SHAP. 194 | 195 | Parameters 196 | ---------- 197 | p : int 198 | Length of the permutation. 199 | start : int 200 | The first element of the permutation. 201 | rng : np.random.Generator 202 | Random number generator for reproducibility. 203 | 204 | Returns 205 | ------- 206 | list 207 | A list representing a random permutation of integers from 0 to p-1, 208 | starting with the specified `start` value. 209 | """ 210 | remaining = [i for i in range(p) if i != start] 211 | rng.shuffle(remaining) 212 | return [start, *remaining] 213 | 214 | 215 | def generate_permutation_masks(J, degree): 216 | """ 217 | Creates a (2 * (p - 1 - 2 * degree) x p) on-off-matrix with 218 | antithetic permutation scheme. 219 | 220 | Required only for sampling version of permutation SHAP. 221 | 222 | Parameters 223 | ---------- 224 | J : list 225 | A permutation vector of length p. 226 | degree : int 227 | Row sums of the returned matrix will be within [1 + degree, p - degree - 1]. 228 | 229 | Returns 230 | ------- 231 | A (2 * (p - 1 - 2 * degree) x p) boolean on-off-matrix. 232 | """ 233 | m = len(J) - 1 234 | if m <= 2 * degree: 235 | msg = "J must have at least 2 * degree + 2 elements" 236 | raise ValueError(msg) 237 | Z = np.ones((m, m + 1), dtype=bool) 238 | for i in range(m): 239 | Z[i : (m + 1), J[i]] = False 240 | if degree > 0: 241 | Z = Z[degree:-degree] 242 | return np.vstack((Z, ~Z)) 243 | 244 | 245 | def check_or_derive_background_data(bg_X, bg_w, bg_n, X, random_state): 246 | """ 247 | Checks or derives background data against X. 248 | 249 | Parameters 250 | ---------- 251 | bg_X : DataFrame, array 252 | Background data. 253 | bg_w : array-like 254 | Background weights. 255 | bg_n : int 256 | Maximum number of observations in the background data (if bg_X = None). 257 | X : DataFrame, array 258 | Input data. 259 | random_state : int 260 | Random seed for reproducibility. 261 | 262 | Returns 263 | ------- 264 | tuple 265 | A tuple containing the background data and weights. 266 | """ 267 | n, p = X.shape 268 | xclass = get_dataclass(X) 269 | 270 | # Deal with background weights 271 | if bg_w is not None: 272 | if get_dataclass(bg_w) in ("pd", "pl"): 273 | bg_w = bg_w.to_numpy() # will use exclusively in np.average(...) 274 | if bg_w.ndim != 1: 275 | msg = "bg_w must be a 1D array-like object." 276 | raise ValueError(msg) 277 | if bg_X is None and bg_w.shape[0] != n: 278 | msg = "bg_w must have the same length as X if bg_X is None." 279 | raise ValueError(msg) 280 | elif bg_X is not None and bg_w.shape[0] != bg_X.shape[0]: 281 | msg = "bg_w must have the same length as bg_X." 282 | raise ValueError(msg) 283 | if not any(bg_w > 0): 284 | msg = "bg_w must have at least one positive weight." 285 | raise ValueError(msg) 286 | 287 | if bg_X is None: 288 | if n <= 20: 289 | msg = "Background data must be provided or X must have at least 20 rows." 290 | raise ValueError(msg) 291 | bg_X = X.copy() if xclass in ("np", "pd") else X 292 | if n > bg_n: 293 | rng = np.random.default_rng(random_state) 294 | indices = rng.choice(n, size=bg_n, replace=False) 295 | bg_X = bg_X[indices] if xclass != "pd" else bg_X.iloc[indices] 296 | if bg_w is not None: 297 | bg_w = bg_w[indices] 298 | else: 299 | if get_dataclass(bg_X) != xclass: 300 | msg = f"Background data must be of type {xclass}." 301 | raise TypeError(msg) 302 | if xclass == "np" and p != bg_X.shape[1]: 303 | msg = f"Background data must have {p} columns, but has {bg_X.shape[1]}." 304 | raise ValueError(msg) 305 | if xclass in ("pd", "pl"): 306 | if set(bg_X.columns).issuperset(set(X.columns)): 307 | bg_X = bg_X[X.columns] 308 | else: 309 | msg = "Background data must have at least the same columns as X." 310 | raise ValueError(msg) 311 | return bg_X, bg_w 312 | 313 | 314 | def safe_predict(func): 315 | """Turns predictions into (n x K) numpy array.""" 316 | if not callable(func): 317 | msg = "predict must be a callable." 318 | raise TypeError(msg) 319 | 320 | @functools.wraps(func) 321 | def wrapped(X): 322 | return np.asarray(func(X)).reshape(X.shape[0], -1) 323 | 324 | return wrapped 325 | 326 | 327 | def collapse_potential(X, bg_X, bg_w): 328 | """ 329 | Collapse potential per row against background data bg_X. 330 | 331 | The idea is as follows: if a value in a row of X is equal to q * 100% rows in bg_X, 332 | then (in exact mode) about q * 100% / 2 of the predictions can be skipped 333 | (potential). This is accumulated over all columns in X in a multiplicative way. 334 | 335 | Note that missing values are not considered in the current code. 336 | 337 | Parameters 338 | ---------- 339 | X : DataFrame, array 340 | The rows to be explained. 341 | bg_X : DataFrame, array 342 | The background data. 343 | bg_w : array 344 | Weights for the background data. 345 | 346 | Returns 347 | ------- 348 | float 349 | The potential collapse value, which is 1 minus the product of 350 | (1 - proportion of equal values per column) divided by 2. 351 | 352 | """ 353 | if not isinstance(X, np.ndarray): # then X must be polars or pandas 354 | X = X.to_numpy() 355 | bg_X = bg_X.to_numpy() 356 | 357 | potential = np.zeros_like(X, dtype=float) 358 | for i in range(X.shape[0]): 359 | potential[i] = np.average(X[i] == bg_X, axis=0, weights=bg_w) / 2 360 | 361 | return 1 - np.prod(1 - potential, axis=1) 362 | 363 | 364 | def collapse_with_index(x, xclass): 365 | """ 366 | Get unique rows of x and indices to reconstruct the original x. 367 | 368 | Parameters 369 | ---------- 370 | x : array-like or DataFrame 371 | Input data to find unique rows in 372 | xclass : str 373 | Type of x: 'numpy', 'pandas', or 'polars' 374 | 375 | Returns 376 | ------- 377 | tuple 378 | (unique_x, indices_to_reconstruct) 379 | - unique_x: x with only unique rows 380 | - indices_to_reconstruct: indices to map from unique_x back to x, or None. 381 | """ 382 | ix_reconstruct = None 383 | 384 | if xclass == "np": 385 | try: 386 | _, ix, ix_reconstruct = np.unique( 387 | x, return_index=True, return_inverse=True, axis=0 388 | ) 389 | ix, ix_reconstruct = ix.squeeze(), ix_reconstruct.squeeze() 390 | unique_x = x[ix] 391 | except TypeError: 392 | # If unique fails (e.g., with mixed dtypes), return original data 393 | unique_x = x 394 | elif xclass == "pd": 395 | unique_x = x.drop_duplicates().reset_index(drop=True) 396 | 397 | if len(unique_x) < len(x): 398 | ix_reconstruct = x.merge( 399 | unique_x.reset_index(names="_unique_idx_"), 400 | on=list(x.columns), 401 | how="left", 402 | )["_unique_idx_"].to_numpy() 403 | else: 404 | unique_x = x 405 | elif xclass == "pl": 406 | pl = get_polars() 407 | unique_x = x.unique(maintain_order=True) 408 | 409 | if len(unique_x) < len(x): 410 | ix_reconstruct = ( 411 | x.join( 412 | unique_x.with_row_index("_unique_idx_"), 413 | on=x.columns, 414 | how="left", 415 | maintain_order="left", 416 | nulls_equal=True, 417 | ) 418 | .select(pl.col("_unique_idx_")) 419 | .to_numpy() 420 | .flatten() 421 | ) 422 | else: 423 | unique_x = x 424 | 425 | return unique_x, ix_reconstruct 426 | 427 | 428 | def masked_predict(predict, masks_rep, x, bg_rep, weights, xclass, collapse, bg_n): 429 | """ 430 | Masked predict function. 431 | 432 | For each on-off vector (rows in mask), the (weighted) average prediction 433 | is returned from a dataset using x as the "on" values and the background data as 434 | the "off" values. 435 | 436 | Parameters 437 | ---------- 438 | predict : callable 439 | Prediction function that takes data as input and returns a length K vector 440 | for each row. 441 | masks_rep : ndarray or pl.DataFrame 442 | An ((m * n_bg) x p) ndarray or pl.DataFrame with on-off values (boolean mask). 443 | x : ndarray, pd.Series, or pl.DataFrame 444 | Row to be explained. Note that for polars, we expect a DataFrame with one row, 445 | while for pandas, we expect a Series with colnames as index. 446 | bg_rep : DataFrame, array 447 | Background data stacked m times, i.e., having shape ((m * n_bg) x p) 448 | weights : array, optional 449 | A vector with case weights (of the same length as the unstacked background data). 450 | xclass : str 451 | The type of the background data, either "pd" for pandas DataFrame, 452 | "np" for numpy array, or "pl" for polars DataFrame. 453 | collapse : bool 454 | Whether to deduplicate the prediction data. 455 | bg_n : int 456 | How many rows does the (non-replicated) background data have. 457 | 458 | Returns 459 | ------- 460 | array 461 | A (m x K) ndarray with masked predictions. 462 | """ 463 | # Apply the masks 464 | if xclass == "np": 465 | # If x would have been replicated: bg_masked[mask_rep] = x[mask_rep] 466 | bg_masked = bg_rep.copy() 467 | for i in range(masks_rep.shape[1]): 468 | bg_masked[masks_rep[:, i], i] = x[i] 469 | elif xclass == "pd": 470 | bg_masked = bg_rep.copy() 471 | for i, v in enumerate(bg_masked.columns.to_list()): 472 | bg_masked.loc[masks_rep[:, i], v] = x[v] 473 | else: # polars DataFrame 474 | pl = get_polars() 475 | bg_masked = bg_rep.with_columns( 476 | pl.when(masks_rep[v]).then(pl.lit(x[v])).otherwise(pl.col(v)).alias(v) 477 | for v in bg_rep.columns 478 | ) 479 | if collapse: 480 | bg_masked, ix_reconstruct = collapse_with_index(bg_masked, xclass=xclass) 481 | else: 482 | ix_reconstruct = None 483 | 484 | preds = predict(bg_masked) 485 | 486 | if ix_reconstruct is not None: 487 | preds = preds[ix_reconstruct] 488 | 489 | m_masks = masks_rep.shape[0] // bg_n 490 | preds = preds.reshape(m_masks, preds.shape[0] // m_masks, -1) # Avoids splitting 491 | return np.average(preds, axis=1, weights=weights) # (m x K) 492 | -------------------------------------------------------------------------------- /src/lightshap/explainers/explain_any.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import joblib 4 | import numpy as np 5 | 6 | from lightshap.explanation.explanation import Explanation 7 | 8 | from ._utils import check_or_derive_background_data, collapse_potential, safe_predict 9 | from .kernel_utils import one_kernelshap, precalculate_kernelshap 10 | from .parallel import ParallelPbar 11 | from .permutation_utils import one_permshap, precalculate_permshap 12 | 13 | 14 | def explain_any( 15 | predict, 16 | X, 17 | bg_X=None, 18 | bg_w=None, 19 | bg_n=200, 20 | method=None, 21 | how=None, 22 | max_iter=None, 23 | tol=0.01, 24 | random_state=None, 25 | n_jobs=1, 26 | verbose=True, 27 | ): 28 | """ 29 | SHAP values for any model 30 | 31 | Calculate SHAP values for any model using either Kernel SHAP or Permutation SHAP. 32 | By default, it uses Permutation SHAP for p <= 8 features and a hybrid between 33 | exact and sampling Kernel SHAP for p > 8 features. 34 | 35 | Parameters 36 | ---------- 37 | predict : callable 38 | A callable to get predictions, e.g., `model.predict`, `model.predict_proba`, 39 | or lambda x: scipy.special.logit(model.predict_proba(x)[:, -1]). 40 | 41 | X : pd.DataFrame, pl.DataFrame, np.ndarray 42 | Input data for which explanations are to be generated. Should contain only 43 | the p feature columns. Must be compatible with `predict`. 44 | 45 | bg_X : pd.DataFrame, pl.DataFrame, np.ndarray, or None, default=None 46 | Background data used to integrate out "switched off" features, 47 | typically a representative sample of the training data with 100 to 500 rows. 48 | Should contain the same columns as `X`, and be compatible with `predict`. 49 | If None, up to `bg_n` rows of `X` are randomly selected. 50 | 51 | bg_w : pd.Series, pl.Series, np.ndarray, or None, default=None 52 | Weights for the background data. If None, equal weights are used. 53 | If `bg_X` is None, `bg_w` must have the same length as `X`. 54 | 55 | bg_n : int, default=200 56 | If `bg_X` is None, that many rows are randomly selected from `X` 57 | to use as background data. Values between 50 and 500 are recommended. 58 | 59 | method: str, or None, default=None 60 | Either "kernel", "permutation", or None. 61 | If None, it is set to "permutation" when p <= 8, and to "kernel" otherwise. 62 | 63 | how: str, or None, default=None 64 | If "exact", exact SHAP values are computed. If "sampling", iterative sampling 65 | is used to approximate SHAP values. For Kernel SHAP, hybrid approaches between 66 | "sampling" and "exact" options are available: "h1" uses exact calculations 67 | for coalitions of size 1 and p-1, whereas "h2" uses exact calculations 68 | for coalitions of size 1, 2, p-2, and p-1. 69 | If None, it is set to "exact" when p <= 8. Otherwise, if method=="permutation", 70 | it is set to "sampling". For Kernel SHAP, if 8 < p <= 16, it is set to "h2", 71 | and to "h1" when p > 16. 72 | 73 | max_iter : int or None, default=None 74 | Maximum number of iterations for non-exact algorithms. Each iteration represents 75 | a forward and backward pass through a random permutation. 76 | For permutation SHAP, one iteration allows to evaluate Shapley's formula 77 | 2*p times (twice per feature). 78 | p subsequent iterations are starting with different values for faster 79 | convergence. If None, it is set to 10 * p. 80 | 81 | tol : float, default=0.01 82 | Tolerance for convergence. The algorithm stops when the estimated standard 83 | errors are all smaller or equal to `tol * range(shap_values)` 84 | for each output dimension. Not used when how=="exact". 85 | 86 | random_state : int or None, default=None 87 | Integer random seed to initialize numpy's random generator. Required for 88 | non-exact algorithms, and to subsample the background data if `bg_X` is None. 89 | 90 | n_jobs : int, default=1 91 | Number of parallel jobs to run via joblib. If 1, no parallelization is used. 92 | If -1, all available cores are used. 93 | 94 | verbose : bool, default=True 95 | If True, prints information and the tqdm progress bar. 96 | 97 | Returns 98 | ------- 99 | Explanation object 100 | 101 | Examples 102 | -------- 103 | **Example 1: Working with Numpy input** 104 | 105 | >>> import numpy as np 106 | >>> from lightshap import explain_any 107 | >>> 108 | >>> # Create synthetic data 109 | >>> rng = np.random.default_rng(0) 110 | >>> X = rng.standard_normal((1000, 4)) 111 | >>> 112 | >>> # In practice, you would use model.predict, model.predict_proba, 113 | >>> # or a function thereof, e.g., 114 | >>> # lambda X: scipy.special.logit(model.predict_proba(X)) 115 | >>> def predict_function(X): 116 | ... linear = X[:, 0] + 2 * X[:, 1] - X[:, 2] + 0.5 * X[:, 3] 117 | ... interactions = X[:, 0] * X[:, 1] - X[:, 1] * X[:, 2] 118 | ... return (linear + interactions).reshape(-1, 1) 119 | >>> 120 | >>> # Explain with numpy array (no feature names initially) 121 | >>> explanation = explain_any( 122 | ... predict=predict_function, 123 | ... X=X[:100], # Explain first 100 rows 124 | ... ) 125 | >>> 126 | >>> # Set meaningful feature names 127 | >>> feature_names = ["temperature", "pressure", "humidity", "wind_speed"] 128 | >>> explanation = explanation.set_feature_names(feature_names) 129 | >>> 130 | >>> # Generate plots 131 | >>> explanation.plot.bar() 132 | >>> explanation.plot.scatter(["temperature", "humidity"]) 133 | >>> explanation.plot.waterfall(row_id=0) 134 | 135 | **Example 2: Polars input with categorical features** 136 | 137 | >>> import numpy as np 138 | >>> import polars as pl 139 | >>> from lightshap import explain_any 140 | >>> 141 | >>> rng = np.random.default_rng(0) 142 | >>> n = 800 143 | >>> 144 | >>> df = pl.DataFrame({ 145 | ... "age": rng.uniform(18, 80, n).round(), 146 | ... "income": rng.exponential(50000, n).round(-3), 147 | ... "education": rng.choice(["high_school", "college", "graduate", "phd"], n), 148 | ... "region": rng.choice(["north", "south", "east", "west"], n), 149 | ... }).with_columns([ 150 | ... pl.col("education").cast(pl.Categorical), 151 | ... pl.col("region").cast(pl.Categorical), 152 | ... ]) 153 | >>> 154 | >>> # Again, in practice you would use a fitted model's predict instead 155 | >>> def predict_function(X): 156 | ... pred = X["age"] / 50 + X["income"] / 100_000 * ( 157 | ... 1 + 0.5 * X["education"].is_in(["graduate", "phd"]) 158 | ... ) 159 | ... return pred 160 | >>> 161 | >>> explanation = explain_any( 162 | ... predict=predict_function, 163 | ... X=df[:200], # Explain first 200 rows 164 | ... bg_X=df[200:400], # Pass background dataset or use (subset) of X 165 | ... ) 166 | >>> 167 | >>> explanation.plot.beeswarm() 168 | >>> explanation.plot.scatter() 169 | """ 170 | n, p = X.shape 171 | 172 | if p < 2: 173 | msg = "At least two features are required." 174 | raise ValueError(msg) 175 | 176 | if method is None: 177 | method = "permutation" if p <= 8 else "kernel" 178 | elif method not in ("permutation", "kernel"): 179 | msg = "method must be 'permutation', 'kernel', or None." 180 | raise ValueError(msg) 181 | 182 | if how is None: 183 | if p <= 8: 184 | how = "exact" 185 | elif method == "permutation": 186 | how = "sampling" 187 | else: # "kernel" 188 | how = "h2" if p <= 16 else "h1" 189 | elif method == "permutation" and how not in ("exact", "sampling"): 190 | msg = "how must be 'exact', 'sampling', or None for permutation SHAP." 191 | raise ValueError(msg) 192 | elif method == "kernel" and how not in ("exact", "sampling", "h1", "h2"): 193 | msg = "how must be 'exact', 'sampling', 'h1', 'h2', or None for kernel SHAP." 194 | raise ValueError(msg) 195 | if method == "permutation" and how == "sampling" and p < 4: 196 | msg = ( 197 | "Sampling Permutation SHAP is not supported for p < 4." 198 | "Use how='exact' instead." 199 | ) 200 | raise ValueError(msg) 201 | if method == "kernel" and how == "h1" and p < 4: 202 | msg = ( 203 | "Degree 1 hybrid Kernel SHAP is not supported for p < 4." 204 | "Use how='exact' instead." 205 | ) 206 | raise ValueError(msg) 207 | elif method == "kernel" and how == "h2" and p < 6: 208 | msg = ( 209 | "Degree 2 hybrid Kernel SHAP is not supported for p < 6." 210 | "Use how='exact' instead." 211 | ) 212 | raise ValueError(msg) 213 | 214 | if max_iter is None: 215 | max_iter = 10 * p 216 | elif not isinstance(max_iter, int) or max_iter < 1: 217 | msg = "max_iter must be a positive integer or None." 218 | raise ValueError(msg) 219 | 220 | # Get or check background data (and weights) 221 | bg_X, bg_w = check_or_derive_background_data( 222 | bg_X=bg_X, bg_w=bg_w, bg_n=bg_n, X=X, random_state=random_state 223 | ) 224 | bg_n = bg_X.shape[0] 225 | 226 | # Ensures predictions are (n, K) numpy arrays 227 | predict = safe_predict(predict) 228 | 229 | # Get base value (v0) and predictions (v1) 230 | v1 = predict(X) # (n x K) 231 | v0 = np.average(predict(bg_X), weights=bg_w, axis=0, keepdims=True) # (1 x K) 232 | 233 | # Precalculation of things that can be reused over rows 234 | if method == "permutation": 235 | precalc = precalculate_permshap(p, bg_X, how=how) 236 | else: # method == "kernel" 237 | precalc = precalculate_kernelshap(p, bg_X, how=how) 238 | 239 | # Should we try to deduplicate prediction data? Only if we can save 25% of rows. 240 | if False: # how in ("exact", "h2"): 241 | collapse = collapse_potential(X, bg_X=bg_X, bg_w=bg_w) >= 0.25 242 | else: 243 | collapse = np.zeros(n, dtype=bool) 244 | 245 | if verbose: 246 | how_text = how 247 | if how in ("h1", "h2"): 248 | prop_ex = 100 * precalc["w"].sum() 249 | how_text = f"hybrid degree {1 if how == 'h1' else 2}, {prop_ex:.0f}% exact" 250 | print(f"{method.title()} SHAP ({how_text})") 251 | 252 | res = ParallelPbar(disable=not verbose)(n_jobs=n_jobs)( 253 | joblib.delayed(one_permshap if method == "permutation" else one_kernelshap)( 254 | i, 255 | predict=predict, 256 | how=how, 257 | bg_w=bg_w, 258 | v0=v0, 259 | max_iter=max_iter, 260 | tol=tol, 261 | random_state=random_state, 262 | X=X, 263 | v1=v1, 264 | precalc=precalc, 265 | collapse=collapse, 266 | bg_n=bg_n, 267 | ) 268 | for i in range(n) 269 | ) 270 | 271 | shap_values, se, converged, n_iter = map(np.stack, zip(*res, strict=False)) 272 | 273 | if converged is not None and not converged.all(): 274 | non_converged = converged.shape[0] - np.count_nonzero(converged) 275 | warnings.warn( 276 | f"{non_converged} rows did not converge. " 277 | f"Consider using a larger tol or higher max_iter.", 278 | UserWarning, 279 | stacklevel=2, 280 | ) 281 | 282 | return Explanation( 283 | shap_values, 284 | X=X, 285 | baseline=v0, 286 | standard_errors=se, 287 | converged=converged, 288 | n_iter=n_iter, 289 | ) 290 | -------------------------------------------------------------------------------- /src/lightshap/explainers/explain_tree.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from lightshap.explanation.explanation import Explanation 4 | 5 | 6 | def explain_tree(model, X): 7 | """ 8 | Calculate TreeSHAP for XGBoost, LightGBM, and CatBoost models. 9 | 10 | The following model types are supported: 11 | 12 | - xgboost.Booster 13 | - xgboost.XGBModel 14 | - xgboost.XGBRegressor 15 | - xgboost.XGBClassifier 16 | - xgboost.XGBRFClassifier 17 | - xgboost.XGBRFRegressor 18 | - lightgbm.Booster 19 | - lightgbm.LGBMModel 20 | - lightgbm.LGBMRanker 21 | - lightgbm.LGBMRegressor 22 | - lightgbm.LGBMClassifier 23 | - catboost.CatBoost 24 | - catboost.CatBoostClassifier 25 | - catboost.CatBoostRanker 26 | - catboost.CatBoostRegressor 27 | 28 | Parameters 29 | ---------- 30 | model : XGBoost, LightGBM, or CatBoost model 31 | A fitted model. 32 | X : array-like 33 | The input data for which SHAP values are to be computed. 34 | 35 | Returns 36 | ------- 37 | Explanation 38 | An Explanation object. 39 | 40 | Examples 41 | -------- 42 | 43 | >>> # Example 1: XGBoost regression 44 | >>> import numpy as np 45 | >>> import pandas as pd 46 | >>> from lightshap import explain_tree 47 | >>> 48 | >>> import xgboost as xgb 49 | >>> 50 | >>> rng = np.random.default_rng(seed=42) 51 | >>> X = pd.DataFrame( 52 | ... { 53 | ... "X1": rng.normal(0, 1, 100), 54 | ... "X2": rng.uniform(-2, 2, 100), 55 | ... "X3": rng.choice([0, 1, 2], 100), 56 | ... } 57 | ... ) 58 | >>> y = X["X1"] + X["X2"] ** 2 + X["X3"] + rng.normal(0, 0.1, 100) 59 | >>> model = xgb.train({"learning_rate": 0.1}, xgb.DMatrix(X, label=y)) 60 | >>> 61 | >>> explanation = explain_tree(model, X) 62 | >>> explanation.plot.beeswarm() 63 | >>> explanation.plot.scatter() 64 | 65 | >>> # Example 2: LightGBM Multi-Class Classification 66 | >>> import numpy as np 67 | >>> import pandas as pd 68 | >>> from lightgbm import LGBMClassifier 69 | >>> from lightshap import explain_tree 70 | >>> 71 | >>> rng = np.random.default_rng(seed=42) 72 | >>> X = pd.DataFrame( 73 | ... { 74 | ... "X1": rng.normal(0, 1, 100), 75 | ... "X2": rng.uniform(-2, 2, 100), 76 | ... "X3": rng.choice([0, 1, 2], 100), 77 | ... } 78 | ... ) 79 | >>> y = X["X1"] + X["X2"] ** 2 + X["X3"] + rng.normal(0, 0.1, 100) 80 | >>> y = pd.cut(y, bins=3, labels=[0, 1, 2]) 81 | >>> model = LGBMClassifier(max_depth=3, verbose=-1) 82 | >>> model.fit(X, y) 83 | >>> 84 | >>> # SHAP analysis 85 | >>> explanation = explain_tree(model, X) 86 | >>> explanation.set_output_names(["Class 0", "Class 1", "Class 2"]) 87 | >>> explanation.plot.bar() 88 | >>> explanation.plot.scatter(which_output=0) # Class 0 89 | """ 90 | if _is_xgboost(model): 91 | shap_values, X, feature_names = _xgb_shap(model, X=X) 92 | elif _is_lightgbm(model): 93 | shap_values, X, feature_names = _lgb_shap(model, X=X) 94 | elif _is_catboost(model): 95 | shap_values, X, feature_names = _catboost_shap(model, X=X) 96 | else: 97 | msg = ( 98 | "Model must be a LightGBM, XGBoost, or CatBoost model." 99 | "Note that not all model subtypes are supported." 100 | ) 101 | raise TypeError(msg) 102 | 103 | # Extract baseline 104 | if shap_values.ndim >= 3: # (n x K x p) multi-output model 105 | baseline = shap_values[0, :, -1] 106 | shap_values = shap_values[:, :, :-1].swapaxes(1, 2) # (n x p x K) 107 | else: 108 | baseline = shap_values[0, -1] 109 | shap_values = shap_values[:, :-1] 110 | 111 | # Note that shap_values have shape (n, p) or (n, p, K) at this point, even 112 | # for single row X. 113 | return Explanation(shap_values, X=X, baseline=baseline, feature_names=feature_names) 114 | 115 | 116 | def _is_lightgbm(x): 117 | """Returns True if x is a LightGBM model with SHAP support. 118 | 119 | The following model types are supported: 120 | 121 | - lightgbm.Booster 122 | - lightgbm.LGBMModel 123 | - lightgbm.LGBMRanker 124 | - lightgbm.LGBMRegressor 125 | - lightgbm.LGBMClassifier 126 | 127 | Parameters 128 | ---------- 129 | x : object 130 | The object to check. 131 | 132 | Returns 133 | ------- 134 | bool 135 | True if x is a LightGBM model with SHAP support, and False otherwise. 136 | """ 137 | try: 138 | lgb = sys.modules["lightgbm"] 139 | except KeyError: 140 | return False 141 | return isinstance( 142 | x, 143 | lgb.Booster 144 | | lgb.LGBMModel 145 | | lgb.LGBMRanker 146 | | lgb.LGBMRegressor 147 | | lgb.LGBMClassifier, 148 | ) 149 | 150 | 151 | def _is_xgboost(x): 152 | """Returns True if x is an XGBoost model with SHAP support. 153 | 154 | The following model types are supported: 155 | 156 | - xgboost.Booster 157 | - xgboost.XGBModel 158 | - xgboost.XGBRegressor 159 | - xgboost.XGBClassifier 160 | - xgboost.XGBRFClassifier 161 | - xgboost.XGBRFRegressor 162 | 163 | Parameters 164 | ---------- 165 | x : object 166 | The object to check. 167 | 168 | Returns 169 | ------- 170 | bool 171 | True if x is an XGBoost model with SHAP support, and False otherwise. 172 | """ 173 | try: 174 | xgb = sys.modules["xgboost"] 175 | except KeyError: 176 | return False 177 | return isinstance( 178 | x, 179 | xgb.Booster 180 | | xgb.XGBRanker 181 | | xgb.XGBModel 182 | | xgb.XGBRegressor 183 | | xgb.XGBClassifier 184 | | xgb.XGBRFClassifier 185 | | xgb.XGBRFRegressor, 186 | ) 187 | 188 | 189 | def _is_catboost(x): 190 | """Returns True if x is a CatBoost model with SHAP support. 191 | 192 | The following model types are supported: 193 | 194 | - catboost.CatBoost 195 | - catboost.CatBoostClassifier 196 | - catboost.CatBoostRanker 197 | - catboost.CatBoostRegressor 198 | 199 | Parameters 200 | ---------- 201 | x : object 202 | The object to check. 203 | 204 | Returns 205 | ------- 206 | bool 207 | True if x is a CatBoost model with SHAP support, and False otherwise. 208 | """ 209 | try: 210 | catboost = sys.modules["catboost"] 211 | except KeyError: 212 | return False 213 | return isinstance( 214 | x, 215 | catboost.CatBoost 216 | | catboost.CatBoostClassifier 217 | | catboost.CatBoostRanker 218 | | catboost.CatBoostRegressor, 219 | ) 220 | 221 | 222 | def _lgb_shap(model, X): 223 | """Calculate SHAP values for LightGBM models. 224 | 225 | Parameters 226 | ---------- 227 | model : lightgbm.Booster or similar 228 | The LightGBM model to explain. 229 | 230 | X : array-like 231 | The input data for which to compute SHAP values. Passed to model.predict(). 232 | Cannot be a lightgbm.Dataset. 233 | 234 | Returns 235 | ------- 236 | shap_values : np.ndarray 237 | The computed SHAP values. 238 | 239 | X : Same as input X. 240 | 241 | feature_names : list 242 | A list of feature names. 243 | """ 244 | 245 | import lightgbm as lgb # noqa: PLC0415 246 | 247 | if isinstance(X, lgb.Dataset): 248 | msg = "X cannot be a lgb.Dataset." 249 | raise TypeError(msg) 250 | 251 | n, p = X.shape 252 | 253 | shap_values = model.predict(X, pred_contrib=True) 254 | 255 | # Multi-output: Turn (n x (K * (p + 1))) -> (n x K x (p + 1)) 256 | if shap_values.shape[1] != p + 1: 257 | shap_values = shap_values.reshape(n, -1, p + 1) 258 | 259 | # Extract feature names 260 | if isinstance(model, lgb.Booster): 261 | feature_names = model.feature_name() 262 | else: 263 | feature_names = model.feature_name_ 264 | return shap_values, X, feature_names 265 | 266 | 267 | def _xgb_shap(model, X): 268 | """Calculate SHAP values for XGBoost models. 269 | 270 | Parameters 271 | ---------- 272 | model : xgboost.Booster or similar 273 | The XGBoost model to explain. 274 | 275 | X : xgb.DMatrix or array-like 276 | The input data for which to compute SHAP values. 277 | 278 | Returns 279 | ------- 280 | shap_values : np.ndarray 281 | The computed SHAP values. 282 | 283 | X : array-like 284 | If X is a xgb.DMatrix, the result of X.get_data().toarray(). 285 | Otherwise, the input X. 286 | 287 | feature_names : list, or None 288 | A list of feature names, or None. 289 | 290 | """ 291 | 292 | import xgboost as xgb # noqa: PLC0415 293 | 294 | # Sklearn API predict() does not have pred_contribs argument 295 | if not isinstance(model, xgb.Booster): 296 | model = model.get_booster() 297 | 298 | if not isinstance(X, xgb.DMatrix): 299 | X_pred = xgb.DMatrix(X) 300 | else: 301 | X_pred = X 302 | X = X.get_data().toarray() 303 | 304 | shap_values = model.predict(X_pred, pred_contribs=True) 305 | 306 | return shap_values, X, model.feature_names 307 | 308 | 309 | def _catboost_shap(model, X): 310 | """Calculate SHAP values for CatBoost models. 311 | 312 | Parameters 313 | ---------- 314 | model : catboost.CatBoost or similar 315 | The CatBoost model to explain. 316 | 317 | X : catboost.Pool or array-like 318 | The input data for which to compute SHAP values. 319 | 320 | Returns 321 | ------- 322 | shap_values : np.ndarray 323 | The computed SHAP values. 324 | 325 | X : array-like 326 | If X is a catboost.Pool, the result of X.get_features(). Otherwise, the input X. 327 | 328 | feature_names : list 329 | A list of feature names. 330 | """ 331 | 332 | import catboost # noqa: PLC0415 333 | 334 | if not isinstance(X, catboost.Pool): 335 | X_pred = catboost.Pool(X, cat_features=model.get_cat_feature_indices()) 336 | else: 337 | X_pred = X 338 | X = X.get_features() 339 | 340 | shap_values = model.get_feature_importance(data=X_pred, fstr_type="ShapValues") 341 | 342 | return shap_values, X, model.feature_names_ 343 | -------------------------------------------------------------------------------- /src/lightshap/explainers/kernel_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import binom 3 | 4 | from lightshap.utils import get_dataclass 5 | 6 | from ._utils import ( 7 | check_convergence, 8 | generate_all_masks, 9 | generate_partly_exact_masks, 10 | generate_permutation_masks, 11 | masked_predict, 12 | random_permutation_from_start, 13 | repeat_masks, 14 | replicate_data, 15 | welford_iteration, 16 | ) 17 | 18 | 19 | def one_kernelshap( 20 | i, 21 | predict, 22 | how, 23 | bg_w, 24 | v0, 25 | max_iter, 26 | tol, 27 | random_state, 28 | X, 29 | v1, 30 | precalc, 31 | collapse, 32 | bg_n, 33 | ): 34 | """ 35 | Explain a single row of input data. 36 | 37 | Parameters: 38 | 39 | Returns: 40 | - Explanation for the row. 41 | """ 42 | p = X.shape[1] 43 | xclass = get_dataclass(X) 44 | x = X[i] if xclass != "pd" else X.iloc[i] 45 | v1 = v1[[i]] # (1 x K) 46 | degree = 0 + (how == "h1") + 2 * (how == "h2") 47 | constraint = v1 - v0 # (1 x K) 48 | 49 | if how != "sampling": 50 | vz = masked_predict( 51 | predict=predict, 52 | masks_rep=precalc["masks_exact_rep"], 53 | x=x, 54 | bg_rep=precalc["bg_exact_rep"], 55 | weights=bg_w, 56 | xclass=xclass, 57 | collapse=collapse[i], 58 | bg_n=bg_n, 59 | ) 60 | 61 | A_exact = precalc["A"] 62 | b_exact = precalc["Z"].astype(float).T @ (precalc["w"] * (vz - v0)) # (p x K) 63 | 64 | # Some of the hybrid cases are exact as well 65 | if how == "exact" or p // 2 == degree: 66 | beta_n = kernel_solver(A_exact, b_exact, constraint=constraint) # (p x K) 67 | return beta_n, np.zeros_like(beta_n), True, 1 68 | 69 | # Sampling part, using A_exact and b_exact to fill up the weights 70 | # Credits: https://github.com/iancovert/shapley-regression/blob/master/shapreg/shapley.py 71 | 72 | # Sampling part 73 | rng = np.random.default_rng(random_state) 74 | pl_schema = None if xclass != "pl" else X.columns 75 | j = 0 76 | 77 | # Container for results 78 | beta_n = np.zeros((p, v0.shape[1]), dtype=float) 79 | sum_squared = np.zeros_like(beta_n) 80 | converged = False 81 | 82 | A_sum = np.zeros((p, p), dtype=float) 83 | b_sum = np.zeros_like(beta_n) 84 | est_m = np.zeros_like(beta_n) 85 | 86 | while not converged and j < max_iter: 87 | input_sampling = prepare_input_sampling( 88 | p=p, degree=degree, start=j % p, rng=rng 89 | ) 90 | j += 1 91 | Z = input_sampling["Z"] 92 | 93 | # Expensive 94 | vz = masked_predict( 95 | predict=predict, 96 | masks_rep=repeat_masks(Z, m=bg_n, pl_schema=pl_schema), 97 | x=x, 98 | bg_rep=precalc["bg_sampling_rep"], 99 | weights=bg_w, 100 | xclass=xclass, 101 | collapse=False, 102 | bg_n=bg_n, 103 | ) 104 | A_new = input_sampling["A"] 105 | b_new = Z.astype(float).T @ (input_sampling["w"] * (vz - v0)) 106 | 107 | # Fill the exact part 108 | if how != "sampling": 109 | A_new += A_exact 110 | b_new += b_exact 111 | 112 | # Solve regression on new values to determine standard error and convergence 113 | est_new = kernel_solver(A_new, b_new, constraint=constraint) 114 | _, sum_squared = welford_iteration( 115 | new_value=est_new, avg=est_m, sum_squared=sum_squared, j=j 116 | ) 117 | 118 | # Solve regression on accumulated values 119 | A_sum += A_new 120 | b_sum += b_new 121 | 122 | if j > 1: 123 | beta_n = kernel_solver(A_sum / j, b_sum / j, constraint=v1 - v0) 124 | converged = check_convergence( 125 | beta_n=beta_n, sum_squared=sum_squared, n_iter=j, tol=tol 126 | ) 127 | elif max_iter == 1: # if n_iter == 1 and max_iter == 1 128 | beta_n = est_new 129 | converged = False 130 | sum_squared = np.full_like(sum_squared, fill_value=np.nan, dtype=float) 131 | 132 | return beta_n, np.sqrt(sum_squared) / j, converged, j 133 | 134 | 135 | def calculate_kernel_weights(p): 136 | """ 137 | Kernel weights normalized to a non-empty subset S 138 | {1, ..., p - 1}. 139 | 140 | The weights represent the Kernel weights given that the coalition vector has 141 | already been generated. 142 | 143 | Parameters: 144 | ---------- 145 | p : int 146 | Total number of features. 147 | Returns: 148 | ------- 149 | np.ndarray 150 | Normalized weights for Kernel SHAP. 151 | """ 152 | 153 | S = np.arange(1, p) 154 | probs = 1.0 / (binom(p, S) * S * (p - S)) 155 | 156 | return probs / probs.sum() 157 | 158 | 159 | def calculate_kernel_weights_per_coalition_size(p, degree=0): 160 | """ 161 | Kernel SHAP weights normalized to a non-empty subset S 162 | {degree + 1, ..., p - degree - 1}. 163 | 164 | The weights represent the Kernel weights of a given coalition size. 165 | 166 | Parameters: 167 | ---------- 168 | p : int 169 | Total number of features. 170 | degree : int 171 | Degree of the hybrid approach. 172 | 173 | Returns: 174 | ------- 175 | np.ndarray 176 | Normalized weights for Kernel SHAP. 177 | """ 178 | if p < 2 * degree + 2: 179 | msg = "The number of features p must be at least 2 * degree + 2." 180 | raise ValueError(msg) 181 | 182 | S = np.arange(1 + degree, p - degree) 183 | probs = 1.0 / (S * (p - S)) 184 | 185 | return probs / probs.sum() 186 | 187 | 188 | def calculate_exact_prop(p, degree): 189 | """Total weight to spend. 190 | 191 | How much Kernel SHAP weights do coalitions of size 192 | {1, ..., deg, ..., p-deg-1 ..., p-1} have? 193 | 194 | Parameters: 195 | ---------- 196 | p : int 197 | Total number of features. 198 | degree : int 199 | Degree of the hybrid approach, default 0. 200 | 201 | Returns: 202 | ------- 203 | float 204 | Value between 0 and 1. 205 | """ 206 | if degree <= 0: 207 | return 0.0 208 | 209 | kw = calculate_kernel_weights_per_coalition_size(p) 210 | w_total = 2.0 * kw[np.arange(degree)].sum() 211 | if p == 2 * degree: 212 | w_total -= kw[degree - 1] 213 | return w_total 214 | 215 | 216 | def prepare_input_exact(p): 217 | """ 218 | Calculate the input for exact permutation SHAP. 219 | 220 | This function generates the masks, weights, and A matrix needed for exact 221 | permutation SHAP. 222 | 223 | Parameters: 224 | ---------- 225 | p : int 226 | Number of features. 227 | Returns: 228 | ------- 229 | tuple 230 | A tuple containing: 231 | - Z: A (2p x p) double matrix with all possible masks. 232 | - w: A (2p,) array of weights corresponding to the masks. 233 | - A: A (p x p) matrix used in the SHAP calculations. 234 | """ 235 | Z = generate_all_masks(p)[1:-1] 236 | kw = calculate_kernel_weights(p) 237 | return prepare_Z_w_A(Z, kw=kw, w_total=1.0) 238 | 239 | 240 | def prepare_input_hybrid(p, degree): 241 | """ 242 | Calculate the (partial) input for partly exact Kernel SHAP. 243 | 244 | Create Z, w, A for vectors z with sum(z) in {degree, p-degree} 245 | for k in {1, ..., degree}. 246 | The total weights do not sum to one, except in the special (exact) 247 | case degree=p-degree. 248 | (The remaining weight will be added in the process with calculate_input_sampling(). 249 | Note that for a given k, the weights are constant. 250 | 251 | Parameters: 252 | ---------- 253 | p : int 254 | Number of features. 255 | degree : int 256 | Degree of the hybrid approach. 257 | 258 | Returns: 259 | ------- 260 | tuple 261 | A tuple containing: 262 | - Z: A boolean matrix with partly exact masks. 263 | - w: An array of weights corresponding to the masks. 264 | - A: A (p x p) matrix used in the Kernel SHAP calculations. 265 | """ 266 | if degree < 1: 267 | msg = "degree must be at least 1" 268 | raise ValueError(msg) 269 | if 2 * degree > p: 270 | msg = "p must be >= 2 * degree" 271 | raise ValueError(msg) 272 | 273 | Z_list = [] 274 | 275 | for k in range(degree): 276 | Z = generate_partly_exact_masks(p, degree=k + 1) 277 | Z_list.append(Z) 278 | Z = np.vstack(Z_list) 279 | kw = calculate_kernel_weights(p) 280 | w_total = calculate_exact_prop(p, degree=degree) # total weight to spend 281 | 282 | return prepare_Z_w_A(Z, kw=kw, w_total=w_total) 283 | 284 | 285 | def prepare_input_sampling(p, degree, start, rng): 286 | """ 287 | Calculate input for sampling Kernel SHAP. 288 | 289 | Let m = 2 * (p - 1 - 2 * degree) be the number of masks to sample. 290 | 291 | Provides random input for paired SHAP sampling: 292 | - Z: Matrix with m on-off vectors z with sum(z) following 293 | Kernel weights. 294 | - w: (m, 1) array of weights corresponding to the masks. 295 | - A: Matrix A = Z'wZ 296 | 297 | If degree > 0, vectors z with sum(z) restricted to [degree+1, p-degree-1] are drawn. 298 | This case is used in combination with calculate_input_partly_exact(). Then, 299 | sum(w) < 1. 300 | 301 | Parameters: 302 | ---------- 303 | p : int 304 | Number of features. 305 | degree : int 306 | Degree of the hybrid approach. 307 | start : int 308 | Starting index for the random permutation. 309 | rng : np.random.Generator 310 | Random number generator for reproducibility. 311 | 312 | Returns: 313 | ------- 314 | tuple 315 | A tuple containing: 316 | - Z: A (m x p) boolean matrix with sampled masks. 317 | - w: A (m, 1) array of weights corresponding to the masks. 318 | - A: A (p x p) matrix used in the Kernel SHAP calculations. 319 | """ 320 | if p < 2 * degree + 2: 321 | msg = "The number of features p must be at least 2 * degree + 2." 322 | raise ValueError(msg) 323 | 324 | J = random_permutation_from_start(p, start, rng=rng) 325 | Z = generate_permutation_masks(J, degree=degree) 326 | 327 | # How much of the total weight do we need to cover? 328 | w_total = 1.0 if degree == 0 else 1.0 - calculate_exact_prop(p, degree) 329 | kw = calculate_kernel_weights_per_coalition_size(p) 330 | 331 | return prepare_Z_w_A(Z, kw=kw, w_total=w_total) 332 | 333 | 334 | def prepare_Z_w_A(Z, kw, w_total=1.0): 335 | """ 336 | Prepare Z, w, and A for Kernel SHAP. 337 | 338 | Parameters: 339 | ---------- 340 | Z : np.ndarray 341 | A boolean matrix with masks. 342 | kw : np.ndarray 343 | Kernel weights for each row sum of Z. 344 | w_total : float, default=1.0 345 | Total weight to be distributed among the masks. 346 | 347 | Returns: 348 | ------- 349 | dict 350 | A dictionary containing: 351 | - Z: The input mask matrix. 352 | - w: The weights for the masks, normalized to w_total. 353 | - A: The (p x p) matrix used in Kernel SHAP calculations. 354 | """ 355 | w = kw[np.count_nonzero(Z, axis=1) - 1].reshape(-1, 1) 356 | w *= w_total / w.sum() 357 | Zf = Z.astype(float) 358 | A = Zf.T @ (w * Zf) 359 | 360 | return {"Z": Z, "w": w, "A": A} 361 | 362 | 363 | # Precalculation of things that can be reused over rows 364 | def precalculate_kernelshap(p, bg_X, how): 365 | """ 366 | Precalculate objects that can be reused over rows for Kernel SHAP. 367 | 368 | Parameters: 369 | ---------- 370 | p : int 371 | Number of features. 372 | bg_X : DataFrame, array 373 | Background data. 374 | how : str 375 | Either "exact", "h2", "h1", or "sampling". 376 | 377 | Returns: 378 | ------- 379 | dict 380 | Precalculated objects for Kernel SHAP. 381 | """ 382 | pl_schema = None if get_dataclass(bg_X) != "pl" else bg_X.columns 383 | bg_n = bg_X.shape[0] 384 | degree = 0 + (how == "h1") + 2 * (how == "h2") 385 | 386 | if how == "exact": 387 | precalc = prepare_input_exact(p) 388 | elif how in ("h1", "h2"): 389 | precalc = prepare_input_hybrid(p, degree=degree) 390 | else: 391 | precalc = {} 392 | 393 | # Add replicated version of bg_X, and for the exact part, also of X 394 | if how != "sampling": 395 | Z = precalc["Z"] 396 | precalc["masks_exact_rep"] = repeat_masks(Z, m=bg_n, pl_schema=pl_schema) 397 | precalc["bg_exact_rep"] = replicate_data(bg_X, Z.shape[0]) 398 | if how != "exact": 399 | precalc["bg_sampling_rep"] = replicate_data(bg_X, 2 * (p - 1 - 2 * degree)) 400 | 401 | return precalc 402 | 403 | 404 | def kernel_solver(A, b, constraint): 405 | """ 406 | Solve the kernel SHAP constrained optimization. 407 | 408 | We are following Ian Covert's approach in 409 | https://github.com/iancovert/shapley-regression/blob/master/shapreg/shapley.py 410 | 411 | Alternatively, to avoid any singular matrix issues, we could use the following: 412 | 413 | Ainv = np.linalg.pinv(A) 414 | s = (Ainv @ b).sum(axis=0) - constraint 415 | s /= Ainv.sum() 416 | return Ainv @ (b - s[np.newaxis, :]) 417 | 418 | The current implementation could be improved by glueing b and 1 together 419 | to decompose A only once (idea by Christian Lorentzen). 420 | 421 | Parameters: 422 | ---------- 423 | A : np.ndarray 424 | (p x p) matrix. 425 | b : np.ndarray 426 | (p x K) matrix. 427 | constraint : np.ndarray 428 | (1 x K) array equal to v1 - v0. 429 | 430 | Returns: 431 | ------- 432 | np.ndarray 433 | (p x K) matrix with the solution to the optimization problem. 434 | 435 | Example: 436 | >>> A = np.array([[0.5, 0.1, 0.1], [0.1, 0.5, 0.1], [0.1, 0.1, 0.5]]) 437 | >>> b = np.arange(6).reshape(-1, 2) 438 | >>> constraint = np.arange(2).reshape(1, -1) 439 | >>> kernel_solver(A, b, constraint) 440 | """ 441 | try: 442 | Ainv1 = np.linalg.solve(A, np.ones((A.shape[1], 1))) 443 | Ainvb = np.linalg.solve(A, b) 444 | except np.linalg.LinAlgError as err: 445 | msg = "Matrix A is singular, try hybrid approach or set higher m." 446 | raise ValueError(msg) from err 447 | num = np.sum(Ainvb, axis=0, keepdims=True) - constraint 448 | return Ainvb - Ainv1 @ num / Ainv1.sum() 449 | -------------------------------------------------------------------------------- /src/lightshap/explainers/parallel.py: -------------------------------------------------------------------------------- 1 | # Code copied without modification from https://github.com/louisabraham/tqdm_joblib 2 | # Original work by Louis Abraham 3 | # Licensed under CC BY-SA 4.0: https://creativecommons.org/licenses/by-sa/4.0/ 4 | 5 | import contextlib 6 | 7 | import joblib 8 | from tqdm.autonotebook import tqdm 9 | 10 | 11 | @contextlib.contextmanager 12 | def tqdm_joblib(*args, **kwargs): 13 | """Context manager to patch joblib to report into tqdm progress bar 14 | given as argument""" 15 | 16 | tqdm_object = tqdm(*args, **kwargs) 17 | 18 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | 22 | def __call__(self, *args, **kwargs): 23 | tqdm_object.update(n=self.batch_size) 24 | return super().__call__(*args, **kwargs) 25 | 26 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 27 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 28 | try: 29 | yield tqdm_object 30 | finally: 31 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 32 | tqdm_object.close() 33 | 34 | 35 | def ParallelPbar(desc=None, **tqdm_kwargs): 36 | class Parallel(joblib.Parallel): 37 | def __call__(self, it): 38 | it = list(it) 39 | if self.n_jobs == 1: 40 | it = tqdm(it, desc=desc, **tqdm_kwargs) 41 | return super().__call__(it) 42 | with tqdm_joblib(total=len(it), desc=desc, **tqdm_kwargs): 43 | return super().__call__(it) 44 | 45 | return Parallel 46 | -------------------------------------------------------------------------------- /src/lightshap/explainers/permutation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import binom 3 | 4 | from lightshap.utils import get_dataclass 5 | 6 | from ._utils import ( 7 | check_convergence, 8 | generate_all_masks, 9 | generate_partly_exact_masks, 10 | generate_permutation_masks, 11 | masked_predict, 12 | random_permutation_from_start, 13 | repeat_masks, 14 | replicate_data, 15 | welford_iteration, 16 | ) 17 | 18 | 19 | def one_permshap( 20 | i, 21 | predict, 22 | how, 23 | bg_w, 24 | v0, 25 | max_iter, 26 | tol, 27 | random_state, 28 | X, 29 | v1, 30 | precalc, 31 | collapse, 32 | bg_n, 33 | ): 34 | """ 35 | Explain a single row of input data. 36 | 37 | Parameters: 38 | 39 | Returns: 40 | - Explanation for the row. 41 | """ 42 | p = X.shape[1] 43 | K = v0.shape[1] 44 | xclass = get_dataclass(X) 45 | x = X[i] if xclass != "pd" else X.iloc[i] 46 | v1 = v1[[i]] # (1 x K) 47 | 48 | # Container for results 49 | beta_n = np.zeros((p, K), dtype=float) 50 | sum_squared = np.zeros_like(beta_n) 51 | converged = False 52 | j = 0 53 | 54 | if how == "exact": 55 | vz = np.zeros((2**p, K), dtype=float) 56 | vz[0] = v0 57 | vz[-1] = v1 58 | 59 | vz[1:-1] = masked_predict( 60 | predict=predict, 61 | masks_rep=precalc["masks_exact_rep"], 62 | x=x, 63 | bg_rep=precalc["bg_exact_rep"], 64 | weights=bg_w, 65 | xclass=xclass, 66 | collapse=collapse[i], 67 | bg_n=bg_n, 68 | ) 69 | 70 | for k in range(p): 71 | on, off = precalc["positions"][k] 72 | # Remember that shapley_weights have been computed without the first row 73 | beta_n[k] = np.average( 74 | vz[on] - vz[off], axis=0, weights=precalc["shapley_weights"][on - 1] 75 | ) 76 | return beta_n, sum_squared, True, 1 77 | 78 | # Sampling mode 79 | rng = np.random.default_rng(random_state) 80 | 81 | vz_balanced = masked_predict( 82 | predict=predict, 83 | masks_rep=precalc["masks_balanced_rep"], 84 | x=x, 85 | bg_rep=precalc["bg_balanced_rep"], 86 | weights=bg_w, 87 | xclass=xclass, 88 | collapse=False, 89 | bg_n=bg_n, 90 | ) 91 | 92 | pl_schema = None if xclass != "pl" else X.columns 93 | 94 | # vz has constant first, middle, and last row 95 | vz = np.zeros((2 * p + 1, K), dtype=float) 96 | vz[[0, -1]] = v1 97 | vz[p] = v0 98 | 99 | # Important positions to be filled in vz 100 | from_balanced = [1, 1 + p, p - 1, 2 * p - 1] 101 | from_iter = np.r_[2 : (p - 1), (p + 2) : (2 * p - 1)] 102 | 103 | while not converged and j < max_iter: 104 | # Cycle through p 105 | k = j % p 106 | chain = random_permutation_from_start(p, start=k, rng=rng) 107 | masks = generate_permutation_masks(chain, degree=1) 108 | j += 1 109 | 110 | vzj = masked_predict( 111 | predict=predict, 112 | masks_rep=repeat_masks(masks, m=bg_n, pl_schema=pl_schema), 113 | x=x, 114 | bg_rep=precalc["bg_sampling_rep"], 115 | weights=bg_w, 116 | xclass=xclass, 117 | collapse=False, 118 | bg_n=bg_n, 119 | ) 120 | 121 | # Fill vz first by pre-calculated masks, then by current iteration 122 | vz[from_balanced] = vz_balanced[[k, k + p, chain[p - 1] + p, chain[p - 1]]] 123 | vz[from_iter] = vzj 124 | 125 | # Evaluate Shapley's formula 2p times 126 | J = np.argsort(chain) 127 | forward = vz[J] - vz[J + 1] 128 | backward = vz[p + J + 1] - vz[p + J] 129 | new_value = (forward + backward) / 2 130 | 131 | beta_n, sum_squared = welford_iteration( 132 | new_value=new_value, avg=beta_n, sum_squared=sum_squared, j=j 133 | ) 134 | 135 | if j > 1: # otherwise, sum_squared is still 0 136 | converged = check_convergence( 137 | beta_n=beta_n, sum_squared=sum_squared, n_iter=j, tol=tol 138 | ) 139 | 140 | return beta_n, np.sqrt(sum_squared) / j, converged, j 141 | 142 | 143 | def precalculate_permshap(p, bg_X, how): 144 | """ 145 | Precalculate objects needed for sampling version of permutation SHAP. 146 | 147 | Parameters: 148 | ---------- 149 | p : int 150 | Number of features. 151 | bg_X : DataFrame, array 152 | Background data. 153 | how : str 154 | Either "exact" or "sampling". 155 | 156 | Returns: 157 | ------- 158 | dict 159 | Precalculated objects for permutation SHAP. 160 | """ 161 | pl_schema = None if get_dataclass(bg_X) != "pl" else bg_X.columns 162 | bg_n = bg_X.shape[0] 163 | 164 | if how == "exact": 165 | M = generate_all_masks(p) 166 | other_players = M[1:].sum(axis=1) - 1 # first row cannot be "on" 167 | 168 | precalc = { 169 | "masks_exact_rep": repeat_masks(M[1:-1], m=bg_n, pl_schema=pl_schema), 170 | "bg_exact_rep": replicate_data(bg_X, m=2**p - 2), # masks_rep.shape[0] 171 | "shapley_weights": calculate_shapley_weights(p, other_players), 172 | "positions": positions_for_exact(M), 173 | } 174 | elif p >= 4: # how == "sampling" 175 | M = generate_partly_exact_masks(p, degree=1) 176 | precalc = { 177 | "masks_balanced_rep": repeat_masks(M, m=bg_n, pl_schema=pl_schema), 178 | "bg_balanced_rep": replicate_data(bg_X, 2 * p), 179 | "bg_sampling_rep": replicate_data(bg_X, 2 * (p - 3)), 180 | } 181 | else: 182 | msg = "sampling method not implemented for p < 4." 183 | raise ValueError(msg) 184 | 185 | precalc["bg_n"] = bg_n 186 | 187 | return precalc 188 | 189 | 190 | def calculate_shapley_weights(p, ell): 191 | """Calculate Shapley weights for a given number of features and off-features. 192 | 193 | The function is vectorized over ell. 194 | 195 | Parameters: 196 | ---------- 197 | p : int 198 | Total number of features. 199 | ell : array-like 200 | Number of features that are off (not included in the subset). 201 | 202 | Returns: 203 | ------- 204 | float 205 | The Shapley weight for the given number of features and off-features. 206 | """ 207 | return 1.0 / binom(p, ell) / (p - ell) 208 | 209 | 210 | def positions_for_exact(mask): 211 | """ 212 | Precomputes positions for exact permutation SHAP. 213 | 214 | For each feature j, this function calculates the indices of the rows in the full 215 | mask with column j = True ("on"), and the indices of *corresponding* off rows. 216 | 217 | Parameters: 218 | ---------- 219 | mask : (2**p, p) boolean matrix 220 | Matrix representing on-off info 221 | 222 | Returns: 223 | ------- 224 | list of length p 225 | Each element represents a tuple with 226 | - Row indices in `mask` of "on" positions for feature j 227 | - Row indices in `mask` with corresponding "off" positions for feature j 228 | """ 229 | p = mask.shape[1] 230 | codes = np.arange(mask.shape[0]) # Row index = binary code of the row 231 | 232 | positions = [] 233 | for j in range(p): 234 | on = codes[mask[:, j]] 235 | off = on - 2 ** (p - 1 - j) # trick to turn "bit" off 236 | positions.append((on, off)) 237 | 238 | return positions 239 | -------------------------------------------------------------------------------- /src/lightshap/explainers/tests/test_explain_any.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from lightshap import explain_any 6 | 7 | ATOL = 1e-6 8 | 9 | 10 | def data_regression(): 11 | rng = np.random.default_rng(1) 12 | 13 | n = 100 14 | X = pd.DataFrame( 15 | { 16 | "x1": rng.uniform(0, 1, size=n), 17 | "x2": rng.uniform(0, 1, size=n), 18 | "x3": rng.choice(["A", "B", "C"], size=n), 19 | "x4": pd.Categorical(rng.choice(["a", "b", "c", "d"], size=n)), 20 | "x5": rng.uniform(0, 1, size=n), 21 | "x6": pd.Categorical(rng.choice(["e", "f", "g", "h"], size=n)), 22 | } 23 | ) 24 | return X 25 | 26 | 27 | @pytest.mark.parametrize("use_sample_weights", [False, True]) 28 | def test_exact_permutation_vs_kernel_shap_identical(use_sample_weights): 29 | """Test that exact methods return identical results.""" 30 | 31 | X = data_regression() 32 | 33 | # Predict with interactions of order 4 34 | def predict(X): 35 | return ( 36 | X["x1"] * X["x2"] * (X["x3"].isin(["A", "C"]) + 1) * (X["x4"].cat.codes + 1) 37 | + X["x5"] 38 | + X["x6"].cat.codes 39 | ) 40 | 41 | if use_sample_weights: 42 | rng = np.random.default_rng(1) 43 | sample_weights = rng.uniform(0.0, 1.0, size=X.shape[0]) 44 | else: 45 | sample_weights = None 46 | 47 | X_small = X.head(10) 48 | 49 | # Get explanations using exact permutation SHAP 50 | explanation_perm = explain_any( 51 | predict=predict, 52 | X=X_small, 53 | bg_X=X, 54 | bg_w=sample_weights, 55 | method="permutation", 56 | how="exact", 57 | verbose=False, 58 | ) 59 | 60 | # Get explanations using exact kernel SHAP 61 | explanation_kernel = explain_any( 62 | predict=predict, 63 | X=X_small, 64 | bg_X=X, 65 | bg_w=sample_weights, 66 | method="kernel", 67 | how="exact", 68 | verbose=False, 69 | ) 70 | 71 | np.testing.assert_allclose( 72 | explanation_perm.shap_values, explanation_kernel.shap_values, atol=ATOL 73 | ) 74 | np.testing.assert_allclose( 75 | explanation_perm.baseline, explanation_kernel.baseline, atol=ATOL 76 | ) 77 | 78 | 79 | @pytest.mark.parametrize("use_sample_weights", [False, True]) 80 | @pytest.mark.parametrize( 81 | ("method", "how"), 82 | [ 83 | ("kernel", "exact"), 84 | ("kernel", "sampling"), 85 | ("kernel", "h1"), 86 | ("kernel", "h2"), 87 | ("permutation", "sampling"), 88 | ], 89 | ) 90 | def test_permutation_vs_kernel_shap_with_interactions(use_sample_weights, method, how): 91 | """Test that algorithms agree for models with interactions of order up to two.""" 92 | 93 | X = data_regression() 94 | 95 | # Predict with interactions of order 2 96 | def predict(X): 97 | return ( 98 | X["x1"] * X["x2"] 99 | + (X["x3"].isin(["A", "C"]) + 1) * (X["x4"].cat.codes + 1) 100 | + X["x5"] 101 | + X["x6"].cat.codes 102 | ) 103 | 104 | if use_sample_weights: 105 | rng = np.random.default_rng(1) 106 | sample_weights = rng.uniform(0.0, 1.0, size=X.shape[0]) 107 | else: 108 | sample_weights = None 109 | 110 | X_small = X.head(5) 111 | 112 | # Get explanations using permutation SHAP 113 | reference = explain_any( 114 | predict=predict, 115 | X=X_small, 116 | bg_X=X, 117 | bg_w=sample_weights, 118 | method="permutation", 119 | how="exact", 120 | verbose=False, 121 | ) 122 | 123 | explanation = explain_any( 124 | predict=predict, 125 | X=X_small, 126 | bg_X=X, 127 | bg_w=sample_weights, 128 | method=method, 129 | how=how, 130 | verbose=False, 131 | ) 132 | 133 | np.testing.assert_allclose( 134 | reference.shap_values, explanation.shap_values, atol=ATOL 135 | ) 136 | np.testing.assert_allclose(explanation.baseline, reference.baseline, atol=ATOL) 137 | 138 | 139 | @pytest.mark.parametrize("method", ["kernel", "permutation"]) 140 | def test_against_shap_library_reference(method): 141 | """Test against known SHAP values from the Python shap library.""" 142 | # Expected values from shap library (Exact explainers) 143 | # Note that sampling methods do not perform very well here as the features 144 | # are extremely highly correlated. 145 | 146 | expected = np.array( 147 | [ 148 | [-1.19621609, -1.24184808, -0.9567848, 3.87942037, -0.33825, 0.54562519], 149 | [-1.64922699, -1.20770105, -1.18388581, 4.54321217, -0.33795, -0.41082395], 150 | ] 151 | ) 152 | 153 | n = 100 154 | X = pd.DataFrame( 155 | { 156 | "x1": np.arange(1, n + 1) / 100, 157 | "x2": np.log(np.arange(1, n + 1)), 158 | "x3": np.sqrt(np.arange(1, n + 1)), 159 | "x4": np.sin(np.arange(1, n + 1)), 160 | "x5": (np.arange(1, n + 1) / 100) ** 2, 161 | "x6": np.cos(np.arange(1, n + 1)), 162 | } 163 | ) 164 | 165 | def predict(X): 166 | return X["x1"] * X["x2"] * X["x3"] * X["x4"] + X["x5"] + X["x6"] 167 | 168 | X_test = X.head(2) 169 | 170 | explanation = explain_any( 171 | predict=predict, 172 | X=X_test, 173 | bg_X=X, 174 | method=method, 175 | how="exact", 176 | verbose=False, 177 | ) 178 | 179 | # Reference via shap.explainers.ExactExplainer(predict, X)(X_test) (shap 0.47.2) 180 | 181 | np.testing.assert_allclose(explanation.shap_values, expected, atol=ATOL) 182 | 183 | 184 | class TestWeights: 185 | """Test class for weight-related functionality.""" 186 | 187 | @pytest.mark.parametrize( 188 | ("method", "how"), 189 | [ 190 | ("kernel", "exact"), 191 | ("kernel", "sampling"), 192 | ("kernel", "h1"), 193 | ("kernel", "h2"), 194 | ("permutation", "exact"), 195 | ("permutation", "sampling"), 196 | ], 197 | ) 198 | def test_constant_weights_equal_no_weights(self, method, how): 199 | """Test that constant weights equal no weights.""" 200 | 201 | X = data_regression() 202 | 203 | def predict(X): 204 | return ( 205 | X["x1"] 206 | * X["x2"] 207 | * (X["x3"].isin(["A", "C"]) + 1) 208 | * (X["x4"].cat.codes + 1) 209 | + X["x5"] 210 | + X["x6"].cat.codes 211 | ) 212 | 213 | bg_w = np.full(X.shape[0], 2.0) 214 | X_small = X.head(10) 215 | 216 | # Get explanations without weights 217 | explanation_no_weights = explain_any( 218 | predict=predict, 219 | X=X_small, 220 | bg_X=X, 221 | bg_w=None, 222 | method=method, 223 | how=how, 224 | verbose=False, 225 | random_state=1, 226 | ) 227 | 228 | # Get explanations with constant weights 229 | explanation_constant_weights = explain_any( 230 | predict=predict, 231 | X=X_small, 232 | bg_X=X, 233 | bg_w=bg_w, 234 | method=method, 235 | how=how, 236 | verbose=False, 237 | random_state=1, 238 | ) 239 | 240 | np.testing.assert_allclose( 241 | explanation_no_weights.shap_values, 242 | explanation_constant_weights.shap_values, 243 | atol=ATOL, 244 | ) 245 | np.testing.assert_allclose( 246 | explanation_no_weights.baseline, 247 | explanation_constant_weights.baseline, 248 | atol=ATOL, 249 | ) 250 | 251 | @pytest.mark.parametrize( 252 | ("method", "how"), 253 | [ 254 | ("kernel", "exact"), 255 | ("kernel", "sampling"), 256 | ("kernel", "h1"), 257 | ("kernel", "h2"), 258 | ("permutation", "exact"), 259 | ("permutation", "sampling"), 260 | ], 261 | ) 262 | def test_non_constant_weights_differ_from_no_weights(self, method, how): 263 | """Test that non-constant weights give different results than no weights.""" 264 | 265 | X = data_regression() 266 | 267 | def predict(X): 268 | return ( 269 | X["x1"] * X["x2"] 270 | + (X["x3"].isin(["A", "C"]) + 1) * (X["x4"].cat.codes + 1) 271 | + X["x5"] 272 | + X["x6"].cat.codes 273 | ) 274 | 275 | # Create non-constant weights 276 | rng = np.random.default_rng(1) 277 | bg_w = rng.uniform(0.1, 2.0, size=X.shape[0]) 278 | X_small = X.head(20) 279 | 280 | # Get explanations without weights 281 | explanation_no_weights = explain_any( 282 | predict=predict, 283 | X=X_small, 284 | bg_X=X, 285 | bg_w=None, 286 | method=method, 287 | how=how, 288 | verbose=False, 289 | random_state=1, 290 | ) 291 | 292 | # Get explanations with non-constant weights 293 | explanation_weighted = explain_any( 294 | predict=predict, 295 | X=X_small, 296 | bg_X=X, 297 | bg_w=bg_w, 298 | method=method, 299 | how=how, 300 | verbose=False, 301 | random_state=1, 302 | ) 303 | 304 | # Results should be different (not allclose) 305 | with pytest.raises(AssertionError): 306 | np.testing.assert_allclose( 307 | explanation_no_weights.shap_values, 308 | explanation_weighted.shap_values, 309 | atol=ATOL, 310 | ) 311 | 312 | 313 | @pytest.mark.parametrize( 314 | ("method", "how"), 315 | [ 316 | ("permutation", "sampling"), 317 | ("kernel", "sampling"), 318 | ("kernel", "h1"), 319 | ("kernel", "h2"), 320 | ], 321 | ) 322 | def test_sampling_methods_approximate_exact(method, how): 323 | """Test that sampling methods approximate exact results within tolerance.""" 324 | # Note that we are using a model with interactions of order > 2 to see 325 | # differences between the methods. 326 | n = 100 327 | rng = np.random.default_rng(1) 328 | 329 | X = pd.DataFrame(rng.uniform(0, 2, (n, 6)), columns=[f"x{i}" for i in range(6)]) 330 | 331 | def predict(X): 332 | return X["x0"] * X["x1"] * X["x2"] + X["x3"] * X["x4"] * X["x5"] 333 | 334 | X_test = X.head(5) 335 | 336 | # Exact reference 337 | exact = explain_any( 338 | predict=predict, 339 | X=X_test, 340 | bg_X=X, 341 | method="permutation", 342 | how="exact", 343 | verbose=False, 344 | ) 345 | 346 | # Approximations of increasing quality 347 | approx = [] 348 | for tol in [0.01, 0.005, 0.0025]: 349 | approx.append( 350 | explain_any( 351 | predict=predict, 352 | X=X_test, 353 | bg_X=X, 354 | method=method, 355 | how=how, 356 | tol=tol, 357 | max_iter=500, # to avoid convergence warnings 358 | verbose=False, 359 | random_state=1, 360 | ) 361 | ) 362 | mae = [np.abs(apr.shap_values - exact.shap_values).mean() for apr in approx] 363 | 364 | # Approximations get better with smaller tolerance (but not necessarily strictly) 365 | assert mae[0] >= mae[1] >= mae[2] 366 | 367 | # These tolerances differ by factor of 4, so that equality is very unlikely 368 | assert mae[0] > mae[2] 369 | 370 | # Threshold somewhat arbitrary, but small given that average prediction is 2 371 | assert mae[2] <= 0.005 372 | 373 | 374 | class TestErrorConditions: 375 | """Test class for error conditions and bad inputs.""" 376 | 377 | def test_too_few_features(self): 378 | """Test that p < 2 raises ValueError.""" 379 | X = pd.DataFrame({"x1": [1, 2, 3]}) 380 | 381 | def predict(X): 382 | return X["x1"] 383 | 384 | with pytest.raises(ValueError, match="At least two features are required"): 385 | explain_any(predict, X) 386 | 387 | def test_invalid_method(self): 388 | """Test that invalid method raises ValueError.""" 389 | X = data_regression() 390 | 391 | def predict(X): 392 | return X["x1"] + X["x2"] 393 | 394 | with pytest.raises( 395 | ValueError, match="method must be 'permutation', 'kernel', or None" 396 | ): 397 | explain_any(predict, X, method="invalid") 398 | 399 | @pytest.mark.parametrize("how", ["invalid", "h1", "h2"]) 400 | def test_invalid_how_for_permutation(self, how): 401 | """Test that invalid how for permutation SHAP raises ValueError.""" 402 | X = data_regression() 403 | 404 | def predict(X): 405 | return X["x1"] + X["x2"] 406 | 407 | with pytest.raises( 408 | ValueError, 409 | match="how must be 'exact', 'sampling', or None for permutation SHAP", 410 | ): 411 | explain_any(predict, X, method="permutation", how=how) 412 | 413 | @pytest.mark.parametrize("how", ["invalid", "h3"]) 414 | def test_invalid_how_for_kernel(self, how): 415 | """Test that invalid how for kernel SHAP raises ValueError.""" 416 | X = data_regression() 417 | 418 | def predict(X): 419 | return X["x1"] + X["x2"] 420 | 421 | with pytest.raises( 422 | ValueError, 423 | match="how must be 'exact', 'sampling', 'h1', 'h2', or None for kernel SHAP", 424 | ): 425 | explain_any(predict, X, method="kernel", how=how) 426 | 427 | def test_sampling_permutation_too_few_features(self): 428 | """Test that sampling permutation SHAP with p < 4 raises ValueError.""" 429 | X = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "x3": [7, 8, 9]}) 430 | 431 | def predict(X): 432 | return X["x1"] + X["x2"] + X["x3"] 433 | 434 | with pytest.raises( 435 | ValueError, match="Sampling Permutation SHAP is not supported for p < 4" 436 | ): 437 | explain_any(predict, X, bg_X=X, method="permutation", how="sampling") 438 | 439 | def test_h1_kernel_too_few_features(self): 440 | """Test that h1 kernel SHAP with p < 4 raises ValueError.""" 441 | X = pd.DataFrame({"x1": [1, 2, 3], "x2": [4, 5, 6], "x3": [7, 8, 9]}) 442 | 443 | def predict(X): 444 | return X["x1"] + X["x2"] + X["x3"] 445 | 446 | with pytest.raises( 447 | ValueError, match="Degree 1 hybrid Kernel SHAP is not supported for p < 4" 448 | ): 449 | explain_any(predict, X, bg_X=X, method="kernel", how="h1") 450 | 451 | def test_h2_kernel_too_few_features(self): 452 | """Test that h2 kernel SHAP with p < 6 raises ValueError.""" 453 | X = pd.DataFrame( 454 | { 455 | "x1": [1, 2, 3, 4, 5], 456 | "x2": [4, 5, 6, 7, 8], 457 | "x3": [7, 8, 9, 10, 11], 458 | "x4": [10, 11, 12, 13, 14], 459 | "x5": [13, 14, 15, 16, 17], 460 | } 461 | ) 462 | 463 | def predict(X): 464 | return X["x1"] + X["x2"] + X["x3"] + X["x4"] + X["x5"] 465 | 466 | with pytest.raises( 467 | ValueError, match="Degree 2 hybrid Kernel SHAP is not supported for p < 6" 468 | ): 469 | explain_any(predict, X, bg_X=X, method="kernel", how="h2") 470 | 471 | @pytest.mark.parametrize("max_iter", [0, -1, -10, 0.5, "invalid"]) 472 | def test_invalid_max_iter(self, max_iter): 473 | """Test that invalid max_iter raises ValueError.""" 474 | X = data_regression() 475 | 476 | def predict(X): 477 | return X["x1"] + X["x2"] 478 | 479 | with pytest.raises( 480 | ValueError, match="max_iter must be a positive integer or None" 481 | ): 482 | explain_any(predict, X, max_iter=max_iter) 483 | -------------------------------------------------------------------------------- /src/lightshap/explainers/tests/test_explain_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.datasets import make_classification, make_regression 5 | from sklearn.ensemble import RandomForestRegressor 6 | 7 | from lightshap import explain_tree 8 | 9 | 10 | def classification_data(): 11 | X, y = make_classification( 12 | n_samples=100, 13 | n_features=4, 14 | n_classes=3, 15 | n_clusters_per_class=1, 16 | random_state=1, 17 | ) 18 | return X, y 19 | 20 | 21 | def regression_data(): 22 | return make_regression(n_samples=100, n_features=4, random_state=1) 23 | 24 | 25 | class TestXGBoost: 26 | """Test XGBoost models with explain_tree.""" 27 | 28 | def test_xgboost_booster_regression_dmatrix(self): 29 | """Test XGBoost Booster with DMatrix for regression.""" 30 | xgb = pytest.importorskip("xgboost") 31 | 32 | X, y = regression_data() 33 | feature_names = [f"f{i}" for i in range(X.shape[1])] 34 | dtrain = xgb.DMatrix(X, label=y, feature_names=feature_names) 35 | 36 | model = xgb.train({"objective": "reg:squarederror"}, dtrain, num_boost_round=10) 37 | 38 | # Get SHAP values directly from XGBoost 39 | expected_shap = model.predict(dtrain, pred_contribs=True) 40 | 41 | # Test explain_tree 42 | expl = explain_tree(model, dtrain) 43 | 44 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 45 | np.testing.assert_allclose(expl.X, X) 46 | assert expl.baseline == expected_shap[0, -1] 47 | assert expl.feature_names == feature_names 48 | 49 | def test_xgboost_booster_regression_numpy(self): 50 | """Test XGBoost Booster with numpy array for regression.""" 51 | xgb = pytest.importorskip("xgboost") 52 | 53 | X, y = regression_data() 54 | dtrain = xgb.DMatrix(X, label=y) 55 | 56 | model = xgb.train({"objective": "reg:squarederror"}, dtrain, num_boost_round=10) 57 | 58 | # Get SHAP values directly from XGBoost 59 | expected_shap = model.predict(dtrain, pred_contribs=True) 60 | 61 | # Test explain_tree 62 | expl = explain_tree(model, X) 63 | 64 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 65 | np.testing.assert_allclose(expl.X, X) 66 | assert expl.baseline == expected_shap[0, -1] 67 | 68 | def test_xgboost_regressor_pandas(self): 69 | """Test XGBRegressor with pandas DataFrame.""" 70 | xgb = pytest.importorskip("xgboost") 71 | 72 | X, y = regression_data() 73 | feature_names = [f"f{i}" for i in range(X.shape[1])] 74 | X_df = pd.DataFrame(X, columns=feature_names) 75 | 76 | model = xgb.XGBRegressor(n_estimators=10, random_state=0) 77 | model.fit(X_df, y) 78 | 79 | # Get SHAP values directly from XGBoost 80 | booster = model.get_booster() 81 | dtest = xgb.DMatrix(X_df) 82 | expected_shap = booster.predict(dtest, pred_contribs=True) 83 | 84 | # Test explain_tree 85 | expl = explain_tree(model, X_df) 86 | 87 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 88 | pd.testing.assert_frame_equal(expl.X, X_df) 89 | assert expl.baseline == expected_shap[0, -1] 90 | assert expl.feature_names == feature_names 91 | 92 | def test_xgboost_classifier_multiclass(self): 93 | """Test XGBClassifier with 3 classes.""" 94 | xgb = pytest.importorskip("xgboost") 95 | 96 | X, y = classification_data() 97 | feature_names = [f"f{i}" for i in range(X.shape[1])] 98 | X_df = pd.DataFrame(X, columns=feature_names) 99 | 100 | model = xgb.XGBClassifier(n_estimators=10, random_state=0) 101 | model.fit(X_df, y) 102 | 103 | # Get SHAP values directly from XGBoost 104 | booster = model.get_booster() 105 | dtest = xgb.DMatrix(X_df) 106 | expected_shap = booster.predict(dtest, pred_contribs=True) 107 | 108 | # Test explain_tree 109 | expl = explain_tree(model, X_df) 110 | 111 | # For multiclass, expected shape is (n, K, p+1) -> (n, p, K) 112 | expected_values = expected_shap[:, :, :-1].swapaxes(1, 2) 113 | expected_baseline = expected_shap[0, :, -1] 114 | 115 | np.testing.assert_allclose(expl.shap_values, expected_values) 116 | pd.testing.assert_frame_equal(expl.X, X_df) 117 | np.testing.assert_allclose(expl.baseline, expected_baseline) 118 | assert expl.feature_names == feature_names 119 | 120 | def test_xgboost_rf_regressor(self): 121 | """Test XGBRFRegressor.""" 122 | xgb = pytest.importorskip("xgboost") 123 | 124 | X, y = regression_data() 125 | 126 | model = xgb.XGBRFRegressor(n_estimators=10, random_state=0) 127 | model.fit(X, y) 128 | 129 | # Get SHAP values directly from XGBoost 130 | booster = model.get_booster() 131 | dtest = xgb.DMatrix(X) 132 | expected_shap = booster.predict(dtest, pred_contribs=True) 133 | 134 | # Test explain_tree 135 | expl = explain_tree(model, X) 136 | 137 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 138 | np.testing.assert_allclose(expl.X, X) 139 | assert expl.baseline == expected_shap[0, -1] 140 | 141 | def test_xgboost_rf_classifier(self): 142 | """Test XGBRFClassifier with 3 classes.""" 143 | xgb = pytest.importorskip("xgboost") 144 | 145 | X, y = classification_data() 146 | 147 | model = xgb.XGBRFClassifier(n_estimators=10, random_state=0) 148 | model.fit(X, y) 149 | 150 | # Get SHAP values directly from XGBoost 151 | booster = model.get_booster() 152 | dtest = xgb.DMatrix(X) 153 | expected_shap = booster.predict(dtest, pred_contribs=True) 154 | 155 | # Test explain_tree 156 | expl = explain_tree(model, X) 157 | 158 | # For multiclass, expected shape is (n, K, p+1) -> (n, p, K) 159 | expected_values = expected_shap[:, :, :-1].swapaxes(1, 2) 160 | expected_baseline = expected_shap[0, :, -1] 161 | 162 | np.testing.assert_allclose(expl.shap_values, expected_values) 163 | np.testing.assert_allclose(expl.X, X) 164 | np.testing.assert_allclose(expl.baseline, expected_baseline) 165 | 166 | 167 | class TestLightGBM: 168 | """Test LightGBM models with explain_tree.""" 169 | 170 | def test_lightgbm_booster_numpy(self): 171 | """Test LightGBM Booster with numpy array.""" 172 | lgb = pytest.importorskip("lightgbm") 173 | 174 | X, y = regression_data() 175 | train_data = lgb.Dataset(X, label=y) 176 | 177 | model = lgb.train( 178 | {"objective": "regression", "verbose": -1}, train_data, num_boost_round=10 179 | ) 180 | 181 | # Get SHAP values directly from LightGBM 182 | expected_shap = model.predict(X, pred_contrib=True) 183 | 184 | # Test explain_tree 185 | expl = explain_tree(model, X) 186 | 187 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 188 | np.testing.assert_allclose(expl.X, X) 189 | assert expl.baseline == expected_shap[0, -1] 190 | assert expl.feature_names == model.feature_name() 191 | 192 | def test_lightgbm_regressor_pandas(self): 193 | """Test LGBMRegressor with pandas DataFrame.""" 194 | lgb = pytest.importorskip("lightgbm") 195 | 196 | X, y = regression_data() 197 | feature_names = [f"f{i}" for i in range(X.shape[1])] 198 | X_df = pd.DataFrame(X, columns=feature_names) 199 | 200 | model = lgb.LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) 201 | model.fit(X_df, y) 202 | 203 | # Get SHAP values directly from LightGBM 204 | expected_shap = model.predict(X_df, pred_contrib=True) 205 | 206 | # Test explain_tree 207 | expl = explain_tree(model, X_df) 208 | 209 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 210 | pd.testing.assert_frame_equal(expl.X, X_df) 211 | assert expl.baseline == expected_shap[0, -1] 212 | assert expl.feature_names == feature_names 213 | 214 | def test_lightgbm_regressor_numpy(self): 215 | """Test LGBMRegressor with numpy array.""" 216 | lgb = pytest.importorskip("lightgbm") 217 | 218 | X, y = regression_data() 219 | 220 | model = lgb.LGBMRegressor(n_estimators=10, verbose=-1, random_state=0) 221 | model.fit(X, y) 222 | 223 | # Get SHAP values directly from LightGBM 224 | expected_shap = model.predict(X, pred_contrib=True) 225 | 226 | # Test explain_tree 227 | expl = explain_tree(model, X) 228 | 229 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 230 | np.testing.assert_allclose(expl.X, X) 231 | assert expl.baseline == expected_shap[0, -1] 232 | 233 | def test_lightgbm_classifier_multiclass(self): 234 | """Test LGBMClassifier with 3 classes.""" 235 | lgb = pytest.importorskip("lightgbm") 236 | 237 | X, y = classification_data() 238 | feature_names = [f"f{i}" for i in range(X.shape[1])] 239 | X_df = pd.DataFrame(X, columns=feature_names) 240 | 241 | model = lgb.LGBMClassifier(n_estimators=10, verbose=-1, random_state=0) 242 | model.fit(X_df, y) 243 | 244 | # Get SHAP values directly from LightGBM 245 | expected_shap = model.predict(X_df, pred_contrib=True) 246 | 247 | # Test explain_tree 248 | expl = explain_tree(model, X_df) 249 | 250 | # For multiclass, reshape from (n, K*(p+1)) to (n, K, p+1) then (n, p, K) 251 | n, p = X_df.shape 252 | expected_shap_reshaped = expected_shap.reshape(n, -1, p + 1) 253 | expected_values = expected_shap_reshaped[:, :, :-1].swapaxes(1, 2) 254 | expected_baseline = expected_shap_reshaped[0, :, -1] 255 | 256 | np.testing.assert_allclose(expl.shap_values, expected_values) 257 | pd.testing.assert_frame_equal(expl.X, X_df) 258 | np.testing.assert_allclose(expl.baseline, expected_baseline) 259 | assert expl.feature_names == feature_names 260 | 261 | 262 | class TestCatBoost: 263 | """Test CatBoost models with explain_tree.""" 264 | 265 | def test_catboost_regressor_pandas(self): 266 | """Test CatBoostRegressor with pandas DataFrame.""" 267 | catboost = pytest.importorskip("catboost") 268 | 269 | X, y = regression_data() 270 | feature_names = [f"f{i}" for i in range(X.shape[1])] 271 | X_df = pd.DataFrame(X, columns=feature_names) 272 | 273 | model = catboost.CatBoostRegressor(iterations=10, verbose=False, random_state=0) 274 | model.fit(X_df, y) 275 | 276 | # Get SHAP values directly from CatBoost 277 | pool = catboost.Pool(X_df) 278 | expected_shap = model.get_feature_importance(data=pool, fstr_type="ShapValues") 279 | 280 | # Test explain_tree 281 | expl = explain_tree(model, X_df) 282 | 283 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 284 | pd.testing.assert_frame_equal(expl.X, X_df) 285 | assert expl.baseline == expected_shap[0, -1] 286 | assert expl.feature_names == model.feature_names_ 287 | 288 | def test_catboost_regressor_numpy(self): 289 | """Test CatBoostRegressor with numpy array.""" 290 | catboost = pytest.importorskip("catboost") 291 | 292 | X, y = regression_data() 293 | 294 | model = catboost.CatBoostRegressor(iterations=10, verbose=False, random_state=0) 295 | model.fit(X, y) 296 | 297 | # Get SHAP values directly from CatBoost 298 | pool = catboost.Pool(X, cat_features=model.get_cat_feature_indices()) 299 | expected_shap = model.get_feature_importance(data=pool, fstr_type="ShapValues") 300 | 301 | # Test explain_tree 302 | expl = explain_tree(model, X) 303 | 304 | np.testing.assert_allclose(expl.shap_values, expected_shap[:, :-1]) 305 | np.testing.assert_allclose(expl.X, X) 306 | assert expl.baseline == expected_shap[0, -1] 307 | assert expl.feature_names == model.feature_names_ 308 | 309 | def test_catboost_classifier_multiclass_pandas(self): 310 | """Test CatBoostClassifier with 3 classes and pandas DataFrame.""" 311 | catboost = pytest.importorskip("catboost") 312 | 313 | X, y = classification_data() 314 | feature_names = [f"f{i}" for i in range(X.shape[1])] 315 | X_df = pd.DataFrame(X, columns=feature_names) 316 | 317 | model = catboost.CatBoostClassifier( 318 | iterations=10, verbose=False, random_state=0 319 | ) 320 | model.fit(X_df, y) 321 | 322 | # Get SHAP values directly from CatBoost 323 | pool = catboost.Pool(X_df) 324 | expected_shap = model.get_feature_importance(data=pool, fstr_type="ShapValues") 325 | 326 | # Test explain_tree 327 | expl = explain_tree(model, X_df) 328 | 329 | # For multiclass, expected shape is (n, K, p+1) -> (n, p, K) 330 | expected_values = expected_shap[:, :, :-1].swapaxes(1, 2) 331 | expected_baseline = expected_shap[0, :, -1] 332 | 333 | np.testing.assert_allclose(expl.shap_values, expected_values) 334 | pd.testing.assert_frame_equal(expl.X, X_df) 335 | np.testing.assert_allclose(expl.baseline, expected_baseline) 336 | assert expl.feature_names == feature_names 337 | 338 | def test_catboost_classifier_multiclass_numpy(self): 339 | """Test CatBoostClassifier with 3 classes and numpy array.""" 340 | catboost = pytest.importorskip("catboost") 341 | 342 | X, y = classification_data() 343 | model = catboost.CatBoostClassifier( 344 | iterations=10, verbose=False, random_state=0 345 | ) 346 | model.fit(X, y) 347 | 348 | # Get SHAP values directly from CatBoost 349 | pool = catboost.Pool(X, cat_features=model.get_cat_feature_indices()) 350 | expected_shap = model.get_feature_importance(data=pool, fstr_type="ShapValues") 351 | 352 | # Test explain_tree 353 | expl = explain_tree(model, X) 354 | 355 | # For multiclass, expected shape is (n, K, p+1) -> (n, p, K) 356 | expected_values = expected_shap[:, :, :-1].swapaxes(1, 2) 357 | expected_baseline = expected_shap[0, :, -1] 358 | 359 | np.testing.assert_allclose(expl.shap_values, expected_values) 360 | np.testing.assert_allclose(expl.X, X) 361 | np.testing.assert_allclose(expl.baseline, expected_baseline) 362 | assert expl.feature_names == model.feature_names_ 363 | 364 | 365 | class TestErrorHandling: 366 | """Test error handling for unsupported models.""" 367 | 368 | def test_unsupported_model_raises_error(self): 369 | """Test that unsupported models raise TypeError.""" 370 | X, y = make_regression(n_samples=100, n_features=4, random_state=1) 371 | model = RandomForestRegressor(n_estimators=10, random_state=0) 372 | model.fit(X, y) 373 | 374 | with pytest.raises( 375 | TypeError, match="Model must be a LightGBM, XGBoost, or CatBoost model" 376 | ): 377 | explain_tree(model, X) 378 | 379 | def test_lgb_dataset_raises_error(self): 380 | """Test that LightGBM Dataset as X raises TypeError.""" 381 | lgb = pytest.importorskip("lightgbm") 382 | 383 | X, y = make_regression(n_samples=100, n_features=4, random_state=1) 384 | train_data = lgb.Dataset(X, label=y) 385 | 386 | model = lgb.train( 387 | {"objective": "regression", "verbose": -1}, train_data, num_boost_round=10 388 | ) 389 | 390 | with pytest.raises(TypeError, match="X cannot be a lgb.Dataset"): 391 | explain_tree(model, train_data) 392 | -------------------------------------------------------------------------------- /src/lightshap/explainers/tests/test_kernel_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from lightshap.explainers.kernel_utils import ( 6 | calculate_exact_prop, 7 | calculate_kernel_weights, 8 | calculate_kernel_weights_per_coalition_size, 9 | kernel_solver, 10 | one_kernelshap, 11 | precalculate_kernelshap, 12 | prepare_input_exact, 13 | prepare_input_hybrid, 14 | prepare_input_sampling, 15 | ) 16 | 17 | 18 | class TestCalculateKernelWeights: 19 | """Test Kernel SHAP weight calculation.""" 20 | 21 | def test_kernel_weights_basic(self): 22 | """Test kernel weights for basic case.""" 23 | p = 4 24 | weights = calculate_kernel_weights(p) 25 | 26 | # Test against R: kernelshap:::kernel_weights(4) # v 0.9.0 27 | expected = [0.4, 0.2, 0.4] 28 | np.testing.assert_array_almost_equal(weights, expected) 29 | 30 | def test_kernel_weights_per_coalition_size_degree_0(self): 31 | """Test kernel weights per coalition size for degree 0.""" 32 | p = 5 33 | degree = 0 34 | weights = calculate_kernel_weights_per_coalition_size(p, degree) 35 | 36 | # Test against R: kernelshap:::kernel_weights_per_coalition_size(5) # v 0.9.0 37 | expected = [0.3, 0.2, 0.2, 0.3] 38 | np.testing.assert_array_almost_equal(weights, expected) 39 | 40 | def test_kernel_weights_per_coalition_size_degree_1(self): 41 | """Test kernel weights per coalition size for degree 1.""" 42 | p = 6 43 | degree = 1 44 | weights = calculate_kernel_weights_per_coalition_size(p, degree) 45 | 46 | # Test against R: kernelshap::kernel_weights_per_coalition_size(6, 2:4) 47 | # v 0.9.0 48 | expected = [0.3461538, 0.3076923, 0.3461538] 49 | np.testing.assert_array_almost_equal(weights, expected, decimal=5) 50 | 51 | def test_kernel_weights_per_coalition_size_error(self): 52 | """Test error when p is too small for degree.""" 53 | with pytest.raises( 54 | ValueError, match="The number of features p must be at least" 55 | ): 56 | calculate_kernel_weights_per_coalition_size(p=3, degree=1) 57 | 58 | 59 | class TestCalculateExactProp: 60 | """Test exact proportion calculation.""" 61 | 62 | def test_exact_prop_zero_degree(self): 63 | """Test exact proportion for degree 0.""" 64 | result = calculate_exact_prop(p=5, degree=0) 65 | assert result == 0.0 66 | 67 | def test_exact_prop_positive_degree(self): 68 | """Test exact proportion for positive degree.""" 69 | p = 6 70 | degree = 2 71 | result = calculate_exact_prop(p, degree) 72 | 73 | # test against R kernelshap:::prop_exact(6, 2) # v 0.9.0 74 | assert np.isclose(result, 0.8540146) 75 | 76 | def test_exact_prop_half_features(self): 77 | """Test exact proportion when degree = p/2.""" 78 | p = 6 79 | degree = 3 80 | result = calculate_exact_prop(p, degree) 81 | 82 | # test against R kernelshap:::prop_exact(6, 3) # v 0.9.0 83 | assert np.isclose(result, 1.0) 84 | 85 | 86 | class TestPrepareInputs: 87 | """Test input preparation functions.""" 88 | 89 | def test_prepare_input_exact(self): 90 | """Test exact input preparation.""" 91 | p = 3 92 | result = prepare_input_exact(p) 93 | 94 | # Check required keys 95 | assert "Z" in result 96 | assert "w" in result 97 | assert "A" in result 98 | 99 | # Check dimensions 100 | assert result["Z"].shape == (2**p - 2, p) # All masks except empty and full 101 | assert result["w"].shape == (2**p - 2, 1) 102 | assert result["A"].shape == (p, p) 103 | 104 | # Weights 105 | assert np.isclose(result["w"].sum(), 1.0) 106 | assert (result["w"] > 0).all() 107 | 108 | # A should be symmetric 109 | np.testing.assert_array_almost_equal(result["A"], result["A"].T) 110 | 111 | def test_prepare_input_hybrid(self): 112 | """Test hybrid input preparation.""" 113 | p = 6 114 | degree = 2 115 | result = prepare_input_hybrid(p, degree) 116 | 117 | # Check required keys 118 | assert "Z" in result 119 | assert "w" in result 120 | assert "A" in result 121 | 122 | # Check dimensions 123 | expected_rows = 2 * (6 + 15) # 2 * (choose(6, 1) + choose(6, 2)) 124 | assert result["Z"].shape == (expected_rows, p) 125 | assert result["w"].shape == (expected_rows, 1) 126 | assert result["A"].shape == (p, p) 127 | 128 | # Weights 129 | assert (result["w"] > 0).all() 130 | 131 | # A should be symmetric 132 | np.testing.assert_array_almost_equal(result["A"], result["A"].T) 133 | 134 | def test_prepare_input_hybrid_errors(self): 135 | """Test hybrid input preparation error cases.""" 136 | with pytest.raises(ValueError, match="degree must be at least 1"): 137 | prepare_input_hybrid(p=5, degree=0) 138 | 139 | with pytest.raises(ValueError, match="p must be >= 2 \\* degree"): 140 | prepare_input_hybrid(p=3, degree=2) 141 | 142 | def test_prepare_input_sampling(self): 143 | """Test sampling input preparation.""" 144 | p = 5 145 | degree = 1 146 | start = 0 147 | rng = np.random.default_rng(0) 148 | 149 | result = prepare_input_sampling(p, degree, start, rng) 150 | 151 | # Check required keys 152 | assert "Z" in result 153 | assert "w" in result 154 | assert "A" in result 155 | 156 | # Check dimensions 157 | expected_rows = 2 * (p - 1 - 2 * degree) 158 | assert result["Z"].shape == (expected_rows, p) 159 | assert result["w"].shape == (expected_rows, 1) 160 | assert result["A"].shape == (p, p) 161 | 162 | # A should be symmetric 163 | np.testing.assert_array_almost_equal(result["A"], result["A"].T) 164 | 165 | @pytest.mark.parametrize("p", [4, 5, 6]) 166 | def test_prepare_input_sampling_approximately_exact(self, p): 167 | """Test that sampling A approximates exact A when repeating many times.""" 168 | 169 | rng = np.random.default_rng(0) 170 | nsim = 1000 171 | 172 | A_samp = np.zeros((p, p)) 173 | for j in range(nsim): 174 | A_samp += prepare_input_sampling(p, degree=0, start=j % p, rng=rng)["A"] 175 | A_samp /= nsim 176 | A_exact = prepare_input_exact(p)["A"] 177 | assert np.abs(A_exact - A_samp).max() < 0.01 178 | 179 | @pytest.mark.parametrize("p,degree", [(4, 1), (5, 1), (6, 1), (6, 2), (7, 2)]) 180 | def test_prepare_input_hybrid_approximately_exact(self, p, degree): 181 | """Test that hybrid A approximates exact A for different degrees when repeating 182 | many times. 183 | """ 184 | 185 | rng = np.random.default_rng(0) 186 | nsim = 1000 187 | 188 | A_sampling = np.zeros((p, p)) 189 | for j in range(nsim): 190 | A_sampling += prepare_input_sampling( 191 | p, degree=degree, start=j % p, rng=rng 192 | )["A"] 193 | A_hybrid_exact = prepare_input_hybrid(p, degree=degree)["A"] 194 | A_hybrid_sampling = A_sampling / nsim 195 | A_hybrid = A_hybrid_sampling + A_hybrid_exact 196 | A_exact = prepare_input_exact(p)["A"] 197 | assert np.abs(A_exact - A_hybrid).max() < 0.01 198 | 199 | @pytest.mark.parametrize("p,degree", [(4, 1), (5, 1), (6, 1), (6, 2), (7, 2)]) 200 | def test_prepare_input_hybrid_sampling_give_weight_one(self, p, degree): 201 | """Test hybrid and sampling weights sum to 1.""" 202 | start = 0 203 | rng = np.random.default_rng(0) 204 | 205 | sampling = prepare_input_sampling(p, degree, start, rng)["w"] 206 | hybrid = prepare_input_hybrid(p, degree)["w"] 207 | 208 | assert np.isclose(sampling.sum() + hybrid.sum(), 1.0) 209 | 210 | def test_prepare_input_sampling_error(self): 211 | """Test sampling input preparation error case.""" 212 | rng = np.random.default_rng(42) 213 | with pytest.raises( 214 | ValueError, match="The number of features p must be at least" 215 | ): 216 | prepare_input_sampling(p=3, degree=1, start=0, rng=rng) 217 | 218 | 219 | class TestKernelSolver: 220 | """Test kernel solver function.""" 221 | 222 | def test_kernel_solver_basic(self): 223 | """Test basic kernel solver functionality.""" 224 | # Simple well-conditioned system 225 | A = np.array([[1.0, 0.1], [0.1, 1.0]]) 226 | b = np.array([[1.0, 2.0], [3.0, 4.0]]) 227 | constraint = np.array([[4.0, 6.0]]) # Sum constraint 228 | 229 | result = kernel_solver(A, b, constraint) 230 | 231 | # Check against R 232 | # A = rbind(c(1.0, 0.1), c(0.1, 1.0)) 233 | # b = rbind(c(1.0, 2.0), c(3.0, 4.0)) 234 | # constraint = c(4.0, 6.0) 235 | # kernelshap:::solver(A, b, constraint) # v 0.9.0 236 | 237 | expected = [[0.8888889, 1.888889], [3.1111111, 4.1111111]] 238 | np.testing.assert_array_almost_equal(result, expected) 239 | 240 | def test_kernel_solver_singular_matrix(self): 241 | """Test kernel solver with singular matrix.""" 242 | # Singular matrix 243 | A = np.array([[1.0, 1.0], [1.0, 1.0]]) 244 | b = np.array([[1.0], [1.0]]) 245 | constraint = np.array([[2.0]]) 246 | 247 | with pytest.raises(ValueError, match="Matrix A is singular"): 248 | kernel_solver(A, b, constraint) 249 | 250 | 251 | class TestPrecalculateKernelShap: 252 | """Test precalculation for Kernel SHAP.""" 253 | 254 | def test_precalculate_exact(self): 255 | """Test precalculation for exact method.""" 256 | p = 3 257 | rng = np.random.default_rng(0) 258 | X = pd.DataFrame(rng.standard_normal((10, p)), columns=["A", "B", "C"]) 259 | 260 | result = precalculate_kernelshap(p, X, how="exact") 261 | 262 | # Check required keys 263 | required_keys = ["Z", "w", "A", "masks_exact_rep", "bg_exact_rep"] 264 | for key in required_keys: 265 | assert key in result 266 | 267 | # Check dimensions 268 | assert result["Z"].shape == (2**p - 2, p) 269 | assert result["masks_exact_rep"].shape == ((2**p - 2) * 10, p) 270 | assert result["bg_exact_rep"].shape == ((2**p - 2) * 10, p) 271 | 272 | @pytest.mark.parametrize("how", ["h1", "h2"]) 273 | def test_precalculate_hybrid(self, how): 274 | """Test precalculation for hybrid methods.""" 275 | p = 6 276 | rng = np.random.default_rng(0) 277 | X = pd.DataFrame(rng.standard_normal((10, p)), columns=list("ABCDEF")) 278 | 279 | result = precalculate_kernelshap(p, X, how=how) 280 | 281 | # Check required keys 282 | required_keys = [ 283 | "Z", 284 | "w", 285 | "A", 286 | "masks_exact_rep", 287 | "bg_exact_rep", 288 | "bg_sampling_rep", 289 | ] 290 | for key in required_keys: 291 | assert key in result 292 | 293 | def test_precalculate_sampling(self): 294 | """Test precalculation for sampling method.""" 295 | p = 5 296 | rng = np.random.default_rng(0) 297 | X = pd.DataFrame(rng.standard_normal((10, p)), columns=list("ABCDE")) 298 | 299 | result = precalculate_kernelshap(p, X, how="sampling") 300 | 301 | # Check required keys - sampling doesn't have exact parts 302 | assert "bg_sampling_rep" in result 303 | assert "masks_exact_rep" not in result 304 | assert "bg_exact_rep" not in result 305 | 306 | 307 | class TestOneKernelShap: 308 | """Test single row explanation with Kernel SHAP.""" 309 | 310 | def test_exact_kernelshap_single_row(self): 311 | """Test exact Kernel SHAP for a single row.""" 312 | # Set up test data 313 | rng = np.random.default_rng(0) 314 | X = pd.DataFrame(rng.standard_normal((20, 3)), columns=["A", "B", "C"]) 315 | bg_X = X.iloc[:10] 316 | bg_w = None 317 | 318 | # Simple linear model 319 | weights = np.array([1.0, 2.0, -1.0]) 320 | 321 | def predict_fn(X): 322 | return (X.values @ weights).reshape(-1, 1) 323 | 324 | v0 = predict_fn(bg_X).mean(keepdims=True) 325 | v1 = predict_fn(X) 326 | 327 | precalc = precalculate_kernelshap(p=3, bg_X=bg_X, how="exact") 328 | 329 | shap_values, se, converged, n_iter = one_kernelshap( 330 | i=0, 331 | predict=predict_fn, 332 | how="exact", 333 | bg_w=bg_w, 334 | v0=v0, 335 | max_iter=1, 336 | tol=0.01, 337 | random_state=0, 338 | X=X, 339 | v1=v1, 340 | precalc=precalc, 341 | collapse=np.array([False]), 342 | bg_n=10, 343 | ) 344 | 345 | # Check output shapes 346 | assert shap_values.shape == (3, 1) 347 | assert se.shape == (3, 1) 348 | assert converged 349 | assert n_iter == 1 350 | 351 | # Check efficiency property 352 | prediction_diff = v1[0] - v0[0] 353 | shap_sum = shap_values.sum(axis=0) 354 | np.testing.assert_array_almost_equal(shap_sum, prediction_diff) 355 | 356 | def test_sampling_kernelshap_single_row(self): 357 | """Test sampling Kernel SHAP for a single row.""" 358 | X_large = pd.DataFrame(np.random.randn(20, 5), columns=list("ABCDE")) 359 | bg_X_large = X_large.iloc[:10] 360 | 361 | # Extend the linear model 362 | weights_large = np.array([1.0, 2.0, -1.0, 0.5, 0.3]) 363 | 364 | def predict_fn_large(X): 365 | return (X.values @ weights_large).reshape(-1, 1) 366 | 367 | v0_large = predict_fn_large(bg_X_large).mean(keepdims=True) 368 | v1_large = predict_fn_large(X_large) 369 | 370 | precalc = precalculate_kernelshap(p=5, bg_X=bg_X_large, how="sampling") 371 | 372 | shap_values, se, converged, n_iter = one_kernelshap( 373 | i=0, 374 | predict=predict_fn_large, 375 | how="sampling", 376 | bg_w=None, 377 | v0=v0_large, 378 | max_iter=20, 379 | tol=0.01, 380 | random_state=0, 381 | X=X_large, 382 | v1=v1_large, 383 | precalc=precalc, 384 | collapse=np.array([False]), 385 | bg_n=10, 386 | ) 387 | 388 | # Check output shapes 389 | assert shap_values.shape == (5, 1) 390 | assert se.shape == (5, 1) 391 | assert isinstance(converged, bool) 392 | assert n_iter == 2 393 | 394 | # Check approximate efficiency (sampling might not be perfect) 395 | prediction_diff = v1_large[0] - v0_large[0] 396 | shap_sum = shap_values.sum(axis=0) 397 | np.testing.assert_array_almost_equal(shap_sum, prediction_diff) 398 | 399 | @pytest.mark.parametrize("how", ["h1", "h2"]) 400 | def test_hybrid_kernelshap_single_row(self, how): 401 | """Test hybrid Kernel SHAP for a single row.""" 402 | X_medium = pd.DataFrame(np.random.randn(20, 7), columns=list("ABCDEFG")) 403 | bg_X_medium = X_medium.iloc[:10] 404 | 405 | # Linear model for 7 features 406 | weights_medium = np.array([1.0, 2.0, -1.0, 0.5, 0.3, 0.1, 0.2]) 407 | 408 | def predict_fn_medium(X): 409 | return (X.values @ weights_medium).reshape(-1, 1) 410 | 411 | v0_medium = predict_fn_medium(bg_X_medium).mean(keepdims=True) 412 | v1_medium = predict_fn_medium(X_medium) 413 | 414 | precalc = precalculate_kernelshap(p=7, bg_X=bg_X_medium, how=how) 415 | 416 | shap_values, se, _, _ = one_kernelshap( 417 | i=0, 418 | predict=predict_fn_medium, 419 | how=how, 420 | bg_w=None, 421 | v0=v0_medium, 422 | max_iter=10, 423 | tol=0.01, 424 | random_state=0, 425 | X=X_medium, 426 | v1=v1_medium, 427 | precalc=precalc, 428 | collapse=np.array([False]), 429 | bg_n=10, 430 | ) 431 | 432 | # Check output shapes 433 | assert shap_values.shape == (7, 1) 434 | assert se.shape == (7, 1) 435 | 436 | # Check efficiency property (hybrid should be exact for linear models) 437 | prediction_diff = v1_medium[0] - v0_medium[0] 438 | shap_sum = shap_values.sum(axis=0) 439 | np.testing.assert_array_almost_equal(shap_sum, prediction_diff) 440 | -------------------------------------------------------------------------------- /src/lightshap/explainers/tests/test_parallel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.datasets import make_classification 4 | from sklearn.ensemble import RandomForestClassifier 5 | 6 | from lightshap import explain_any 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ("method", "how"), 11 | [ 12 | ("kernel", "exact"), 13 | ("kernel", "sampling"), 14 | ("kernel", "h1"), 15 | ("kernel", "h2"), 16 | ("permutation", "exact"), 17 | ("permutation", "sampling"), 18 | ], 19 | ) 20 | def test_parallel_vs_serial_methods(method, how): 21 | """Test that methods give consistent results in parallel mode.""" 22 | 23 | X, y = make_classification(n_samples=100, n_features=6, random_state=1) 24 | model = RandomForestClassifier(n_estimators=10, random_state=1) 25 | model.fit(X, y) 26 | 27 | X_small = X[0:5] 28 | 29 | # Serial execution 30 | result_serial = explain_any( 31 | model.predict_proba, 32 | X_small, 33 | bg_X=X, 34 | method=method, 35 | how=how, 36 | n_jobs=1, 37 | verbose=False, 38 | random_state=1, 39 | ) 40 | 41 | # Parallel execution 42 | result_parallel = explain_any( 43 | model.predict_proba, 44 | X_small, 45 | bg_X=X, 46 | method=method, 47 | how=how, 48 | n_jobs=2, 49 | verbose=False, 50 | random_state=1, 51 | ) 52 | 53 | np.testing.assert_allclose( 54 | result_serial.shap_values, 55 | result_parallel.shap_values, 56 | err_msg=f"Results too different for method={method}, how={how}", 57 | ) 58 | -------------------------------------------------------------------------------- /src/lightshap/explainers/tests/test_permutation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from scipy.special import binom 5 | 6 | from lightshap.explainers._utils import generate_all_masks 7 | from lightshap.explainers.permutation_utils import ( 8 | calculate_shapley_weights, 9 | one_permshap, 10 | positions_for_exact, 11 | precalculate_permshap, 12 | ) 13 | 14 | 15 | class TestCalculateShapleyWeights: 16 | """Test Shapley weight calculation.""" 17 | 18 | def test_shapley_weights_single_value(self): 19 | """Test Shapley weights for single values.""" 20 | # For p=3, ell=1: weight = 1 / (binom(3,1) * (3-1)) = 1 / (3 * 2) = 1/6 21 | result = calculate_shapley_weights(p=3, ell=1) 22 | expected = 1.0 / (binom(3, 1) * (3 - 1)) 23 | assert np.isclose(result, expected) 24 | 25 | def test_shapley_weights_array(self): 26 | """Test Shapley weights for array input.""" 27 | p = 4 28 | ell = np.array([1, 2, 3]) 29 | result = calculate_shapley_weights(p, ell) 30 | expected = np.array( 31 | [ 32 | 1.0 / (binom(4, 1) * (4 - 1)), 33 | 1.0 / (binom(4, 2) * (4 - 2)), 34 | 1.0 / (binom(4, 3) * (4 - 3)), 35 | ] 36 | ) 37 | np.testing.assert_array_almost_equal(result, expected) 38 | 39 | 40 | class TestPositionsForExact: 41 | """Test position calculation for exact permutation SHAP.""" 42 | 43 | def test_positions_simple_case(self): 44 | """Test positions for a simple 3-feature case.""" 45 | mask = generate_all_masks(3) 46 | positions = positions_for_exact(mask) 47 | 48 | assert len(positions) == 3 49 | 50 | # Check first feature positions 51 | on_indices, off_indices = positions[0] 52 | assert len(on_indices) == len(off_indices) 53 | 54 | # Verify on/off relationship 55 | for on_idx, off_idx in zip(on_indices, off_indices, strict=False): 56 | assert mask[on_idx, 0] 57 | assert not mask[off_idx, 0] 58 | # Other features should be the same 59 | np.testing.assert_array_equal(mask[on_idx, 1:], mask[off_idx, 1:]) 60 | 61 | def test_positions_binary_relationship(self): 62 | """Test that positions maintain correct binary relationships.""" 63 | p = 4 64 | mask = generate_all_masks(p) 65 | positions = positions_for_exact(mask) 66 | 67 | for j in range(p): 68 | on_indices, off_indices = positions[j] 69 | power_of_two = 2 ** (p - 1 - j) 70 | 71 | for on_idx, off_idx in zip(on_indices, off_indices, strict=False): 72 | # The difference should be exactly 2^(p-1-j) 73 | assert on_idx - off_idx == power_of_two 74 | 75 | 76 | class TestPrecalculatePermshap: 77 | """Test precalculation for permutation SHAP.""" 78 | 79 | def test_exact_precalculation(self): 80 | """Test precalculation for exact method.""" 81 | n = 10 82 | p = 3 83 | rng = np.random.default_rng(seed=0) 84 | X = pd.DataFrame(rng.standard_normal(size=(n, p)), columns=["A", "B", "C"]) 85 | precalc = precalculate_permshap(p=p, bg_X=X, how="exact") 86 | 87 | # Check required keys 88 | required_keys = [ 89 | "masks_exact_rep", 90 | "bg_exact_rep", 91 | "shapley_weights", 92 | "positions", 93 | "bg_n", 94 | ] 95 | for key in required_keys: 96 | assert key in precalc 97 | 98 | # Check dimensions 99 | assert precalc["masks_exact_rep"].shape == ((2**p - 2) * n, p) 100 | assert precalc["bg_exact_rep"].shape == ((2**p - 2) * n, p) 101 | assert len(precalc["shapley_weights"]) == 2**p - 1 102 | assert len(precalc["positions"]) == p 103 | assert precalc["bg_n"] == n 104 | 105 | def test_sampling_precalculation(self): 106 | """Test precalculation for sampling method.""" 107 | n = 10 108 | p = 4 109 | rng = np.random.default_rng(seed=0) 110 | X = pd.DataFrame(rng.standard_normal(size=(n, p)), columns=["A", "B", "C", "D"]) 111 | precalc = precalculate_permshap(p=p, bg_X=X, how="sampling") 112 | 113 | # Check required keys 114 | required_keys = [ 115 | "masks_balanced_rep", 116 | "bg_balanced_rep", 117 | "bg_sampling_rep", 118 | "bg_n", 119 | ] 120 | for key in required_keys: 121 | assert key in precalc 122 | 123 | # Check dimensions 124 | assert precalc["masks_balanced_rep"].shape == (8 * n, p) # 2*p * bg_n 125 | assert precalc["bg_balanced_rep"].shape == (8 * n, p) 126 | assert precalc["bg_sampling_rep"].shape == (2 * n, p) # 2*(p-3) * bg_n 127 | 128 | def test_sampling_error_small_p(self): 129 | """Test that sampling raises error for p < 4.""" 130 | rng = np.random.default_rng(seed=0) 131 | X = pd.DataFrame(rng.standard_normal(size=(10, 3)), columns=["A", "B", "C"]) 132 | 133 | with pytest.raises( 134 | ValueError, match="sampling method not implemented for p < 4" 135 | ): 136 | precalculate_permshap(p=3, bg_X=X, how="sampling") 137 | 138 | def test_precalculation_with_numpy(self): 139 | """Test precalculation works with numpy arrays.""" 140 | rng = np.random.default_rng(seed=0) 141 | X = rng.standard_normal(size=(10, 3)) 142 | precalc = precalculate_permshap(p=3, bg_X=X, how="exact") 143 | 144 | assert "masks_exact_rep" in precalc 145 | assert isinstance(precalc["masks_exact_rep"], np.ndarray) 146 | 147 | 148 | class TestOnePermshap: 149 | """Test single row explanation.""" 150 | 151 | def test_exact_permshap_single_row(self): 152 | """Test exact permutation SHAP for a single row.""" 153 | # Set up test data 154 | rng = np.random.default_rng(seed=0) 155 | X = pd.DataFrame(rng.standard_normal(size=(20, 3)), columns=["A", "B", "C"]) 156 | bg_X = X.iloc[:10] 157 | bg_w = None 158 | 159 | # Simple linear model 160 | weights = np.array([1.0, 2.0, -1.0]) 161 | 162 | def predict_fn(X): 163 | return (X.values @ weights).reshape(-1, 1) 164 | 165 | v0 = predict_fn(bg_X).mean(keepdims=True) 166 | v1 = predict_fn(X) 167 | 168 | precalc = precalculate_permshap(p=3, bg_X=bg_X, how="exact") 169 | 170 | shap_values, se, converged, n_iter = one_permshap( 171 | i=0, 172 | predict=predict_fn, 173 | how="exact", 174 | bg_w=bg_w, 175 | v0=v0, 176 | max_iter=1, 177 | tol=0.01, 178 | random_state=0, 179 | X=X, 180 | v1=v1, 181 | precalc=precalc, 182 | collapse=np.array([False]), 183 | bg_n=10, 184 | ) 185 | 186 | # Check output shapes 187 | assert shap_values.shape == (3, 1) 188 | assert se.shape == (3, 1) 189 | assert converged 190 | assert n_iter == 1 191 | 192 | # Check efficiency property: sum of SHAP values = prediction - baseline 193 | prediction_diff = v1[0] - v0[0] 194 | shap_sum = shap_values.sum(axis=0) 195 | np.testing.assert_array_almost_equal(shap_sum, prediction_diff) 196 | 197 | def test_sampling_permshap_single_row(self): 198 | """Test sampling permutation SHAP for a single row.""" 199 | X_large = pd.DataFrame(np.random.randn(20, 4), columns=["A", "B", "C", "D"]) 200 | bg_X_large = X_large.iloc[:10] 201 | 202 | # Linear model 203 | weights_large = np.array([1.0, 2.0, -1.0, 0.5]) 204 | 205 | def predict_fn_large(X): 206 | return (X.values @ weights_large).reshape(-1, 1) 207 | 208 | v0_large = predict_fn_large(bg_X_large).mean(keepdims=True) 209 | v1_large = predict_fn_large(X_large) 210 | 211 | precalc = precalculate_permshap(p=4, bg_X=bg_X_large, how="sampling") 212 | 213 | shap_values, se, converged, n_iter = one_permshap( 214 | i=0, 215 | predict=predict_fn_large, 216 | how="sampling", 217 | bg_w=None, 218 | v0=v0_large, 219 | max_iter=10, 220 | tol=0.01, 221 | random_state=0, 222 | X=X_large, 223 | v1=v1_large, 224 | precalc=precalc, 225 | collapse=np.array([False]), 226 | bg_n=10, 227 | ) 228 | 229 | # Check output shapes 230 | assert shap_values.shape == (4, 1) 231 | assert se.shape == (4, 1) 232 | assert isinstance(converged, bool) 233 | assert n_iter == 2 # the first two iterations return identical values 234 | 235 | # Check approximate efficiency (sampling might not be perfect) 236 | prediction_diff = v1_large[0] - v0_large[0] 237 | shap_sum = shap_values.sum(axis=0) 238 | np.testing.assert_array_almost_equal(shap_sum, prediction_diff, decimal=1) 239 | 240 | def test_output_shapes_multioutput(self): 241 | """Test output shapes for multi-output model.""" 242 | # Set up test data for this specific test 243 | rng = np.random.default_rng(seed=0) 244 | X = pd.DataFrame(rng.standard_normal(size=(20, 3)), columns=["A", "B", "C"]) 245 | bg_X = X.iloc[:10] 246 | weights = np.array([1.0, 2.0, -1.0]) 247 | 248 | def predict_multioutput(X): 249 | X = X.values 250 | # Two outputs: linear combination and squared sum 251 | out1 = (X @ weights).reshape(-1, 1) 252 | out2 = (X**2).sum(axis=1, keepdims=True) 253 | return np.hstack([out1, out2]) 254 | 255 | v0_multi = predict_multioutput(bg_X).mean(axis=0, keepdims=True) 256 | v1_multi = predict_multioutput(X) 257 | 258 | precalc = precalculate_permshap(p=3, bg_X=bg_X, how="exact") 259 | 260 | shap_values, se, _, _ = one_permshap( 261 | i=0, 262 | predict=predict_multioutput, 263 | how="exact", 264 | bg_w=None, 265 | v0=v0_multi, 266 | max_iter=10, 267 | tol=0.01, 268 | random_state=0, 269 | X=X, 270 | v1=v1_multi, 271 | precalc=precalc, 272 | collapse=np.array([False]), 273 | bg_n=10, 274 | ) 275 | 276 | # Check output shapes for 2 outputs 277 | assert shap_values.shape == (3, 2) 278 | assert se.shape == (3, 2) 279 | -------------------------------------------------------------------------------- /src/lightshap/explanation/__init__.py: -------------------------------------------------------------------------------- 1 | from .explanation import Explanation 2 | 3 | __all__ = ["Explanation"] 4 | -------------------------------------------------------------------------------- /src/lightshap/explanation/_utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | from matplotlib.colors import BoundaryNorm, ListedColormap 7 | from scipy.stats import gaussian_kde 8 | 9 | 10 | def safe_to_float(x): 11 | """ 12 | Convert a pandas Series to float, handling categorical items. 13 | The function preserves missing values. 14 | 15 | Parameters 16 | ---------- 17 | x : pd.Series 18 | Input data to convert. 19 | 20 | Returns 21 | ------- 22 | pd.Series 23 | Converted Series with float or int values. 24 | """ 25 | is_numeric = pd.api.types.is_numeric_dtype(x) 26 | is_categorical = isinstance(x.dtype, pd.CategoricalDtype) 27 | 28 | if not is_numeric and not is_categorical: 29 | msg = f"Unsupported dtype {x.dtype} for Series." 30 | raise TypeError(msg) 31 | elif is_categorical: 32 | x = x.cat.codes.replace(-1, np.nan) 33 | 34 | return x.astype(float) 35 | 36 | 37 | def min_max_scale(x): 38 | """ 39 | Scale the input data to the range [0, 1] using min-max scaling. 40 | 41 | Parameters 42 | ---------- 43 | X : pd.Series 44 | Input data to scale. 45 | 46 | Returns 47 | ------- 48 | pd.Series 49 | Scaled data in the range [0, 1]. 50 | """ 51 | 52 | if x.isna().all(): 53 | return x 54 | 55 | xmin, xmax = x.min(), x.max() 56 | 57 | # Constant column (retaining missing values) 58 | if xmax == xmin: 59 | return x * 0.0 + 0.5 60 | 61 | return (x - xmin) / (xmax - xmin) 62 | 63 | 64 | def halton(i, base=2): 65 | """ 66 | Generate the i-th element of the Halton sequence. 67 | 68 | Source: https://en.wikipedia.org/wiki/Halton_sequence 69 | 70 | Parameters 71 | ---------- 72 | i : int 73 | Index (1-based) 74 | base : int, optional 75 | Base for the sequence, default is 2 76 | 77 | Returns 78 | ------- 79 | float 80 | The i-th value in the sequence 81 | """ 82 | result = 0 83 | f = 1 84 | while i > 0: 85 | f /= base 86 | result += f * (i % base) 87 | i = i // base 88 | return result 89 | 90 | 91 | def halton_sequence(n, base=2): 92 | """ 93 | Generate the first n elements of the Halton sequence. 94 | 95 | Parameters 96 | ---------- 97 | n : int 98 | Number of elements to generate 99 | base : int, optional 100 | Base for the sequence, default is 2 101 | 102 | Returns 103 | ------- 104 | numpy.ndarray 105 | Array of the first n elements in the Halton sequence 106 | """ 107 | return np.array([halton(i + 1, base) for i in range(n)]) 108 | 109 | 110 | def beeswarm_jitter(values, halton_vals=None): 111 | """ 112 | Compute jitter values for beeswarm plot based on density. 113 | 114 | Parameters 115 | ---------- 116 | values : array-like 117 | Values to create jitter for 118 | halton_vals : array-like 119 | Precomputed Halton sequence for jittering 120 | 121 | Returns 122 | ------- 123 | numpy.ndarray 124 | Jitter values for each point 125 | """ 126 | if len(values) == 1: 127 | return np.zeros(1, dtype=float) 128 | 129 | # Density at each point 130 | try: 131 | kde = gaussian_kde(values) 132 | density = kde(values) 133 | density_normalized = density / density.max() 134 | except ValueError: 135 | # Uniform if KDE fails 136 | density_normalized = np.ones_like(values, dtype=float) 137 | 138 | # Quasi-random values based on ranks 139 | if halton_vals is None: 140 | halton_vals = halton_sequence(len(values)) 141 | ranks = np.argsort(np.argsort(values)) 142 | shifts = halton_vals[ranks] - 0.5 143 | 144 | # Scale shifts by density 145 | return 2 * shifts * density_normalized 146 | 147 | 148 | def plot_layout(p): 149 | """ 150 | Determine plot layout based on the number of plots 151 | 152 | Parameters 153 | ---------- 154 | p : int 155 | Number of plots 156 | 157 | Returns 158 | ------- 159 | tuple 160 | Number of rows and columns for the plot layout 161 | """ 162 | if p <= 3: 163 | return 1, p 164 | elif p <= 6: 165 | return (p + 1) // 2, 2 166 | elif p <= 12: 167 | return (p + 2) // 3, 3 168 | else: 169 | return (p + 3) // 4, 4 170 | 171 | 172 | def _check_features(features, all_features, name="features"): 173 | """ 174 | Check and validate feature names. 175 | 176 | Parameters 177 | ---------- 178 | features : iterable 179 | Feature names to check. 180 | all_features : iterable 181 | All available feature names. 182 | name : str, optional 183 | Name of the feature set (for error messages). 184 | 185 | Returns 186 | ------- 187 | iterable 188 | Validated feature names. 189 | """ 190 | if features is None: 191 | return all_features 192 | elif isinstance(features, Iterable) and not isinstance(features, str): 193 | if not set(features).issubset(all_features): 194 | msg = f"Some {features} are not present in the data." 195 | raise ValueError(msg) 196 | else: 197 | msg = f"{name} must be an iterable of names, or None." 198 | raise TypeError(msg) 199 | 200 | return features 201 | 202 | 203 | def _safe_cor(x, y): 204 | """ 205 | Compute Pearson correlation coefficient between two arrays. 206 | 207 | Parameters 208 | ---------- 209 | x : array-like 210 | First input array. 211 | y : array-like 212 | Second input array. 213 | 214 | Returns 215 | ------- 216 | float 217 | The Pearson correlation coefficient, or 0 if not computable. 218 | """ 219 | ok = np.isfinite(x) & np.isfinite(y) 220 | if np.count_nonzero(ok) < 2: 221 | return 0.0 222 | x, y = x[ok], y[ok] 223 | x_sd, y_sd = x.std(ddof=1), y.std(ddof=1) 224 | 225 | if x_sd <= 1e-7 or y_sd <= 1e-7: 226 | return 0.0 227 | 228 | return np.corrcoef(x, y)[0, 1] 229 | 230 | 231 | def get_text_bbox(ax): 232 | """Get the bounding box of the text labels in the plot in the order 233 | x left, x right, y bottom, y top. 234 | 235 | Parameters 236 | ---------- 237 | ax : matplotlib.axes.Axes 238 | The axes object containing the text labels. 239 | 240 | Returns 241 | ------- 242 | tuple 243 | The bounding box coordinates of the text labels (x left, x right, y bottom, y top). 244 | """ 245 | renderer = ax.get_figure().canvas.get_renderer() 246 | left, right, bottom, top = [], [], [], [] 247 | for text in ax.texts: 248 | text_bbox = text.get_window_extent(renderer=renderer) 249 | text_bbox_data = text_bbox.transformed(ax.transData.inverted()) 250 | 251 | # Might be simplified 252 | left.append(text_bbox_data.x0) 253 | right.append(text_bbox_data.x1) 254 | bottom.append(text_bbox_data.y0) 255 | top.append(text_bbox_data.y1) 256 | return min(left), max(right), min(bottom), max(top) 257 | 258 | 259 | def color_axis_info(z, cmap, max_color_labels, max_color_label_length, **kwargs): 260 | """ 261 | Prepare color axis information for a given color feature. 262 | 263 | Helper function of plot.scatter(). 264 | 265 | Parameters 266 | ---------- 267 | z : pd.Series 268 | The color feature values. 269 | cmap : str or matplotlib colormap 270 | The colormap to use. 271 | max_color_labels : int 272 | The maximum number of color labels to display. 273 | max_color_label_length : int 274 | The maximum length of color labels. 275 | 276 | Returns 277 | ------- 278 | dict 279 | A dictionary containing color axis information. 280 | """ 281 | out = {} 282 | if isinstance(z.dtype, pd.CategoricalDtype): 283 | out["categorical"] = True 284 | out["mapping"] = dict(enumerate(z.cat.categories)) 285 | z = z.cat.codes.replace(-1, np.nan) 286 | n = out["n_colors"] = len(out["mapping"]) 287 | base_colors = plt.get_cmap(cmap, n)(np.linspace(0, 1, n)) 288 | out["cmap"] = ListedColormap(base_colors) 289 | out["norm"] = BoundaryNorm(np.arange(-0.5, n + 0.5), n) 290 | 291 | # Reduce number of labels on color bar 292 | if n > max_color_labels: 293 | step = int(np.ceil(n / max_color_labels)) 294 | for i, key in enumerate(out["mapping"]): 295 | if 0 < i < n - 1 and i % step > 0: 296 | out["mapping"][key] = "" 297 | 298 | # Truncate long labels 299 | for key, value in out["mapping"].items(): 300 | if len(value) > max_color_label_length: 301 | out["mapping"][key] = value[:max_color_label_length] 302 | else: 303 | out["cmap"] = plt.get_cmap(cmap) 304 | out["categorical"] = False 305 | 306 | out["values"] = z 307 | out["cmap"].set_bad("gray", alpha=kwargs.get("alpha", 1.0)) 308 | 309 | return out 310 | -------------------------------------------------------------------------------- /src/lightshap/explanation/explanation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from lightshap.utils import get_dataclass 7 | 8 | from ._utils import _check_features, _safe_cor, safe_to_float 9 | from .explanationplotter import ExplanationPlotter 10 | 11 | 12 | class Explanation: 13 | """SHAP Explanation object that encapsulates model explanations. 14 | 15 | The Explanation class provides a comprehensive framework for storing, analyzing, 16 | and visualizing SHAP (SHapley Additive exPlanations) values, which help interpret 17 | machine learning model predictions. This class supports both single-output and 18 | multi-output models, handles feature importance analysis, and offers various 19 | visualization methods. 20 | 21 | The class stores SHAP values along with the associated data points, baseline 22 | values, and optionally includes standard errors, convergence indicators, and 23 | iteration counts for approximation methods. It provides methods to select subsets 24 | of the data, calculate feature importance, and create various visualizations 25 | including waterfall plots, dependence plots, summary plots, and importance plots. 26 | 27 | Parameters 28 | ---------- 29 | shap_values : numpy.ndarray 30 | numpy.ndarray of shape (n_obs, n_features) for single-output models, and 31 | of shape (n_obs, n_features, n_outputs) for multi-output models. 32 | 33 | X : pandas.DataFrame, polars.DataFrame, numpy.ndarray 34 | Feature values corresponding to `shap_values`. The columns must be in the 35 | same order. 36 | 37 | baseline : float or numpy.ndarray, default=0.0 38 | The baseline value(s) representing the expected model output when all 39 | features are missing. For single-output models, either a scalar or a 40 | numpy.ndarray of shape (1, ). 41 | For multi-output models, an array of shape (n_outputs,). 42 | 43 | feature_names : list or None, default=None 44 | Feature names. If None and X is a pandas DataFrame, column names 45 | are used. If None and X is not a DataFrame, default names are generated. 46 | 47 | output_names : list or None, default=None 48 | Names of the outputs for multi-output models. If None, default names are 49 | generated. 50 | 51 | standard_errors : numpy.ndarray or None, default=None 52 | Standard errors of the SHAP values. Must have the same shape as shap_values, 53 | or None. Only relevant for approximate methods. 54 | 55 | converged : numpy.ndarray or None, default=None 56 | Boolean array indicating the convergence status per observation. Only 57 | relevant for approximate methods. 58 | 59 | n_iter : numpy.ndarray or None, default=None 60 | Number of iterations per observation. Only relevant for approximate methods. 61 | 62 | Attributes 63 | ---------- 64 | shap_values : numpy.ndarray 65 | numpy.ndarray of shape (n_obs, n_features) for single-output models, and 66 | of shape (n_obs, n_features, n_outputs) for multi-output models. 67 | 68 | X : pandas.DataFrame 69 | The feature values corresponding to `shap_values`. Note that the index 70 | is reset to the values 0 to n_obs - 1. 71 | 72 | baseline : numpy.ndarray 73 | Baseline value(s). Has shape (1, ) for single-output models, and 74 | shape (n_outputs, ) for multi-output models. 75 | 76 | standard_errors : numpy.ndarray or None 77 | Standard errors of the SHAP values of the same shape as `shap_values` 78 | (if available). 79 | 80 | converged : numpy.ndarray or None 81 | Convergence indicators of shape (n_obs, ) (if available). 82 | 83 | n_iter : numpy.ndarray or None 84 | Iteration counts of shape (n_obs, ) (if available). 85 | 86 | shape : tuple 87 | Shape of `shap_values`. 88 | 89 | ndim : int 90 | Number of dimensions of the SHAP values (2 or 3). 91 | 92 | feature_names : list 93 | Feature names. 94 | 95 | output_names : list or None 96 | Output names for multi-output models. None for single-output models. 97 | 98 | Examples 99 | -------- 100 | >>> import numpy as np 101 | >>> import pandas as pd 102 | >>> from lightshap import Explanation 103 | >>> 104 | >>> # Example data 105 | >>> X = pd.DataFrame({'feature1': [1, 2, 3], 'feature2': [4, 5, 6]}) 106 | >>> shap_values = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) 107 | >>> 108 | >>> explanation = Explanation(shap_values, X, baseline=0.5) 109 | >>> 110 | >>> # Waterfall plot of first observation 111 | >>> explanation.plot.waterfall(row_id=0) 112 | """ 113 | 114 | def __init__( 115 | self, 116 | shap_values, 117 | X, 118 | baseline=0.0, 119 | feature_names=None, 120 | output_names=None, 121 | standard_errors=None, 122 | converged=None, 123 | n_iter=None, 124 | ): 125 | if not isinstance(shap_values, np.ndarray) or shap_values.shape[0] < 1: 126 | msg = "SHAP values must be a numpy array with at least one row." 127 | raise TypeError(msg) 128 | 129 | n = shap_values.shape[0] 130 | 131 | if standard_errors is not None and standard_errors.shape != shap_values.shape: 132 | msg = ( 133 | f"Shape {standard_errors.shape} of standard_errors does not match " 134 | f"shape {shap_values.shape} of SHAP values." 135 | ) 136 | raise ValueError(msg) 137 | if converged is not None and converged.shape[0] != n: 138 | msg = ( 139 | f"Length {converged.shape[0]} of converged does not match " 140 | f"number {n} of rows of SHAP values." 141 | ) 142 | raise ValueError(msg) 143 | if n_iter is not None and n_iter.shape[0] != n: 144 | msg = ( 145 | f"Length {n_iter.shape[0]} of n_iter does not match " 146 | f"number {n} of rows of SHAP values." 147 | ) 148 | raise ValueError(msg) 149 | 150 | # Drop third dimension of shap_values and standard_errors if unnecessary 151 | if shap_values.ndim == 3 and shap_values.shape[2] == 1: 152 | shap_values = shap_values.reshape(n, -1) 153 | if standard_errors is not None: 154 | standard_errors = standard_errors.reshape(n, -1) 155 | elif not 2 <= shap_values.ndim <= 3: 156 | msg = "SHAP values must be 2D or 3D." 157 | raise ValueError(msg) 158 | 159 | # Baseline should have shape (K, ) 160 | if not isinstance(baseline, np.ndarray): 161 | baseline = np.asarray(baseline) 162 | baseline = baseline.flatten() # turn into 1D array 163 | K = 1 if shap_values.ndim == 2 else shap_values.shape[2] 164 | if baseline.shape[0] != K: 165 | msg = ( 166 | f"Length {len(baseline)} of baseline does not match " 167 | f"number {K} of output dimensions." 168 | ) 169 | raise ValueError(msg) 170 | 171 | self.shap_values = shap_values 172 | self.baseline = baseline 173 | self.standard_errors = standard_errors 174 | self.converged = converged 175 | self.n_iter = n_iter 176 | 177 | # Some attributes for convenience 178 | self.shape = shap_values.shape 179 | self.ndim = shap_values.ndim 180 | 181 | self.set_output_names(output_names) 182 | 183 | # Setting X also sets feature names 184 | self.set_X(X) 185 | if feature_names is not None: 186 | self.set_feature_names(feature_names) 187 | 188 | @property 189 | def plot(self): 190 | """ 191 | Accessor for plotting methods. 192 | 193 | Examples 194 | -------- 195 | >>> explanation.plot.bar() 196 | >>> explanation.plot.waterfall(row_id=0) 197 | >>> explanation.plot.beeswarm() 198 | >>> explanation.plot.scatter(features=["feature1", "feature2"]) 199 | """ 200 | return ExplanationPlotter(self) 201 | 202 | def __repr__(self): 203 | # Get shapes and sample sizes for display 204 | n = self.shape[0] 205 | n_display = min(2, n) 206 | 207 | out = "SHAP Explanation\n\n" 208 | 209 | # SHAP values section 210 | out += f"SHAP values {self.shape}, first {n_display}:\n" 211 | out += f"{self.shap_values[:n_display]!r}\n\n" 212 | 213 | # Data section 214 | out += f"X, first {n_display} rows:\n" 215 | out += str(self.X.head(2)) 216 | 217 | return out 218 | 219 | def __len__(self): 220 | return self.shape[0] 221 | 222 | def filter(self, indices): 223 | """ 224 | Filter the SHAP values by array-like. 225 | 226 | Parameters 227 | ---------- 228 | indices : array-like 229 | Integer or boolean array-like to filter the SHAP values and data. 230 | 231 | Returns 232 | ------- 233 | Explanation 234 | A new Explanation object with filtered SHAP values and data. 235 | """ 236 | if not isinstance(indices, np.ndarray): 237 | indices = np.asarray(indices) 238 | 239 | if not (np.issubdtype(indices.dtype, np.integer) or indices.dtype == np.bool_): 240 | msg = "indices must be an integer or boolean array-like." 241 | raise TypeError(msg) 242 | 243 | values = self.shap_values[indices] 244 | se = self.standard_errors[indices] if self.standard_errors is not None else None 245 | X = self.X[indices] if indices.dtype == np.bool_ else self.X.iloc[indices] 246 | 247 | return Explanation( 248 | shap_values=values, 249 | X=X, 250 | baseline=self.baseline, 251 | output_names=self.output_names, 252 | standard_errors=se, 253 | converged=self.converged[indices] if self.converged is not None else None, 254 | n_iter=self.n_iter[indices] if self.n_iter is not None else None, 255 | ) 256 | 257 | def select_output(self, index): 258 | """ 259 | Select specific output dimension from the SHAP values. Useful if 260 | predictions are multi-output. 261 | 262 | Parameters 263 | ---------- 264 | index : Int or str 265 | Index or name of the output dimension to select. 266 | 267 | Returns 268 | ------- 269 | Explanation 270 | A new Explanation object with only the selected output. 271 | """ 272 | if self.ndim != 3: 273 | return self 274 | 275 | if self.output_names is not None and isinstance(index, str): 276 | index = self.output_names.index(index) 277 | elif not isinstance(index, int): 278 | msg = "index must be an integer or string." 279 | raise TypeError(msg) 280 | 281 | if self.standard_errors is not None: 282 | se = self.standard_errors[:, :, index] 283 | else: 284 | se = None 285 | 286 | return Explanation( 287 | shap_values=self.shap_values[:, :, index], 288 | X=self.X, 289 | baseline=self.baseline[[index]], # need to keep np.array 290 | output_names=None, 291 | standard_errors=se, 292 | converged=self.converged, 293 | n_iter=self.n_iter, 294 | ) 295 | 296 | def set_feature_names(self, feature_names): 297 | """ 298 | Set feature names of 'X'. 299 | 300 | Parameters 301 | ---------- 302 | feature_names : list or array-like 303 | Feature names to set. 304 | """ 305 | p = self.X.shape[1] 306 | if len(feature_names) != p: 307 | msg = ( 308 | f"Length {len(feature_names)} of feature_names does not match " 309 | f"number {p} of columns in X." 310 | ) 311 | raise ValueError(msg) 312 | if not isinstance(feature_names, list): 313 | feature_names = list(feature_names) 314 | 315 | self.X.columns = self.feature_names = feature_names 316 | 317 | return self 318 | 319 | def set_output_names(self, output_names=None): 320 | """ 321 | If predictions are multi-output, set names of the additional dimension. 322 | 323 | Parameters 324 | ---------- 325 | output_names : list or array-like, optional 326 | Output names to set. 327 | """ 328 | if self.ndim == 3: 329 | K = self.shap_values.shape[2] 330 | if output_names is None: 331 | output_names = list(range(K)) 332 | elif len(output_names) != K: 333 | msg = ( 334 | f"Length {len(output_names)} of output_names does not match " 335 | f"number {K} of outputs in SHAP values." 336 | ) 337 | raise ValueError(msg) 338 | else: 339 | output_names = None 340 | 341 | if output_names is not None and not isinstance(output_names, list): 342 | output_names = list(output_names) 343 | 344 | self.output_names = output_names 345 | 346 | return self 347 | 348 | def set_X(self, X): 349 | """Set X and self.feature_names. 350 | 351 | `X` is converted to pandas. String and object columns are converted to 352 | categoricals, while numeric columns are left unchanged. Other column types 353 | will raise a TypeError. 354 | 355 | Parameters 356 | ---------- 357 | X : numpy.ndarray, pandas.DataFrame or polars.DataFrame 358 | New data to set. Columns must match the order of SHAP values. 359 | """ 360 | if X.shape != self.shap_values.shape[:2]: 361 | msg = ( 362 | f"Shape {X.shape} of X does not match shape " 363 | f"{self.shap_values.shape[:2]} of SHAP values." 364 | ) 365 | raise ValueError(msg) 366 | 367 | xclass = get_dataclass(X) 368 | if xclass == "np": 369 | if hasattr(self, "feature_names") and self.feature_names is not None: 370 | X = pd.DataFrame(X, columns=self.feature_names) 371 | else: 372 | X = pd.DataFrame(X) 373 | elif xclass == "pl": 374 | try: 375 | X = X.to_pandas() 376 | except Exception as e: 377 | msg = ( 378 | "Failed to convert polars DataFrame to pandas. " 379 | "Make sure polars is properly installed: pip install polars" 380 | ) 381 | raise ImportError(msg) from e 382 | else: # pd 383 | X = X.reset_index(drop=True) 384 | 385 | # Columns will stay numeric/boolean or become categorical 386 | for v in X.columns: 387 | is_numeric = pd.api.types.is_numeric_dtype(X[v]) 388 | is_categorical = isinstance(X[v].dtype, pd.CategoricalDtype) 389 | if not is_numeric and not is_categorical: 390 | is_string = pd.api.types.is_string_dtype(X[v]) 391 | is_object = pd.api.types.is_object_dtype(X[v]) 392 | 393 | if is_string or is_object: 394 | X[v] = X[v].astype("category") 395 | else: 396 | msg = f"Column {v} has unsupported dtype {X[v].dtype}." 397 | raise TypeError(msg) 398 | 399 | self.X = X 400 | self.feature_names = self.X.columns.to_list() 401 | 402 | return self 403 | 404 | def importance(self, which_output=None): 405 | """ 406 | Calculate mean absolute SHAP values for each feature (and output dimension). 407 | 408 | Parameters 409 | ---------- 410 | which_output : int or string, optional 411 | Index or name of the output dimension to calculate importance for. 412 | If None, all outputs are considered. Only relevant for multi-output models. 413 | 414 | Returns 415 | ------- 416 | pd.Series or pd.DataFrame 417 | Series containing mean absolute SHAP values sorted by importance. 418 | In case of multi-output models, it returns a DataFrame, and the sort 419 | order is determined by the average importance across all outputs. 420 | """ 421 | if self.ndim == 3 and which_output is not None: 422 | self = self.select_output(which_output) # noqa: PLW0642 423 | 424 | imp = np.abs(self.shap_values).mean(axis=0) 425 | 426 | if self.ndim == 2: 427 | imp = pd.Series(imp, index=self.feature_names).sort_values(ascending=False) 428 | else: # ndim == 3 -> we sort by average importance across outputs 429 | imp = pd.DataFrame(imp, index=self.feature_names, columns=self.output_names) 430 | imp = imp.loc[imp.mean(axis=1).sort_values(ascending=False).index] 431 | return imp 432 | 433 | def interaction_heuristic(self, features=None, color_features=None): 434 | """Interaction heuristic. 435 | 436 | For each feature/color_feature combination, the weighted average absolute 437 | Pearson correlation coefficient between the SHAP values of the feature 438 | and the values of the color_feature is calculated. The larger the value, 439 | the higher the potential interaction. 440 | 441 | Notes: 442 | 443 | - Non-numeric color features are converted to numeric, which does not always 444 | make sense. 445 | - Missing values in the color feature are currently discarded. 446 | - The number of non-missing color values in the bins are used as weight to 447 | compute the weighted average. 448 | 449 | Parameters 450 | ---------- 451 | features : list, optional 452 | List of feature names. If None, all features are used. 453 | color_features : list, optional 454 | List of color feature names. If None, all features are used. 455 | 456 | Returns 457 | ------- 458 | pd.DataFrame 459 | DataFrame with interaction heuristics. `feature_names` serve as index, 460 | and `color_features` as columns. 461 | """ 462 | features = _check_features(features, self.feature_names) 463 | color_features = _check_features( 464 | color_features, self.feature_names, name="color features" 465 | ) 466 | 467 | idx = [self.feature_names.index(f) for f in features] 468 | 469 | df = self.X[features] 470 | df_color = self.X[color_features].apply(safe_to_float) # to numeric 471 | df_shap = pd.DataFrame(self.shap_values[:, idx], columns=df.columns) 472 | 473 | nbins = math.ceil(min(np.sqrt(df.shape[0]), df.shape[0] / 20)) 474 | 475 | out = pd.DataFrame(0.0, index=df.columns, columns=df_color.columns) 476 | 477 | for xname in df.columns: 478 | xgroups = df[xname] 479 | 480 | if pd.api.types.is_numeric_dtype(xgroups) and xgroups.nunique() > nbins: 481 | xgroups = pd.qcut(xgroups, nbins + 1, duplicates="drop", labels=False) 482 | 483 | pick = [column for column in df_color.columns if column != xname] 484 | grouped = df_color[pick].groupby(xgroups, dropna=False, observed=True) 485 | corr = grouped.corrwith(df_shap[xname], method=_safe_cor) 486 | out.loc[xname, pick] = np.average( 487 | corr.abs(), weights=grouped.count(), axis=0 488 | ) 489 | 490 | return out 491 | -------------------------------------------------------------------------------- /src/lightshap/explanation/tests/test_explanation_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | 6 | from lightshap.explanation._utils import ( 7 | _check_features, 8 | _safe_cor, 9 | beeswarm_jitter, 10 | color_axis_info, 11 | get_text_bbox, 12 | halton, 13 | halton_sequence, 14 | min_max_scale, 15 | plot_layout, 16 | safe_to_float, 17 | ) 18 | 19 | 20 | class TestSafeToFloat: 21 | """Test the safe_to_float utility function.""" 22 | 23 | def test_numeric_series(self): 24 | """Test with numeric pandas Series.""" 25 | s = pd.Series([1, 2, 3, 4, 5]) 26 | result = safe_to_float(s) 27 | expected = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0]) 28 | pd.testing.assert_series_equal(result, expected) 29 | 30 | def test_integer_series(self): 31 | """Test with integer pandas Series.""" 32 | s = pd.Series([1, 2, 3]) 33 | result = safe_to_float(s) 34 | expected = pd.Series([1.0, 2.0, 3.0]) 35 | pd.testing.assert_series_equal(result, expected) 36 | 37 | def test_categorical_series(self): 38 | """Test with categorical pandas Series.""" 39 | s = pd.Series(["a", "b", "c", "a"], dtype="category") 40 | result = safe_to_float(s) 41 | expected = pd.Series([0.0, 1.0, 2.0, 0.0]) 42 | pd.testing.assert_series_equal(result, expected) 43 | 44 | def test_categorical_with_nan(self): 45 | """Test categorical series with missing values.""" 46 | s = pd.Series(["a", "b", None, "a"], dtype="category") 47 | result = safe_to_float(s) 48 | expected = pd.Series([0.0, 1.0, np.nan, 0.0]) 49 | pd.testing.assert_series_equal(result, expected) 50 | 51 | def test_float_with_nan(self): 52 | """Test float series with missing values.""" 53 | s = pd.Series([1.0, 2.0, np.nan, 4.0]) 54 | result = safe_to_float(s) 55 | pd.testing.assert_series_equal(result, s) 56 | 57 | def test_unsupported_dtype_error(self): 58 | """Test error with unsupported data type.""" 59 | s = pd.Series(["a", "b", "c"]) # String without category 60 | with pytest.raises(TypeError, match="Unsupported dtype"): 61 | safe_to_float(s) 62 | 63 | 64 | class TestMinMaxScale: 65 | """Test the min_max_scale utility function.""" 66 | 67 | def test_basic_scaling(self): 68 | """Test basic min-max scaling.""" 69 | s = pd.Series([1, 2, 3, 4, 5]) 70 | result = min_max_scale(s) 71 | expected = pd.Series([0.0, 0.25, 0.5, 0.75, 1.0]) 72 | pd.testing.assert_series_equal(result, expected) 73 | 74 | def test_constant_values(self): 75 | """Test scaling with constant values.""" 76 | s = pd.Series([5, 5, 5, 5]) 77 | result = min_max_scale(s) 78 | expected = pd.Series([0.5, 0.5, 0.5, 0.5]) 79 | pd.testing.assert_series_equal(result, expected) 80 | 81 | def test_with_missing_values(self): 82 | """Test scaling with missing values.""" 83 | s = pd.Series([1, np.nan, 3, 4, 5]) 84 | result = min_max_scale(s) 85 | expected = pd.Series([0.0, np.nan, 0.5, 0.75, 1.0]) 86 | pd.testing.assert_series_equal(result, expected) 87 | 88 | def test_all_missing_values(self): 89 | """Test scaling when all values are missing.""" 90 | s = pd.Series([np.nan, np.nan, np.nan]) 91 | result = min_max_scale(s) 92 | pd.testing.assert_series_equal(result, s) 93 | 94 | 95 | class TestHaltonSequence: 96 | """Test Halton sequence generation.""" 97 | 98 | def test_halton_base_2(self): 99 | """Test Halton sequence with base 2.""" 100 | result = halton(1, base=2) 101 | assert result == 0.5 102 | 103 | result = halton(2, base=2) 104 | assert result == 0.25 105 | 106 | result = halton(3, base=2) 107 | assert result == 0.75 108 | 109 | def test_halton_sequence_generation(self): 110 | """Test generation of multiple Halton sequence values.""" 111 | result = halton_sequence(4, base=2) 112 | expected = np.array([0.5, 0.25, 0.75, 0.125]) 113 | np.testing.assert_array_almost_equal(result, expected) 114 | 115 | def test_halton_sequence_length(self): 116 | """Test that halton_sequence returns correct length.""" 117 | n = 10 118 | result = halton_sequence(n) 119 | assert len(result) == n 120 | 121 | 122 | class TestBeeswarmJitter: 123 | """Test beeswarm jitter calculation.""" 124 | 125 | def test_single_value(self): 126 | """Test jitter with single value.""" 127 | values = np.array([1.0]) 128 | result = beeswarm_jitter(values) 129 | expected = np.array([0.0]) 130 | np.testing.assert_array_equal(result, expected) 131 | 132 | def test_multiple_values(self): 133 | """Test jitter with multiple values.""" 134 | values = np.array([1, 2, 3, 4, 5]) 135 | result = beeswarm_jitter(values) 136 | assert len(result) == len(values) 137 | assert result.dtype == float 138 | 139 | def test_identical_values(self): 140 | """Test jitter with identical values.""" 141 | values = np.array([2, 2, 2, 2]) 142 | result = beeswarm_jitter(values) 143 | assert len(result) == len(values) 144 | # Should not raise an error even with identical values 145 | 146 | def test_jitter_range(self): 147 | """Test that jitter values are reasonable.""" 148 | values = np.random.randn(50) 149 | result = beeswarm_jitter(values) 150 | # Jitter should generally be within reasonable bounds 151 | assert np.abs(result).max() <= 1.0 152 | 153 | 154 | class TestPlotLayout: 155 | """Test plot layout determination.""" 156 | 157 | def test_small_numbers(self): 158 | """Test layout for small numbers of plots.""" 159 | assert plot_layout(1) == (1, 1) 160 | assert plot_layout(2) == (1, 2) 161 | assert plot_layout(3) == (1, 3) 162 | 163 | def test_medium_numbers(self): 164 | """Test layout for medium numbers of plots.""" 165 | assert plot_layout(4) == (2, 2) 166 | assert plot_layout(5) == (3, 2) 167 | assert plot_layout(6) == (3, 2) 168 | 169 | def test_larger_numbers(self): 170 | """Test layout for larger numbers of plots.""" 171 | assert plot_layout(7) == (3, 3) 172 | assert plot_layout(9) == (3, 3) 173 | assert plot_layout(12) == (4, 3) 174 | 175 | def test_very_large_numbers(self): 176 | """Test layout for very large numbers of plots.""" 177 | assert plot_layout(13) == (4, 4) 178 | assert plot_layout(16) == (4, 4) 179 | assert plot_layout(20) == (5, 4) 180 | 181 | 182 | class TestCheckFeatures: 183 | """Test feature checking and validation.""" 184 | 185 | def test_none_features(self): 186 | """Test with None features (should return all).""" 187 | all_features = ["a", "b", "c"] 188 | result = _check_features(None, all_features) 189 | assert result == all_features 190 | 191 | def test_valid_features(self): 192 | """Test with valid feature subset.""" 193 | features = ["a", "c"] 194 | all_features = ["a", "b", "c"] 195 | result = _check_features(features, all_features) 196 | assert result == features 197 | 198 | def test_invalid_features_error(self): 199 | """Test error with invalid features.""" 200 | features = ["a", "d"] # "d" not in all_features 201 | all_features = ["a", "b", "c"] 202 | with pytest.raises(ValueError, match="Some .* are not present"): 203 | _check_features(features, all_features) 204 | 205 | def test_non_iterable_error(self): 206 | """Test error with non-iterable features.""" 207 | features = "a" # String is iterable but should be treated as single item 208 | all_features = ["a", "b", "c"] 209 | with pytest.raises(TypeError, match="must be an iterable"): 210 | _check_features(features, all_features) 211 | 212 | 213 | class TestSafeCor: 214 | """Test safe correlation calculation.""" 215 | 216 | def test_perfect_correlation(self): 217 | """Test perfect positive correlation.""" 218 | x = np.array([1, 2, 3, 4, 5]) 219 | y = np.array([2, 4, 6, 8, 10]) 220 | result = _safe_cor(x, y) 221 | assert abs(result - 1.0) < 1e-10 222 | 223 | def test_no_correlation(self): 224 | """Test no correlation.""" 225 | x = np.array([1, 2, 3, 4, 5]) 226 | y = np.array([1, 1, 1, 1, 1]) # Constant 227 | result = _safe_cor(x, y) 228 | assert result == 0.0 229 | 230 | def test_with_nan_values(self): 231 | """Test correlation with NaN values.""" 232 | x = np.array([1, 2, np.nan, 4, 5]) 233 | y = np.array([2, 4, 6, 8, 10]) 234 | result = _safe_cor(x, y) 235 | # Should compute correlation on valid pairs only 236 | assert abs(result - 1.0) < 1e-10 237 | 238 | def test_insufficient_data(self): 239 | """Test correlation with insufficient valid data.""" 240 | x = np.array([1, np.nan]) 241 | y = np.array([2, np.nan]) 242 | result = _safe_cor(x, y) 243 | assert result == 0.0 244 | 245 | def test_zero_variance(self): 246 | """Test correlation when one variable has zero variance.""" 247 | x = np.array([1, 1, 1, 1]) 248 | y = np.array([1, 2, 3, 4]) 249 | result = _safe_cor(x, y) 250 | assert result == 0.0 251 | 252 | 253 | class TestGetTextBbox: 254 | """Test text bounding box calculation.""" 255 | 256 | def test_text_bbox_basic(self): 257 | """Test basic text bounding box calculation.""" 258 | fig, ax = plt.subplots() 259 | ax.text(0.5, 0.5, "Test text") 260 | ax.text(0.2, 0.8, "Another text") 261 | 262 | bbox = get_text_bbox(ax) 263 | assert len(bbox) == 4 # (left, right, bottom, top) 264 | assert all(isinstance(coord, int | float) for coord in bbox) 265 | 266 | plt.close(fig) 267 | 268 | 269 | class TestColorAxisInfo: 270 | """Test color axis information preparation.""" 271 | 272 | def test_categorical_color_feature(self): 273 | """Test with categorical color feature.""" 274 | z = pd.Series(["A", "B", "C", "A"], dtype="category") 275 | result = color_axis_info( 276 | z, "viridis", max_color_labels=10, max_color_label_length=20 277 | ) 278 | 279 | assert result["categorical"] is True 280 | assert "mapping" in result 281 | assert "cmap" in result 282 | assert "norm" in result 283 | assert "values" in result 284 | assert result["n_colors"] == 3 285 | 286 | def test_numeric_color_feature(self): 287 | """Test with numeric color feature.""" 288 | z = pd.Series([1.0, 2.0, 3.0, 4.0]) 289 | result = color_axis_info( 290 | z, "viridis", max_color_labels=10, max_color_label_length=20 291 | ) 292 | 293 | assert result["categorical"] is False 294 | assert "cmap" in result 295 | assert "values" in result 296 | pd.testing.assert_series_equal(result["values"], z) 297 | 298 | def test_categorical_label_truncation(self): 299 | """Test label truncation for long categorical labels.""" 300 | categories = ["very_long_category_name", "short"] 301 | z = pd.Series(categories, dtype="category") 302 | result = color_axis_info( 303 | z, "viridis", max_color_labels=10, max_color_label_length=5 304 | ) 305 | 306 | # Check that long labels are truncated 307 | for value in result["mapping"].values(): 308 | assert len(value) <= 5 309 | 310 | def test_too_many_categorical_labels(self): 311 | """Test behavior with too many categorical labels.""" 312 | categories = [f"cat_{i}" for i in range(20)] 313 | z = pd.Series(categories, dtype="category") 314 | result = color_axis_info( 315 | z, "viridis", max_color_labels=5, max_color_label_length=20 316 | ) 317 | 318 | # Should reduce number of labels shown 319 | non_empty_labels = sum(1 for v in result["mapping"].values() if v != "") 320 | assert non_empty_labels <= 7 # Some labels should be empty 321 | 322 | def test_categorical_with_missing(self): 323 | """Test categorical feature with missing values.""" 324 | z = pd.Series(["A", "B", None, "A"], dtype="category") 325 | result = color_axis_info( 326 | z, "viridis", max_color_labels=10, max_color_label_length=20 327 | ) 328 | 329 | assert result["categorical"] is True 330 | # Missing values should be handled as NaN in the values 331 | assert pd.isna(result["values"]).any() 332 | 333 | def test_colormap_bad_value_setting(self): 334 | """Test that colormap handles bad values (NaN) correctly.""" 335 | z = pd.Series([1.0, 2.0, np.nan, 4.0]) 336 | result = color_axis_info( 337 | z, "viridis", max_color_labels=10, max_color_label_length=20, alpha=0.5 338 | ) 339 | 340 | # The colormap should have bad color set 341 | assert hasattr(result["cmap"], "_rgba_bad") 342 | -------------------------------------------------------------------------------- /src/lightshap/explanation/tests/test_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | 3 | mpl.use("Agg") # Use non-interactive backend for testing 4 | 5 | import matplotlib.colors as mcolors 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import pytest 10 | 11 | from lightshap.explanation.explanation import Explanation 12 | from lightshap.explanation.explanationplotter import ExplanationPlotter 13 | 14 | 15 | def create_explanation(n_samples=10, n_features=5, n_outputs=1): 16 | """Create an Explanation object for testing.""" 17 | rng = np.random.default_rng(0) 18 | feature_names = [f"feature_{i}" for i in range(n_features)] 19 | 20 | # Mixed-type features 21 | X_data = {} 22 | for i, name in enumerate(feature_names): 23 | if i == 0: # First feature is categorical 24 | categories = ["A", "B", "C", "D"] 25 | values = rng.choice(categories, size=n_samples) 26 | X_data[name] = pd.Categorical(values) 27 | else: # Other features are numeric 28 | X_data[name] = rng.random(n_samples) 29 | 30 | X = pd.DataFrame(X_data) 31 | 32 | # Add 10% missing values randomly across all features 33 | n_missing = int(0.1 * n_samples * n_features) 34 | missing_indices = rng.choice(n_samples * n_features, size=n_missing, replace=False) 35 | 36 | for idx in missing_indices: 37 | row_idx = idx // n_features 38 | col_idx = idx % n_features 39 | X.iloc[row_idx, col_idx] = np.nan 40 | 41 | if n_outputs == 1: 42 | shap_values = rng.random((n_samples, n_features)) 43 | baseline = 0.5 44 | else: 45 | shap_values = rng.random((n_samples, n_features, n_outputs)) 46 | baseline = rng.random(n_outputs) 47 | 48 | return Explanation( 49 | shap_values=shap_values, X=X, baseline=baseline, feature_names=feature_names 50 | ) 51 | 52 | 53 | class TestBar: 54 | """Test suite for the bar method.""" 55 | 56 | def setup_method(self): 57 | """Set up test fixtures.""" 58 | self.explanation = create_explanation(n_samples=20, n_features=8) 59 | self.plotter = ExplanationPlotter(self.explanation) 60 | 61 | def test_returns_matplotlib_axis(self): 62 | """Test that bar returns a matplotlib Axes object.""" 63 | ax = self.plotter.bar() 64 | assert isinstance(ax, plt.Axes) 65 | plt.close() 66 | 67 | def test_max_display_limits_features(self): 68 | """Test that max_display parameter limits the number of displayed features.""" 69 | max_display = 3 70 | ax = self.plotter.bar(max_display=max_display) 71 | 72 | # Check that the number of y-tick labels is at most max_display 73 | y_labels = ax.get_yticklabels() 74 | assert len(y_labels) == max_display 75 | plt.close() 76 | 77 | def test_max_display_none_shows_all_features(self): 78 | """Test that max_display=None shows all features.""" 79 | ax = self.plotter.bar(max_display=None) 80 | 81 | y_labels = ax.get_yticklabels() 82 | assert len(y_labels) == self.explanation.shape[1] # n_features 83 | plt.close() 84 | 85 | def test_custom_axis(self): 86 | """Test that custom axis can be provided.""" 87 | fig, ax = plt.subplots() 88 | fig.canvas.draw() # Force canvas initialization 89 | result_ax = self.plotter.bar(ax=ax) 90 | 91 | assert result_ax is ax 92 | plt.close() 93 | 94 | def test_multi_output_explanation(self): 95 | """Test bar with multi-output explanation.""" 96 | M = 3 97 | multi_explanation = create_explanation(n_samples=10, n_features=5, n_outputs=M) 98 | plotter = ExplanationPlotter(multi_explanation) 99 | 100 | ax = plotter.bar() 101 | assert isinstance(ax, plt.Axes) 102 | plt.close() 103 | 104 | def test_bar_containers_exist(self): 105 | """Test that bar containers are created.""" 106 | ax = self.plotter.bar() 107 | 108 | # Check that there are bar containers 109 | assert len(ax.containers) > 0 110 | plt.close() 111 | 112 | def test_axis_labels_and_grid(self): 113 | """Test that proper axis labels and grid are set.""" 114 | ax = self.plotter.bar() 115 | 116 | assert ax.get_xlabel() == "Mean Absolute SHAP Value" 117 | assert ax.grid # Grid should be enabled 118 | plt.close() 119 | 120 | def test_no_bar_labels_with_zero_fontsize(self): 121 | """Test that no bar labels are added when label_fontsize=0.""" 122 | ax = self.plotter.bar(label_fontsize=0.0) 123 | 124 | # Should still have bars but no labels 125 | assert len(ax.containers) > 0 126 | plt.close() 127 | 128 | def test_inverted_yaxis(self): 129 | """Test that y-axis is properly inverted for feature importance ranking.""" 130 | ax = self.plotter.bar() 131 | 132 | # Y-axis should be inverted (higher importance features at top) 133 | ylim = ax.get_ylim() 134 | assert ylim[0] > ylim[1] # Inverted axis 135 | plt.close() 136 | 137 | def test_invalid_ax_type_raises_error(self): 138 | """Test that invalid ax parameter raises TypeError.""" 139 | with pytest.raises(TypeError, match="ax must be a matplotlib Axes"): 140 | self.plotter.bar(ax="invalid") 141 | 142 | def test_custom_color(self): 143 | """Test that custom color can be set.""" 144 | custom_color = "#ff0000" 145 | ax = self.plotter.bar(color=custom_color) 146 | 147 | # Check that bars exist (actual color testing is complex with matplotlib) 148 | assert len(ax.containers) > 0 149 | plt.close() 150 | 151 | 152 | class TestBeeswarm: 153 | """Test suite for the beeswarm method.""" 154 | 155 | def setup_method(self): 156 | """Set up test fixtures.""" 157 | self.explanation = create_explanation(n_samples=20, n_features=8) 158 | self.plotter = ExplanationPlotter(self.explanation) 159 | 160 | def test_returns_matplotlib_axis(self): 161 | """Test that beeswarm returns a matplotlib Axes object.""" 162 | ax = self.plotter.beeswarm() 163 | assert isinstance(ax, plt.Axes) 164 | plt.close() 165 | 166 | def test_max_display_limits_features(self): 167 | """Test that max_display parameter limits the number of displayed features.""" 168 | max_display = 3 169 | ax = self.plotter.beeswarm(max_display=max_display) 170 | 171 | # Check that the number of y-tick labels is at most max_display 172 | y_labels = ax.get_yticklabels() 173 | assert len(y_labels) == max_display 174 | plt.close() 175 | 176 | def test_max_display_none_shows_all_features(self): 177 | """Test that max_display=None shows all features.""" 178 | ax = self.plotter.beeswarm(max_display=None) 179 | 180 | y_labels = ax.get_yticklabels() 181 | assert len(y_labels) == self.explanation.shape[1] # n_features 182 | plt.close() 183 | 184 | def test_custom_axis(self): 185 | """Test that custom axis can be provided.""" 186 | fig, ax = plt.subplots() 187 | fig.canvas.draw() # Force canvas initialization 188 | result_ax = self.plotter.beeswarm(ax=ax) 189 | 190 | assert result_ax is ax 191 | plt.close() 192 | 193 | def test_inverted_yaxis(self): 194 | """Test that y-axis is properly inverted for feature importance ranking.""" 195 | ax = self.plotter.beeswarm() 196 | 197 | # Y-axis should be inverted (higher importance features at top) 198 | ylim = ax.get_ylim() 199 | assert ylim[0] > ylim[1] # Inverted axis 200 | plt.close() 201 | 202 | def test_multi_output_explanation(self): 203 | """Test beeswarm with multi-output explanation.""" 204 | multi_explanation = create_explanation(n_samples=10, n_features=5, n_outputs=3) 205 | plotter = ExplanationPlotter(multi_explanation) 206 | 207 | ax = plotter.beeswarm(which_output=1) 208 | assert isinstance(ax, plt.Axes) 209 | plt.close() 210 | 211 | def test_colorbar_exists(self): 212 | """Test that a colorbar is created in the plot.""" 213 | ax = self.plotter.beeswarm() 214 | fig = ax.get_figure() 215 | 216 | # Check that colorbar was added to the figure 217 | assert len(fig.axes) == 2 # Main axis + colorbar axis 218 | plt.close() 219 | 220 | def test_jitter_width_zero(self): 221 | """Test that jitter_width=0 works without errors.""" 222 | ax = self.plotter.beeswarm(jitter_width=0) 223 | assert isinstance(ax, plt.Axes) 224 | plt.close() 225 | 226 | def test_scatter_points_exist(self): 227 | """Test that scatter points are actually plotted.""" 228 | ax = self.plotter.beeswarm() 229 | 230 | # Check that there are scatter plot collections 231 | collections = [c for c in ax.collections if hasattr(c, "get_offsets")] 232 | assert len(collections) > 0 233 | plt.close() 234 | 235 | def test_axis_labels_and_grid(self): 236 | """Test that proper axis labels and grid are set.""" 237 | ax = self.plotter.beeswarm() 238 | 239 | assert ax.get_xlabel() == "SHAP Value" 240 | assert ax.grid # Grid should be enabled 241 | plt.close() 242 | 243 | def test_invalid_ax_type_raises_error(self): 244 | """Test that invalid ax parameter raises TypeError.""" 245 | with pytest.raises(TypeError, match="ax must be a matplotlib Axes"): 246 | self.plotter.beeswarm(ax="invalid") 247 | 248 | 249 | class TestWaterfall: 250 | """Test suite for the waterfall method.""" 251 | 252 | def setup_method(self): 253 | """Set up test fixtures.""" 254 | self.explanation = create_explanation(n_samples=20, n_features=8) 255 | self.plotter = ExplanationPlotter(self.explanation) 256 | 257 | def test_returns_matplotlib_axis(self): 258 | """Test that waterfall returns a matplotlib Axes object.""" 259 | ax = self.plotter.waterfall(row_id=0) 260 | assert isinstance(ax, plt.Axes) 261 | plt.close() 262 | 263 | def test_max_display_limits_features(self): 264 | """Test that max_display parameter limits the number of displayed features.""" 265 | max_display = 2 266 | ax = self.plotter.waterfall(row_id=0, max_display=max_display) 267 | 268 | # Check that the number of y-tick labels is at most max_display 269 | y_labels = ax.get_yticklabels() 270 | assert len(y_labels) == max_display 271 | plt.close() 272 | 273 | def test_max_display_none_shows_all_features(self): 274 | """Test that max_display=None shows all features.""" 275 | ax = self.plotter.waterfall(row_id=0, max_display=None) 276 | 277 | y_labels = ax.get_yticklabels() 278 | assert len(y_labels) == self.explanation.shape[1] # n_features 279 | plt.close() 280 | 281 | def test_row_id_validation(self): 282 | """Test that invalid row_id raises ValueError.""" 283 | with pytest.raises(ValueError, match="row_id .* must be integer"): 284 | self.plotter.waterfall(row_id=len(self.explanation)) 285 | 286 | with pytest.raises(ValueError, match="row_id .* must be integer"): 287 | self.plotter.waterfall(row_id=-1) 288 | 289 | def test_fill_colors_validation(self): 290 | """Test that invalid fill_colors raises ValueError.""" 291 | with pytest.raises( 292 | ValueError, match="fill_colors must be a list or tuple of length 2" 293 | ): 294 | self.plotter.waterfall(row_id=0, fill_colors=["red"]) 295 | 296 | with pytest.raises( 297 | ValueError, match="fill_colors must be a list or tuple of length 2" 298 | ): 299 | self.plotter.waterfall(row_id=0, fill_colors="red") 300 | 301 | def test_annotation_validation(self): 302 | """Test that invalid annotation parameter raises ValueError.""" 303 | with pytest.raises( 304 | ValueError, match="annotation must be a list or tuple of length 2" 305 | ): 306 | self.plotter.waterfall(row_id=0, annotation=["E[f(x)]"]) 307 | 308 | def test_custom_axis(self): 309 | """Test that custom axis can be provided.""" 310 | fig, ax = plt.subplots() 311 | fig.canvas.draw() # Force canvas initialization 312 | result_ax = self.plotter.waterfall(row_id=0, ax=ax) 313 | 314 | assert result_ax is ax 315 | plt.close() 316 | 317 | def test_multi_output_explanation(self): 318 | """Test waterfall with multi-output explanation.""" 319 | multi_explanation = create_explanation(n_samples=10, n_features=5, n_outputs=3) 320 | plotter = ExplanationPlotter(multi_explanation) 321 | 322 | ax = plotter.waterfall(which_output=1) 323 | assert isinstance(ax, plt.Axes) 324 | plt.close() 325 | 326 | def test_max_display_creates_other_features_label(self): 327 | """Test that max_display creates 'other features' label when needed.""" 328 | max_display = 3 329 | ax = self.plotter.waterfall(row_id=0, max_display=max_display) 330 | 331 | y_labels = [label.get_text() for label in ax.get_yticklabels()] 332 | 333 | # Should have max_display labels 334 | assert len(y_labels) == max_display 335 | 336 | # Last label should mention "other features" if we collapsed features 337 | if self.explanation.shape[1] > max_display: 338 | assert ( 339 | f"{self.explanation.shape[1] - max_display + 1} other features" 340 | in y_labels[-1] 341 | ) 342 | 343 | plt.close() 344 | 345 | 346 | class TestScatter: 347 | """Test suite for the scatter method.""" 348 | 349 | def setup_method(self): 350 | """Set up test fixtures.""" 351 | self.explanation = create_explanation(n_samples=20, n_features=8) 352 | self.plotter = ExplanationPlotter(self.explanation) 353 | 354 | def test_returns_matplotlib_axis(self): 355 | """Test that scatter returns a matplotlib Axes object.""" 356 | ax = self.plotter.scatter(features=["feature_1", "feature_2"]) 357 | assert isinstance(ax, np.ndarray) # Returns array of axes for subplots 358 | assert isinstance(ax.flatten()[0], plt.Axes) 359 | assert len(ax.flatten()) == 2 360 | plt.close() 361 | 362 | def test_single_feature(self): 363 | """Test scatter plot with a single feature.""" 364 | ax = self.plotter.scatter(features=["feature_1"]) 365 | assert isinstance(ax, np.ndarray) 366 | assert len(ax.flatten()) == 1 367 | plt.close() 368 | 369 | def test_custom_axis_single(self): 370 | """Test that custom single axis can be provided.""" 371 | fig, ax = plt.subplots() 372 | result_ax = self.plotter.scatter(features=["feature_1"], ax=ax) 373 | 374 | assert result_ax is ax 375 | plt.close() 376 | 377 | def test_custom_axis_array(self): 378 | """Test that custom axis array can be provided.""" 379 | fig, axes = plt.subplots(1, 2) 380 | result_ax = self.plotter.scatter(features=["feature_1", "feature_2"], ax=axes) 381 | 382 | assert np.array_equal(result_ax, axes) 383 | plt.close() 384 | 385 | def test_multi_output_explanation(self): 386 | """Test scatter with multi-output explanation.""" 387 | multi_explanation = create_explanation(n_samples=10, n_features=5, n_outputs=3) 388 | plotter = ExplanationPlotter(multi_explanation) 389 | 390 | ax = plotter.scatter(features=["feature_1"], which_output=1) 391 | assert isinstance(ax, plt.Axes | np.ndarray) 392 | plt.close() 393 | 394 | def test_color_features_empty_list(self): 395 | """Test scatter with color_features=[] uses the specified color.""" 396 | custom_color = "#ff5733" 397 | ax = self.plotter.scatter( 398 | features=["feature_1"], color_features=[], color=custom_color 399 | ) 400 | 401 | first_ax = ax.flatten()[0] if isinstance(ax, np.ndarray) else ax 402 | 403 | scatter_collections = [ 404 | c for c in first_ax.collections if hasattr(c, "get_facecolors") 405 | ] 406 | assert len(scatter_collections) > 0 407 | 408 | # Get the colors of the scatter points 409 | face_colors = scatter_collections[0].get_facecolors() 410 | assert len(face_colors) > 0 411 | 412 | # Convert custom_color to RGBA for comparison 413 | expected_rgba = mcolors.to_rgba(custom_color) 414 | 415 | # Check if all points have the expected color (within tolerance) 416 | # Note: matplotlib sometimes adds alpha channel, so we compare RGB components 417 | for color in face_colors: 418 | assert np.allclose(color[:3], expected_rgba[:3], atol=0.01) 419 | 420 | plt.close() 421 | 422 | def test_axis_labels(self): 423 | """Test that proper axis labels are set.""" 424 | ax = self.plotter.scatter(features=["feature_1"]) 425 | 426 | # Get the first (or only) subplot 427 | first_ax = ax.flatten()[0] if isinstance(ax, np.ndarray) else ax 428 | assert first_ax.get_xlabel() == "feature_1" 429 | plt.close() 430 | 431 | def test_no_shared_y_axis(self): 432 | """Test sharey=False parameter functionality.""" 433 | ax = self.plotter.scatter(features=["feature_1", "feature_2"], sharey=False) 434 | assert isinstance(ax, np.ndarray) 435 | plt.close() 436 | 437 | def test_invalid_ax_type_raises_error(self): 438 | """Test that invalid ax parameter raises TypeError.""" 439 | with pytest.raises( 440 | TypeError, match="ax must be a matplotlib Axes or an array of Axes" 441 | ): 442 | self.plotter.scatter(features=["feature_1"], ax="invalid") 443 | 444 | def test_mismatched_axes_count_raises_error(self): 445 | """Test that mismatched number of axes and features raises ValueError.""" 446 | # Create 2 axes but try to plot 3 features 447 | fig, axes = plt.subplots(1, 2) 448 | 449 | with pytest.raises(ValueError, match="Expected 3 axes, got 2"): 450 | self.plotter.scatter( 451 | features=["feature_1", "feature_2", "feature_3"], 452 | ax=axes, 453 | ) 454 | 455 | plt.close() 456 | -------------------------------------------------------------------------------- /src/lightshap/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import polars as pl 4 | import pytest 5 | 6 | from lightshap.utils import get_dataclass, get_polars 7 | 8 | 9 | class TestGetDataclass: 10 | """Test the get_dataclass utility function.""" 11 | 12 | def test_numpy_array(self): 13 | """Test that numpy arrays are correctly identified.""" 14 | data = np.array([[1, 2], [3, 4]]) 15 | result = get_dataclass(data) 16 | assert result == "np" 17 | 18 | def test_pandas_dataframe(self): 19 | """Test that pandas DataFrames are correctly identified.""" 20 | data = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) 21 | result = get_dataclass(data) 22 | assert result == "pd" 23 | 24 | def test_polars_dataframe(self): 25 | """Test that polars DataFrames are correctly identified.""" 26 | data = pl.DataFrame({"A": [1, 2], "B": [3, 4]}) 27 | result = get_dataclass(data) 28 | assert result == "pl" 29 | 30 | def test_unknown_type_error(self): 31 | """Test that unknown data types raise KeyError.""" 32 | data = {"A": [1, 2], "B": [3, 4]} # Plain dict 33 | with pytest.raises(KeyError, match="Unknown data class"): 34 | get_dataclass(data) 35 | 36 | 37 | def test_get_polars_success(): 38 | """Test that get_polars returns polars module when available.""" 39 | result = get_polars() 40 | assert result is pl 41 | -------------------------------------------------------------------------------- /src/lightshap/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Handle polars import at module level 4 | try: 5 | import polars as pl 6 | 7 | HAS_POLARS = True 8 | except ImportError: 9 | HAS_POLARS = False 10 | pl = None 11 | 12 | 13 | def get_dataclass(data): 14 | """Determine the type of the input. 15 | 16 | Returns a string indicating whether the input is pandas, 17 | numpy, or polars, or raises a TypeError otherwise. 18 | 19 | Both Series-like and DataFrame-like objects are accepted. 20 | 21 | Parameters 22 | ---------- 23 | data : DataFrame-like or Series-like 24 | The input data to determine the type of. 25 | 26 | Returns 27 | ------- 28 | str 29 | A string indicating the type of the data: "pd" for pandas, 30 | "np" for numpy array, or "pl" for polars DataFrame. 31 | """ 32 | if isinstance(data, np.ndarray): 33 | return "np" 34 | if hasattr(data, "iloc"): 35 | return "pd" 36 | if hasattr(data, "with_columns"): 37 | return "pl" 38 | else: 39 | msg = "Unknown data class. Expected 'numpy', 'pandas', or 'polars'" 40 | raise KeyError(msg) 41 | 42 | 43 | def get_polars(): 44 | """Get polars module or raise error if not available.""" 45 | if not HAS_POLARS: 46 | raise ImportError( 47 | "polars is required but is not installed. " 48 | "Install it with: pip install polars" 49 | ) 50 | return pl 51 | --------------------------------------------------------------------------------