├── .github
├── CODEOWNERS
└── workflows
│ └── python-app.yml
├── .gitignore
├── Example.html
├── Example.ipynb
├── LICENSE
├── README.md
├── assets
├── jpmorgan-logo.svg
└── xai_coe-logo.png
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
├── src
└── cfshap
│ ├── __init__.py
│ ├── attribution
│ ├── __init__.py
│ ├── composite.py
│ └── shap.py
│ ├── background
│ ├── __init__.py
│ ├── counterfactual_adapter.py
│ └── opposite.py
│ ├── base.py
│ ├── counterfactuals
│ ├── __init__.py
│ ├── composition.py
│ ├── knn.py
│ └── project.py
│ ├── evaluation
│ ├── __init__.py
│ ├── attribution
│ │ ├── __init__.py
│ │ ├── _basic.py
│ │ └── induced_counterfactual
│ │ │ ├── __init__.py
│ │ │ ├── _actionability.py
│ │ │ └── _utils.py
│ └── counterfactuals
│ │ ├── __init__.py
│ │ └── plausibility.py
│ ├── trend.py
│ └── utils
│ ├── __init__.py
│ ├── _python.py
│ ├── _utils.py
│ ├── parallel
│ ├── __init__.py
│ ├── batch.py
│ └── utils.py
│ ├── preprocessing
│ ├── __init__.py
│ └── scaling
│ │ ├── __init__.py
│ │ ├── madscaler.py
│ │ ├── multiscaler.py
│ │ └── quantiletransformer.py
│ ├── project.py
│ ├── random
│ ├── __init__.py
│ ├── _sample.py
│ └── _stability.py
│ └── tree.py
└── tests
└── test_usage.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Each line is a file pattern followed by one or more owners.
2 | # The order matters with a later match taking precedence,
3 | # Default Code Owners will be your maintainer team.
4 | # Set a review when someone opens a pull request.
5 | * @jpmorganchase/XAI-CoE_maintainer
6 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ${{ matrix.os }}
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [ubuntu-latest]
23 | python-version: [3.6, 3.7, 3.8, 3.9]
24 |
25 | steps:
26 | - uses: actions/checkout@v3
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v3
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip setuptools wheel
34 | pip install flake8 pytest
35 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
36 | - name: Install package (with extras test)
37 | run: |
38 | python -m pip install .[test]
39 | - name: Lint (flake8)
40 | run: |
41 | # Stop the build if there are Python syntax errors or undefined names
42 | flake8 . --select=E9,F63,F7,F82 --show-source
43 | # exit-zero treats all errors as warnings.
44 | flake8 . --exit-zero --show-source
45 | - name: Test (pytest)
46 | run: |
47 | pytest
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | **/__pycache__
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | **/.pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | **/.ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Windows
132 | **/Thumbs.db
133 |
134 | # Other cache
135 | **/.pyc
136 |
137 | **/.vscode
138 |
139 |
140 | links.sh
141 | links_init.py
142 | clean.sh
143 |
144 | a.yml
145 |
146 | src/emutils/data/lendingclub
147 |
148 | flake.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | [](https://github.com/jpmorganchase/cf-shap/blob/master/LICENSE)
9 | [](https://www.emanuelealbini.com)
10 |
11 |
12 | # Counterfactual SHAP (`cf-shap`)
13 | A modular framework for the generation of counterfactual feature attribution explanations (a.k.a., feature importance).
14 | This Python package implements the algorithms proposed in the following paper.
15 | If you use this package please cite our work.
16 |
17 | **Counterfactual Shapley Additive Explanations**
18 | Emanuele Albini, Jason Long, Danial Dervovic and Daniele Magazzeni
19 | J.P. Morgan AI Research
20 | [ACM](https://dl.acm.org/doi/abs/10.1145/3531146.3533168) | [ArXiv](https://arxiv.org/abs/2110.14270)
21 |
22 | ```
23 | @inproceedings{Albini2022,
24 | title = {Counterfactual {{Shapley Additive Explanations}}},
25 | booktitle = {2022 {{ACM Conference}} on {{Fairness}}, {{Accountability}}, and {{Transparency}}},
26 | author = {Albini, Emanuele and Long, Jason and Dervovic, Danial and Magazzeni, Daniele},
27 | year = {2022},
28 | series = {{{FAccT}} '22},
29 | pages = {1054--1070},
30 | doi = {10.1145/3531146.3533168}
31 | }
32 | ```
33 |
34 | **Note that this repository contains the package with the algorithms for the generation of the explanations proposed in the paper and their evaluations _but not_ the expriments themselves.** If you are interested in reproducing the results of the paper, please refer to the [cf-shap-facct22](https://github.com/jpmorganchase/cf-shap-facct22) repository (that uses the algorithms implemented in this repository).
35 |
36 | ## 1. Installation
37 | To install the package manually, simply use the following commands. Note that this package depends on shap>=0.39.0 package: you may want to install this package or the other dependencies manually (using conda or pip).
38 |
39 | ```bash
40 | # Clone the repo into the `cf-shap` directory
41 | git clone https://github.com/jpmorganchase/cf-shap.git
42 |
43 | # Install the package in editable mode
44 | pip install -e cf-shap
45 | ```
46 | The package has been tested with Python 3.6 and 3.8, but it is agnostic to the Python version being used.
47 | See `setup.py` and `requirements.txt` for more details on the dependencies of the package.
48 |
49 | ## 2. Usage Example
50 | Check out `Example.ipynb` or `Example.html` for a basic usage example of the package.
51 |
52 | ## 3. Contacts and Issues
53 |
54 | For further information or queries on this work you can contact the _Explainable AI Center of Excellence at J.P. Morgan_ ([xai.coe@jpmchase.com](mailto:xai.coe@jpmchase.com)) or [Emanuele Albini](https://www.emanuelealbini.com), the main author of the paper.
55 |
56 | If you have issues using the package, feel free to open an issue on the GitHub, or contact the authors using the contacts above. We will try to address any issue as soon as possible.
57 |
58 | ## 4. Disclaimer
59 |
60 | This repository was prepared for informational purposes by the Artificial Intelligence Research group of JPMorgan Chase & Co. and its affiliates (``JP Morgan''), and is not a product of the Research Department of JP Morgan. JP Morgan makes no representation and warranty whatsoever and disclaims all liability, for the completeness, accuracy or reliability of the information contained herein. This document is not intended as investment research or investment advice, or a recommendation, offer or solicitation for the purchase or sale of any security, financial instrument, financial product or service, or to be used in any way for evaluating the merits of participating in any transaction, and shall not constitute a solicitation under any jurisdiction or to any person, if such solicitation under such jurisdiction or to such person would be unlawful.
61 |
62 | The code is provided for illustrative purposes only and is not intended to be used for trading or investment purposes. The code is provided "as is" and without warranty of any kind, express or implied, including without limitation, any warranty of merchantability or fitness for a particular purpose. In no event shall JP Morgan be liable for any direct, indirect, incidental, special or consequential damages, including, without limitation, lost revenues, lost profits, or loss of prospective economic advantage, resulting from the use of the code.
63 |
--------------------------------------------------------------------------------
/assets/jpmorgan-logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
48 |
--------------------------------------------------------------------------------
/assets/xai_coe-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpmorganchase/cf-shap/ba58c61d4e110a950808d18dbe85b67af862549b/assets/xai_coe-logo.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.2
2 | scipy==1.5.2
3 | pandas==1.1.5
4 | scikit-learn==0.24.2
5 | xgboost==1.3.3
6 | joblib==1.0.1
7 | tqdm==4.61
8 | numba==0.51.2
9 | shap==0.39.0
10 | cached_property
11 | dataclasses
12 | typing_extensions
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = cfshap
3 | version = 0.0.2
4 | author = Emanuele Albini
5 |
6 | description = Counterfactual SHAP: a framework for counterfactual feature importance
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 |
10 | url = https://github.com/jpmorganchase/cf-shap
11 | project_urls =
12 | Source Code = https://github.com/jpmorganchase/cf-shap
13 | Bug Tracker = https://github.com/jpmorganchase/cf-shap/issues
14 | Author Website = https://www.emanuelealbini.com
15 |
16 | classifiers =
17 | Programming Language :: Python :: 3
18 | Programming Language :: Python :: 3.6
19 | Operating System :: OS Independent
20 |
21 | keywords =
22 | XAI
23 | Explainability
24 | Explainable AI
25 | Counterfactuals
26 | Algorithmic Recourse
27 | Contrastive Explanations
28 | Feature Importance
29 | Feature Attribution
30 | Machine Learning
31 | Shapley values
32 | SHAP
33 | FAccT22
34 |
35 | platform = any
36 |
37 | [options]
38 | include_package_data = True
39 | package_dir =
40 | = src
41 | packages = find:
42 |
43 | python_requires = >=3.6
44 | install_requires =
45 | numpy
46 | scipy
47 | pandas
48 | scikit-learn
49 | joblib
50 | tqdm
51 | numba>=0.51.2
52 | shap==0.39.0
53 | cached_property
54 | dataclasses
55 | typing_extensions
56 |
57 | [options.extras_require]
58 | test = pytest; xgboost; scikit-learn
59 |
60 | [options.packages.find]
61 | where = src
62 |
63 |
64 | [flake8]
65 | max-line-length = 200
66 | exclude = .git,__pycache__,docs,old,build,dist,venv
67 | # Blank lines, trailins spaces, etc.
68 | extend-ignore = W29,W39
69 | per-file-ignores =
70 | # imported but unused / unable to detect undefined names
71 | __init__.py: F401,F403
72 | # imported but unused / module level import not at top of file
73 | src/emutils/imports.py: F401, E402
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | import pkg_resources
3 |
4 | # setup.cfg
5 | pkg_resources.require('setuptools>=39.2')
6 |
7 | setuptools.setup()
--------------------------------------------------------------------------------
/src/cfshap/__init__.py:
--------------------------------------------------------------------------------
1 | import pkg_resources
2 |
3 | try:
4 | __version__ = pkg_resources.get_distribution(__name__).version
5 | except pkg_resources.DistributionNotFound:
6 | __version__ = "0.0.0"
--------------------------------------------------------------------------------
/src/cfshap/attribution/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Feature Attribution / Importance techniques.
5 | """
6 |
7 | from .shap import *
8 | from .composite import *
9 |
--------------------------------------------------------------------------------
/src/cfshap/attribution/composite.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | CF-SHAP Explainer, it uses a CF generator togheter with an explainer that supports at-runtime background distributions.
5 | """
6 |
7 | from typing import Union
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from ..base import (
12 | BaseExplainer,
13 | ExplainerSupportsDynamicBackground,
14 | BackgroundGenerator,
15 | CounterfactualMethod,
16 | ListOf2DArrays,
17 | )
18 |
19 | from ..background import CounterfactualMethodBackgroundGeneratorAdapter
20 | from ..utils import attrdict
21 |
22 | __all__ = [
23 | 'CompositeExplainer',
24 | 'CFExplainer',
25 | ]
26 |
27 |
28 | class CompositeExplainer(BaseExplainer):
29 | """CompositeExplainer allows to compose:
30 | - a background generator (or a counterfactual generation method), and
31 | - an explainer (i.e., feature importance method) that suppports a background distributions
32 | into a new explainer that uses the dataset obtained with the background generator
33 | as background for the feature importance method.
34 |
35 | """
36 | def __init__(
37 | self,
38 | background_generator: Union[CounterfactualMethod, BackgroundGenerator],
39 | explainer: ExplainerSupportsDynamicBackground,
40 | n_top: Union[None, int] = None,
41 | verbose: bool = True,
42 | ):
43 | """
44 |
45 | Args:
46 | background_generator (Union[CounterfactualMethod, BackgroundGenerator]): A background generator.
47 | explainer (ExplainerSupportsDynamicBackground): An explainer supporting a dynamic background dataset.
48 | n_top (Union[None, int], optional): Number of top counterfactuals (only if background_generator is a counterfactual method). Defaults to None.
49 | verbose (bool, optional): If True, will be verbose. Defaults to True.
50 |
51 | """
52 |
53 | super().__init__(background_generator.model)
54 | if isinstance(background_generator, CounterfactualMethod):
55 | background_generator = CounterfactualMethodBackgroundGeneratorAdapter(background_generator, n_top=n_top)
56 | else:
57 | if n_top is not None:
58 | raise NotImplementedError('n_top is supported only for counterfactual methods.')
59 |
60 | self.background_generator = background_generator
61 | self.explainer = explainer
62 | self.verbose = verbose
63 |
64 | def get_backgrounds(
65 | self,
66 | X: np.ndarray,
67 | background_data: Union[None, ListOf2DArrays] = None,
68 | ) -> ListOf2DArrays:
69 | """Generate the background datasets for the query instances.
70 |
71 | Args:
72 | X (np.ndarray): The query instances
73 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None.
74 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor.
75 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances.
76 |
77 | Returns:
78 | ListOf2DArrays : A list/array of background datasets
79 | nb_query_instances x nb_background_samples x nb_features
80 | """
81 | X = self.preprocess(X)
82 |
83 | # If we do not have any background then we compute it
84 | if background_data is None:
85 | return self.background_generator.get_backgrounds(X)
86 |
87 | # If we have some background we use it
88 | else:
89 | assert len(self.background_data) == X.shape[0]
90 | return background_data
91 |
92 | def _get_backgrounds_iterator(self, X, background_data):
93 | iters = zip(X, background_data)
94 | if len(X) > 100 and self.verbose:
95 | iters = tqdm(iters)
96 |
97 | return iters
98 |
99 | def get_attributions(
100 | self,
101 | X: np.array,
102 | background_data: Union[None, ListOf2DArrays] = None,
103 | **kwargs,
104 | ) -> np.ndarray:
105 | """Generate the feature attributions for the query instances.
106 |
107 | Args:
108 | X (np.ndarray): The query instances
109 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None.
110 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor.
111 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances.
112 |
113 | Returns:
114 | np.ndarray : An array of feature attributions
115 | nb_query_instances x nb_features
116 | """
117 | X = self.preprocess(X)
118 | backgrounds = self.get_backgrounds(X, background_data)
119 |
120 | shapvals = []
121 | for x, background in self._get_backgrounds_iterator(X, backgrounds):
122 | if len(background) > 0:
123 | # Set background
124 | self.explainer.data = background
125 | # Compute Shapley values
126 | shapvals.append(self.explainer.get_attributions(x.reshape(1, -1), **kwargs)[0])
127 | else:
128 | shapvals.append(np.full(X.shape[1], np.nan))
129 | return np.array(shapvals)
130 |
131 | def get_trends(
132 | self,
133 | X: np.array,
134 | background_data: Union[None, ListOf2DArrays] = None,
135 | ):
136 | """Generate the feature trends for the query instances.
137 |
138 | Args:
139 | X (np.ndarray): The query instances
140 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None.
141 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor.
142 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances.
143 |
144 | Returns:
145 | np.ndarray : An array of feature trends (+1 / -1)
146 | nb_query_instances x nb_features
147 | """
148 | X = self.preprocess(X)
149 | backgrounds = self.get_backgrounds(X, background_data)
150 |
151 | trends = []
152 | for x, background in self._get_backgrounds_iterator(X, backgrounds):
153 | if len(background) > 0:
154 | self.explainer.data = background
155 | trends.append(self.explainer.get_trends(x.reshape(1, -1))[0])
156 | else:
157 | trends.append(np.full(X.shape[1], np.nan))
158 | return np.array(trends)
159 |
160 | def __call__(
161 | self,
162 | X: np.array,
163 | background_data: Union[None, ListOf2DArrays] = None,
164 | ):
165 | """Generate the explanations for the query instances.
166 |
167 | Args:
168 | X (np.ndarray): The query instances
169 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None.
170 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor.
171 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances.
172 |
173 | Returns:
174 | attrdict : The explanations.
175 |
176 | See BaseExplainer.__call__ for more details on the output format.
177 | """
178 | X = self.preprocess(X)
179 | backgrounds = self.get_backgrounds(X, background_data)
180 |
181 | shapvals = []
182 |
183 | # Check if trends is implemented
184 | try:
185 | self.get_trends(X[:1])
186 | trends = []
187 | except NotImplementedError:
188 | trends = None
189 |
190 | for x, background in self._get_backgrounds_iterator(X, backgrounds):
191 | if len(background) > 0:
192 | self.explainer.data = background
193 | shapvals.append(self.explainer.get_attributions(x.reshape(1, -1))[0])
194 | if trends is not None:
195 | trends.append(self.explainer.get_trends(x.reshape(1, -1))[0])
196 | else:
197 | shapvals.append(np.full(X.shape[1], np.nan))
198 | if trends is not None:
199 | trends.append(np.full(X.shape[1], np.nan))
200 | return attrdict(
201 | backgrounds=backgrounds,
202 | values=np.array(shapvals),
203 | trends=np.array(trends) if trends is not None else None,
204 | )
205 |
206 |
207 | # Alias
208 | CFExplainer = CompositeExplainer
--------------------------------------------------------------------------------
/src/cfshap/attribution/shap.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | SHAP Wrapper
5 | """
6 |
7 | from abc import ABC
8 | from typing import Union
9 | import numpy as np
10 | import pandas as pd
11 |
12 | from shap.maskers import Independent as SHAPIndependent
13 | from shap import TreeExplainer as SHAPTreeExplainer
14 | from shap.explainers import Exact as SHAPExactExplainer
15 |
16 | from ..base import (BaseExplainer, BaseSupportsDynamicBackground, Model, TrendEstimatorProtocol)
17 | from ..utils import get_shap_compatible_model
18 |
19 | __all__ = [
20 | 'TreeExplainer',
21 | 'ExactExplainer',
22 | ]
23 |
24 |
25 | class SHAPExplainer(BaseExplainer, BaseSupportsDynamicBackground, ABC):
26 | @property
27 | def explainer(self):
28 | if self._data is None:
29 | self._raise_data_error()
30 | return self._explainer
31 |
32 | def get_backgrounds(self, X):
33 | return np.broadcast_to(self.data, [X.shape[0], *self.data.shape])
34 |
35 | def get_trends(self, X):
36 | if self.trend_estimator is None:
37 | return np.full(X.shape, np.nan)
38 | else:
39 | return np.array([self.trend_estimator(x, self.data) for x in X])
40 |
41 |
42 | class TreeExplainer(SHAPExplainer):
43 | """Monkey-patched and improved wrapper for (Interventional) TreeSHAP
44 | """
45 | def __init__(
46 | self,
47 | model: Model,
48 | data: Union[None, np.ndarray] = None,
49 | max_samples: int = 1e10,
50 | class_index: int = 1,
51 | trend_estimator: Union[None, TrendEstimatorProtocol] = None,
52 | **kwargs,
53 | ):
54 | """Create the TreeSHAP explainer.
55 |
56 | Args:
57 | model (Model): The model
58 | data (Union[None, np.ndarray], optional): The background dataset. Defaults to None.
59 | If None it must be set dynamically using self.data = ...
60 | max_samples (int, optional): Maximum number of (random) samples to draw from the background dataset. Defaults to 1e10.
61 | class_index (int, optional): Class for which SHAP values will be computed. Defaults to 1 (positive class in a binary classification setting).
62 | trend_estimator (Union[None, TrendEstimatorProtocol], optional): The feature trend estimator. Defaults to None.
63 | **kwargs: Other arguments to be passed to shap.TreeExplainer.__init__
64 |
65 | """
66 | super().__init__(model, None)
67 |
68 | self.max_samples = max_samples
69 | self.class_index = class_index
70 | self.trend_estimator = trend_estimator
71 | self.kwargs = kwargs
72 |
73 | # Allow only for interventional SHAP
74 | if ('feature_perturbation' in kwargs) and (kwargs['feature_perturbation'] != 'interventional'):
75 | raise NotImplementedError()
76 |
77 | self.data = data
78 |
79 | @BaseSupportsDynamicBackground.data.setter
80 | def data(self, data):
81 | if data is not None:
82 | if not hasattr(self, '_explainer'):
83 | self._explainer = SHAPTreeExplainer(
84 | get_shap_compatible_model(self.model),
85 | data=SHAPIndependent(self.preprocess(data), max_samples=self.max_samples),
86 | **self.kwargs,
87 | )
88 |
89 | # Copied and adapted from SHAP 0.39.0 __init__
90 | # -- Start of copy --
91 | self._explainer.data = self._explainer.model.data = self._data = SHAPIndependent(
92 | self.preprocess(data), max_samples=self.max_samples).data
93 | self._explainer.data_missing = self._explainer.model.data_missing = pd.isna(self._explainer.data)
94 | try:
95 | self._explainer.expected_value = self._explainer.model.predict(self._explainer.data).mean(0)
96 | except ValueError:
97 | raise Exception(
98 | "Currently TreeExplainer can only handle models with categorical splits when "
99 | "feature_perturbation=\"tree_path_dependent\" and no background data is passed. Please try again using "
100 | "shap.TreeExplainer(model, feature_perturbation=\"tree_path_dependent\").")
101 | if hasattr(self._explainer.expected_value, '__len__') and len(self._explainer.expected_value) == 1:
102 | self._explainer.expected_value = self._explainer.expected_value[0]
103 |
104 | # if our output format requires binary classification to be represented as two outputs then we do that here
105 | if self._explainer.model.model_output == "probability_doubled" and self._explainer.expected_value is not None:
106 | self._explainer.expected_value = [1 - self._explainer.expected_value, self._explainer.expected_value]
107 | # -- End of copy --
108 | else:
109 | self._data = None
110 |
111 | def get_attributions(self, *args, **kwargs) -> np.ndarray:
112 | """Compute SHAP values
113 | Proxy for shap.TreeExplainer.shap_values(*args, **kwargs)
114 |
115 | Returns:
116 | np.ndarray: nb_samples x nb_features array with SHAP values of class self.class_index
117 | """
118 | r = self.explainer.shap_values(*args, **kwargs)
119 |
120 | # Select Shapley values of a specific class
121 | if isinstance(r, list):
122 | return r[self.class_index]
123 | else:
124 | return r
125 |
126 |
127 | class ExactExplainer(SHAPExplainer):
128 | def __init__(self, model, data=None, max_samples=1e10, function: str = 'predict', **kwargs):
129 | """Monkey-patched and improved constructor for TreeSHAP
130 |
131 | Args:
132 | model (any): A model
133 | data (np.ndarray, optional): Reference data. Defaults to None.
134 | max_samples (int, optional): Maximum number of samples in the reference data. Defaults to 1e10.
135 | function (str): the name of the function to call on the model.
136 | **kwargs: Other arguments to be passed to shap.ExactExplainer.__init__
137 | """
138 |
139 | super().__init__(model, None)
140 |
141 | self.max_samples = max_samples
142 | self.function = function
143 | self.kwargs = kwargs
144 |
145 | if ('identity' in kwargs) and (kwargs['link'] != 'identity'):
146 | raise NotImplementedError()
147 |
148 | self.data = data
149 |
150 | @BaseSupportsDynamicBackground.data.setter
151 | def data(self, data):
152 | if data is not None:
153 | self._data = self.preprocess(data)
154 | self._explainer = SHAPExactExplainer(
155 | lambda X: getattr(self.model, self.function)(X),
156 | SHAPIndependent(self._data, max_samples=self.max_samples),
157 | **self.kwargs,
158 | )
159 | else:
160 | self._data = None
161 |
162 | def get_attributions(self, *args, **kwargs):
163 | """Compute SHAP values
164 | Proxy for shap.ExactExplainer.shap_values(*args, **kwargs)
165 |
166 | Returns:
167 | np.ndarray: nb_samples x nb_features array with SHAP values
168 | """
169 | return self.explainer(*args, **kwargs).values
--------------------------------------------------------------------------------
/src/cfshap/background/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Background/Dataset Generation Techniques.
5 | """
6 |
7 | from .counterfactual_adapter import *
8 | from .opposite import *
--------------------------------------------------------------------------------
/src/cfshap/background/counterfactual_adapter.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Adapter from a counterfactual method to background generator.
5 | """
6 | from typing import Union
7 | import numpy as np
8 |
9 | from ..base import (BaseBackgroundGenerator, BackgroundGenerator, CounterfactualMethod, MultipleCounterfactualMethod,
10 | ListOf2DArrays)
11 |
12 | from ..utils import (
13 | get_top_counterfactuals,
14 | expand_dims_counterfactuals,
15 | )
16 |
17 | __all__ = ['CounterfactualMethodBackgroundGeneratorAdapter']
18 |
19 |
20 | class CounterfactualMethodBackgroundGeneratorAdapter(BaseBackgroundGenerator, BackgroundGenerator):
21 | """Adapter to make a counterfactual method into a background generator"""
22 | def __init__(
23 | self,
24 | counterfactual_method: CounterfactualMethod,
25 | n_top: Union[None, int] = None, # By default: All
26 | ):
27 | """
28 |
29 | Args:
30 | counterfactual_method (CounterfactualMethod): The counterfactual method
31 | n_top (Union[None, int], optional): Number of top-counterfactuals to select as background. Defaults to None (all).
32 | """
33 | self.counterfactual_method = counterfactual_method
34 | self.n_top = n_top
35 |
36 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays:
37 | """Generate the background datasets for each query instance
38 |
39 | Args:
40 | X (np.ndarray): The query instances
41 |
42 | Returns:
43 | ListOf2DArrays: A list/array of background datasets
44 | nb_query_intances x nb_background_points x nb_features
45 | """
46 |
47 | X = self.preprocess(X)
48 |
49 | # If we do not have any background then we compute it
50 | if isinstance(self.counterfactual_method, MultipleCounterfactualMethod):
51 | if self.n_top is None:
52 | return self.counterfactual_method.get_multiple_counterfactuals(X)
53 | elif self.n_top == 1:
54 | return expand_dims_counterfactuals(self.counterfactual_method.get_counterfactuals(X))
55 | else:
56 | return get_top_counterfactuals(
57 | self.counterfactual_method.get_multiple_counterfactuals(X),
58 | X,
59 | n_top=self.n_top,
60 | nan=False,
61 | )
62 | elif isinstance(self.counterfactual_method, CounterfactualMethod):
63 | if self.n_top is not None and self.n_top != 1:
64 | raise ValueError('Counterfactual methodology do not supportthe generation of multiple counterfactuals.')
65 | return np.expand_dims(self.counterfactual_method.get_counterfactuals(X), axis=1)
66 | else:
67 | raise NotImplementedError('Unsupported CounterfactualMethod.')
--------------------------------------------------------------------------------
/src/cfshap/background/opposite.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Background data computed as samples with different label or prediction.
5 | """
6 |
7 | import numpy as np
8 |
9 | from ..base import (
10 | Model,
11 | BaseBackgroundGenerator,
12 | ListOf2DArrays,
13 | )
14 |
15 | __all__ = ['DifferentLabelBackgroundGenerator', 'DifferentPredictionBackgroundGenerator']
16 |
17 |
18 | class BaseDifferent_BackgrondGenerator(BaseBackgroundGenerator):
19 | def __init__(self, model, max_samples=None, random_state=None):
20 | super().__init__(model, None, random_state)
21 |
22 | self.max_samples = max_samples
23 |
24 | @property
25 | def data(self):
26 | return self._data
27 |
28 | def _save_data(self, data):
29 | data = data.copy()
30 | data = self.sample(data, self.max_samples)
31 | data.flags.writeable = False
32 | return data
33 |
34 | def _set_data(self, X, y):
35 | self._data = X
36 | self._data_per_class = {j: self._save_data(X[y != j]) for j in np.unique(y)}
37 |
38 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays:
39 | X = self.preprocess(X)
40 | y = self.model.predict(X)
41 | return [self._data_per_class[j] for j in y]
42 |
43 |
44 | class DifferentPredictionBackgroundGenerator(BaseDifferent_BackgrondGenerator):
45 | """The background dataset will be constituted of the points with a different PREDICTION to that of the query instance.
46 | """
47 | def __init__(self, model: Model, data: np.ndarray, **kwargs):
48 | super().__init__(model, **kwargs)
49 |
50 | # Based on the prediction
51 | data = self.preprocess(data)
52 | self._set_data(data, self.model.predict(data))
53 |
54 |
55 | class DifferentLabelBackgroundGenerator(BaseDifferent_BackgrondGenerator):
56 | """The background dataset will be constituted of the points with a different LABEL to that of the query instance.
57 | """
58 | def __init__(self, model: Model, X: np.ndarray, y: np.ndarray, **kwargs):
59 | super().__init__(model, **kwargs)
60 |
61 | # Based on the label
62 | X = self.preprocess(X)
63 | self._set_data(X, y)
64 |
--------------------------------------------------------------------------------
/src/cfshap/base.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | This module contains base classes and interfaces (protocol in Python jargon).
5 |
6 | Note: this module contains classes that are more general than needed for this package.
7 | This is to allow for future integration in a more general XAI package.
8 |
9 | Most of the interfaces, base classes and methods are self-explanatory.
10 |
11 | """
12 |
13 | from abc import ABC, abstractmethod
14 | import warnings
15 | from typing import Union, List
16 |
17 | try:
18 | # In Python >= 3.8 this functionality is included in the standard library
19 | from typing import Protocol
20 | from typing import runtime_checkable
21 | except (ImportError, ModuleNotFoundError):
22 | # Python < 3.8 - Backward Compatibility through package
23 | from typing_extensions import Protocol
24 | from typing_extensions import runtime_checkable
25 |
26 | import numpy as np
27 |
28 | from .utils import attrdict
29 | from .utils.random import np_sample
30 |
31 | __all__ = [
32 | 'Scaler',
33 | 'Model',
34 | 'ModelWithDecisionFunction',
35 | 'XGBWrapping',
36 | 'Explainer',
37 | 'ExplainerSupportsDynamicBackground',
38 | 'BaseExplainer',
39 | 'BaseSupportsDynamicBackground',
40 | 'BaseGroupExplainer',
41 | 'BackgroundGenerator',
42 | 'CounterfactualMethod',
43 | 'MultipleCounterfactualMethod',
44 | 'MultipleCounterfactualMethodSupportsWrapping',
45 | 'MultipleCounterfactualMethodWrappable',
46 | 'BaseCounterfactualMethod',
47 | 'BaseMultipleCounterfactualMethod',
48 | 'TrendEstimatorProtocol',
49 | 'ListOf2DArrays',
50 | 'CounterfactualEvaluationScorer',
51 | 'BaseCounterfactualEvaluationScorer',
52 | ]
53 |
54 | ListOf2DArrays = Union[List[np.ndarray], np.ndarray]
55 |
56 | # ------------------- MODELs, etc. -------------------------
57 |
58 |
59 | @runtime_checkable
60 | class Model(Protocol):
61 | """Protocol for a ML model"""
62 | def predict(self, X: np.ndarray) -> np.ndarray:
63 | pass
64 |
65 | def predict_proba(self, X: np.ndarray) -> np.ndarray:
66 | pass
67 |
68 |
69 | @runtime_checkable
70 | class ModelWithDecisionFunction(Model, Protocol):
71 | """Protocol for a Model with a decision function as well"""
72 | def decision_function(self, X: np.ndarray) -> np.ndarray:
73 | pass
74 |
75 |
76 | @runtime_checkable
77 | class XGBWrapping(Model, Protocol):
78 | """Protocol for an XGBoost model wrapper"""
79 | def get_booster(self):
80 | pass
81 |
82 |
83 | @runtime_checkable
84 | class Scaler(Protocol):
85 | """Protocol for a Scaler"""
86 | def transform(self, X: np.ndarray) -> np.ndarray:
87 | pass
88 |
89 | def inverse_transform(self, X: np.ndarray) -> np.ndarray:
90 | pass
91 |
92 |
93 | # ------------------- Explainers, etc. -------------------------
94 |
95 |
96 | class BaseClass(ABC):
97 | """Base class for all explainability methods"""
98 | def __init__(self, model: Model, scaler: Union[Scaler, None] = None, random_state: int = 2021):
99 |
100 | self._model = model
101 | self._scaler = scaler
102 | self.random_state = random_state
103 |
104 | # model and scaler cannot be changed at runtime. Set as properties.
105 | @property
106 | def model(self):
107 | return self._model
108 |
109 | @property
110 | def scaler(self):
111 | return self._scaler
112 |
113 | def preprocess(self, X: np.ndarray):
114 | if not isinstance(X, np.ndarray):
115 | raise ValueError('Must pass a NumPy array.')
116 |
117 | if len(X.shape) != 2:
118 | raise ValueError("The input data must be a 2D matrix.")
119 |
120 | if X.shape[0] == 0:
121 | raise ValueError(
122 | "An empty array was passed! You must pass a non-empty array of samples in order to generate explanations."
123 | )
124 |
125 | return X
126 |
127 | def sample(self, X: np.ndarray, n: int):
128 | if n is not None:
129 | X = np_sample(X, n, random_state=self.random_state, safe=True)
130 |
131 | return X
132 |
133 | def scale(self, X: np.ndarray):
134 | if self.scaler is None:
135 | return X
136 | else:
137 | return self.scaler.transform(X)
138 |
139 |
140 | @runtime_checkable
141 | class Explainer(Protocol):
142 | """Protocol for an Explainer (a feature attribution/importance method).
143 |
144 | Attributes:
145 | model (Model): The model for which the feature importance is computed
146 | scaler (Scaler, optional): The scaler for the data. Default to None (i.e., no scaling).
147 |
148 | Methods:
149 | get_attributions(X): Returns the feature attributions.
150 |
151 | Optional Methods:
152 | get_trends(X): Returns the feature trends.
153 | get_backgrounds(X): Returns the background datasets.
154 |
155 | To build a new explainer one can easily extend BaseExplainer.
156 | """
157 |
158 | model: Model
159 | scaler: Union[Scaler, None]
160 |
161 | def get_attributions(self, X):
162 | pass
163 |
164 | # Optional
165 | # def get_trends(self, X):
166 | # pass
167 |
168 | # def get_backgrounds(self, X):
169 | # pass
170 |
171 |
172 | @runtime_checkable
173 | class SupportsDynamicBackground(Protocol):
174 | """Additional Protocol for a class that supports at-runtime change of the background data."""
175 | @property
176 | def data(self):
177 | pass
178 |
179 | @data.setter
180 | def data(self, data):
181 | pass
182 |
183 |
184 | @runtime_checkable
185 | class ExplainerSupportsDynamicBackground(Explainer, SupportsDynamicBackground, Protocol):
186 | """Protocol for an Explainer that supports at-runtime change of the background data"""
187 | pass
188 |
189 |
190 | class BaseExplainer(BaseClass, ABC):
191 | """Base class for a feature attribution/importance method"""
192 | @abstractmethod
193 | def get_attributions(self, X: np.ndarray) -> np.ndarray:
194 | """Generate the feature attributions for query instances X"""
195 | pass
196 |
197 | def get_trends(self, X: np.ndarray) -> np.ndarray:
198 | """Generate the feature trends for query instances X"""
199 | raise NotImplementedError('trends method is not implemented!')
200 |
201 | def get_backgrounds(self, X: np.ndarray) -> np.ndarray:
202 | """Returns the background datasets for query instances X"""
203 | raise NotImplementedError('get_backgrounds method is not implemented!')
204 |
205 | def __call__(self, X: np.ndarray) -> attrdict:
206 | """Returns the explanations
207 |
208 | Args:
209 | X (np.ndarray): The query instances
210 |
211 | Returns:
212 | attrdict: An attrdict (i.e., a dict which fields can be accessed also through attributes) with the following attributes:
213 | - .values : the feature attributions
214 | - .backgrounds : the background datasets (if any)
215 | - .trends : the feature trends (if any)
216 | """
217 | X = self.preprocess(X)
218 | return attrdict(
219 | values=self.get_attributions(X),
220 | backgrounds=self.get_backgrounds(X),
221 | trends=self.get_trends(X),
222 | )
223 |
224 | # Alias for 'get_attributions' for backward-compatibility
225 | def shap_values(self, *args, **kwargs):
226 | return self.get_attributions(*args, **kwargs)
227 |
228 |
229 | class BaseSupportsDynamicBackground(ABC):
230 | """Base class for a class that supports at-runtime change of the background data."""
231 | @property
232 | def data(self):
233 | if self._data is None:
234 | self._raise_data_error()
235 | return self._data
236 |
237 | def _raise_data_error(self):
238 | raise ValueError('Must set background data first.')
239 |
240 | @data.setter
241 | @abstractmethod
242 | def data(self, data):
243 | pass
244 |
245 |
246 | class BaseGroupExplainer:
247 | """Base class for an explainer (feature attribution) for groups of features."""
248 | def preprocess_groups(self, feature_groups: List[List[int]], nb_features):
249 |
250 | features_in_groups = sum(feature_groups, [])
251 | nb_groups = len(feature_groups)
252 |
253 | if nb_groups > nb_features:
254 | raise ValueError('There are more groups than features.')
255 |
256 | if len(set(features_in_groups)) != len(features_in_groups):
257 | raise ValueError('Some features are in multiple groups!')
258 |
259 | if len(set(features_in_groups)) < nb_features:
260 | raise ValueError('Not all the features are in groups')
261 |
262 | if any([len(x) == 0 for x in feature_groups]):
263 | raise ValueError('Some feature groups are empty!')
264 |
265 | return feature_groups
266 |
267 |
268 | # ------------------------------------- BACKGROUND GENERATOR --------------------------------------
269 |
270 |
271 | @runtime_checkable
272 | class BackgroundGenerator(Protocol):
273 | """Protocol for a Background Generator: can be used together with an explainer to dynamicly generate backgrounds for each instance (see `composite`)"""
274 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays:
275 | """Returns the background datasets for the query instances.
276 |
277 | Args:
278 | X (np.ndarray): The query instances.
279 |
280 | Returns:
281 | ListOf2DArrays: The background datasets.
282 | """
283 | pass
284 |
285 |
286 | class BaseBackgroundGenerator(BaseClass, ABC, BackgroundGenerator):
287 | """Base class for a background generator."""
288 | @abstractmethod
289 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays:
290 | pass
291 |
292 |
293 | # ------------------------------------- TREND ESTIMATOR --------------------------------------
294 |
295 |
296 | @runtime_checkable
297 | class TrendEstimatorProtocol(Protocol):
298 | """Protocol for a feature Trend Estimator"""
299 | def predict(self, X: np.ndarray, YY: ListOf2DArrays) -> np.ndarray:
300 | pass
301 |
302 | def __call__(self, x: np.ndarray, Y: np.ndarray) -> np.ndarray:
303 | pass
304 |
305 |
306 | # ------------------- Counterfactuals, etc. -------------------------
307 |
308 |
309 | @runtime_checkable
310 | class CounterfactualMethod(Protocol):
311 | """Protocol for a counterfactual generation method (that generate a single counterfactual per query instance)."""
312 | model: Model
313 |
314 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray:
315 | pass
316 |
317 |
318 | @runtime_checkable
319 | class MultipleCounterfactualMethod(CounterfactualMethod, Protocol):
320 | """Protocol for a counterfactual generation method (that generate a single OR MULTIPLE counterfactuals per query instance)."""
321 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays:
322 | pass
323 |
324 |
325 | class BaseCounterfactualMethod(BaseClass, ABC, CounterfactualMethod):
326 | """Base class for a counterfactual generation method (that generate a single counterfactual per query instance)."""
327 | def __init__(self, *args, **kwargs):
328 | super().__init__(*args, **kwargs)
329 |
330 | self.invalid_counterfactual = 'raise'
331 |
332 | def _invalid_response(self, invalid: Union[None, str]) -> str:
333 | invalid = invalid or self.invalid_counterfactual
334 | assert invalid in ('nan', 'raise', 'ignore')
335 | return invalid
336 |
337 | def postprocess(
338 | self,
339 | X: np.ndarray,
340 | XC: np.ndarray,
341 | invalid: Union[None, str] = None,
342 | ) -> np.ndarray:
343 | """Post-process counterfactuals
344 |
345 | Args:
346 | X (np.ndarray : nb_samples x nb_features): The query instances
347 | XC (np.ndarray : nb_samples x nb_features): The counterfactuals
348 | invalid (Union[None, str], optional): It can have the following values. Defaults to None ('raise').
349 | - 'nan': invalid counterfactuals (non changing prediction) will be marked with NaN
350 | - 'raise': an error will be raised if invalid counterfactuals are passed
351 | - 'ignore': Nothing will be node. Invalid counterfactuals will be returned.
352 |
353 | Returns:
354 | np.ndarray: The post-processed counterfactuals
355 | """
356 |
357 | invalid = self._invalid_response(invalid)
358 |
359 | # Mask with the non-flipped counterfactuals
360 | not_flipped_mask = (self.model.predict(X) == self.model.predict(XC))
361 | if not_flipped_mask.sum() > 0:
362 | if invalid == 'raise':
363 | self._raise_invalid()
364 | elif invalid == 'nan':
365 | self._warn_invalid()
366 | XC[not_flipped_mask, :] = np.nan
367 |
368 | return XC
369 |
370 | def _warn_invalid(self):
371 | warnings.warn('!! ERROR: Some counterfactuals are NOT VALID (will be set to NaN)')
372 |
373 | def _raise_invalid(self):
374 | raise RuntimeError('Invalid counterfactuals')
375 |
376 | def _raise_nan(self):
377 | raise RuntimeError('NaN counterfactuals are generated before post-processing.')
378 |
379 | def _raise_inf(self):
380 | raise RuntimeError('+/-inf counterfactuals are generated before post-processing.')
381 |
382 | @abstractmethod
383 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray:
384 | pass
385 |
386 |
387 | class BaseMultipleCounterfactualMethod(BaseCounterfactualMethod):
388 | """Base class for a counterfactual generation method (that generate a single OR MULTIPLE counterfactuals per query instance)."""
389 | def multiple_postprocess(
390 | self,
391 | X: np.ndarray,
392 | XX_C: ListOf2DArrays,
393 | invalid: Union[None, str] = None,
394 | allow_nan: bool = True,
395 | allow_inf: bool = False,
396 | ) -> ListOf2DArrays:
397 | """Post-process multiple counterfactuals
398 |
399 | Args:
400 | X (np.ndarray : nb_samples x nb_features): The query instances
401 | XX_C (ListOf2DArrays : nb_samples x nb_counterfactuals x nb_features): The counterfactuals
402 | invalid (Union[None, str], optional): It can have the following values. Defaults to None ('raise').
403 | - 'nan': invalid counterfactuals (non changing prediction) will be marked with NaN
404 | - 'raise': an error will be raised if invalid counterfactuals are passed
405 | - 'ignore': Nothing will be node. Invalid counterfactuals will be returned.
406 | allow_nan (bool, optional): If True, allows NaN counterfactuals a input (invalid). If False, it raises an error. Defaults to True.
407 | allow_inf (bool, optional): If True, allows infinite in counterfactuals. If False, it raise an error. Defaults to False.
408 |
409 | Returns:
410 | ListOf2DArrays : The post-processed counterfactuals
411 | """
412 |
413 | invalid = self._invalid_response(invalid)
414 |
415 | # Reshape (for zero-length arrays)
416 | XX_C = [X_C.reshape(-1, X.shape[1]) for X_C in XX_C]
417 |
418 | # Check for NaN and Inf
419 | for XC in XX_C:
420 | if not allow_nan and np.isnan(XC).sum() != 0:
421 | self._raise_nan()
422 | if not allow_inf and np.isinf(XC).sum() != 0:
423 | self._raise_inf()
424 |
425 | # Mask with the non-flipped counterfactuals
426 | nb_counters = np.array([X_c.shape[0] for X_c in XX_C])
427 | not_flipped_mask = np.equal(
428 | np.repeat(self.model.predict(X), nb_counters),
429 | self.model.predict(np.concatenate(XX_C, axis=0)),
430 | )
431 | if not_flipped_mask.sum() > 0:
432 | if invalid == 'raise':
433 | print('X, f(X) :', X, self.model.predict(X))
434 | print('X_C, f(X_C) :', XX_C, self.model.predict(np.concatenate(XX_C, axis=0)))
435 | self._raise_invalid()
436 | elif invalid == 'nan':
437 | self._warn_invalid()
438 | sections = np.cumsum(nb_counters[:-1])
439 | not_flipped_mask = np.split(not_flipped_mask, indices_or_sections=sections)
440 |
441 | # Set them to nan
442 | for i, nfm in enumerate(not_flipped_mask):
443 | XX_C[i][nfm, :] = np.nan
444 |
445 | return XX_C
446 |
447 | def multiple_trace_postprocess(self, X, XTX_counter, invalid=None):
448 | invalid = self._invalid_response(invalid)
449 |
450 | # Reshape (for zero-length arrays)
451 | XTX_counter = [[X_C.reshape(-1, X.shape[1]) for X_C in TX_C] for TX_C in XTX_counter]
452 |
453 | # Mask with the non-flipped counterfactuals
454 | shapess = [[X_C.shape[0] for X_C in TX_C] for TX_C in XTX_counter]
455 | shapes = [sum(S) for S in shapess]
456 |
457 | X_counter = np.concatenate([np.concatenate(TX_C, axis=0) for TX_C in XTX_counter], axis=0)
458 | not_flipped_mask = np.equal(
459 | np.repeat(self.model.predict(X), shapes),
460 | self.model.predict(X_counter),
461 | )
462 | if not_flipped_mask.sum() > 0:
463 | if invalid == 'raise':
464 | self._raise_invalid()
465 | elif invalid == 'nan':
466 | self._warn_invalid()
467 | sections = np.cumsum(shapes[:-1])
468 | sectionss = [np.cumsum(s[:-1]) for s in shapess]
469 | not_flipped_mask = np.split(not_flipped_mask, indices_or_sections=sections)
470 | not_flipped_mask = [np.split(NFM, indices_or_sections=s) for NFM, s in zip(not_flipped_mask, sectionss)]
471 |
472 | # Set them to nan
473 | for i, NFM in enumerate(not_flipped_mask):
474 | for j, nfm in enumerate(NFM):
475 | X_counter[i][j][nfm, :] = np.nan
476 |
477 | return XTX_counter
478 |
479 | @abstractmethod
480 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays:
481 | pass
482 |
483 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray:
484 | return np.array([X_C[0] for X_C in self.get_multiple_counterfactuals(X)])
485 |
486 | # Alias backward compatibility
487 | def diverse_postprocess(self, *args, **kwargs):
488 | return self.multiple_postprocess(*args, **kwargs)
489 |
490 | def diverse_trace_postprocess(self, *args, **kwargs):
491 | return self.multiple_trace_postprocess(*args, **kwargs)
492 |
493 |
494 | @runtime_checkable
495 | class Wrappable(Protocol):
496 | verbose: Union[int, bool]
497 |
498 |
499 | @runtime_checkable
500 | class SupportsWrapping(Protocol):
501 | @property
502 | def data(self):
503 | pass
504 |
505 | @data.setter
506 | @abstractmethod
507 | def data(self, data):
508 | pass
509 |
510 |
511 | @runtime_checkable
512 | class MultipleCounterfactualMethodSupportsWrapping(MultipleCounterfactualMethod, SupportsWrapping, Protocol):
513 | """Protocol for a counterfactual method that can be wrapped by another one
514 | (i.e., the output of a SupportsWrapping method can be used as background data of another)"""
515 | pass
516 |
517 |
518 | @runtime_checkable
519 | class MultipleCounterfactualMethodWrappable(MultipleCounterfactualMethod, Wrappable, Protocol):
520 | """Protocol for a counterfactual method that can used as wrapping for another one
521 | (i.e., a Wrappable method can use the ouput of an another CFX method as input)"""
522 | pass
523 |
524 |
525 | # ------------------------ EVALUATION -----------------------
526 |
527 |
528 | @runtime_checkable
529 | class CounterfactualEvaluationScorer(Protocol):
530 | """Protocol for an evaluation method that returns an array of scores (float) for a list of counterfactuals."""
531 | def score(self, X: np.ndarray) -> np.ndarray:
532 | pass
533 |
534 |
535 | class BaseCounterfactualEvaluationScorer(ABC):
536 | @abstractmethod
537 | def score(self, X: np.ndarray) -> np.ndarray:
538 | pass
539 |
--------------------------------------------------------------------------------
/src/cfshap/counterfactuals/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Counterfactuals Generation (Algorithmic Recourse) Techniques.
5 | """
6 |
7 | from .knn import *
8 | from .project import *
9 | from .composition import *
--------------------------------------------------------------------------------
/src/cfshap/counterfactuals/composition.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Counterfactual method that composes two children methods.
5 | One as a wrapper and one as a wrapped method.
6 | The wrapper method will use the counterfactuals generated by the wrapped method as background.
7 | """
8 |
9 | __all__ = [
10 | 'CounterfactualComposition',
11 | 'compose_diverse_counterfactual_method',
12 | ]
13 |
14 | from copy import deepcopy
15 | import numpy as np
16 | from tqdm import tqdm
17 |
18 | from ..base import (
19 | BaseMultipleCounterfactualMethod,
20 | MultipleCounterfactualMethodSupportsWrapping,
21 | MultipleCounterfactualMethodWrappable,
22 | )
23 |
24 | class CounterfactualComposition(BaseMultipleCounterfactualMethod):
25 | def __init__(
26 | self,
27 | wrapping_instance: MultipleCounterfactualMethodSupportsWrapping,
28 | wrapped_instance: MultipleCounterfactualMethodWrappable,
29 | verbose=True,
30 | ):
31 |
32 | if wrapping_instance.model != wrapped_instance.model:
33 | raise ValueError('Models of wrapping and wrapped method differs.')
34 |
35 | super().__init__(wrapping_instance.model, None)
36 |
37 | if not isinstance(wrapped_instance, MultipleCounterfactualMethodWrappable):
38 | raise ValueError('wrapped_instance do not implement necessary interface.')
39 |
40 | if not isinstance(wrapping_instance, MultipleCounterfactualMethodSupportsWrapping):
41 | raise ValueError('wrapping_instance do not implement necessary interface.')
42 |
43 | # Scalers are ignored (they are already into the wrapped/wrapping instances)
44 | super().__init__(wrapping_instance.model, scaler=None)
45 |
46 | wrapping_instance = deepcopy(wrapping_instance)
47 | wrapped_instance = deepcopy(wrapped_instance)
48 |
49 | wrapped_instance.verbose = 0
50 |
51 | self.wrapping_instance = wrapping_instance
52 | self.wrapped_instance = wrapped_instance
53 |
54 | self.data = None
55 | self.verbose = verbose
56 |
57 | def get_multiple_counterfactuals(self, X):
58 | X = self.preprocess(X)
59 |
60 | # Get underlying counterfactuals
61 | X_Cs = self.wrapped_instance.get_multiple_counterfactuals(X)
62 |
63 | iters = enumerate(zip(X, X_Cs))
64 | if len(X) > 100 and self.verbose:
65 | iters = tqdm(iters, desc='Composition CF: Step 2')
66 |
67 | # Postprocess them
68 | for i, (x, X_C) in iters:
69 | # Set background
70 | self.data = self.wrapping_instance.data = X_C
71 |
72 | # Compute counterfactuals
73 | X_Cs[i] = self.wrapping_instance.get_multiple_counterfactuals(np.array([x]))[0]
74 |
75 | return X_Cs
76 |
77 |
78 | def compose_diverse_counterfactual_method(*args, **kwargs):
79 | if len(args) < 2:
80 | raise ValueError('At least 2 counterfactual methods methods must be passed.')
81 |
82 | # Most inner
83 | composition = args[-1]
84 |
85 | # Iteratively compose
86 | for wrapping_instance in args[-2::-1]:
87 | composition = CounterfactualComposition(wrapping_instance, composition, **kwargs)
88 |
89 | return composition
90 |
--------------------------------------------------------------------------------
/src/cfshap/counterfactuals/knn.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Implementation of K-Nearest Neighbours Counterfactuals
5 | """
6 |
7 | __all__ = ['KNNCounterfactuals']
8 |
9 | from typing import Union
10 |
11 | import numpy as np
12 | from sklearn.neighbors import NearestNeighbors
13 |
14 | from ..base import (
15 | BaseMultipleCounterfactualMethod,
16 | Model,
17 | Scaler,
18 | ListOf2DArrays,
19 | )
20 | from ..utils import keydefaultdict
21 |
22 | class KNNCounterfactuals(BaseMultipleCounterfactualMethod):
23 | """Returns the K Nearest Neighbours of the query instance with a different prediction.
24 |
25 | """
26 | def __init__(
27 | self,
28 | model: Model,
29 | scaler: Union[None, Scaler],
30 | X: np.ndarray,
31 | nb_diverse_counterfactuals: Union[None, int, float] = None,
32 | n_neighbors: Union[None, int, float] = None,
33 | distance: str = None,
34 | max_samples: int = int(1e10),
35 | random_state: int = 2021,
36 | verbose: int = 0,
37 | **distance_params,
38 | ):
39 | """
40 |
41 | Args:
42 | model (Model): The model.
43 | scaler (Union[None, Scaler]): The scaler for the data.
44 | X (np.ndarray): The background dataset.
45 | nb_diverse_counterfactuals (Union[None, int, float], optional): Number of counterfactuals to generate. Defaults to None.
46 | n_neighbors (Union[None, int, float], optional): Number of neighbours to generate. Defaults to None.
47 | Note that this is an alias for nb_diverse_counterfactuals in this class.
48 | distance (str, optional): The distance metric to use for K-NN. Defaults to None.
49 | max_samples (int, optional): Number of samples of the background to use at most. Defaults to int(1e10).
50 | random_state (int, optional): Random seed. Defaults to 2021.
51 | verbose (int, optional): Level of verbosity. Defaults to 0.
52 | **distance_params: Additional parameters for the distance metric
53 | """
54 |
55 | assert nb_diverse_counterfactuals is not None or n_neighbors is not None, 'nb_diverse_counterfactuals or n_neighbors must be set.'
56 |
57 | super().__init__(model, scaler, random_state)
58 |
59 | self._metric, self._metric_params = distance, distance_params
60 | self.__nb_diverse_counterfactuals = nb_diverse_counterfactuals
61 | self.__n_neighbors = n_neighbors
62 | self.max_samples = max_samples
63 |
64 | self.data = X
65 | self.verbose = verbose
66 |
67 | @property
68 | def data(self):
69 | return self._data
70 |
71 | @data.setter
72 | def data(self, data):
73 | self._raw_data = self.preprocess(data)
74 | if self.max_samples < len(self._raw_data):
75 | self._raw_data = self.sample(self._raw_data, n=self.max_samples)
76 | self._preds = self.model.predict(self._raw_data)
77 |
78 | # In the base class this two information are equivalent
79 | if self.__n_neighbors is None:
80 | self.__n_neighbors = self.__nb_diverse_counterfactuals
81 | if self.__nb_diverse_counterfactuals is None:
82 | self.__nb_diverse_counterfactuals = self.__n_neighbors
83 |
84 | def get_nb_of_items(nb):
85 | if np.isinf(nb):
86 | return keydefaultdict(lambda pred: self._data[pred].shape[0])
87 | elif isinstance(nb, int) and nb >= 1:
88 | return keydefaultdict(lambda pred: min(nb, self._data[pred].shape[0]))
89 | elif isinstance(nb, float) and nb <= 1.0 and nb > 0.0:
90 | return keydefaultdict(lambda pred: int(max(1, round(len(self._data[pred]) * nb))))
91 | else:
92 | raise ValueError(
93 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)'
94 | )
95 |
96 | self._n_neighbors = get_nb_of_items(self.__n_neighbors)
97 | self._nb_diverse_counterfactuals = get_nb_of_items(self.__nb_diverse_counterfactuals)
98 |
99 | # We will be searching neighbors of a different class
100 | self._data = keydefaultdict(lambda pred: self._raw_data[self._preds != pred])
101 |
102 | self._nn = keydefaultdict(lambda pred: NearestNeighbors(
103 | n_neighbors=self._n_neighbors[pred],
104 | metric=self._metric,
105 | p=self._metric_params['p'] if 'p' in self._metric_params else 2,
106 | metric_params=self._metric_params,
107 | ).fit(self.scale(self._data[pred])))
108 |
109 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray:
110 | """Generate the closest counterfactual for each query instance"""
111 |
112 | # Pre-process
113 | X = self.preprocess(X)
114 |
115 | preds = self.model.predict(X)
116 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)}
117 |
118 | X_counter = np.zeros_like(X)
119 |
120 | for pred, indices in preds_indices.items():
121 | _, neighbors_indices = self._nn[pred].kneighbors(self.scale(X), n_neighbors=1)
122 | X_counter[indices] = self._data[pred][neighbors_indices.flatten()]
123 |
124 | # Post-process
125 | X_counter = self.postprocess(X, X_counter, invalid='raise')
126 |
127 | return X_counter
128 |
129 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays:
130 | """Generate the multiple closest counterfactuals for each query instance"""
131 |
132 | # Pre-condition
133 | assert self.__n_neighbors == self.__nb_diverse_counterfactuals, (
134 | 'n_neighbors and nb_diverse_counterfactuals are set to different values'
135 | f'({self.__n_neighbors} != {self.__nb_diverse_counterfactuals}).'
136 | 'When both are set they must be set to the same value.')
137 |
138 | # Pre-process
139 | X = self.preprocess(X)
140 |
141 | preds = self.model.predict(X)
142 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)}
143 |
144 | X_counter = [
145 | np.full((self._nb_diverse_counterfactuals[preds[i]], X.shape[1]), np.nan) for i in range(X.shape[0])
146 | ]
147 |
148 | for pred, indices in preds_indices.items():
149 | _, neighbors_indices = self._nn[pred].kneighbors(self.scale(X[indices]), n_neighbors=None)
150 | counters = self._data[pred][neighbors_indices.flatten()].reshape(len(indices),
151 | self._nb_diverse_counterfactuals[pred], -1)
152 | for e, i in enumerate(indices):
153 | # We use :counters[e].shape[0] so it raises an exception if shape are not coherent.
154 | X_counter[i][:counters[e].shape[0]] = counters[e]
155 |
156 | # Post-process
157 | X_counter = self.diverse_postprocess(X, X_counter, invalid='raise')
158 |
159 | return X_counter
--------------------------------------------------------------------------------
/src/cfshap/counterfactuals/project.py:
--------------------------------------------------------------------------------
1 | """
2 | Counterfactuals as projection onto the decision boundary (estimated by bisection).
3 | """
4 |
5 | __author__ = 'Emanuele Albini'
6 | __all__ = ['BisectionProjectionDBCounterfactuals']
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from ..base import BaseMultipleCounterfactualMethod
12 | from ..utils import keydefaultdict
13 | from ..utils.project import find_decision_boundary_bisection
14 |
15 |
16 | class BisectionProjectionDBCounterfactuals(BaseMultipleCounterfactualMethod):
17 | def __init__(self,
18 | model,
19 | data,
20 | num_iters=100,
21 | earlystop_error=1e-8,
22 | scaler=None,
23 | max_samples=1e10,
24 | nb_diverse_counterfactuals=None,
25 | random_state=0,
26 | verbose=1,
27 | **kwargs):
28 |
29 | if scaler is not None:
30 | raise NotImplementedError('Scaling is not supported by Cone.')
31 |
32 | super().__init__(model, scaler, random_state=random_state)
33 |
34 | self.num_iters = num_iters
35 | self.earlystop_error = earlystop_error
36 | self.kwargs = kwargs
37 |
38 | self._max_samples = min(max_samples or np.inf, nb_diverse_counterfactuals or np.inf)
39 |
40 | self.data = data
41 | self.verbose = verbose
42 |
43 | @property
44 | def data(self):
45 | return self._data
46 |
47 | @data.setter
48 | def data(self, data):
49 | if data is not None:
50 | self._raw_data = self.preprocess(data)
51 | self._preds = self.model.predict(self._raw_data)
52 |
53 | self._data = keydefaultdict(
54 | lambda pred: self.sample(self._raw_data[self._preds != pred], n=self._max_samples))
55 | else:
56 | self._raw_data = None
57 | self._preds = None
58 | self._data = None
59 |
60 | def __get_counterfactuals(self, x, pred):
61 | if self.data is not None:
62 | return np.array(
63 | find_decision_boundary_bisection(
64 | x=x,
65 | Y=self.data[pred],
66 | model=self.model,
67 | scaler=self.scaler,
68 | num=self.num_iters,
69 | error=self.earlystop_error,
70 | method='counterfactual',
71 | desc=None,
72 | # model_parallelism=1,
73 | # n_jobs=1,
74 | **self.kwargs,
75 | ))
76 | else:
77 | raise ValueError('Invalid state. `self.data` must be set.')
78 |
79 | def get_multiple_counterfactuals(self, X):
80 | X = self.preprocess(X)
81 | preds = self.model.predict(X)
82 | X_counter = []
83 | iters = tqdm(zip(X, preds)) if self.verbose else zip(X, preds)
84 | for x, pred in iters:
85 | X_counter.append(self.__get_counterfactuals(x, pred))
86 |
87 | return self.diverse_postprocess(X, X_counter)
88 |
--------------------------------------------------------------------------------
/src/cfshap/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Evaluation methods for explanations.
5 | """
6 |
--------------------------------------------------------------------------------
/src/cfshap/evaluation/attribution/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Evaluation methods for feature attributions.
5 | """
6 |
7 | from ._basic import *
--------------------------------------------------------------------------------
/src/cfshap/evaluation/attribution/_basic.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Basic evaluation utilities for feature attributions
5 | """
6 |
7 | from typing import Union
8 | import numpy as np
9 | import pandas as pd
10 |
11 | __all__ = [
12 | 'feature_attributions_statistics',
13 | ]
14 |
15 |
16 | def feature_attributions_statistics(phi: np.ndarray, mean=False) -> Union[pd.DataFrame, pd.Series]:
17 | """Calculate some statistics on feature attributions
18 |
19 | Args:
20 | phi (np.ndarray: nb_samples x nb_features): The feature attributions
21 | mean (bool, optional): Calculate the mean. Defaults to True.
22 |
23 | Returns:
24 | The statistics
25 | pd.DataFrame: By default
26 | pd.Series: If mean is True.
27 | """
28 | stats = pd.DataFrame()
29 | stats['Non-computable Phi (NaN)'] = np.any(np.isnan(phi), axis=1)
30 | stats['Non-Positive Phi'] = np.all(phi < 0, axis=1)
31 | stats['Non-Negative Phi'] = np.all(phi > 0, axis=1)
32 | stats['All Zeros Phi'] = np.all(phi == 0, axis=1)
33 | stats['Number of Feature with Positive Phi'] = np.sum(phi > 0, axis=1)
34 | stats['Number of Feature with Negative Phi'] = np.sum(phi < 0, axis=1)
35 | stats['Number of Feature with Zero Phi'] = np.sum(phi == 0, axis=1)
36 |
37 | if mean:
38 | # Rename columns accordingly
39 | def __rename(name):
40 | if name.startswith('Number'):
41 | return 'Average ' + name
42 | else:
43 | return 'Percentage of ' + name
44 |
45 | stats = stats.rename(columns={name: __rename(name) for name in stats.columns.values})
46 |
47 | # Mean over all the attributions in the dataset
48 | stats = stats.mean()
49 | return stats
--------------------------------------------------------------------------------
/src/cfshap/evaluation/attribution/induced_counterfactual/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Method to evaluate the actionability of a feature attribution technique.
5 |
6 | This is the implementation of the method officialy proposed in our paper.
7 | Therein called counterfactual-ability.
8 |
9 | Counterfactual Shapely Additive Explanations
10 | https://arxiv.org/pdf/2110.14270.pdf
11 | """
12 |
13 | from ._actionability import *
--------------------------------------------------------------------------------
/src/cfshap/evaluation/attribution/induced_counterfactual/_actionability.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from ....utils import attrdict
8 | from ....utils.preprocessing import MultiScaler
9 |
10 | from ....base import Model
11 | from ....utils.tree import TreeEnsemble, get_shap_compatible_model
12 | from ._utils import (
13 | transform_trees,
14 | mask_values,
15 | inverse_actionabiltiy,
16 | )
17 |
18 | __all__ = ['TreeInducedCounterfactualGeneratorV2']
19 |
20 |
21 | class TreeInducedCounterfactualGeneratorV2:
22 | def __init__(
23 | self,
24 | model: Model,
25 | data: np.ndarray = None,
26 | multiscaler: np.ndarray = None,
27 | global_feature_trends: Union[None, list, np.ndarray] = None,
28 | random_state: int = 0,
29 | ):
30 | assert data is not None or multiscaler is not None
31 |
32 | self.model = model
33 | self.data = data
34 | self.multiscaler = multiscaler
35 | self.global_feature_trends = global_feature_trends
36 | self.random_state = random_state
37 |
38 | @property
39 | def multiscaler(self):
40 | return self._multiscaler
41 |
42 | @multiscaler.setter
43 | def multiscaler(self, multiscaler):
44 | self._multiscaler = multiscaler
45 |
46 | @property
47 | def data(self):
48 | raise NotImplementedError()
49 |
50 | @data.setter
51 | def data(self, data):
52 | if data is not None:
53 | self._multiscaler = MultiScaler(self._data)
54 |
55 | @property
56 | def model(self):
57 | return self._model
58 |
59 | @property
60 | def shap_model(self):
61 | return self._shap_model
62 |
63 | @model.setter
64 | def model(self, model):
65 | self._model = model
66 | self._shap_model = get_shap_compatible_model(self.model)
67 |
68 | def transform(
69 | self,
70 | X: np.ndarray,
71 | explanations: attrdict,
72 | K: Tuple[int, float] = (1, 2, 3, 4, 5, np.inf),
73 | action_cost_normalization: str = 'quantile',
74 | action_cost_aggregation: str = 'L1', # L0, L1, L2, Linf
75 | action_strategy: str = 'proportional', # 'equal', 'proportional', 'random'
76 | action_scope: str = 'positive', # 'positive', 'negative', 'all'
77 | action_direction: str = 'local', # 'local', 'global', 'random'
78 | precision: float = 1e-4,
79 | show_progress: bool = True,
80 | starting_sample: int = 0,
81 | max_samples: Union[int, None] = None,
82 | nan_explanation: str = 'raise', # 'ignore'
83 | counters: np.ndarray = None,
84 | costs: np.ndarray = None,
85 | desc: str = "",
86 | ):
87 |
88 | # Must pass trends if global trends are used
89 | assert self.global_feature_trends is not None or action_direction != 'global'
90 | # Must pass local trends if local trends are used
91 | assert hasattr(explanations, 'trends') or action_direction != 'local'
92 |
93 | assert (counters is None and costs is None) or (counters is not None and costs is not None)
94 |
95 | assert action_cost_aggregation[0] == 'L'
96 | action_cost_aggregation = float(action_cost_aggregation[1:])
97 |
98 | random_state = np.random.RandomState(self.random_state)
99 |
100 | # Compute stuff from the results
101 | Xn = self.multiscaler.transform(X, method=action_cost_normalization)
102 | Pred = self.model.predict(X)
103 | Phi = explanations.values
104 |
105 | # Let's raise some warning in case of NaN
106 | nb_PhiNaN = np.any(np.isnan(Phi), axis=1).sum()
107 | if nb_PhiNaN > 0:
108 | logging.warning(f'There are {nb_PhiNaN} NaN explanations.')
109 | if nb_PhiNaN == len(X):
110 | logging.error('All explanations are NaN.')
111 |
112 | # Pre-compute directions of change
113 | if action_direction == 'local':
114 | Directions = explanations.trends
115 | elif action_direction == 'global':
116 | Directions = (-1 * (np.broadcast_to(Pred.reshape(-1, 1), X.shape) - .5) * 2 *
117 | np.broadcast_to(np.array(self.global_feature_trends), X.shape)).astype(int)
118 | elif action_direction == 'random':
119 | Directions = random_state.choice((-1, 1), X.shape)
120 | else:
121 | raise ValueError('Invalid direction_strategy.')
122 |
123 | # If we don't have a direction we set Phi to 0
124 | Phi = Phi * (Directions != 0)
125 |
126 | logging.info(f"There are {np.any(Directions == 0, axis = 1).sum()}/{len(X)} samples with some tau_i = 0")
127 |
128 | assert np.all((Directions == -1) | (Directions == +1) | (np.isnan(Phi))
129 | | ((Directions == 0) & (Phi == 0))), "Some trends are not +1 or -1"
130 |
131 | # Compute the order of feature-change
132 | # NOTE: This must be done before the action_strategy computation below!
133 | if action_scope == 'all':
134 | Order = np.argsort(-1 * np.abs(Phi), axis=1)
135 | elif action_scope == 'positive':
136 | Order = np.argsort(-1 * Phi, axis=1)
137 | Phi = Phi * (Phi > 0)
138 | elif action_scope == 'negative':
139 | Order = np.argsort(+1 * Phi, axis=1)
140 | Phi = Phi * (Phi < 0)
141 | else:
142 | raise ValueError('Invalid action_scope.')
143 |
144 | assert np.all((Phi >= 0) | (np.isnan(Phi))), "Something weird happend: Phi should be >= 0 at this point"
145 |
146 | # Modify Phi based on the action strategy
147 | if action_strategy == 'proportional':
148 | pass
149 | elif action_strategy == 'equal':
150 | Phi = 1 * (Phi > 0)
151 | elif action_strategy == 'random':
152 | Phi = random_state.random(Phi.shape) * (Phi != 0) * (~np.isnan(Phi))
153 | elif action_strategy == 'noisy':
154 | Phi = Phi * random_state.random(Phi.shape)
155 | else:
156 | raise ValueError('Invalid action_strategy.')
157 |
158 | if len(np.unique(Pred)) != 1:
159 | raise RuntimeError('Samples have different predicted class! This is weird: stopping.')
160 |
161 | assert np.all((Phi >= 0) | (np.isnan(Phi))), "Something weird happend: Phi should still be >= 0 at this point"
162 |
163 | # Do some shape checks
164 | assert X.shape == Xn.shape
165 | assert X.shape == Phi.shape
166 | assert X.shape == Order.shape
167 | assert X.shape == Directions.shape
168 | assert (X.shape[0], ) == Pred.shape
169 | assert np.all(self.model.predict(X) == Pred)
170 | assert np.all(Pred % Pred.astype(np.int) == 0)
171 |
172 | # Compute trees (and cost-normalized version)
173 | trees = TreeEnsemble(self.shap_model).trees
174 | ntrees = transform_trees(trees, self.multiscaler, action_cost_normalization)
175 |
176 | # Run the experiments from scratch
177 | if counters is None:
178 | counters = []
179 | costs = []
180 |
181 | # Resume the experiments
182 | else:
183 | assert counters.shape[0] == costs.shape[0] == (min(X[starting_sample:].shape[0], max_samples or np.inf))
184 | assert counters.shape[1] == costs.shape[1]
185 | assert counters.shape[1] >= len(K)
186 | assert counters.shape[2] == X.shape[1]
187 |
188 | counters = counters.copy()
189 | costs = costs.copy()
190 |
191 | counters = counters.transpose((1, 0, 2)) # Top-K x Samples x nb_features
192 | costs = costs.transpose((1, 0)) # Top-K x Samples
193 | recompute_mask = np.isinf(costs)
194 |
195 | for kidx, k in enumerate(K):
196 | # Mask only top-k features
197 | PhiMasked = mask_values(Phi, Order, k)
198 |
199 | # Create object to be iterated
200 | iters = np.array(list(zip(X, Xn, PhiMasked, Directions, Pred))[starting_sample:], dtype=np.object)
201 | if max_samples is not None:
202 | iters = iters[:max_samples]
203 | if not isinstance(counters, list):
204 | iters = iters[recompute_mask[kidx]]
205 |
206 | if show_progress:
207 | iters = tqdm(
208 | iters,
209 | desc=(f'{desc} - (A)St={action_strategy}/D={action_direction}/Sc={action_scope}/K={k} '
210 | '(C)N={action_cost_normalization}/Agg={action_cost_aggregation}'),
211 | )
212 |
213 | # Iterate over all samples
214 | results = []
215 |
216 | for a, args in enumerate(iters):
217 | # print(a) # TQDM may not be precise
218 | results.append(
219 | inverse_actionabiltiy(
220 | *args,
221 | model=self.model,
222 | multiscaler=self.multiscaler,
223 | normalization=action_cost_normalization,
224 | trees=trees,
225 | ntrees=ntrees,
226 | degree=action_cost_aggregation,
227 | precision=precision,
228 | error=nan_explanation,
229 | ))
230 |
231 | counters_ = [r[0] for r in results]
232 | costs_ = [r[1] for r in results]
233 | if isinstance(counters, list):
234 | counters.append(counters_)
235 | costs.append(costs_)
236 | else:
237 | if recompute_mask[kidx].sum() > 0:
238 | counters[kidx][recompute_mask[kidx]] = np.array(counters_)
239 | costs[kidx][recompute_mask[kidx]] = np.array(costs_)
240 |
241 | # Counterfactuals
242 | if isinstance(counters, list):
243 | counters = np.array(counters) # Top-K x Samples x nb_features
244 | counters = counters.transpose((1, 0, 2)) # Samples x Top-K x nb_features
245 |
246 | # Costs
247 | if isinstance(costs, list):
248 | costs = np.array(costs) # Top-K x Samples
249 | costs = costs.transpose((1, 0)) # Samples x Top-K
250 |
251 | return counters, costs
--------------------------------------------------------------------------------
/src/cfshap/evaluation/attribution/induced_counterfactual/_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import warnings
4 | from copy import deepcopy
5 |
6 | import numba
7 | import numpy as np
8 |
9 | __all__ = ['transform_trees', 'mask_values', 'inverse_actionabiltiy']
10 |
11 | EPSILON = np.finfo(float).eps
12 |
13 |
14 | def mask_values(Phi, Order, k):
15 | if np.isinf(k):
16 | k = Phi.shape[1]
17 |
18 | # Select the subset
19 | SubsetCols = Order[:, :k]
20 | SubsetRows = np.tile(np.arange(Phi.shape[0]), (SubsetCols.shape[1], 1)).T
21 |
22 | # Create a mask
23 | OrderMask = np.zeros(Phi.shape, dtype=np.int)
24 | OrderMask[SubsetRows, SubsetCols] = 1
25 |
26 | # Apply the mask
27 | PhiMasked = Phi * OrderMask
28 |
29 | return PhiMasked
30 |
31 |
32 | def transform_trees(trees, multiscaler, method):
33 | ntrees = deepcopy(trees)
34 | for i in range(len(ntrees)):
35 | for j in range(len(ntrees[i].features)):
36 | if ntrees[i].children_left[j] != -1: # A split (not a leaf)
37 | ntrees[i].thresholds[j] = multiscaler.value_transform(ntrees[i].thresholds[j], ntrees[i].features[j],
38 | method)
39 | return ntrees
40 |
41 |
42 | def cost_trees(trees, ntrees, xn):
43 | ctrees = deepcopy(trees)
44 | for i in range(len(ctrees)):
45 | ctrees[i].thresholds_c = np.zeros(ctrees[i].thresholds.shape)
46 | for j in range(len(ctrees[i].thresholds_c)):
47 | if ctrees[i].children_left[j] != -1:
48 | ctrees[i].thresholds_c[j] = ntrees[i].thresholds[j] - xn[ctrees[i].features[j]]
49 | return ctrees
50 |
51 |
52 | @numba.jit(nopython=True, nogil=True)
53 | def filter_region(region, x, phi, tau, f, mc, Mc, m, M):
54 | # Get contraint
55 | t = tau[f]
56 | v = phi[f]
57 |
58 | if v == 0.0:
59 | # Must be 0-cost
60 | if (Mc < 0.0) or (mc > 0.0):
61 | return None
62 | else:
63 | # Must be trend/sign compatible (direction of the change to the feature)
64 | if (t > 0 and Mc <= 0.0) or (t < 0 and mc >= 0.0):
65 | return None
66 |
67 | # This is a super-rare error caused by np.float32 <> np.float64 conflicts
68 | # TreeEnsemble returns np.float32, but the dataset could be np.float64
69 | # The multiscaler then passed with a np.float32 (64) spits a np.float32 (64), it does not change type
70 | # This results in a wrong mc/Mc = thr_n - xn, i.e., sgn(thr_n - xn) != sgn(thr - x)
71 | if (t > 0 and M <= x[f]) or (t < 0 and m >= x[f]):
72 | return None
73 |
74 | if t > 0:
75 | mc = np.maximum(0.0, mc)
76 | m = np.maximum(x[f], m)
77 | if t < 0:
78 | Mc = np.minimum(0.0, Mc)
79 | M = np.minimum(x[f], M)
80 |
81 | assert mc < Mc and m < M
82 |
83 | # Precompute direction constraints
84 | if mc == 0.0 or Mc == 0.0:
85 | # We need this because inf * 0 = Nan (raise error)
86 | vm = 0.0
87 | elif mc * Mc > 0.0:
88 | vm = np.minimum(np.abs(mc), np.abs(Mc))
89 | else:
90 | vm = 0.0
91 | vM = np.maximum(np.abs(mc), np.abs(Mc))
92 | vm, vM = vm / v, vM / v
93 |
94 | for f_ in range(region.shape[0]):
95 | mc_, Mc_, m_, M_ = region[f_][0], region[f_][1], region[f_][2], region[f_][3]
96 | # Must be overlapping
97 | if f == f_:
98 | if M <= m_ or m >= M_:
99 | return None
100 |
101 | # Must be direction-compatible (proportion between features)
102 | v_ = phi[f_]
103 | if v != 0.0 and v_ != 0.0:
104 | if mc_ == 0.0 or Mc_ == 0.0:
105 | # We need this because inf * 0 = Nan (raise error)
106 | vm_ = 0.0
107 | elif mc_ * Mc_ > 0.0:
108 | vm_ = np.minimum(np.abs(mc_), np.abs(Mc_))
109 | else:
110 | vm_ = 0.0
111 | vM_ = np.maximum(np.abs(mc_), np.abs(Mc_))
112 | vm_, vM_ = vm_ / v_, vM_ / v_
113 | if (vM <= vm_ or vm >= vM_):
114 | return None
115 |
116 | nregion = region.copy()
117 | nregion[f][0] = np.maximum(region[f][0], mc)
118 | nregion[f][1] = np.minimum(region[f][1], Mc)
119 | nregion[f][2] = np.maximum(region[f][2], m)
120 | nregion[f][3] = np.minimum(region[f][3], M)
121 | return nregion
122 |
123 |
124 | @numba.jit(nopython=True, nogil=True)
125 | def filter_children_regions(tree_features, tree_children_left, tree_children_right, tree_thresholds, tree_thresholds_c,
126 | region, x, phi, tau, n):
127 |
128 | # Non-compatible: return empty array
129 | if region is None:
130 | return np.zeros((0, x.shape[0], 4))
131 |
132 | # Recurse
133 | if tree_children_left[n] != -1:
134 | f = tree_features[n]
135 | l = tree_children_left[n] # noqa: E741
136 | r = tree_children_right[n]
137 | t = tree_thresholds[n]
138 | tc = tree_thresholds_c[n]
139 |
140 | # print('Region', region)
141 |
142 | regionl = filter_region(region, x, phi, tau, f, -np.inf, tc, -np.inf, t)
143 | regionr = filter_region(region, x, phi, tau, f, tc, np.inf, t, np.inf)
144 |
145 | # print('Region L', regionl)
146 | # print('Region R', regionr)
147 |
148 | return np.concatenate((filter_children_regions(tree_features, tree_children_left, tree_children_right,
149 | tree_thresholds, tree_thresholds_c, regionl, x, phi, tau, l),
150 | filter_children_regions(tree_features, tree_children_left, tree_children_right,
151 | tree_thresholds, tree_thresholds_c, regionr, x, phi, tau, r)))
152 |
153 | # Recurse termination
154 | else:
155 | return np.expand_dims(region, 0)
156 |
157 |
158 | @numba.jit(nopython=True, nogil=True)
159 | def filter_tree_regions(tree_features, tree_children_left, tree_children_right, tree_thresholds, tree_thresholds_c,
160 | regions, x, phi, tau):
161 | cregions = np.zeros((0, x.shape[0], 4))
162 | for region in regions:
163 | filtered_children = filter_children_regions(tree_features, tree_children_left, tree_children_right,
164 | tree_thresholds, tree_thresholds_c, region, x, phi, tau, 0)
165 | cregions = np.concatenate((cregions, filtered_children))
166 | return cregions
167 |
168 |
169 | def filter_regions(trees, x, phi, tau):
170 | regions = np.tile(np.array([-np.inf, np.inf, -np.inf, np.inf]), (1, x.shape[0], 1))
171 | for t, tree in enumerate(trees):
172 | regions = filter_tree_regions(tree.features, tree.children_left, tree.children_right, tree.thresholds,
173 | tree.thresholds_c, regions, x, phi, tau)
174 | # There must be at least a region
175 | assert len(regions) > 0
176 | regions = [{i: tuple(v) for i, v in enumerate(region)} for region in regions]
177 | return np.array(regions)
178 |
179 |
180 | def assert_regions(regions):
181 | for region in regions:
182 | for f, (mc, Mc, m, M) in region.items():
183 | assert mc < Mc and m < M
184 | return regions
185 |
186 |
187 | def cost_in_region(region, cost):
188 | for f, (mc, Mc, _, _) in region.items():
189 | if cost[f] > Mc or cost[f] < mc:
190 | return False
191 | return True
192 |
193 |
194 | def point_in_region(region, point):
195 | for f, (_, _, m, M) in region.items():
196 | if point[f] > M or point[f] < m:
197 | return False
198 | return True
199 |
200 |
201 | def find_region_candidates(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass):
202 | X_costs = []
203 |
204 | # print('>>>>>>>>>>>> REGION', region, '<<<<<<<<<<<<<<<<<<<<')
205 |
206 | # Compute candidates in the region
207 | for f, (mc, Mc, m, M) in region.items():
208 | # print(f, mc, Mc, phi[f], end=' / ')
209 | if phi[f] == 0.0: # => f is not a constraining variable
210 | # print('NC phi=0.0')
211 | continue
212 | if Mc < 0: # => decrease f
213 | delta_star = Mc - precision
214 | if delta_star < mc:
215 | delta_star = (mc + Mc) / 2
216 | elif mc > 0: # => increase f
217 | delta_star = mc + precision
218 | if delta_star > Mc:
219 | delta_star = (mc + Mc) / 2
220 | elif mc <= 0 and Mc >= 0: # keep f the same (not constraining)
221 | # print('NC mc <= 0 and Mc >= 0')
222 | continue
223 | else:
224 | raise AssertionError('This should never happen.')
225 |
226 | # print(delta_star)
227 | # print('Dstar=', delta_star, 'Tau=', tau, '/ phi=', phi)
228 | X_costs.append(tau * phi / phi[f] * np.abs(delta_star))
229 |
230 | # Means that the regions is the current region (i.e., the region where x lies)
231 | if len(X_costs) == 0:
232 | return None, None
233 |
234 | # print('costs prev:', X_costs)
235 |
236 | # Filter points outside the region (in terms of cost)
237 | X_costs = np.array([cost for cost in X_costs if cost_in_region(region, cost)])
238 |
239 | # print('costs post 1:', X_costs)
240 |
241 | if len(X_costs) > 1:
242 | logging.debug(
243 | f'More than one direction-compatible point in the region (cost-space) with precision {precision}.')
244 |
245 | if len(X_costs) == 0:
246 | logging.debug(
247 | f'Less than one direction-compatible point in the region (cost-space) with precision {precision}.')
248 | return X_costs, X_costs # Return empty (both X_p and X_costs)
249 |
250 | # Filter points outside the region (on the input space)
251 | X_p = multiscaler.inverse_transform(np.tile(xn, (X_costs.shape[0], 1)) + X_costs, method=normalization)
252 | X_p = np.where(X_costs == 0.0, np.tile(x, (X_costs.shape[0], 1)), X_p)
253 |
254 | if bypass is True:
255 | return X_p, X_costs
256 |
257 |
258 | # print('X_p post 1:', X_p)
259 |
260 | mask = np.array([point_in_region(region, x) for x in X_p])
261 | X_p = X_p[mask]
262 | X_costs = X_costs[mask]
263 |
264 | # print('costs post 2:', X_costs)
265 | # print('X_p post 2:', X_p)
266 |
267 | if len(X_costs) > 1:
268 | logging.debug(f'More than one direction-compatible point in the region (x-space) with precision {precision}.')
269 |
270 | if len(X_costs) == 0:
271 | logging.debug(f'Less than one direction-compatible point in the region (x-space) with precision {precision}.')
272 |
273 | return X_p, X_costs
274 |
275 |
276 | def _find_point_in_region_with_min_cost(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass):
277 | X_p = np.zeros((0, x.shape[0]))
278 | X_costs = np.zeros((0, x.shape[0]))
279 |
280 | # There are 0 or more than 1 candidates with this precision
281 | while len(X_costs) != 1 and precision > EPSILON:
282 | # print(f'Precision {precision}')
283 | X_p_, X_costs_ = find_region_candidates(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass)
284 | precision = precision / 2
285 |
286 | # Means that the regions is the current regions
287 | if X_costs_ is None:
288 | return np.array([x]), np.zeros((1, x.shape[0]))
289 |
290 | # More than one with the same cost and its fine
291 | if len(X_costs_) == 0 and len(X_costs) > 1:
292 | break
293 |
294 | X_costs = X_costs_
295 | X_p = X_p_
296 | return X_p, X_costs
297 |
298 |
299 | def find_point_in_region_with_min_cost(region, x, xn, phi, tau, degree, multiscaler, normalization, precision):
300 | X_p, X_costs = _find_point_in_region_with_min_cost(region,
301 | x,
302 | xn,
303 | phi,
304 | tau,
305 | multiscaler,
306 | normalization,
307 | precision,
308 | bypass=False)
309 |
310 | if len(X_costs) == 0:
311 | logging.warning(
312 | f'Less than one direction-compatible point in the region with precision {precision}/{EPSILON}: point will be (slightly) out of the region.'
313 | )
314 | X_p, X_costs = _find_point_in_region_with_min_cost(region,
315 | x,
316 | xn,
317 | phi,
318 | tau,
319 | multiscaler,
320 | normalization,
321 | precision,
322 | bypass=True)
323 |
324 | if len(X_costs) == 0:
325 | logging.error(
326 | f'No solution. Less than one direction-compatible point in the region with precision {precision}/{EPSILON}.'
327 | )
328 | raise RuntimeError()
329 |
330 | # Compute costs
331 | costs = np.linalg.norm(np.abs(np.array(X_costs)), degree, axis=1)
332 | index_min = np.argmin(costs)
333 | # costs_min = X_costs[index_min]
334 | cost_min = costs[index_min]
335 | x_min = X_p[index_min]
336 | # x_min = multiscaler.single_inverse_transform(xn + costs_min, method=normalization)
337 | # x_min = np.where(cost_min == 0.0, x, x_min)
338 |
339 | # print('x/c', x_min, cost_min)
340 |
341 | # quantile_overflow handling: will return none when cost exceed 1.0 or less than 0.0
342 | if np.any(np.isnan(x_min)):
343 | return None, np.inf
344 |
345 | if not point_in_region(region, x_min):
346 | logging.warning(f"Point not in region: x = {x_min} / Region = {region}")
347 |
348 | return x_min, cost_min
349 |
350 |
351 | def find_points_in_regions(regions, x, xn, phi, tau, degree, multiscaler, normalization, precision):
352 | X_prime, dist_prime = [], []
353 | for region in regions:
354 | x_min, cost_min = find_point_in_region_with_min_cost(region, x, xn, phi, tau, degree, multiscaler,
355 | normalization, precision)
356 | if x_min is not None:
357 | X_prime.append(x_min)
358 | dist_prime.append(cost_min)
359 | return np.array(X_prime), np.array(dist_prime)
360 |
361 |
362 | def inverse_actionabiltiy(
363 | x,
364 | xn,
365 | phi,
366 | tau,
367 | y_pred,
368 | model,
369 | multiscaler,
370 | trees,
371 | ntrees,
372 | degree,
373 | normalization,
374 | precision,
375 | error,
376 | ):
377 |
378 | x_counter, cost_counter = np.full(x.shape, np.nan), np.inf
379 |
380 | if np.any(np.isnan(phi)):
381 | if error == 'raise':
382 | raise RuntimeError('NaN feature attribution.')
383 | elif error == 'warning':
384 | warnings.warn('NaN feature attribution.')
385 | elif error == 'ignore':
386 | pass
387 | else:
388 | raise ValueError('Invalid error argument.')
389 | return x_counter, cost_counter
390 |
391 | # Compute costs in trees
392 | ctrees = cost_trees(trees, ntrees, xn)
393 |
394 | # Find regions (of the input space) that are phi-compatible
395 | regions = filter_regions(ctrees, x, phi=phi, tau=tau)
396 | assert_regions(regions)
397 |
398 | # Fint the regions counterfactual and cost
399 | X_prime, costs_prime = find_points_in_regions(
400 | regions,
401 | x,
402 | xn,
403 | phi,
404 | tau,
405 | degree=degree,
406 | multiscaler=multiscaler,
407 | normalization=normalization,
408 | precision=precision,
409 | )
410 |
411 | # Order the regions based on the cost
412 | costs_sort = np.argsort(costs_prime)
413 | X_prime, costs_prime = X_prime[costs_sort], costs_prime[costs_sort]
414 |
415 | # Generate the prediction points corresponding to regions of the input
416 | start = time.perf_counter()
417 | y_prime = model.predict(X_prime)
418 | logging.debug(f'Prediction of X_prime.shape = {X_prime.shape} took {time.perf_counter()-start}')
419 |
420 | # Find first occurence of change in output (with min cost)
421 | for xp, yp, cp in zip(X_prime, y_prime, costs_prime):
422 | if yp != y_pred:
423 | x_counter = xp
424 | cost_counter = cp
425 | break
426 |
427 | return x_counter, cost_counter
428 |
--------------------------------------------------------------------------------
/src/cfshap/evaluation/counterfactuals/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jpmorganchase/cf-shap/ba58c61d4e110a950808d18dbe85b67af862549b/src/cfshap/evaluation/counterfactuals/__init__.py
--------------------------------------------------------------------------------
/src/cfshap/evaluation/counterfactuals/plausibility.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | from sklearn.neighbors import NearestNeighbors
5 |
6 | from ...base import CounterfactualEvaluationScorer, BaseCounterfactualEvaluationScorer
7 | from ...utils import keydefaultdict
8 | from ...utils.random import np_sample
9 |
10 |
11 | class BaseNNDistance(CounterfactualEvaluationScorer, BaseCounterfactualEvaluationScorer):
12 | def __init__(
13 | self,
14 | scaler,
15 | X,
16 | distance,
17 | n_neighbors: Union[int, float] = 5,
18 | max_samples=int(1e10),
19 | random_state=2021,
20 | **distance_params,
21 | ):
22 |
23 | self._scaler = scaler
24 | self._metric, self._metric_params = distance, distance_params
25 | self._n_neighbors = n_neighbors
26 | self._max_samples = max_samples
27 | self.random_state = random_state
28 | self.data = X
29 |
30 | @property
31 | def data(self):
32 | return self._data
33 |
34 |
35 | class NNDistance(BaseNNDistance):
36 | """
37 | Plausibility metrics for counterfactuals.
38 |
39 | It computes the (average) distance from the K Nearest Neighbours of the point
40 |
41 | Note it considers all the points (regardless of their class as neighbours).
42 | """
43 | def __init__(
44 | self,
45 | *args,
46 | **kwargs,
47 | ):
48 | super().__init__(*args, **kwargs)
49 |
50 | @BaseNNDistance.data.setter
51 | def data(self, data):
52 |
53 | # Sample
54 | self._data = np_sample(np.asarray(data), random_state=self.random_state, n=self._max_samples, safe=True)
55 |
56 | if isinstance(self._n_neighbors, int) and self._n_neighbors >= 1:
57 | n_neighbors = min(self._n_neighbors, len(self._data))
58 | elif isinstance(self._n_neighbors, float) and self._n_neighbors <= 1.0 and self._n_neighbors > 0.0:
59 | n_neighbors = int(max(1, round(len(self._data) * self._n_neighbors)))
60 | else:
61 | raise ValueError(
62 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)')
63 |
64 | # We will be searching neighbors
65 | self._nn = NearestNeighbors(
66 | n_neighbors=n_neighbors,
67 | metric=self._metric,
68 | p=self._metric_params['p'] if 'p' in self._metric_params else 2,
69 | metric_params=self._metric_params,
70 | ).fit(self._scaler.transform(self._data))
71 |
72 | def score(self, X: np.ndarray):
73 | X = np.asarray(X)
74 |
75 | avg_dist = np.full(X.shape[0], np.nan)
76 | nan_mask = np.any(np.isnan(X), axis=1)
77 |
78 | if (~nan_mask).sum() > 0:
79 | neigh_dist, _ = self._nn.kneighbors(self._scaler.transform(X[~nan_mask]), n_neighbors=None)
80 | neigh_dist = neigh_dist.mean(axis=1)
81 | avg_dist[~nan_mask] = neigh_dist
82 |
83 | return avg_dist
84 |
85 |
86 | class yNNDistance(BaseNNDistance):
87 | """
88 | Plausibility metrics for counterfactuals.
89 |
90 | It computes the (average) distance from the K Nearest COUNTERFACTUAL Neighbours of the point
91 |
92 | Contrary to NNDistance, it considers as neighbours only points that have a different class.
93 | """
94 | def __init__(self, model, *args, **kwargs):
95 | self.model = model
96 | super().__init__(*args, **kwargs)
97 |
98 | @BaseNNDistance.data.setter
99 | def data(self, data):
100 | self._raw_data = np.asarray(data)
101 |
102 | # Sample
103 | self._raw_data = np_sample(self._raw_data, random_state=self.random_state, n=self._max_samples, safe=True)
104 |
105 | # Predict
106 | self._preds = self.model.predict(self._raw_data)
107 |
108 | if isinstance(self._n_neighbors, int) and self._n_neighbors >= 1:
109 | n_neighbors = keydefaultdict(lambda pred: min(self._n_neighbors, len(self._data[pred])))
110 | elif isinstance(self._n_neighbors, float) and self._n_neighbors <= 1.0 and self._n_neighbors > 0.0:
111 | n_neighbors = keydefaultdict(lambda pred: int(max(1, round(len(self._data[pred]) * self._n_neighbors))))
112 | else:
113 | raise ValueError(
114 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)')
115 |
116 | # We will be searching neighbors of a different class
117 | self._data = keydefaultdict(lambda pred: self._raw_data[self._preds == pred])
118 |
119 | self._nn = keydefaultdict(lambda pred: NearestNeighbors(
120 | n_neighbors=n_neighbors[pred],
121 | metric=self._metric,
122 | p=self._metric_params['p'] if 'p' in self._metric_params else 2,
123 | metric_params=self._metric_params,
124 | ).fit(self._scaler.transform(self._data[pred])))
125 |
126 | def score(self, X: np.ndarray):
127 | X = np.asarray(X)
128 |
129 | nan_mask = np.any(np.isnan(X), axis=1)
130 |
131 | preds = self.model.predict(X[~nan_mask])
132 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)}
133 | avg_dist_ = np.full(preds.shape[0], np.nan)
134 |
135 | for pred, indices in preds_indices.items():
136 | neigh_dist, _ = self._nn[pred].kneighbors(self._scaler.transform(X[~nan_mask][indices]), n_neighbors=None)
137 | neigh_dist = neigh_dist.mean(axis=1)
138 | avg_dist_[preds_indices[pred]] = neigh_dist
139 |
140 | avg_dist = np.full(X.shape[0], np.nan)
141 | avg_dist[~nan_mask] = avg_dist_
142 | return avg_dist
143 |
--------------------------------------------------------------------------------
/src/cfshap/trend.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Feature Trend Estimators
5 | """
6 |
7 | import numpy as np
8 | from sklearn.base import BaseEstimator
9 |
10 | __all__ = ['TrendEstimator', 'DummyTrendEstimator']
11 |
12 |
13 | class DummyTrendEstimator(BaseEstimator):
14 | def __init__(self, trends):
15 | self.trends = trends
16 |
17 | def predict(self, x=None, Y=None):
18 | return np.asarray(self.trends)
19 |
20 |
21 | class TrendEstimator(BaseEstimator):
22 | def __init__(self, strategy='mean'):
23 | self.strategy = strategy
24 |
25 | @staticmethod
26 | def __step_function(x, y):
27 | return 2 * np.heaviside(y - x, 0) - 1
28 |
29 | def __preprocess(self, x, Y):
30 | if not isinstance(x, np.ndarray):
31 | raise ValueError('Must pass a NumPy array.')
32 |
33 | if not isinstance(Y, np.ndarray):
34 | raise ValueError('Must pass a NumPy array.')
35 |
36 | if len(Y.shape) != 2:
37 | raise ValueError('Y must be a 2D matrix.')
38 |
39 | if len(x.shape) != 1:
40 | x = x.flatten()
41 |
42 | if x.shape[0] != Y.shape[1]:
43 | raise ValueError('x and Y sizes must be coherent.')
44 |
45 | return x, Y
46 |
47 | def __preprocess3D(self, X, YY):
48 | if not isinstance(X, np.ndarray):
49 | raise ValueError('Must pass a NumPy array.')
50 |
51 | if not isinstance(YY[0], np.ndarray):
52 | raise ValueError('Must pass a list of NumPy array.')
53 |
54 | if len(X.shape) != 2:
55 | raise ValueError('X must be a 2D matrix.')
56 |
57 | if len(X) != len(YY):
58 | raise ValueError('X and Y must have the same length.')
59 |
60 | if X.shape[1] != YY[0].shape[1]:
61 | raise ValueError('X and Y[0] sizes must be coherent.')
62 |
63 | return X, YY
64 |
65 | def __call__(self, x, Y):
66 | x, Y = self.__preprocess(x, Y)
67 |
68 | if self.strategy == 'mean':
69 | return self.__step_function(x, Y.mean(axis=0))
70 | else:
71 | raise ValueError('Invalid strategy.')
72 |
73 | # NOTE: Duplicated for efficiency on 3D data
74 | def predict(self, X, YY):
75 | X, YY = self.__preprocess3D(X, YY)
76 |
77 | if self.strategy == 'mean':
78 | return np.array([self.__step_function(x, Y.mean(axis=0)) for x, Y in zip(X, YY)])
79 | else:
80 | raise ValueError('Invalid strategy.')
81 |
--------------------------------------------------------------------------------
/src/cfshap/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from ._python import *
2 | from ._utils import *
3 |
--------------------------------------------------------------------------------
/src/cfshap/utils/_python.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | This module contains general utilities for Python programming language.
5 | - Extension and improvement of Python data structures (e.g., attrdict)
6 | - Extension of Python language (e.g., static variables for functions, function caching)
7 | - Implementations of design patterns in Python (e.g., Singleton)
8 | """
9 |
10 | from abc import ABCMeta
11 | from collections import defaultdict
12 | from collections.abc import Mapping
13 | from pprint import pformat
14 | import json
15 |
16 | __all__ = [
17 | 'static_vars',
18 | 'keydefaultdict',
19 | 'ignorenonedict',
20 | 'attrdict',
21 | 'function_cache',
22 | 'Singleton',
23 | ]
24 |
25 |
26 | def static_vars(**kwargs):
27 | """Decorator to create static variables for a function
28 |
29 | Usage:
30 | ```python
31 | @static_vars(i = 0)
32 | def function(x):
33 | i += 1
34 | return x + i
35 |
36 | function(0) # 1
37 | function(0) # 2
38 | function(10) # 13
39 | ```
40 | """
41 | def decorate(func):
42 | for k in kwargs:
43 | setattr(func, k, kwargs[k])
44 | return func
45 |
46 | return decorate
47 |
48 |
49 | class keydefaultdict(defaultdict):
50 | """
51 | Extension of defaultdict that support
52 | passing the key to the default_factory
53 | """
54 | def __missing__(self, key):
55 | if self.default_factory is None:
56 | raise KeyError(key, "Must pass a default factory with a single argument.")
57 | else:
58 | ret = self[key] = self.default_factory(key)
59 | return ret
60 |
61 |
62 | class ignorenonedict(dict):
63 | def __init__(self, other=None, **kwargs):
64 | super().__init__()
65 | self.update(other, **kwargs)
66 |
67 | def __setitem__(self, key, value):
68 | if value is not None:
69 | super().__setitem__(key, value)
70 |
71 | def update(self, other=None, **kwargs):
72 | if other is not None:
73 | for k, v in other.items() if isinstance(other, Mapping) else other:
74 | self.__setitem__(k, v)
75 | for k, v in kwargs.items():
76 | self.__setitem__(k, v)
77 |
78 |
79 | class attrdict(dict):
80 | """
81 | Attributes-dict bounded structure for paramenters
82 | -> When a dictionary key is set the corresponding attribute is set
83 | -> When an attribute is set the corresponding dictionary key is set
84 |
85 | Usage:
86 |
87 | # Create the object
88 | args = AttrDict()
89 |
90 | args.a = 1
91 | print(args.a) # 1
92 | print(args['a']) # 1
93 |
94 | args['b'] = 2
95 | print(args.b) # 2
96 | print(args['b']) # 2
97 |
98 | """
99 | def __init__(self, *args, **kwargs):
100 | super(attrdict, self).__init__(*args, **kwargs)
101 | self.__dict__ = self
102 |
103 | def repr(self):
104 | return dict(self)
105 |
106 | def __repr__(self):
107 | return pformat(self.repr())
108 |
109 | def __str__(self):
110 | return self.__repr__()
111 |
112 | def update_defaults(self, d: dict):
113 | for k, v in d.items():
114 | self.setdefault(k, v)
115 |
116 | def save_json(self, file_name):
117 | with open(file_name, 'w') as fp:
118 | json.dump(self.repr(), fp)
119 |
120 | def copy(self):
121 | return type(self)(self)
122 |
123 |
124 | class FunctionCache(dict):
125 | def __init__(self, f):
126 | self.f = f
127 |
128 | def get(self, *args, **kwargs):
129 | key = hash(str(args)) + hash(str(kwargs))
130 | if key in self:
131 | return self[key]
132 | else:
133 | ret = self[key] = self.f(*args, **kwargs)
134 | return ret
135 |
136 |
137 | # Decorator
138 | def function_cache(f, name='cache'):
139 | """"
140 |
141 | Usage Example:
142 |
143 | @function_cache(lambda X: expensive_function(X))
144 | @function_cache(lambda X: expensive_function2(X), name = 'second_cache')
145 | def f(X, y):
146 | return expensive_function(Y) - f.cache.get(X) + f.second_cache.get(X)
147 |
148 |
149 | X, Y = ..., ...
150 | for y in Y:
151 | f(X, y) # The function is called multiple times with X
152 |
153 |
154 |
155 | """
156 | def decorate(func):
157 | setattr(func, name, FunctionCache(f))
158 | return func
159 |
160 | return decorate
161 |
162 |
163 | class cached_property(object):
164 | def __init__(self, func):
165 | self.__doc__ = getattr(func, "__doc__")
166 |
167 | self.func = func
168 |
169 | def __get__(self, obj, cls):
170 | if obj is None:
171 | return self
172 |
173 | value = obj.__dict__[self.func.__name__] = self.func(obj)
174 |
175 | return value
176 |
177 |
178 | class Singleton(ABCMeta):
179 | """
180 | Singleton META-CLASS from https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python
181 |
182 | Usage Example:
183 | class Logger(metaclass=Singleton):
184 | pass
185 |
186 | What it does:
187 | When you call logger with Logger(), Python first asks the metaclass of Logger, Singleton, what to do,
188 | allowing instance creation to be pre-empted. This process is the same as Python asking a class what to do
189 | by calling __getattr__ when you reference one of it's attributes by doing myclass.attribute.
190 |
191 | """
192 | _instances = {}
193 |
194 | def __call__(cls, *args, **kwargs):
195 | if cls not in cls._instances:
196 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
197 | return cls._instances[cls]
--------------------------------------------------------------------------------
/src/cfshap/utils/_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | This module implements some general utilities for explanations
5 | """
6 |
7 | import numpy as np
8 |
9 | from ..base import Model, XGBWrapping, ListOf2DArrays
10 |
11 | __all__ = [
12 | 'get_shap_compatible_model',
13 | 'get_top_counterfactuals',
14 | 'expand_dims_counterfactuals',
15 | 'get_top_counterfactual_costs',
16 | ]
17 |
18 |
19 | def get_top_counterfactuals(
20 | XX_C: ListOf2DArrays,
21 | X: np.ndarray,
22 | n_top: int = 1,
23 | nan: bool = True,
24 | return_2d=True,
25 | ):
26 | """Extracts the top-x counterfactuals from the input
27 |
28 | Args:
29 | XX_C (ListOf2DArrays): List of arrays (for each sample) of counterfactuals (multiple)
30 | X (np.ndarray): The query instances/samples
31 | n_top (int, optional): Number of counterfactuals to extract. Defaults to 1.
32 | nan (bool, optional): Return NaN. Defaults to True.
33 | return_2d (bool, optional): Return a 2D array if n_top = 1. Defaults to True.
34 |
35 | Returns:
36 | ListOf2DArrays: The top-counterfactuals
37 | If n_top = 1 and return_2d = True:
38 | nb_samples x nb_features
39 | Note: if nan = False, nb_samples may be less that X.shape[0]
40 | Else:
41 | nb_samples (List) x nb_counterfactuals x nb_features
42 | Note: if nan = False, nb_counterfactuals may be less than n_top
43 |
44 | """
45 | assert X.shape[0] == len(XX_C)
46 |
47 | if n_top == 1 and return_2d is True:
48 | XC = np.full(X.shape, np.nan)
49 | for i, XC_ in enumerate(XX_C):
50 | if XC_.shape[0] > 0:
51 | XC[i] = XC_[0]
52 | if nan is False:
53 | XC = XC[np.any(np.isnan(XC), axis=1)]
54 | return XC
55 | else:
56 | if nan is True:
57 | XX_C_top = [np.full((n_top, X.shape[1]), np.nan)]
58 | for i, XC_ in enumerate(XX_C):
59 | if XC_.shape[0] > 0:
60 | n_xc = min(n_top, XC_.shape[0])
61 | XX_C_top[i][:n_xc] = XC_[:n_xc]
62 | return XX_C_top
63 | else:
64 | XX_C = [XC[~np.any(np.isnan(XC), axis=1)] for XC in XX_C]
65 | return [XC[:n_top] for XC in XX_C]
66 |
67 |
68 | def expand_dims_counterfactuals(X):
69 | """Expand dimensions of a list of counterfactuals
70 |
71 | Args:
72 | X (np.ndarray): The counterfactuals
73 |
74 | Returns:
75 | List[np.ndarray]: A list of arrays of counterfactuals. All arrays will have length 1.
76 |
77 | """
78 |
79 | if isinstance(X, np.ndarray):
80 | return [x.reshape(-1, 1) for x in X]
81 | else:
82 | return [np.array([x]) for x in X]
83 |
84 |
85 | def get_top_counterfactual_costs(costs: ListOf2DArrays, X: np.ndarray, nan: bool = True):
86 | """Extract the costs of the top-x counterfactuals
87 |
88 | Pre-condition: Counterfactuals costs for each sample must be ordered (ascending).
89 |
90 | Args:
91 | costs (ListOf2DArrays): List of arrays (for each sample) of counterfactuals costs (multiple).
92 | X (np.ndarray): The query instances/samples
93 | nan (bool, optional): If True, return NaN cost if no counterfactual exists, skips the sample otherwise. Defaults to True.
94 |
95 | Returns:
96 | np.ndarray: Array of the costs of the best counterfactuals
97 | """
98 | if nan:
99 | return np.array([c[0] if len(c) > 0 else np.nan for c in costs])
100 | else:
101 | raise NotImplementedError()
102 |
103 |
104 | def get_shap_compatible_model(model: Model):
105 | """Get the SHAP-compatible model
106 |
107 | Args:
108 | model (Model): A model
109 |
110 | Raises:
111 | ValueError: If the model is not supported
112 |
113 | Returns:
114 | object: SHAP-compatible model.
115 | """
116 | # Extract XGBoost-wrapped models
117 | if isinstance(model, XGBWrapping):
118 | model = model.get_booster()
119 |
120 | # Explicit error handling
121 | if not (type(model).__module__ == 'xgboost.core' and type(model).__name__ == 'Booster'):
122 | raise ValueError("'get_booster' must return an XGBoost Booster object.")
123 |
124 | return model
125 |
--------------------------------------------------------------------------------
/src/cfshap/utils/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Parallelization plotting utilities.
5 | """
6 |
7 | from .batch import *
8 | from .utils import *
9 |
--------------------------------------------------------------------------------
/src/cfshap/utils/parallel/batch.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Parallelization Utilities for batch processing with JobLib.
5 | """
6 |
7 | import logging
8 | from typing import Union
9 | from tqdm import tqdm
10 |
11 | from .utils import Element, split, join, max_cpu_count
12 |
13 | __all__ = ['batch_process']
14 |
15 |
16 | def _parallel_joblib(f, splits, n_jobs):
17 | # pylint: disable=import-outside-toplevel
18 | from joblib import Parallel, delayed
19 | # pylint: enable=import-outside-toplevel
20 | return Parallel(n_jobs=n_jobs)(delayed(f)(s) for s in splits)
21 |
22 |
23 | def batch_process(
24 | f,
25 | iterable: Element,
26 | split_size: int = None,
27 | batch_size: int = None,
28 | desc='Batches',
29 | verbose=1,
30 | n_jobs: Union[None, int, str] = None,
31 | ) -> Element:
32 | '''
33 | Batch-processing for iterable that support list comprehension.
34 |
35 | f : function
36 | iterable : An iterable to be batch processed (list of numpy)
37 | split_size/batch_size : size of the batch
38 | '''
39 |
40 | assert split_size is None or batch_size is None
41 | if split_size is None:
42 | split_size = batch_size
43 |
44 | if n_jobs is None:
45 | n_jobs = 1
46 | if n_jobs == 'auto':
47 | n_jobs = n_jobs or (max_cpu_count() - 1)
48 |
49 | # Split
50 | if verbose > 1:
51 | logging.info('Splitting the iterable for batch process...')
52 | splits = split(iterable, split_size)
53 | if verbose > 1:
54 | logging.info('Splitting done.')
55 |
56 | # Attach TQDM
57 | if desc is not None and desc is not False and verbose > 0:
58 | splits = tqdm(splits, desc=str(desc))
59 |
60 | # Batch process
61 | if n_jobs > 1 and len(iterable) > split_size:
62 | logging.info(f'Using joblib with {n_jobs} jobs')
63 | results = _parallel_joblib(f, splits, n_jobs)
64 | else:
65 | results = [f(s) for s in splits]
66 |
67 | # Join
68 | if verbose > 1:
69 | logging.info('Joining the iterable after batch process...')
70 | results = join(results)
71 | if verbose > 1:
72 | logging.info('Joining done.')
73 |
74 | # Return
75 | return results
--------------------------------------------------------------------------------
/src/cfshap/utils/parallel/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | General utilities needed when parallelizing (e.g., determining number of CPUs, split an array)
5 | """
6 |
7 | from typing import Union, List, Tuple
8 | import itertools
9 | import multiprocessing
10 |
11 | import pandas as pd
12 | import numpy as np
13 |
14 | __all__ = [
15 | 'max_cpu_count',
16 | ]
17 |
18 | Element = Union[list, np.ndarray, pd.DataFrame]
19 | Container = Union[List[Element], np.ndarray]
20 |
21 |
22 | def max_cpu_count(reserved=1):
23 | count = multiprocessing.cpu_count()
24 | return max(1, count - reserved)
25 |
26 |
27 | def nb_splits(iterable: Element, split_size: int) -> int:
28 | return int(np.ceil(len(iterable) / split_size))
29 |
30 |
31 | def split_indexes(iterable: Element, split_size: int) -> List[Tuple[int, int]]:
32 | nb = nb_splits(iterable, split_size)
33 | split_size = int(np.ceil(len(iterable) / nb)) # Equalize sizes of the splits
34 | return [(split_id * split_size, (split_id + 1) * split_size) for split_id in range(0, nb)]
35 |
36 |
37 | def split(iterable: Element, split_size) -> Container:
38 | """
39 | iterable : The objet that should be split in pieces
40 | split_size: Maximum split size (the size is equalized among all splits)
41 | """
42 | if isinstance(iterable, (np.ndarray, pd.DataFrame)):
43 | return np.array_split(iterable, nb_splits(iterable, split_size))
44 | else:
45 | return [iterable[a:b] for a, b in split_indexes(iterable, split_size)]
46 |
47 |
48 | def join(iterable: Container, axis: int = 0) -> Element:
49 | # Get rid of iterators
50 | iterable = list(iterable)
51 |
52 | assert len(set((type(a) for a in iterable))) == 1, "Expecting a non-empty iterable of objects of the same type."
53 |
54 | if isinstance(iterable[0], pd.DataFrame):
55 | return pd.concat(iterable, axis=axis)
56 | elif isinstance(iterable[0], np.ndarray):
57 | return np.concatenate(iterable, axis=axis)
58 | elif isinstance(iterable[0], list):
59 | return list(itertools.chain.from_iterable(iterable))
60 | else:
61 | raise TypeError('Can handle only pandas, numpy and lists. No iterators.')
62 |
--------------------------------------------------------------------------------
/src/cfshap/utils/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | ML / Data Science Pre-processing Utilities
5 | """
6 |
7 | from .scaling import *
8 |
--------------------------------------------------------------------------------
/src/cfshap/utils/preprocessing/scaling/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantiletransformer import *
2 | from .madscaler import *
3 | from .multiscaler import *
--------------------------------------------------------------------------------
/src/cfshap/utils/preprocessing/scaling/madscaler.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements the (Mean/Median) Absolute Deviation Scalers
3 |
4 | NOTE: The acronym MAD (Mean/Median Absolute Deviation) has many different meaning.
5 | This module implements all the four possible alternatives:
6 | - Mean Absolute Deviation Scaler from the Mean
7 | - Median Absolute Deviation Scaler from the Median
8 | - Mean Absolute Deviation Scaler from the Median
9 | - Median Absolute Deviation Scaler from the Mean
10 |
11 | """
12 |
13 | __all__ = [
14 | 'MeanAbsoluteDeviationFromMeanScaler',
15 | 'MeanAbsoluteDeviationFromMedianScaler',
16 | 'MedianAbsoluteDeviationFromMeanScaler',
17 | 'MedianAbsoluteDeviationFromMedianScaler',
18 | 'MedianAbsoluteDeviationScaler',
19 | 'MeanAbsoluteDeviationScaler',
20 | ]
21 | __author__ = 'Emanuele Albini'
22 |
23 | from abc import ABC, abstractmethod
24 | import numpy as np
25 | from scipy import sparse
26 | from scipy import stats
27 | from sklearn.base import BaseEstimator, TransformerMixin
28 | from sklearn.utils.validation import (check_is_fitted, FLOAT_DTYPES)
29 | from sklearn.utils.sparsefuncs import inplace_column_scale
30 | from sklearn.utils import check_array
31 |
32 |
33 | def mean_absolute_deviation(data, axis=None):
34 | return np.mean(np.absolute(data - np.mean(data, axis)), axis)
35 |
36 |
37 | def mean_absolute_deviation_from_median(data, axis=None):
38 | return np.mean(np.absolute(data - np.median(data, axis)), axis)
39 |
40 |
41 | def median_absolute_deviation_from_mean(data, axis=None):
42 | return np.median(np.absolute(data - np.mean(data, axis)), axis)
43 |
44 |
45 | class AbsoluteDeviationScaler(TransformerMixin, BaseEstimator, ABC):
46 | """
47 | This class is the interface and base class for the Mean/Median Absolute Deviation Scalers.
48 | It implements the scikit-learn API for scalers.
49 |
50 | This class is based on scikit-learn StandardScaler and RobustScaler.
51 | """
52 | def __init__(self, *, copy=True, with_centering=True, with_scaling=True):
53 | self.with_centering = with_centering
54 | self.with_scaling = with_scaling
55 | self.copy = copy
56 |
57 | @abstractmethod
58 | def _center(self, X):
59 | pass
60 |
61 | @abstractmethod
62 | def _scale(self, X):
63 | pass
64 |
65 | def _check_inputs(self, X):
66 | try:
67 | X = self._validate_data(X,
68 | accept_sparse='csc',
69 | estimator=self,
70 | dtype=FLOAT_DTYPES,
71 | force_all_finite='allow-nan')
72 | except AttributeError:
73 | X = check_array(X,
74 | accept_sparse='csr',
75 | copy=self.copy,
76 | estimator=self,
77 | dtype=FLOAT_DTYPES,
78 | force_all_finite='allow-nan')
79 | return X
80 |
81 | def fit(self, X, y=None):
82 | X = self._check_inputs(X)
83 |
84 | if self.with_centering:
85 | if sparse.issparse(X):
86 | raise ValueError("Cannot center sparse matrices: use `with_centering=False`"
87 | " instead. See docstring for motivation and alternatives.")
88 | self.center_ = self._center(X)
89 | else:
90 | self.center_ = None
91 |
92 | if self.with_scaling:
93 | self.scale_ = self._scale(X)
94 | else:
95 | self.scale_ = None
96 |
97 | return self
98 |
99 | def transform(self, X):
100 | """Center and scale the data.
101 | Parameters
102 | ----------
103 | X : {array-like, sparse matrix} of shape (n_samples, n_features)
104 | The data used to scale along the specified axis.
105 | Returns
106 | -------
107 | X_tr : {ndarray, sparse matrix} of shape (n_samples, n_features)
108 | Transformed array.
109 | """
110 | check_is_fitted(self, ['center_', 'scale_'])
111 | X = self._check_inputs(X)
112 |
113 | if sparse.issparse(X):
114 | if self.with_scaling:
115 | inplace_column_scale(X, 1.0 / self.scale_)
116 | else:
117 | if self.with_centering:
118 | X -= self.center_
119 | if self.with_scaling:
120 | X /= self.scale_
121 | return X
122 |
123 | def inverse_transform(self, X):
124 | """Scale back the data to the original representation
125 | Parameters
126 | ----------
127 | X : {array-like, sparse matrix} of shape (n_samples, n_features)
128 | The rescaled data to be transformed back.
129 | Returns
130 | -------
131 | X_tr : {ndarray, sparse matrix} of shape (n_samples, n_features)
132 | Transformed array.
133 | """
134 | check_is_fitted(self, ['center_', 'scale_'])
135 | X = check_array(X,
136 | accept_sparse=('csr', 'csc'),
137 | copy=self.copy,
138 | estimator=self,
139 | dtype=FLOAT_DTYPES,
140 | force_all_finite='allow-nan')
141 |
142 | if sparse.issparse(X):
143 | if self.with_scaling:
144 | inplace_column_scale(X, self.scale_)
145 | else:
146 | if self.with_scaling:
147 | X *= self.scale_
148 | if self.with_centering:
149 | X += self.center_
150 | return X
151 |
152 | def _more_tags(self):
153 | return {'allow_nan': True}
154 |
155 |
156 | class MeanAbsoluteDeviationFromMeanScaler(AbsoluteDeviationScaler):
157 | """Mean absolute deviation scaler (from the mean)
158 | It scales using the the MEAN deviation from the MEAN
159 | """
160 | def _center(self, X):
161 | return np.nanmean(X, axis=0)
162 |
163 | def _scale(self, X):
164 | return mean_absolute_deviation(X, axis=0)
165 |
166 |
167 | class MedianAbsoluteDeviationFromMedianScaler(AbsoluteDeviationScaler):
168 | """Median absolute deviation scaler (from the median)
169 | It scales using the the MEDIAN deviation from the MEDIAN
170 | """
171 | def _center(self, X):
172 | return np.nanmedian(X, axis=0)
173 |
174 | def _scale(self, X):
175 | return stats.median_absolute_deviation(X, axis=0)
176 |
177 |
178 | class MeanAbsoluteDeviationFromMedianScaler(AbsoluteDeviationScaler):
179 | """Mean absolute deviation scaler (from the median)
180 | It scales using the the MEAN deviation from the MEDIAN
181 | """
182 | def _center(self, X):
183 | return np.nanmean(X, axis=0)
184 |
185 | def _scale(self, X):
186 | return mean_absolute_deviation_from_median(X, axis=0)
187 |
188 |
189 | class MedianAbsoluteDeviationFromMeanScaler(AbsoluteDeviationScaler):
190 | """Median absolute deviation scaler (from the mean)
191 | It scales using the the MEDIAN deviation from the MEAN
192 | """
193 | def _center(self, X):
194 | return np.nanmean(X, axis=0)
195 |
196 | def _scale(self, X):
197 | return median_absolute_deviation_from_mean(X, axis=0)
198 |
199 |
200 | # Aliases
201 | # By default we return the matching absolute scalers (median-median and mean-mean).
202 | MeanAbsoluteDeviationScaler = MeanAbsoluteDeviationFromMeanScaler
203 | MedianAbsoluteDeviationScaler = MedianAbsoluteDeviationFromMedianScaler
--------------------------------------------------------------------------------
/src/cfshap/utils/preprocessing/scaling/multiscaler.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements the MultiScaler.
3 | The multi scaler is a scaler that allows for different scaling within the same class through an argument passed to the `transform` methods.
4 | e.g., STD, MAD, Quantile, etc.
5 |
6 | """
7 |
8 | __all__ = [
9 | 'MultiScaler',
10 | 'IdentityScaler',
11 | 'get_scaler_name',
12 | 'SCALERS',
13 | ]
14 | __author__ = 'Emanuele Albini'
15 |
16 | from typing import Union
17 |
18 | import numpy as np
19 | import scipy as sp
20 |
21 | from sklearn.preprocessing import MinMaxScaler, StandardScaler
22 | from sklearn.base import BaseEstimator, TransformerMixin
23 |
24 | from ... import keydefaultdict
25 |
26 | from .madscaler import (
27 | MedianAbsoluteDeviationScaler,
28 | MeanAbsoluteDeviationScaler,
29 | MeanAbsoluteDeviationFromMedianScaler,
30 | MedianAbsoluteDeviationFromMeanScaler,
31 | )
32 | from .quantiletransformer import EfficientQuantileTransformer
33 |
34 | # NOTE: This will be deprecated, it is confusing
35 | _DISTANCE_TO_SCALER = {
36 | 'euclidean': 'std',
37 | 'manhattan': 'mad',
38 | 'cityblock': 'mad',
39 | 'percshift': 'quantile',
40 | }
41 |
42 | _SCALER_TO_NAME = {
43 | 'quantile': 'Quantile',
44 | 'std': 'Standard',
45 | 'mad': 'Median Absolute Dev. (from median)',
46 | 'Mad': 'Mean Absolute Dev. (from mean)',
47 | 'madM': 'Median Absolute Dev. from mean',
48 | 'Madm': 'Mean Absolute Dev. from median',
49 | 'minmax': 'Min Max',
50 | 'quantile_nan': 'Quantile w/NaN OF',
51 | 'quantile_sum': 'Quantile w/OF-Σ',
52 | None: 'Identity',
53 | }
54 |
55 | SCALERS = list(_SCALER_TO_NAME.keys())
56 |
57 |
58 | def _get_method_safe(method):
59 | if method is None or method in SCALERS:
60 | return method
61 | # NOTE: This will be deprecated, it is confusing
62 | elif method in list(_DISTANCE_TO_SCALER.keys()):
63 | return _DISTANCE_TO_SCALER[method]
64 | else:
65 | raise ValueError('Invalid normalization method.')
66 |
67 |
68 | def get_scaler_name(method):
69 | return _SCALER_TO_NAME[_get_method_safe(method)]
70 |
71 |
72 | class IdentityScaler(TransformerMixin, BaseEstimator):
73 | """A dummy/identity scaler compatible with the sklearn interface for scalers
74 | It returns the same input it receives.
75 |
76 | """
77 | def __init__(self):
78 | pass
79 |
80 | def fit(self, X, y=None):
81 | return self
82 |
83 | def transform(self, X):
84 | return X
85 |
86 | def inverse_transform(self, X):
87 | return X
88 |
89 |
90 | def get_transformer_class(method):
91 | method = _get_method_safe(method)
92 |
93 | if method is None:
94 | return IdentityScaler
95 | elif method == 'std':
96 | return StandardScaler
97 | elif method == 'minmax':
98 | return MinMaxScaler
99 | elif method == 'mad':
100 | return MedianAbsoluteDeviationScaler
101 | elif method == 'Mad':
102 | return MeanAbsoluteDeviationScaler
103 | elif method == 'Madm':
104 | return MeanAbsoluteDeviationFromMedianScaler
105 | elif method == 'madM':
106 | return MedianAbsoluteDeviationFromMeanScaler
107 | elif method == 'quantile':
108 | return EfficientQuantileTransformer # PercentileShifterCached
109 | elif method == 'quantile_nan':
110 | return EfficientQuantileTransformer
111 | elif method == 'quantile_sum':
112 | return EfficientQuantileTransformer
113 |
114 |
115 | def get_transformer_kwargs(method):
116 | method = _get_method_safe(method)
117 |
118 | if method == 'quantile_nan':
119 | return dict(overflow="nan")
120 | elif method == 'quantile_sum':
121 | return dict(overflow="sum")
122 | else:
123 | return dict()
124 |
125 |
126 | class MultiScaler:
127 | """Multi-Scaler
128 |
129 | Raises:
130 | Exception: If an invalid normalization is used
131 |
132 | """
133 | # Backward compatibility
134 | NORMALIZATION = SCALERS
135 |
136 | def __init__(self, data: np.ndarray = None):
137 | """Constructor
138 |
139 | - The data on which to train the scalers can be passed here (in the constructor), or
140 | - It can also be passe later using the .fit(data) method.
141 |
142 | Args:
143 | data (pd.DataFrame): A dataframe with the features, optional. Default to None.
144 |
145 | """
146 |
147 | if data is not None:
148 | return self.fit(data)
149 |
150 | self.__suppress_warning = False
151 |
152 | def fit(self, data: np.ndarray):
153 | self.data = np.asarray(data)
154 |
155 | self.transformers = keydefaultdict(lambda method: get_transformer_class(method)
156 | (**get_transformer_kwargs(method)).fit(self.data))
157 |
158 | def single_transformer(method, f):
159 | return get_transformer_class(method)(**get_transformer_kwargs(method)).fit(self.data[:, f].reshape(-1, 1))
160 |
161 | self.single_transformers = keydefaultdict(lambda args: single_transformer(*args))
162 | self.data_transformed = keydefaultdict(lambda method: self.transform(self.data, method))
163 |
164 | self.covs = keydefaultdict(lambda method, lib: self.__compute_covariance_matrix(self.data, method, lib))
165 |
166 | self._lower_bounds = keydefaultdict(lambda method: self.data_transformed[method].min(axis=0))
167 | self._upper_bounds = keydefaultdict(lambda method: self.data_transformed[method].max(axis=0))
168 |
169 | def lower_bounds(self, method):
170 | return self._lower_bounds[method]
171 |
172 | def upper_bounds(self, method):
173 | return self._upper_bound[method]
174 |
175 | def suppress_warnings(self, value=True):
176 | self.__suppress_warning = value
177 |
178 | def __compute_covariance_matrix(self, data, method, lib):
179 | if lib == 'np':
180 | return sp.linalg.inv(np.cov((self.transform(data, method=method)), rowvar=False))
181 | elif lib == 'tf':
182 | from ..tf import inv_cov as tf_inv_cov
183 | return tf_inv_cov(self.transform(data, method=method))
184 | else:
185 | raise ValueError('Invalid lib.')
186 |
187 | def transform(self, data: np.ndarray, method: str, **kwargs):
188 | """Normalize the data according to the "method" passed
189 |
190 | Args:
191 | data (np.ndarray): The data to be normalized (nb_samples x nb_features)
192 | method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'.
193 |
194 | Raises:
195 | ValueError: Invalid normalization
196 |
197 | Returns:
198 | np.ndarray: Normalized array
199 | """
200 | method = _get_method_safe(method)
201 | return self.transformers[method].transform(data)
202 |
203 | def inverse_transform(self, data: np.ndarray, method: str = 'std'):
204 | """Un-normalize the data according to the "method" passes
205 |
206 | Args:
207 | data (np.ndarray): The data to be un-normalized (nb_samples x nb_features)
208 | method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'.
209 |
210 | Raises:
211 | ValueError: Invalid normalization
212 |
213 | Returns:
214 | np.ndarray: Un-normalized array
215 | """
216 | method = _get_method_safe(method)
217 | return self.transformers[method].inverse_transform(data)
218 |
219 | def feature_deviation(self, method: str = 'std', phi: Union[float, int] = 1):
220 | """Get the deviation of each feature according to the normalization method
221 |
222 | Args:
223 | method (str): method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'.
224 | phi (Union[float, int]): The fraction of the STD/MAD/MINMAX. Default to 1.
225 |
226 | Raises:
227 | ValueError: Invalid normalization
228 |
229 | Returns:
230 | np.ndarray: Deviations, shape = (nb_features, )
231 | """
232 | method = _get_method_safe(method)
233 | transformer = self.transformers[method]
234 | if 'scale_' in dir(transformer):
235 | return transformer.scale_ * phi
236 | else:
237 | return np.ones(self.data.shape[1]) * phi
238 |
239 | def feature_transform(self, x: np.ndarray, f: int, method: str):
240 | x = np.asarray(x)
241 | transformer = self.get_feature_transformer(method, f)
242 | return transformer.transform(x.reshape(-1, 1))[:, 0]
243 |
244 | def value_transform(self, x: float, f: int, method: str):
245 | x = np.asarray([x])
246 | transformer = self.get_feature_transformer(method, f)
247 | return transformer.transform(x.reshape(-1, 1))[:, 0][0]
248 |
249 | def shift_transform(self, X, shifts, method, **kwargs):
250 | transformer = self.get_transformer(method)
251 | if 'shift' in dir(transformer):
252 | return transformer.shift_transform(X, shifts=shifts, **kwargs)
253 | else:
254 | return X + shifts
255 |
256 | def move_transform(self, X, costs, method, **kwargs):
257 | transformer = self.get_transformer(method)
258 | assert costs.shape[0] == X.shape[1]
259 | return transformer.inverse_transform(transformer.transform(X) + np.tile(costs, (X.shape[0], 1)))
260 |
261 | def get_transformer(self, method: str):
262 | return self.transformers[_get_method_safe(method)]
263 |
264 | def get_feature_transformer(self, method: str, f: int):
265 | return self.single_transformers[(_get_method_safe(method), f)]
266 |
267 | def single_transform(self, x, *args, **kwargs):
268 | return self.transform(np.array([x]), *args, **kwargs)[0]
269 |
270 | def single_inverse_transform(self, x, *args, **kwargs):
271 | return self.inverse_transform(np.array([x]), *args, **kwargs)[0]
272 |
273 | def single_shift_transform(self, x, shift, **kwargs):
274 | return self.shift_transform(np.array([x]), np.array([shift]), **kwargs)[0]
275 |
276 | # NOTE: This must be deprecated, it does not fit here.
277 | def covariance_matrix(self, data: np.ndarray, method: Union[None, str], lib='np'):
278 | if data is None:
279 | return self.covs[(method, lib)]
280 | else:
281 | return self.__compute_covariance_matrix(data, method, lib)
282 |
283 | # NOTE: This must be deprecated, it does not fit here.
284 | def covariance_matrices(self, data: np.ndarray, methods=None, lib='np'):
285 | """Compute the covariance matrices
286 |
287 | Args:
288 | data (np.ndarray): The data from which to extract the covariance
289 |
290 | Returns:
291 | Dict[np.ndarray]: Dictionary (for each normalization method) of covariance matrices
292 | """
293 | # If no method is passed we compute for all of them
294 | if methods is None:
295 | methods = self.NORMALIZATION
296 |
297 | return {method: self.covariance_matrix(data, method, lib) for method in methods}
--------------------------------------------------------------------------------
/src/cfshap/utils/preprocessing/scaling/quantiletransformer.py:
--------------------------------------------------------------------------------
1 | """
2 | This module implements an efficient and exact version of the scikit-learn QuantileTransformer.
3 |
4 | Note: This module has been inspired from scikit-learn QuantileTransformer
5 | (and it is effectively an extension of QuantileTransformer API)
6 | See https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/preprocessing
7 |
8 | """
9 |
10 | __all__ = ['EfficientQuantileTransformer']
11 | __author__ = 'Emanuele Albini'
12 |
13 | import numpy as np
14 | from scipy import sparse
15 | from sklearn.utils import check_random_state
16 | from sklearn.utils.validation import (check_is_fitted, FLOAT_DTYPES)
17 | from sklearn.preprocessing import QuantileTransformer
18 | from sklearn.utils import check_array
19 |
20 |
21 | class EfficientQuantileTransformer(QuantileTransformer):
22 | """
23 | This class directly extends and improve the efficiency of scikit-learn QuantileTransformer
24 |
25 | Note: The efficient implementation will be only used if:
26 | - The input are NumPy arrays (NOT scipy sparse matrices)
27 | The flag self.smart_fit_ marks when the efficient implementation is being used.
28 |
29 | """
30 | def __init__(
31 | self,
32 | *,
33 | subsample=np.inf,
34 | random_state=None,
35 | copy=True,
36 | overflow=None, # "nan" or "sum"
37 | ):
38 | """Initialize the transformer
39 |
40 | Args:
41 | subsample (int, optional): Number of samples to use to create the quantile space. Defaults to np.inf.
42 | random_state (int, optional): Random seed (sampling happen only if subsample < number of samples fitted). Defaults to None.
43 | copy (bool, optional): If False, passed arrays will be edited in place. Defaults to True.
44 | overflow (str, optional): Overflow strategy. Defaults to None.
45 | When doing the inverse transformation if a quantile > 1 or < 0 is passed then:
46 | - None > Nothing is done. max(0, min(1, q)) will be used. The 0% or 100% reference will be returned.
47 | - 'sum' > It will add proportionally, e.g., q = 1.2 will result in adding 20% more quantile to the 100% reference.
48 | - 'nan' > It will return NaN
49 | """
50 | self.ignore_implicit_zeros = False
51 | self.n_quantiles_ = np.inf
52 | self.output_distribution = 'uniform'
53 | self.subsample = subsample
54 | self.random_state = random_state
55 | self.overflow = overflow
56 | self.copy = copy
57 |
58 | def _smart_fit(self, X, random_state):
59 | n_samples, n_features = X.shape
60 | self.references_ = []
61 | self.quantiles_ = []
62 | for col in X.T:
63 | # Do sampling if necessary
64 | if self.subsample < n_samples:
65 | subsample_idx = random_state.choice(n_samples, size=self.subsample, replace=False)
66 | col = col.take(subsample_idx, mode='clip')
67 | col = np.sort(col)
68 | quantiles = np.sort(np.unique(col))
69 | references = 0.5 * np.array(
70 | [np.searchsorted(col, v, side='left') + np.searchsorted(col, v, side='right')
71 | for v in quantiles]) / n_samples
72 | self.quantiles_.append(quantiles)
73 | self.references_.append(references)
74 |
75 | def fit(self, X, y=None):
76 | """Compute the quantiles used for transforming.
77 | Parameters
78 | ----------
79 | X : {array-like} of shape (n_samples, n_features)
80 | The data used to scale along the features axis.
81 |
82 | y : None
83 | Ignored.
84 | Returns
85 | -------
86 | self : object
87 | Fitted transformer.
88 | """
89 |
90 | if self.subsample <= 1:
91 | raise ValueError("Invalid value for 'subsample': %d. "
92 | "The number of subsamples must be at least two." % self.subsample)
93 |
94 | X = self._check_inputs(X, in_fit=True, copy=False)
95 | n_samples = X.shape[0]
96 |
97 | if n_samples <= 1:
98 | raise ValueError("Invalid value for samples: %d. "
99 | "The number of samples to fit for must be at least two." % n_samples)
100 |
101 | rng = check_random_state(self.random_state)
102 |
103 | # Create the quantiles of reference
104 | self.smart_fit_ = not sparse.issparse(X)
105 | if self.smart_fit_: # <<<<<- New case
106 | self._smart_fit(X, rng)
107 | else:
108 | raise NotImplementedError('EfficientQuantileTransformer handles only NON-sparse matrices!')
109 |
110 | return self
111 |
112 | def _smart_transform_col(self, X_col, quantiles, references, inverse):
113 | """Private function to transform a single feature."""
114 |
115 | isfinite_mask = ~np.isnan(X_col)
116 | X_col_finite = X_col[isfinite_mask]
117 | # Simply Interpolate
118 | if not inverse:
119 | X_col[isfinite_mask] = np.interp(X_col_finite, quantiles, references)
120 | else:
121 | X_col[isfinite_mask] = np.interp(X_col_finite, references, quantiles)
122 |
123 | return X_col
124 |
125 | def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False):
126 | """Check inputs before fit and transform."""
127 | try:
128 | X = self._validate_data(X,
129 | reset=in_fit,
130 | accept_sparse=False,
131 | copy=copy,
132 | dtype=FLOAT_DTYPES,
133 | force_all_finite='allow-nan')
134 | except AttributeError: # Old sklearn version (_validate_data do not exists)
135 | X = check_array(X, accept_sparse=False, copy=self.copy, dtype=FLOAT_DTYPES, force_all_finite='allow-nan')
136 |
137 | # we only accept positive sparse matrix when ignore_implicit_zeros is
138 | # false and that we call fit or transform.
139 | with np.errstate(invalid='ignore'): # hide NaN comparison warnings
140 | if (not accept_sparse_negative and not self.ignore_implicit_zeros
141 | and (sparse.issparse(X) and np.any(X.data < 0))):
142 | raise ValueError('QuantileTransformer only accepts' ' non-negative sparse matrices.')
143 |
144 | # check the output distribution
145 | if self.output_distribution not in ('normal', 'uniform'):
146 | raise ValueError("'output_distribution' has to be either 'normal'"
147 | " or 'uniform'. Got '{}' instead.".format(self.output_distribution))
148 |
149 | return X
150 |
151 | def _transform(self, X, inverse=False):
152 | """Forward and inverse transform.
153 | Parameters
154 | ----------
155 | X : ndarray of shape (n_samples, n_features)
156 | The data used to scale along the features axis.
157 | inverse : bool, default=False
158 | If False, apply forward transform. If True, apply
159 | inverse transform.
160 | Returns
161 | -------
162 | X : ndarray of shape (n_samples, n_features)
163 | Projected data.
164 | """
165 | for feature_idx in range(X.shape[1]):
166 | X[:, feature_idx] = self._smart_transform_col(X[:, feature_idx], self.quantiles_[feature_idx],
167 | self.references_[feature_idx], inverse)
168 |
169 | return X
170 |
171 | def transform(self, X):
172 | """Feature-wise transformation of the data.
173 | Parameters
174 | ----------
175 | X : {array-like} of shape (n_samples, n_features)
176 | The data used to scale along the features axis.
177 |
178 | Returns
179 | -------
180 | Xt : {ndarray, sparse matrix} of shape (n_samples, n_features)
181 | The projected data.
182 | """
183 | check_is_fitted(self, ['quantiles_', 'references_', 'smart_fit_'])
184 | X = self._check_inputs(X, in_fit=False, copy=self.copy)
185 | return self._transform(X, inverse=False)
186 |
187 | def inverse_transform(self, X):
188 | """Back-projection to the original space.
189 | Parameters
190 | ----------
191 | X : {array-like} of shape (n_samples, n_features)
192 | The data used to scale along the features axis.
193 |
194 | Returns
195 | -------
196 | Xt : {ndarray, sparse matrix} of (n_samples, n_features)
197 | The projected data.
198 | """
199 | check_is_fitted(self, ['quantiles_', 'references_', 'smart_fit_'])
200 | X = self._check_inputs(X, in_fit=False, accept_sparse_negative=False, copy=self.copy)
201 |
202 | if self.overflow is None:
203 | T = self._transform(X, inverse=True)
204 | elif self.overflow == 'nan':
205 | NaN_mask = np.ones(X.shape)
206 | NaN_mask[(X > 1) | (X < 0)] = np.nan
207 | T = NaN_mask * self._transform(X, inverse=True)
208 |
209 | elif self.overflow == 'sum':
210 | ones = self._transform(np.ones(X.shape), inverse=True)
211 | zeros = self._transform(np.zeros(X.shape), inverse=True)
212 |
213 | # Standard computation
214 | T = self._transform(X.copy(), inverse=True)
215 |
216 | # Deduct already computed part
217 | X = np.where((X > 0), np.maximum(X - 1, 0), X)
218 |
219 | # After this X > 0 => Remaining quantile > 1.00
220 | # and X < 0 => Remaining quantile < 0.00
221 |
222 | T = T + (X > 1) * np.floor(X) * (ones - zeros)
223 | X = np.where((X > 1), np.maximum(X - np.floor(X), 0), X)
224 | T = T + (X > 0) * (ones - self._transform(1 - X.copy(), inverse=True))
225 |
226 | T = T - (X < -1) * np.floor(-X) * (ones - zeros)
227 | X = np.where((X < -1), np.minimum(X + np.floor(-X), 0), X)
228 | T = T - (X < 0) * (self._transform(-X.copy(), inverse=True) - zeros)
229 |
230 | # Set the NaN the values that have not been reached after a certaing amount of iterations
231 | # T[(X > 0) | (X < 0)] = np.nan
232 |
233 | else:
234 | raise ValueError('Invalid value for overflow.')
235 |
236 | return T
237 |
238 |
239 | # %%
240 |
--------------------------------------------------------------------------------
/src/cfshap/utils/project.py:
--------------------------------------------------------------------------------
1 | """Functions to find the decision boundary of a classifier using bisection.
2 | """
3 |
4 | __author__ = 'Emanuele Albini'
5 | __all__ = [
6 | 'find_decision_boundary_bisection',
7 | ]
8 |
9 | import numpy as np
10 |
11 | from ..base import Model, Scaler
12 | from .parallel import batch_process
13 |
14 |
15 | def find_decision_boundary_bisection(
16 | x: np.ndarray,
17 | Y: np.ndarray,
18 | model: Model,
19 | scaler: Scaler = None,
20 | num: int = 1000,
21 | n_jobs: int = 1,
22 | error: float = 1e-10,
23 | method: str = 'mean',
24 | mem_coeff: float = 1,
25 | model_parallelism: int = 1,
26 | desc='Find Decision Boundary Ellipse (batch)',
27 | ) -> np.ndarray:
28 | """Find the decision boundary between x and multiple points Y using bisection.
29 |
30 | Args:
31 | x (np.ndarray): A point.
32 | Y (np.ndarray): An array of points.
33 | model (Model): A model (that implements model.predict(X))
34 | scaler (Scaler, optional): A scaler. Defaults to None (no scaling).
35 | num (int, optional): The (maximum) number of bisection steps. Defaults to 1000.
36 | n_jobs (int, optional): The number of parallel jobs. Defaults to 1.
37 | error (float, optional): The early stopping error for the bisection. Defaults to 1e-10.
38 | method (str, optional): The method to find the point. Defaults to 'mean'.
39 | - 'left': On the left of the decision boundary (closer to x).
40 | - 'right': On the right of the decision boundary (closer to y).
41 | - 'mean': The mean of the two points (it can be either on the left of right side).
42 | - 'counterfactual': alias for 'right'.
43 | mem_coeff (float, optional): The coefficient for the job split (the higher the bigger every single job will be). Defaults to 1.
44 | model_parallelism (int, optional): Factor to enable parallel bisection. Defaults to 1.
45 | desc (str, optional): TQDM progress bar description. Defaults to 'Find Decision Boundary Ellipse (batch)'.
46 |
47 | Returns:
48 | np.ndarray: The points close to the decision boundary.
49 |
50 | NOTE: This function perform much better when num << 1,000,000
51 | because it batches together multiple points when predicting
52 | (this is also done to avoid memory issues).
53 | """
54 |
55 | assert model_parallelism >= 1
56 | assert method in ['mean', 'left', 'right', 'counterfactual']
57 |
58 | if scaler is not None:
59 | assert hasattr(scaler, 'transform'), 'Scaler must have a `transform` method.'
60 | assert hasattr(scaler, 'inverse_transform'), 'Scaler must have a `inverse_transform` method.'
61 |
62 | def _find_decision_boundary_bisection(x: np.ndarray, Y: np.ndarray):
63 |
64 | x = np.asarray(x).copy()
65 | Y = np.asarray(Y).copy()
66 | X = np.tile(x, (Y.shape[0], 1))
67 | Yrange = np.arange(Y.shape[0])
68 |
69 | model_parallelism_ = int(np.ceil(model_parallelism / len(Y)))
70 |
71 | # Compute the predictions for the two points and check that they are different
72 | pred_x = model.predict(np.array([x]))
73 | pred_Y = model.predict(np.array(Y))
74 | assert np.all(pred_x != pred_Y), 'The predictions for x and Y are the same. They must be different.'
75 |
76 | for _ in range(num):
77 | # 'Pure' bisection
78 | if model_parallelism_ == 1:
79 | if scaler is None:
80 | M = (X + Y) / 2
81 | else:
82 | X_ = scaler.transform(X)
83 | Y_ = scaler.transform(Y)
84 | M_ = (X_ + Y_) / 2
85 | M = scaler.inverse_transform(M_)
86 |
87 | # Cast to proper type (e.g., if X and/or Y are integers) with proper precision
88 | if M.dtype != X.dtype:
89 | X = X.astype(M.dtype)
90 | if M.dtype != Y.dtype:
91 | Y = Y.astype(M.dtype)
92 |
93 | preds = (model.predict(M) != pred_x)
94 | different_indexes = np.argwhere(preds)
95 | non_different_indexes = np.argwhere(~preds)
96 |
97 | Y[different_indexes] = M[different_indexes]
98 | X[non_different_indexes] = M[non_different_indexes]
99 |
100 | # Parallel bisection
101 | else:
102 | if scaler is None:
103 | M = np.concatenate(np.linspace(X, Y, model_parallelism_ + 1, endpoint=False)[1:])
104 | else:
105 | X_ = scaler.transform(X)
106 | Y_ = scaler.transform(Y)
107 | M_ = np.concatenate(np.linspace(X_, Y_, model_parallelism_, endpoint=False))
108 | M = scaler.inverse_transform(M_)
109 |
110 | # Predict
111 | preds = (model.predict(M) != pred_x).reshape(model_parallelism_, -1).T
112 |
113 | # Rebuild M
114 | M = np.concatenate([
115 | np.expand_dims(X, axis=0),
116 | M.reshape(model_parallelism_, Y.shape[0], -1),
117 | np.expand_dims(Y, axis=0)
118 | ])
119 |
120 | # Find next index
121 | left_index = np.array([np.searchsorted(preds[i], 1) for i in range(len(Y))])
122 | right_index = left_index + 1
123 |
124 | # Update left and right
125 | X = M[left_index, Yrange]
126 | Y = M[right_index, Yrange]
127 |
128 | # Early stopping
129 | err = np.max(np.linalg.norm((X - Y), axis=1, ord=2))
130 | if err < error:
131 | break
132 |
133 | # The decision boundary is in between the change points
134 | if method == 'mean':
135 | return (X + Y) / 2
136 | # Same predictions
137 | elif method == 'left':
138 | return X
139 | # Counterfactual
140 | elif method == 'right' or method == 'counterfactual':
141 | return Y
142 | else:
143 | raise ValueError('Invalid method.')
144 |
145 | return batch_process(
146 | lambda Y: list(_find_decision_boundary_bisection(x=x, Y=Y)),
147 | iterable=Y,
148 | # Optimal split size (cpu-mem threadoff)
149 | split_size=int(max(1, 100 * 1000 * 1000 / num / (x.shape[0]))) * mem_coeff,
150 | desc=desc,
151 | n_jobs=min(n_jobs, len(Y)),
152 | )
--------------------------------------------------------------------------------
/src/cfshap/utils/random/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Random utilities.
5 | """
6 |
7 | from ._sample import *
8 | from ._stability import *
--------------------------------------------------------------------------------
/src/cfshap/utils/random/_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Random sampling utilities.
5 | """
6 |
7 | from typing import Union
8 | import numpy as np
9 | import pandas as pd
10 |
11 | __all__ = [
12 | 'sample_data',
13 | 'np_sample',
14 | ]
15 |
16 |
17 | def sample_data(
18 | X: Union[pd.DataFrame, np.ndarray],
19 | n=None,
20 | frac=None,
21 | random_state=None,
22 | replace=False,
23 | **kwargs,
24 | ) -> Union[pd.DataFrame, np.ndarray]:
25 | assert frac is None or n is None, "Cannot specify both `n` and `frac`"
26 | assert not (frac is None and n is None), "One of `n` or `frac` must be passed."
27 |
28 | if isinstance(X, pd.DataFrame):
29 | return X.sample(
30 | n=n,
31 | frac=frac,
32 | random_state=random_state,
33 | replace=replace,
34 | **kwargs,
35 | )
36 | elif isinstance(X, np.ndarray):
37 | if frac is not None:
38 | n = int(np.ceil(len(X) * frac))
39 | return np_sample(X, n=n, replace=replace, random_state=random_state, **kwargs)
40 | else:
41 | raise NotImplementedError('Unsupported dataset type.')
42 |
43 |
44 | def np_sample(
45 | a: Union[np.ndarray, int],
46 | n: Union[int, None],
47 | replace: bool = False,
48 | seed: Union[None, int] = None,
49 | random_state: Union[None, int] = None,
50 | safe: bool = False,
51 | ) -> np.ndarray:
52 | """Randomly sample on axis 0 of a NumPy array
53 |
54 | Args:
55 | a (Union[np.ndarray, int]): The array to be samples, or the integer (max) for an `range`
56 | n (int or None): Number of samples to be draw. If None, it sample all the samples.
57 | replace (bool, optional): Repeat samples or not. Defaults to False.
58 | seed (Union[None, int], optional): Random seed for NumPy. Defaults to None.
59 | random_state (Union[None, int], optional): Alias for seed. Defaults to None.
60 | safe (bool, optional) : Safely handle `n` or not. If True and replace = False, and n > len(a), then n = len(a)
61 |
62 | Returns:
63 | np.ndarray: A random sample
64 | """
65 | assert random_state is None or seed is None
66 |
67 | if random_state is not None:
68 | seed = random_state
69 |
70 | if seed is not None:
71 | random_state = np.random.RandomState(seed)
72 | else:
73 | random_state = np.random
74 |
75 | # Range case
76 | if isinstance(a, int):
77 | if safe and n > a:
78 | n = a
79 | return random_state.choice(a, n, replace=replace)
80 | # Array sampling case
81 | else:
82 | if n is None:
83 | n = len(a)
84 | if safe and n > len(a):
85 | n = len(a)
86 | return a[random_state.choice(a.shape[0], n, replace=replace)]
--------------------------------------------------------------------------------
/src/cfshap/utils/random/_stability.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Random stability (i.e., deterministic execution) utilities.
5 | """
6 |
7 | import logging
8 | import os
9 | import random
10 |
11 | __all__ = [
12 | 'random_stability',
13 | ]
14 |
15 |
16 | def random_stability(
17 | seed_value=0,
18 | deterministic=True,
19 | verbose=True,
20 | ):
21 | '''Set random seed/global random states to the specified value for a series of libraries:
22 |
23 | - Python environment
24 | - Python random package
25 | - NumPy/Scipy
26 | - Tensorflow
27 | - Keras
28 | - Torch
29 |
30 | seed_value (int): random seed
31 | deterministic (bool) : negatively effect performance making (parallel) operations deterministic. Default to True.
32 | verbose (bool): Output verbose log. Default to True.
33 | '''
34 | # pylint: disable=bare-except
35 |
36 | outputs = []
37 |
38 | # Python environment
39 | os.environ['PYTHONHASHSEED'] = str(seed_value)
40 | outputs.append('PYTHONHASHSEED (env)')
41 |
42 | # Python random
43 | random.seed(seed_value)
44 | outputs.append('random')
45 |
46 | try:
47 | import numpy as np
48 |
49 | np.random.seed(seed_value)
50 | outputs.append('NumPy')
51 | except ModuleNotFoundError:
52 | pass
53 |
54 | # TensorFlow 2
55 | try:
56 | import tensorflow as tf
57 | tf.random.set_seed(seed_value)
58 | if deterministic:
59 | outputs.append('TensorFlow 2 (deterministic)')
60 | tf.config.threading.set_inter_op_parallelism_threads(1)
61 | tf.config.threading.intra_op_parallelism_threads(1)
62 | else:
63 | outputs.append('TensorFlow 2 (parallel, non-deterministic)')
64 | except (ModuleNotFoundError, ImportError, AttributeError):
65 | pass
66 |
67 | # TensorFlow 1 & Keras ? Not sure it works
68 | try:
69 | import tensorflow as tf
70 | if tf.__version__ < '2':
71 | try:
72 | import tensorflow.compat.v1
73 | from tf.compat.v1 import set_random_seed
74 | except (ModuleNotFoundError, ImportError):
75 | from tf import set_random_seed
76 |
77 | try:
78 | import tensorflow.compat.v1 # noqa: F811
79 | from tf.compat.v1 import ConfigProto
80 | except (ModuleNotFoundError, ImportError):
81 | from tf import ConfigProto
82 |
83 | set_random_seed(seed_value)
84 | if deterministic:
85 | outputs.append('TensorFlow 1 (deterministic)')
86 | session_conf = ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
87 | else:
88 | outputs.append('TensorFlow 1 (parallel, non-deterministic)')
89 | session_conf = ConfigProto()
90 | session_conf.gpu_options.allow_growth = True
91 | sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
92 | try:
93 | from keras import backend as K
94 | K.set_session(sess)
95 | outputs.append('Keras')
96 | except (ModuleNotFoundError, ImportError, AttributeError):
97 | 'Keras random stability failed.'
98 | except (ModuleNotFoundError, ImportError, AttributeError):
99 | pass
100 |
101 | try:
102 | import torch
103 | torch.manual_seed(seed_value)
104 | torch.cuda.manual_seed_all(seed_value)
105 | if deterministic:
106 | torch.backends.cudnn.deterministic = True
107 | torch.backends.cudnn.benchmark = False
108 | outputs.append('PyTorch (deterministic)')
109 | else:
110 | outputs.append('PyTorch (parallel, non-deterministic)')
111 |
112 | except (ModuleNotFoundError, ImportError, AttributeError):
113 | pass
114 | # pylint: enable=bare-except
115 |
116 | if verbose:
117 | logging.info('Random seed (%d) set for: %s', seed_value, ", ".join(outputs))
118 |
--------------------------------------------------------------------------------
/src/cfshap/utils/tree.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Emanuele Albini
3 |
4 | Utilities for tree-based models explainability.
5 | """
6 |
7 | from collections import defaultdict
8 | import numpy as np
9 |
10 | from ..utils import get_shap_compatible_model
11 |
12 | try:
13 | # Newer versions of SHAPss
14 | from shap.explainers._tree import TreeEnsemble
15 | except (ImportError, ModuleNotFoundError):
16 | # Older versions of SHAP
17 | from shap.explainers.tree import TreeEnsemble
18 |
19 | __all__ = [
20 | 'TreeEnsemble',
21 | 'get_splits',
22 | 'splits_to_values',
23 | 'SplitTransformer',
24 | ]
25 |
26 |
27 | def get_splits(model, nb_features):
28 | """Extract feature splits from a tree-based model
29 | In order to extract the trees this method rely on the SHAP API:
30 | any model supported by SHAP is supported.
31 |
32 | Args:
33 | model (Any tree-based model): The model
34 | nb_features (int): The number of features (sometimes it cannot be induced by the model)
35 |
36 | Returns:
37 | List[np.ndarray]: A list of splits for each of the features
38 | nb_features x nb_splits (it may different for each feature)
39 | """
40 |
41 | # Convert the model to something that SHAP API can understand
42 | model = get_shap_compatible_model(model)
43 |
44 | # Extract the trees (using TreeEnseble API from SHAP)
45 | ensemble = TreeEnsemble(model)
46 |
47 | splits = defaultdict(list)
48 | for tree in ensemble.trees:
49 | for i in range(len(tree.features)):
50 | if tree.children_left[i] != -1: # i.e., it is a split (and not a leaf)
51 | splits[tree.features[i]].append(tree.thresholds[i])
52 |
53 | assert len(splits) <= nb_features
54 | assert max(list(splits.keys())) < nb_features
55 |
56 | return [np.sort(np.array(list(set(splits[i])))) for i in range(nb_features)]
57 |
58 |
59 | def __splits_to_values(splits, how: str, eps: float):
60 | if len(splits) == 0:
61 | return [0]
62 |
63 | if how == 'left':
64 | return ([splits[0] - eps] + [(splits[i] + eps) for i in range(len(splits) - 1)] + [splits[-1] + eps])
65 | elif how == 'right':
66 | return ([splits[0] - eps] + [(splits[i + 1] - eps) for i in range(len(splits) - 1)] + [splits[-1] + eps])
67 | elif how == 'center':
68 | return ([splits[0] - eps] + [(splits[i] + splits[i + 1]) / 2
69 | for i in range(len(splits) - 1)] + [splits[-1] + eps])
70 | else:
71 | raise ValueError('Invalid mode.')
72 |
73 |
74 | def splits_to_values(splits, how: str = 'center', eps: float = 1e-6):
75 | """Convert lists of splits in values (in between splits)
76 |
77 | Args:
78 | splits (List[np.ndarray]): List of splits (one per each feature) obtained using get_feature_splits
79 | how (str): Where should the values be wrt to the splits. Default to 'center'.
80 | 'center': precisely in between the two splits (except for the first and last split)
81 | 'right': split - eps
82 | 'left': split + eps
83 | e.g. if we indicate with | the splits and the values with x
84 | 'center': x| x | x | x | x |x
85 | 'left': x|x |x |x |x |x
86 | 'right': x| x| x| x| x|x
87 |
88 | eps (float, optional): epsilon to use for the conversion of splits to values. Defaults to 1e-6.
89 |
90 | Returns:
91 | List[np.ndarray]: List of values for the features supported based on the splits (one list per feature)
92 | """
93 |
94 | return [np.unique(np.array(__splits_to_values(s, how=how, eps=eps))) for s in splits]
95 |
96 |
97 | class SplitTransformer:
98 | def __init__(self, model, nb_features=None):
99 | """The initialization is lazy because we may not know yet the number of features"""
100 | if hasattr(model, 'get_booster'):
101 | model = model.get_booster()
102 | self.model = model
103 | self.nb_features = nb_features
104 | self.values = None
105 | self.splits = None
106 |
107 | self._initialize()
108 |
109 | def _initialize(self, X=None):
110 | if self.nb_features is None and X is not None:
111 | self.nb_features = X.shape[1]
112 | if (self.values is None or self.splits is None) and (self.nb_features is not None):
113 | self.splits = get_splits(self.model, self.nb_features)
114 | self.values = splits_to_values(self.splits, how='center')
115 |
116 | def get_nb_value(self, i):
117 | return len(self.splits[i]) + 1
118 |
119 | def transform(self, X):
120 | """Takes inputs and transform them into discrete splits"""
121 | A = np.asarray(X)
122 | self._initialize(A)
123 | for i in range(A.shape[1]):
124 | if len(self.splits[i]) > 0:
125 | A[:, i] = np.digitize(A[:, i], self.splits[i])
126 | else:
127 | A[:, i] = 0
128 | return A.astype(int)
129 |
130 | def inverse_transform(self, X):
131 | """Takes splits and transform them into continuous inputs
132 | Note: The post-condition T^(-1)(T(X)) == X may not hold.
133 | """
134 | X_ = np.asarray(X).astype(int)
135 | self._initialize(X)
136 | A = np.full(X.shape, np.nan)
137 | for i in range(X.shape[1]):
138 | A[:, i] = self.values[i][X_[:, i]]
139 | return A
140 |
141 |
142 | # import numba
143 | # try:
144 | # from shap.explainers._tree import TreeEnsemble
145 | # except:
146 | # # Older versions
147 | # from shap.explainers.tree import TreeEnsemble
148 | # @numba.jit(nopython=True, nogil=True)
149 | # def predict_tree(x, region, features, values, thresholds, children_left, children_right, children_default):
150 | # n = 0
151 | # while n != -1: # -1 => Leaf
152 | # f = features[n]
153 | # t = thresholds[n]
154 | # v = x[f]
155 |
156 | # if v < t: # Left
157 | # # Constrain region
158 | # region[f, 1] = np.minimum(region[f, 1], t)
159 | # # print(f'T{i}: x_{f} = {v} < {t}')
160 |
161 | # # Go to child
162 | # n = children_left[n]
163 |
164 | # elif v >= t: # Right
165 | # # Constrain region
166 | # region[f, 0] = np.maximum(region[f, 0], t)
167 | # # print(f'T{i}: x_{f} = {v} >= {t}')
168 |
169 | # # Go to child
170 | # n = children_right[n]
171 |
172 | # else: # Missing
173 | # n = children_default[n]
174 | # region[f] = np.nan
175 |
176 | # return region
177 |
178 | # def predict_region(x, region, ensemble):
179 | # for i, tree in enumerate(ensemble.trees):
180 | # region = predict_tree(
181 | # x, region, tree.features, tree.values, tree.thresholds, tree.children_left, tree.children_right,
182 | # tree.children_default
183 | # )
184 | # return region
185 |
186 | # def predict_regions(X, model):
187 | # ensemble = TreeEnsemble(model)
188 | # regions = np.tile(
189 | # np.array([-np.inf, np.inf], dtype=ensemble.trees[0].thresholds.dtype), (X.shape[0], X.shape[1], 1)
190 | # )
191 | # return np.stack([predict_region(x, region, ensemble) for x, region in zip(X, regions)])
--------------------------------------------------------------------------------
/tests/test_usage.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests of basic usage.
3 | """
4 |
5 | import pytest
6 | import itertools
7 |
8 | from sklearn.datasets import load_boston
9 | from xgboost import XGBClassifier
10 |
11 | from cfshap.utils.preprocessing import EfficientQuantileTransformer
12 | from cfshap.counterfactuals import KNNCounterfactuals
13 | from cfshap.attribution import TreeExplainer, CompositeExplainer
14 | from cfshap.trend import TrendEstimator
15 |
16 |
17 | def test_usage():
18 |
19 | dataset = load_boston()
20 | X = dataset.data
21 | y = (dataset.target > 21).astype(int)
22 |
23 | model = XGBClassifier()
24 | model.fit(X, y)
25 |
26 | scaler = EfficientQuantileTransformer()
27 | scaler.fit(X)
28 |
29 | trend_estimator = TrendEstimator(strategy='mean')
30 |
31 | explainer = CompositeExplainer(
32 | KNNCounterfactuals(
33 | model=model,
34 | X=X,
35 | n_neighbors=100,
36 | distance='cityblock',
37 | scaler=scaler,
38 | max_samples=10000,
39 | ),
40 | TreeExplainer(
41 | model,
42 | data=None,
43 | trend_estimator=trend_estimator,
44 | max_samples=10000,
45 | ),
46 | )
47 |
48 | return explainer(X[:10])
49 |
--------------------------------------------------------------------------------