├── .gitattributes ├── .github └── workflows │ ├── docs_pages.yml │ └── python-app.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── nshap.rst └── usage │ ├── examples.rst │ ├── installation.rst │ └── notebooks │ ├── Example1.ipynb │ ├── Example2.ipynb │ └── Example3.ipynb ├── images ├── img1.png ├── img2.png ├── img3.png ├── img4.png ├── img5.png └── img6.png ├── notebooks ├── Example1.ipynb ├── Example2.ipynb ├── Example3.ipynb └── replicate-paper │ ├── checkerboard-compute-million.ipynb │ ├── checkerboard-compute.ipynb │ ├── checkerboard-figures.ipynb │ ├── compute-vfunc.ipynb │ ├── compute.ipynb │ ├── datasets.py │ ├── estimation.ipynb │ ├── figures.ipynb │ ├── hyperparameters.ipynb │ └── paperutil.py ├── pyproject.toml ├── setup.py ├── src └── nshap │ ├── InteractionIndex.py │ ├── __init__.py │ ├── functions.py │ ├── plot.py │ ├── util.py │ └── vfunc.py └── tests ├── test_util.py └── tests.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/docs_pages.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | # execute this workflow automatically when a we push to master 4 | on: 5 | push: 6 | branches: [ main ] 7 | 8 | jobs: 9 | 10 | build_docs_job: 11 | runs-on: ubuntu-latest 12 | env: 13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} 14 | 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v3 18 | 19 | - name: Set up Python 3.10 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: "3.10" 23 | 24 | - name: Install package and dependencies for building docs 25 | run: | 26 | sudo apt-get install pandoc 27 | python -m pip install --upgrade pip 28 | python -m pip install -U sphinx 29 | python -m pip install sphinx-rtd-theme 30 | # python -m pip install sphinxcontrib-apidoc 31 | python -m pip install sphinx-autoapi 32 | python -m pip install nbsphinx 33 | python -m pip install pypandoc 34 | pip install . 35 | - name: make the sphinx docs 36 | run: | 37 | make -C docs clean 38 | # sphinx-apidoc -f -o docs/source . -H Test -e -t docs/source/_templates 39 | make -C docs html 40 | - name: Init new repo in dist folder and commit generated files 41 | run: | 42 | cd docs/_build/html/ 43 | git init 44 | touch .nojekyll 45 | git add -A 46 | git config --local user.email "action@github.com" 47 | git config --local user.name "GitHub Action" 48 | git commit -m 'deploy' 49 | - name: Force push to destination branch 50 | uses: ad-m/github-push-action@master 51 | with: 52 | github_token: ${{ secrets.GITHUB_TOKEN }} 53 | branch: docs-build 54 | force: true 55 | directory: ./docs/_build/html 56 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: pytesting 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | tests: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install package & dependencies required for testing 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | pip install . 32 | pip install folktables xgboost 33 | pip install -U scikit-learn 34 | - name: Lint with flake8 35 | run: | 36 | # stop the build if there are Python syntax errors or undefined names 37 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 38 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 39 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 40 | - name: Test with pytest 41 | run: | 42 | pytest tests/tests.py 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | figures/ 3 | results/ 4 | output_notebooks/ 5 | 6 | # saved interaction indices 7 | *.json 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .DS_Store 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 TML Tübingen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to the nshap Package! 2 | 3 | 4 |

5 | Shapley Values 6 |

7 | 8 | This is a python package to compute interaction indices that extend the Shapley Value. It accompanies the AISTATS'23 paper [From Shapley Values to Generalized Additive Models and back](http://arxiv.org/abs/2209.04012) by Sebastian Bordt and Ulrike von Luxburg. 9 | 10 | [![PyPi package version](https://img.shields.io/pypi/v/nshap.svg)](https://pypi.org/project/nshap/) 11 | [![sphinx documentation for latest release](https://github.com/tml-tuebingen/nshap/workflows/docs/badge.svg)](https://tml-tuebingen.github.io/nshap/) 12 | ![tests](https://github.com/tml-tuebingen/nshap/workflows/pytesting/badge.svg) 13 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg?color=g&style=plastic)](https://opensource.org/licenses/MIT) 14 | 15 | This package supports, among others, 16 | 17 | - [n-Shapley Values](http://arxiv.org/abs/2209.04012), introduced in our paper 18 | - [SHAP Interaction Values](https://www.nature.com/articles/s42256-019-0138-9), a popular interaction index that can also be computed with the [shap](https://github.com/slundberg/shap/) package 19 | - the [Shapley Taylor](https://arxiv.org/abs/1902.05622) interaction index 20 | - the [Faith-Shap](https://arxiv.org/abs/2203.00870) interaction index 21 | - the [Faith-Banzhaf](https://arxiv.org/abs/2203.00870) interaction index. 22 | 23 | The package works with arbitrary user-defined value functions. It also provides a model-agnostic implementation of the interventional SHAP value function. 24 | 25 | Note that the computed interaction indices are an estimate [that can be inaccurate](#estimation), especially if the order of the interaction is large. 26 | 27 | Documentation is available at [https://tml-tuebingen.github.io/nshap](https://tml-tuebingen.github.io/nshap/). 28 | 29 | ⚠️ Disclaimer 30 | 31 | This package does not provide an efficient way to compute Shapley Values. For this you should refer to the [shap](https://github.com/slundberg/shap/) package or approaches like [FastSHAP](https://arxiv.org/abs/2107.07436). In practice, the current implementation works for arbitrary functions of up to ~10 variables. 32 | 33 | ## Setup 34 | 35 | To install the package run 36 | 37 | ``` 38 | pip install nshap 39 | ``` 40 | 41 | ## Computing Interaction Indices 42 | 43 | Let's assume that we have trained a Gradient Boosted Tree on the [Folktables](https://github.com/zykls/folktables) Income data set. 44 | 45 | ```python 46 | gbtree = xgboost.XGBClassifier() 47 | gbtree.fit(X_train, Y_train) 48 | print(f'Accuracy: {accuracy_score(Y_test, gbtree.predict(X_test)):0.3f}') 49 | ``` 50 | ```Accuracy: 0.830``` 51 | 52 | Now we want to compute an interaction index. This package supports interaction indices that extend the Shapley Value. This means that the interaction index is based on a value function, just as the Shapley Value. So we need to define a value function. We can use the function ```nshap.vfunc.interventional_shap```, which approximates the interventional SHAP value function. 53 | 54 | ```python 55 | import nshap 56 | 57 | vfunc = nshap.vfunc.interventional_shap(gbtree.predict_proba, X_train, target=0, num_samples=1000) 58 | ``` 59 | The function takes 4 arguments 60 | 61 | - The function that we want to explain 62 | - The training data or another sample from the data distribution 63 | - The target class (required here since 'predict_proba' has 2 outputs). 64 | - The number of samples that should be used to estimate the expectation (Default: 1000) 65 | 66 | Equipped with a value function, we can compute different kinds of interaction indices. We can compute n-Shapley Values 67 | 68 | ```python 69 | n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc, n=8) 70 | ``` 71 | 72 | the Shapley-Taylor interaction index 73 | 74 | ```python 75 | shapley_taylor = nshap.shapley_taylor(X_test[0, :], vfunc, n=8) 76 | ``` 77 | 78 | or the Faith-Shap interaction index of order 3 79 | 80 | ```python 81 | faith_shap = nshap.faith_shap(X_test[0, :], vfunc, n=3) 82 | ``` 83 | 84 | The functions that compute interaction indices have a common interface. They take 3 arguments 85 | 86 | - ```x```: The data point for which to compute the explanation ([numpy.ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html)) 87 | 88 | - ```v_func```: The value function. 89 | 90 | - ```n```: The order of the interaction index. Defaults to the number of features. 91 | 92 | All functions return an object of type ```InteractionIndex```. To get the interaction between features 2 and 3, simply call 93 | 94 | ```python 95 | n_shapley_values[(2,3)] 96 | ``` 97 | 98 | ```-0.0054``` 99 | 100 | To visualize an interaction index, call 101 | 102 | ```python 103 | n_shapley_values.plot(feature_names = feature_names) 104 | ``` 105 | 106 |

107 | 10-Shapley Values 108 |

109 | 110 | This works for all interaction indices 111 | 112 | ```python 113 | faith_shap.plot(feature_names = feature_names) 114 | ``` 115 | 116 |

117 | 10-Shapley Values 118 |

119 | 120 | For n-Shapley Values, we can compute interaction indices of lower order from those of higher order 121 | 122 | ```python 123 | n_shapley_values.k_shapley_values(2).plot(feature_names = feature_names) 124 | ``` 125 | 126 |

127 | 2-Shapley Values 128 |

129 | 130 | We can also obtain the original Shapley Values and plot them with the plotting functions from the [shap](https://github.com/slundberg/shap/) package. 131 | 132 | ```python 133 | import shap 134 | 135 | shap.force_plot(vfunc(X_test[0,:], []), n_shapley_values.shapley_values()) 136 | ``` 137 | 138 |

139 | Shapley Values 140 |

141 | 142 | Let us compare our result to the Shapley Values obtained from the KernelSHAP Algorithm. 143 | 144 | ```python 145 | explainer = shap.KernelExplainer(gbtree.predict_proba, shap.kmeans(X_train, 25)) 146 | shap.force_plot(explainer.expected_value[0], shap_values[0]) 147 | ``` 148 | 149 |

150 | Shapley Values 151 |

152 | 153 | ## The ```InteractionIndex``` class 154 | 155 | The ```InteractionIndex``` class is a python ```dict``` with some added functionallity. It supports the following operations. 156 | 157 | - The individual attributions can be indexed with tuples of integers. For example, indexing with ```(0,)``` returns the main effect of the first feature. Indexing with ```(0,1,2)``` returns the interaction effect between features 0, 1 and 2. 158 | 159 | - ```plot()``` generates the plots described in the paper. 160 | 161 | - ```sum()``` sums the individual attributions (this does usually sum to the function value minus the value of the empty coalition) 162 | 163 | - ```save(fname)``` serializes the object to json. Can be loaded from there with ```nshap.load(fname)```. This can be useful since computing interaction indices takes time, so you might want to compute them in parallel, then aggregate the results for analysis. 164 | 165 | Some function can only be called certain interaction indices: 166 | 167 | - ```k_shapley_values(k)``` computes the $k$-Shapley Values using the recursive relationship among $n$-Shapley Values of different order (requires $k\leq n$). Can only be called for $n$-Shapley Values. 168 | 169 | - ```shapley_values()``` returns the associated original Shapley Values as a list. Useful for compatiblity with the [shap](https://github.com/slundberg/shap/) package. 170 | 171 | ## Definig Value Functions 172 | 173 | A value function has to follow the interface ```v_func(x, S)``` where ```x``` is a single data point (a [numpy.ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html)) and ```S``` is a python ```list``` with the indices the the coordinates that belong to the coaltion. 174 | 175 | In the introductory example with the Gradient Boosted Tree, 176 | 177 | ```python 178 | vfunc(x, []) 179 | ``` 180 | 181 | returns the expected predicted probability that an observation belongs to class 0, and 182 | 183 | ```python 184 | vfunc(x, [0,1,2,3,4,5,6,7]) 185 | ``` 186 | 187 | returns the predicted probability that the observation ```x``` belongs to class 0 (note that the problem is 8-dimensional). 188 | 189 | ## Implementation Details 190 | 191 | At the moment all functions computes interaction indices simply via their definition. Independent of the order ```n``` of the $n$-Shapley Values, this requires to call the value function ```v_func``` once for all $2^d$ subsets of coordinates. Thus, the current implementation provides no essential speedup for the computation of $n$-Shapley Values of lower order. 192 | 193 | The function ```nshap.vfunc.interventional_shap``` approximates the interventional SHAP value function by intervening on the coordinates of randomly sampled points from the data distributions. 194 | 195 | ## Accuray of the computed interaction indices 196 | 197 | The computed interaction indices are an estimate which can be inaccurate. 198 | 199 | The estimation error depends on the precision of the value function. With the provided implementation of the interventional SHAP value function, the precision depends on the number of samples used to estimate the expectation. 200 | 201 | A simple way to test whether your result is precisely estimated to increase the number of samples (the ```num_samples``` parameter of ```nshap.vfunc.interventional_shap```) and see if the result changes. 202 | 203 | For more details, check out the discussion in [Section 8 of our paper](http://arxiv.org/abs/2209.04012). 204 | 205 | ## Replicating the Results in our Paper 206 | 207 | The folder ```notebooks\replicate-paper``` contains Jupyter Notebooks that allow to replicated the results in our [paper](http://arxiv.org/abs/2209.04012). 208 | 209 | - The notebooks ```figures.ipynb``` and ```checkerboard-figures.ipynb``` generate all the figures in the paper. 210 | - The notebook ```estimation.ipynb ``` provides the estimation example with the kNN classifier on the Folktables Travel data set that we discuss in Appendix Section B. 211 | - The notebook ```hyperparameters.ipynb``` cross-validates the parameter $k$ of the kNN classifier. 212 | - The notebooks ```compute.ipynb```, ```compute-vfunc.ipynb```, ```checkerboard-compute.ipynb``` and ```checkerboard-compute-million.ipynb``` compute the different $n$-Shapley Values. You do not have to run these notebooks, the pre-computed results can be downloaded [here](https://nextcloud.tuebingen.mpg.de/index.php/s/SsowoR7SAibQYE7). 213 | 214 | ⚠️ Important 215 | 216 | You have use version 0.1.0 of this package in order to run the notebooks that replicate the results in the paper. 217 | 218 | ``` 219 | pip install nshap=0.1.0 220 | ``` 221 | 222 | ## Citing nshap 223 | 224 | If you use this software in your research, we encourage you to cite our paper. 225 | 226 | ```bib 227 | @inproceedings{bordtlux2023, 228 | author = {Bordt, Sebastian and von Luxburg, Ulrike}, 229 | title = {From Shapley Values to Generalized Additive Models and back}, 230 | booktitle = {AISTATS}, 231 | year = {2023} 232 | } 233 | ``` 234 | 235 | If you use interaction indices that were introduced in other works, such as [Shapley Taylor](https://arxiv.org/abs/1902.05622) or [Faith-Shap](https://arxiv.org/abs/2203.00870), you should also consider to cite the respective papers. 236 | 237 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../src/nshap/')) 16 | sys.path.insert(0, os.path.abspath('../src/')) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = 'nshap' 22 | copyright = '2023, Sebastian Bordt' 23 | author = 'Sebastian Bordt' 24 | release = '0.2.0' 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | 'sphinx.ext.duration', 33 | 'sphinx.ext.doctest', 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.napoleon', 36 | 'nbsphinx' 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | 50 | # The theme to use for HTML and HTML Help pages. See the documentation for 51 | # a list of builtin themes. 52 | # 53 | html_theme = 'sphinx_rtd_theme' 54 | 55 | # Add any paths that contain custom static files (such as style sheets) here, 56 | # relative to this directory. They are copied after the builtin static files, 57 | # so a file named "default.css" will overwrite the builtin "default.css". 58 | #html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. nshap documentation master file, created by 2 | sphinx-quickstart on Fri Oct 28 14:04:18 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to the Documentation of the nshap Package! 7 | ================================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | usage/installation 14 | usage/examples 15 | nshap 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/nshap.rst: -------------------------------------------------------------------------------- 1 | Documentation 2 | ============= 3 | 4 | Functions that compute interaction indices 5 | ------------------------------------------ 6 | 7 | .. automodule:: nshap.functions 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | Value Functions 13 | --------------- 14 | 15 | .. automodule:: nshap.vfunc 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | The InteractionIndex class 21 | ------------------------ 22 | 23 | .. autoclass:: nshap.InteractionIndex.InteractionIndex 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | Plotting and Utilities 29 | ---------------------- 30 | 31 | .. automodule:: nshap.plot 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | .. automodule:: nshap.util 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /docs/usage/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | -------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | notebooks/Example1 9 | notebooks/Example2 10 | notebooks/Example3 -------------------------------------------------------------------------------- /docs/usage/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | To install the package via pip run 5 | 6 | :code:`pip install nshap` 7 | 8 | then you can use 9 | 10 | :code:`import nshap` 11 | 12 | to import the package. -------------------------------------------------------------------------------- /docs/usage/notebooks/Example3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Example 3: Custom value functions\n", 9 | "#### These examples are from section 5.1 of the Faith-Shap paper: https://arxiv.org/abs/2203.00870" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "id": "c10cfd76", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import seaborn as sns\n", 20 | "\n", 21 | "sns.set_style(\"whitegrid\")\n", 22 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", 23 | "\n", 24 | "import numpy as np\n", 25 | "import math\n", 26 | "\n", 27 | "import nshap" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "id": "f0b5f063", 34 | "metadata": {}, 35 | "source": [ 36 | "### Example 1\n", 37 | "##### A value function takes two arguments: A single data point x (a numpy.ndarray) and a python list S with the indices the the coordinates that belong to the coaltion.\n", 38 | "##### In this example, the value does not actually depend on the point x" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 7, 44 | "id": "44978e93", 45 | "metadata": { 46 | "tags": [] 47 | }, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "3.4" 53 | ] 54 | }, 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "p = 0.1\n", 62 | "def v_func(x, S):\n", 63 | " \"\"\" The value function from Example 1 in the Faith-Shap paper.\n", 64 | " \"\"\"\n", 65 | " if len(S) <= 1:\n", 66 | " return 0\n", 67 | " return len(S) - p * math.comb(len(S), 2)\n", 68 | "\n", 69 | "\n", 70 | "v_func(np.zeros((1,11)), [1,2,3,4])" 71 | ] 72 | }, 73 | { 74 | "attachments": {}, 75 | "cell_type": "markdown", 76 | "id": "86a0f974", 77 | "metadata": {}, 78 | "source": [ 79 | "##### Equipped with the value function, we can compute different kinds of interaction indices" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "fbbbbcf5", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, 1)\n", 90 | "faith_shap[(0,1)], faith_shap[(1,2)]" 91 | ] 92 | }, 93 | { 94 | "attachments": {}, 95 | "cell_type": "markdown", 96 | "id": "08657d28", 97 | "metadata": {}, 98 | "source": [ 99 | "##### We can replicate Table 1 in the Faith-Shap paper" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 13, 105 | "id": "d4bcee91", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Table 1 in the Faith-Shap paper:\n", 113 | "\n", 114 | "p=0.1 l=1\n", 115 | "\n", 116 | "Faith-Shap: 0.5000000000001867\n", 117 | "Shapley Taylor: 0.500000000000185\n", 118 | "Interaction Shapley: 0.4999999999999939\n", 119 | "Banzhaf Interaction: 0.5087890624999979\n", 120 | "Faith-Banzhaf: 0.5087890624999989\n", 121 | "\n", 122 | "p=0.1 l=2\n", 123 | "\n", 124 | "Faith-Shap: 0.9545454545463214 -0.09090909090923938\n", 125 | "Shapley Taylor: 0 0.10000000000002415\n", 126 | "Interaction Shapley: 0.4999999999999939 -5.759281940243e-16\n", 127 | "Banzhaf Interaction: 0.5087890624999979 -0.11367187500000073\n", 128 | "Faith-Banzhaf: 1.0771484374999503 -0.11367187499999737\n", 129 | "\n", 130 | "p=0.2 l=1\n", 131 | "\n", 132 | "Faith-Shap: 1.6486811915683575e-13\n", 133 | "Shapley Taylor: 1.7552626019323725e-13\n", 134 | "Interaction Shapley: 2.7755575615628914e-17\n", 135 | "Banzhaf Interaction: 0.0087890625\n", 136 | "Faith-Banzhaf: 0.008789062500008604\n", 137 | "\n", 138 | "p=0.2 l=2\n", 139 | "\n", 140 | "Faith-Shap: 0.9545454545460967 -0.19090909090925617\n", 141 | "Shapley Taylor: 0 -3.7941871866564725e-14\n", 142 | "Interaction Shapley: 2.7755575615628914e-17 -0.10000000000000021\n", 143 | "Banzhaf Interaction: 0.0087890625 -0.21367187500000198\n", 144 | "Faith-Banzhaf: 1.0771484374999574 -0.21367187499998952\n", 145 | "\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "print('Table 1 in the Faith-Shap paper:\\n')\n", 151 | "for p in [0.1, 0.2]:\n", 152 | " for l in [1, 2]:\n", 153 | " print(f'p={p} l={l}\\n')\n", 154 | "\n", 155 | " # define the value function\n", 156 | " def v_func(x, S):\n", 157 | " if len(S) <= 1:\n", 158 | " return 0\n", 159 | " return len(S) - p * math.comb(len(S), 2)\n", 160 | "\n", 161 | " # compute interaction indices\n", 162 | " faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, l)\n", 163 | " shapley_taylor = nshap.shapley_taylor(np.zeros((1,11)), v_func, l)\n", 164 | " shapley_interaction = nshap.shapley_interaction_index(np.zeros((1,11)), v_func, l)\n", 165 | " banzhaf_interaction = nshap.banzhaf_interaction_index(np.zeros((1,11)), v_func, l)\n", 166 | " faith_banzhaf = nshap.faith_banzhaf(np.zeros((1,11)), v_func, l)\n", 167 | "\n", 168 | " # print result\n", 169 | " if l == 1:\n", 170 | " print('Faith-Shap: ', faith_shap[(0,)])\n", 171 | " print('Shapley Taylor: ', shapley_taylor[(0,)])\n", 172 | " print('Interaction Shapley: ', shapley_interaction[(0,)])\n", 173 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)])\n", 174 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)])\n", 175 | " else:\n", 176 | " print('Faith-Shap: ', faith_shap[(0,)], faith_shap[(0,1)])\n", 177 | " print('Shapley Taylor: ', shapley_taylor[(0,)], shapley_taylor[(0,1)])\n", 178 | " print('Interaction Shapley: ', shapley_interaction[(0,)], shapley_interaction[(0,1)])\n", 179 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)], banzhaf_interaction[(0,1)])\n", 180 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)], faith_banzhaf[(0,1)])\n", 181 | " print('')" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3.9.12 ('base')", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.9.12" 202 | }, 203 | "vscode": { 204 | "interpreter": { 205 | "hash": "62406f3c2942480da828869ab3f3f95d1c0177b3689d5bc770f3ddfd7b9b3df5" 206 | } 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 5 211 | } 212 | -------------------------------------------------------------------------------- /images/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img1.png -------------------------------------------------------------------------------- /images/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img2.png -------------------------------------------------------------------------------- /images/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img3.png -------------------------------------------------------------------------------- /images/img4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img4.png -------------------------------------------------------------------------------- /images/img5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img5.png -------------------------------------------------------------------------------- /images/img6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img6.png -------------------------------------------------------------------------------- /notebooks/Example3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Example 3: Custom value functions\n", 9 | "#### These examples are from section 5.1 of the Faith-Shap paper: https://arxiv.org/abs/2203.00870" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "id": "c10cfd76", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import seaborn as sns\n", 20 | "\n", 21 | "sns.set_style(\"whitegrid\")\n", 22 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", 23 | "\n", 24 | "import numpy as np\n", 25 | "import math\n", 26 | "\n", 27 | "import nshap" 28 | ] 29 | }, 30 | { 31 | "attachments": {}, 32 | "cell_type": "markdown", 33 | "id": "f0b5f063", 34 | "metadata": {}, 35 | "source": [ 36 | "### Example 1\n", 37 | "##### A value function takes two arguments: A single data point x (a numpy.ndarray) and a python list S with the indices the the coordinates that belong to the coaltion.\n", 38 | "##### In this example, the value does not actually depend on the point x" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 7, 44 | "id": "44978e93", 45 | "metadata": { 46 | "tags": [] 47 | }, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/plain": [ 52 | "3.4" 53 | ] 54 | }, 55 | "execution_count": 7, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "p = 0.1\n", 62 | "def v_func(x, S):\n", 63 | " \"\"\" The value function from Example 1 in the Faith-Shap paper.\n", 64 | " \"\"\"\n", 65 | " if len(S) <= 1:\n", 66 | " return 0\n", 67 | " return len(S) - p * math.comb(len(S), 2)\n", 68 | "\n", 69 | "\n", 70 | "v_func(np.zeros((1,11)), [1,2,3,4])" 71 | ] 72 | }, 73 | { 74 | "attachments": {}, 75 | "cell_type": "markdown", 76 | "id": "86a0f974", 77 | "metadata": {}, 78 | "source": [ 79 | "##### Equipped with the value function, we can compute different kinds of interaction indices" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "fbbbbcf5", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, 1)\n", 90 | "faith_shap[(0,1)], faith_shap[(1,2)]" 91 | ] 92 | }, 93 | { 94 | "attachments": {}, 95 | "cell_type": "markdown", 96 | "id": "08657d28", 97 | "metadata": {}, 98 | "source": [ 99 | "##### We can replicate Table 1 in the Faith-Shap paper" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 13, 105 | "id": "d4bcee91", 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Table 1 in the Faith-Shap paper:\n", 113 | "\n", 114 | "p=0.1 l=1\n", 115 | "\n", 116 | "Faith-Shap: 0.5000000000001867\n", 117 | "Shapley Taylor: 0.500000000000185\n", 118 | "Interaction Shapley: 0.4999999999999939\n", 119 | "Banzhaf Interaction: 0.5087890624999979\n", 120 | "Faith-Banzhaf: 0.5087890624999989\n", 121 | "\n", 122 | "p=0.1 l=2\n", 123 | "\n", 124 | "Faith-Shap: 0.9545454545463214 -0.09090909090923938\n", 125 | "Shapley Taylor: 0 0.10000000000002415\n", 126 | "Interaction Shapley: 0.4999999999999939 -5.759281940243e-16\n", 127 | "Banzhaf Interaction: 0.5087890624999979 -0.11367187500000073\n", 128 | "Faith-Banzhaf: 1.0771484374999503 -0.11367187499999737\n", 129 | "\n", 130 | "p=0.2 l=1\n", 131 | "\n", 132 | "Faith-Shap: 1.6486811915683575e-13\n", 133 | "Shapley Taylor: 1.7552626019323725e-13\n", 134 | "Interaction Shapley: 2.7755575615628914e-17\n", 135 | "Banzhaf Interaction: 0.0087890625\n", 136 | "Faith-Banzhaf: 0.008789062500008604\n", 137 | "\n", 138 | "p=0.2 l=2\n", 139 | "\n", 140 | "Faith-Shap: 0.9545454545460967 -0.19090909090925617\n", 141 | "Shapley Taylor: 0 -3.7941871866564725e-14\n", 142 | "Interaction Shapley: 2.7755575615628914e-17 -0.10000000000000021\n", 143 | "Banzhaf Interaction: 0.0087890625 -0.21367187500000198\n", 144 | "Faith-Banzhaf: 1.0771484374999574 -0.21367187499998952\n", 145 | "\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "print('Table 1 in the Faith-Shap paper:\\n')\n", 151 | "for p in [0.1, 0.2]:\n", 152 | " for l in [1, 2]:\n", 153 | " print(f'p={p} l={l}\\n')\n", 154 | "\n", 155 | " # define the value function\n", 156 | " def v_func(x, S):\n", 157 | " if len(S) <= 1:\n", 158 | " return 0\n", 159 | " return len(S) - p * math.comb(len(S), 2)\n", 160 | "\n", 161 | " # compute interaction indices\n", 162 | " faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, l)\n", 163 | " shapley_taylor = nshap.shapley_taylor(np.zeros((1,11)), v_func, l)\n", 164 | " shapley_interaction = nshap.shapley_interaction_index(np.zeros((1,11)), v_func, l)\n", 165 | " banzhaf_interaction = nshap.banzhaf_interaction_index(np.zeros((1,11)), v_func, l)\n", 166 | " faith_banzhaf = nshap.faith_banzhaf(np.zeros((1,11)), v_func, l)\n", 167 | "\n", 168 | " # print result\n", 169 | " if l == 1:\n", 170 | " print('Faith-Shap: ', faith_shap[(0,)])\n", 171 | " print('Shapley Taylor: ', shapley_taylor[(0,)])\n", 172 | " print('Interaction Shapley: ', shapley_interaction[(0,)])\n", 173 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)])\n", 174 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)])\n", 175 | " else:\n", 176 | " print('Faith-Shap: ', faith_shap[(0,)], faith_shap[(0,1)])\n", 177 | " print('Shapley Taylor: ', shapley_taylor[(0,)], shapley_taylor[(0,1)])\n", 178 | " print('Interaction Shapley: ', shapley_interaction[(0,)], shapley_interaction[(0,1)])\n", 179 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)], banzhaf_interaction[(0,1)])\n", 180 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)], faith_banzhaf[(0,1)])\n", 181 | " print('')" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3.9.12 ('base')", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.9.12" 202 | }, 203 | "vscode": { 204 | "interpreter": { 205 | "hash": "62406f3c2942480da828869ab3f3f95d1c0177b3689d5bc770f3ddfd7b9b3df5" 206 | } 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 5 211 | } 212 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/checkerboard-compute-million.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Compute for the checkerboard function" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "tags": [ 15 | "parameters" 16 | ] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# papermill parameter: notebook id\n", 21 | "aid = 0" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "import os\n", 35 | "\n", 36 | "sns.set_style(\"whitegrid\")\n", 37 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", 38 | "\n", 39 | "from itertools import product\n", 40 | "\n", 41 | "import nshap\n", 42 | "\n", 43 | "from paperutil import checkerboard_function" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "### The different compute jobs, and the current job" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "interaction_orders = [2, 3, 4, 5, 6, 7, 8, 9, 10]\n", 60 | "replications = list(range(10))\n", 61 | "\n", 62 | "all_jobs = list(product(interaction_orders, replications))\n", 63 | "print(len(all_jobs), 'different compute jobs')" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "interaction_order = all_jobs[aid][0]\n", 73 | "replication = all_jobs[aid][1]" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "### Create output dir structure, if it does not already exist" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "paths = ['../../results/', \n", 90 | " '../../results/n_shapley_values/', \n", 91 | " '../../results/n_shapley_values/checkerboard/']\n", 92 | "for p in paths:\n", 93 | " if not os.path.exists( p ):\n", 94 | " os.mkdir( p )" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### Compute n-Shapley Values for a k-dimensional checkerboard in a 10-dimensional space" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "d = 10\n", 111 | "f = checkerboard_function(interaction_order, num_checkers=100)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "tags": [] 119 | }, 120 | "outputs": [], 121 | "source": [ 122 | "np.random.seed(replication)\n", 123 | "\n", 124 | "for num_samples in [1000000]:\n", 125 | " X_train = np.random.uniform(0, 1, size=(1000000, d))\n", 126 | " X_test = np.random.uniform(0, 1, size=(1, d))\n", 127 | " vfunc = vfunc = nshap.vfunc.interventional_shap(f, X_train, num_samples = num_samples, random_state=replication)\n", 128 | " n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc)\n", 129 | " n_shapley_values.save(f'../../results/n_shapley_values/checkerboard/{interaction_order}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON')\n", 130 | " n_shapley_values.plot()" 131 | ] 132 | } 133 | ], 134 | "metadata": { 135 | "kernelspec": { 136 | "display_name": "Python 3", 137 | "language": "python", 138 | "name": "python3" 139 | }, 140 | "language_info": { 141 | "codemirror_mode": { 142 | "name": "ipython", 143 | "version": 3 144 | }, 145 | "file_extension": ".py", 146 | "mimetype": "text/x-python", 147 | "name": "python", 148 | "nbconvert_exporter": "python", 149 | "pygments_lexer": "ipython3", 150 | "version": "3.8.3" 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 5 155 | } 156 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/checkerboard-compute.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Compute for the checkerboard function" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "tags": [ 15 | "parameters" 16 | ] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# papermill parameter: notebook id\n", 21 | "aid = 0" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import matplotlib.pyplot as plt\n", 31 | "import seaborn as sns\n", 32 | "\n", 33 | "import numpy as np\n", 34 | "import os\n", 35 | "\n", 36 | "sns.set_style(\"whitegrid\")\n", 37 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", 38 | "\n", 39 | "import nshap\n", 40 | "\n", 41 | "from paperutil import checkerboard_function" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "interaction_order = [2, 3, 4, 5, 6, 7, 8, 9, 10]\n", 51 | "\n", 52 | "all_jobs = list(interaction_order)\n", 53 | "print(len(all_jobs), 'different compute jobs')" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### Create output dir structure, if it does not already exist" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "paths = ['../../results/', \n", 70 | " '../../results/n_shapley_values/', \n", 71 | " '../../results/n_shapley_values/checkerboard/']\n", 72 | "for p in paths:\n", 73 | " if not os.path.exists( p ):\n", 74 | " os.mkdir( p )" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "### Compute n-Shapley Values for a k-dimensional checkerboard in a 10-dimensional space" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "d = 10\n", 91 | "f = checkerboard_function(interaction_order[aid], num_checkers=100)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": { 98 | "tags": [] 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "for num_samples in [100, 1000, 10000, 100000]:\n", 103 | " result = []\n", 104 | " for replication in range(10): \n", 105 | " X_train = np.random.uniform(0, 1, size=(1000000, d))\n", 106 | " X_test = np.random.uniform(0, 1, size=(1, d))\n", 107 | " vfunc = vfunc = nshap.vfunc.interventional_shap(f, X_train, num_samples = num_samples)\n", 108 | " n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc)\n", 109 | " n_shapley_values.save(f'../../results/n_shapley_values/checkerboard/{interaction_order[aid]}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON')\n", 110 | " n_shapley_values.plot()\n", 111 | " true_order = np.sum([np.abs(v) for k, v in n_shapley_values.items() if len(k) == interaction_order[aid]]) \n", 112 | " all_orders = np.sum([np.abs(v) for k, v in n_shapley_values.items()])\n", 113 | " result.append(true_order / all_orders)\n", 114 | " print(f'{interaction_order[aid]}, {num_samples}, {np.mean(result)}, {result}')" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.8.3" 135 | } 136 | }, 137 | "nbformat": 4, 138 | "nbformat_minor": 5 139 | } 140 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/checkerboard-figures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Figures for the checkerboard function" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import matplotlib.pyplot as plt\n", 17 | "import seaborn as sns\n", 18 | "\n", 19 | "sns.set_style(\"whitegrid\")\n", 20 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import nshap\n", 24 | "import os" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### Visualization in 2d" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from paperutil import checkerboard_function\n", 41 | "\n", 42 | "d = 2\n", 43 | "num_points = 1000\n", 44 | "f = checkerboard_function(2, 4)\n", 45 | "X = np.random.uniform(0, 1, size=(num_points, d))\n", 46 | "Y = f(X)\n", 47 | "\n", 48 | "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=Y)\n", 49 | "plt.show()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "### Load the pre-computed n-Shpaley Values" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "tags": [] 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "results = {}\n", 68 | "for degree in range(2, 11):\n", 69 | " results[degree] = {}\n", 70 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n", 71 | " results[degree][num_samples] = []\n", 72 | " for replication in range(10):\n", 73 | " fname = f'../../results/n_shapley_values/checkerboard/{degree}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON'\n", 74 | " if os.path.exists(fname):\n", 75 | " results[degree][num_samples].append( nshap.load(fname) )" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Correctly estimated fractions" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "tags": [] 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "sns.set_style(\"whitegrid\")\n", 94 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 2}, font_scale=2)\n", 95 | "\n", 96 | "plt.figure(figsize=(12.5, 8))\n", 97 | "for checkerboard_dim in range(2, 11):\n", 98 | " y = []\n", 99 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n", 100 | " yy = []\n", 101 | " for x in results[checkerboard_dim][num_samples]:\n", 102 | " true_order = np.sum([np.abs(v) for k, v in x.items() if len(k) == checkerboard_dim]) \n", 103 | " all_orders = np.sum([np.abs(v) for k, v in x.items()])\n", 104 | " yy.append(true_order / all_orders)\n", 105 | " y.append(np.mean(yy))\n", 106 | " sns.scatterplot(x=[1,2,3,4,5], y=y, color=sns.color_palette(\"tab10\")[checkerboard_dim-1], s=200)\n", 107 | " plt.plot([1,2,3,4,5], y, c=sns.color_palette(\"colorblind\")[checkerboard_dim-1], ls='--', lw=1.5)\n", 108 | "plt.ylim([-0.04, 1.04])\n", 109 | "plt.yticks(np.arange(0, 1.1, 0.1))\n", 110 | "plt.xticks([1,2,3,4,5], ['100', '1000', '10 000', '100 000', '1 000 000'])\n", 111 | "plt.xlabel('Number of points sampled to estimate the value function')\n", 112 | "plt.ylabel('Fraction of checkerboard function\\nthat is correctly estimated')\n", 113 | "plt.savefig('../../figures/checkerboard_estimation.pdf')\n", 114 | "plt.show()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "### Plots for individual degrees" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "sns.set_style(\"whitegrid\")\n", 131 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 2}, font_scale=2)\n", 132 | "\n", 133 | "for checkerboard_dim in range(2, 11): \n", 134 | " plt.figure(figsize=(6, 6))\n", 135 | " plt.ylim([-0.04, 1.04])\n", 136 | " for interaction_order in range(1,11):\n", 137 | " y, y_min, y_max = [], [], [] \n", 138 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n", 139 | " yy = []\n", 140 | " for x in results[checkerboard_dim][num_samples]:\n", 141 | " order_sum = np.sum([np.abs(v) for k, v in x.items() if len(k) == interaction_order]) \n", 142 | " all_sum = np.sum([np.abs(v) for k, v in x.items()])\n", 143 | " yy.append(order_sum / all_sum)\n", 144 | " y.append(np.mean(yy))\n", 145 | " y_min.append(np.min(yy))\n", 146 | " y_max.append(np.max(yy))\n", 147 | " if np.max(y) > 0.01:\n", 148 | " sns.scatterplot(x=[1,2,3,4,5], y=y, color=sns.color_palette(\"tab10\")[interaction_order-1], s=200)\n", 149 | " plt.plot([1,2,3,4,5], y, c=sns.color_palette(\"colorblind\")[interaction_order-1], ls='--', lw=1.5)\n", 150 | " # convert y_min and y_max from y coordinates to plot range\n", 151 | " y_min = [x/1.08+0.04 for x in y_min]\n", 152 | " y_max = [x/1.08+0.04 for x in y_max]\n", 153 | " for x_pos in [1,2,3,4,5]:\n", 154 | " plt.axvline(x=x_pos, ymin=y_min[x_pos-1], ymax=y_max[x_pos-1], c=sns.color_palette(\"colorblind\")[interaction_order-1], lw=4)\n", 155 | " plt.yticks(np.arange(0, 1.1, 0.1))\n", 156 | " plt.xticks([1,2,3,4,5], ['100', '1000', '10 000', '100 000', '1 000 000'], rotation=20)\n", 157 | " plt.title(f'{checkerboard_dim}d Checkerboard Function')\n", 158 | " plt.xlabel('Number of points sampled')\n", 159 | " plt.ylabel('Order of estimated effects')\n", 160 | " plt.tight_layout()\n", 161 | " plt.savefig(f'../../figures/{checkerboard_dim}_checkerboard_estimation.pdf')\n", 162 | " plt.show()" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "Python 3", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.8.3" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 5 187 | } 188 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/compute-vfunc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for parallel evaluation of the value function" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "tags": [ 15 | "parameters" 16 | ] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# papermill parameter: notebook id\n", 21 | "aid = 0" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy as np\n", 31 | "import os\n", 32 | "\n", 33 | "import nshap\n", 34 | "\n", 35 | "import paperutil\n", 36 | "\n", 37 | "from itertools import product\n", 38 | "\n", 39 | "import datasets\n", 40 | "\n", 41 | "%load_ext autoreload\n", 42 | "%autoreload 2" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### The different compute jobs" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "subsets = list(nshap.powerset(list(range(10))))\n", 59 | "\n", 60 | "all_jobs = list(subsets)\n", 61 | "print(len(all_jobs), 'different compute jobs')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "### The current job" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "job_id = aid\n", 78 | "dataset = 'folk_travel'\n", 79 | "classifier = 'knn'\n", 80 | "i_datapoint = 0\n", 81 | "random_seed = i_datapoint\n", 82 | "\n", 83 | "print(job_id, dataset, classifier, i_datapoint, random_seed)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### Load the dataset" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "is_classification = datasets.is_classification(dataset)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "### Predict, proba or decision" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "method = 'predict'\n", 125 | "if is_classification:\n", 126 | " method = 'proba'\n", 127 | "if classifier == 'gam':\n", 128 | " method = 'decision'" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "### The number of samples is limited by the size of the data set" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "max_samples = 1000000\n", 145 | "num_samples = min(max_samples, X_train.shape[0])" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Create output dir structure, if it does not already exist" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "froot = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_samples}/'" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "paths = ['../../results/', \n", 171 | " '../../results/n_shapley_values/'\n", 172 | " f'../../results/n_shapley_values/{dataset}/', \n", 173 | " f'../../results/n_shapley_values/{dataset}/{classifier}/',\n", 174 | " froot]\n", 175 | "for p in paths:\n", 176 | " if not os.path.exists( p ):\n", 177 | " os.mkdir( p )" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "### Train the classifier" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "clf = paperutil.train_classifier(dataset, classifier)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "### The value function" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "if method == 'predict':\n", 210 | " vfunc = nshap.vfunc.interventional_shap(clf.predict, X_train, num_samples=num_samples, random_state=0)\n", 211 | "elif method == 'proba':\n", 212 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n", 213 | " vfunc = nshap.vfunc.interventional_shap(clf.predict_proba, X_train, num_samples=num_samples, random_state=0, target=prediction)\n", 214 | "elif method == 'decision':\n", 215 | " vfunc = nshap.vfunc.interventional_shap(clf.decision_function, X_train, num_samples=num_samples, random_state=0)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "### Evaluate the value function" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "tags": [] 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "for idx in range(10):\n", 234 | " S = subsets[10*job_id + idx] # 10 jobs per notebook\n", 235 | " if len(S) > 0 and np.max(S) >= X_train.shape[1]:\n", 236 | " continue\n", 237 | " fname = froot + f'v{S}.txt' \n", 238 | " # evaluate the value function and save the result\n", 239 | " if not os.path.exists(fname):\n", 240 | " result = vfunc(X_test[i_datapoint, :].reshape((1,-1)), S)\n", 241 | " with open(fname, 'w+') as f:\n", 242 | " f.write(f'{result}')" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.8.3" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 5 267 | } 268 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/compute.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for parallel computation of n-Shapley Values" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "tags": [ 15 | "parameters" 16 | ] 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "# papermill parameter: notebook id\n", 21 | "aid = 0" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "# second compute wave\n", 31 | "#aid = aid + 1000" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import numpy as np\n", 41 | "\n", 42 | "import os\n", 43 | "\n", 44 | "import datasets\n", 45 | "import nshap\n", 46 | "\n", 47 | "from itertools import product\n", 48 | "\n", 49 | "import paperutil\n", 50 | "\n", 51 | "%load_ext autoreload\n", 52 | "%autoreload 2" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "### The different compute jobs" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']\n", 69 | "classifiers = ['rf', 'knn', 'gam', 'gbtree']\n", 70 | "examples = list(range(0, 100))\n", 71 | "\n", 72 | "all_jobs = list(product(data_sets, classifiers, examples))\n", 73 | "print(len(all_jobs), 'different compute jobs')" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "for data_set in data_sets:\n", 83 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(data_set)\n", 84 | " print(data_set, X_train.shape[0])" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### The current job" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "job_id = aid\n", 101 | "dataset = all_jobs[job_id][0]\n", 102 | "classifier = all_jobs[job_id][1]\n", 103 | "example = all_jobs[job_id][2]\n", 104 | "random_seed = example\n", 105 | "\n", 106 | "print(job_id, dataset, classifier, example, random_seed)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### Create output dir structure, if it does not already exist" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "if not os.path.exists( f'../../results/n_shapley_values/{dataset}' ):\n", 123 | " os.mkdir( f'../../results/n_shapley_values/{dataset}' )\n", 124 | "if not os.path.exists( f'../../results/n_shapley_values/{dataset}/{classifier}' ):\n", 125 | " os.mkdir( f'../../results/n_shapley_values/{dataset}/{classifier}' )" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "### Load the dataset" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "is_classification = datasets.is_classification(dataset)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "metadata": {}, 156 | "source": [ 157 | "### Train the classifier" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "clf = paperutil.train_classifier(dataset, classifier)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "if is_classification:\n", 176 | " print( sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test)) )\n", 177 | "else:\n", 178 | " print( sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test)) )" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "### n-Shapley Values" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": { 192 | "tags": [] 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "i_datapoint = example\n", 197 | "froot = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}'\n", 198 | "for max_samples in [500, 5000]:\n", 199 | " num_samples = min(max_samples, X_train.shape[0])\n", 200 | " # the value function\n", 201 | " vfunc = nshap.vfunc.interventional_shap(clf.predict, X_train, num_samples=num_samples, random_state=0)\n", 202 | " fname = froot + f'_predict_{num_samples}.JSON'\n", 203 | " if is_classification:\n", 204 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n", 205 | " vfunc = nshap.vfunc.interventional_shap(clf.predict_proba, X_train, num_samples=num_samples, random_state=0, target=prediction)\n", 206 | " fname = froot + f'_proba_{num_samples}.JSON'\n", 207 | " if classifier == 'gam':\n", 208 | " vfunc = nshap.vfunc.interventional_shap(clf.decision_function, X_train, num_samples=num_samples, random_state=0)\n", 209 | " fname = froot + f'_decision_{num_samples}.JSON'\n", 210 | " # compute and save n-Shapley Values\n", 211 | " if not os.path.exists(fname):\n", 212 | " n_shapley_values = nshap.n_shapley_values(X_test[i_datapoint, :].reshape((1,-1)), vfunc)\n", 213 | " n_shapley_values.save(fname)" 214 | ] 215 | } 216 | ], 217 | "metadata": { 218 | "kernelspec": { 219 | "display_name": "Python 3", 220 | "language": "python", 221 | "name": "python3" 222 | }, 223 | "language_info": { 224 | "codemirror_mode": { 225 | "name": "ipython", 226 | "version": 3 227 | }, 228 | "file_extension": ".py", 229 | "mimetype": "text/x-python", 230 | "name": "python", 231 | "nbconvert_exporter": "python", 232 | "pygments_lexer": "ipython3", 233 | "version": "3.8.3" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 5 238 | } 239 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/datasets.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################################## 2 | # Load the different datasets. 3 | # 4 | # The features are scaled to have mean zero and unit variance. 5 | # 6 | # All functions return: X_train, X_test, Y_train, Y_test, feature_names 7 | ######################################################################################################################## 8 | 9 | import numpy as np 10 | 11 | import sklearn 12 | import sklearn.datasets 13 | from sklearn.model_selection import train_test_split 14 | from sklearn.preprocessing import StandardScaler 15 | 16 | import pandas as pd 17 | 18 | from folktables import ACSDataSource, ACSIncome, ACSTravelTime 19 | 20 | import os 21 | 22 | data_root_dir = "../../data/" 23 | 24 | 25 | def german_credit(seed=0): 26 | """ The german credit dataset 27 | """ 28 | feature_names = [ 29 | "checking account status", 30 | "Duration", 31 | "Credit history", 32 | "Purpose", 33 | "Credit amount", 34 | "Savings account/bonds", 35 | "Present employment since", 36 | "Installment rate in percentage of disposable income", 37 | "Personal status and sex", 38 | "Other debtors / guarantors", 39 | "Present residence since", 40 | "Property", 41 | "Age in years", 42 | "Other installment plans", 43 | "Housing", 44 | "Number of existing credits at this bank", 45 | "Job", 46 | " Number of people being liable to provide maintenance for", 47 | "Telephone", 48 | "foreign worker", 49 | ] 50 | columns = [*feature_names, "target"] 51 | 52 | data = pd.read_csv(os.path.join(data_root_dir, "german.data"), sep=" ", header=None) 53 | data.columns = columns 54 | Y = data["target"] - 1 55 | X = data 56 | X = X.drop("target", axis=1) 57 | cat_columns = X.select_dtypes(["object"]).columns 58 | X[cat_columns] = X[cat_columns].apply(lambda x: x.astype("category").cat.codes) 59 | 60 | # zero mean and unit variance for all features 61 | X = StandardScaler().fit_transform(X) 62 | 63 | # train-test split 64 | X_train, X_test, Y_train, Y_test = train_test_split( 65 | X, Y, train_size=0.8, random_state=seed 66 | ) 67 | 68 | return X_train, X_test, Y_train, Y_test, feature_names 69 | 70 | 71 | def iris(seed=0): 72 | """ The iris dataset, class 1 vs. the rest. 73 | """ 74 | # load the dataset 75 | iris = sklearn.datasets.load_iris() 76 | X = iris.data 77 | Y = iris.target 78 | 79 | # feature names 80 | feature_names = iris.feature_names 81 | 82 | # create a binary outcome 83 | Y = Y == 1 84 | 85 | # zero mean and unit variance for all features 86 | X = StandardScaler().fit_transform(X) 87 | 88 | # train-test split 89 | X_train, X_test, Y_train, Y_test = train_test_split( 90 | X, Y, train_size=0.8, random_state=seed 91 | ) 92 | 93 | return X_train, X_test, Y_train, Y_test, feature_names 94 | 95 | 96 | def california_housing(seed=0, classification=False): 97 | """ The california housing dataset. 98 | """ 99 | # load the dataset 100 | housing = sklearn.datasets.fetch_california_housing() 101 | X = housing.data 102 | Y = housing.target 103 | 104 | # feature names 105 | feature_names = housing.feature_names 106 | 107 | # create a binary outcome 108 | if classification: 109 | Y = Y > np.median(Y) 110 | 111 | # zero mean and unit variance for all features 112 | X = StandardScaler().fit_transform(X) 113 | 114 | # train-test split 115 | X_train, X_test, Y_train, Y_test = train_test_split( 116 | X, Y, train_size=0.8, random_state=seed 117 | ) 118 | 119 | return X_train, X_test, Y_train, Y_test, feature_names 120 | 121 | 122 | def folktables_acs_income(seed=0, survey_year="2016", states=["CA"]): 123 | # (down-)load the dataset 124 | data_source = ACSDataSource( 125 | survey_year=survey_year, 126 | horizon="1-Year", 127 | survey="person", 128 | root_dir=data_root_dir, 129 | ) 130 | data = data_source.get_data(states=states, download=True) 131 | X, Y, _ = ACSIncome.df_to_numpy(data) 132 | 133 | # feature names 134 | feature_names = ACSIncome.features 135 | 136 | # zero mean and unit variance for all features 137 | X = StandardScaler().fit_transform(X) 138 | 139 | # train-test split 140 | X_train, X_test, Y_train, Y_test = train_test_split( 141 | X, Y, train_size=0.8, random_state=seed 142 | ) 143 | 144 | return X_train, X_test, Y_train, Y_test, feature_names 145 | 146 | 147 | def folktables_acs_travel_time(seed=0, survey_year="2016", states=["CA"]): 148 | # (down-)load the dataset 149 | data_source = ACSDataSource( 150 | survey_year=survey_year, 151 | horizon="1-Year", 152 | survey="person", 153 | root_dir=data_root_dir, 154 | ) 155 | data = data_source.get_data(states=states, download=True) 156 | X, Y, _ = ACSTravelTime.df_to_numpy(data) 157 | 158 | # feature names 159 | feature_names = ACSTravelTime.features 160 | 161 | # zero mean and unit variance for all features 162 | X = StandardScaler().fit_transform(X) 163 | 164 | # train-test split 165 | X_train, X_test, Y_train, Y_test = train_test_split( 166 | X, Y, train_size=0.8, random_state=seed 167 | ) 168 | 169 | return X_train, X_test, Y_train, Y_test, feature_names 170 | 171 | 172 | ######################################################################################################################## 173 | # Functions to ease access to all the different datasets 174 | ######################################################################################################################## 175 | 176 | dataset_dict = { 177 | "iris": iris, 178 | "folk_income": folktables_acs_income, 179 | "folk_travel": folktables_acs_travel_time, 180 | "housing": california_housing, 181 | "credit": german_credit, 182 | } 183 | 184 | 185 | def get_datasets(): 186 | """ Returns the names of the available datasets. 187 | """ 188 | return dataset_dict 189 | 190 | 191 | def is_classification(dataset): 192 | if dataset == "housing": 193 | return False 194 | return True 195 | 196 | 197 | def load_dataset(dataset): 198 | if dataset == "folk_income": 199 | X_train, X_test, Y_train, Y_test, feature_names = folktables_acs_income(0) 200 | elif dataset == "folk_travel": 201 | X_train, X_test, Y_train, Y_test, feature_names = folktables_acs_travel_time(0) 202 | # subset the dataset to 10 features to ease computation 203 | feature_subset = [13, 14, 9, 0, 12, 15, 1, 3, 7, 11] 204 | feature_names = [feature_names[i] for i in feature_subset] 205 | X_train = X_train[:, feature_subset] 206 | X_test = X_test[:, feature_subset] 207 | elif dataset == "housing": 208 | X_train, X_test, Y_train, Y_test, feature_names = california_housing(0) 209 | elif dataset == "credit": 210 | X_train, X_test, Y_train, Y_test, feature_names = german_credit(0) 211 | # subset the dataset to 10 features to ease computation 212 | feature_subset = [0, 1, 2, 3, 4, 5, 6, 7, 14, 11] 213 | feature_names = [feature_names[i] for i in feature_subset] 214 | X_train = X_train[:, feature_subset] 215 | X_test = X_test[:, feature_subset] 216 | elif dataset == "iris": 217 | X_train, X_test, Y_train, Y_test, feature_names = iris(0) 218 | return X_train, X_test, Y_train, Y_test, feature_names 219 | 220 | 221 | def get_feature_names(dataset): 222 | """ Shortened for better plotting. 223 | """ 224 | if dataset == "folk_income": 225 | return folktables_acs_income()[4] 226 | elif dataset == "folk_travel": 227 | feature_names = folktables_acs_travel_time()[4] 228 | feature_names = [feature_names[i] for i in [13, 14, 9, 0, 12, 15, 1, 3, 7, 11]] 229 | feature_names[1] = "POWP" 230 | return feature_names 231 | elif dataset == "housing": 232 | return california_housing()[4] 233 | elif dataset == "credit": 234 | return [ 235 | "Account", 236 | "Duration", 237 | "History", 238 | "Purpose", 239 | "Amount", 240 | "Savings", 241 | "Employ", 242 | "Rate", 243 | "Housing", 244 | "Property", 245 | ] 246 | elif dataset == "iris": 247 | return ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"] 248 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/estimation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Estimating n-Shapley Values for the kNN classifier on the Folktables Travel data set" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import matplotlib.pyplot as plt\n", 17 | "import seaborn as sns\n", 18 | "\n", 19 | "sns.set_style(\"whitegrid\")\n", 20 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.1)\n", 21 | "\n", 22 | "import datasets\n", 23 | "import nshap" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### Load the data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "_, X_test, _, _, feature_names = datasets.load_dataset('folk_travel')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### Load the pre-computed n-Shapley Values" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "n_shapley_values = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_500.JSON')\n", 56 | "n_shapley_values_5000 = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_5000.JSON')" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "def vfunc(x, S):\n", 66 | " S = tuple(S)\n", 67 | " fname = f'../../results/n_shapley_values/folk_travel/knn/observation_0_proba_133549/v{S}.txt' \n", 68 | " with open(fname, 'r') as f:\n", 69 | " result = float( f.read() )\n", 70 | " return result" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "n_shapley_values_133549 = nshap.n_shapley_values(X_test[0, :], vfunc)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "### Plots" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "for idx, v in enumerate([n_shapley_values, n_shapley_values_5000, n_shapley_values_133549]):\n", 96 | " fig, ax = plt.subplots(1, 1, figsize=(7, 7.9))\n", 97 | " v.plot(axis=ax, legend=False, feature_names=feature_names, rotation=60)\n", 98 | " plt.ylim([-0.295, 0.29])\n", 99 | " plt.title(f'Shapey-GAM, {[\"500\", \"5000\", \"133549\"][idx]} Samples')\n", 100 | " plt.tight_layout()\n", 101 | " plt.savefig(f'../../figures/knn_estimation_{idx}.pdf')\n", 102 | " plt.show()" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "### Latex code for Table in Appendix" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "tags": [] 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "for S in nshap.powerset(list(range(10))):\n", 121 | " if len(S) == 0:\n", 122 | " continue\n", 123 | " print(f'{S} & {n_shapley_values[S]:0.4f} & {n_shapley_values_5000[S]:0.4f} & {n_shapley_values_133549[S]:0.4f} \\\\\\\\')" 124 | ] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.8.3" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 5 148 | } 149 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/figures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generate the Figures in the Paper" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "\n", 18 | "import os\n", 19 | "import sklearn\n", 20 | "from interpret.glassbox import ExplainableBoostingClassifier\n", 21 | "\n", 22 | "from itertools import product\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import seaborn as sns\n", 26 | "\n", 27 | "import datasets\n", 28 | "import paperutil\n", 29 | "import nshap\n", 30 | "\n", 31 | "%load_ext autoreload\n", 32 | "%autoreload 2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "### Load the computed n-Shapley Values" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']\n", 49 | "classifiers = ['gam', 'rf', 'gbtree', 'knn']\n", 50 | "methods = ['predict', 'proba', 'decision']" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "tags": [] 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "shapley_values = {}\n", 62 | "for dataset in data_sets:\n", 63 | " X_train, X_test, _, _, _ = datasets.load_dataset(dataset)\n", 64 | " shapley_values[dataset] = {}\n", 65 | " num_datapoints = min(5000, X_train.shape[0]) \n", 66 | " for classifier in classifiers:\n", 67 | " shapley_values[dataset][classifier] = {}\n", 68 | " for method in methods:\n", 69 | " if os.path.exists(f'../../results/n_shapley_values/{dataset}/{classifier}/observation_0_{method}_{num_datapoints}.JSON'):\n", 70 | " shapley_values[dataset][classifier][method] = []\n", 71 | " for i_datapoint in range(min(X_test.shape[0], 100)):\n", 72 | " fname = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_datapoints}.JSON'\n", 73 | " if os.path.exists(fname):\n", 74 | " n_shapley_values = nshap.load(fname)\n", 75 | " shapley_values[dataset][classifier][method].append(n_shapley_values)\n", 76 | " else:\n", 77 | " print(f'File {fname} not found.')" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "### Create directory structure" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "paths = ['../../figures/', '../../figures/partial_dependence/', '../../figures/shapley_gam/', '../../figures/n_shapley_values/']\n", 94 | "for p in paths:\n", 95 | " if not os.path.exists( p ):\n", 96 | " os.mkdir( p )" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Plot Settings" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# avoid type-3 fonts\n", 113 | "import matplotlib\n", 114 | "matplotlib.rcParams['pdf.fonttype'] = 42\n", 115 | "matplotlib.rcParams['ps.fonttype'] = 42" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "rotation = {'folk_income': 60, 'folk_travel': 60, 'housing': 60, 'credit': 60, 'iris': 0}" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "sns.set_style(\"whitegrid\")\n", 134 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "def asdjust_ylim(axlist):\n", 144 | " ymin = min([x.get_ylim()[0] for x in axlist])\n", 145 | " ymax = max([x.get_ylim()[1] for x in axlist])\n", 146 | " for ax in axlist:\n", 147 | " ax.set_ylim((ymin, ymax))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "### Plots of n-Shapley Values" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "sns.set_style(\"whitegrid\")\n", 164 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.1)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "tags": [] 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "for dataset in data_sets:\n", 176 | " feature_names = datasets.get_feature_names(dataset)\n", 177 | " for classifier in classifiers:\n", 178 | " for method in methods:\n", 179 | " # different methods for different classifiers\n", 180 | " if not method in shapley_values[dataset][classifier]: \n", 181 | " continue\n", 182 | " print(dataset, classifier, method)\n", 183 | " for i_datapoint in range(5):\n", 184 | " n_shapley_values = shapley_values[dataset][classifier][method][i_datapoint]\n", 185 | " fig, ax = plt.subplots(1, 4, figsize=(22.5, 6.75))\n", 186 | " ax0 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(1), axis=ax[0], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n", 187 | " ax0.set_ylabel('Feature Attribution')\n", 188 | " ax0.set_title('Shapley Values')\n", 189 | " ax1 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(2), axis=ax[1], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n", 190 | " ax1.set(yticklabels= []) \n", 191 | " ax1.set_title('Shapley Interaction Values')\n", 192 | " ax2 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(4), axis=ax[2], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n", 193 | " ax2.set(yticklabels= [])\n", 194 | " ax2.set_title('4-Shapley Values')\n", 195 | " ax3 = nshap.plot_n_shapley(n_shapley_values, axis=ax[3], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n", 196 | " ax3.set(yticklabels= []) \n", 197 | " ax3.set_title('Shapley-GAM')\n", 198 | " axes = [ax0, ax1, ax2, ax3]\n", 199 | " ymin = min([x.get_ylim()[0] for x in axes])\n", 200 | " ymax = max([x.get_ylim()[1] for x in axes])\n", 201 | " for x in axes:\n", 202 | " x.set_ylim((ymin, ymax))\n", 203 | " plt.tight_layout()\n", 204 | " plt.savefig(f'../../figures/n_shapley_values/{dataset}_{classifier}_{method}_{i_datapoint}.pdf')\n", 205 | " if i_datapoint == 0:\n", 206 | " plt.show()\n", 207 | " plt.close(fig)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "### Plots of n-Shapley Values in the Appendix" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": { 221 | "tags": [] 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "sns.set_style(\"whitegrid\")\n", 226 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.25)\n", 227 | "\n", 228 | "from itertools import product\n", 229 | "\n", 230 | "for dataset in data_sets:\n", 231 | " feature_names = datasets.get_feature_names(dataset)\n", 232 | " if dataset == 'iris':\n", 233 | " continue\n", 234 | " for classifier in classifiers:\n", 235 | " method = 'proba'\n", 236 | " if dataset == 'housing':\n", 237 | " method = 'predict'\n", 238 | " if classifier == 'gam':\n", 239 | " method = 'decision'\n", 240 | " print(dataset, classifier, method)\n", 241 | " for i_datapoint in range(1):\n", 242 | " n_shapley_values = shapley_values[dataset][classifier][method][i_datapoint]\n", 243 | " if dataset == 'housing': # 8 features\n", 244 | " ncols = 4\n", 245 | " fig, ax = plt.subplots(2, ncols, figsize=(26, 14.75))\n", 246 | " else: # 10 features\n", 247 | " ncols = 5\n", 248 | " fig, ax = plt.subplots(2, ncols, figsize=(32, 14.75))\n", 249 | " for i in range(2):\n", 250 | " for j in range(ncols):\n", 251 | " k = 1 + ncols*i + j\n", 252 | " nshap.plot_n_shapley(n_shapley_values.k_shapley_values(k), axis=ax[i, j], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n", 253 | " ax[i, j].set_title(f'{k}-Shapley Values')\n", 254 | " ax[0, 0].set_ylabel('Feature Attribution')\n", 255 | " ax[1, 0].set_ylabel('Feature Attribution')\n", 256 | " ax[0, 0].set_title('Shapley Values')\n", 257 | " ax[0, 1].set_title('Shapley Interaction Values')\n", 258 | " ax[1, ncols-1].set_title('Shapley-GAM')\n", 259 | " axes = [ax[i,j] for (i,j) in product(range(2), range(ncols))]\n", 260 | " asdjust_ylim(axes)\n", 261 | " for j in range(ncols):\n", 262 | " ax[0, j].set(xticklabels= [])\n", 263 | " for i in range(2):\n", 264 | " for j in range(1,ncols):\n", 265 | " ax[i, j].set(yticklabels= [])\n", 266 | " plt.tight_layout()\n", 267 | " plt.savefig(f'../../figures/n_shapley_values/apx_{dataset}_{classifier}_{method}_{i_datapoint}_full.pdf')\n", 268 | " if i_datapoint == 0:\n", 269 | " plt.show()\n", 270 | " plt.close(fig)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "### Example Visualizations" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": { 284 | "tags": [] 285 | }, 286 | "outputs": [], 287 | "source": [ 288 | "values = {(i,):0 for i in range(4)}\n", 289 | "values[(2,)] = 0.2\n", 290 | "values[(3,)] = -0.1\n", 291 | "n_shapley_values = nshap.nShapleyValues(values)\n", 292 | " \n", 293 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 294 | "nshap.plot_n_shapley(n_shapley_values.k_shapley_values(1), legend=False, axis=ax)\n", 295 | "plt.tight_layout()\n", 296 | "plt.savefig('../../figures/example1.pdf')\n", 297 | "print(values)\n", 298 | "plt.show()\n", 299 | "\n", 300 | "values[(1,2)] = 0.1 \n", 301 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 302 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n", 303 | "plt.tight_layout()\n", 304 | "plt.savefig('../../figures/example2.pdf')\n", 305 | "print(values)\n", 306 | "plt.show()\n", 307 | "\n", 308 | "values[(2,3)] = -0.1 \n", 309 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 310 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n", 311 | "plt.tight_layout()\n", 312 | "plt.savefig('../../figures/example3.pdf')\n", 313 | "print(values)\n", 314 | "plt.show()\n", 315 | "\n", 316 | "values[(1,2,3)] = 0.1\n", 317 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 318 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n", 319 | "plt.tight_layout()\n", 320 | "plt.savefig('../../figures/example4.pdf')\n", 321 | "print(values)\n", 322 | "plt.show()\n", 323 | "\n", 324 | "values[(0,1,2,3)] = -0.1 \n", 325 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 326 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n", 327 | "plt.tight_layout()\n", 328 | "plt.savefig('../../figures/example5.pdf')\n", 329 | "print(values)\n", 330 | "plt.show()" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "### The legend" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "import matplotlib.patches as mpatches\n", 347 | "from matplotlib.patches import Rectangle\n", 348 | "from matplotlib.transforms import Bbox\n", 349 | "\n", 350 | "fig, ax = plt.subplots(1, 1, figsize=(12, 1))\n", 351 | "\n", 352 | "# legend\n", 353 | "plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', \n", 354 | " '#17becf', 'black', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']\n", 355 | "\n", 356 | "color_patches = [mpatches.Patch(color=color) for color in plot_colors]\n", 357 | "lables = ['Main']\n", 358 | "lables.append('2nd order')\n", 359 | "lables.append('3rd order')\n", 360 | "for i in range(4, 10):\n", 361 | " lables.append(f'{i}th')\n", 362 | "lables.append('10th order')\n", 363 | "ax.legend(color_patches, lables, ncol=11, fontsize=30, handletextpad=0.5, handlelength=1, handleheight=1)\n", 364 | "plt.axis('off')\n", 365 | "plt.savefig(f'../../figures/legend.pdf', bbox_inches=Bbox([[-13.5, -0.2], [11, 0.9]]))\n", 366 | "plt.show()" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "import matplotlib.patches as mpatches\n", 376 | "from matplotlib.patches import Rectangle\n", 377 | "from matplotlib.transforms import Bbox\n", 378 | "\n", 379 | "fig, ax = plt.subplots(1, 1, figsize=(12, 1))\n", 380 | "\n", 381 | "# legend\n", 382 | "plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', \n", 383 | " '#17becf', 'black', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']\n", 384 | "\n", 385 | "color_patches = [mpatches.Patch(color=color) for color in plot_colors]\n", 386 | "lables = ['Main']\n", 387 | "lables.append('2nd order')\n", 388 | "lables.append('3rd order')\n", 389 | "for i in range(4, 8):\n", 390 | " lables.append(f'{i}th')\n", 391 | "ax.legend(color_patches, lables, ncol=11, fontsize=30, handletextpad=0.5, handlelength=1, handleheight=1)\n", 392 | "plt.axis('off')\n", 393 | "plt.savefig(f'../../figures/legend7.svg', bbox_inches=Bbox([[-13.5, -0.2], [11, 0.9]]))\n", 394 | "plt.show()" 395 | ] 396 | }, 397 | { 398 | "cell_type": "markdown", 399 | "metadata": {}, 400 | "source": [ 401 | "### Shapley-GAM Figure in the paper" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "sns.set_style(\"whitegrid\")\n", 411 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.9)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 421 | "ax = nshap.plot_n_shapley(shapley_values['credit']['gam']['decision'][0], axis=ax, legend=False, feature_names=datasets.get_feature_names('credit'), rotation=60)\n", 422 | "ax.set_ylabel('Feature Attribution')\n", 423 | "ax.set_title('Glassbox-GAM')\n", 424 | "plt.tight_layout()\n", 425 | "plt.savefig(f'../../figures/A.svg')\n", 426 | "plt.show()" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "fig, ax = plt.subplots(1, 1, figsize=(4.7, 6))\n", 436 | "ax = nshap.plot_n_shapley(shapley_values['housing']['gbtree']['predict'][0], axis=ax, legend=False, feature_names=datasets.get_feature_names('housing'), rotation=60)\n", 437 | "ax.set_title('Gradient Boosted Tree')\n", 438 | "print(ax.get_ylim())\n", 439 | "plt.tight_layout()\n", 440 | "plt.savefig(f'../../figures/B.svg')\n", 441 | "plt.show()" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n", 451 | "ax = nshap.plot_n_shapley(shapley_values['folk_travel']['knn']['proba'][1], axis=ax, legend=False, feature_names=datasets.get_feature_names('folk_travel'), rotation=60)\n", 452 | "ax.set_title('k-Nearest Neighbor')\n", 453 | "print(ax.get_ylim())\n", 454 | "ax.set_ylim((-0.32, 0.34))\n", 455 | "plt.tight_layout()\n", 456 | "plt.savefig(f'../../figures/C.svg')\n", 457 | "plt.show()" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": {}, 464 | "outputs": [], 465 | "source": [ 466 | "n_shapley_values = {}\n", 467 | "for S in nshap.powerset(range(8)):\n", 468 | " if len(S) == 0:\n", 469 | " continue\n", 470 | " elif len(S) == 8:\n", 471 | " n_shapley_values[S] = 1\n", 472 | " else:\n", 473 | " n_shapley_values[S] = 0 \n", 474 | "n_shapley_values = nshap.nShapleyValues(n_shapley_values)\n", 475 | "\n", 476 | "fig, ax = plt.subplots(1, 1, figsize=(4.7, 6))\n", 477 | "ax = nshap.plot_n_shapley(n_shapley_values, axis=ax, legend=False, rotation=60)\n", 478 | "ax.set_title('8d Checkerboard Function')\n", 479 | "print(ax.get_ylim())\n", 480 | "ax.set_ylim((-0.32, 0.34))\n", 481 | "plt.yticks([]) \n", 482 | "plt.tight_layout()\n", 483 | "plt.savefig(f'../../figures/D.svg')\n", 484 | "plt.show()" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "### Partial dependence plots" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "sns.set_style(\"whitegrid\")\n", 501 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.6)" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "metadata": { 508 | "tags": [] 509 | }, 510 | "outputs": [], 511 | "source": [ 512 | "for dataset in data_sets:\n", 513 | " feature_names = datasets.get_feature_names(dataset)\n", 514 | " X_train, X_test, _, _, _ = datasets.load_dataset(dataset)\n", 515 | " num_datapoints = min(5000, X_train.shape[0]) \n", 516 | " for classifier in classifiers:\n", 517 | " method = 'proba'\n", 518 | " if dataset == 'housing':\n", 519 | " method = 'predict'\n", 520 | " if classifier == 'gam':\n", 521 | " method = 'decision'\n", 522 | " print(dataset, classifier, method)\n", 523 | " clf = paperutil.train_classifier(dataset, classifier)\n", 524 | " for i_feature in range(len(feature_names)):\n", 525 | " # collect data\n", 526 | " x = []\n", 527 | " nsv_list = []\n", 528 | " for i_datapoint in range(100):\n", 529 | " fname = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_datapoints}.JSON'\n", 530 | " if os.path.exists(fname):\n", 531 | " n_shapley_values = nshap.load(fname)\n", 532 | " if method == 'proba': \n", 533 | " # we computed the shapley values for the probablity of the predicted class\n", 534 | " # but here we want to explain the probability of class 1, for all data points\n", 535 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n", 536 | " x.append(X_test[i_datapoint, i_feature])\n", 537 | " if prediction == 0: \n", 538 | " n_shapley_values = nshap.nShapleyValues({k:-v for k,v in n_shapley_values.items()})\n", 539 | " nsv_list.append(n_shapley_values) \n", 540 | " else:\n", 541 | " x.append(X_test[i_datapoint, i_feature])\n", 542 | " nsv_list.append(n_shapley_values)\n", 543 | " else:\n", 544 | " print(f'File {fname} not found')\n", 545 | " # plot\n", 546 | " fig, ax = plt.subplots(1, 4, figsize=(30, 5.5)) # appendix\n", 547 | " #fig, ax = plt.subplots(1, 4, figsize=(30, 6)) # paper\n", 548 | " y = [n_shapley_values.k_shapley_values(1)[(i_feature,)] for n_shapley_values in nsv_list]\n", 549 | " ax0 = sns.scatterplot(x=x, y=y, ax=ax[0], s=150)\n", 550 | " ax0.set_ylabel('Feature Attribution')\n", 551 | " ax0.set_title('Shapley Values')\n", 552 | " y = [n_shapley_values.k_shapley_values(2)[(i_feature,)] for n_shapley_values in nsv_list]\n", 553 | " ax1 = sns.scatterplot(x=x, y=y, ax=ax[1], s=150)\n", 554 | " ax1.set(yticklabels= []) \n", 555 | " ax1.set_title('Shapley Interaction Values')\n", 556 | " y = [n_shapley_values.k_shapley_values(4)[(i_feature,)] for n_shapley_values in nsv_list]\n", 557 | " ax2 = sns.scatterplot(x=x, y=y, ax=ax[2], s=150)\n", 558 | " ax2.set(yticklabels= [])\n", 559 | " ax2.set_title('4-Shapley Values')\n", 560 | " y = [n_shapley_values[(i_feature,)] for n_shapley_values in nsv_list]\n", 561 | " ax3 = sns.scatterplot(x=x, y=y, ax=ax[3], s=150)\n", 562 | " ax3.set(yticklabels= []) \n", 563 | " ax3.set_title('Shapley-GAM')\n", 564 | " axes = [ax0, ax1, ax2, ax3]\n", 565 | " asdjust_ylim(axes)\n", 566 | " for ax in axes:\n", 567 | " ax.set_xlabel(f'Value of Feature {feature_names[i_feature]}')\n", 568 | " plt.tight_layout()\n", 569 | " plt.savefig(f'../../figures/partial_dependence/{dataset}_{classifier}_{i_feature}_{method}.pdf')\n", 570 | " plt.show()\n", 571 | " plt.close(fig)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "### Recovery of GAM without interaction terms" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "from interpret.glassbox import ExplainableBoostingClassifier" 588 | ] 589 | }, 590 | { 591 | "cell_type": "code", 592 | "execution_count": null, 593 | "metadata": {}, 594 | "outputs": [], 595 | "source": [ 596 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset('folk_travel')" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [ 605 | "ebm = ExplainableBoostingClassifier(feature_names=feature_names, interactions=0, random_state=0)\n", 606 | "ebm.fit(X_train[:50000], Y_train[:50000])\n", 607 | "(ebm.predict(X_test) == Y_test).mean()" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": null, 613 | "metadata": { 614 | "tags": [] 615 | }, 616 | "outputs": [], 617 | "source": [ 618 | "from interpret.provider import InlineProvider\n", 619 | "from interpret import set_visualize_provider\n", 620 | "\n", 621 | "set_visualize_provider(InlineProvider())\n", 622 | "\n", 623 | "from interpret import show\n", 624 | "\n", 625 | "ebm_global = ebm.explain_global()\n", 626 | "show(ebm_global)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "markdown", 631 | "metadata": {}, 632 | "source": [ 633 | "#### Compute KernelShap explanations" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "import shap \n", 643 | "\n", 644 | "X_train_summary = shap.kmeans(X_train, 25)\n", 645 | "kernel_explainer = shap.KernelExplainer(ebm.decision_function, X_train_summary)" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "metadata": { 652 | "tags": [] 653 | }, 654 | "outputs": [], 655 | "source": [ 656 | "kernel_shap_values = []\n", 657 | "for i in range(100):\n", 658 | " kernel_shap_values.append( kernel_explainer.shap_values(X_test[i, :]) )" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": null, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "for ifeature in [1]:\n", 668 | " print(f'--------------------------- {ifeature} ---------------------')\n", 669 | " # partial influence of ifeature in the gam\n", 670 | " x_ifeature = []\n", 671 | " gam_v = []\n", 672 | " for i in range(100):\n", 673 | " x_hat = np.zeros((1,10))\n", 674 | " x_hat[0, ifeature] = X_test[i, ifeature]\n", 675 | " x_ifeature.append(X_test[i, ifeature])\n", 676 | " gam_v.append(ebm.decision_function(x_hat)[0])\n", 677 | " fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", 678 | " sns.scatterplot(x_ifeature, gam_v-np.mean(gam_v))\n", 679 | " plt.title('Explainable Boosting')\n", 680 | " plt.xlabel(f'{feature_names[ifeature]}')\n", 681 | " plt.ylabel(f'Score')\n", 682 | " plt.tight_layout()\n", 683 | " plt.savefig(f'../../figures/recovery_ebm.pdf')\n", 684 | " plt.show()\n", 685 | "\n", 686 | " # shapley value of feature i\n", 687 | " shapley_v = []\n", 688 | " for i in range(100):\n", 689 | " shapley_v.append(kernel_shap_values[i][ifeature])\n", 690 | " fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n", 691 | " sns.scatterplot(x_ifeature, shapley_v-np.mean(shapley_v))\n", 692 | " plt.title('kernel SHAP')\n", 693 | " plt.xlabel(f'{feature_names[ifeature]}')\n", 694 | " plt.ylabel(f'Score')\n", 695 | " plt.tight_layout()\n", 696 | " plt.savefig(f'../../figures/recovery_kernel_shap.pdf')\n", 697 | " plt.show()" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": null, 703 | "metadata": {}, 704 | "outputs": [], 705 | "source": [ 706 | "sns.set_style(\"whitegrid\")\n", 707 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.4)\n", 708 | "\n", 709 | "img = plt.imread(\"../../figures/gam_curve.png\")\n", 710 | "fig, ax = plt.subplots(figsize=(7,4))\n", 711 | "ax.imshow(img, extent=[-2.1, 2.15, -0.99, 0.83], aspect='auto')\n", 712 | "sns.scatterplot(x_ifeature, shapley_v-np.mean(shapley_v), color='r', s=50)\n", 713 | "plt.xlabel(f'Value of Feature POWPUMA')\n", 714 | "plt.ylabel(f'Kernel SHAP Attribution')\n", 715 | "plt.tight_layout()\n", 716 | "plt.savefig(f'../../figures/recovery.svg')\n", 717 | "plt.show()" 718 | ] 719 | }, 720 | { 721 | "cell_type": "markdown", 722 | "metadata": {}, 723 | "source": [ 724 | "### Accuracy vs Average Degree of Variable Interaction in the Shapley-GAM" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": null, 730 | "metadata": {}, 731 | "outputs": [], 732 | "source": [ 733 | "accuracies = {}\n", 734 | "for dataset in data_sets:\n", 735 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)\n", 736 | " is_classification = datasets.is_classification(dataset)\n", 737 | " accuracies[dataset] = {}\n", 738 | " for classifier in classifiers:\n", 739 | " clf = paperutil.train_classifier(dataset, classifier)\n", 740 | " # accuracy / mse\n", 741 | " if is_classification:\n", 742 | " accuracies[dataset][classifier] = sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test))\n", 743 | " else:\n", 744 | " accuracies[dataset][classifier] = sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test))\n", 745 | " print(dataset, classifier, accuracies[dataset][classifier])" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [ 754 | "complexities = {}\n", 755 | "for dataset in data_sets:\n", 756 | " complexities[dataset] = {}\n", 757 | " for classifier in classifiers:\n", 758 | " method = 'proba'\n", 759 | " if dataset == 'housing':\n", 760 | " method = 'predict'\n", 761 | " if method == 'proba' and classifier == 'svm':\n", 762 | " continue\n", 763 | " if classifier == 'gam':\n", 764 | " method = 'decision'\n", 765 | " v = []\n", 766 | " for n_shapley_values in shapley_values[dataset][classifier][method]: \n", 767 | " degree_contributions = n_shapley_values.get_degree_contributions()\n", 768 | " integral = np.sum(degree_contributions*list(range(1, len(degree_contributions)+1))) / np.sum(degree_contributions)\n", 769 | " v.append(integral)\n", 770 | " complexities[dataset][classifier] = np.mean(v)\n", 771 | " print(dataset, classifier, np.mean(v))" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": null, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "sns.set_theme(font_scale=1.3)\n", 781 | "\n", 782 | "x = []\n", 783 | "y = []\n", 784 | "hue = []\n", 785 | "style = []\n", 786 | "for dataset in data_sets:\n", 787 | " if dataset == 'housing':\n", 788 | " continue\n", 789 | " for classifier in classifiers:\n", 790 | " x.append(complexities[dataset][classifier])\n", 791 | " y.append(accuracies[dataset][classifier])\n", 792 | " hue.append(dataset)\n", 793 | " style.append(classifier)\n", 794 | " \n", 795 | "fig, ax = plt.subplots(1, 1, figsize=(6, 3))\n", 796 | "ax = sns.scatterplot(x, y, hue=hue, style=style, s=200)\n", 797 | "\n", 798 | "handles, labels = ax.get_legend_handles_labels()\n", 799 | "plt.legend(ncol=1, \n", 800 | " bbox_to_anchor = (1., 1.03),\n", 801 | " handles=[handles[i] for i in [3, 0, 2, 1, 4, 6, 5, 7]], \n", 802 | " labels=['Iris', 'Income', 'Credit', 'Travel', 'GAM', 'GBTree', 'RF', 'KNN'],\n", 803 | " frameon=True,\n", 804 | " fontsize=12,\n", 805 | " markerscale = 1.8)\n", 806 | "\n", 807 | "plt.ylabel('Accuracy')\n", 808 | "plt.xlabel('Average Degree of Variable Interaction in Shapley-GAM')\n", 809 | "ax.set_xticks([1,2,3,4,5])\n", 810 | "ax.set_xlim([0.85,5])\n", 811 | "ax.set_ylim([0.6, 1.025])\n", 812 | "ax.set_yticks([0.6, 0.7, 0.8, 0.9, 1.0])\n", 813 | "plt.tight_layout()\n", 814 | "plt.savefig(f'../../figures/accuracy_interaction.svg')\n", 815 | "plt.show()" 816 | ] 817 | } 818 | ], 819 | "metadata": { 820 | "kernelspec": { 821 | "display_name": "Python 3", 822 | "language": "python", 823 | "name": "python3" 824 | }, 825 | "language_info": { 826 | "codemirror_mode": { 827 | "name": "ipython", 828 | "version": 3 829 | }, 830 | "file_extension": ".py", 831 | "mimetype": "text/x-python", 832 | "name": "python", 833 | "nbconvert_exporter": "python", 834 | "pygments_lexer": "ipython3", 835 | "version": "3.8.3" 836 | } 837 | }, 838 | "nbformat": 4, 839 | "nbformat_minor": 4 840 | } 841 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/hyperparameters.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Hyperparameter tuning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import datasets\n", 17 | "import nshap\n", 18 | "\n", 19 | "import xgboost\n", 20 | "import sklearn\n", 21 | "from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor\n", 22 | "from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor\n", 23 | "from sklearn.svm import SVC, SVR\n", 24 | "from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor\n", 25 | "\n", 26 | "from sklearn.model_selection import GridSearchCV\n", 27 | "\n", 28 | "%load_ext autoreload\n", 29 | "%autoreload 2" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### kNN" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "param_grid = {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 10, 15, 20, 25, 30, 50, 80]}\n", 55 | " \n", 56 | "for data_set in data_sets:\n", 57 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(data_set)\n", 58 | " is_classification = datasets.is_classification(data_set)\n", 59 | " print(data_set, is_classification, feature_names)\n", 60 | " \n", 61 | " if is_classification:\n", 62 | " clf = GridSearchCV(KNeighborsClassifier(), param_grid) \n", 63 | " clf.fit(X_train, Y_train)\n", 64 | " else: \n", 65 | " clf = GridSearchCV(KNeighborsRegressor(), param_grid)\n", 66 | " clf.fit(X_train, Y_train)\n", 67 | " \n", 68 | " print(clf.best_params_)\n", 69 | " if is_classification:\n", 70 | " print( sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test)) )\n", 71 | " else:\n", 72 | " print( sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test)) )" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "knn_k = {'folk_income': 30, \n", 82 | " 'folk_travel': 80, \n", 83 | " 'housing': 10, \n", 84 | " 'diabetes': 15, \n", 85 | " 'credit': 25,\n", 86 | " 'iris': 1} " 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "knn_k" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.8.3" 116 | } 117 | }, 118 | "nbformat": 4, 119 | "nbformat_minor": 5 120 | } 121 | -------------------------------------------------------------------------------- /notebooks/replicate-paper/paperutil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import datasets 4 | 5 | import xgboost 6 | from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor 7 | from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor 8 | from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor 9 | 10 | def checkerboard_function(k, num_checkers=8): 11 | """ The k-dimensional checkerboard function, delivered within the unit cube. 12 | This is pure interaction. Enjoy! 13 | 14 | k: dimension of the checkerboard (k >= 2) 15 | num_checkers: number of checkers along an axis (num_checkers >= 2) 16 | """ 17 | def f_checkerboard(X): 18 | if X.ndim == 1: 19 | return np.sum([int(num_checkers * X[i]) for i in range(k)]) % 2 20 | # X.ndin == 2 21 | result = np.zeros(X.shape[0]) 22 | for i_point, x in enumerate(X): 23 | result[i_point] = np.sum([int(num_checkers * x[i]) for i in range(k)]) % 2 24 | return result 25 | return f_checkerboard 26 | 27 | def train_classifier(dataset, classifier): 28 | """ Train the different classifiers that we use in the paper. 29 | """ 30 | X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset) 31 | is_classification = datasets.is_classification(dataset) 32 | if classifier == 'gam': 33 | if is_classification: 34 | clf = ExplainableBoostingClassifier(feature_names=feature_names, interactions=0, random_state=0) 35 | clf.fit(X_train, Y_train) 36 | else: 37 | clf = ExplainableBoostingRegressor(feature_names=feature_names, interactions=0, random_state=0) 38 | clf.fit(X_train, Y_train) 39 | elif classifier == 'rf': 40 | if is_classification: 41 | clf = RandomForestClassifier(n_estimators=100, random_state=0) 42 | clf.fit(X_train, Y_train) 43 | else: 44 | clf = RandomForestRegressor(n_estimators=100, random_state=0) 45 | clf.fit(X_train, Y_train) 46 | elif classifier == 'gbtree': 47 | if is_classification: 48 | clf = xgboost.XGBClassifier(n_estimators=100, use_label_encoder=False, random_state=0) 49 | clf.fit(X_train, Y_train) 50 | else: 51 | clf = xgboost.XGBRegressor(n_estimators=100, use_label_encoder=False, random_state=0) 52 | clf.fit(X_train, Y_train) 53 | elif classifier == 'knn': 54 | # determined with cross-validation 55 | knn_k = {'folk_income': 30, 56 | 'folk_travel': 80, 57 | 'housing': 10, 58 | 'diabetes': 15, 59 | 'credit': 25, 60 | 'iris': 1} 61 | if is_classification: 62 | clf = KNeighborsClassifier(n_neighbors = knn_k[dataset]) 63 | clf.fit(X_train, Y_train) 64 | else: 65 | clf = KNeighborsRegressor(n_neighbors = knn_k[dataset]) 66 | clf.fit(X_train, Y_train) 67 | return clf 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | 3 | [build-system] 4 | requires = ["setuptools", "wheel"] 5 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="nshap", 8 | version="0.2.0", 9 | author="Sebastian Bordt", 10 | author_email="sbordt@posteo.de", 11 | description="Python package to compute n-Shapley Values.", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/tml-tuebingen/nshap", 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | package_dir={"": "src"}, 21 | packages=["nshap"], 22 | python_requires=">=3.6", 23 | install_requires=["numpy", "matplotlib", "seaborn",], 24 | ) 25 | -------------------------------------------------------------------------------- /src/nshap/InteractionIndex.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | import nshap 6 | 7 | ############################################################################################################################# 8 | # Unique string identifiers for all supported interaction indices 9 | ############################################################################################################################# 10 | 11 | SHAPLEY_VALUES = "Shapley Values" 12 | MOEBIUS_TRANSFORM = "Moebius Transform" 13 | SHAPLEY_INTERACTION = "Shapley Interaction" 14 | N_SHAPLEY_VALUES = "n-Shapley Values" 15 | SHAPLEY_TAYLOR = "Shapley Taylor" 16 | FAITH_SHAP = "Faith-Shap" 17 | BANZHAF = "Banzhaf" 18 | BANZHAF_INTERACTION = "Banzhaf Interaction" 19 | FAITH_BANZHAF = "Faith-Banzhaf" 20 | 21 | ALL_INDICES = [ 22 | SHAPLEY_VALUES, 23 | MOEBIUS_TRANSFORM, 24 | SHAPLEY_INTERACTION, 25 | N_SHAPLEY_VALUES, 26 | SHAPLEY_TAYLOR, 27 | FAITH_SHAP, 28 | BANZHAF, 29 | BANZHAF_INTERACTION, 30 | FAITH_BANZHAF 31 | ] 32 | 33 | ############################################################################################################################# 34 | # A single class for all interaction indices 35 | ############################################################################################################################# 36 | 37 | 38 | class InteractionIndex(collections.UserDict): 39 | """Class for different interaction indices (n-Shapley Values, Shapley Taylor, Faith-Shap). 40 | Interaction indices are a Python dict with added functionality. 41 | """ 42 | 43 | def __init__(self, index_type: str, values, n=None, d=None): 44 | """Initialize an interaction index from a dict of values. 45 | 46 | Args: 47 | type (str): The type of the interaction index. 48 | values (_type_): The underlying dict of values. 49 | n (_type_, optional): _description_. Defaults to None. 50 | d (_type_, optional): _description_. Defaults to None. 51 | """ 52 | super().__init__(values) 53 | assert index_type in ALL_INDICES, f"{index_type} is not a supported interaction index" 54 | self.index_type = index_type 55 | self.n = n 56 | if ( 57 | n is None 58 | ): # if n or d are not given as aruments, infer them from the values dict. 59 | self.n = max([len(x) for x in values.keys()]) 60 | self.d = d 61 | if d is None: 62 | self.d = len([x[0] for x in values.keys() if len(x) == 1]) 63 | 64 | def get_index_type(self): 65 | """Return the type of the interaction index (for example, "Shapley Taylor") 66 | 67 | Returns: 68 | str: See function description. 69 | """ 70 | return self.index_type 71 | 72 | def get_input_dimension(self): 73 | """Return the input dimension of the function for which we computed the interaction index ('d'). 74 | 75 | Returns: 76 | integer: See function description. 77 | """ 78 | return self.d 79 | 80 | def get_order(self): 81 | """Return the order of the interaction index ('n'). 82 | 83 | Returns: 84 | integer: See function description. 85 | """ 86 | return self.n 87 | 88 | def sum(self): 89 | """Sum all the terms that are invovled in the interaction index. 90 | 91 | For many interaction indices, the result is equal to the value of the function that we attempt to explain, minus the value of the empty coalition. 92 | 93 | Returns: 94 | Float: The sum. 95 | """ 96 | return np.sum([x for x in self.data.values()]) 97 | 98 | def copy(self): 99 | """Return a copy of the current object. 100 | 101 | Returns: 102 | nshap.InteractionIndex: The copy. 103 | """ 104 | return InteractionIndex(self.index_type, self.data.copy()) 105 | 106 | def save(self, fname): 107 | """Save the interaction index to a JSON file. 108 | 109 | Args: 110 | fname (str): Filename. 111 | """ 112 | nshap.save(self, fname) 113 | 114 | def plot(self, *args, **kwargs): 115 | """Generate a plots of the n-Shapley Values. 116 | 117 | This function simply calls nshap.plots.plot_n_shapley. 118 | 119 | Returns: 120 | The axis of the matplotlib plot. 121 | """ 122 | return nshap.plot_interaction_index(self, *args, **kwargs) 123 | 124 | def shapley_values(self): 125 | """Compute the original Shapley Values. 126 | 127 | Returns: 128 | numpy.ndarray: Shaley Values. If you prefer an object of type nShapleyValues, call k_shapley_values(self, 1). 129 | """ 130 | assert ( 131 | self.index_type == N_SHAPLEY_VALUES or self.index_type == MOEBIUS_TRANSFORM 132 | ), f"shapley_values only supports {N_SHAPLEY_VALUES} and {MOEBIUS_TRANSFORM}" 133 | shapley_values = self.k_shapley_values(1) 134 | shapley_values = np.array(list(shapley_values.values())).reshape((1, -1)) 135 | return shapley_values 136 | 137 | def k_shapley_values(self, k): 138 | """Compute k-Shapley Values of lower order. Requires k <= n. 139 | 140 | Args: 141 | k (int): The desired order. 142 | 143 | Returns: 144 | nShapleyValues: k-Shapley Values. 145 | """ 146 | assert ( 147 | self.index_type == N_SHAPLEY_VALUES or self.index_type == MOEBIUS_TRANSFORM 148 | ), f"k_shapley_values only supports {N_SHAPLEY_VALUES} and {MOEBIUS_TRANSFORM}" 149 | assert k <= self.n, "k_shapley_values requires k self.n - 1): 169 | continue 170 | # we have the n-normalized effect 171 | S_effect = self.data.get(S, 0) 172 | # go over all subsets T of length n that contain S 173 | for T in nshap.powerset(range(self.d)): 174 | if (len(T) != self.n) or (not set(S).issubset(T)): 175 | continue 176 | # add the effect of T to S 177 | T_effect = self.data.get( 178 | T, 0 179 | ) # default to zero in case the dict is sparse 180 | # normalization 181 | S_effect = ( 182 | S_effect - (nshap.bernoulli_numbers[len(T) - len(S)]) * T_effect 183 | ) 184 | # now we have the normalized effect 185 | result[S] = S_effect 186 | return InteractionIndex(N_SHAPLEY_VALUES, result) 187 | -------------------------------------------------------------------------------- /src/nshap/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.0" 2 | __author__ = "Sebastian Bordt" 3 | 4 | from nshap.InteractionIndex import InteractionIndex 5 | from nshap.InteractionIndex import ( 6 | SHAPLEY_VALUES, 7 | MOEBIUS_TRANSFORM, 8 | SHAPLEY_INTERACTION, 9 | N_SHAPLEY_VALUES, 10 | SHAPLEY_TAYLOR, 11 | FAITH_SHAP, 12 | BANZHAF, 13 | BANZHAF_INTERACTION, 14 | FAITH_BANZHAF, 15 | ) 16 | 17 | from nshap.functions import ( 18 | shapley_interaction_index, 19 | n_shapley_values, 20 | moebius_transform, 21 | shapley_values, 22 | faith_shap, 23 | shapley_taylor, 24 | faith_banzhaf, 25 | banzhaf_interaction_index, 26 | bernoulli_numbers, 27 | ) 28 | from nshap.plot import plot_interaction_index 29 | from nshap.util import allclose, save, load, powerset 30 | 31 | from nshap.vfunc import memoized_vfunc 32 | 33 | -------------------------------------------------------------------------------- /src/nshap/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | import nshap 5 | 6 | 7 | ############################################################################################################################# 8 | # Bernoulli numbers 9 | ############################################################################################################################# 10 | 11 | bernoulli_numbers = np.array( 12 | [ 13 | 1, 14 | -1 / 2, 15 | 1 / 6, 16 | 0, 17 | -1 / 30, 18 | 0, 19 | 1 / 42, 20 | 0, 21 | -1 / 30, 22 | 0, 23 | 5 / 66, 24 | 0, 25 | -691 / 2730, 26 | 0, 27 | 7 / 6, 28 | 0, 29 | -3617 / 510, 30 | 0, 31 | 43867 / 798, 32 | 0, 33 | ] 34 | ) 35 | 36 | 37 | ############################################################################################################################# 38 | # Input Validation (Used by all that follows) 39 | ############################################################################################################################# 40 | 41 | 42 | def validate_inputs(x, v_func): 43 | if x.ndim == 1: 44 | x = x.reshape((1, -1)) 45 | assert ( 46 | x.shape[0] == 1 47 | ), "The nshap package only accepts single data points as input." 48 | dim = x.shape[1] 49 | if not isinstance(v_func, nshap.memoized_vfunc): # meomization 50 | v_func = nshap.memoized_vfunc(v_func) 51 | # for d>20, we would have to consider the numerics of the problem more carefully 52 | assert dim <= 20, "The nshap package only supports d<=20." 53 | return x, v_func, dim 54 | 55 | 56 | ############################################################################################################################# 57 | # The Moebius Transform 58 | ############################################################################################################################# 59 | 60 | 61 | def moebius_transform(x, v_func): 62 | """Compute the Moebius Transform of of the value function v_func at the data point x. 63 | 64 | Args: 65 | x (numpy.ndarray): A data point. 66 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 67 | 68 | Returns: 69 | nshap.InteractionIndex: The interaction index. 70 | """ 71 | # validate input parameters 72 | x, v_func, dim = validate_inputs(x, v_func) 73 | # go over all subsets S of N with 1<=|S|<=d 74 | result = {} 75 | for S in nshap.powerset(set(range(dim))): 76 | if len(S) == 0: 77 | continue 78 | summands = [] 79 | # go over all subsets T of S 80 | for T in nshap.powerset(S): 81 | summands.append(v_func(x, list(T)) * (-1) ** (len(S) - len(T))) 82 | result[S] = np.sum(summands) 83 | # return result 84 | return nshap.InteractionIndex(nshap.MOEBIUS_TRANSFORM, result) 85 | 86 | 87 | ############################################################################################################################# 88 | # Shapley Values 89 | ############################################################################################################################# 90 | 91 | 92 | def shapley_values(x, v_func): 93 | """Compute the original Shapley Values, according to the Shapley Formula. 94 | 95 | Args: 96 | x (numpy.ndarray): A data point. 97 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 98 | 99 | Returns: 100 | nshap.InteractionIndex: The interaction index. 101 | """ 102 | # validate input parameters 103 | x, v_func, dim = validate_inputs(x, v_func) 104 | # go over all features 105 | result = {} 106 | for i_feature in range(dim): 107 | phi = 0 108 | S = set(range(dim)) 109 | S.remove(i_feature) 110 | for subset in nshap.powerset(S): 111 | v = v_func(x, list(subset)) 112 | subset_i = list(subset) 113 | subset_i.append(i_feature) 114 | subset_i.sort() 115 | v_i = v_func(x, subset_i) 116 | phi = phi + np.math.factorial(len(subset)) * np.math.factorial( 117 | dim - len(subset) - 1 118 | ) / np.math.factorial(dim) * (v_i - v) 119 | result[(i_feature,)] = phi 120 | # return result 121 | return nshap.InteractionIndex(nshap.SHAPLEY_VALUES, result) 122 | 123 | 124 | ############################################################################################################################# 125 | # Shapley Interaction Index 126 | ############################################################################################################################# 127 | 128 | 129 | def shapley_interaction_index(x, v_func, n=-1): 130 | """Compute the Shapley Interaction Index (https://link.springer.com/article/10.1007/s001820050125) at the data point x, and all S such that |S|<=n, given a coalition value function. 131 | 132 | Args: 133 | x (numpy.ndarray): A data point 134 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 135 | n (int): Order up to which the Shapley Interaction Index should be computed. 136 | 137 | Returns: 138 | nshap.InteractionIndex: The interaction index. 139 | """ 140 | # validate input parameters 141 | x, v_func, dim = validate_inputs(x, v_func) 142 | if n == -1: 143 | n = dim 144 | # go over all subsets S of N with |S|<=n 145 | result = {} 146 | for S in nshap.powerset(set(range(dim))): 147 | if len(S) > n: 148 | continue 149 | # go over all subsets T of N\S 150 | phi = 0 151 | N_minus_S = set(range(dim)) - set(S) 152 | for T in nshap.powerset(N_minus_S): 153 | # go over all subsets L of S 154 | delta = 0 155 | for L in nshap.powerset(S): 156 | coalition = list(L) 157 | coalition.extend(list(T)) 158 | coalition.sort() 159 | delta = delta + np.power(-1, len(S) - len(L)) * v_func(x, coalition) 160 | phi = phi + delta * np.math.factorial(len(T)) * np.math.factorial( 161 | dim - len(T) - len(S) 162 | ) / np.math.factorial(dim - len(S) + 1) 163 | result[S] = phi 164 | # return result 165 | return nshap.InteractionIndex(nshap.SHAPLEY_INTERACTION, result) 166 | 167 | 168 | ############################################################################################################################# 169 | # n-Shapley Values 170 | ############################################################################################################################# 171 | 172 | 173 | def n_shapley_values(x, v_func, n=-1): 174 | """This function provides an exact computation of n-Shapley Values (https://arxiv.org/abs/2209.04012) via their definition. 175 | 176 | Args: 177 | x (numpy.ndarray): A data point. 178 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 179 | n (int, optional): Order of n-Shapley Values or -1 for n=d. Defaults to -1. 180 | 181 | Returns: 182 | nshap.InteractionIndex: The interaction index. 183 | """ 184 | # validate input parameters 185 | x, v_func, dim = validate_inputs(x, v_func) 186 | if n == -1: 187 | n = dim 188 | # first compute the shapley interaction index 189 | shapley_int_idx = shapley_interaction_index(x, v_func, n) 190 | # a list of length num_datapoints 191 | result = {} 192 | # consider all subsets S with 1<=|S|<=n 193 | for S in nshap.powerset(range(dim)): 194 | if (len(S) == 0) or (len(S) > n): 195 | continue 196 | # obtain the unnormalized effect (that is, delta_S(x)) 197 | S_effect = shapley_int_idx[S] 198 | # go over all subsets T of length k+1, ..., n that contain S 199 | for T in nshap.powerset(range(dim)): 200 | if (len(T) <= len(S)) or (len(T) > n) or (not set(S).issubset(T)): 201 | continue 202 | # get the effect of T, and substract it from the effect of S 203 | T_effect = shapley_int_idx[T] 204 | # normalization with bernoulli_numbers 205 | S_effect = S_effect + (bernoulli_numbers[len(T) - len(S)]) * T_effect 206 | # now we have the normalized effect 207 | result[S] = S_effect 208 | # return result 209 | return nshap.InteractionIndex(nshap.N_SHAPLEY_VALUES, result) 210 | 211 | 212 | ############################################################################################################################# 213 | # Shapley Taylor Interaction Index 214 | ############################################################################################################################# 215 | 216 | 217 | def shapley_taylor(x, v_func, n=-1): 218 | """ Compute the Shapley Taylor Interaction Index (https://arxiv.org/abs/1902.05622) of the value function v_func at the data point x. 219 | 220 | Args: 221 | x (numpy.ndarray): A data point. 222 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 223 | n (int, optional): Order of the Shapley Taylor Interaction Index or -1 for n=d. Defaults to -1. 224 | 225 | Returns: 226 | nshap.InteractionIndex: The interaction index. 227 | """ 228 | # validate input parameters 229 | x, v_func, dim = validate_inputs(x, v_func) 230 | if n == -1: 231 | n = dim 232 | # we first compute the moebius transform 233 | moebius = moebius_transform(x, v_func) 234 | # then compute the Shapley Taylor Interaction Index, for all datapoints 235 | result = {} 236 | # consider all subsets S with 1<=|S|<=n 237 | for S in nshap.powerset(range(dim)): 238 | if (len(S) == 0) or (len(S) > n): 239 | continue 240 | result[S] = moebius[S] 241 | # for |S|=n, average the higher-order effects 242 | if len(S) == n: 243 | # go over all subsets of [d] that contain S 244 | for T in nshap.powerset(range(dim)): 245 | if (len(T) <= len(S)) or (not set(S).issubset(T)): 246 | continue 247 | result[S] += moebius[T] / math.comb(len(T), len(S)) 248 | # return result 249 | return nshap.InteractionIndex(nshap.SHAPLEY_TAYLOR, result) 250 | 251 | 252 | ############################################################################################################################# 253 | # Faith-Shap Interaction Index 254 | ############################################################################################################################# 255 | 256 | 257 | def faith_shap(x, v_func, n=-1): 258 | """ Compute the Faith-Shap Interaction Index (https://arxiv.org/abs/2203.00870) of the value function v_func at the data point x. 259 | 260 | Args: 261 | x (numpy.ndarray): A data point. 262 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 263 | n (int, optional): Order of the Interaction Index or -1 for n=d. Defaults to -1. 264 | 265 | Returns: 266 | nshap.InteractionIndex: The interaction index. 267 | """ 268 | # validate input parameters 269 | x, v_func, dim = validate_inputs(x, v_func) 270 | if n == -1: 271 | n = dim 272 | # we first compute the moebius transform 273 | moebius = moebius_transform(x, v_func) 274 | # then compute the Faith-Shap Interaction Index 275 | result = {} 276 | # consider all subsets S with 1<=|S|<=n 277 | for S in nshap.powerset(range(dim)): 278 | if (len(S) == 0) or (len(S) > n): 279 | continue 280 | result[S] = moebius[S] 281 | # go over all subsets of [d] that contain S 282 | for T in nshap.powerset(range(dim)): 283 | if (len(T) <= n) or (not set(S).issubset(T)): 284 | continue 285 | # compare Theorem 19 in the Faith-Shap paper. In our notation, l=n. 286 | result[S] += ( 287 | (-1) ** (n - len(S)) 288 | * len(S) 289 | / (n + len(S)) 290 | * math.comb(n, len(S)) 291 | * math.comb(len(T) - 1, n) 292 | / math.comb(len(T) + n - 1, n + len(S)) 293 | * moebius[T] 294 | ) 295 | # return result 296 | return nshap.InteractionIndex(nshap.FAITH_SHAP, result) 297 | 298 | 299 | ############################################################################################################################# 300 | # Bhanzaf Values 301 | ############################################################################################################################# 302 | 303 | 304 | ############################################################################################################################# 305 | # Bhanzaf Interaction Index 306 | ############################################################################################################################# 307 | 308 | 309 | def banzhaf_interaction_index(x, v_func, n=-1): 310 | """Compute the Banzhaf Interaction Index (https://link.springer.com/article/10.1007/s001820050125) at the data point x, and all S such that |S|<=n, given a coalition value function. 311 | 312 | Args: 313 | x (numpy.ndarray): A data point 314 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 315 | n (int): Order up to which the Shapley Interaction Index should be computed. 316 | 317 | Returns: 318 | nshap.InteractionIndex: The interaction index. 319 | """ 320 | # validate input parameters 321 | x, v_func, dim = validate_inputs(x, v_func) 322 | if n == -1: 323 | n = dim 324 | # go over all subsets S of N with |S|<=n 325 | result = {} 326 | for S in nshap.powerset(set(range(dim))): 327 | if len(S) > n: 328 | continue 329 | # go over all subsets T of N\S 330 | phi = 0 331 | N_minus_S = set(range(dim)) - set(S) 332 | for T in nshap.powerset(N_minus_S): 333 | # go over all subsets L of S 334 | delta = 0 335 | for L in nshap.powerset(S): 336 | coalition = list(L) 337 | coalition.extend(list(T)) 338 | coalition.sort() 339 | delta = delta + np.power(-1, len(S) - len(L)) * v_func(x, coalition) 340 | phi = phi + delta * (1 / np.power(2, dim - len(S))) 341 | result[S] = phi 342 | # return result 343 | return nshap.InteractionIndex(nshap.BANZHAF_INTERACTION, result) 344 | 345 | 346 | ############################################################################################################################# 347 | # Faith-Bhanzaf Interaction Index 348 | ############################################################################################################################# 349 | 350 | 351 | def faith_banzhaf(x, v_func, n=-1): 352 | """ Compute the Faith-Banzhaf Interaction Index (https://arxiv.org/abs/2203.00870) of the value function v_func at the data point x. 353 | 354 | Args: 355 | x (numpy.ndarray): A data point. 356 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition. 357 | n (int, optional): Order of the Interaction Index or -1 for n=d. Defaults to -1. 358 | 359 | Returns: 360 | nshap.InteractionIndex: The interaction index. 361 | """ 362 | # validate input parameters 363 | x, v_func, dim = validate_inputs(x, v_func) 364 | if n == -1: 365 | n = dim 366 | # we first compute the moebius transform 367 | moebius = moebius_transform(x, v_func) 368 | # then compute the Faith-Banzhaf Interaction Index 369 | result = {} 370 | # consider all subsets S with 1<=|S|<=n 371 | for S in nshap.powerset(range(dim)): 372 | if (len(S) == 0) or (len(S) > n): 373 | continue 374 | result[S] = moebius[S] 375 | # go over all subsets of [d] that contain S 376 | for T in nshap.powerset(range(dim)): 377 | if (len(T) <= n) or (not set(S).issubset(T)): 378 | continue 379 | # compare Theorem 17 in the Faith-Shap paper. In our notation, l=n. 380 | result[S] += ( 381 | (-1) ** (n - len(S)) 382 | * np.power(0.5, len(T) - len(S)) 383 | * math.comb(len(T) - len(S) - 1, n - len(S)) 384 | * moebius[T] 385 | ) 386 | # return result 387 | return nshap.InteractionIndex(nshap.FAITH_BANZHAF, result) 388 | -------------------------------------------------------------------------------- /src/nshap/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as mpatches 6 | 7 | import seaborn as sns 8 | 9 | import nshap 10 | 11 | # avoid type-3 fonts 12 | matplotlib.rcParams["pdf.fonttype"] = 42 13 | matplotlib.rcParams["ps.fonttype"] = 42 14 | 15 | ############################################################################################################################# 16 | # Plots 17 | ############################################################################################################################# 18 | 19 | plot_colors = [ 20 | "#1f77b4", 21 | "#ff7f0e", 22 | "#2ca02c", 23 | "#d62728", 24 | "#9467bd", 25 | "#8c564b", 26 | "#e377c2", 27 | "#7f7f7f", 28 | "#bcbd22", 29 | "#17becf", 30 | "black", 31 | "#1f77b4", 32 | "#ff7f0e", 33 | "#2ca02c", 34 | "#d62728", 35 | "#9467bd", 36 | "#8c564b", 37 | "#e377c2", 38 | "#7f7f7f", 39 | "#bcbd22", 40 | "#17becf", 41 | ] 42 | 43 | 44 | def plot_interaction_index( 45 | I: nshap.InteractionIndex, 46 | max_degree=None, 47 | axis=None, 48 | feature_names=None, 49 | rotation=70, 50 | legend=True, 51 | fig_kwargs={"figsize": (6, 6)}, 52 | barwidth=0.5, 53 | ): 54 | """Generate the plots in the paper. 55 | 56 | Args: 57 | n_shapley_values (nshap.nShapleyValues): The n-Shapley Values that we want to plot. 58 | max_degree (int, optional): Plots all effect of order larger than max_degree with a single color. Defaults to None. 59 | axis (optional): Matplotlib axis on which to plot. Defaults to None. 60 | feature_names (_type_, optional): Used to label the x-axis. Defaults to None. 61 | rotation (int, optional): Rotation for x-axis labels. Defaults to 70. 62 | legend (bool, optional): Plot legend. Defaults to True. 63 | fig_kwargs (dict, optional): fig_kwargs, handed down to matplotlib figure. Defaults to {"figsize": (6, 6)}. 64 | barwidth (float, optional): Widht of the bars. Defaults to 0.5. 65 | 66 | Returns: 67 | Matplotlib axis: The plot axis. 68 | """ 69 | if max_degree == 1: 70 | I = I.shapley_values() # TODO this might have to change 71 | num_features = I.d 72 | vmax, vmin = 0, 0 73 | ax = axis 74 | if axis is None: 75 | _, ax = plt.subplots(**fig_kwargs) 76 | if max_degree is None or max_degree >= I.n: 77 | max_degree = I.n 78 | ax.axhline(y=0, color="black", linestyle="-") # line at 0 79 | for i_feature in range(num_features): 80 | bmin, bmax = 0, 0 81 | v = I[(i_feature,)] 82 | ax.bar( 83 | x=i_feature, 84 | height=v, 85 | width=barwidth, 86 | bottom=0, 87 | align="center", 88 | label=f"Feature {i_feature}", 89 | color=plot_colors[0], 90 | ) 91 | bmin = min(bmin, v) 92 | bmax = max(bmax, v) 93 | # higher-order effects, up to max_degree 94 | for n_k in range(2, I.n + 1): 95 | v_pos = np.sum( 96 | [ 97 | I[k] / len(k) 98 | for k in I.data.keys() 99 | if (len(k) == n_k and i_feature in k and I[k] > 0) 100 | ] 101 | ) 102 | v_neg = np.sum( 103 | [ 104 | I[k] / len(k) 105 | for k in I.data.keys() 106 | if (len(k) == n_k and i_feature in k and I[k] < 0) 107 | ] 108 | ) 109 | # 'max_degree or higher' 110 | if n_k == max_degree: 111 | v_pos = np.sum( 112 | [ 113 | I[k] / len(k) 114 | for k in I.data.keys() 115 | if (len(k) >= n_k and i_feature in k and I[k] > 0) 116 | ] 117 | ) 118 | v_neg = np.sum( 119 | [ 120 | I[k] / len(k) 121 | for k in I.data.keys() 122 | if (len(k) >= n_k and i_feature in k and I[k] < 0) 123 | ] 124 | ) 125 | if v_pos > 0: 126 | ax.bar( 127 | x=i_feature, 128 | height=v_pos, 129 | width=barwidth, 130 | bottom=bmax, 131 | align="center", 132 | color=plot_colors[n_k - 1], 133 | ) 134 | bmax = bmax + v_pos 135 | if v_neg < 0: 136 | ax.bar( 137 | x=i_feature, 138 | height=v_neg, 139 | width=barwidth, 140 | bottom=bmin, 141 | align="center", 142 | color=plot_colors[n_k - 1], 143 | ) 144 | bmin = bmin + v_neg 145 | if n_k == max_degree: # no higher orders 146 | break 147 | vmin = min(vmin, bmin) 148 | vmax = max(vmax, bmax) 149 | # axes 150 | if feature_names is None: 151 | feature_names = [f"Feature {i+1}" for i in range(num_features)] 152 | ax.set_ylim([1.1 * vmin, 1.1 * vmax]) 153 | ax.set_xticks(np.arange(num_features)) 154 | ax.set_xticklabels(feature_names, rotation=rotation) 155 | # legend with custom labels 156 | color_patches = [mpatches.Patch(color=color) for color in plot_colors] 157 | lables = ["Main"] 158 | if max_degree > 1: 159 | lables.append("2nd order") 160 | if max_degree > 2: 161 | lables.append(f"3rd order") 162 | if max_degree > 3: 163 | for i_degree in range(4, I.n + 1): 164 | if i_degree == max_degree and max_degree < I.n: 165 | lables.append(f"{i_degree}-{I.n}th order") 166 | break 167 | else: 168 | if i_degree == I.n: 169 | lables.append(f"{i_degree}th order") 170 | else: 171 | lables.append(f"{i_degree}th") 172 | ax.legend( 173 | color_patches, 174 | lables, 175 | bbox_to_anchor=(1.02, 1), 176 | loc="upper left", 177 | borderaxespad=0, 178 | handletextpad=0.5, 179 | handlelength=1, 180 | handleheight=1, 181 | ) 182 | ax.get_legend().set_visible(legend) 183 | ax.set_title(I.index_type) 184 | if axis is None: 185 | plt.show() 186 | return ax 187 | -------------------------------------------------------------------------------- /src/nshap/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from itertools import chain, combinations 4 | 5 | import json 6 | import re 7 | 8 | import nshap 9 | 10 | 11 | def allclose(dict1, dict2, rtol=1e-05, atol=1e-08): 12 | """Compare if two dicts of n-Shapley Values are close according to numpy.allclose. 13 | 14 | Useful for testing purposes. 15 | 16 | Args: 17 | dict1 (dict): The first dict. 18 | dict2 (dict): The second dict. 19 | rtol (float, optional): passed to numpy.allclose. Defaults to 1e-05. 20 | atol (float, optional): passed to numpy.allclose. Defaults to 1e-08. 21 | 22 | Returns: 23 | bool: Result of numpy.allclose. 24 | """ 25 | if dict1.keys() != dict2.keys(): # both dictionaries need to have the same keys 26 | return False 27 | for key in dict1.keys(): 28 | if not np.allclose(dict1[key], dict2[key], rtol=rtol, atol=atol): 29 | return False 30 | return True 31 | 32 | 33 | def save(values, fname): 34 | """Save an interaction index to a JSON file. 35 | 36 | Args: 37 | values (nshap.InteractionIndex): The interaction index. 38 | fname (str): Filename. 39 | """ 40 | json_dict = { 41 | str(k): v for k, v in values.data.items() 42 | } # convert the integer tuples to strings 43 | json_dict["index_type"] = values.index_type 44 | json_dict["n"] = values.n 45 | json_dict["d"] = values.d 46 | 47 | with open(fname, "w+") as fp: 48 | fp.write(json.dumps(json_dict, indent=2)) 49 | 50 | 51 | def to_int_tuple(str_tuple): 52 | """Convert string representations of integer tuples back to python integer tuples. 53 | 54 | This utility function is used to load n-Shapley Values from JSON. 55 | 56 | Args: 57 | str_tuple (str): String representaiton of an integer tuple, for example "(1,2,3)" 58 | 59 | Returns: 60 | tuple: Tuple of integers, for example (1,2,3) 61 | """ 62 | start = str_tuple.find("(") + 1 63 | first_comma = str_tuple.find(",") 64 | end = str_tuple.rfind(")") 65 | # is the string of the form "(1,)", i.e. with a trailing comma? 66 | if len(re.findall("[0-9]", str_tuple[first_comma:])) == 0: 67 | end = first_comma 68 | str_tuple = str_tuple[start:end] 69 | return tuple(map(int, str_tuple.split(","))) 70 | 71 | 72 | def read_remove(d, key): 73 | """Read and remove a key from a dict d. 74 | """ 75 | r = d[key] 76 | del d[key] 77 | return r 78 | 79 | 80 | def load(fname): 81 | """Load an interaction index from a JSON file. 82 | 83 | Args: 84 | fname (str): Filename. 85 | 86 | Returns: 87 | nshap.InteractionIndex: The loaded interaction index. 88 | """ 89 | with open(fname, "r") as fp: 90 | str_dump = json.load(fp) 91 | 92 | index_type = read_remove(str_dump, "index_type") 93 | n = read_remove(str_dump, "n") 94 | d = read_remove(str_dump, "d") 95 | 96 | # convert the remaining string tuples to integer tuples 97 | python_dict = {to_int_tuple(k): v for k, v in str_dump.items()} 98 | return nshap.InteractionIndex(index_type, python_dict, n, d) 99 | 100 | 101 | def powerset(iterable): 102 | "The powerset function from https://docs.python.org/3/library/itertools.html#itertools-recipes." 103 | s = list(iterable) 104 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) 105 | -------------------------------------------------------------------------------- /src/nshap/vfunc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import functools 4 | 5 | import numbers 6 | 7 | ############################################################################################################################# 8 | # Interventional SHAP 9 | ############################################################################################################################# 10 | 11 | 12 | def interventional_shap( 13 | f, X, target=None, num_samples=1000, random_state=None, meomized=True 14 | ): 15 | """Approximate the value function of interventional SHAP. 16 | 17 | In order to compute n-Shapley Values, a data set is sampled once and then fixed for all evaluations of the value function. 18 | 19 | Args: 20 | f (function): The function to be explained. Will be called as f(x) where x has shape (1,d). 21 | X (_type_): Sample from the data distribution. 22 | target (int, optional): Target class. Required if the output of f(x) is multi-dimensional. Defaults to None. 23 | num_samples (int, optional): The number of samples that should be drawn from X in order estimate the value function. Defaults to 1000. 24 | random_state (_type_, optional): Random state that is passed to np.random.default_rng. Used for reproducibility. Defaults to None. 25 | meomized (bool, optional): Whether the returned value function should be meomized. Defaults to True. 26 | 27 | Returns: 28 | function: The vaue function. 29 | """ 30 | # sample background data set 31 | rng = np.random.default_rng(random_state) 32 | if num_samples < X.shape[0]: 33 | indices = rng.integers(low=0, high=X.shape[0], size=num_samples) 34 | X = X[indices, :] 35 | # the function 36 | def fn(x, subset): 37 | subset = list(subset) 38 | x = x.flatten() 39 | values = [] 40 | for idx in range(X.shape[0]): 41 | x_sample = X[idx, :].copy() # copy here is important, otherwise we modify X 42 | x_sample[subset] = x[subset] 43 | v = f(x_sample.reshape((1, -1))) 44 | # handle different return types 45 | if isinstance(v, numbers.Number): 46 | values.append(v) 47 | elif v.ndim == 1: 48 | if target is None: 49 | values.append(v[0]) 50 | else: 51 | values.append(v[target]) 52 | else: 53 | assert ( 54 | target is not None 55 | ), "f returns multi-dimensional array, but target is not specified" 56 | values.append(v[0, target]) 57 | return np.mean(values) 58 | 59 | if meomized: # meomization 60 | fn = memoized_vfunc(fn) 61 | return fn 62 | 63 | 64 | ############################################################################################################################# 65 | # Meomization for value functions 66 | ############################################################################################################################# 67 | 68 | 69 | class memoized_vfunc(object): 70 | """Decorator to meomize a vfunc. Handles hashability of the numpy.ndarray and the list. 71 | 72 | The hash depends only on the values of x that are in the subset of coordinates. 73 | 74 | This function is able to handle parameters x of shape (d,) and (1,d). 75 | 76 | Meomization helps to avoid duplicate evaluations of the value function, even if we perform 77 | computations that involve the same term multiple times. 78 | 79 | https://wiki.python.org/moin/PythonDecoratorLibrary 80 | """ 81 | 82 | def __init__(self, func): 83 | """ 84 | 85 | Args: 86 | func (function): The value function to be meomized. 87 | """ 88 | self.func = func 89 | self.cache = {} 90 | 91 | def __call__(self, *args): 92 | x = args[0] 93 | if x.ndim == 2: 94 | assert ( 95 | x.shape[0] == 1 96 | ), "Parameter x of value function has to be of shape (d,) or (1,d)." 97 | x = x[0] 98 | subset = tuple(args[1]) 99 | x = tuple((x[i] for i in subset)) 100 | hashable_args = (x, subset) 101 | if hashable_args in self.cache: # in cache 102 | return self.cache[hashable_args] 103 | else: 104 | value = self.func(*args) 105 | self.cache[hashable_args] = value 106 | return value 107 | 108 | def __repr__(self): 109 | """Return the function's docstring.""" 110 | return f"Memoized value function: {self.func.__doc__}" 111 | 112 | def __get__(self, obj, objtype): 113 | """Support instance methods.""" 114 | return functools.partial(self.__call__, obj) 115 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import xgboost 2 | 3 | from folktables import ACSDataSource, ACSIncome 4 | 5 | from sklearn.preprocessing import StandardScaler 6 | from sklearn.model_selection import train_test_split 7 | 8 | import os 9 | 10 | data_root_dir = "../data/" 11 | 12 | paths = [data_root_dir] 13 | for p in paths: 14 | if not os.path.exists(p): 15 | os.mkdir(p) 16 | 17 | 18 | def folktables_income(): 19 | # (down-)load the dataset 20 | data_source = ACSDataSource( 21 | survey_year="2016", horizon="1-Year", survey="person", root_dir=data_root_dir 22 | ) 23 | data = data_source.get_data(states=["CA"], download=True) 24 | X, Y, _ = ACSIncome.df_to_numpy(data) 25 | feature_names = ACSIncome.features 26 | 27 | # feature names 28 | feature_names = ACSIncome.features 29 | 30 | # zero mean and unit variance for all features 31 | X = StandardScaler().fit_transform(X) 32 | 33 | # train-test split 34 | X_train, X_test, Y_train, Y_test = train_test_split( 35 | X, Y, train_size=0.8, random_state=0 36 | ) 37 | 38 | return X_train, X_test, Y_train, Y_test, feature_names 39 | -------------------------------------------------------------------------------- /tests/tests.py: -------------------------------------------------------------------------------- 1 | import xgboost 2 | 3 | from test_util import folktables_income 4 | 5 | import nshap 6 | 7 | 8 | def test_n_shapley(): 9 | """Compare different formulas for computing n-shapley values. 10 | 11 | (1) via delta_S 12 | (2) via the component functions of the shapley gam 13 | 14 | Then compare the shapley values resulting from (1) and (2) to the shapley values computed with the shapley formula. 15 | """ 16 | X_train, X_test, Y_train, Y_test, _ = folktables_income() 17 | X_train = X_train[:, 0:5] 18 | X_test = X_test[:, 0:5] 19 | gbtree = xgboost.XGBClassifier() 20 | gbtree.fit(X_train, Y_train) 21 | 22 | vfunc = nshap.vfunc.interventional_shap(gbtree.predict_proba, X_train, target=0) 23 | 24 | n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc) 25 | moebius = nshap.moebius_transform(X_test[0, :], vfunc) 26 | shapley_values = nshap.shapley_values(X_test[0, :], vfunc) 27 | 28 | assert nshap.allclose(n_shapley_values, moebius) 29 | for k in range(1, X_train.shape[1]): 30 | k_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc, k) 31 | assert nshap.allclose(n_shapley_values.k_shapley_values(k), k_shapley_values) 32 | assert nshap.allclose( 33 | n_shapley_values.k_shapley_values(k), moebius.k_shapley_values(k) 34 | ) 35 | assert nshap.allclose(n_shapley_values.k_shapley_values(1), shapley_values) 36 | assert nshap.allclose(moebius.k_shapley_values(1), shapley_values) 37 | 38 | --------------------------------------------------------------------------------