├── .gitattributes
├── .github
└── workflows
│ ├── docs_pages.yml
│ └── python-app.yml
├── .gitignore
├── LICENSE
├── README.md
├── docs
├── Makefile
├── conf.py
├── index.rst
├── make.bat
├── nshap.rst
└── usage
│ ├── examples.rst
│ ├── installation.rst
│ └── notebooks
│ ├── Example1.ipynb
│ ├── Example2.ipynb
│ └── Example3.ipynb
├── images
├── img1.png
├── img2.png
├── img3.png
├── img4.png
├── img5.png
└── img6.png
├── notebooks
├── Example1.ipynb
├── Example2.ipynb
├── Example3.ipynb
└── replicate-paper
│ ├── checkerboard-compute-million.ipynb
│ ├── checkerboard-compute.ipynb
│ ├── checkerboard-figures.ipynb
│ ├── compute-vfunc.ipynb
│ ├── compute.ipynb
│ ├── datasets.py
│ ├── estimation.ipynb
│ ├── figures.ipynb
│ ├── hyperparameters.ipynb
│ └── paperutil.py
├── pyproject.toml
├── setup.py
├── src
└── nshap
│ ├── InteractionIndex.py
│ ├── __init__.py
│ ├── functions.py
│ ├── plot.py
│ ├── util.py
│ └── vfunc.py
└── tests
├── test_util.py
└── tests.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.github/workflows/docs_pages.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | # execute this workflow automatically when a we push to master
4 | on:
5 | push:
6 | branches: [ main ]
7 |
8 | jobs:
9 |
10 | build_docs_job:
11 | runs-on: ubuntu-latest
12 | env:
13 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
14 |
15 | steps:
16 | - name: Checkout
17 | uses: actions/checkout@v3
18 |
19 | - name: Set up Python 3.10
20 | uses: actions/setup-python@v3
21 | with:
22 | python-version: "3.10"
23 |
24 | - name: Install package and dependencies for building docs
25 | run: |
26 | sudo apt-get install pandoc
27 | python -m pip install --upgrade pip
28 | python -m pip install -U sphinx
29 | python -m pip install sphinx-rtd-theme
30 | # python -m pip install sphinxcontrib-apidoc
31 | python -m pip install sphinx-autoapi
32 | python -m pip install nbsphinx
33 | python -m pip install pypandoc
34 | pip install .
35 | - name: make the sphinx docs
36 | run: |
37 | make -C docs clean
38 | # sphinx-apidoc -f -o docs/source . -H Test -e -t docs/source/_templates
39 | make -C docs html
40 | - name: Init new repo in dist folder and commit generated files
41 | run: |
42 | cd docs/_build/html/
43 | git init
44 | touch .nojekyll
45 | git add -A
46 | git config --local user.email "action@github.com"
47 | git config --local user.name "GitHub Action"
48 | git commit -m 'deploy'
49 | - name: Force push to destination branch
50 | uses: ad-m/github-push-action@master
51 | with:
52 | github_token: ${{ secrets.GITHUB_TOKEN }}
53 | branch: docs-build
54 | force: true
55 | directory: ./docs/_build/html
56 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: pytesting
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | tests:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v3
22 | - name: Set up Python 3.10
23 | uses: actions/setup-python@v3
24 | with:
25 | python-version: "3.10"
26 | - name: Install package & dependencies required for testing
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install flake8 pytest
30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31 | pip install .
32 | pip install folktables xgboost
33 | pip install -U scikit-learn
34 | - name: Lint with flake8
35 | run: |
36 | # stop the build if there are Python syntax errors or undefined names
37 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
38 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
39 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
40 | - name: Test with pytest
41 | run: |
42 | pytest tests/tests.py
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | figures/
3 | results/
4 | output_notebooks/
5 |
6 | # saved interaction indices
7 | *.json
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | .DS_Store
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 TML Tübingen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Welcome to the nshap Package!
2 |
3 |
4 |
5 |
6 |
7 |
8 | This is a python package to compute interaction indices that extend the Shapley Value. It accompanies the AISTATS'23 paper [From Shapley Values to Generalized Additive Models and back](http://arxiv.org/abs/2209.04012) by Sebastian Bordt and Ulrike von Luxburg.
9 |
10 | [](https://pypi.org/project/nshap/)
11 | [](https://tml-tuebingen.github.io/nshap/)
12 | 
13 | [](https://opensource.org/licenses/MIT)
14 |
15 | This package supports, among others,
16 |
17 | - [n-Shapley Values](http://arxiv.org/abs/2209.04012), introduced in our paper
18 | - [SHAP Interaction Values](https://www.nature.com/articles/s42256-019-0138-9), a popular interaction index that can also be computed with the [shap](https://github.com/slundberg/shap/) package
19 | - the [Shapley Taylor](https://arxiv.org/abs/1902.05622) interaction index
20 | - the [Faith-Shap](https://arxiv.org/abs/2203.00870) interaction index
21 | - the [Faith-Banzhaf](https://arxiv.org/abs/2203.00870) interaction index.
22 |
23 | The package works with arbitrary user-defined value functions. It also provides a model-agnostic implementation of the interventional SHAP value function.
24 |
25 | Note that the computed interaction indices are an estimate [that can be inaccurate](#estimation), especially if the order of the interaction is large.
26 |
27 | Documentation is available at [https://tml-tuebingen.github.io/nshap](https://tml-tuebingen.github.io/nshap/).
28 |
29 | ⚠️ Disclaimer
30 |
31 | This package does not provide an efficient way to compute Shapley Values. For this you should refer to the [shap](https://github.com/slundberg/shap/) package or approaches like [FastSHAP](https://arxiv.org/abs/2107.07436). In practice, the current implementation works for arbitrary functions of up to ~10 variables.
32 |
33 | ## Setup
34 |
35 | To install the package run
36 |
37 | ```
38 | pip install nshap
39 | ```
40 |
41 | ## Computing Interaction Indices
42 |
43 | Let's assume that we have trained a Gradient Boosted Tree on the [Folktables](https://github.com/zykls/folktables) Income data set.
44 |
45 | ```python
46 | gbtree = xgboost.XGBClassifier()
47 | gbtree.fit(X_train, Y_train)
48 | print(f'Accuracy: {accuracy_score(Y_test, gbtree.predict(X_test)):0.3f}')
49 | ```
50 | ```Accuracy: 0.830```
51 |
52 | Now we want to compute an interaction index. This package supports interaction indices that extend the Shapley Value. This means that the interaction index is based on a value function, just as the Shapley Value. So we need to define a value function. We can use the function ```nshap.vfunc.interventional_shap```, which approximates the interventional SHAP value function.
53 |
54 | ```python
55 | import nshap
56 |
57 | vfunc = nshap.vfunc.interventional_shap(gbtree.predict_proba, X_train, target=0, num_samples=1000)
58 | ```
59 | The function takes 4 arguments
60 |
61 | - The function that we want to explain
62 | - The training data or another sample from the data distribution
63 | - The target class (required here since 'predict_proba' has 2 outputs).
64 | - The number of samples that should be used to estimate the expectation (Default: 1000)
65 |
66 | Equipped with a value function, we can compute different kinds of interaction indices. We can compute n-Shapley Values
67 |
68 | ```python
69 | n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc, n=8)
70 | ```
71 |
72 | the Shapley-Taylor interaction index
73 |
74 | ```python
75 | shapley_taylor = nshap.shapley_taylor(X_test[0, :], vfunc, n=8)
76 | ```
77 |
78 | or the Faith-Shap interaction index of order 3
79 |
80 | ```python
81 | faith_shap = nshap.faith_shap(X_test[0, :], vfunc, n=3)
82 | ```
83 |
84 | The functions that compute interaction indices have a common interface. They take 3 arguments
85 |
86 | - ```x```: The data point for which to compute the explanation ([numpy.ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html))
87 |
88 | - ```v_func```: The value function.
89 |
90 | - ```n```: The order of the interaction index. Defaults to the number of features.
91 |
92 | All functions return an object of type ```InteractionIndex```. To get the interaction between features 2 and 3, simply call
93 |
94 | ```python
95 | n_shapley_values[(2,3)]
96 | ```
97 |
98 | ```-0.0054```
99 |
100 | To visualize an interaction index, call
101 |
102 | ```python
103 | n_shapley_values.plot(feature_names = feature_names)
104 | ```
105 |
106 |
107 |
108 |
109 |
110 | This works for all interaction indices
111 |
112 | ```python
113 | faith_shap.plot(feature_names = feature_names)
114 | ```
115 |
116 |
117 |
118 |
119 |
120 | For n-Shapley Values, we can compute interaction indices of lower order from those of higher order
121 |
122 | ```python
123 | n_shapley_values.k_shapley_values(2).plot(feature_names = feature_names)
124 | ```
125 |
126 |
127 |
128 |
129 |
130 | We can also obtain the original Shapley Values and plot them with the plotting functions from the [shap](https://github.com/slundberg/shap/) package.
131 |
132 | ```python
133 | import shap
134 |
135 | shap.force_plot(vfunc(X_test[0,:], []), n_shapley_values.shapley_values())
136 | ```
137 |
138 |
139 |
140 |
141 |
142 | Let us compare our result to the Shapley Values obtained from the KernelSHAP Algorithm.
143 |
144 | ```python
145 | explainer = shap.KernelExplainer(gbtree.predict_proba, shap.kmeans(X_train, 25))
146 | shap.force_plot(explainer.expected_value[0], shap_values[0])
147 | ```
148 |
149 |
150 |
151 |
152 |
153 | ## The ```InteractionIndex``` class
154 |
155 | The ```InteractionIndex``` class is a python ```dict``` with some added functionallity. It supports the following operations.
156 |
157 | - The individual attributions can be indexed with tuples of integers. For example, indexing with ```(0,)``` returns the main effect of the first feature. Indexing with ```(0,1,2)``` returns the interaction effect between features 0, 1 and 2.
158 |
159 | - ```plot()``` generates the plots described in the paper.
160 |
161 | - ```sum()``` sums the individual attributions (this does usually sum to the function value minus the value of the empty coalition)
162 |
163 | - ```save(fname)``` serializes the object to json. Can be loaded from there with ```nshap.load(fname)```. This can be useful since computing interaction indices takes time, so you might want to compute them in parallel, then aggregate the results for analysis.
164 |
165 | Some function can only be called certain interaction indices:
166 |
167 | - ```k_shapley_values(k)``` computes the $k$-Shapley Values using the recursive relationship among $n$-Shapley Values of different order (requires $k\leq n$). Can only be called for $n$-Shapley Values.
168 |
169 | - ```shapley_values()``` returns the associated original Shapley Values as a list. Useful for compatiblity with the [shap](https://github.com/slundberg/shap/) package.
170 |
171 | ## Definig Value Functions
172 |
173 | A value function has to follow the interface ```v_func(x, S)``` where ```x``` is a single data point (a [numpy.ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html)) and ```S``` is a python ```list``` with the indices the the coordinates that belong to the coaltion.
174 |
175 | In the introductory example with the Gradient Boosted Tree,
176 |
177 | ```python
178 | vfunc(x, [])
179 | ```
180 |
181 | returns the expected predicted probability that an observation belongs to class 0, and
182 |
183 | ```python
184 | vfunc(x, [0,1,2,3,4,5,6,7])
185 | ```
186 |
187 | returns the predicted probability that the observation ```x``` belongs to class 0 (note that the problem is 8-dimensional).
188 |
189 | ## Implementation Details
190 |
191 | At the moment all functions computes interaction indices simply via their definition. Independent of the order ```n``` of the $n$-Shapley Values, this requires to call the value function ```v_func``` once for all $2^d$ subsets of coordinates. Thus, the current implementation provides no essential speedup for the computation of $n$-Shapley Values of lower order.
192 |
193 | The function ```nshap.vfunc.interventional_shap``` approximates the interventional SHAP value function by intervening on the coordinates of randomly sampled points from the data distributions.
194 |
195 | ## Accuray of the computed interaction indices
196 |
197 | The computed interaction indices are an estimate which can be inaccurate.
198 |
199 | The estimation error depends on the precision of the value function. With the provided implementation of the interventional SHAP value function, the precision depends on the number of samples used to estimate the expectation.
200 |
201 | A simple way to test whether your result is precisely estimated to increase the number of samples (the ```num_samples``` parameter of ```nshap.vfunc.interventional_shap```) and see if the result changes.
202 |
203 | For more details, check out the discussion in [Section 8 of our paper](http://arxiv.org/abs/2209.04012).
204 |
205 | ## Replicating the Results in our Paper
206 |
207 | The folder ```notebooks\replicate-paper``` contains Jupyter Notebooks that allow to replicated the results in our [paper](http://arxiv.org/abs/2209.04012).
208 |
209 | - The notebooks ```figures.ipynb``` and ```checkerboard-figures.ipynb``` generate all the figures in the paper.
210 | - The notebook ```estimation.ipynb ``` provides the estimation example with the kNN classifier on the Folktables Travel data set that we discuss in Appendix Section B.
211 | - The notebook ```hyperparameters.ipynb``` cross-validates the parameter $k$ of the kNN classifier.
212 | - The notebooks ```compute.ipynb```, ```compute-vfunc.ipynb```, ```checkerboard-compute.ipynb``` and ```checkerboard-compute-million.ipynb``` compute the different $n$-Shapley Values. You do not have to run these notebooks, the pre-computed results can be downloaded [here](https://nextcloud.tuebingen.mpg.de/index.php/s/SsowoR7SAibQYE7).
213 |
214 | ⚠️ Important
215 |
216 | You have use version 0.1.0 of this package in order to run the notebooks that replicate the results in the paper.
217 |
218 | ```
219 | pip install nshap=0.1.0
220 | ```
221 |
222 | ## Citing nshap
223 |
224 | If you use this software in your research, we encourage you to cite our paper.
225 |
226 | ```bib
227 | @inproceedings{bordtlux2023,
228 | author = {Bordt, Sebastian and von Luxburg, Ulrike},
229 | title = {From Shapley Values to Generalized Additive Models and back},
230 | booktitle = {AISTATS},
231 | year = {2023}
232 | }
233 | ```
234 |
235 | If you use interaction indices that were introduced in other works, such as [Shapley Taylor](https://arxiv.org/abs/1902.05622) or [Faith-Shap](https://arxiv.org/abs/2203.00870), you should also consider to cite the respective papers.
236 |
237 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('../src/nshap/'))
16 | sys.path.insert(0, os.path.abspath('../src/'))
17 |
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = 'nshap'
22 | copyright = '2023, Sebastian Bordt'
23 | author = 'Sebastian Bordt'
24 | release = '0.2.0'
25 |
26 | # -- General configuration ---------------------------------------------------
27 |
28 | # Add any Sphinx extension module names here, as strings. They can be
29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30 | # ones.
31 | extensions = [
32 | 'sphinx.ext.duration',
33 | 'sphinx.ext.doctest',
34 | 'sphinx.ext.autodoc',
35 | 'sphinx.ext.napoleon',
36 | 'nbsphinx'
37 | ]
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ['_templates']
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
46 |
47 |
48 | # -- Options for HTML output -------------------------------------------------
49 |
50 | # The theme to use for HTML and HTML Help pages. See the documentation for
51 | # a list of builtin themes.
52 | #
53 | html_theme = 'sphinx_rtd_theme'
54 |
55 | # Add any paths that contain custom static files (such as style sheets) here,
56 | # relative to this directory. They are copied after the builtin static files,
57 | # so a file named "default.css" will overwrite the builtin "default.css".
58 | #html_static_path = ['_static']
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. nshap documentation master file, created by
2 | sphinx-quickstart on Fri Oct 28 14:04:18 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to the Documentation of the nshap Package!
7 | ==================================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | usage/installation
14 | usage/examples
15 | nshap
16 |
17 |
18 | Indices and tables
19 | ==================
20 |
21 | * :ref:`genindex`
22 | * :ref:`modindex`
23 | * :ref:`search`
24 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/nshap.rst:
--------------------------------------------------------------------------------
1 | Documentation
2 | =============
3 |
4 | Functions that compute interaction indices
5 | ------------------------------------------
6 |
7 | .. automodule:: nshap.functions
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | Value Functions
13 | ---------------
14 |
15 | .. automodule:: nshap.vfunc
16 | :members:
17 | :undoc-members:
18 | :show-inheritance:
19 |
20 | The InteractionIndex class
21 | ------------------------
22 |
23 | .. autoclass:: nshap.InteractionIndex.InteractionIndex
24 | :members:
25 | :undoc-members:
26 | :show-inheritance:
27 |
28 | Plotting and Utilities
29 | ----------------------
30 |
31 | .. automodule:: nshap.plot
32 | :members:
33 | :undoc-members:
34 | :show-inheritance:
35 |
36 | .. automodule:: nshap.util
37 | :members:
38 | :undoc-members:
39 | :show-inheritance:
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/docs/usage/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | --------
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Contents:
7 |
8 | notebooks/Example1
9 | notebooks/Example2
10 | notebooks/Example3
--------------------------------------------------------------------------------
/docs/usage/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | To install the package via pip run
5 |
6 | :code:`pip install nshap`
7 |
8 | then you can use
9 |
10 | :code:`import nshap`
11 |
12 | to import the package.
--------------------------------------------------------------------------------
/docs/usage/notebooks/Example3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Example 3: Custom value functions\n",
9 | "#### These examples are from section 5.1 of the Faith-Shap paper: https://arxiv.org/abs/2203.00870"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "id": "c10cfd76",
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import seaborn as sns\n",
20 | "\n",
21 | "sns.set_style(\"whitegrid\")\n",
22 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n",
23 | "\n",
24 | "import numpy as np\n",
25 | "import math\n",
26 | "\n",
27 | "import nshap"
28 | ]
29 | },
30 | {
31 | "attachments": {},
32 | "cell_type": "markdown",
33 | "id": "f0b5f063",
34 | "metadata": {},
35 | "source": [
36 | "### Example 1\n",
37 | "##### A value function takes two arguments: A single data point x (a numpy.ndarray) and a python list S with the indices the the coordinates that belong to the coaltion.\n",
38 | "##### In this example, the value does not actually depend on the point x"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 7,
44 | "id": "44978e93",
45 | "metadata": {
46 | "tags": []
47 | },
48 | "outputs": [
49 | {
50 | "data": {
51 | "text/plain": [
52 | "3.4"
53 | ]
54 | },
55 | "execution_count": 7,
56 | "metadata": {},
57 | "output_type": "execute_result"
58 | }
59 | ],
60 | "source": [
61 | "p = 0.1\n",
62 | "def v_func(x, S):\n",
63 | " \"\"\" The value function from Example 1 in the Faith-Shap paper.\n",
64 | " \"\"\"\n",
65 | " if len(S) <= 1:\n",
66 | " return 0\n",
67 | " return len(S) - p * math.comb(len(S), 2)\n",
68 | "\n",
69 | "\n",
70 | "v_func(np.zeros((1,11)), [1,2,3,4])"
71 | ]
72 | },
73 | {
74 | "attachments": {},
75 | "cell_type": "markdown",
76 | "id": "86a0f974",
77 | "metadata": {},
78 | "source": [
79 | "##### Equipped with the value function, we can compute different kinds of interaction indices"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "fbbbbcf5",
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, 1)\n",
90 | "faith_shap[(0,1)], faith_shap[(1,2)]"
91 | ]
92 | },
93 | {
94 | "attachments": {},
95 | "cell_type": "markdown",
96 | "id": "08657d28",
97 | "metadata": {},
98 | "source": [
99 | "##### We can replicate Table 1 in the Faith-Shap paper"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 13,
105 | "id": "d4bcee91",
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "Table 1 in the Faith-Shap paper:\n",
113 | "\n",
114 | "p=0.1 l=1\n",
115 | "\n",
116 | "Faith-Shap: 0.5000000000001867\n",
117 | "Shapley Taylor: 0.500000000000185\n",
118 | "Interaction Shapley: 0.4999999999999939\n",
119 | "Banzhaf Interaction: 0.5087890624999979\n",
120 | "Faith-Banzhaf: 0.5087890624999989\n",
121 | "\n",
122 | "p=0.1 l=2\n",
123 | "\n",
124 | "Faith-Shap: 0.9545454545463214 -0.09090909090923938\n",
125 | "Shapley Taylor: 0 0.10000000000002415\n",
126 | "Interaction Shapley: 0.4999999999999939 -5.759281940243e-16\n",
127 | "Banzhaf Interaction: 0.5087890624999979 -0.11367187500000073\n",
128 | "Faith-Banzhaf: 1.0771484374999503 -0.11367187499999737\n",
129 | "\n",
130 | "p=0.2 l=1\n",
131 | "\n",
132 | "Faith-Shap: 1.6486811915683575e-13\n",
133 | "Shapley Taylor: 1.7552626019323725e-13\n",
134 | "Interaction Shapley: 2.7755575615628914e-17\n",
135 | "Banzhaf Interaction: 0.0087890625\n",
136 | "Faith-Banzhaf: 0.008789062500008604\n",
137 | "\n",
138 | "p=0.2 l=2\n",
139 | "\n",
140 | "Faith-Shap: 0.9545454545460967 -0.19090909090925617\n",
141 | "Shapley Taylor: 0 -3.7941871866564725e-14\n",
142 | "Interaction Shapley: 2.7755575615628914e-17 -0.10000000000000021\n",
143 | "Banzhaf Interaction: 0.0087890625 -0.21367187500000198\n",
144 | "Faith-Banzhaf: 1.0771484374999574 -0.21367187499998952\n",
145 | "\n"
146 | ]
147 | }
148 | ],
149 | "source": [
150 | "print('Table 1 in the Faith-Shap paper:\\n')\n",
151 | "for p in [0.1, 0.2]:\n",
152 | " for l in [1, 2]:\n",
153 | " print(f'p={p} l={l}\\n')\n",
154 | "\n",
155 | " # define the value function\n",
156 | " def v_func(x, S):\n",
157 | " if len(S) <= 1:\n",
158 | " return 0\n",
159 | " return len(S) - p * math.comb(len(S), 2)\n",
160 | "\n",
161 | " # compute interaction indices\n",
162 | " faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, l)\n",
163 | " shapley_taylor = nshap.shapley_taylor(np.zeros((1,11)), v_func, l)\n",
164 | " shapley_interaction = nshap.shapley_interaction_index(np.zeros((1,11)), v_func, l)\n",
165 | " banzhaf_interaction = nshap.banzhaf_interaction_index(np.zeros((1,11)), v_func, l)\n",
166 | " faith_banzhaf = nshap.faith_banzhaf(np.zeros((1,11)), v_func, l)\n",
167 | "\n",
168 | " # print result\n",
169 | " if l == 1:\n",
170 | " print('Faith-Shap: ', faith_shap[(0,)])\n",
171 | " print('Shapley Taylor: ', shapley_taylor[(0,)])\n",
172 | " print('Interaction Shapley: ', shapley_interaction[(0,)])\n",
173 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)])\n",
174 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)])\n",
175 | " else:\n",
176 | " print('Faith-Shap: ', faith_shap[(0,)], faith_shap[(0,1)])\n",
177 | " print('Shapley Taylor: ', shapley_taylor[(0,)], shapley_taylor[(0,1)])\n",
178 | " print('Interaction Shapley: ', shapley_interaction[(0,)], shapley_interaction[(0,1)])\n",
179 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)], banzhaf_interaction[(0,1)])\n",
180 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)], faith_banzhaf[(0,1)])\n",
181 | " print('')"
182 | ]
183 | }
184 | ],
185 | "metadata": {
186 | "kernelspec": {
187 | "display_name": "Python 3.9.12 ('base')",
188 | "language": "python",
189 | "name": "python3"
190 | },
191 | "language_info": {
192 | "codemirror_mode": {
193 | "name": "ipython",
194 | "version": 3
195 | },
196 | "file_extension": ".py",
197 | "mimetype": "text/x-python",
198 | "name": "python",
199 | "nbconvert_exporter": "python",
200 | "pygments_lexer": "ipython3",
201 | "version": "3.9.12"
202 | },
203 | "vscode": {
204 | "interpreter": {
205 | "hash": "62406f3c2942480da828869ab3f3f95d1c0177b3689d5bc770f3ddfd7b9b3df5"
206 | }
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 5
211 | }
212 |
--------------------------------------------------------------------------------
/images/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img1.png
--------------------------------------------------------------------------------
/images/img2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img2.png
--------------------------------------------------------------------------------
/images/img3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img3.png
--------------------------------------------------------------------------------
/images/img4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img4.png
--------------------------------------------------------------------------------
/images/img5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img5.png
--------------------------------------------------------------------------------
/images/img6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tml-tuebingen/nshap/d2feb1328911c66dd9027e692a5d3d02c2c919ad/images/img6.png
--------------------------------------------------------------------------------
/notebooks/Example3.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Example 3: Custom value functions\n",
9 | "#### These examples are from section 5.1 of the Faith-Shap paper: https://arxiv.org/abs/2203.00870"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "id": "c10cfd76",
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import seaborn as sns\n",
20 | "\n",
21 | "sns.set_style(\"whitegrid\")\n",
22 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n",
23 | "\n",
24 | "import numpy as np\n",
25 | "import math\n",
26 | "\n",
27 | "import nshap"
28 | ]
29 | },
30 | {
31 | "attachments": {},
32 | "cell_type": "markdown",
33 | "id": "f0b5f063",
34 | "metadata": {},
35 | "source": [
36 | "### Example 1\n",
37 | "##### A value function takes two arguments: A single data point x (a numpy.ndarray) and a python list S with the indices the the coordinates that belong to the coaltion.\n",
38 | "##### In this example, the value does not actually depend on the point x"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 7,
44 | "id": "44978e93",
45 | "metadata": {
46 | "tags": []
47 | },
48 | "outputs": [
49 | {
50 | "data": {
51 | "text/plain": [
52 | "3.4"
53 | ]
54 | },
55 | "execution_count": 7,
56 | "metadata": {},
57 | "output_type": "execute_result"
58 | }
59 | ],
60 | "source": [
61 | "p = 0.1\n",
62 | "def v_func(x, S):\n",
63 | " \"\"\" The value function from Example 1 in the Faith-Shap paper.\n",
64 | " \"\"\"\n",
65 | " if len(S) <= 1:\n",
66 | " return 0\n",
67 | " return len(S) - p * math.comb(len(S), 2)\n",
68 | "\n",
69 | "\n",
70 | "v_func(np.zeros((1,11)), [1,2,3,4])"
71 | ]
72 | },
73 | {
74 | "attachments": {},
75 | "cell_type": "markdown",
76 | "id": "86a0f974",
77 | "metadata": {},
78 | "source": [
79 | "##### Equipped with the value function, we can compute different kinds of interaction indices"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "fbbbbcf5",
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, 1)\n",
90 | "faith_shap[(0,1)], faith_shap[(1,2)]"
91 | ]
92 | },
93 | {
94 | "attachments": {},
95 | "cell_type": "markdown",
96 | "id": "08657d28",
97 | "metadata": {},
98 | "source": [
99 | "##### We can replicate Table 1 in the Faith-Shap paper"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 13,
105 | "id": "d4bcee91",
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "Table 1 in the Faith-Shap paper:\n",
113 | "\n",
114 | "p=0.1 l=1\n",
115 | "\n",
116 | "Faith-Shap: 0.5000000000001867\n",
117 | "Shapley Taylor: 0.500000000000185\n",
118 | "Interaction Shapley: 0.4999999999999939\n",
119 | "Banzhaf Interaction: 0.5087890624999979\n",
120 | "Faith-Banzhaf: 0.5087890624999989\n",
121 | "\n",
122 | "p=0.1 l=2\n",
123 | "\n",
124 | "Faith-Shap: 0.9545454545463214 -0.09090909090923938\n",
125 | "Shapley Taylor: 0 0.10000000000002415\n",
126 | "Interaction Shapley: 0.4999999999999939 -5.759281940243e-16\n",
127 | "Banzhaf Interaction: 0.5087890624999979 -0.11367187500000073\n",
128 | "Faith-Banzhaf: 1.0771484374999503 -0.11367187499999737\n",
129 | "\n",
130 | "p=0.2 l=1\n",
131 | "\n",
132 | "Faith-Shap: 1.6486811915683575e-13\n",
133 | "Shapley Taylor: 1.7552626019323725e-13\n",
134 | "Interaction Shapley: 2.7755575615628914e-17\n",
135 | "Banzhaf Interaction: 0.0087890625\n",
136 | "Faith-Banzhaf: 0.008789062500008604\n",
137 | "\n",
138 | "p=0.2 l=2\n",
139 | "\n",
140 | "Faith-Shap: 0.9545454545460967 -0.19090909090925617\n",
141 | "Shapley Taylor: 0 -3.7941871866564725e-14\n",
142 | "Interaction Shapley: 2.7755575615628914e-17 -0.10000000000000021\n",
143 | "Banzhaf Interaction: 0.0087890625 -0.21367187500000198\n",
144 | "Faith-Banzhaf: 1.0771484374999574 -0.21367187499998952\n",
145 | "\n"
146 | ]
147 | }
148 | ],
149 | "source": [
150 | "print('Table 1 in the Faith-Shap paper:\\n')\n",
151 | "for p in [0.1, 0.2]:\n",
152 | " for l in [1, 2]:\n",
153 | " print(f'p={p} l={l}\\n')\n",
154 | "\n",
155 | " # define the value function\n",
156 | " def v_func(x, S):\n",
157 | " if len(S) <= 1:\n",
158 | " return 0\n",
159 | " return len(S) - p * math.comb(len(S), 2)\n",
160 | "\n",
161 | " # compute interaction indices\n",
162 | " faith_shap = nshap.faith_shap(np.zeros((1,11)), v_func, l)\n",
163 | " shapley_taylor = nshap.shapley_taylor(np.zeros((1,11)), v_func, l)\n",
164 | " shapley_interaction = nshap.shapley_interaction_index(np.zeros((1,11)), v_func, l)\n",
165 | " banzhaf_interaction = nshap.banzhaf_interaction_index(np.zeros((1,11)), v_func, l)\n",
166 | " faith_banzhaf = nshap.faith_banzhaf(np.zeros((1,11)), v_func, l)\n",
167 | "\n",
168 | " # print result\n",
169 | " if l == 1:\n",
170 | " print('Faith-Shap: ', faith_shap[(0,)])\n",
171 | " print('Shapley Taylor: ', shapley_taylor[(0,)])\n",
172 | " print('Interaction Shapley: ', shapley_interaction[(0,)])\n",
173 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)])\n",
174 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)])\n",
175 | " else:\n",
176 | " print('Faith-Shap: ', faith_shap[(0,)], faith_shap[(0,1)])\n",
177 | " print('Shapley Taylor: ', shapley_taylor[(0,)], shapley_taylor[(0,1)])\n",
178 | " print('Interaction Shapley: ', shapley_interaction[(0,)], shapley_interaction[(0,1)])\n",
179 | " print('Banzhaf Interaction: ', banzhaf_interaction[(0,)], banzhaf_interaction[(0,1)])\n",
180 | " print('Faith-Banzhaf: ', faith_banzhaf[(0,)], faith_banzhaf[(0,1)])\n",
181 | " print('')"
182 | ]
183 | }
184 | ],
185 | "metadata": {
186 | "kernelspec": {
187 | "display_name": "Python 3.9.12 ('base')",
188 | "language": "python",
189 | "name": "python3"
190 | },
191 | "language_info": {
192 | "codemirror_mode": {
193 | "name": "ipython",
194 | "version": 3
195 | },
196 | "file_extension": ".py",
197 | "mimetype": "text/x-python",
198 | "name": "python",
199 | "nbconvert_exporter": "python",
200 | "pygments_lexer": "ipython3",
201 | "version": "3.9.12"
202 | },
203 | "vscode": {
204 | "interpreter": {
205 | "hash": "62406f3c2942480da828869ab3f3f95d1c0177b3689d5bc770f3ddfd7b9b3df5"
206 | }
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 5
211 | }
212 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/checkerboard-compute-million.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Compute for the checkerboard function"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "tags": [
15 | "parameters"
16 | ]
17 | },
18 | "outputs": [],
19 | "source": [
20 | "# papermill parameter: notebook id\n",
21 | "aid = 0"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import matplotlib.pyplot as plt\n",
31 | "import seaborn as sns\n",
32 | "\n",
33 | "import numpy as np\n",
34 | "import os\n",
35 | "\n",
36 | "sns.set_style(\"whitegrid\")\n",
37 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n",
38 | "\n",
39 | "from itertools import product\n",
40 | "\n",
41 | "import nshap\n",
42 | "\n",
43 | "from paperutil import checkerboard_function"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "### The different compute jobs, and the current job"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "interaction_orders = [2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
60 | "replications = list(range(10))\n",
61 | "\n",
62 | "all_jobs = list(product(interaction_orders, replications))\n",
63 | "print(len(all_jobs), 'different compute jobs')"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "interaction_order = all_jobs[aid][0]\n",
73 | "replication = all_jobs[aid][1]"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "metadata": {},
79 | "source": [
80 | "### Create output dir structure, if it does not already exist"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "paths = ['../../results/', \n",
90 | " '../../results/n_shapley_values/', \n",
91 | " '../../results/n_shapley_values/checkerboard/']\n",
92 | "for p in paths:\n",
93 | " if not os.path.exists( p ):\n",
94 | " os.mkdir( p )"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {},
100 | "source": [
101 | "### Compute n-Shapley Values for a k-dimensional checkerboard in a 10-dimensional space"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "d = 10\n",
111 | "f = checkerboard_function(interaction_order, num_checkers=100)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "tags": []
119 | },
120 | "outputs": [],
121 | "source": [
122 | "np.random.seed(replication)\n",
123 | "\n",
124 | "for num_samples in [1000000]:\n",
125 | " X_train = np.random.uniform(0, 1, size=(1000000, d))\n",
126 | " X_test = np.random.uniform(0, 1, size=(1, d))\n",
127 | " vfunc = vfunc = nshap.vfunc.interventional_shap(f, X_train, num_samples = num_samples, random_state=replication)\n",
128 | " n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc)\n",
129 | " n_shapley_values.save(f'../../results/n_shapley_values/checkerboard/{interaction_order}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON')\n",
130 | " n_shapley_values.plot()"
131 | ]
132 | }
133 | ],
134 | "metadata": {
135 | "kernelspec": {
136 | "display_name": "Python 3",
137 | "language": "python",
138 | "name": "python3"
139 | },
140 | "language_info": {
141 | "codemirror_mode": {
142 | "name": "ipython",
143 | "version": 3
144 | },
145 | "file_extension": ".py",
146 | "mimetype": "text/x-python",
147 | "name": "python",
148 | "nbconvert_exporter": "python",
149 | "pygments_lexer": "ipython3",
150 | "version": "3.8.3"
151 | }
152 | },
153 | "nbformat": 4,
154 | "nbformat_minor": 5
155 | }
156 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/checkerboard-compute.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Compute for the checkerboard function"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "tags": [
15 | "parameters"
16 | ]
17 | },
18 | "outputs": [],
19 | "source": [
20 | "# papermill parameter: notebook id\n",
21 | "aid = 0"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import matplotlib.pyplot as plt\n",
31 | "import seaborn as sns\n",
32 | "\n",
33 | "import numpy as np\n",
34 | "import os\n",
35 | "\n",
36 | "sns.set_style(\"whitegrid\")\n",
37 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n",
38 | "\n",
39 | "import nshap\n",
40 | "\n",
41 | "from paperutil import checkerboard_function"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "interaction_order = [2, 3, 4, 5, 6, 7, 8, 9, 10]\n",
51 | "\n",
52 | "all_jobs = list(interaction_order)\n",
53 | "print(len(all_jobs), 'different compute jobs')"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "### Create output dir structure, if it does not already exist"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "paths = ['../../results/', \n",
70 | " '../../results/n_shapley_values/', \n",
71 | " '../../results/n_shapley_values/checkerboard/']\n",
72 | "for p in paths:\n",
73 | " if not os.path.exists( p ):\n",
74 | " os.mkdir( p )"
75 | ]
76 | },
77 | {
78 | "cell_type": "markdown",
79 | "metadata": {},
80 | "source": [
81 | "### Compute n-Shapley Values for a k-dimensional checkerboard in a 10-dimensional space"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "d = 10\n",
91 | "f = checkerboard_function(interaction_order[aid], num_checkers=100)"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {
98 | "tags": []
99 | },
100 | "outputs": [],
101 | "source": [
102 | "for num_samples in [100, 1000, 10000, 100000]:\n",
103 | " result = []\n",
104 | " for replication in range(10): \n",
105 | " X_train = np.random.uniform(0, 1, size=(1000000, d))\n",
106 | " X_test = np.random.uniform(0, 1, size=(1, d))\n",
107 | " vfunc = vfunc = nshap.vfunc.interventional_shap(f, X_train, num_samples = num_samples)\n",
108 | " n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc)\n",
109 | " n_shapley_values.save(f'../../results/n_shapley_values/checkerboard/{interaction_order[aid]}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON')\n",
110 | " n_shapley_values.plot()\n",
111 | " true_order = np.sum([np.abs(v) for k, v in n_shapley_values.items() if len(k) == interaction_order[aid]]) \n",
112 | " all_orders = np.sum([np.abs(v) for k, v in n_shapley_values.items()])\n",
113 | " result.append(true_order / all_orders)\n",
114 | " print(f'{interaction_order[aid]}, {num_samples}, {np.mean(result)}, {result}')"
115 | ]
116 | }
117 | ],
118 | "metadata": {
119 | "kernelspec": {
120 | "display_name": "Python 3",
121 | "language": "python",
122 | "name": "python3"
123 | },
124 | "language_info": {
125 | "codemirror_mode": {
126 | "name": "ipython",
127 | "version": 3
128 | },
129 | "file_extension": ".py",
130 | "mimetype": "text/x-python",
131 | "name": "python",
132 | "nbconvert_exporter": "python",
133 | "pygments_lexer": "ipython3",
134 | "version": "3.8.3"
135 | }
136 | },
137 | "nbformat": 4,
138 | "nbformat_minor": 5
139 | }
140 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/checkerboard-figures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Figures for the checkerboard function"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import matplotlib.pyplot as plt\n",
17 | "import seaborn as sns\n",
18 | "\n",
19 | "sns.set_style(\"whitegrid\")\n",
20 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.5)\n",
21 | "\n",
22 | "import numpy as np\n",
23 | "import nshap\n",
24 | "import os"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "### Visualization in 2d"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from paperutil import checkerboard_function\n",
41 | "\n",
42 | "d = 2\n",
43 | "num_points = 1000\n",
44 | "f = checkerboard_function(2, 4)\n",
45 | "X = np.random.uniform(0, 1, size=(num_points, d))\n",
46 | "Y = f(X)\n",
47 | "\n",
48 | "sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=Y)\n",
49 | "plt.show()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "### Load the pre-computed n-Shpaley Values"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {
63 | "tags": []
64 | },
65 | "outputs": [],
66 | "source": [
67 | "results = {}\n",
68 | "for degree in range(2, 11):\n",
69 | " results[degree] = {}\n",
70 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n",
71 | " results[degree][num_samples] = []\n",
72 | " for replication in range(10):\n",
73 | " fname = f'../../results/n_shapley_values/checkerboard/{degree}d_checkerboard_{num_samples}_samples_replication_{replication}.JSON'\n",
74 | " if os.path.exists(fname):\n",
75 | " results[degree][num_samples].append( nshap.load(fname) )"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {},
81 | "source": [
82 | "### Correctly estimated fractions"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "tags": []
90 | },
91 | "outputs": [],
92 | "source": [
93 | "sns.set_style(\"whitegrid\")\n",
94 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 2}, font_scale=2)\n",
95 | "\n",
96 | "plt.figure(figsize=(12.5, 8))\n",
97 | "for checkerboard_dim in range(2, 11):\n",
98 | " y = []\n",
99 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n",
100 | " yy = []\n",
101 | " for x in results[checkerboard_dim][num_samples]:\n",
102 | " true_order = np.sum([np.abs(v) for k, v in x.items() if len(k) == checkerboard_dim]) \n",
103 | " all_orders = np.sum([np.abs(v) for k, v in x.items()])\n",
104 | " yy.append(true_order / all_orders)\n",
105 | " y.append(np.mean(yy))\n",
106 | " sns.scatterplot(x=[1,2,3,4,5], y=y, color=sns.color_palette(\"tab10\")[checkerboard_dim-1], s=200)\n",
107 | " plt.plot([1,2,3,4,5], y, c=sns.color_palette(\"colorblind\")[checkerboard_dim-1], ls='--', lw=1.5)\n",
108 | "plt.ylim([-0.04, 1.04])\n",
109 | "plt.yticks(np.arange(0, 1.1, 0.1))\n",
110 | "plt.xticks([1,2,3,4,5], ['100', '1000', '10 000', '100 000', '1 000 000'])\n",
111 | "plt.xlabel('Number of points sampled to estimate the value function')\n",
112 | "plt.ylabel('Fraction of checkerboard function\\nthat is correctly estimated')\n",
113 | "plt.savefig('../../figures/checkerboard_estimation.pdf')\n",
114 | "plt.show()"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "### Plots for individual degrees"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "sns.set_style(\"whitegrid\")\n",
131 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 2}, font_scale=2)\n",
132 | "\n",
133 | "for checkerboard_dim in range(2, 11): \n",
134 | " plt.figure(figsize=(6, 6))\n",
135 | " plt.ylim([-0.04, 1.04])\n",
136 | " for interaction_order in range(1,11):\n",
137 | " y, y_min, y_max = [], [], [] \n",
138 | " for num_samples in [100, 1000, 10000, 100000, 1000000]:\n",
139 | " yy = []\n",
140 | " for x in results[checkerboard_dim][num_samples]:\n",
141 | " order_sum = np.sum([np.abs(v) for k, v in x.items() if len(k) == interaction_order]) \n",
142 | " all_sum = np.sum([np.abs(v) for k, v in x.items()])\n",
143 | " yy.append(order_sum / all_sum)\n",
144 | " y.append(np.mean(yy))\n",
145 | " y_min.append(np.min(yy))\n",
146 | " y_max.append(np.max(yy))\n",
147 | " if np.max(y) > 0.01:\n",
148 | " sns.scatterplot(x=[1,2,3,4,5], y=y, color=sns.color_palette(\"tab10\")[interaction_order-1], s=200)\n",
149 | " plt.plot([1,2,3,4,5], y, c=sns.color_palette(\"colorblind\")[interaction_order-1], ls='--', lw=1.5)\n",
150 | " # convert y_min and y_max from y coordinates to plot range\n",
151 | " y_min = [x/1.08+0.04 for x in y_min]\n",
152 | " y_max = [x/1.08+0.04 for x in y_max]\n",
153 | " for x_pos in [1,2,3,4,5]:\n",
154 | " plt.axvline(x=x_pos, ymin=y_min[x_pos-1], ymax=y_max[x_pos-1], c=sns.color_palette(\"colorblind\")[interaction_order-1], lw=4)\n",
155 | " plt.yticks(np.arange(0, 1.1, 0.1))\n",
156 | " plt.xticks([1,2,3,4,5], ['100', '1000', '10 000', '100 000', '1 000 000'], rotation=20)\n",
157 | " plt.title(f'{checkerboard_dim}d Checkerboard Function')\n",
158 | " plt.xlabel('Number of points sampled')\n",
159 | " plt.ylabel('Order of estimated effects')\n",
160 | " plt.tight_layout()\n",
161 | " plt.savefig(f'../../figures/{checkerboard_dim}_checkerboard_estimation.pdf')\n",
162 | " plt.show()"
163 | ]
164 | }
165 | ],
166 | "metadata": {
167 | "kernelspec": {
168 | "display_name": "Python 3",
169 | "language": "python",
170 | "name": "python3"
171 | },
172 | "language_info": {
173 | "codemirror_mode": {
174 | "name": "ipython",
175 | "version": 3
176 | },
177 | "file_extension": ".py",
178 | "mimetype": "text/x-python",
179 | "name": "python",
180 | "nbconvert_exporter": "python",
181 | "pygments_lexer": "ipython3",
182 | "version": "3.8.3"
183 | }
184 | },
185 | "nbformat": 4,
186 | "nbformat_minor": 5
187 | }
188 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/compute-vfunc.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Notebook for parallel evaluation of the value function"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "tags": [
15 | "parameters"
16 | ]
17 | },
18 | "outputs": [],
19 | "source": [
20 | "# papermill parameter: notebook id\n",
21 | "aid = 0"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "import numpy as np\n",
31 | "import os\n",
32 | "\n",
33 | "import nshap\n",
34 | "\n",
35 | "import paperutil\n",
36 | "\n",
37 | "from itertools import product\n",
38 | "\n",
39 | "import datasets\n",
40 | "\n",
41 | "%load_ext autoreload\n",
42 | "%autoreload 2"
43 | ]
44 | },
45 | {
46 | "cell_type": "markdown",
47 | "metadata": {},
48 | "source": [
49 | "### The different compute jobs"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "subsets = list(nshap.powerset(list(range(10))))\n",
59 | "\n",
60 | "all_jobs = list(subsets)\n",
61 | "print(len(all_jobs), 'different compute jobs')"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "### The current job"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "job_id = aid\n",
78 | "dataset = 'folk_travel'\n",
79 | "classifier = 'knn'\n",
80 | "i_datapoint = 0\n",
81 | "random_seed = i_datapoint\n",
82 | "\n",
83 | "print(job_id, dataset, classifier, i_datapoint, random_seed)"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "### Load the dataset"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "is_classification = datasets.is_classification(dataset)"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "metadata": {},
114 | "source": [
115 | "### Predict, proba or decision"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "method = 'predict'\n",
125 | "if is_classification:\n",
126 | " method = 'proba'\n",
127 | "if classifier == 'gam':\n",
128 | " method = 'decision'"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "### The number of samples is limited by the size of the data set"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "max_samples = 1000000\n",
145 | "num_samples = min(max_samples, X_train.shape[0])"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "### Create output dir structure, if it does not already exist"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "froot = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_samples}/'"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": null,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "paths = ['../../results/', \n",
171 | " '../../results/n_shapley_values/'\n",
172 | " f'../../results/n_shapley_values/{dataset}/', \n",
173 | " f'../../results/n_shapley_values/{dataset}/{classifier}/',\n",
174 | " froot]\n",
175 | "for p in paths:\n",
176 | " if not os.path.exists( p ):\n",
177 | " os.mkdir( p )"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "### Train the classifier"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "clf = paperutil.train_classifier(dataset, classifier)"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {},
199 | "source": [
200 | "### The value function"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": [
209 | "if method == 'predict':\n",
210 | " vfunc = nshap.vfunc.interventional_shap(clf.predict, X_train, num_samples=num_samples, random_state=0)\n",
211 | "elif method == 'proba':\n",
212 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n",
213 | " vfunc = nshap.vfunc.interventional_shap(clf.predict_proba, X_train, num_samples=num_samples, random_state=0, target=prediction)\n",
214 | "elif method == 'decision':\n",
215 | " vfunc = nshap.vfunc.interventional_shap(clf.decision_function, X_train, num_samples=num_samples, random_state=0)"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "metadata": {},
221 | "source": [
222 | "### Evaluate the value function"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {
229 | "tags": []
230 | },
231 | "outputs": [],
232 | "source": [
233 | "for idx in range(10):\n",
234 | " S = subsets[10*job_id + idx] # 10 jobs per notebook\n",
235 | " if len(S) > 0 and np.max(S) >= X_train.shape[1]:\n",
236 | " continue\n",
237 | " fname = froot + f'v{S}.txt' \n",
238 | " # evaluate the value function and save the result\n",
239 | " if not os.path.exists(fname):\n",
240 | " result = vfunc(X_test[i_datapoint, :].reshape((1,-1)), S)\n",
241 | " with open(fname, 'w+') as f:\n",
242 | " f.write(f'{result}')"
243 | ]
244 | }
245 | ],
246 | "metadata": {
247 | "kernelspec": {
248 | "display_name": "Python 3",
249 | "language": "python",
250 | "name": "python3"
251 | },
252 | "language_info": {
253 | "codemirror_mode": {
254 | "name": "ipython",
255 | "version": 3
256 | },
257 | "file_extension": ".py",
258 | "mimetype": "text/x-python",
259 | "name": "python",
260 | "nbconvert_exporter": "python",
261 | "pygments_lexer": "ipython3",
262 | "version": "3.8.3"
263 | }
264 | },
265 | "nbformat": 4,
266 | "nbformat_minor": 5
267 | }
268 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/compute.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Notebook for parallel computation of n-Shapley Values"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {
14 | "tags": [
15 | "parameters"
16 | ]
17 | },
18 | "outputs": [],
19 | "source": [
20 | "# papermill parameter: notebook id\n",
21 | "aid = 0"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "# second compute wave\n",
31 | "#aid = aid + 1000"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import numpy as np\n",
41 | "\n",
42 | "import os\n",
43 | "\n",
44 | "import datasets\n",
45 | "import nshap\n",
46 | "\n",
47 | "from itertools import product\n",
48 | "\n",
49 | "import paperutil\n",
50 | "\n",
51 | "%load_ext autoreload\n",
52 | "%autoreload 2"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "### The different compute jobs"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']\n",
69 | "classifiers = ['rf', 'knn', 'gam', 'gbtree']\n",
70 | "examples = list(range(0, 100))\n",
71 | "\n",
72 | "all_jobs = list(product(data_sets, classifiers, examples))\n",
73 | "print(len(all_jobs), 'different compute jobs')"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "for data_set in data_sets:\n",
83 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(data_set)\n",
84 | " print(data_set, X_train.shape[0])"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "### The current job"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "job_id = aid\n",
101 | "dataset = all_jobs[job_id][0]\n",
102 | "classifier = all_jobs[job_id][1]\n",
103 | "example = all_jobs[job_id][2]\n",
104 | "random_seed = example\n",
105 | "\n",
106 | "print(job_id, dataset, classifier, example, random_seed)"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {},
112 | "source": [
113 | "### Create output dir structure, if it does not already exist"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "if not os.path.exists( f'../../results/n_shapley_values/{dataset}' ):\n",
123 | " os.mkdir( f'../../results/n_shapley_values/{dataset}' )\n",
124 | "if not os.path.exists( f'../../results/n_shapley_values/{dataset}/{classifier}' ):\n",
125 | " os.mkdir( f'../../results/n_shapley_values/{dataset}/{classifier}' )"
126 | ]
127 | },
128 | {
129 | "cell_type": "markdown",
130 | "metadata": {},
131 | "source": [
132 | "### Load the dataset"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "is_classification = datasets.is_classification(dataset)"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "metadata": {},
156 | "source": [
157 | "### Train the classifier"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": null,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "clf = paperutil.train_classifier(dataset, classifier)"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "if is_classification:\n",
176 | " print( sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test)) )\n",
177 | "else:\n",
178 | " print( sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test)) )"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "### n-Shapley Values"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {
192 | "tags": []
193 | },
194 | "outputs": [],
195 | "source": [
196 | "i_datapoint = example\n",
197 | "froot = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}'\n",
198 | "for max_samples in [500, 5000]:\n",
199 | " num_samples = min(max_samples, X_train.shape[0])\n",
200 | " # the value function\n",
201 | " vfunc = nshap.vfunc.interventional_shap(clf.predict, X_train, num_samples=num_samples, random_state=0)\n",
202 | " fname = froot + f'_predict_{num_samples}.JSON'\n",
203 | " if is_classification:\n",
204 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n",
205 | " vfunc = nshap.vfunc.interventional_shap(clf.predict_proba, X_train, num_samples=num_samples, random_state=0, target=prediction)\n",
206 | " fname = froot + f'_proba_{num_samples}.JSON'\n",
207 | " if classifier == 'gam':\n",
208 | " vfunc = nshap.vfunc.interventional_shap(clf.decision_function, X_train, num_samples=num_samples, random_state=0)\n",
209 | " fname = froot + f'_decision_{num_samples}.JSON'\n",
210 | " # compute and save n-Shapley Values\n",
211 | " if not os.path.exists(fname):\n",
212 | " n_shapley_values = nshap.n_shapley_values(X_test[i_datapoint, :].reshape((1,-1)), vfunc)\n",
213 | " n_shapley_values.save(fname)"
214 | ]
215 | }
216 | ],
217 | "metadata": {
218 | "kernelspec": {
219 | "display_name": "Python 3",
220 | "language": "python",
221 | "name": "python3"
222 | },
223 | "language_info": {
224 | "codemirror_mode": {
225 | "name": "ipython",
226 | "version": 3
227 | },
228 | "file_extension": ".py",
229 | "mimetype": "text/x-python",
230 | "name": "python",
231 | "nbconvert_exporter": "python",
232 | "pygments_lexer": "ipython3",
233 | "version": "3.8.3"
234 | }
235 | },
236 | "nbformat": 4,
237 | "nbformat_minor": 5
238 | }
239 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/datasets.py:
--------------------------------------------------------------------------------
1 | ########################################################################################################################
2 | # Load the different datasets.
3 | #
4 | # The features are scaled to have mean zero and unit variance.
5 | #
6 | # All functions return: X_train, X_test, Y_train, Y_test, feature_names
7 | ########################################################################################################################
8 |
9 | import numpy as np
10 |
11 | import sklearn
12 | import sklearn.datasets
13 | from sklearn.model_selection import train_test_split
14 | from sklearn.preprocessing import StandardScaler
15 |
16 | import pandas as pd
17 |
18 | from folktables import ACSDataSource, ACSIncome, ACSTravelTime
19 |
20 | import os
21 |
22 | data_root_dir = "../../data/"
23 |
24 |
25 | def german_credit(seed=0):
26 | """ The german credit dataset
27 | """
28 | feature_names = [
29 | "checking account status",
30 | "Duration",
31 | "Credit history",
32 | "Purpose",
33 | "Credit amount",
34 | "Savings account/bonds",
35 | "Present employment since",
36 | "Installment rate in percentage of disposable income",
37 | "Personal status and sex",
38 | "Other debtors / guarantors",
39 | "Present residence since",
40 | "Property",
41 | "Age in years",
42 | "Other installment plans",
43 | "Housing",
44 | "Number of existing credits at this bank",
45 | "Job",
46 | " Number of people being liable to provide maintenance for",
47 | "Telephone",
48 | "foreign worker",
49 | ]
50 | columns = [*feature_names, "target"]
51 |
52 | data = pd.read_csv(os.path.join(data_root_dir, "german.data"), sep=" ", header=None)
53 | data.columns = columns
54 | Y = data["target"] - 1
55 | X = data
56 | X = X.drop("target", axis=1)
57 | cat_columns = X.select_dtypes(["object"]).columns
58 | X[cat_columns] = X[cat_columns].apply(lambda x: x.astype("category").cat.codes)
59 |
60 | # zero mean and unit variance for all features
61 | X = StandardScaler().fit_transform(X)
62 |
63 | # train-test split
64 | X_train, X_test, Y_train, Y_test = train_test_split(
65 | X, Y, train_size=0.8, random_state=seed
66 | )
67 |
68 | return X_train, X_test, Y_train, Y_test, feature_names
69 |
70 |
71 | def iris(seed=0):
72 | """ The iris dataset, class 1 vs. the rest.
73 | """
74 | # load the dataset
75 | iris = sklearn.datasets.load_iris()
76 | X = iris.data
77 | Y = iris.target
78 |
79 | # feature names
80 | feature_names = iris.feature_names
81 |
82 | # create a binary outcome
83 | Y = Y == 1
84 |
85 | # zero mean and unit variance for all features
86 | X = StandardScaler().fit_transform(X)
87 |
88 | # train-test split
89 | X_train, X_test, Y_train, Y_test = train_test_split(
90 | X, Y, train_size=0.8, random_state=seed
91 | )
92 |
93 | return X_train, X_test, Y_train, Y_test, feature_names
94 |
95 |
96 | def california_housing(seed=0, classification=False):
97 | """ The california housing dataset.
98 | """
99 | # load the dataset
100 | housing = sklearn.datasets.fetch_california_housing()
101 | X = housing.data
102 | Y = housing.target
103 |
104 | # feature names
105 | feature_names = housing.feature_names
106 |
107 | # create a binary outcome
108 | if classification:
109 | Y = Y > np.median(Y)
110 |
111 | # zero mean and unit variance for all features
112 | X = StandardScaler().fit_transform(X)
113 |
114 | # train-test split
115 | X_train, X_test, Y_train, Y_test = train_test_split(
116 | X, Y, train_size=0.8, random_state=seed
117 | )
118 |
119 | return X_train, X_test, Y_train, Y_test, feature_names
120 |
121 |
122 | def folktables_acs_income(seed=0, survey_year="2016", states=["CA"]):
123 | # (down-)load the dataset
124 | data_source = ACSDataSource(
125 | survey_year=survey_year,
126 | horizon="1-Year",
127 | survey="person",
128 | root_dir=data_root_dir,
129 | )
130 | data = data_source.get_data(states=states, download=True)
131 | X, Y, _ = ACSIncome.df_to_numpy(data)
132 |
133 | # feature names
134 | feature_names = ACSIncome.features
135 |
136 | # zero mean and unit variance for all features
137 | X = StandardScaler().fit_transform(X)
138 |
139 | # train-test split
140 | X_train, X_test, Y_train, Y_test = train_test_split(
141 | X, Y, train_size=0.8, random_state=seed
142 | )
143 |
144 | return X_train, X_test, Y_train, Y_test, feature_names
145 |
146 |
147 | def folktables_acs_travel_time(seed=0, survey_year="2016", states=["CA"]):
148 | # (down-)load the dataset
149 | data_source = ACSDataSource(
150 | survey_year=survey_year,
151 | horizon="1-Year",
152 | survey="person",
153 | root_dir=data_root_dir,
154 | )
155 | data = data_source.get_data(states=states, download=True)
156 | X, Y, _ = ACSTravelTime.df_to_numpy(data)
157 |
158 | # feature names
159 | feature_names = ACSTravelTime.features
160 |
161 | # zero mean and unit variance for all features
162 | X = StandardScaler().fit_transform(X)
163 |
164 | # train-test split
165 | X_train, X_test, Y_train, Y_test = train_test_split(
166 | X, Y, train_size=0.8, random_state=seed
167 | )
168 |
169 | return X_train, X_test, Y_train, Y_test, feature_names
170 |
171 |
172 | ########################################################################################################################
173 | # Functions to ease access to all the different datasets
174 | ########################################################################################################################
175 |
176 | dataset_dict = {
177 | "iris": iris,
178 | "folk_income": folktables_acs_income,
179 | "folk_travel": folktables_acs_travel_time,
180 | "housing": california_housing,
181 | "credit": german_credit,
182 | }
183 |
184 |
185 | def get_datasets():
186 | """ Returns the names of the available datasets.
187 | """
188 | return dataset_dict
189 |
190 |
191 | def is_classification(dataset):
192 | if dataset == "housing":
193 | return False
194 | return True
195 |
196 |
197 | def load_dataset(dataset):
198 | if dataset == "folk_income":
199 | X_train, X_test, Y_train, Y_test, feature_names = folktables_acs_income(0)
200 | elif dataset == "folk_travel":
201 | X_train, X_test, Y_train, Y_test, feature_names = folktables_acs_travel_time(0)
202 | # subset the dataset to 10 features to ease computation
203 | feature_subset = [13, 14, 9, 0, 12, 15, 1, 3, 7, 11]
204 | feature_names = [feature_names[i] for i in feature_subset]
205 | X_train = X_train[:, feature_subset]
206 | X_test = X_test[:, feature_subset]
207 | elif dataset == "housing":
208 | X_train, X_test, Y_train, Y_test, feature_names = california_housing(0)
209 | elif dataset == "credit":
210 | X_train, X_test, Y_train, Y_test, feature_names = german_credit(0)
211 | # subset the dataset to 10 features to ease computation
212 | feature_subset = [0, 1, 2, 3, 4, 5, 6, 7, 14, 11]
213 | feature_names = [feature_names[i] for i in feature_subset]
214 | X_train = X_train[:, feature_subset]
215 | X_test = X_test[:, feature_subset]
216 | elif dataset == "iris":
217 | X_train, X_test, Y_train, Y_test, feature_names = iris(0)
218 | return X_train, X_test, Y_train, Y_test, feature_names
219 |
220 |
221 | def get_feature_names(dataset):
222 | """ Shortened for better plotting.
223 | """
224 | if dataset == "folk_income":
225 | return folktables_acs_income()[4]
226 | elif dataset == "folk_travel":
227 | feature_names = folktables_acs_travel_time()[4]
228 | feature_names = [feature_names[i] for i in [13, 14, 9, 0, 12, 15, 1, 3, 7, 11]]
229 | feature_names[1] = "POWP"
230 | return feature_names
231 | elif dataset == "housing":
232 | return california_housing()[4]
233 | elif dataset == "credit":
234 | return [
235 | "Account",
236 | "Duration",
237 | "History",
238 | "Purpose",
239 | "Amount",
240 | "Savings",
241 | "Employ",
242 | "Rate",
243 | "Housing",
244 | "Property",
245 | ]
246 | elif dataset == "iris":
247 | return ["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"]
248 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/estimation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Estimating n-Shapley Values for the kNN classifier on the Folktables Travel data set"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import matplotlib.pyplot as plt\n",
17 | "import seaborn as sns\n",
18 | "\n",
19 | "sns.set_style(\"whitegrid\")\n",
20 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.1)\n",
21 | "\n",
22 | "import datasets\n",
23 | "import nshap"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "### Load the data"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "_, X_test, _, _, feature_names = datasets.load_dataset('folk_travel')"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "### Load the pre-computed n-Shapley Values"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "n_shapley_values = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_500.JSON')\n",
56 | "n_shapley_values_5000 = nshap.load('../../results/n_shapley_values/folk_travel/knn/observation_0_proba_5000.JSON')"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "def vfunc(x, S):\n",
66 | " S = tuple(S)\n",
67 | " fname = f'../../results/n_shapley_values/folk_travel/knn/observation_0_proba_133549/v{S}.txt' \n",
68 | " with open(fname, 'r') as f:\n",
69 | " result = float( f.read() )\n",
70 | " return result"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "n_shapley_values_133549 = nshap.n_shapley_values(X_test[0, :], vfunc)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "### Plots"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "for idx, v in enumerate([n_shapley_values, n_shapley_values_5000, n_shapley_values_133549]):\n",
96 | " fig, ax = plt.subplots(1, 1, figsize=(7, 7.9))\n",
97 | " v.plot(axis=ax, legend=False, feature_names=feature_names, rotation=60)\n",
98 | " plt.ylim([-0.295, 0.29])\n",
99 | " plt.title(f'Shapey-GAM, {[\"500\", \"5000\", \"133549\"][idx]} Samples')\n",
100 | " plt.tight_layout()\n",
101 | " plt.savefig(f'../../figures/knn_estimation_{idx}.pdf')\n",
102 | " plt.show()"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "metadata": {},
108 | "source": [
109 | "### Latex code for Table in Appendix"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "tags": []
117 | },
118 | "outputs": [],
119 | "source": [
120 | "for S in nshap.powerset(list(range(10))):\n",
121 | " if len(S) == 0:\n",
122 | " continue\n",
123 | " print(f'{S} & {n_shapley_values[S]:0.4f} & {n_shapley_values_5000[S]:0.4f} & {n_shapley_values_133549[S]:0.4f} \\\\\\\\')"
124 | ]
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "Python 3",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.8.3"
144 | }
145 | },
146 | "nbformat": 4,
147 | "nbformat_minor": 5
148 | }
149 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/figures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Generate the Figures in the Paper"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import numpy as np\n",
17 | "\n",
18 | "import os\n",
19 | "import sklearn\n",
20 | "from interpret.glassbox import ExplainableBoostingClassifier\n",
21 | "\n",
22 | "from itertools import product\n",
23 | "\n",
24 | "import matplotlib.pyplot as plt\n",
25 | "import seaborn as sns\n",
26 | "\n",
27 | "import datasets\n",
28 | "import paperutil\n",
29 | "import nshap\n",
30 | "\n",
31 | "%load_ext autoreload\n",
32 | "%autoreload 2"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "metadata": {},
38 | "source": [
39 | "### Load the computed n-Shapley Values"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']\n",
49 | "classifiers = ['gam', 'rf', 'gbtree', 'knn']\n",
50 | "methods = ['predict', 'proba', 'decision']"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "tags": []
58 | },
59 | "outputs": [],
60 | "source": [
61 | "shapley_values = {}\n",
62 | "for dataset in data_sets:\n",
63 | " X_train, X_test, _, _, _ = datasets.load_dataset(dataset)\n",
64 | " shapley_values[dataset] = {}\n",
65 | " num_datapoints = min(5000, X_train.shape[0]) \n",
66 | " for classifier in classifiers:\n",
67 | " shapley_values[dataset][classifier] = {}\n",
68 | " for method in methods:\n",
69 | " if os.path.exists(f'../../results/n_shapley_values/{dataset}/{classifier}/observation_0_{method}_{num_datapoints}.JSON'):\n",
70 | " shapley_values[dataset][classifier][method] = []\n",
71 | " for i_datapoint in range(min(X_test.shape[0], 100)):\n",
72 | " fname = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_datapoints}.JSON'\n",
73 | " if os.path.exists(fname):\n",
74 | " n_shapley_values = nshap.load(fname)\n",
75 | " shapley_values[dataset][classifier][method].append(n_shapley_values)\n",
76 | " else:\n",
77 | " print(f'File {fname} not found.')"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "### Create directory structure"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "paths = ['../../figures/', '../../figures/partial_dependence/', '../../figures/shapley_gam/', '../../figures/n_shapley_values/']\n",
94 | "for p in paths:\n",
95 | " if not os.path.exists( p ):\n",
96 | " os.mkdir( p )"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "### Plot Settings"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# avoid type-3 fonts\n",
113 | "import matplotlib\n",
114 | "matplotlib.rcParams['pdf.fonttype'] = 42\n",
115 | "matplotlib.rcParams['ps.fonttype'] = 42"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "rotation = {'folk_income': 60, 'folk_travel': 60, 'housing': 60, 'credit': 60, 'iris': 0}"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "sns.set_style(\"whitegrid\")\n",
134 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "def asdjust_ylim(axlist):\n",
144 | " ymin = min([x.get_ylim()[0] for x in axlist])\n",
145 | " ymax = max([x.get_ylim()[1] for x in axlist])\n",
146 | " for ax in axlist:\n",
147 | " ax.set_ylim((ymin, ymax))"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "### Plots of n-Shapley Values"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "sns.set_style(\"whitegrid\")\n",
164 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.1)"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "metadata": {
171 | "tags": []
172 | },
173 | "outputs": [],
174 | "source": [
175 | "for dataset in data_sets:\n",
176 | " feature_names = datasets.get_feature_names(dataset)\n",
177 | " for classifier in classifiers:\n",
178 | " for method in methods:\n",
179 | " # different methods for different classifiers\n",
180 | " if not method in shapley_values[dataset][classifier]: \n",
181 | " continue\n",
182 | " print(dataset, classifier, method)\n",
183 | " for i_datapoint in range(5):\n",
184 | " n_shapley_values = shapley_values[dataset][classifier][method][i_datapoint]\n",
185 | " fig, ax = plt.subplots(1, 4, figsize=(22.5, 6.75))\n",
186 | " ax0 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(1), axis=ax[0], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n",
187 | " ax0.set_ylabel('Feature Attribution')\n",
188 | " ax0.set_title('Shapley Values')\n",
189 | " ax1 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(2), axis=ax[1], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n",
190 | " ax1.set(yticklabels= []) \n",
191 | " ax1.set_title('Shapley Interaction Values')\n",
192 | " ax2 = nshap.plot_n_shapley(n_shapley_values.k_shapley_values(4), axis=ax[2], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n",
193 | " ax2.set(yticklabels= [])\n",
194 | " ax2.set_title('4-Shapley Values')\n",
195 | " ax3 = nshap.plot_n_shapley(n_shapley_values, axis=ax[3], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n",
196 | " ax3.set(yticklabels= []) \n",
197 | " ax3.set_title('Shapley-GAM')\n",
198 | " axes = [ax0, ax1, ax2, ax3]\n",
199 | " ymin = min([x.get_ylim()[0] for x in axes])\n",
200 | " ymax = max([x.get_ylim()[1] for x in axes])\n",
201 | " for x in axes:\n",
202 | " x.set_ylim((ymin, ymax))\n",
203 | " plt.tight_layout()\n",
204 | " plt.savefig(f'../../figures/n_shapley_values/{dataset}_{classifier}_{method}_{i_datapoint}.pdf')\n",
205 | " if i_datapoint == 0:\n",
206 | " plt.show()\n",
207 | " plt.close(fig)"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "### Plots of n-Shapley Values in the Appendix"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {
221 | "tags": []
222 | },
223 | "outputs": [],
224 | "source": [
225 | "sns.set_style(\"whitegrid\")\n",
226 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.25)\n",
227 | "\n",
228 | "from itertools import product\n",
229 | "\n",
230 | "for dataset in data_sets:\n",
231 | " feature_names = datasets.get_feature_names(dataset)\n",
232 | " if dataset == 'iris':\n",
233 | " continue\n",
234 | " for classifier in classifiers:\n",
235 | " method = 'proba'\n",
236 | " if dataset == 'housing':\n",
237 | " method = 'predict'\n",
238 | " if classifier == 'gam':\n",
239 | " method = 'decision'\n",
240 | " print(dataset, classifier, method)\n",
241 | " for i_datapoint in range(1):\n",
242 | " n_shapley_values = shapley_values[dataset][classifier][method][i_datapoint]\n",
243 | " if dataset == 'housing': # 8 features\n",
244 | " ncols = 4\n",
245 | " fig, ax = plt.subplots(2, ncols, figsize=(26, 14.75))\n",
246 | " else: # 10 features\n",
247 | " ncols = 5\n",
248 | " fig, ax = plt.subplots(2, ncols, figsize=(32, 14.75))\n",
249 | " for i in range(2):\n",
250 | " for j in range(ncols):\n",
251 | " k = 1 + ncols*i + j\n",
252 | " nshap.plot_n_shapley(n_shapley_values.k_shapley_values(k), axis=ax[i, j], legend=False, feature_names=feature_names, rotation=rotation[dataset])\n",
253 | " ax[i, j].set_title(f'{k}-Shapley Values')\n",
254 | " ax[0, 0].set_ylabel('Feature Attribution')\n",
255 | " ax[1, 0].set_ylabel('Feature Attribution')\n",
256 | " ax[0, 0].set_title('Shapley Values')\n",
257 | " ax[0, 1].set_title('Shapley Interaction Values')\n",
258 | " ax[1, ncols-1].set_title('Shapley-GAM')\n",
259 | " axes = [ax[i,j] for (i,j) in product(range(2), range(ncols))]\n",
260 | " asdjust_ylim(axes)\n",
261 | " for j in range(ncols):\n",
262 | " ax[0, j].set(xticklabels= [])\n",
263 | " for i in range(2):\n",
264 | " for j in range(1,ncols):\n",
265 | " ax[i, j].set(yticklabels= [])\n",
266 | " plt.tight_layout()\n",
267 | " plt.savefig(f'../../figures/n_shapley_values/apx_{dataset}_{classifier}_{method}_{i_datapoint}_full.pdf')\n",
268 | " if i_datapoint == 0:\n",
269 | " plt.show()\n",
270 | " plt.close(fig)"
271 | ]
272 | },
273 | {
274 | "cell_type": "markdown",
275 | "metadata": {},
276 | "source": [
277 | "### Example Visualizations"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {
284 | "tags": []
285 | },
286 | "outputs": [],
287 | "source": [
288 | "values = {(i,):0 for i in range(4)}\n",
289 | "values[(2,)] = 0.2\n",
290 | "values[(3,)] = -0.1\n",
291 | "n_shapley_values = nshap.nShapleyValues(values)\n",
292 | " \n",
293 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
294 | "nshap.plot_n_shapley(n_shapley_values.k_shapley_values(1), legend=False, axis=ax)\n",
295 | "plt.tight_layout()\n",
296 | "plt.savefig('../../figures/example1.pdf')\n",
297 | "print(values)\n",
298 | "plt.show()\n",
299 | "\n",
300 | "values[(1,2)] = 0.1 \n",
301 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
302 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n",
303 | "plt.tight_layout()\n",
304 | "plt.savefig('../../figures/example2.pdf')\n",
305 | "print(values)\n",
306 | "plt.show()\n",
307 | "\n",
308 | "values[(2,3)] = -0.1 \n",
309 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
310 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n",
311 | "plt.tight_layout()\n",
312 | "plt.savefig('../../figures/example3.pdf')\n",
313 | "print(values)\n",
314 | "plt.show()\n",
315 | "\n",
316 | "values[(1,2,3)] = 0.1\n",
317 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
318 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n",
319 | "plt.tight_layout()\n",
320 | "plt.savefig('../../figures/example4.pdf')\n",
321 | "print(values)\n",
322 | "plt.show()\n",
323 | "\n",
324 | "values[(0,1,2,3)] = -0.1 \n",
325 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
326 | "nshap.plot_n_shapley(nshap.nShapleyValues(values), legend=False, axis=ax)\n",
327 | "plt.tight_layout()\n",
328 | "plt.savefig('../../figures/example5.pdf')\n",
329 | "print(values)\n",
330 | "plt.show()"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {},
336 | "source": [
337 | "### The legend"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": null,
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "import matplotlib.patches as mpatches\n",
347 | "from matplotlib.patches import Rectangle\n",
348 | "from matplotlib.transforms import Bbox\n",
349 | "\n",
350 | "fig, ax = plt.subplots(1, 1, figsize=(12, 1))\n",
351 | "\n",
352 | "# legend\n",
353 | "plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', \n",
354 | " '#17becf', 'black', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']\n",
355 | "\n",
356 | "color_patches = [mpatches.Patch(color=color) for color in plot_colors]\n",
357 | "lables = ['Main']\n",
358 | "lables.append('2nd order')\n",
359 | "lables.append('3rd order')\n",
360 | "for i in range(4, 10):\n",
361 | " lables.append(f'{i}th')\n",
362 | "lables.append('10th order')\n",
363 | "ax.legend(color_patches, lables, ncol=11, fontsize=30, handletextpad=0.5, handlelength=1, handleheight=1)\n",
364 | "plt.axis('off')\n",
365 | "plt.savefig(f'../../figures/legend.pdf', bbox_inches=Bbox([[-13.5, -0.2], [11, 0.9]]))\n",
366 | "plt.show()"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": null,
372 | "metadata": {},
373 | "outputs": [],
374 | "source": [
375 | "import matplotlib.patches as mpatches\n",
376 | "from matplotlib.patches import Rectangle\n",
377 | "from matplotlib.transforms import Bbox\n",
378 | "\n",
379 | "fig, ax = plt.subplots(1, 1, figsize=(12, 1))\n",
380 | "\n",
381 | "# legend\n",
382 | "plot_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', \n",
383 | " '#17becf', 'black', '#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']\n",
384 | "\n",
385 | "color_patches = [mpatches.Patch(color=color) for color in plot_colors]\n",
386 | "lables = ['Main']\n",
387 | "lables.append('2nd order')\n",
388 | "lables.append('3rd order')\n",
389 | "for i in range(4, 8):\n",
390 | " lables.append(f'{i}th')\n",
391 | "ax.legend(color_patches, lables, ncol=11, fontsize=30, handletextpad=0.5, handlelength=1, handleheight=1)\n",
392 | "plt.axis('off')\n",
393 | "plt.savefig(f'../../figures/legend7.svg', bbox_inches=Bbox([[-13.5, -0.2], [11, 0.9]]))\n",
394 | "plt.show()"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "### Shapley-GAM Figure in the paper"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {},
408 | "outputs": [],
409 | "source": [
410 | "sns.set_style(\"whitegrid\")\n",
411 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.9)"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {},
418 | "outputs": [],
419 | "source": [
420 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
421 | "ax = nshap.plot_n_shapley(shapley_values['credit']['gam']['decision'][0], axis=ax, legend=False, feature_names=datasets.get_feature_names('credit'), rotation=60)\n",
422 | "ax.set_ylabel('Feature Attribution')\n",
423 | "ax.set_title('Glassbox-GAM')\n",
424 | "plt.tight_layout()\n",
425 | "plt.savefig(f'../../figures/A.svg')\n",
426 | "plt.show()"
427 | ]
428 | },
429 | {
430 | "cell_type": "code",
431 | "execution_count": null,
432 | "metadata": {},
433 | "outputs": [],
434 | "source": [
435 | "fig, ax = plt.subplots(1, 1, figsize=(4.7, 6))\n",
436 | "ax = nshap.plot_n_shapley(shapley_values['housing']['gbtree']['predict'][0], axis=ax, legend=False, feature_names=datasets.get_feature_names('housing'), rotation=60)\n",
437 | "ax.set_title('Gradient Boosted Tree')\n",
438 | "print(ax.get_ylim())\n",
439 | "plt.tight_layout()\n",
440 | "plt.savefig(f'../../figures/B.svg')\n",
441 | "plt.show()"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "fig, ax = plt.subplots(1, 1, figsize=(5.5, 6))\n",
451 | "ax = nshap.plot_n_shapley(shapley_values['folk_travel']['knn']['proba'][1], axis=ax, legend=False, feature_names=datasets.get_feature_names('folk_travel'), rotation=60)\n",
452 | "ax.set_title('k-Nearest Neighbor')\n",
453 | "print(ax.get_ylim())\n",
454 | "ax.set_ylim((-0.32, 0.34))\n",
455 | "plt.tight_layout()\n",
456 | "plt.savefig(f'../../figures/C.svg')\n",
457 | "plt.show()"
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "metadata": {},
464 | "outputs": [],
465 | "source": [
466 | "n_shapley_values = {}\n",
467 | "for S in nshap.powerset(range(8)):\n",
468 | " if len(S) == 0:\n",
469 | " continue\n",
470 | " elif len(S) == 8:\n",
471 | " n_shapley_values[S] = 1\n",
472 | " else:\n",
473 | " n_shapley_values[S] = 0 \n",
474 | "n_shapley_values = nshap.nShapleyValues(n_shapley_values)\n",
475 | "\n",
476 | "fig, ax = plt.subplots(1, 1, figsize=(4.7, 6))\n",
477 | "ax = nshap.plot_n_shapley(n_shapley_values, axis=ax, legend=False, rotation=60)\n",
478 | "ax.set_title('8d Checkerboard Function')\n",
479 | "print(ax.get_ylim())\n",
480 | "ax.set_ylim((-0.32, 0.34))\n",
481 | "plt.yticks([]) \n",
482 | "plt.tight_layout()\n",
483 | "plt.savefig(f'../../figures/D.svg')\n",
484 | "plt.show()"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "### Partial dependence plots"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "sns.set_style(\"whitegrid\")\n",
501 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=2.6)"
502 | ]
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "metadata": {
508 | "tags": []
509 | },
510 | "outputs": [],
511 | "source": [
512 | "for dataset in data_sets:\n",
513 | " feature_names = datasets.get_feature_names(dataset)\n",
514 | " X_train, X_test, _, _, _ = datasets.load_dataset(dataset)\n",
515 | " num_datapoints = min(5000, X_train.shape[0]) \n",
516 | " for classifier in classifiers:\n",
517 | " method = 'proba'\n",
518 | " if dataset == 'housing':\n",
519 | " method = 'predict'\n",
520 | " if classifier == 'gam':\n",
521 | " method = 'decision'\n",
522 | " print(dataset, classifier, method)\n",
523 | " clf = paperutil.train_classifier(dataset, classifier)\n",
524 | " for i_feature in range(len(feature_names)):\n",
525 | " # collect data\n",
526 | " x = []\n",
527 | " nsv_list = []\n",
528 | " for i_datapoint in range(100):\n",
529 | " fname = f'../../results/n_shapley_values/{dataset}/{classifier}/observation_{i_datapoint}_{method}_{num_datapoints}.JSON'\n",
530 | " if os.path.exists(fname):\n",
531 | " n_shapley_values = nshap.load(fname)\n",
532 | " if method == 'proba': \n",
533 | " # we computed the shapley values for the probablity of the predicted class\n",
534 | " # but here we want to explain the probability of class 1, for all data points\n",
535 | " prediction = int( clf.predict( X_test[i_datapoint, :].reshape((1,-1)) ) )\n",
536 | " x.append(X_test[i_datapoint, i_feature])\n",
537 | " if prediction == 0: \n",
538 | " n_shapley_values = nshap.nShapleyValues({k:-v for k,v in n_shapley_values.items()})\n",
539 | " nsv_list.append(n_shapley_values) \n",
540 | " else:\n",
541 | " x.append(X_test[i_datapoint, i_feature])\n",
542 | " nsv_list.append(n_shapley_values)\n",
543 | " else:\n",
544 | " print(f'File {fname} not found')\n",
545 | " # plot\n",
546 | " fig, ax = plt.subplots(1, 4, figsize=(30, 5.5)) # appendix\n",
547 | " #fig, ax = plt.subplots(1, 4, figsize=(30, 6)) # paper\n",
548 | " y = [n_shapley_values.k_shapley_values(1)[(i_feature,)] for n_shapley_values in nsv_list]\n",
549 | " ax0 = sns.scatterplot(x=x, y=y, ax=ax[0], s=150)\n",
550 | " ax0.set_ylabel('Feature Attribution')\n",
551 | " ax0.set_title('Shapley Values')\n",
552 | " y = [n_shapley_values.k_shapley_values(2)[(i_feature,)] for n_shapley_values in nsv_list]\n",
553 | " ax1 = sns.scatterplot(x=x, y=y, ax=ax[1], s=150)\n",
554 | " ax1.set(yticklabels= []) \n",
555 | " ax1.set_title('Shapley Interaction Values')\n",
556 | " y = [n_shapley_values.k_shapley_values(4)[(i_feature,)] for n_shapley_values in nsv_list]\n",
557 | " ax2 = sns.scatterplot(x=x, y=y, ax=ax[2], s=150)\n",
558 | " ax2.set(yticklabels= [])\n",
559 | " ax2.set_title('4-Shapley Values')\n",
560 | " y = [n_shapley_values[(i_feature,)] for n_shapley_values in nsv_list]\n",
561 | " ax3 = sns.scatterplot(x=x, y=y, ax=ax[3], s=150)\n",
562 | " ax3.set(yticklabels= []) \n",
563 | " ax3.set_title('Shapley-GAM')\n",
564 | " axes = [ax0, ax1, ax2, ax3]\n",
565 | " asdjust_ylim(axes)\n",
566 | " for ax in axes:\n",
567 | " ax.set_xlabel(f'Value of Feature {feature_names[i_feature]}')\n",
568 | " plt.tight_layout()\n",
569 | " plt.savefig(f'../../figures/partial_dependence/{dataset}_{classifier}_{i_feature}_{method}.pdf')\n",
570 | " plt.show()\n",
571 | " plt.close(fig)"
572 | ]
573 | },
574 | {
575 | "cell_type": "markdown",
576 | "metadata": {},
577 | "source": [
578 | "### Recovery of GAM without interaction terms"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "metadata": {},
585 | "outputs": [],
586 | "source": [
587 | "from interpret.glassbox import ExplainableBoostingClassifier"
588 | ]
589 | },
590 | {
591 | "cell_type": "code",
592 | "execution_count": null,
593 | "metadata": {},
594 | "outputs": [],
595 | "source": [
596 | "X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset('folk_travel')"
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {},
603 | "outputs": [],
604 | "source": [
605 | "ebm = ExplainableBoostingClassifier(feature_names=feature_names, interactions=0, random_state=0)\n",
606 | "ebm.fit(X_train[:50000], Y_train[:50000])\n",
607 | "(ebm.predict(X_test) == Y_test).mean()"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": null,
613 | "metadata": {
614 | "tags": []
615 | },
616 | "outputs": [],
617 | "source": [
618 | "from interpret.provider import InlineProvider\n",
619 | "from interpret import set_visualize_provider\n",
620 | "\n",
621 | "set_visualize_provider(InlineProvider())\n",
622 | "\n",
623 | "from interpret import show\n",
624 | "\n",
625 | "ebm_global = ebm.explain_global()\n",
626 | "show(ebm_global)"
627 | ]
628 | },
629 | {
630 | "cell_type": "markdown",
631 | "metadata": {},
632 | "source": [
633 | "#### Compute KernelShap explanations"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": [
642 | "import shap \n",
643 | "\n",
644 | "X_train_summary = shap.kmeans(X_train, 25)\n",
645 | "kernel_explainer = shap.KernelExplainer(ebm.decision_function, X_train_summary)"
646 | ]
647 | },
648 | {
649 | "cell_type": "code",
650 | "execution_count": null,
651 | "metadata": {
652 | "tags": []
653 | },
654 | "outputs": [],
655 | "source": [
656 | "kernel_shap_values = []\n",
657 | "for i in range(100):\n",
658 | " kernel_shap_values.append( kernel_explainer.shap_values(X_test[i, :]) )"
659 | ]
660 | },
661 | {
662 | "cell_type": "code",
663 | "execution_count": null,
664 | "metadata": {},
665 | "outputs": [],
666 | "source": [
667 | "for ifeature in [1]:\n",
668 | " print(f'--------------------------- {ifeature} ---------------------')\n",
669 | " # partial influence of ifeature in the gam\n",
670 | " x_ifeature = []\n",
671 | " gam_v = []\n",
672 | " for i in range(100):\n",
673 | " x_hat = np.zeros((1,10))\n",
674 | " x_hat[0, ifeature] = X_test[i, ifeature]\n",
675 | " x_ifeature.append(X_test[i, ifeature])\n",
676 | " gam_v.append(ebm.decision_function(x_hat)[0])\n",
677 | " fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
678 | " sns.scatterplot(x_ifeature, gam_v-np.mean(gam_v))\n",
679 | " plt.title('Explainable Boosting')\n",
680 | " plt.xlabel(f'{feature_names[ifeature]}')\n",
681 | " plt.ylabel(f'Score')\n",
682 | " plt.tight_layout()\n",
683 | " plt.savefig(f'../../figures/recovery_ebm.pdf')\n",
684 | " plt.show()\n",
685 | "\n",
686 | " # shapley value of feature i\n",
687 | " shapley_v = []\n",
688 | " for i in range(100):\n",
689 | " shapley_v.append(kernel_shap_values[i][ifeature])\n",
690 | " fig, ax = plt.subplots(1, 1, figsize=(6, 5))\n",
691 | " sns.scatterplot(x_ifeature, shapley_v-np.mean(shapley_v))\n",
692 | " plt.title('kernel SHAP')\n",
693 | " plt.xlabel(f'{feature_names[ifeature]}')\n",
694 | " plt.ylabel(f'Score')\n",
695 | " plt.tight_layout()\n",
696 | " plt.savefig(f'../../figures/recovery_kernel_shap.pdf')\n",
697 | " plt.show()"
698 | ]
699 | },
700 | {
701 | "cell_type": "code",
702 | "execution_count": null,
703 | "metadata": {},
704 | "outputs": [],
705 | "source": [
706 | "sns.set_style(\"whitegrid\")\n",
707 | "sns.set_context(\"notebook\", rc={'axes.linewidth': 2, 'grid.linewidth': 1}, font_scale=1.4)\n",
708 | "\n",
709 | "img = plt.imread(\"../../figures/gam_curve.png\")\n",
710 | "fig, ax = plt.subplots(figsize=(7,4))\n",
711 | "ax.imshow(img, extent=[-2.1, 2.15, -0.99, 0.83], aspect='auto')\n",
712 | "sns.scatterplot(x_ifeature, shapley_v-np.mean(shapley_v), color='r', s=50)\n",
713 | "plt.xlabel(f'Value of Feature POWPUMA')\n",
714 | "plt.ylabel(f'Kernel SHAP Attribution')\n",
715 | "plt.tight_layout()\n",
716 | "plt.savefig(f'../../figures/recovery.svg')\n",
717 | "plt.show()"
718 | ]
719 | },
720 | {
721 | "cell_type": "markdown",
722 | "metadata": {},
723 | "source": [
724 | "### Accuracy vs Average Degree of Variable Interaction in the Shapley-GAM"
725 | ]
726 | },
727 | {
728 | "cell_type": "code",
729 | "execution_count": null,
730 | "metadata": {},
731 | "outputs": [],
732 | "source": [
733 | "accuracies = {}\n",
734 | "for dataset in data_sets:\n",
735 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)\n",
736 | " is_classification = datasets.is_classification(dataset)\n",
737 | " accuracies[dataset] = {}\n",
738 | " for classifier in classifiers:\n",
739 | " clf = paperutil.train_classifier(dataset, classifier)\n",
740 | " # accuracy / mse\n",
741 | " if is_classification:\n",
742 | " accuracies[dataset][classifier] = sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test))\n",
743 | " else:\n",
744 | " accuracies[dataset][classifier] = sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test))\n",
745 | " print(dataset, classifier, accuracies[dataset][classifier])"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": null,
751 | "metadata": {},
752 | "outputs": [],
753 | "source": [
754 | "complexities = {}\n",
755 | "for dataset in data_sets:\n",
756 | " complexities[dataset] = {}\n",
757 | " for classifier in classifiers:\n",
758 | " method = 'proba'\n",
759 | " if dataset == 'housing':\n",
760 | " method = 'predict'\n",
761 | " if method == 'proba' and classifier == 'svm':\n",
762 | " continue\n",
763 | " if classifier == 'gam':\n",
764 | " method = 'decision'\n",
765 | " v = []\n",
766 | " for n_shapley_values in shapley_values[dataset][classifier][method]: \n",
767 | " degree_contributions = n_shapley_values.get_degree_contributions()\n",
768 | " integral = np.sum(degree_contributions*list(range(1, len(degree_contributions)+1))) / np.sum(degree_contributions)\n",
769 | " v.append(integral)\n",
770 | " complexities[dataset][classifier] = np.mean(v)\n",
771 | " print(dataset, classifier, np.mean(v))"
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": null,
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "sns.set_theme(font_scale=1.3)\n",
781 | "\n",
782 | "x = []\n",
783 | "y = []\n",
784 | "hue = []\n",
785 | "style = []\n",
786 | "for dataset in data_sets:\n",
787 | " if dataset == 'housing':\n",
788 | " continue\n",
789 | " for classifier in classifiers:\n",
790 | " x.append(complexities[dataset][classifier])\n",
791 | " y.append(accuracies[dataset][classifier])\n",
792 | " hue.append(dataset)\n",
793 | " style.append(classifier)\n",
794 | " \n",
795 | "fig, ax = plt.subplots(1, 1, figsize=(6, 3))\n",
796 | "ax = sns.scatterplot(x, y, hue=hue, style=style, s=200)\n",
797 | "\n",
798 | "handles, labels = ax.get_legend_handles_labels()\n",
799 | "plt.legend(ncol=1, \n",
800 | " bbox_to_anchor = (1., 1.03),\n",
801 | " handles=[handles[i] for i in [3, 0, 2, 1, 4, 6, 5, 7]], \n",
802 | " labels=['Iris', 'Income', 'Credit', 'Travel', 'GAM', 'GBTree', 'RF', 'KNN'],\n",
803 | " frameon=True,\n",
804 | " fontsize=12,\n",
805 | " markerscale = 1.8)\n",
806 | "\n",
807 | "plt.ylabel('Accuracy')\n",
808 | "plt.xlabel('Average Degree of Variable Interaction in Shapley-GAM')\n",
809 | "ax.set_xticks([1,2,3,4,5])\n",
810 | "ax.set_xlim([0.85,5])\n",
811 | "ax.set_ylim([0.6, 1.025])\n",
812 | "ax.set_yticks([0.6, 0.7, 0.8, 0.9, 1.0])\n",
813 | "plt.tight_layout()\n",
814 | "plt.savefig(f'../../figures/accuracy_interaction.svg')\n",
815 | "plt.show()"
816 | ]
817 | }
818 | ],
819 | "metadata": {
820 | "kernelspec": {
821 | "display_name": "Python 3",
822 | "language": "python",
823 | "name": "python3"
824 | },
825 | "language_info": {
826 | "codemirror_mode": {
827 | "name": "ipython",
828 | "version": 3
829 | },
830 | "file_extension": ".py",
831 | "mimetype": "text/x-python",
832 | "name": "python",
833 | "nbconvert_exporter": "python",
834 | "pygments_lexer": "ipython3",
835 | "version": "3.8.3"
836 | }
837 | },
838 | "nbformat": 4,
839 | "nbformat_minor": 4
840 | }
841 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/hyperparameters.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Hyperparameter tuning"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import datasets\n",
17 | "import nshap\n",
18 | "\n",
19 | "import xgboost\n",
20 | "import sklearn\n",
21 | "from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor\n",
22 | "from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor\n",
23 | "from sklearn.svm import SVC, SVR\n",
24 | "from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor\n",
25 | "\n",
26 | "from sklearn.model_selection import GridSearchCV\n",
27 | "\n",
28 | "%load_ext autoreload\n",
29 | "%autoreload 2"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "data_sets = ['folk_income', 'folk_travel', 'housing', 'credit', 'iris']"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "### kNN"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "param_grid = {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 10, 15, 20, 25, 30, 50, 80]}\n",
55 | " \n",
56 | "for data_set in data_sets:\n",
57 | " X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(data_set)\n",
58 | " is_classification = datasets.is_classification(data_set)\n",
59 | " print(data_set, is_classification, feature_names)\n",
60 | " \n",
61 | " if is_classification:\n",
62 | " clf = GridSearchCV(KNeighborsClassifier(), param_grid) \n",
63 | " clf.fit(X_train, Y_train)\n",
64 | " else: \n",
65 | " clf = GridSearchCV(KNeighborsRegressor(), param_grid)\n",
66 | " clf.fit(X_train, Y_train)\n",
67 | " \n",
68 | " print(clf.best_params_)\n",
69 | " if is_classification:\n",
70 | " print( sklearn.metrics.accuracy_score(Y_test, clf.predict(X_test)) )\n",
71 | " else:\n",
72 | " print( sklearn.metrics.mean_squared_error(Y_test, clf.predict(X_test)) )"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "knn_k = {'folk_income': 30, \n",
82 | " 'folk_travel': 80, \n",
83 | " 'housing': 10, \n",
84 | " 'diabetes': 15, \n",
85 | " 'credit': 25,\n",
86 | " 'iris': 1} "
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "knn_k"
96 | ]
97 | }
98 | ],
99 | "metadata": {
100 | "kernelspec": {
101 | "display_name": "Python 3",
102 | "language": "python",
103 | "name": "python3"
104 | },
105 | "language_info": {
106 | "codemirror_mode": {
107 | "name": "ipython",
108 | "version": 3
109 | },
110 | "file_extension": ".py",
111 | "mimetype": "text/x-python",
112 | "name": "python",
113 | "nbconvert_exporter": "python",
114 | "pygments_lexer": "ipython3",
115 | "version": "3.8.3"
116 | }
117 | },
118 | "nbformat": 4,
119 | "nbformat_minor": 5
120 | }
121 |
--------------------------------------------------------------------------------
/notebooks/replicate-paper/paperutil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import datasets
4 |
5 | import xgboost
6 | from sklearn.ensemble import RandomForestClassifier,RandomForestRegressor
7 | from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
8 | from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
9 |
10 | def checkerboard_function(k, num_checkers=8):
11 | """ The k-dimensional checkerboard function, delivered within the unit cube.
12 | This is pure interaction. Enjoy!
13 |
14 | k: dimension of the checkerboard (k >= 2)
15 | num_checkers: number of checkers along an axis (num_checkers >= 2)
16 | """
17 | def f_checkerboard(X):
18 | if X.ndim == 1:
19 | return np.sum([int(num_checkers * X[i]) for i in range(k)]) % 2
20 | # X.ndin == 2
21 | result = np.zeros(X.shape[0])
22 | for i_point, x in enumerate(X):
23 | result[i_point] = np.sum([int(num_checkers * x[i]) for i in range(k)]) % 2
24 | return result
25 | return f_checkerboard
26 |
27 | def train_classifier(dataset, classifier):
28 | """ Train the different classifiers that we use in the paper.
29 | """
30 | X_train, X_test, Y_train, Y_test, feature_names = datasets.load_dataset(dataset)
31 | is_classification = datasets.is_classification(dataset)
32 | if classifier == 'gam':
33 | if is_classification:
34 | clf = ExplainableBoostingClassifier(feature_names=feature_names, interactions=0, random_state=0)
35 | clf.fit(X_train, Y_train)
36 | else:
37 | clf = ExplainableBoostingRegressor(feature_names=feature_names, interactions=0, random_state=0)
38 | clf.fit(X_train, Y_train)
39 | elif classifier == 'rf':
40 | if is_classification:
41 | clf = RandomForestClassifier(n_estimators=100, random_state=0)
42 | clf.fit(X_train, Y_train)
43 | else:
44 | clf = RandomForestRegressor(n_estimators=100, random_state=0)
45 | clf.fit(X_train, Y_train)
46 | elif classifier == 'gbtree':
47 | if is_classification:
48 | clf = xgboost.XGBClassifier(n_estimators=100, use_label_encoder=False, random_state=0)
49 | clf.fit(X_train, Y_train)
50 | else:
51 | clf = xgboost.XGBRegressor(n_estimators=100, use_label_encoder=False, random_state=0)
52 | clf.fit(X_train, Y_train)
53 | elif classifier == 'knn':
54 | # determined with cross-validation
55 | knn_k = {'folk_income': 30,
56 | 'folk_travel': 80,
57 | 'housing': 10,
58 | 'diabetes': 15,
59 | 'credit': 25,
60 | 'iris': 1}
61 | if is_classification:
62 | clf = KNeighborsClassifier(n_neighbors = knn_k[dataset])
63 | clf.fit(X_train, Y_train)
64 | else:
65 | clf = KNeighborsRegressor(n_neighbors = knn_k[dataset])
66 | clf.fit(X_train, Y_train)
67 | return clf
68 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # pyproject.toml
2 |
3 | [build-system]
4 | requires = ["setuptools", "wheel"]
5 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="nshap",
8 | version="0.2.0",
9 | author="Sebastian Bordt",
10 | author_email="sbordt@posteo.de",
11 | description="Python package to compute n-Shapley Values.",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/tml-tuebingen/nshap",
15 | classifiers=[
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ],
20 | package_dir={"": "src"},
21 | packages=["nshap"],
22 | python_requires=">=3.6",
23 | install_requires=["numpy", "matplotlib", "seaborn",],
24 | )
25 |
--------------------------------------------------------------------------------
/src/nshap/InteractionIndex.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 |
5 | import nshap
6 |
7 | #############################################################################################################################
8 | # Unique string identifiers for all supported interaction indices
9 | #############################################################################################################################
10 |
11 | SHAPLEY_VALUES = "Shapley Values"
12 | MOEBIUS_TRANSFORM = "Moebius Transform"
13 | SHAPLEY_INTERACTION = "Shapley Interaction"
14 | N_SHAPLEY_VALUES = "n-Shapley Values"
15 | SHAPLEY_TAYLOR = "Shapley Taylor"
16 | FAITH_SHAP = "Faith-Shap"
17 | BANZHAF = "Banzhaf"
18 | BANZHAF_INTERACTION = "Banzhaf Interaction"
19 | FAITH_BANZHAF = "Faith-Banzhaf"
20 |
21 | ALL_INDICES = [
22 | SHAPLEY_VALUES,
23 | MOEBIUS_TRANSFORM,
24 | SHAPLEY_INTERACTION,
25 | N_SHAPLEY_VALUES,
26 | SHAPLEY_TAYLOR,
27 | FAITH_SHAP,
28 | BANZHAF,
29 | BANZHAF_INTERACTION,
30 | FAITH_BANZHAF
31 | ]
32 |
33 | #############################################################################################################################
34 | # A single class for all interaction indices
35 | #############################################################################################################################
36 |
37 |
38 | class InteractionIndex(collections.UserDict):
39 | """Class for different interaction indices (n-Shapley Values, Shapley Taylor, Faith-Shap).
40 | Interaction indices are a Python dict with added functionality.
41 | """
42 |
43 | def __init__(self, index_type: str, values, n=None, d=None):
44 | """Initialize an interaction index from a dict of values.
45 |
46 | Args:
47 | type (str): The type of the interaction index.
48 | values (_type_): The underlying dict of values.
49 | n (_type_, optional): _description_. Defaults to None.
50 | d (_type_, optional): _description_. Defaults to None.
51 | """
52 | super().__init__(values)
53 | assert index_type in ALL_INDICES, f"{index_type} is not a supported interaction index"
54 | self.index_type = index_type
55 | self.n = n
56 | if (
57 | n is None
58 | ): # if n or d are not given as aruments, infer them from the values dict.
59 | self.n = max([len(x) for x in values.keys()])
60 | self.d = d
61 | if d is None:
62 | self.d = len([x[0] for x in values.keys() if len(x) == 1])
63 |
64 | def get_index_type(self):
65 | """Return the type of the interaction index (for example, "Shapley Taylor")
66 |
67 | Returns:
68 | str: See function description.
69 | """
70 | return self.index_type
71 |
72 | def get_input_dimension(self):
73 | """Return the input dimension of the function for which we computed the interaction index ('d').
74 |
75 | Returns:
76 | integer: See function description.
77 | """
78 | return self.d
79 |
80 | def get_order(self):
81 | """Return the order of the interaction index ('n').
82 |
83 | Returns:
84 | integer: See function description.
85 | """
86 | return self.n
87 |
88 | def sum(self):
89 | """Sum all the terms that are invovled in the interaction index.
90 |
91 | For many interaction indices, the result is equal to the value of the function that we attempt to explain, minus the value of the empty coalition.
92 |
93 | Returns:
94 | Float: The sum.
95 | """
96 | return np.sum([x for x in self.data.values()])
97 |
98 | def copy(self):
99 | """Return a copy of the current object.
100 |
101 | Returns:
102 | nshap.InteractionIndex: The copy.
103 | """
104 | return InteractionIndex(self.index_type, self.data.copy())
105 |
106 | def save(self, fname):
107 | """Save the interaction index to a JSON file.
108 |
109 | Args:
110 | fname (str): Filename.
111 | """
112 | nshap.save(self, fname)
113 |
114 | def plot(self, *args, **kwargs):
115 | """Generate a plots of the n-Shapley Values.
116 |
117 | This function simply calls nshap.plots.plot_n_shapley.
118 |
119 | Returns:
120 | The axis of the matplotlib plot.
121 | """
122 | return nshap.plot_interaction_index(self, *args, **kwargs)
123 |
124 | def shapley_values(self):
125 | """Compute the original Shapley Values.
126 |
127 | Returns:
128 | numpy.ndarray: Shaley Values. If you prefer an object of type nShapleyValues, call k_shapley_values(self, 1).
129 | """
130 | assert (
131 | self.index_type == N_SHAPLEY_VALUES or self.index_type == MOEBIUS_TRANSFORM
132 | ), f"shapley_values only supports {N_SHAPLEY_VALUES} and {MOEBIUS_TRANSFORM}"
133 | shapley_values = self.k_shapley_values(1)
134 | shapley_values = np.array(list(shapley_values.values())).reshape((1, -1))
135 | return shapley_values
136 |
137 | def k_shapley_values(self, k):
138 | """Compute k-Shapley Values of lower order. Requires k <= n.
139 |
140 | Args:
141 | k (int): The desired order.
142 |
143 | Returns:
144 | nShapleyValues: k-Shapley Values.
145 | """
146 | assert (
147 | self.index_type == N_SHAPLEY_VALUES or self.index_type == MOEBIUS_TRANSFORM
148 | ), f"k_shapley_values only supports {N_SHAPLEY_VALUES} and {MOEBIUS_TRANSFORM}"
149 | assert k <= self.n, "k_shapley_values requires k self.n - 1):
169 | continue
170 | # we have the n-normalized effect
171 | S_effect = self.data.get(S, 0)
172 | # go over all subsets T of length n that contain S
173 | for T in nshap.powerset(range(self.d)):
174 | if (len(T) != self.n) or (not set(S).issubset(T)):
175 | continue
176 | # add the effect of T to S
177 | T_effect = self.data.get(
178 | T, 0
179 | ) # default to zero in case the dict is sparse
180 | # normalization
181 | S_effect = (
182 | S_effect - (nshap.bernoulli_numbers[len(T) - len(S)]) * T_effect
183 | )
184 | # now we have the normalized effect
185 | result[S] = S_effect
186 | return InteractionIndex(N_SHAPLEY_VALUES, result)
187 |
--------------------------------------------------------------------------------
/src/nshap/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.0"
2 | __author__ = "Sebastian Bordt"
3 |
4 | from nshap.InteractionIndex import InteractionIndex
5 | from nshap.InteractionIndex import (
6 | SHAPLEY_VALUES,
7 | MOEBIUS_TRANSFORM,
8 | SHAPLEY_INTERACTION,
9 | N_SHAPLEY_VALUES,
10 | SHAPLEY_TAYLOR,
11 | FAITH_SHAP,
12 | BANZHAF,
13 | BANZHAF_INTERACTION,
14 | FAITH_BANZHAF,
15 | )
16 |
17 | from nshap.functions import (
18 | shapley_interaction_index,
19 | n_shapley_values,
20 | moebius_transform,
21 | shapley_values,
22 | faith_shap,
23 | shapley_taylor,
24 | faith_banzhaf,
25 | banzhaf_interaction_index,
26 | bernoulli_numbers,
27 | )
28 | from nshap.plot import plot_interaction_index
29 | from nshap.util import allclose, save, load, powerset
30 |
31 | from nshap.vfunc import memoized_vfunc
32 |
33 |
--------------------------------------------------------------------------------
/src/nshap/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 |
4 | import nshap
5 |
6 |
7 | #############################################################################################################################
8 | # Bernoulli numbers
9 | #############################################################################################################################
10 |
11 | bernoulli_numbers = np.array(
12 | [
13 | 1,
14 | -1 / 2,
15 | 1 / 6,
16 | 0,
17 | -1 / 30,
18 | 0,
19 | 1 / 42,
20 | 0,
21 | -1 / 30,
22 | 0,
23 | 5 / 66,
24 | 0,
25 | -691 / 2730,
26 | 0,
27 | 7 / 6,
28 | 0,
29 | -3617 / 510,
30 | 0,
31 | 43867 / 798,
32 | 0,
33 | ]
34 | )
35 |
36 |
37 | #############################################################################################################################
38 | # Input Validation (Used by all that follows)
39 | #############################################################################################################################
40 |
41 |
42 | def validate_inputs(x, v_func):
43 | if x.ndim == 1:
44 | x = x.reshape((1, -1))
45 | assert (
46 | x.shape[0] == 1
47 | ), "The nshap package only accepts single data points as input."
48 | dim = x.shape[1]
49 | if not isinstance(v_func, nshap.memoized_vfunc): # meomization
50 | v_func = nshap.memoized_vfunc(v_func)
51 | # for d>20, we would have to consider the numerics of the problem more carefully
52 | assert dim <= 20, "The nshap package only supports d<=20."
53 | return x, v_func, dim
54 |
55 |
56 | #############################################################################################################################
57 | # The Moebius Transform
58 | #############################################################################################################################
59 |
60 |
61 | def moebius_transform(x, v_func):
62 | """Compute the Moebius Transform of of the value function v_func at the data point x.
63 |
64 | Args:
65 | x (numpy.ndarray): A data point.
66 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
67 |
68 | Returns:
69 | nshap.InteractionIndex: The interaction index.
70 | """
71 | # validate input parameters
72 | x, v_func, dim = validate_inputs(x, v_func)
73 | # go over all subsets S of N with 1<=|S|<=d
74 | result = {}
75 | for S in nshap.powerset(set(range(dim))):
76 | if len(S) == 0:
77 | continue
78 | summands = []
79 | # go over all subsets T of S
80 | for T in nshap.powerset(S):
81 | summands.append(v_func(x, list(T)) * (-1) ** (len(S) - len(T)))
82 | result[S] = np.sum(summands)
83 | # return result
84 | return nshap.InteractionIndex(nshap.MOEBIUS_TRANSFORM, result)
85 |
86 |
87 | #############################################################################################################################
88 | # Shapley Values
89 | #############################################################################################################################
90 |
91 |
92 | def shapley_values(x, v_func):
93 | """Compute the original Shapley Values, according to the Shapley Formula.
94 |
95 | Args:
96 | x (numpy.ndarray): A data point.
97 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
98 |
99 | Returns:
100 | nshap.InteractionIndex: The interaction index.
101 | """
102 | # validate input parameters
103 | x, v_func, dim = validate_inputs(x, v_func)
104 | # go over all features
105 | result = {}
106 | for i_feature in range(dim):
107 | phi = 0
108 | S = set(range(dim))
109 | S.remove(i_feature)
110 | for subset in nshap.powerset(S):
111 | v = v_func(x, list(subset))
112 | subset_i = list(subset)
113 | subset_i.append(i_feature)
114 | subset_i.sort()
115 | v_i = v_func(x, subset_i)
116 | phi = phi + np.math.factorial(len(subset)) * np.math.factorial(
117 | dim - len(subset) - 1
118 | ) / np.math.factorial(dim) * (v_i - v)
119 | result[(i_feature,)] = phi
120 | # return result
121 | return nshap.InteractionIndex(nshap.SHAPLEY_VALUES, result)
122 |
123 |
124 | #############################################################################################################################
125 | # Shapley Interaction Index
126 | #############################################################################################################################
127 |
128 |
129 | def shapley_interaction_index(x, v_func, n=-1):
130 | """Compute the Shapley Interaction Index (https://link.springer.com/article/10.1007/s001820050125) at the data point x, and all S such that |S|<=n, given a coalition value function.
131 |
132 | Args:
133 | x (numpy.ndarray): A data point
134 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
135 | n (int): Order up to which the Shapley Interaction Index should be computed.
136 |
137 | Returns:
138 | nshap.InteractionIndex: The interaction index.
139 | """
140 | # validate input parameters
141 | x, v_func, dim = validate_inputs(x, v_func)
142 | if n == -1:
143 | n = dim
144 | # go over all subsets S of N with |S|<=n
145 | result = {}
146 | for S in nshap.powerset(set(range(dim))):
147 | if len(S) > n:
148 | continue
149 | # go over all subsets T of N\S
150 | phi = 0
151 | N_minus_S = set(range(dim)) - set(S)
152 | for T in nshap.powerset(N_minus_S):
153 | # go over all subsets L of S
154 | delta = 0
155 | for L in nshap.powerset(S):
156 | coalition = list(L)
157 | coalition.extend(list(T))
158 | coalition.sort()
159 | delta = delta + np.power(-1, len(S) - len(L)) * v_func(x, coalition)
160 | phi = phi + delta * np.math.factorial(len(T)) * np.math.factorial(
161 | dim - len(T) - len(S)
162 | ) / np.math.factorial(dim - len(S) + 1)
163 | result[S] = phi
164 | # return result
165 | return nshap.InteractionIndex(nshap.SHAPLEY_INTERACTION, result)
166 |
167 |
168 | #############################################################################################################################
169 | # n-Shapley Values
170 | #############################################################################################################################
171 |
172 |
173 | def n_shapley_values(x, v_func, n=-1):
174 | """This function provides an exact computation of n-Shapley Values (https://arxiv.org/abs/2209.04012) via their definition.
175 |
176 | Args:
177 | x (numpy.ndarray): A data point.
178 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
179 | n (int, optional): Order of n-Shapley Values or -1 for n=d. Defaults to -1.
180 |
181 | Returns:
182 | nshap.InteractionIndex: The interaction index.
183 | """
184 | # validate input parameters
185 | x, v_func, dim = validate_inputs(x, v_func)
186 | if n == -1:
187 | n = dim
188 | # first compute the shapley interaction index
189 | shapley_int_idx = shapley_interaction_index(x, v_func, n)
190 | # a list of length num_datapoints
191 | result = {}
192 | # consider all subsets S with 1<=|S|<=n
193 | for S in nshap.powerset(range(dim)):
194 | if (len(S) == 0) or (len(S) > n):
195 | continue
196 | # obtain the unnormalized effect (that is, delta_S(x))
197 | S_effect = shapley_int_idx[S]
198 | # go over all subsets T of length k+1, ..., n that contain S
199 | for T in nshap.powerset(range(dim)):
200 | if (len(T) <= len(S)) or (len(T) > n) or (not set(S).issubset(T)):
201 | continue
202 | # get the effect of T, and substract it from the effect of S
203 | T_effect = shapley_int_idx[T]
204 | # normalization with bernoulli_numbers
205 | S_effect = S_effect + (bernoulli_numbers[len(T) - len(S)]) * T_effect
206 | # now we have the normalized effect
207 | result[S] = S_effect
208 | # return result
209 | return nshap.InteractionIndex(nshap.N_SHAPLEY_VALUES, result)
210 |
211 |
212 | #############################################################################################################################
213 | # Shapley Taylor Interaction Index
214 | #############################################################################################################################
215 |
216 |
217 | def shapley_taylor(x, v_func, n=-1):
218 | """ Compute the Shapley Taylor Interaction Index (https://arxiv.org/abs/1902.05622) of the value function v_func at the data point x.
219 |
220 | Args:
221 | x (numpy.ndarray): A data point.
222 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
223 | n (int, optional): Order of the Shapley Taylor Interaction Index or -1 for n=d. Defaults to -1.
224 |
225 | Returns:
226 | nshap.InteractionIndex: The interaction index.
227 | """
228 | # validate input parameters
229 | x, v_func, dim = validate_inputs(x, v_func)
230 | if n == -1:
231 | n = dim
232 | # we first compute the moebius transform
233 | moebius = moebius_transform(x, v_func)
234 | # then compute the Shapley Taylor Interaction Index, for all datapoints
235 | result = {}
236 | # consider all subsets S with 1<=|S|<=n
237 | for S in nshap.powerset(range(dim)):
238 | if (len(S) == 0) or (len(S) > n):
239 | continue
240 | result[S] = moebius[S]
241 | # for |S|=n, average the higher-order effects
242 | if len(S) == n:
243 | # go over all subsets of [d] that contain S
244 | for T in nshap.powerset(range(dim)):
245 | if (len(T) <= len(S)) or (not set(S).issubset(T)):
246 | continue
247 | result[S] += moebius[T] / math.comb(len(T), len(S))
248 | # return result
249 | return nshap.InteractionIndex(nshap.SHAPLEY_TAYLOR, result)
250 |
251 |
252 | #############################################################################################################################
253 | # Faith-Shap Interaction Index
254 | #############################################################################################################################
255 |
256 |
257 | def faith_shap(x, v_func, n=-1):
258 | """ Compute the Faith-Shap Interaction Index (https://arxiv.org/abs/2203.00870) of the value function v_func at the data point x.
259 |
260 | Args:
261 | x (numpy.ndarray): A data point.
262 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
263 | n (int, optional): Order of the Interaction Index or -1 for n=d. Defaults to -1.
264 |
265 | Returns:
266 | nshap.InteractionIndex: The interaction index.
267 | """
268 | # validate input parameters
269 | x, v_func, dim = validate_inputs(x, v_func)
270 | if n == -1:
271 | n = dim
272 | # we first compute the moebius transform
273 | moebius = moebius_transform(x, v_func)
274 | # then compute the Faith-Shap Interaction Index
275 | result = {}
276 | # consider all subsets S with 1<=|S|<=n
277 | for S in nshap.powerset(range(dim)):
278 | if (len(S) == 0) or (len(S) > n):
279 | continue
280 | result[S] = moebius[S]
281 | # go over all subsets of [d] that contain S
282 | for T in nshap.powerset(range(dim)):
283 | if (len(T) <= n) or (not set(S).issubset(T)):
284 | continue
285 | # compare Theorem 19 in the Faith-Shap paper. In our notation, l=n.
286 | result[S] += (
287 | (-1) ** (n - len(S))
288 | * len(S)
289 | / (n + len(S))
290 | * math.comb(n, len(S))
291 | * math.comb(len(T) - 1, n)
292 | / math.comb(len(T) + n - 1, n + len(S))
293 | * moebius[T]
294 | )
295 | # return result
296 | return nshap.InteractionIndex(nshap.FAITH_SHAP, result)
297 |
298 |
299 | #############################################################################################################################
300 | # Bhanzaf Values
301 | #############################################################################################################################
302 |
303 |
304 | #############################################################################################################################
305 | # Bhanzaf Interaction Index
306 | #############################################################################################################################
307 |
308 |
309 | def banzhaf_interaction_index(x, v_func, n=-1):
310 | """Compute the Banzhaf Interaction Index (https://link.springer.com/article/10.1007/s001820050125) at the data point x, and all S such that |S|<=n, given a coalition value function.
311 |
312 | Args:
313 | x (numpy.ndarray): A data point
314 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
315 | n (int): Order up to which the Shapley Interaction Index should be computed.
316 |
317 | Returns:
318 | nshap.InteractionIndex: The interaction index.
319 | """
320 | # validate input parameters
321 | x, v_func, dim = validate_inputs(x, v_func)
322 | if n == -1:
323 | n = dim
324 | # go over all subsets S of N with |S|<=n
325 | result = {}
326 | for S in nshap.powerset(set(range(dim))):
327 | if len(S) > n:
328 | continue
329 | # go over all subsets T of N\S
330 | phi = 0
331 | N_minus_S = set(range(dim)) - set(S)
332 | for T in nshap.powerset(N_minus_S):
333 | # go over all subsets L of S
334 | delta = 0
335 | for L in nshap.powerset(S):
336 | coalition = list(L)
337 | coalition.extend(list(T))
338 | coalition.sort()
339 | delta = delta + np.power(-1, len(S) - len(L)) * v_func(x, coalition)
340 | phi = phi + delta * (1 / np.power(2, dim - len(S)))
341 | result[S] = phi
342 | # return result
343 | return nshap.InteractionIndex(nshap.BANZHAF_INTERACTION, result)
344 |
345 |
346 | #############################################################################################################################
347 | # Faith-Bhanzaf Interaction Index
348 | #############################################################################################################################
349 |
350 |
351 | def faith_banzhaf(x, v_func, n=-1):
352 | """ Compute the Faith-Banzhaf Interaction Index (https://arxiv.org/abs/2203.00870) of the value function v_func at the data point x.
353 |
354 | Args:
355 | x (numpy.ndarray): A data point.
356 | v_func (function): The value function. It takes two arguments: The datapoint x and a list with the indices of the coalition.
357 | n (int, optional): Order of the Interaction Index or -1 for n=d. Defaults to -1.
358 |
359 | Returns:
360 | nshap.InteractionIndex: The interaction index.
361 | """
362 | # validate input parameters
363 | x, v_func, dim = validate_inputs(x, v_func)
364 | if n == -1:
365 | n = dim
366 | # we first compute the moebius transform
367 | moebius = moebius_transform(x, v_func)
368 | # then compute the Faith-Banzhaf Interaction Index
369 | result = {}
370 | # consider all subsets S with 1<=|S|<=n
371 | for S in nshap.powerset(range(dim)):
372 | if (len(S) == 0) or (len(S) > n):
373 | continue
374 | result[S] = moebius[S]
375 | # go over all subsets of [d] that contain S
376 | for T in nshap.powerset(range(dim)):
377 | if (len(T) <= n) or (not set(S).issubset(T)):
378 | continue
379 | # compare Theorem 17 in the Faith-Shap paper. In our notation, l=n.
380 | result[S] += (
381 | (-1) ** (n - len(S))
382 | * np.power(0.5, len(T) - len(S))
383 | * math.comb(len(T) - len(S) - 1, n - len(S))
384 | * moebius[T]
385 | )
386 | # return result
387 | return nshap.InteractionIndex(nshap.FAITH_BANZHAF, result)
388 |
--------------------------------------------------------------------------------
/src/nshap/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import matplotlib.patches as mpatches
6 |
7 | import seaborn as sns
8 |
9 | import nshap
10 |
11 | # avoid type-3 fonts
12 | matplotlib.rcParams["pdf.fonttype"] = 42
13 | matplotlib.rcParams["ps.fonttype"] = 42
14 |
15 | #############################################################################################################################
16 | # Plots
17 | #############################################################################################################################
18 |
19 | plot_colors = [
20 | "#1f77b4",
21 | "#ff7f0e",
22 | "#2ca02c",
23 | "#d62728",
24 | "#9467bd",
25 | "#8c564b",
26 | "#e377c2",
27 | "#7f7f7f",
28 | "#bcbd22",
29 | "#17becf",
30 | "black",
31 | "#1f77b4",
32 | "#ff7f0e",
33 | "#2ca02c",
34 | "#d62728",
35 | "#9467bd",
36 | "#8c564b",
37 | "#e377c2",
38 | "#7f7f7f",
39 | "#bcbd22",
40 | "#17becf",
41 | ]
42 |
43 |
44 | def plot_interaction_index(
45 | I: nshap.InteractionIndex,
46 | max_degree=None,
47 | axis=None,
48 | feature_names=None,
49 | rotation=70,
50 | legend=True,
51 | fig_kwargs={"figsize": (6, 6)},
52 | barwidth=0.5,
53 | ):
54 | """Generate the plots in the paper.
55 |
56 | Args:
57 | n_shapley_values (nshap.nShapleyValues): The n-Shapley Values that we want to plot.
58 | max_degree (int, optional): Plots all effect of order larger than max_degree with a single color. Defaults to None.
59 | axis (optional): Matplotlib axis on which to plot. Defaults to None.
60 | feature_names (_type_, optional): Used to label the x-axis. Defaults to None.
61 | rotation (int, optional): Rotation for x-axis labels. Defaults to 70.
62 | legend (bool, optional): Plot legend. Defaults to True.
63 | fig_kwargs (dict, optional): fig_kwargs, handed down to matplotlib figure. Defaults to {"figsize": (6, 6)}.
64 | barwidth (float, optional): Widht of the bars. Defaults to 0.5.
65 |
66 | Returns:
67 | Matplotlib axis: The plot axis.
68 | """
69 | if max_degree == 1:
70 | I = I.shapley_values() # TODO this might have to change
71 | num_features = I.d
72 | vmax, vmin = 0, 0
73 | ax = axis
74 | if axis is None:
75 | _, ax = plt.subplots(**fig_kwargs)
76 | if max_degree is None or max_degree >= I.n:
77 | max_degree = I.n
78 | ax.axhline(y=0, color="black", linestyle="-") # line at 0
79 | for i_feature in range(num_features):
80 | bmin, bmax = 0, 0
81 | v = I[(i_feature,)]
82 | ax.bar(
83 | x=i_feature,
84 | height=v,
85 | width=barwidth,
86 | bottom=0,
87 | align="center",
88 | label=f"Feature {i_feature}",
89 | color=plot_colors[0],
90 | )
91 | bmin = min(bmin, v)
92 | bmax = max(bmax, v)
93 | # higher-order effects, up to max_degree
94 | for n_k in range(2, I.n + 1):
95 | v_pos = np.sum(
96 | [
97 | I[k] / len(k)
98 | for k in I.data.keys()
99 | if (len(k) == n_k and i_feature in k and I[k] > 0)
100 | ]
101 | )
102 | v_neg = np.sum(
103 | [
104 | I[k] / len(k)
105 | for k in I.data.keys()
106 | if (len(k) == n_k and i_feature in k and I[k] < 0)
107 | ]
108 | )
109 | # 'max_degree or higher'
110 | if n_k == max_degree:
111 | v_pos = np.sum(
112 | [
113 | I[k] / len(k)
114 | for k in I.data.keys()
115 | if (len(k) >= n_k and i_feature in k and I[k] > 0)
116 | ]
117 | )
118 | v_neg = np.sum(
119 | [
120 | I[k] / len(k)
121 | for k in I.data.keys()
122 | if (len(k) >= n_k and i_feature in k and I[k] < 0)
123 | ]
124 | )
125 | if v_pos > 0:
126 | ax.bar(
127 | x=i_feature,
128 | height=v_pos,
129 | width=barwidth,
130 | bottom=bmax,
131 | align="center",
132 | color=plot_colors[n_k - 1],
133 | )
134 | bmax = bmax + v_pos
135 | if v_neg < 0:
136 | ax.bar(
137 | x=i_feature,
138 | height=v_neg,
139 | width=barwidth,
140 | bottom=bmin,
141 | align="center",
142 | color=plot_colors[n_k - 1],
143 | )
144 | bmin = bmin + v_neg
145 | if n_k == max_degree: # no higher orders
146 | break
147 | vmin = min(vmin, bmin)
148 | vmax = max(vmax, bmax)
149 | # axes
150 | if feature_names is None:
151 | feature_names = [f"Feature {i+1}" for i in range(num_features)]
152 | ax.set_ylim([1.1 * vmin, 1.1 * vmax])
153 | ax.set_xticks(np.arange(num_features))
154 | ax.set_xticklabels(feature_names, rotation=rotation)
155 | # legend with custom labels
156 | color_patches = [mpatches.Patch(color=color) for color in plot_colors]
157 | lables = ["Main"]
158 | if max_degree > 1:
159 | lables.append("2nd order")
160 | if max_degree > 2:
161 | lables.append(f"3rd order")
162 | if max_degree > 3:
163 | for i_degree in range(4, I.n + 1):
164 | if i_degree == max_degree and max_degree < I.n:
165 | lables.append(f"{i_degree}-{I.n}th order")
166 | break
167 | else:
168 | if i_degree == I.n:
169 | lables.append(f"{i_degree}th order")
170 | else:
171 | lables.append(f"{i_degree}th")
172 | ax.legend(
173 | color_patches,
174 | lables,
175 | bbox_to_anchor=(1.02, 1),
176 | loc="upper left",
177 | borderaxespad=0,
178 | handletextpad=0.5,
179 | handlelength=1,
180 | handleheight=1,
181 | )
182 | ax.get_legend().set_visible(legend)
183 | ax.set_title(I.index_type)
184 | if axis is None:
185 | plt.show()
186 | return ax
187 |
--------------------------------------------------------------------------------
/src/nshap/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from itertools import chain, combinations
4 |
5 | import json
6 | import re
7 |
8 | import nshap
9 |
10 |
11 | def allclose(dict1, dict2, rtol=1e-05, atol=1e-08):
12 | """Compare if two dicts of n-Shapley Values are close according to numpy.allclose.
13 |
14 | Useful for testing purposes.
15 |
16 | Args:
17 | dict1 (dict): The first dict.
18 | dict2 (dict): The second dict.
19 | rtol (float, optional): passed to numpy.allclose. Defaults to 1e-05.
20 | atol (float, optional): passed to numpy.allclose. Defaults to 1e-08.
21 |
22 | Returns:
23 | bool: Result of numpy.allclose.
24 | """
25 | if dict1.keys() != dict2.keys(): # both dictionaries need to have the same keys
26 | return False
27 | for key in dict1.keys():
28 | if not np.allclose(dict1[key], dict2[key], rtol=rtol, atol=atol):
29 | return False
30 | return True
31 |
32 |
33 | def save(values, fname):
34 | """Save an interaction index to a JSON file.
35 |
36 | Args:
37 | values (nshap.InteractionIndex): The interaction index.
38 | fname (str): Filename.
39 | """
40 | json_dict = {
41 | str(k): v for k, v in values.data.items()
42 | } # convert the integer tuples to strings
43 | json_dict["index_type"] = values.index_type
44 | json_dict["n"] = values.n
45 | json_dict["d"] = values.d
46 |
47 | with open(fname, "w+") as fp:
48 | fp.write(json.dumps(json_dict, indent=2))
49 |
50 |
51 | def to_int_tuple(str_tuple):
52 | """Convert string representations of integer tuples back to python integer tuples.
53 |
54 | This utility function is used to load n-Shapley Values from JSON.
55 |
56 | Args:
57 | str_tuple (str): String representaiton of an integer tuple, for example "(1,2,3)"
58 |
59 | Returns:
60 | tuple: Tuple of integers, for example (1,2,3)
61 | """
62 | start = str_tuple.find("(") + 1
63 | first_comma = str_tuple.find(",")
64 | end = str_tuple.rfind(")")
65 | # is the string of the form "(1,)", i.e. with a trailing comma?
66 | if len(re.findall("[0-9]", str_tuple[first_comma:])) == 0:
67 | end = first_comma
68 | str_tuple = str_tuple[start:end]
69 | return tuple(map(int, str_tuple.split(",")))
70 |
71 |
72 | def read_remove(d, key):
73 | """Read and remove a key from a dict d.
74 | """
75 | r = d[key]
76 | del d[key]
77 | return r
78 |
79 |
80 | def load(fname):
81 | """Load an interaction index from a JSON file.
82 |
83 | Args:
84 | fname (str): Filename.
85 |
86 | Returns:
87 | nshap.InteractionIndex: The loaded interaction index.
88 | """
89 | with open(fname, "r") as fp:
90 | str_dump = json.load(fp)
91 |
92 | index_type = read_remove(str_dump, "index_type")
93 | n = read_remove(str_dump, "n")
94 | d = read_remove(str_dump, "d")
95 |
96 | # convert the remaining string tuples to integer tuples
97 | python_dict = {to_int_tuple(k): v for k, v in str_dump.items()}
98 | return nshap.InteractionIndex(index_type, python_dict, n, d)
99 |
100 |
101 | def powerset(iterable):
102 | "The powerset function from https://docs.python.org/3/library/itertools.html#itertools-recipes."
103 | s = list(iterable)
104 | return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
105 |
--------------------------------------------------------------------------------
/src/nshap/vfunc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import functools
4 |
5 | import numbers
6 |
7 | #############################################################################################################################
8 | # Interventional SHAP
9 | #############################################################################################################################
10 |
11 |
12 | def interventional_shap(
13 | f, X, target=None, num_samples=1000, random_state=None, meomized=True
14 | ):
15 | """Approximate the value function of interventional SHAP.
16 |
17 | In order to compute n-Shapley Values, a data set is sampled once and then fixed for all evaluations of the value function.
18 |
19 | Args:
20 | f (function): The function to be explained. Will be called as f(x) where x has shape (1,d).
21 | X (_type_): Sample from the data distribution.
22 | target (int, optional): Target class. Required if the output of f(x) is multi-dimensional. Defaults to None.
23 | num_samples (int, optional): The number of samples that should be drawn from X in order estimate the value function. Defaults to 1000.
24 | random_state (_type_, optional): Random state that is passed to np.random.default_rng. Used for reproducibility. Defaults to None.
25 | meomized (bool, optional): Whether the returned value function should be meomized. Defaults to True.
26 |
27 | Returns:
28 | function: The vaue function.
29 | """
30 | # sample background data set
31 | rng = np.random.default_rng(random_state)
32 | if num_samples < X.shape[0]:
33 | indices = rng.integers(low=0, high=X.shape[0], size=num_samples)
34 | X = X[indices, :]
35 | # the function
36 | def fn(x, subset):
37 | subset = list(subset)
38 | x = x.flatten()
39 | values = []
40 | for idx in range(X.shape[0]):
41 | x_sample = X[idx, :].copy() # copy here is important, otherwise we modify X
42 | x_sample[subset] = x[subset]
43 | v = f(x_sample.reshape((1, -1)))
44 | # handle different return types
45 | if isinstance(v, numbers.Number):
46 | values.append(v)
47 | elif v.ndim == 1:
48 | if target is None:
49 | values.append(v[0])
50 | else:
51 | values.append(v[target])
52 | else:
53 | assert (
54 | target is not None
55 | ), "f returns multi-dimensional array, but target is not specified"
56 | values.append(v[0, target])
57 | return np.mean(values)
58 |
59 | if meomized: # meomization
60 | fn = memoized_vfunc(fn)
61 | return fn
62 |
63 |
64 | #############################################################################################################################
65 | # Meomization for value functions
66 | #############################################################################################################################
67 |
68 |
69 | class memoized_vfunc(object):
70 | """Decorator to meomize a vfunc. Handles hashability of the numpy.ndarray and the list.
71 |
72 | The hash depends only on the values of x that are in the subset of coordinates.
73 |
74 | This function is able to handle parameters x of shape (d,) and (1,d).
75 |
76 | Meomization helps to avoid duplicate evaluations of the value function, even if we perform
77 | computations that involve the same term multiple times.
78 |
79 | https://wiki.python.org/moin/PythonDecoratorLibrary
80 | """
81 |
82 | def __init__(self, func):
83 | """
84 |
85 | Args:
86 | func (function): The value function to be meomized.
87 | """
88 | self.func = func
89 | self.cache = {}
90 |
91 | def __call__(self, *args):
92 | x = args[0]
93 | if x.ndim == 2:
94 | assert (
95 | x.shape[0] == 1
96 | ), "Parameter x of value function has to be of shape (d,) or (1,d)."
97 | x = x[0]
98 | subset = tuple(args[1])
99 | x = tuple((x[i] for i in subset))
100 | hashable_args = (x, subset)
101 | if hashable_args in self.cache: # in cache
102 | return self.cache[hashable_args]
103 | else:
104 | value = self.func(*args)
105 | self.cache[hashable_args] = value
106 | return value
107 |
108 | def __repr__(self):
109 | """Return the function's docstring."""
110 | return f"Memoized value function: {self.func.__doc__}"
111 |
112 | def __get__(self, obj, objtype):
113 | """Support instance methods."""
114 | return functools.partial(self.__call__, obj)
115 |
--------------------------------------------------------------------------------
/tests/test_util.py:
--------------------------------------------------------------------------------
1 | import xgboost
2 |
3 | from folktables import ACSDataSource, ACSIncome
4 |
5 | from sklearn.preprocessing import StandardScaler
6 | from sklearn.model_selection import train_test_split
7 |
8 | import os
9 |
10 | data_root_dir = "../data/"
11 |
12 | paths = [data_root_dir]
13 | for p in paths:
14 | if not os.path.exists(p):
15 | os.mkdir(p)
16 |
17 |
18 | def folktables_income():
19 | # (down-)load the dataset
20 | data_source = ACSDataSource(
21 | survey_year="2016", horizon="1-Year", survey="person", root_dir=data_root_dir
22 | )
23 | data = data_source.get_data(states=["CA"], download=True)
24 | X, Y, _ = ACSIncome.df_to_numpy(data)
25 | feature_names = ACSIncome.features
26 |
27 | # feature names
28 | feature_names = ACSIncome.features
29 |
30 | # zero mean and unit variance for all features
31 | X = StandardScaler().fit_transform(X)
32 |
33 | # train-test split
34 | X_train, X_test, Y_train, Y_test = train_test_split(
35 | X, Y, train_size=0.8, random_state=0
36 | )
37 |
38 | return X_train, X_test, Y_train, Y_test, feature_names
39 |
--------------------------------------------------------------------------------
/tests/tests.py:
--------------------------------------------------------------------------------
1 | import xgboost
2 |
3 | from test_util import folktables_income
4 |
5 | import nshap
6 |
7 |
8 | def test_n_shapley():
9 | """Compare different formulas for computing n-shapley values.
10 |
11 | (1) via delta_S
12 | (2) via the component functions of the shapley gam
13 |
14 | Then compare the shapley values resulting from (1) and (2) to the shapley values computed with the shapley formula.
15 | """
16 | X_train, X_test, Y_train, Y_test, _ = folktables_income()
17 | X_train = X_train[:, 0:5]
18 | X_test = X_test[:, 0:5]
19 | gbtree = xgboost.XGBClassifier()
20 | gbtree.fit(X_train, Y_train)
21 |
22 | vfunc = nshap.vfunc.interventional_shap(gbtree.predict_proba, X_train, target=0)
23 |
24 | n_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc)
25 | moebius = nshap.moebius_transform(X_test[0, :], vfunc)
26 | shapley_values = nshap.shapley_values(X_test[0, :], vfunc)
27 |
28 | assert nshap.allclose(n_shapley_values, moebius)
29 | for k in range(1, X_train.shape[1]):
30 | k_shapley_values = nshap.n_shapley_values(X_test[0, :], vfunc, k)
31 | assert nshap.allclose(n_shapley_values.k_shapley_values(k), k_shapley_values)
32 | assert nshap.allclose(
33 | n_shapley_values.k_shapley_values(k), moebius.k_shapley_values(k)
34 | )
35 | assert nshap.allclose(n_shapley_values.k_shapley_values(1), shapley_values)
36 | assert nshap.allclose(moebius.k_shapley_values(1), shapley_values)
37 |
38 |
--------------------------------------------------------------------------------