├── .github ├── CODEOWNERS └── workflows │ └── python-app.yml ├── .gitignore ├── Example.html ├── Example.ipynb ├── LICENSE ├── README.md ├── assets ├── jpmorgan-logo.svg └── xai_coe-logo.png ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── cfshap │ ├── __init__.py │ ├── attribution │ ├── __init__.py │ ├── composite.py │ └── shap.py │ ├── background │ ├── __init__.py │ ├── counterfactual_adapter.py │ └── opposite.py │ ├── base.py │ ├── counterfactuals │ ├── __init__.py │ ├── composition.py │ ├── knn.py │ └── project.py │ ├── evaluation │ ├── __init__.py │ ├── attribution │ │ ├── __init__.py │ │ ├── _basic.py │ │ └── induced_counterfactual │ │ │ ├── __init__.py │ │ │ ├── _actionability.py │ │ │ └── _utils.py │ └── counterfactuals │ │ ├── __init__.py │ │ └── plausibility.py │ ├── trend.py │ └── utils │ ├── __init__.py │ ├── _python.py │ ├── _utils.py │ ├── parallel │ ├── __init__.py │ ├── batch.py │ └── utils.py │ ├── preprocessing │ ├── __init__.py │ └── scaling │ │ ├── __init__.py │ │ ├── madscaler.py │ │ ├── multiscaler.py │ │ └── quantiletransformer.py │ ├── project.py │ ├── random │ ├── __init__.py │ ├── _sample.py │ └── _stability.py │ └── tree.py └── tests └── test_usage.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | # The order matters with a later match taking precedence, 3 | # Default Code Owners will be your maintainer team. 4 | # Set a review when someone opens a pull request. 5 | * @jpmorganchase/XAI-CoE_maintainer 6 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ${{ matrix.os }} 19 | strategy: 20 | fail-fast: false 21 | matrix: 22 | os: [ubuntu-latest] 23 | python-version: [3.6, 3.7, 3.8, 3.9] 24 | 25 | steps: 26 | - uses: actions/checkout@v3 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v3 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip setuptools wheel 34 | pip install flake8 pytest 35 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 36 | - name: Install package (with extras test) 37 | run: | 38 | python -m pip install .[test] 39 | - name: Lint (flake8) 40 | run: | 41 | # Stop the build if there are Python syntax errors or undefined names 42 | flake8 . --select=E9,F63,F7,F82 --show-source 43 | # exit-zero treats all errors as warnings. 44 | flake8 . --exit-zero --show-source 45 | - name: Test (pytest) 46 | run: | 47 | pytest 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | **/__pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | **/.pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | **/.ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Windows 132 | **/Thumbs.db 133 | 134 | # Other cache 135 | **/.pyc 136 | 137 | **/.vscode 138 | 139 | 140 | links.sh 141 | links_init.py 142 | clean.sh 143 | 144 | a.yml 145 | 146 | src/emutils/data/lendingclub 147 | 148 | flake.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | JPMorgan Logo 3 | Explainale AI Center of Excellence Logo 4 | 5 | 6 | 7 | 8 | [![License](https://img.shields.io/github/license/jpmorganchase/cf-shap)](https://github.com/jpmorganchase/cf-shap/blob/master/LICENSE) 9 | [![Maintaner](https://img.shields.io/badge/maintainer-Emanuele_Albini-blue)](https://www.emanuelealbini.com) 10 | 11 | 12 | # Counterfactual SHAP (`cf-shap`) 13 | A modular framework for the generation of counterfactual feature attribution explanations (a.k.a., feature importance). 14 | This Python package implements the algorithms proposed in the following paper. 15 | If you use this package please cite our work. 16 | 17 | **Counterfactual Shapley Additive Explanations** 18 | Emanuele Albini, Jason Long, Danial Dervovic and Daniele Magazzeni 19 | J.P. Morgan AI Research 20 | [ACM](https://dl.acm.org/doi/abs/10.1145/3531146.3533168) | [ArXiv](https://arxiv.org/abs/2110.14270) 21 | 22 | ``` 23 | @inproceedings{Albini2022, 24 | title = {Counterfactual {{Shapley Additive Explanations}}}, 25 | booktitle = {2022 {{ACM Conference}} on {{Fairness}}, {{Accountability}}, and {{Transparency}}}, 26 | author = {Albini, Emanuele and Long, Jason and Dervovic, Danial and Magazzeni, Daniele}, 27 | year = {2022}, 28 | series = {{{FAccT}} '22}, 29 | pages = {1054--1070}, 30 | doi = {10.1145/3531146.3533168} 31 | } 32 | ``` 33 | 34 | **Note that this repository contains the package with the algorithms for the generation of the explanations proposed in the paper and their evaluations _but not_ the expriments themselves.** If you are interested in reproducing the results of the paper, please refer to the [cf-shap-facct22](https://github.com/jpmorganchase/cf-shap-facct22) repository (that uses the algorithms implemented in this repository). 35 | 36 | ## 1. Installation 37 | To install the package manually, simply use the following commands. Note that this package depends on shap>=0.39.0 package: you may want to install this package or the other dependencies manually (using conda or pip). 38 | 39 | ```bash 40 | # Clone the repo into the `cf-shap` directory 41 | git clone https://github.com/jpmorganchase/cf-shap.git 42 | 43 | # Install the package in editable mode 44 | pip install -e cf-shap 45 | ``` 46 | The package has been tested with Python 3.6 and 3.8, but it is agnostic to the Python version being used. 47 | See `setup.py` and `requirements.txt` for more details on the dependencies of the package. 48 | 49 | ## 2. Usage Example 50 | Check out `Example.ipynb` or `Example.html` for a basic usage example of the package. 51 | 52 | ## 3. Contacts and Issues 53 | 54 | For further information or queries on this work you can contact the _Explainable AI Center of Excellence at J.P. Morgan_ ([xai.coe@jpmchase.com](mailto:xai.coe@jpmchase.com)) or [Emanuele Albini](https://www.emanuelealbini.com), the main author of the paper. 55 | 56 | If you have issues using the package, feel free to open an issue on the GitHub, or contact the authors using the contacts above. We will try to address any issue as soon as possible. 57 | 58 | ## 4. Disclaimer 59 | 60 | This repository was prepared for informational purposes by the Artificial Intelligence Research group of JPMorgan Chase & Co. and its affiliates (``JP Morgan''), and is not a product of the Research Department of JP Morgan. JP Morgan makes no representation and warranty whatsoever and disclaims all liability, for the completeness, accuracy or reliability of the information contained herein. This document is not intended as investment research or investment advice, or a recommendation, offer or solicitation for the purchase or sale of any security, financial instrument, financial product or service, or to be used in any way for evaluating the merits of participating in any transaction, and shall not constitute a solicitation under any jurisdiction or to any person, if such solicitation under such jurisdiction or to such person would be unlawful. 61 | 62 | The code is provided for illustrative purposes only and is not intended to be used for trading or investment purposes. The code is provided "as is" and without warranty of any kind, express or implied, including without limitation, any warranty of merchantability or fitness for a particular purpose. In no event shall JP Morgan be liable for any direct, indirect, incidental, special or consequential damages, including, without limitation, lost revenues, lost profits, or loss of prospective economic advantage, resulting from the use of the code. 63 | -------------------------------------------------------------------------------- /assets/jpmorgan-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 8 | 10 | 12 | 15 | 20 | 23 | 26 | 32 | 37 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /assets/xai_coe-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmorganchase/cf-shap/ba58c61d4e110a950808d18dbe85b67af862549b/assets/xai_coe-logo.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | scipy==1.5.2 3 | pandas==1.1.5 4 | scikit-learn==0.24.2 5 | xgboost==1.3.3 6 | joblib==1.0.1 7 | tqdm==4.61 8 | numba==0.51.2 9 | shap==0.39.0 10 | cached_property 11 | dataclasses 12 | typing_extensions -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = cfshap 3 | version = 0.0.2 4 | author = Emanuele Albini 5 | 6 | description = Counterfactual SHAP: a framework for counterfactual feature importance 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | 10 | url = https://github.com/jpmorganchase/cf-shap 11 | project_urls = 12 | Source Code = https://github.com/jpmorganchase/cf-shap 13 | Bug Tracker = https://github.com/jpmorganchase/cf-shap/issues 14 | Author Website = https://www.emanuelealbini.com 15 | 16 | classifiers = 17 | Programming Language :: Python :: 3 18 | Programming Language :: Python :: 3.6 19 | Operating System :: OS Independent 20 | 21 | keywords = 22 | XAI 23 | Explainability 24 | Explainable AI 25 | Counterfactuals 26 | Algorithmic Recourse 27 | Contrastive Explanations 28 | Feature Importance 29 | Feature Attribution 30 | Machine Learning 31 | Shapley values 32 | SHAP 33 | FAccT22 34 | 35 | platform = any 36 | 37 | [options] 38 | include_package_data = True 39 | package_dir = 40 | = src 41 | packages = find: 42 | 43 | python_requires = >=3.6 44 | install_requires = 45 | numpy 46 | scipy 47 | pandas 48 | scikit-learn 49 | joblib 50 | tqdm 51 | numba>=0.51.2 52 | shap==0.39.0 53 | cached_property 54 | dataclasses 55 | typing_extensions 56 | 57 | [options.extras_require] 58 | test = pytest; xgboost; scikit-learn 59 | 60 | [options.packages.find] 61 | where = src 62 | 63 | 64 | [flake8] 65 | max-line-length = 200 66 | exclude = .git,__pycache__,docs,old,build,dist,venv 67 | # Blank lines, trailins spaces, etc. 68 | extend-ignore = W29,W39 69 | per-file-ignores = 70 | # imported but unused / unable to detect undefined names 71 | __init__.py: F401,F403 72 | # imported but unused / module level import not at top of file 73 | src/emutils/imports.py: F401, E402 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import pkg_resources 3 | 4 | # setup.cfg 5 | pkg_resources.require('setuptools>=39.2') 6 | 7 | setuptools.setup() -------------------------------------------------------------------------------- /src/cfshap/__init__.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | 3 | try: 4 | __version__ = pkg_resources.get_distribution(__name__).version 5 | except pkg_resources.DistributionNotFound: 6 | __version__ = "0.0.0" -------------------------------------------------------------------------------- /src/cfshap/attribution/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Feature Attribution / Importance techniques. 5 | """ 6 | 7 | from .shap import * 8 | from .composite import * 9 | -------------------------------------------------------------------------------- /src/cfshap/attribution/composite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | CF-SHAP Explainer, it uses a CF generator togheter with an explainer that supports at-runtime background distributions. 5 | """ 6 | 7 | from typing import Union 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from ..base import ( 12 | BaseExplainer, 13 | ExplainerSupportsDynamicBackground, 14 | BackgroundGenerator, 15 | CounterfactualMethod, 16 | ListOf2DArrays, 17 | ) 18 | 19 | from ..background import CounterfactualMethodBackgroundGeneratorAdapter 20 | from ..utils import attrdict 21 | 22 | __all__ = [ 23 | 'CompositeExplainer', 24 | 'CFExplainer', 25 | ] 26 | 27 | 28 | class CompositeExplainer(BaseExplainer): 29 | """CompositeExplainer allows to compose: 30 | - a background generator (or a counterfactual generation method), and 31 | - an explainer (i.e., feature importance method) that suppports a background distributions 32 | into a new explainer that uses the dataset obtained with the background generator 33 | as background for the feature importance method. 34 | 35 | """ 36 | def __init__( 37 | self, 38 | background_generator: Union[CounterfactualMethod, BackgroundGenerator], 39 | explainer: ExplainerSupportsDynamicBackground, 40 | n_top: Union[None, int] = None, 41 | verbose: bool = True, 42 | ): 43 | """ 44 | 45 | Args: 46 | background_generator (Union[CounterfactualMethod, BackgroundGenerator]): A background generator. 47 | explainer (ExplainerSupportsDynamicBackground): An explainer supporting a dynamic background dataset. 48 | n_top (Union[None, int], optional): Number of top counterfactuals (only if background_generator is a counterfactual method). Defaults to None. 49 | verbose (bool, optional): If True, will be verbose. Defaults to True. 50 | 51 | """ 52 | 53 | super().__init__(background_generator.model) 54 | if isinstance(background_generator, CounterfactualMethod): 55 | background_generator = CounterfactualMethodBackgroundGeneratorAdapter(background_generator, n_top=n_top) 56 | else: 57 | if n_top is not None: 58 | raise NotImplementedError('n_top is supported only for counterfactual methods.') 59 | 60 | self.background_generator = background_generator 61 | self.explainer = explainer 62 | self.verbose = verbose 63 | 64 | def get_backgrounds( 65 | self, 66 | X: np.ndarray, 67 | background_data: Union[None, ListOf2DArrays] = None, 68 | ) -> ListOf2DArrays: 69 | """Generate the background datasets for the query instances. 70 | 71 | Args: 72 | X (np.ndarray): The query instances 73 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None. 74 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor. 75 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances. 76 | 77 | Returns: 78 | ListOf2DArrays : A list/array of background datasets 79 | nb_query_instances x nb_background_samples x nb_features 80 | """ 81 | X = self.preprocess(X) 82 | 83 | # If we do not have any background then we compute it 84 | if background_data is None: 85 | return self.background_generator.get_backgrounds(X) 86 | 87 | # If we have some background we use it 88 | else: 89 | assert len(self.background_data) == X.shape[0] 90 | return background_data 91 | 92 | def _get_backgrounds_iterator(self, X, background_data): 93 | iters = zip(X, background_data) 94 | if len(X) > 100 and self.verbose: 95 | iters = tqdm(iters) 96 | 97 | return iters 98 | 99 | def get_attributions( 100 | self, 101 | X: np.array, 102 | background_data: Union[None, ListOf2DArrays] = None, 103 | **kwargs, 104 | ) -> np.ndarray: 105 | """Generate the feature attributions for the query instances. 106 | 107 | Args: 108 | X (np.ndarray): The query instances 109 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None. 110 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor. 111 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances. 112 | 113 | Returns: 114 | np.ndarray : An array of feature attributions 115 | nb_query_instances x nb_features 116 | """ 117 | X = self.preprocess(X) 118 | backgrounds = self.get_backgrounds(X, background_data) 119 | 120 | shapvals = [] 121 | for x, background in self._get_backgrounds_iterator(X, backgrounds): 122 | if len(background) > 0: 123 | # Set background 124 | self.explainer.data = background 125 | # Compute Shapley values 126 | shapvals.append(self.explainer.get_attributions(x.reshape(1, -1), **kwargs)[0]) 127 | else: 128 | shapvals.append(np.full(X.shape[1], np.nan)) 129 | return np.array(shapvals) 130 | 131 | def get_trends( 132 | self, 133 | X: np.array, 134 | background_data: Union[None, ListOf2DArrays] = None, 135 | ): 136 | """Generate the feature trends for the query instances. 137 | 138 | Args: 139 | X (np.ndarray): The query instances 140 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None. 141 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor. 142 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances. 143 | 144 | Returns: 145 | np.ndarray : An array of feature trends (+1 / -1) 146 | nb_query_instances x nb_features 147 | """ 148 | X = self.preprocess(X) 149 | backgrounds = self.get_backgrounds(X, background_data) 150 | 151 | trends = [] 152 | for x, background in self._get_backgrounds_iterator(X, backgrounds): 153 | if len(background) > 0: 154 | self.explainer.data = background 155 | trends.append(self.explainer.get_trends(x.reshape(1, -1))[0]) 156 | else: 157 | trends.append(np.full(X.shape[1], np.nan)) 158 | return np.array(trends) 159 | 160 | def __call__( 161 | self, 162 | X: np.array, 163 | background_data: Union[None, ListOf2DArrays] = None, 164 | ): 165 | """Generate the explanations for the query instances. 166 | 167 | Args: 168 | X (np.ndarray): The query instances 169 | background_data (Union[None, ListOf2DArrays], optional): The background datasets. Defaults to None. 170 | By default it will recompute the background for each query instance using the backgroud generator passed in the constructor. 171 | One may consider passing the background_data here to accelerate the execution if the method is called multiple times on the same query instances. 172 | 173 | Returns: 174 | attrdict : The explanations. 175 | 176 | See BaseExplainer.__call__ for more details on the output format. 177 | """ 178 | X = self.preprocess(X) 179 | backgrounds = self.get_backgrounds(X, background_data) 180 | 181 | shapvals = [] 182 | 183 | # Check if trends is implemented 184 | try: 185 | self.get_trends(X[:1]) 186 | trends = [] 187 | except NotImplementedError: 188 | trends = None 189 | 190 | for x, background in self._get_backgrounds_iterator(X, backgrounds): 191 | if len(background) > 0: 192 | self.explainer.data = background 193 | shapvals.append(self.explainer.get_attributions(x.reshape(1, -1))[0]) 194 | if trends is not None: 195 | trends.append(self.explainer.get_trends(x.reshape(1, -1))[0]) 196 | else: 197 | shapvals.append(np.full(X.shape[1], np.nan)) 198 | if trends is not None: 199 | trends.append(np.full(X.shape[1], np.nan)) 200 | return attrdict( 201 | backgrounds=backgrounds, 202 | values=np.array(shapvals), 203 | trends=np.array(trends) if trends is not None else None, 204 | ) 205 | 206 | 207 | # Alias 208 | CFExplainer = CompositeExplainer -------------------------------------------------------------------------------- /src/cfshap/attribution/shap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | SHAP Wrapper 5 | """ 6 | 7 | from abc import ABC 8 | from typing import Union 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from shap.maskers import Independent as SHAPIndependent 13 | from shap import TreeExplainer as SHAPTreeExplainer 14 | from shap.explainers import Exact as SHAPExactExplainer 15 | 16 | from ..base import (BaseExplainer, BaseSupportsDynamicBackground, Model, TrendEstimatorProtocol) 17 | from ..utils import get_shap_compatible_model 18 | 19 | __all__ = [ 20 | 'TreeExplainer', 21 | 'ExactExplainer', 22 | ] 23 | 24 | 25 | class SHAPExplainer(BaseExplainer, BaseSupportsDynamicBackground, ABC): 26 | @property 27 | def explainer(self): 28 | if self._data is None: 29 | self._raise_data_error() 30 | return self._explainer 31 | 32 | def get_backgrounds(self, X): 33 | return np.broadcast_to(self.data, [X.shape[0], *self.data.shape]) 34 | 35 | def get_trends(self, X): 36 | if self.trend_estimator is None: 37 | return np.full(X.shape, np.nan) 38 | else: 39 | return np.array([self.trend_estimator(x, self.data) for x in X]) 40 | 41 | 42 | class TreeExplainer(SHAPExplainer): 43 | """Monkey-patched and improved wrapper for (Interventional) TreeSHAP 44 | """ 45 | def __init__( 46 | self, 47 | model: Model, 48 | data: Union[None, np.ndarray] = None, 49 | max_samples: int = 1e10, 50 | class_index: int = 1, 51 | trend_estimator: Union[None, TrendEstimatorProtocol] = None, 52 | **kwargs, 53 | ): 54 | """Create the TreeSHAP explainer. 55 | 56 | Args: 57 | model (Model): The model 58 | data (Union[None, np.ndarray], optional): The background dataset. Defaults to None. 59 | If None it must be set dynamically using self.data = ... 60 | max_samples (int, optional): Maximum number of (random) samples to draw from the background dataset. Defaults to 1e10. 61 | class_index (int, optional): Class for which SHAP values will be computed. Defaults to 1 (positive class in a binary classification setting). 62 | trend_estimator (Union[None, TrendEstimatorProtocol], optional): The feature trend estimator. Defaults to None. 63 | **kwargs: Other arguments to be passed to shap.TreeExplainer.__init__ 64 | 65 | """ 66 | super().__init__(model, None) 67 | 68 | self.max_samples = max_samples 69 | self.class_index = class_index 70 | self.trend_estimator = trend_estimator 71 | self.kwargs = kwargs 72 | 73 | # Allow only for interventional SHAP 74 | if ('feature_perturbation' in kwargs) and (kwargs['feature_perturbation'] != 'interventional'): 75 | raise NotImplementedError() 76 | 77 | self.data = data 78 | 79 | @BaseSupportsDynamicBackground.data.setter 80 | def data(self, data): 81 | if data is not None: 82 | if not hasattr(self, '_explainer'): 83 | self._explainer = SHAPTreeExplainer( 84 | get_shap_compatible_model(self.model), 85 | data=SHAPIndependent(self.preprocess(data), max_samples=self.max_samples), 86 | **self.kwargs, 87 | ) 88 | 89 | # Copied and adapted from SHAP 0.39.0 __init__ 90 | # -- Start of copy -- 91 | self._explainer.data = self._explainer.model.data = self._data = SHAPIndependent( 92 | self.preprocess(data), max_samples=self.max_samples).data 93 | self._explainer.data_missing = self._explainer.model.data_missing = pd.isna(self._explainer.data) 94 | try: 95 | self._explainer.expected_value = self._explainer.model.predict(self._explainer.data).mean(0) 96 | except ValueError: 97 | raise Exception( 98 | "Currently TreeExplainer can only handle models with categorical splits when " 99 | "feature_perturbation=\"tree_path_dependent\" and no background data is passed. Please try again using " 100 | "shap.TreeExplainer(model, feature_perturbation=\"tree_path_dependent\").") 101 | if hasattr(self._explainer.expected_value, '__len__') and len(self._explainer.expected_value) == 1: 102 | self._explainer.expected_value = self._explainer.expected_value[0] 103 | 104 | # if our output format requires binary classification to be represented as two outputs then we do that here 105 | if self._explainer.model.model_output == "probability_doubled" and self._explainer.expected_value is not None: 106 | self._explainer.expected_value = [1 - self._explainer.expected_value, self._explainer.expected_value] 107 | # -- End of copy -- 108 | else: 109 | self._data = None 110 | 111 | def get_attributions(self, *args, **kwargs) -> np.ndarray: 112 | """Compute SHAP values 113 | Proxy for shap.TreeExplainer.shap_values(*args, **kwargs) 114 | 115 | Returns: 116 | np.ndarray: nb_samples x nb_features array with SHAP values of class self.class_index 117 | """ 118 | r = self.explainer.shap_values(*args, **kwargs) 119 | 120 | # Select Shapley values of a specific class 121 | if isinstance(r, list): 122 | return r[self.class_index] 123 | else: 124 | return r 125 | 126 | 127 | class ExactExplainer(SHAPExplainer): 128 | def __init__(self, model, data=None, max_samples=1e10, function: str = 'predict', **kwargs): 129 | """Monkey-patched and improved constructor for TreeSHAP 130 | 131 | Args: 132 | model (any): A model 133 | data (np.ndarray, optional): Reference data. Defaults to None. 134 | max_samples (int, optional): Maximum number of samples in the reference data. Defaults to 1e10. 135 | function (str): the name of the function to call on the model. 136 | **kwargs: Other arguments to be passed to shap.ExactExplainer.__init__ 137 | """ 138 | 139 | super().__init__(model, None) 140 | 141 | self.max_samples = max_samples 142 | self.function = function 143 | self.kwargs = kwargs 144 | 145 | if ('identity' in kwargs) and (kwargs['link'] != 'identity'): 146 | raise NotImplementedError() 147 | 148 | self.data = data 149 | 150 | @BaseSupportsDynamicBackground.data.setter 151 | def data(self, data): 152 | if data is not None: 153 | self._data = self.preprocess(data) 154 | self._explainer = SHAPExactExplainer( 155 | lambda X: getattr(self.model, self.function)(X), 156 | SHAPIndependent(self._data, max_samples=self.max_samples), 157 | **self.kwargs, 158 | ) 159 | else: 160 | self._data = None 161 | 162 | def get_attributions(self, *args, **kwargs): 163 | """Compute SHAP values 164 | Proxy for shap.ExactExplainer.shap_values(*args, **kwargs) 165 | 166 | Returns: 167 | np.ndarray: nb_samples x nb_features array with SHAP values 168 | """ 169 | return self.explainer(*args, **kwargs).values -------------------------------------------------------------------------------- /src/cfshap/background/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Background/Dataset Generation Techniques. 5 | """ 6 | 7 | from .counterfactual_adapter import * 8 | from .opposite import * -------------------------------------------------------------------------------- /src/cfshap/background/counterfactual_adapter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Adapter from a counterfactual method to background generator. 5 | """ 6 | from typing import Union 7 | import numpy as np 8 | 9 | from ..base import (BaseBackgroundGenerator, BackgroundGenerator, CounterfactualMethod, MultipleCounterfactualMethod, 10 | ListOf2DArrays) 11 | 12 | from ..utils import ( 13 | get_top_counterfactuals, 14 | expand_dims_counterfactuals, 15 | ) 16 | 17 | __all__ = ['CounterfactualMethodBackgroundGeneratorAdapter'] 18 | 19 | 20 | class CounterfactualMethodBackgroundGeneratorAdapter(BaseBackgroundGenerator, BackgroundGenerator): 21 | """Adapter to make a counterfactual method into a background generator""" 22 | def __init__( 23 | self, 24 | counterfactual_method: CounterfactualMethod, 25 | n_top: Union[None, int] = None, # By default: All 26 | ): 27 | """ 28 | 29 | Args: 30 | counterfactual_method (CounterfactualMethod): The counterfactual method 31 | n_top (Union[None, int], optional): Number of top-counterfactuals to select as background. Defaults to None (all). 32 | """ 33 | self.counterfactual_method = counterfactual_method 34 | self.n_top = n_top 35 | 36 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays: 37 | """Generate the background datasets for each query instance 38 | 39 | Args: 40 | X (np.ndarray): The query instances 41 | 42 | Returns: 43 | ListOf2DArrays: A list/array of background datasets 44 | nb_query_intances x nb_background_points x nb_features 45 | """ 46 | 47 | X = self.preprocess(X) 48 | 49 | # If we do not have any background then we compute it 50 | if isinstance(self.counterfactual_method, MultipleCounterfactualMethod): 51 | if self.n_top is None: 52 | return self.counterfactual_method.get_multiple_counterfactuals(X) 53 | elif self.n_top == 1: 54 | return expand_dims_counterfactuals(self.counterfactual_method.get_counterfactuals(X)) 55 | else: 56 | return get_top_counterfactuals( 57 | self.counterfactual_method.get_multiple_counterfactuals(X), 58 | X, 59 | n_top=self.n_top, 60 | nan=False, 61 | ) 62 | elif isinstance(self.counterfactual_method, CounterfactualMethod): 63 | if self.n_top is not None and self.n_top != 1: 64 | raise ValueError('Counterfactual methodology do not supportthe generation of multiple counterfactuals.') 65 | return np.expand_dims(self.counterfactual_method.get_counterfactuals(X), axis=1) 66 | else: 67 | raise NotImplementedError('Unsupported CounterfactualMethod.') -------------------------------------------------------------------------------- /src/cfshap/background/opposite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Background data computed as samples with different label or prediction. 5 | """ 6 | 7 | import numpy as np 8 | 9 | from ..base import ( 10 | Model, 11 | BaseBackgroundGenerator, 12 | ListOf2DArrays, 13 | ) 14 | 15 | __all__ = ['DifferentLabelBackgroundGenerator', 'DifferentPredictionBackgroundGenerator'] 16 | 17 | 18 | class BaseDifferent_BackgrondGenerator(BaseBackgroundGenerator): 19 | def __init__(self, model, max_samples=None, random_state=None): 20 | super().__init__(model, None, random_state) 21 | 22 | self.max_samples = max_samples 23 | 24 | @property 25 | def data(self): 26 | return self._data 27 | 28 | def _save_data(self, data): 29 | data = data.copy() 30 | data = self.sample(data, self.max_samples) 31 | data.flags.writeable = False 32 | return data 33 | 34 | def _set_data(self, X, y): 35 | self._data = X 36 | self._data_per_class = {j: self._save_data(X[y != j]) for j in np.unique(y)} 37 | 38 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays: 39 | X = self.preprocess(X) 40 | y = self.model.predict(X) 41 | return [self._data_per_class[j] for j in y] 42 | 43 | 44 | class DifferentPredictionBackgroundGenerator(BaseDifferent_BackgrondGenerator): 45 | """The background dataset will be constituted of the points with a different PREDICTION to that of the query instance. 46 | """ 47 | def __init__(self, model: Model, data: np.ndarray, **kwargs): 48 | super().__init__(model, **kwargs) 49 | 50 | # Based on the prediction 51 | data = self.preprocess(data) 52 | self._set_data(data, self.model.predict(data)) 53 | 54 | 55 | class DifferentLabelBackgroundGenerator(BaseDifferent_BackgrondGenerator): 56 | """The background dataset will be constituted of the points with a different LABEL to that of the query instance. 57 | """ 58 | def __init__(self, model: Model, X: np.ndarray, y: np.ndarray, **kwargs): 59 | super().__init__(model, **kwargs) 60 | 61 | # Based on the label 62 | X = self.preprocess(X) 63 | self._set_data(X, y) 64 | -------------------------------------------------------------------------------- /src/cfshap/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | This module contains base classes and interfaces (protocol in Python jargon). 5 | 6 | Note: this module contains classes that are more general than needed for this package. 7 | This is to allow for future integration in a more general XAI package. 8 | 9 | Most of the interfaces, base classes and methods are self-explanatory. 10 | 11 | """ 12 | 13 | from abc import ABC, abstractmethod 14 | import warnings 15 | from typing import Union, List 16 | 17 | try: 18 | # In Python >= 3.8 this functionality is included in the standard library 19 | from typing import Protocol 20 | from typing import runtime_checkable 21 | except (ImportError, ModuleNotFoundError): 22 | # Python < 3.8 - Backward Compatibility through package 23 | from typing_extensions import Protocol 24 | from typing_extensions import runtime_checkable 25 | 26 | import numpy as np 27 | 28 | from .utils import attrdict 29 | from .utils.random import np_sample 30 | 31 | __all__ = [ 32 | 'Scaler', 33 | 'Model', 34 | 'ModelWithDecisionFunction', 35 | 'XGBWrapping', 36 | 'Explainer', 37 | 'ExplainerSupportsDynamicBackground', 38 | 'BaseExplainer', 39 | 'BaseSupportsDynamicBackground', 40 | 'BaseGroupExplainer', 41 | 'BackgroundGenerator', 42 | 'CounterfactualMethod', 43 | 'MultipleCounterfactualMethod', 44 | 'MultipleCounterfactualMethodSupportsWrapping', 45 | 'MultipleCounterfactualMethodWrappable', 46 | 'BaseCounterfactualMethod', 47 | 'BaseMultipleCounterfactualMethod', 48 | 'TrendEstimatorProtocol', 49 | 'ListOf2DArrays', 50 | 'CounterfactualEvaluationScorer', 51 | 'BaseCounterfactualEvaluationScorer', 52 | ] 53 | 54 | ListOf2DArrays = Union[List[np.ndarray], np.ndarray] 55 | 56 | # ------------------- MODELs, etc. ------------------------- 57 | 58 | 59 | @runtime_checkable 60 | class Model(Protocol): 61 | """Protocol for a ML model""" 62 | def predict(self, X: np.ndarray) -> np.ndarray: 63 | pass 64 | 65 | def predict_proba(self, X: np.ndarray) -> np.ndarray: 66 | pass 67 | 68 | 69 | @runtime_checkable 70 | class ModelWithDecisionFunction(Model, Protocol): 71 | """Protocol for a Model with a decision function as well""" 72 | def decision_function(self, X: np.ndarray) -> np.ndarray: 73 | pass 74 | 75 | 76 | @runtime_checkable 77 | class XGBWrapping(Model, Protocol): 78 | """Protocol for an XGBoost model wrapper""" 79 | def get_booster(self): 80 | pass 81 | 82 | 83 | @runtime_checkable 84 | class Scaler(Protocol): 85 | """Protocol for a Scaler""" 86 | def transform(self, X: np.ndarray) -> np.ndarray: 87 | pass 88 | 89 | def inverse_transform(self, X: np.ndarray) -> np.ndarray: 90 | pass 91 | 92 | 93 | # ------------------- Explainers, etc. ------------------------- 94 | 95 | 96 | class BaseClass(ABC): 97 | """Base class for all explainability methods""" 98 | def __init__(self, model: Model, scaler: Union[Scaler, None] = None, random_state: int = 2021): 99 | 100 | self._model = model 101 | self._scaler = scaler 102 | self.random_state = random_state 103 | 104 | # model and scaler cannot be changed at runtime. Set as properties. 105 | @property 106 | def model(self): 107 | return self._model 108 | 109 | @property 110 | def scaler(self): 111 | return self._scaler 112 | 113 | def preprocess(self, X: np.ndarray): 114 | if not isinstance(X, np.ndarray): 115 | raise ValueError('Must pass a NumPy array.') 116 | 117 | if len(X.shape) != 2: 118 | raise ValueError("The input data must be a 2D matrix.") 119 | 120 | if X.shape[0] == 0: 121 | raise ValueError( 122 | "An empty array was passed! You must pass a non-empty array of samples in order to generate explanations." 123 | ) 124 | 125 | return X 126 | 127 | def sample(self, X: np.ndarray, n: int): 128 | if n is not None: 129 | X = np_sample(X, n, random_state=self.random_state, safe=True) 130 | 131 | return X 132 | 133 | def scale(self, X: np.ndarray): 134 | if self.scaler is None: 135 | return X 136 | else: 137 | return self.scaler.transform(X) 138 | 139 | 140 | @runtime_checkable 141 | class Explainer(Protocol): 142 | """Protocol for an Explainer (a feature attribution/importance method). 143 | 144 | Attributes: 145 | model (Model): The model for which the feature importance is computed 146 | scaler (Scaler, optional): The scaler for the data. Default to None (i.e., no scaling). 147 | 148 | Methods: 149 | get_attributions(X): Returns the feature attributions. 150 | 151 | Optional Methods: 152 | get_trends(X): Returns the feature trends. 153 | get_backgrounds(X): Returns the background datasets. 154 | 155 | To build a new explainer one can easily extend BaseExplainer. 156 | """ 157 | 158 | model: Model 159 | scaler: Union[Scaler, None] 160 | 161 | def get_attributions(self, X): 162 | pass 163 | 164 | # Optional 165 | # def get_trends(self, X): 166 | # pass 167 | 168 | # def get_backgrounds(self, X): 169 | # pass 170 | 171 | 172 | @runtime_checkable 173 | class SupportsDynamicBackground(Protocol): 174 | """Additional Protocol for a class that supports at-runtime change of the background data.""" 175 | @property 176 | def data(self): 177 | pass 178 | 179 | @data.setter 180 | def data(self, data): 181 | pass 182 | 183 | 184 | @runtime_checkable 185 | class ExplainerSupportsDynamicBackground(Explainer, SupportsDynamicBackground, Protocol): 186 | """Protocol for an Explainer that supports at-runtime change of the background data""" 187 | pass 188 | 189 | 190 | class BaseExplainer(BaseClass, ABC): 191 | """Base class for a feature attribution/importance method""" 192 | @abstractmethod 193 | def get_attributions(self, X: np.ndarray) -> np.ndarray: 194 | """Generate the feature attributions for query instances X""" 195 | pass 196 | 197 | def get_trends(self, X: np.ndarray) -> np.ndarray: 198 | """Generate the feature trends for query instances X""" 199 | raise NotImplementedError('trends method is not implemented!') 200 | 201 | def get_backgrounds(self, X: np.ndarray) -> np.ndarray: 202 | """Returns the background datasets for query instances X""" 203 | raise NotImplementedError('get_backgrounds method is not implemented!') 204 | 205 | def __call__(self, X: np.ndarray) -> attrdict: 206 | """Returns the explanations 207 | 208 | Args: 209 | X (np.ndarray): The query instances 210 | 211 | Returns: 212 | attrdict: An attrdict (i.e., a dict which fields can be accessed also through attributes) with the following attributes: 213 | - .values : the feature attributions 214 | - .backgrounds : the background datasets (if any) 215 | - .trends : the feature trends (if any) 216 | """ 217 | X = self.preprocess(X) 218 | return attrdict( 219 | values=self.get_attributions(X), 220 | backgrounds=self.get_backgrounds(X), 221 | trends=self.get_trends(X), 222 | ) 223 | 224 | # Alias for 'get_attributions' for backward-compatibility 225 | def shap_values(self, *args, **kwargs): 226 | return self.get_attributions(*args, **kwargs) 227 | 228 | 229 | class BaseSupportsDynamicBackground(ABC): 230 | """Base class for a class that supports at-runtime change of the background data.""" 231 | @property 232 | def data(self): 233 | if self._data is None: 234 | self._raise_data_error() 235 | return self._data 236 | 237 | def _raise_data_error(self): 238 | raise ValueError('Must set background data first.') 239 | 240 | @data.setter 241 | @abstractmethod 242 | def data(self, data): 243 | pass 244 | 245 | 246 | class BaseGroupExplainer: 247 | """Base class for an explainer (feature attribution) for groups of features.""" 248 | def preprocess_groups(self, feature_groups: List[List[int]], nb_features): 249 | 250 | features_in_groups = sum(feature_groups, []) 251 | nb_groups = len(feature_groups) 252 | 253 | if nb_groups > nb_features: 254 | raise ValueError('There are more groups than features.') 255 | 256 | if len(set(features_in_groups)) != len(features_in_groups): 257 | raise ValueError('Some features are in multiple groups!') 258 | 259 | if len(set(features_in_groups)) < nb_features: 260 | raise ValueError('Not all the features are in groups') 261 | 262 | if any([len(x) == 0 for x in feature_groups]): 263 | raise ValueError('Some feature groups are empty!') 264 | 265 | return feature_groups 266 | 267 | 268 | # ------------------------------------- BACKGROUND GENERATOR -------------------------------------- 269 | 270 | 271 | @runtime_checkable 272 | class BackgroundGenerator(Protocol): 273 | """Protocol for a Background Generator: can be used together with an explainer to dynamicly generate backgrounds for each instance (see `composite`)""" 274 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays: 275 | """Returns the background datasets for the query instances. 276 | 277 | Args: 278 | X (np.ndarray): The query instances. 279 | 280 | Returns: 281 | ListOf2DArrays: The background datasets. 282 | """ 283 | pass 284 | 285 | 286 | class BaseBackgroundGenerator(BaseClass, ABC, BackgroundGenerator): 287 | """Base class for a background generator.""" 288 | @abstractmethod 289 | def get_backgrounds(self, X: np.ndarray) -> ListOf2DArrays: 290 | pass 291 | 292 | 293 | # ------------------------------------- TREND ESTIMATOR -------------------------------------- 294 | 295 | 296 | @runtime_checkable 297 | class TrendEstimatorProtocol(Protocol): 298 | """Protocol for a feature Trend Estimator""" 299 | def predict(self, X: np.ndarray, YY: ListOf2DArrays) -> np.ndarray: 300 | pass 301 | 302 | def __call__(self, x: np.ndarray, Y: np.ndarray) -> np.ndarray: 303 | pass 304 | 305 | 306 | # ------------------- Counterfactuals, etc. ------------------------- 307 | 308 | 309 | @runtime_checkable 310 | class CounterfactualMethod(Protocol): 311 | """Protocol for a counterfactual generation method (that generate a single counterfactual per query instance).""" 312 | model: Model 313 | 314 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray: 315 | pass 316 | 317 | 318 | @runtime_checkable 319 | class MultipleCounterfactualMethod(CounterfactualMethod, Protocol): 320 | """Protocol for a counterfactual generation method (that generate a single OR MULTIPLE counterfactuals per query instance).""" 321 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays: 322 | pass 323 | 324 | 325 | class BaseCounterfactualMethod(BaseClass, ABC, CounterfactualMethod): 326 | """Base class for a counterfactual generation method (that generate a single counterfactual per query instance).""" 327 | def __init__(self, *args, **kwargs): 328 | super().__init__(*args, **kwargs) 329 | 330 | self.invalid_counterfactual = 'raise' 331 | 332 | def _invalid_response(self, invalid: Union[None, str]) -> str: 333 | invalid = invalid or self.invalid_counterfactual 334 | assert invalid in ('nan', 'raise', 'ignore') 335 | return invalid 336 | 337 | def postprocess( 338 | self, 339 | X: np.ndarray, 340 | XC: np.ndarray, 341 | invalid: Union[None, str] = None, 342 | ) -> np.ndarray: 343 | """Post-process counterfactuals 344 | 345 | Args: 346 | X (np.ndarray : nb_samples x nb_features): The query instances 347 | XC (np.ndarray : nb_samples x nb_features): The counterfactuals 348 | invalid (Union[None, str], optional): It can have the following values. Defaults to None ('raise'). 349 | - 'nan': invalid counterfactuals (non changing prediction) will be marked with NaN 350 | - 'raise': an error will be raised if invalid counterfactuals are passed 351 | - 'ignore': Nothing will be node. Invalid counterfactuals will be returned. 352 | 353 | Returns: 354 | np.ndarray: The post-processed counterfactuals 355 | """ 356 | 357 | invalid = self._invalid_response(invalid) 358 | 359 | # Mask with the non-flipped counterfactuals 360 | not_flipped_mask = (self.model.predict(X) == self.model.predict(XC)) 361 | if not_flipped_mask.sum() > 0: 362 | if invalid == 'raise': 363 | self._raise_invalid() 364 | elif invalid == 'nan': 365 | self._warn_invalid() 366 | XC[not_flipped_mask, :] = np.nan 367 | 368 | return XC 369 | 370 | def _warn_invalid(self): 371 | warnings.warn('!! ERROR: Some counterfactuals are NOT VALID (will be set to NaN)') 372 | 373 | def _raise_invalid(self): 374 | raise RuntimeError('Invalid counterfactuals') 375 | 376 | def _raise_nan(self): 377 | raise RuntimeError('NaN counterfactuals are generated before post-processing.') 378 | 379 | def _raise_inf(self): 380 | raise RuntimeError('+/-inf counterfactuals are generated before post-processing.') 381 | 382 | @abstractmethod 383 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray: 384 | pass 385 | 386 | 387 | class BaseMultipleCounterfactualMethod(BaseCounterfactualMethod): 388 | """Base class for a counterfactual generation method (that generate a single OR MULTIPLE counterfactuals per query instance).""" 389 | def multiple_postprocess( 390 | self, 391 | X: np.ndarray, 392 | XX_C: ListOf2DArrays, 393 | invalid: Union[None, str] = None, 394 | allow_nan: bool = True, 395 | allow_inf: bool = False, 396 | ) -> ListOf2DArrays: 397 | """Post-process multiple counterfactuals 398 | 399 | Args: 400 | X (np.ndarray : nb_samples x nb_features): The query instances 401 | XX_C (ListOf2DArrays : nb_samples x nb_counterfactuals x nb_features): The counterfactuals 402 | invalid (Union[None, str], optional): It can have the following values. Defaults to None ('raise'). 403 | - 'nan': invalid counterfactuals (non changing prediction) will be marked with NaN 404 | - 'raise': an error will be raised if invalid counterfactuals are passed 405 | - 'ignore': Nothing will be node. Invalid counterfactuals will be returned. 406 | allow_nan (bool, optional): If True, allows NaN counterfactuals a input (invalid). If False, it raises an error. Defaults to True. 407 | allow_inf (bool, optional): If True, allows infinite in counterfactuals. If False, it raise an error. Defaults to False. 408 | 409 | Returns: 410 | ListOf2DArrays : The post-processed counterfactuals 411 | """ 412 | 413 | invalid = self._invalid_response(invalid) 414 | 415 | # Reshape (for zero-length arrays) 416 | XX_C = [X_C.reshape(-1, X.shape[1]) for X_C in XX_C] 417 | 418 | # Check for NaN and Inf 419 | for XC in XX_C: 420 | if not allow_nan and np.isnan(XC).sum() != 0: 421 | self._raise_nan() 422 | if not allow_inf and np.isinf(XC).sum() != 0: 423 | self._raise_inf() 424 | 425 | # Mask with the non-flipped counterfactuals 426 | nb_counters = np.array([X_c.shape[0] for X_c in XX_C]) 427 | not_flipped_mask = np.equal( 428 | np.repeat(self.model.predict(X), nb_counters), 429 | self.model.predict(np.concatenate(XX_C, axis=0)), 430 | ) 431 | if not_flipped_mask.sum() > 0: 432 | if invalid == 'raise': 433 | print('X, f(X) :', X, self.model.predict(X)) 434 | print('X_C, f(X_C) :', XX_C, self.model.predict(np.concatenate(XX_C, axis=0))) 435 | self._raise_invalid() 436 | elif invalid == 'nan': 437 | self._warn_invalid() 438 | sections = np.cumsum(nb_counters[:-1]) 439 | not_flipped_mask = np.split(not_flipped_mask, indices_or_sections=sections) 440 | 441 | # Set them to nan 442 | for i, nfm in enumerate(not_flipped_mask): 443 | XX_C[i][nfm, :] = np.nan 444 | 445 | return XX_C 446 | 447 | def multiple_trace_postprocess(self, X, XTX_counter, invalid=None): 448 | invalid = self._invalid_response(invalid) 449 | 450 | # Reshape (for zero-length arrays) 451 | XTX_counter = [[X_C.reshape(-1, X.shape[1]) for X_C in TX_C] for TX_C in XTX_counter] 452 | 453 | # Mask with the non-flipped counterfactuals 454 | shapess = [[X_C.shape[0] for X_C in TX_C] for TX_C in XTX_counter] 455 | shapes = [sum(S) for S in shapess] 456 | 457 | X_counter = np.concatenate([np.concatenate(TX_C, axis=0) for TX_C in XTX_counter], axis=0) 458 | not_flipped_mask = np.equal( 459 | np.repeat(self.model.predict(X), shapes), 460 | self.model.predict(X_counter), 461 | ) 462 | if not_flipped_mask.sum() > 0: 463 | if invalid == 'raise': 464 | self._raise_invalid() 465 | elif invalid == 'nan': 466 | self._warn_invalid() 467 | sections = np.cumsum(shapes[:-1]) 468 | sectionss = [np.cumsum(s[:-1]) for s in shapess] 469 | not_flipped_mask = np.split(not_flipped_mask, indices_or_sections=sections) 470 | not_flipped_mask = [np.split(NFM, indices_or_sections=s) for NFM, s in zip(not_flipped_mask, sectionss)] 471 | 472 | # Set them to nan 473 | for i, NFM in enumerate(not_flipped_mask): 474 | for j, nfm in enumerate(NFM): 475 | X_counter[i][j][nfm, :] = np.nan 476 | 477 | return XTX_counter 478 | 479 | @abstractmethod 480 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays: 481 | pass 482 | 483 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray: 484 | return np.array([X_C[0] for X_C in self.get_multiple_counterfactuals(X)]) 485 | 486 | # Alias backward compatibility 487 | def diverse_postprocess(self, *args, **kwargs): 488 | return self.multiple_postprocess(*args, **kwargs) 489 | 490 | def diverse_trace_postprocess(self, *args, **kwargs): 491 | return self.multiple_trace_postprocess(*args, **kwargs) 492 | 493 | 494 | @runtime_checkable 495 | class Wrappable(Protocol): 496 | verbose: Union[int, bool] 497 | 498 | 499 | @runtime_checkable 500 | class SupportsWrapping(Protocol): 501 | @property 502 | def data(self): 503 | pass 504 | 505 | @data.setter 506 | @abstractmethod 507 | def data(self, data): 508 | pass 509 | 510 | 511 | @runtime_checkable 512 | class MultipleCounterfactualMethodSupportsWrapping(MultipleCounterfactualMethod, SupportsWrapping, Protocol): 513 | """Protocol for a counterfactual method that can be wrapped by another one 514 | (i.e., the output of a SupportsWrapping method can be used as background data of another)""" 515 | pass 516 | 517 | 518 | @runtime_checkable 519 | class MultipleCounterfactualMethodWrappable(MultipleCounterfactualMethod, Wrappable, Protocol): 520 | """Protocol for a counterfactual method that can used as wrapping for another one 521 | (i.e., a Wrappable method can use the ouput of an another CFX method as input)""" 522 | pass 523 | 524 | 525 | # ------------------------ EVALUATION ----------------------- 526 | 527 | 528 | @runtime_checkable 529 | class CounterfactualEvaluationScorer(Protocol): 530 | """Protocol for an evaluation method that returns an array of scores (float) for a list of counterfactuals.""" 531 | def score(self, X: np.ndarray) -> np.ndarray: 532 | pass 533 | 534 | 535 | class BaseCounterfactualEvaluationScorer(ABC): 536 | @abstractmethod 537 | def score(self, X: np.ndarray) -> np.ndarray: 538 | pass 539 | -------------------------------------------------------------------------------- /src/cfshap/counterfactuals/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Counterfactuals Generation (Algorithmic Recourse) Techniques. 5 | """ 6 | 7 | from .knn import * 8 | from .project import * 9 | from .composition import * -------------------------------------------------------------------------------- /src/cfshap/counterfactuals/composition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Counterfactual method that composes two children methods. 5 | One as a wrapper and one as a wrapped method. 6 | The wrapper method will use the counterfactuals generated by the wrapped method as background. 7 | """ 8 | 9 | __all__ = [ 10 | 'CounterfactualComposition', 11 | 'compose_diverse_counterfactual_method', 12 | ] 13 | 14 | from copy import deepcopy 15 | import numpy as np 16 | from tqdm import tqdm 17 | 18 | from ..base import ( 19 | BaseMultipleCounterfactualMethod, 20 | MultipleCounterfactualMethodSupportsWrapping, 21 | MultipleCounterfactualMethodWrappable, 22 | ) 23 | 24 | class CounterfactualComposition(BaseMultipleCounterfactualMethod): 25 | def __init__( 26 | self, 27 | wrapping_instance: MultipleCounterfactualMethodSupportsWrapping, 28 | wrapped_instance: MultipleCounterfactualMethodWrappable, 29 | verbose=True, 30 | ): 31 | 32 | if wrapping_instance.model != wrapped_instance.model: 33 | raise ValueError('Models of wrapping and wrapped method differs.') 34 | 35 | super().__init__(wrapping_instance.model, None) 36 | 37 | if not isinstance(wrapped_instance, MultipleCounterfactualMethodWrappable): 38 | raise ValueError('wrapped_instance do not implement necessary interface.') 39 | 40 | if not isinstance(wrapping_instance, MultipleCounterfactualMethodSupportsWrapping): 41 | raise ValueError('wrapping_instance do not implement necessary interface.') 42 | 43 | # Scalers are ignored (they are already into the wrapped/wrapping instances) 44 | super().__init__(wrapping_instance.model, scaler=None) 45 | 46 | wrapping_instance = deepcopy(wrapping_instance) 47 | wrapped_instance = deepcopy(wrapped_instance) 48 | 49 | wrapped_instance.verbose = 0 50 | 51 | self.wrapping_instance = wrapping_instance 52 | self.wrapped_instance = wrapped_instance 53 | 54 | self.data = None 55 | self.verbose = verbose 56 | 57 | def get_multiple_counterfactuals(self, X): 58 | X = self.preprocess(X) 59 | 60 | # Get underlying counterfactuals 61 | X_Cs = self.wrapped_instance.get_multiple_counterfactuals(X) 62 | 63 | iters = enumerate(zip(X, X_Cs)) 64 | if len(X) > 100 and self.verbose: 65 | iters = tqdm(iters, desc='Composition CF: Step 2') 66 | 67 | # Postprocess them 68 | for i, (x, X_C) in iters: 69 | # Set background 70 | self.data = self.wrapping_instance.data = X_C 71 | 72 | # Compute counterfactuals 73 | X_Cs[i] = self.wrapping_instance.get_multiple_counterfactuals(np.array([x]))[0] 74 | 75 | return X_Cs 76 | 77 | 78 | def compose_diverse_counterfactual_method(*args, **kwargs): 79 | if len(args) < 2: 80 | raise ValueError('At least 2 counterfactual methods methods must be passed.') 81 | 82 | # Most inner 83 | composition = args[-1] 84 | 85 | # Iteratively compose 86 | for wrapping_instance in args[-2::-1]: 87 | composition = CounterfactualComposition(wrapping_instance, composition, **kwargs) 88 | 89 | return composition 90 | -------------------------------------------------------------------------------- /src/cfshap/counterfactuals/knn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Implementation of K-Nearest Neighbours Counterfactuals 5 | """ 6 | 7 | __all__ = ['KNNCounterfactuals'] 8 | 9 | from typing import Union 10 | 11 | import numpy as np 12 | from sklearn.neighbors import NearestNeighbors 13 | 14 | from ..base import ( 15 | BaseMultipleCounterfactualMethod, 16 | Model, 17 | Scaler, 18 | ListOf2DArrays, 19 | ) 20 | from ..utils import keydefaultdict 21 | 22 | class KNNCounterfactuals(BaseMultipleCounterfactualMethod): 23 | """Returns the K Nearest Neighbours of the query instance with a different prediction. 24 | 25 | """ 26 | def __init__( 27 | self, 28 | model: Model, 29 | scaler: Union[None, Scaler], 30 | X: np.ndarray, 31 | nb_diverse_counterfactuals: Union[None, int, float] = None, 32 | n_neighbors: Union[None, int, float] = None, 33 | distance: str = None, 34 | max_samples: int = int(1e10), 35 | random_state: int = 2021, 36 | verbose: int = 0, 37 | **distance_params, 38 | ): 39 | """ 40 | 41 | Args: 42 | model (Model): The model. 43 | scaler (Union[None, Scaler]): The scaler for the data. 44 | X (np.ndarray): The background dataset. 45 | nb_diverse_counterfactuals (Union[None, int, float], optional): Number of counterfactuals to generate. Defaults to None. 46 | n_neighbors (Union[None, int, float], optional): Number of neighbours to generate. Defaults to None. 47 | Note that this is an alias for nb_diverse_counterfactuals in this class. 48 | distance (str, optional): The distance metric to use for K-NN. Defaults to None. 49 | max_samples (int, optional): Number of samples of the background to use at most. Defaults to int(1e10). 50 | random_state (int, optional): Random seed. Defaults to 2021. 51 | verbose (int, optional): Level of verbosity. Defaults to 0. 52 | **distance_params: Additional parameters for the distance metric 53 | """ 54 | 55 | assert nb_diverse_counterfactuals is not None or n_neighbors is not None, 'nb_diverse_counterfactuals or n_neighbors must be set.' 56 | 57 | super().__init__(model, scaler, random_state) 58 | 59 | self._metric, self._metric_params = distance, distance_params 60 | self.__nb_diverse_counterfactuals = nb_diverse_counterfactuals 61 | self.__n_neighbors = n_neighbors 62 | self.max_samples = max_samples 63 | 64 | self.data = X 65 | self.verbose = verbose 66 | 67 | @property 68 | def data(self): 69 | return self._data 70 | 71 | @data.setter 72 | def data(self, data): 73 | self._raw_data = self.preprocess(data) 74 | if self.max_samples < len(self._raw_data): 75 | self._raw_data = self.sample(self._raw_data, n=self.max_samples) 76 | self._preds = self.model.predict(self._raw_data) 77 | 78 | # In the base class this two information are equivalent 79 | if self.__n_neighbors is None: 80 | self.__n_neighbors = self.__nb_diverse_counterfactuals 81 | if self.__nb_diverse_counterfactuals is None: 82 | self.__nb_diverse_counterfactuals = self.__n_neighbors 83 | 84 | def get_nb_of_items(nb): 85 | if np.isinf(nb): 86 | return keydefaultdict(lambda pred: self._data[pred].shape[0]) 87 | elif isinstance(nb, int) and nb >= 1: 88 | return keydefaultdict(lambda pred: min(nb, self._data[pred].shape[0])) 89 | elif isinstance(nb, float) and nb <= 1.0 and nb > 0.0: 90 | return keydefaultdict(lambda pred: int(max(1, round(len(self._data[pred]) * nb)))) 91 | else: 92 | raise ValueError( 93 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)' 94 | ) 95 | 96 | self._n_neighbors = get_nb_of_items(self.__n_neighbors) 97 | self._nb_diverse_counterfactuals = get_nb_of_items(self.__nb_diverse_counterfactuals) 98 | 99 | # We will be searching neighbors of a different class 100 | self._data = keydefaultdict(lambda pred: self._raw_data[self._preds != pred]) 101 | 102 | self._nn = keydefaultdict(lambda pred: NearestNeighbors( 103 | n_neighbors=self._n_neighbors[pred], 104 | metric=self._metric, 105 | p=self._metric_params['p'] if 'p' in self._metric_params else 2, 106 | metric_params=self._metric_params, 107 | ).fit(self.scale(self._data[pred]))) 108 | 109 | def get_counterfactuals(self, X: np.ndarray) -> np.ndarray: 110 | """Generate the closest counterfactual for each query instance""" 111 | 112 | # Pre-process 113 | X = self.preprocess(X) 114 | 115 | preds = self.model.predict(X) 116 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)} 117 | 118 | X_counter = np.zeros_like(X) 119 | 120 | for pred, indices in preds_indices.items(): 121 | _, neighbors_indices = self._nn[pred].kneighbors(self.scale(X), n_neighbors=1) 122 | X_counter[indices] = self._data[pred][neighbors_indices.flatten()] 123 | 124 | # Post-process 125 | X_counter = self.postprocess(X, X_counter, invalid='raise') 126 | 127 | return X_counter 128 | 129 | def get_multiple_counterfactuals(self, X: np.ndarray) -> ListOf2DArrays: 130 | """Generate the multiple closest counterfactuals for each query instance""" 131 | 132 | # Pre-condition 133 | assert self.__n_neighbors == self.__nb_diverse_counterfactuals, ( 134 | 'n_neighbors and nb_diverse_counterfactuals are set to different values' 135 | f'({self.__n_neighbors} != {self.__nb_diverse_counterfactuals}).' 136 | 'When both are set they must be set to the same value.') 137 | 138 | # Pre-process 139 | X = self.preprocess(X) 140 | 141 | preds = self.model.predict(X) 142 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)} 143 | 144 | X_counter = [ 145 | np.full((self._nb_diverse_counterfactuals[preds[i]], X.shape[1]), np.nan) for i in range(X.shape[0]) 146 | ] 147 | 148 | for pred, indices in preds_indices.items(): 149 | _, neighbors_indices = self._nn[pred].kneighbors(self.scale(X[indices]), n_neighbors=None) 150 | counters = self._data[pred][neighbors_indices.flatten()].reshape(len(indices), 151 | self._nb_diverse_counterfactuals[pred], -1) 152 | for e, i in enumerate(indices): 153 | # We use :counters[e].shape[0] so it raises an exception if shape are not coherent. 154 | X_counter[i][:counters[e].shape[0]] = counters[e] 155 | 156 | # Post-process 157 | X_counter = self.diverse_postprocess(X, X_counter, invalid='raise') 158 | 159 | return X_counter -------------------------------------------------------------------------------- /src/cfshap/counterfactuals/project.py: -------------------------------------------------------------------------------- 1 | """ 2 | Counterfactuals as projection onto the decision boundary (estimated by bisection). 3 | """ 4 | 5 | __author__ = 'Emanuele Albini' 6 | __all__ = ['BisectionProjectionDBCounterfactuals'] 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from ..base import BaseMultipleCounterfactualMethod 12 | from ..utils import keydefaultdict 13 | from ..utils.project import find_decision_boundary_bisection 14 | 15 | 16 | class BisectionProjectionDBCounterfactuals(BaseMultipleCounterfactualMethod): 17 | def __init__(self, 18 | model, 19 | data, 20 | num_iters=100, 21 | earlystop_error=1e-8, 22 | scaler=None, 23 | max_samples=1e10, 24 | nb_diverse_counterfactuals=None, 25 | random_state=0, 26 | verbose=1, 27 | **kwargs): 28 | 29 | if scaler is not None: 30 | raise NotImplementedError('Scaling is not supported by Cone.') 31 | 32 | super().__init__(model, scaler, random_state=random_state) 33 | 34 | self.num_iters = num_iters 35 | self.earlystop_error = earlystop_error 36 | self.kwargs = kwargs 37 | 38 | self._max_samples = min(max_samples or np.inf, nb_diverse_counterfactuals or np.inf) 39 | 40 | self.data = data 41 | self.verbose = verbose 42 | 43 | @property 44 | def data(self): 45 | return self._data 46 | 47 | @data.setter 48 | def data(self, data): 49 | if data is not None: 50 | self._raw_data = self.preprocess(data) 51 | self._preds = self.model.predict(self._raw_data) 52 | 53 | self._data = keydefaultdict( 54 | lambda pred: self.sample(self._raw_data[self._preds != pred], n=self._max_samples)) 55 | else: 56 | self._raw_data = None 57 | self._preds = None 58 | self._data = None 59 | 60 | def __get_counterfactuals(self, x, pred): 61 | if self.data is not None: 62 | return np.array( 63 | find_decision_boundary_bisection( 64 | x=x, 65 | Y=self.data[pred], 66 | model=self.model, 67 | scaler=self.scaler, 68 | num=self.num_iters, 69 | error=self.earlystop_error, 70 | method='counterfactual', 71 | desc=None, 72 | # model_parallelism=1, 73 | # n_jobs=1, 74 | **self.kwargs, 75 | )) 76 | else: 77 | raise ValueError('Invalid state. `self.data` must be set.') 78 | 79 | def get_multiple_counterfactuals(self, X): 80 | X = self.preprocess(X) 81 | preds = self.model.predict(X) 82 | X_counter = [] 83 | iters = tqdm(zip(X, preds)) if self.verbose else zip(X, preds) 84 | for x, pred in iters: 85 | X_counter.append(self.__get_counterfactuals(x, pred)) 86 | 87 | return self.diverse_postprocess(X, X_counter) 88 | -------------------------------------------------------------------------------- /src/cfshap/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Evaluation methods for explanations. 5 | """ 6 | -------------------------------------------------------------------------------- /src/cfshap/evaluation/attribution/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Evaluation methods for feature attributions. 5 | """ 6 | 7 | from ._basic import * -------------------------------------------------------------------------------- /src/cfshap/evaluation/attribution/_basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Basic evaluation utilities for feature attributions 5 | """ 6 | 7 | from typing import Union 8 | import numpy as np 9 | import pandas as pd 10 | 11 | __all__ = [ 12 | 'feature_attributions_statistics', 13 | ] 14 | 15 | 16 | def feature_attributions_statistics(phi: np.ndarray, mean=False) -> Union[pd.DataFrame, pd.Series]: 17 | """Calculate some statistics on feature attributions 18 | 19 | Args: 20 | phi (np.ndarray: nb_samples x nb_features): The feature attributions 21 | mean (bool, optional): Calculate the mean. Defaults to True. 22 | 23 | Returns: 24 | The statistics 25 | pd.DataFrame: By default 26 | pd.Series: If mean is True. 27 | """ 28 | stats = pd.DataFrame() 29 | stats['Non-computable Phi (NaN)'] = np.any(np.isnan(phi), axis=1) 30 | stats['Non-Positive Phi'] = np.all(phi < 0, axis=1) 31 | stats['Non-Negative Phi'] = np.all(phi > 0, axis=1) 32 | stats['All Zeros Phi'] = np.all(phi == 0, axis=1) 33 | stats['Number of Feature with Positive Phi'] = np.sum(phi > 0, axis=1) 34 | stats['Number of Feature with Negative Phi'] = np.sum(phi < 0, axis=1) 35 | stats['Number of Feature with Zero Phi'] = np.sum(phi == 0, axis=1) 36 | 37 | if mean: 38 | # Rename columns accordingly 39 | def __rename(name): 40 | if name.startswith('Number'): 41 | return 'Average ' + name 42 | else: 43 | return 'Percentage of ' + name 44 | 45 | stats = stats.rename(columns={name: __rename(name) for name in stats.columns.values}) 46 | 47 | # Mean over all the attributions in the dataset 48 | stats = stats.mean() 49 | return stats -------------------------------------------------------------------------------- /src/cfshap/evaluation/attribution/induced_counterfactual/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Method to evaluate the actionability of a feature attribution technique. 5 | 6 | This is the implementation of the method officialy proposed in our paper. 7 | Therein called counterfactual-ability. 8 | 9 | Counterfactual Shapely Additive Explanations 10 | https://arxiv.org/pdf/2110.14270.pdf 11 | """ 12 | 13 | from ._actionability import * -------------------------------------------------------------------------------- /src/cfshap/evaluation/attribution/induced_counterfactual/_actionability.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from ....utils import attrdict 8 | from ....utils.preprocessing import MultiScaler 9 | 10 | from ....base import Model 11 | from ....utils.tree import TreeEnsemble, get_shap_compatible_model 12 | from ._utils import ( 13 | transform_trees, 14 | mask_values, 15 | inverse_actionabiltiy, 16 | ) 17 | 18 | __all__ = ['TreeInducedCounterfactualGeneratorV2'] 19 | 20 | 21 | class TreeInducedCounterfactualGeneratorV2: 22 | def __init__( 23 | self, 24 | model: Model, 25 | data: np.ndarray = None, 26 | multiscaler: np.ndarray = None, 27 | global_feature_trends: Union[None, list, np.ndarray] = None, 28 | random_state: int = 0, 29 | ): 30 | assert data is not None or multiscaler is not None 31 | 32 | self.model = model 33 | self.data = data 34 | self.multiscaler = multiscaler 35 | self.global_feature_trends = global_feature_trends 36 | self.random_state = random_state 37 | 38 | @property 39 | def multiscaler(self): 40 | return self._multiscaler 41 | 42 | @multiscaler.setter 43 | def multiscaler(self, multiscaler): 44 | self._multiscaler = multiscaler 45 | 46 | @property 47 | def data(self): 48 | raise NotImplementedError() 49 | 50 | @data.setter 51 | def data(self, data): 52 | if data is not None: 53 | self._multiscaler = MultiScaler(self._data) 54 | 55 | @property 56 | def model(self): 57 | return self._model 58 | 59 | @property 60 | def shap_model(self): 61 | return self._shap_model 62 | 63 | @model.setter 64 | def model(self, model): 65 | self._model = model 66 | self._shap_model = get_shap_compatible_model(self.model) 67 | 68 | def transform( 69 | self, 70 | X: np.ndarray, 71 | explanations: attrdict, 72 | K: Tuple[int, float] = (1, 2, 3, 4, 5, np.inf), 73 | action_cost_normalization: str = 'quantile', 74 | action_cost_aggregation: str = 'L1', # L0, L1, L2, Linf 75 | action_strategy: str = 'proportional', # 'equal', 'proportional', 'random' 76 | action_scope: str = 'positive', # 'positive', 'negative', 'all' 77 | action_direction: str = 'local', # 'local', 'global', 'random' 78 | precision: float = 1e-4, 79 | show_progress: bool = True, 80 | starting_sample: int = 0, 81 | max_samples: Union[int, None] = None, 82 | nan_explanation: str = 'raise', # 'ignore' 83 | counters: np.ndarray = None, 84 | costs: np.ndarray = None, 85 | desc: str = "", 86 | ): 87 | 88 | # Must pass trends if global trends are used 89 | assert self.global_feature_trends is not None or action_direction != 'global' 90 | # Must pass local trends if local trends are used 91 | assert hasattr(explanations, 'trends') or action_direction != 'local' 92 | 93 | assert (counters is None and costs is None) or (counters is not None and costs is not None) 94 | 95 | assert action_cost_aggregation[0] == 'L' 96 | action_cost_aggregation = float(action_cost_aggregation[1:]) 97 | 98 | random_state = np.random.RandomState(self.random_state) 99 | 100 | # Compute stuff from the results 101 | Xn = self.multiscaler.transform(X, method=action_cost_normalization) 102 | Pred = self.model.predict(X) 103 | Phi = explanations.values 104 | 105 | # Let's raise some warning in case of NaN 106 | nb_PhiNaN = np.any(np.isnan(Phi), axis=1).sum() 107 | if nb_PhiNaN > 0: 108 | logging.warning(f'There are {nb_PhiNaN} NaN explanations.') 109 | if nb_PhiNaN == len(X): 110 | logging.error('All explanations are NaN.') 111 | 112 | # Pre-compute directions of change 113 | if action_direction == 'local': 114 | Directions = explanations.trends 115 | elif action_direction == 'global': 116 | Directions = (-1 * (np.broadcast_to(Pred.reshape(-1, 1), X.shape) - .5) * 2 * 117 | np.broadcast_to(np.array(self.global_feature_trends), X.shape)).astype(int) 118 | elif action_direction == 'random': 119 | Directions = random_state.choice((-1, 1), X.shape) 120 | else: 121 | raise ValueError('Invalid direction_strategy.') 122 | 123 | # If we don't have a direction we set Phi to 0 124 | Phi = Phi * (Directions != 0) 125 | 126 | logging.info(f"There are {np.any(Directions == 0, axis = 1).sum()}/{len(X)} samples with some tau_i = 0") 127 | 128 | assert np.all((Directions == -1) | (Directions == +1) | (np.isnan(Phi)) 129 | | ((Directions == 0) & (Phi == 0))), "Some trends are not +1 or -1" 130 | 131 | # Compute the order of feature-change 132 | # NOTE: This must be done before the action_strategy computation below! 133 | if action_scope == 'all': 134 | Order = np.argsort(-1 * np.abs(Phi), axis=1) 135 | elif action_scope == 'positive': 136 | Order = np.argsort(-1 * Phi, axis=1) 137 | Phi = Phi * (Phi > 0) 138 | elif action_scope == 'negative': 139 | Order = np.argsort(+1 * Phi, axis=1) 140 | Phi = Phi * (Phi < 0) 141 | else: 142 | raise ValueError('Invalid action_scope.') 143 | 144 | assert np.all((Phi >= 0) | (np.isnan(Phi))), "Something weird happend: Phi should be >= 0 at this point" 145 | 146 | # Modify Phi based on the action strategy 147 | if action_strategy == 'proportional': 148 | pass 149 | elif action_strategy == 'equal': 150 | Phi = 1 * (Phi > 0) 151 | elif action_strategy == 'random': 152 | Phi = random_state.random(Phi.shape) * (Phi != 0) * (~np.isnan(Phi)) 153 | elif action_strategy == 'noisy': 154 | Phi = Phi * random_state.random(Phi.shape) 155 | else: 156 | raise ValueError('Invalid action_strategy.') 157 | 158 | if len(np.unique(Pred)) != 1: 159 | raise RuntimeError('Samples have different predicted class! This is weird: stopping.') 160 | 161 | assert np.all((Phi >= 0) | (np.isnan(Phi))), "Something weird happend: Phi should still be >= 0 at this point" 162 | 163 | # Do some shape checks 164 | assert X.shape == Xn.shape 165 | assert X.shape == Phi.shape 166 | assert X.shape == Order.shape 167 | assert X.shape == Directions.shape 168 | assert (X.shape[0], ) == Pred.shape 169 | assert np.all(self.model.predict(X) == Pred) 170 | assert np.all(Pred % Pred.astype(np.int) == 0) 171 | 172 | # Compute trees (and cost-normalized version) 173 | trees = TreeEnsemble(self.shap_model).trees 174 | ntrees = transform_trees(trees, self.multiscaler, action_cost_normalization) 175 | 176 | # Run the experiments from scratch 177 | if counters is None: 178 | counters = [] 179 | costs = [] 180 | 181 | # Resume the experiments 182 | else: 183 | assert counters.shape[0] == costs.shape[0] == (min(X[starting_sample:].shape[0], max_samples or np.inf)) 184 | assert counters.shape[1] == costs.shape[1] 185 | assert counters.shape[1] >= len(K) 186 | assert counters.shape[2] == X.shape[1] 187 | 188 | counters = counters.copy() 189 | costs = costs.copy() 190 | 191 | counters = counters.transpose((1, 0, 2)) # Top-K x Samples x nb_features 192 | costs = costs.transpose((1, 0)) # Top-K x Samples 193 | recompute_mask = np.isinf(costs) 194 | 195 | for kidx, k in enumerate(K): 196 | # Mask only top-k features 197 | PhiMasked = mask_values(Phi, Order, k) 198 | 199 | # Create object to be iterated 200 | iters = np.array(list(zip(X, Xn, PhiMasked, Directions, Pred))[starting_sample:], dtype=np.object) 201 | if max_samples is not None: 202 | iters = iters[:max_samples] 203 | if not isinstance(counters, list): 204 | iters = iters[recompute_mask[kidx]] 205 | 206 | if show_progress: 207 | iters = tqdm( 208 | iters, 209 | desc=(f'{desc} - (A)St={action_strategy}/D={action_direction}/Sc={action_scope}/K={k} ' 210 | '(C)N={action_cost_normalization}/Agg={action_cost_aggregation}'), 211 | ) 212 | 213 | # Iterate over all samples 214 | results = [] 215 | 216 | for a, args in enumerate(iters): 217 | # print(a) # TQDM may not be precise 218 | results.append( 219 | inverse_actionabiltiy( 220 | *args, 221 | model=self.model, 222 | multiscaler=self.multiscaler, 223 | normalization=action_cost_normalization, 224 | trees=trees, 225 | ntrees=ntrees, 226 | degree=action_cost_aggregation, 227 | precision=precision, 228 | error=nan_explanation, 229 | )) 230 | 231 | counters_ = [r[0] for r in results] 232 | costs_ = [r[1] for r in results] 233 | if isinstance(counters, list): 234 | counters.append(counters_) 235 | costs.append(costs_) 236 | else: 237 | if recompute_mask[kidx].sum() > 0: 238 | counters[kidx][recompute_mask[kidx]] = np.array(counters_) 239 | costs[kidx][recompute_mask[kidx]] = np.array(costs_) 240 | 241 | # Counterfactuals 242 | if isinstance(counters, list): 243 | counters = np.array(counters) # Top-K x Samples x nb_features 244 | counters = counters.transpose((1, 0, 2)) # Samples x Top-K x nb_features 245 | 246 | # Costs 247 | if isinstance(costs, list): 248 | costs = np.array(costs) # Top-K x Samples 249 | costs = costs.transpose((1, 0)) # Samples x Top-K 250 | 251 | return counters, costs -------------------------------------------------------------------------------- /src/cfshap/evaluation/attribution/induced_counterfactual/_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import warnings 4 | from copy import deepcopy 5 | 6 | import numba 7 | import numpy as np 8 | 9 | __all__ = ['transform_trees', 'mask_values', 'inverse_actionabiltiy'] 10 | 11 | EPSILON = np.finfo(float).eps 12 | 13 | 14 | def mask_values(Phi, Order, k): 15 | if np.isinf(k): 16 | k = Phi.shape[1] 17 | 18 | # Select the subset 19 | SubsetCols = Order[:, :k] 20 | SubsetRows = np.tile(np.arange(Phi.shape[0]), (SubsetCols.shape[1], 1)).T 21 | 22 | # Create a mask 23 | OrderMask = np.zeros(Phi.shape, dtype=np.int) 24 | OrderMask[SubsetRows, SubsetCols] = 1 25 | 26 | # Apply the mask 27 | PhiMasked = Phi * OrderMask 28 | 29 | return PhiMasked 30 | 31 | 32 | def transform_trees(trees, multiscaler, method): 33 | ntrees = deepcopy(trees) 34 | for i in range(len(ntrees)): 35 | for j in range(len(ntrees[i].features)): 36 | if ntrees[i].children_left[j] != -1: # A split (not a leaf) 37 | ntrees[i].thresholds[j] = multiscaler.value_transform(ntrees[i].thresholds[j], ntrees[i].features[j], 38 | method) 39 | return ntrees 40 | 41 | 42 | def cost_trees(trees, ntrees, xn): 43 | ctrees = deepcopy(trees) 44 | for i in range(len(ctrees)): 45 | ctrees[i].thresholds_c = np.zeros(ctrees[i].thresholds.shape) 46 | for j in range(len(ctrees[i].thresholds_c)): 47 | if ctrees[i].children_left[j] != -1: 48 | ctrees[i].thresholds_c[j] = ntrees[i].thresholds[j] - xn[ctrees[i].features[j]] 49 | return ctrees 50 | 51 | 52 | @numba.jit(nopython=True, nogil=True) 53 | def filter_region(region, x, phi, tau, f, mc, Mc, m, M): 54 | # Get contraint 55 | t = tau[f] 56 | v = phi[f] 57 | 58 | if v == 0.0: 59 | # Must be 0-cost 60 | if (Mc < 0.0) or (mc > 0.0): 61 | return None 62 | else: 63 | # Must be trend/sign compatible (direction of the change to the feature) 64 | if (t > 0 and Mc <= 0.0) or (t < 0 and mc >= 0.0): 65 | return None 66 | 67 | # This is a super-rare error caused by np.float32 <> np.float64 conflicts 68 | # TreeEnsemble returns np.float32, but the dataset could be np.float64 69 | # The multiscaler then passed with a np.float32 (64) spits a np.float32 (64), it does not change type 70 | # This results in a wrong mc/Mc = thr_n - xn, i.e., sgn(thr_n - xn) != sgn(thr - x) 71 | if (t > 0 and M <= x[f]) or (t < 0 and m >= x[f]): 72 | return None 73 | 74 | if t > 0: 75 | mc = np.maximum(0.0, mc) 76 | m = np.maximum(x[f], m) 77 | if t < 0: 78 | Mc = np.minimum(0.0, Mc) 79 | M = np.minimum(x[f], M) 80 | 81 | assert mc < Mc and m < M 82 | 83 | # Precompute direction constraints 84 | if mc == 0.0 or Mc == 0.0: 85 | # We need this because inf * 0 = Nan (raise error) 86 | vm = 0.0 87 | elif mc * Mc > 0.0: 88 | vm = np.minimum(np.abs(mc), np.abs(Mc)) 89 | else: 90 | vm = 0.0 91 | vM = np.maximum(np.abs(mc), np.abs(Mc)) 92 | vm, vM = vm / v, vM / v 93 | 94 | for f_ in range(region.shape[0]): 95 | mc_, Mc_, m_, M_ = region[f_][0], region[f_][1], region[f_][2], region[f_][3] 96 | # Must be overlapping 97 | if f == f_: 98 | if M <= m_ or m >= M_: 99 | return None 100 | 101 | # Must be direction-compatible (proportion between features) 102 | v_ = phi[f_] 103 | if v != 0.0 and v_ != 0.0: 104 | if mc_ == 0.0 or Mc_ == 0.0: 105 | # We need this because inf * 0 = Nan (raise error) 106 | vm_ = 0.0 107 | elif mc_ * Mc_ > 0.0: 108 | vm_ = np.minimum(np.abs(mc_), np.abs(Mc_)) 109 | else: 110 | vm_ = 0.0 111 | vM_ = np.maximum(np.abs(mc_), np.abs(Mc_)) 112 | vm_, vM_ = vm_ / v_, vM_ / v_ 113 | if (vM <= vm_ or vm >= vM_): 114 | return None 115 | 116 | nregion = region.copy() 117 | nregion[f][0] = np.maximum(region[f][0], mc) 118 | nregion[f][1] = np.minimum(region[f][1], Mc) 119 | nregion[f][2] = np.maximum(region[f][2], m) 120 | nregion[f][3] = np.minimum(region[f][3], M) 121 | return nregion 122 | 123 | 124 | @numba.jit(nopython=True, nogil=True) 125 | def filter_children_regions(tree_features, tree_children_left, tree_children_right, tree_thresholds, tree_thresholds_c, 126 | region, x, phi, tau, n): 127 | 128 | # Non-compatible: return empty array 129 | if region is None: 130 | return np.zeros((0, x.shape[0], 4)) 131 | 132 | # Recurse 133 | if tree_children_left[n] != -1: 134 | f = tree_features[n] 135 | l = tree_children_left[n] # noqa: E741 136 | r = tree_children_right[n] 137 | t = tree_thresholds[n] 138 | tc = tree_thresholds_c[n] 139 | 140 | # print('Region', region) 141 | 142 | regionl = filter_region(region, x, phi, tau, f, -np.inf, tc, -np.inf, t) 143 | regionr = filter_region(region, x, phi, tau, f, tc, np.inf, t, np.inf) 144 | 145 | # print('Region L', regionl) 146 | # print('Region R', regionr) 147 | 148 | return np.concatenate((filter_children_regions(tree_features, tree_children_left, tree_children_right, 149 | tree_thresholds, tree_thresholds_c, regionl, x, phi, tau, l), 150 | filter_children_regions(tree_features, tree_children_left, tree_children_right, 151 | tree_thresholds, tree_thresholds_c, regionr, x, phi, tau, r))) 152 | 153 | # Recurse termination 154 | else: 155 | return np.expand_dims(region, 0) 156 | 157 | 158 | @numba.jit(nopython=True, nogil=True) 159 | def filter_tree_regions(tree_features, tree_children_left, tree_children_right, tree_thresholds, tree_thresholds_c, 160 | regions, x, phi, tau): 161 | cregions = np.zeros((0, x.shape[0], 4)) 162 | for region in regions: 163 | filtered_children = filter_children_regions(tree_features, tree_children_left, tree_children_right, 164 | tree_thresholds, tree_thresholds_c, region, x, phi, tau, 0) 165 | cregions = np.concatenate((cregions, filtered_children)) 166 | return cregions 167 | 168 | 169 | def filter_regions(trees, x, phi, tau): 170 | regions = np.tile(np.array([-np.inf, np.inf, -np.inf, np.inf]), (1, x.shape[0], 1)) 171 | for t, tree in enumerate(trees): 172 | regions = filter_tree_regions(tree.features, tree.children_left, tree.children_right, tree.thresholds, 173 | tree.thresholds_c, regions, x, phi, tau) 174 | # There must be at least a region 175 | assert len(regions) > 0 176 | regions = [{i: tuple(v) for i, v in enumerate(region)} for region in regions] 177 | return np.array(regions) 178 | 179 | 180 | def assert_regions(regions): 181 | for region in regions: 182 | for f, (mc, Mc, m, M) in region.items(): 183 | assert mc < Mc and m < M 184 | return regions 185 | 186 | 187 | def cost_in_region(region, cost): 188 | for f, (mc, Mc, _, _) in region.items(): 189 | if cost[f] > Mc or cost[f] < mc: 190 | return False 191 | return True 192 | 193 | 194 | def point_in_region(region, point): 195 | for f, (_, _, m, M) in region.items(): 196 | if point[f] > M or point[f] < m: 197 | return False 198 | return True 199 | 200 | 201 | def find_region_candidates(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass): 202 | X_costs = [] 203 | 204 | # print('>>>>>>>>>>>> REGION', region, '<<<<<<<<<<<<<<<<<<<<') 205 | 206 | # Compute candidates in the region 207 | for f, (mc, Mc, m, M) in region.items(): 208 | # print(f, mc, Mc, phi[f], end=' / ') 209 | if phi[f] == 0.0: # => f is not a constraining variable 210 | # print('NC phi=0.0') 211 | continue 212 | if Mc < 0: # => decrease f 213 | delta_star = Mc - precision 214 | if delta_star < mc: 215 | delta_star = (mc + Mc) / 2 216 | elif mc > 0: # => increase f 217 | delta_star = mc + precision 218 | if delta_star > Mc: 219 | delta_star = (mc + Mc) / 2 220 | elif mc <= 0 and Mc >= 0: # keep f the same (not constraining) 221 | # print('NC mc <= 0 and Mc >= 0') 222 | continue 223 | else: 224 | raise AssertionError('This should never happen.') 225 | 226 | # print(delta_star) 227 | # print('Dstar=', delta_star, 'Tau=', tau, '/ phi=', phi) 228 | X_costs.append(tau * phi / phi[f] * np.abs(delta_star)) 229 | 230 | # Means that the regions is the current region (i.e., the region where x lies) 231 | if len(X_costs) == 0: 232 | return None, None 233 | 234 | # print('costs prev:', X_costs) 235 | 236 | # Filter points outside the region (in terms of cost) 237 | X_costs = np.array([cost for cost in X_costs if cost_in_region(region, cost)]) 238 | 239 | # print('costs post 1:', X_costs) 240 | 241 | if len(X_costs) > 1: 242 | logging.debug( 243 | f'More than one direction-compatible point in the region (cost-space) with precision {precision}.') 244 | 245 | if len(X_costs) == 0: 246 | logging.debug( 247 | f'Less than one direction-compatible point in the region (cost-space) with precision {precision}.') 248 | return X_costs, X_costs # Return empty (both X_p and X_costs) 249 | 250 | # Filter points outside the region (on the input space) 251 | X_p = multiscaler.inverse_transform(np.tile(xn, (X_costs.shape[0], 1)) + X_costs, method=normalization) 252 | X_p = np.where(X_costs == 0.0, np.tile(x, (X_costs.shape[0], 1)), X_p) 253 | 254 | if bypass is True: 255 | return X_p, X_costs 256 | 257 | 258 | # print('X_p post 1:', X_p) 259 | 260 | mask = np.array([point_in_region(region, x) for x in X_p]) 261 | X_p = X_p[mask] 262 | X_costs = X_costs[mask] 263 | 264 | # print('costs post 2:', X_costs) 265 | # print('X_p post 2:', X_p) 266 | 267 | if len(X_costs) > 1: 268 | logging.debug(f'More than one direction-compatible point in the region (x-space) with precision {precision}.') 269 | 270 | if len(X_costs) == 0: 271 | logging.debug(f'Less than one direction-compatible point in the region (x-space) with precision {precision}.') 272 | 273 | return X_p, X_costs 274 | 275 | 276 | def _find_point_in_region_with_min_cost(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass): 277 | X_p = np.zeros((0, x.shape[0])) 278 | X_costs = np.zeros((0, x.shape[0])) 279 | 280 | # There are 0 or more than 1 candidates with this precision 281 | while len(X_costs) != 1 and precision > EPSILON: 282 | # print(f'Precision {precision}') 283 | X_p_, X_costs_ = find_region_candidates(region, x, xn, phi, tau, multiscaler, normalization, precision, bypass) 284 | precision = precision / 2 285 | 286 | # Means that the regions is the current regions 287 | if X_costs_ is None: 288 | return np.array([x]), np.zeros((1, x.shape[0])) 289 | 290 | # More than one with the same cost and its fine 291 | if len(X_costs_) == 0 and len(X_costs) > 1: 292 | break 293 | 294 | X_costs = X_costs_ 295 | X_p = X_p_ 296 | return X_p, X_costs 297 | 298 | 299 | def find_point_in_region_with_min_cost(region, x, xn, phi, tau, degree, multiscaler, normalization, precision): 300 | X_p, X_costs = _find_point_in_region_with_min_cost(region, 301 | x, 302 | xn, 303 | phi, 304 | tau, 305 | multiscaler, 306 | normalization, 307 | precision, 308 | bypass=False) 309 | 310 | if len(X_costs) == 0: 311 | logging.warning( 312 | f'Less than one direction-compatible point in the region with precision {precision}/{EPSILON}: point will be (slightly) out of the region.' 313 | ) 314 | X_p, X_costs = _find_point_in_region_with_min_cost(region, 315 | x, 316 | xn, 317 | phi, 318 | tau, 319 | multiscaler, 320 | normalization, 321 | precision, 322 | bypass=True) 323 | 324 | if len(X_costs) == 0: 325 | logging.error( 326 | f'No solution. Less than one direction-compatible point in the region with precision {precision}/{EPSILON}.' 327 | ) 328 | raise RuntimeError() 329 | 330 | # Compute costs 331 | costs = np.linalg.norm(np.abs(np.array(X_costs)), degree, axis=1) 332 | index_min = np.argmin(costs) 333 | # costs_min = X_costs[index_min] 334 | cost_min = costs[index_min] 335 | x_min = X_p[index_min] 336 | # x_min = multiscaler.single_inverse_transform(xn + costs_min, method=normalization) 337 | # x_min = np.where(cost_min == 0.0, x, x_min) 338 | 339 | # print('x/c', x_min, cost_min) 340 | 341 | # quantile_overflow handling: will return none when cost exceed 1.0 or less than 0.0 342 | if np.any(np.isnan(x_min)): 343 | return None, np.inf 344 | 345 | if not point_in_region(region, x_min): 346 | logging.warning(f"Point not in region: x = {x_min} / Region = {region}") 347 | 348 | return x_min, cost_min 349 | 350 | 351 | def find_points_in_regions(regions, x, xn, phi, tau, degree, multiscaler, normalization, precision): 352 | X_prime, dist_prime = [], [] 353 | for region in regions: 354 | x_min, cost_min = find_point_in_region_with_min_cost(region, x, xn, phi, tau, degree, multiscaler, 355 | normalization, precision) 356 | if x_min is not None: 357 | X_prime.append(x_min) 358 | dist_prime.append(cost_min) 359 | return np.array(X_prime), np.array(dist_prime) 360 | 361 | 362 | def inverse_actionabiltiy( 363 | x, 364 | xn, 365 | phi, 366 | tau, 367 | y_pred, 368 | model, 369 | multiscaler, 370 | trees, 371 | ntrees, 372 | degree, 373 | normalization, 374 | precision, 375 | error, 376 | ): 377 | 378 | x_counter, cost_counter = np.full(x.shape, np.nan), np.inf 379 | 380 | if np.any(np.isnan(phi)): 381 | if error == 'raise': 382 | raise RuntimeError('NaN feature attribution.') 383 | elif error == 'warning': 384 | warnings.warn('NaN feature attribution.') 385 | elif error == 'ignore': 386 | pass 387 | else: 388 | raise ValueError('Invalid error argument.') 389 | return x_counter, cost_counter 390 | 391 | # Compute costs in trees 392 | ctrees = cost_trees(trees, ntrees, xn) 393 | 394 | # Find regions (of the input space) that are phi-compatible 395 | regions = filter_regions(ctrees, x, phi=phi, tau=tau) 396 | assert_regions(regions) 397 | 398 | # Fint the regions counterfactual and cost 399 | X_prime, costs_prime = find_points_in_regions( 400 | regions, 401 | x, 402 | xn, 403 | phi, 404 | tau, 405 | degree=degree, 406 | multiscaler=multiscaler, 407 | normalization=normalization, 408 | precision=precision, 409 | ) 410 | 411 | # Order the regions based on the cost 412 | costs_sort = np.argsort(costs_prime) 413 | X_prime, costs_prime = X_prime[costs_sort], costs_prime[costs_sort] 414 | 415 | # Generate the prediction points corresponding to regions of the input 416 | start = time.perf_counter() 417 | y_prime = model.predict(X_prime) 418 | logging.debug(f'Prediction of X_prime.shape = {X_prime.shape} took {time.perf_counter()-start}') 419 | 420 | # Find first occurence of change in output (with min cost) 421 | for xp, yp, cp in zip(X_prime, y_prime, costs_prime): 422 | if yp != y_pred: 423 | x_counter = xp 424 | cost_counter = cp 425 | break 426 | 427 | return x_counter, cost_counter 428 | -------------------------------------------------------------------------------- /src/cfshap/evaluation/counterfactuals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmorganchase/cf-shap/ba58c61d4e110a950808d18dbe85b67af862549b/src/cfshap/evaluation/counterfactuals/__init__.py -------------------------------------------------------------------------------- /src/cfshap/evaluation/counterfactuals/plausibility.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | from ...base import CounterfactualEvaluationScorer, BaseCounterfactualEvaluationScorer 7 | from ...utils import keydefaultdict 8 | from ...utils.random import np_sample 9 | 10 | 11 | class BaseNNDistance(CounterfactualEvaluationScorer, BaseCounterfactualEvaluationScorer): 12 | def __init__( 13 | self, 14 | scaler, 15 | X, 16 | distance, 17 | n_neighbors: Union[int, float] = 5, 18 | max_samples=int(1e10), 19 | random_state=2021, 20 | **distance_params, 21 | ): 22 | 23 | self._scaler = scaler 24 | self._metric, self._metric_params = distance, distance_params 25 | self._n_neighbors = n_neighbors 26 | self._max_samples = max_samples 27 | self.random_state = random_state 28 | self.data = X 29 | 30 | @property 31 | def data(self): 32 | return self._data 33 | 34 | 35 | class NNDistance(BaseNNDistance): 36 | """ 37 | Plausibility metrics for counterfactuals. 38 | 39 | It computes the (average) distance from the K Nearest Neighbours of the point 40 | 41 | Note it considers all the points (regardless of their class as neighbours). 42 | """ 43 | def __init__( 44 | self, 45 | *args, 46 | **kwargs, 47 | ): 48 | super().__init__(*args, **kwargs) 49 | 50 | @BaseNNDistance.data.setter 51 | def data(self, data): 52 | 53 | # Sample 54 | self._data = np_sample(np.asarray(data), random_state=self.random_state, n=self._max_samples, safe=True) 55 | 56 | if isinstance(self._n_neighbors, int) and self._n_neighbors >= 1: 57 | n_neighbors = min(self._n_neighbors, len(self._data)) 58 | elif isinstance(self._n_neighbors, float) and self._n_neighbors <= 1.0 and self._n_neighbors > 0.0: 59 | n_neighbors = int(max(1, round(len(self._data) * self._n_neighbors))) 60 | else: 61 | raise ValueError( 62 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)') 63 | 64 | # We will be searching neighbors 65 | self._nn = NearestNeighbors( 66 | n_neighbors=n_neighbors, 67 | metric=self._metric, 68 | p=self._metric_params['p'] if 'p' in self._metric_params else 2, 69 | metric_params=self._metric_params, 70 | ).fit(self._scaler.transform(self._data)) 71 | 72 | def score(self, X: np.ndarray): 73 | X = np.asarray(X) 74 | 75 | avg_dist = np.full(X.shape[0], np.nan) 76 | nan_mask = np.any(np.isnan(X), axis=1) 77 | 78 | if (~nan_mask).sum() > 0: 79 | neigh_dist, _ = self._nn.kneighbors(self._scaler.transform(X[~nan_mask]), n_neighbors=None) 80 | neigh_dist = neigh_dist.mean(axis=1) 81 | avg_dist[~nan_mask] = neigh_dist 82 | 83 | return avg_dist 84 | 85 | 86 | class yNNDistance(BaseNNDistance): 87 | """ 88 | Plausibility metrics for counterfactuals. 89 | 90 | It computes the (average) distance from the K Nearest COUNTERFACTUAL Neighbours of the point 91 | 92 | Contrary to NNDistance, it considers as neighbours only points that have a different class. 93 | """ 94 | def __init__(self, model, *args, **kwargs): 95 | self.model = model 96 | super().__init__(*args, **kwargs) 97 | 98 | @BaseNNDistance.data.setter 99 | def data(self, data): 100 | self._raw_data = np.asarray(data) 101 | 102 | # Sample 103 | self._raw_data = np_sample(self._raw_data, random_state=self.random_state, n=self._max_samples, safe=True) 104 | 105 | # Predict 106 | self._preds = self.model.predict(self._raw_data) 107 | 108 | if isinstance(self._n_neighbors, int) and self._n_neighbors >= 1: 109 | n_neighbors = keydefaultdict(lambda pred: min(self._n_neighbors, len(self._data[pred]))) 110 | elif isinstance(self._n_neighbors, float) and self._n_neighbors <= 1.0 and self._n_neighbors > 0.0: 111 | n_neighbors = keydefaultdict(lambda pred: int(max(1, round(len(self._data[pred]) * self._n_neighbors)))) 112 | else: 113 | raise ValueError( 114 | 'Invalid n_neighbors, it must be the number of neighbors (int) or the fraction of the dataset (float)') 115 | 116 | # We will be searching neighbors of a different class 117 | self._data = keydefaultdict(lambda pred: self._raw_data[self._preds == pred]) 118 | 119 | self._nn = keydefaultdict(lambda pred: NearestNeighbors( 120 | n_neighbors=n_neighbors[pred], 121 | metric=self._metric, 122 | p=self._metric_params['p'] if 'p' in self._metric_params else 2, 123 | metric_params=self._metric_params, 124 | ).fit(self._scaler.transform(self._data[pred]))) 125 | 126 | def score(self, X: np.ndarray): 127 | X = np.asarray(X) 128 | 129 | nan_mask = np.any(np.isnan(X), axis=1) 130 | 131 | preds = self.model.predict(X[~nan_mask]) 132 | preds_indices = {pred: np.argwhere(preds == pred).flatten() for pred in np.unique(preds)} 133 | avg_dist_ = np.full(preds.shape[0], np.nan) 134 | 135 | for pred, indices in preds_indices.items(): 136 | neigh_dist, _ = self._nn[pred].kneighbors(self._scaler.transform(X[~nan_mask][indices]), n_neighbors=None) 137 | neigh_dist = neigh_dist.mean(axis=1) 138 | avg_dist_[preds_indices[pred]] = neigh_dist 139 | 140 | avg_dist = np.full(X.shape[0], np.nan) 141 | avg_dist[~nan_mask] = avg_dist_ 142 | return avg_dist 143 | -------------------------------------------------------------------------------- /src/cfshap/trend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Feature Trend Estimators 5 | """ 6 | 7 | import numpy as np 8 | from sklearn.base import BaseEstimator 9 | 10 | __all__ = ['TrendEstimator', 'DummyTrendEstimator'] 11 | 12 | 13 | class DummyTrendEstimator(BaseEstimator): 14 | def __init__(self, trends): 15 | self.trends = trends 16 | 17 | def predict(self, x=None, Y=None): 18 | return np.asarray(self.trends) 19 | 20 | 21 | class TrendEstimator(BaseEstimator): 22 | def __init__(self, strategy='mean'): 23 | self.strategy = strategy 24 | 25 | @staticmethod 26 | def __step_function(x, y): 27 | return 2 * np.heaviside(y - x, 0) - 1 28 | 29 | def __preprocess(self, x, Y): 30 | if not isinstance(x, np.ndarray): 31 | raise ValueError('Must pass a NumPy array.') 32 | 33 | if not isinstance(Y, np.ndarray): 34 | raise ValueError('Must pass a NumPy array.') 35 | 36 | if len(Y.shape) != 2: 37 | raise ValueError('Y must be a 2D matrix.') 38 | 39 | if len(x.shape) != 1: 40 | x = x.flatten() 41 | 42 | if x.shape[0] != Y.shape[1]: 43 | raise ValueError('x and Y sizes must be coherent.') 44 | 45 | return x, Y 46 | 47 | def __preprocess3D(self, X, YY): 48 | if not isinstance(X, np.ndarray): 49 | raise ValueError('Must pass a NumPy array.') 50 | 51 | if not isinstance(YY[0], np.ndarray): 52 | raise ValueError('Must pass a list of NumPy array.') 53 | 54 | if len(X.shape) != 2: 55 | raise ValueError('X must be a 2D matrix.') 56 | 57 | if len(X) != len(YY): 58 | raise ValueError('X and Y must have the same length.') 59 | 60 | if X.shape[1] != YY[0].shape[1]: 61 | raise ValueError('X and Y[0] sizes must be coherent.') 62 | 63 | return X, YY 64 | 65 | def __call__(self, x, Y): 66 | x, Y = self.__preprocess(x, Y) 67 | 68 | if self.strategy == 'mean': 69 | return self.__step_function(x, Y.mean(axis=0)) 70 | else: 71 | raise ValueError('Invalid strategy.') 72 | 73 | # NOTE: Duplicated for efficiency on 3D data 74 | def predict(self, X, YY): 75 | X, YY = self.__preprocess3D(X, YY) 76 | 77 | if self.strategy == 'mean': 78 | return np.array([self.__step_function(x, Y.mean(axis=0)) for x, Y in zip(X, YY)]) 79 | else: 80 | raise ValueError('Invalid strategy.') 81 | -------------------------------------------------------------------------------- /src/cfshap/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._python import * 2 | from ._utils import * 3 | -------------------------------------------------------------------------------- /src/cfshap/utils/_python.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | This module contains general utilities for Python programming language. 5 | - Extension and improvement of Python data structures (e.g., attrdict) 6 | - Extension of Python language (e.g., static variables for functions, function caching) 7 | - Implementations of design patterns in Python (e.g., Singleton) 8 | """ 9 | 10 | from abc import ABCMeta 11 | from collections import defaultdict 12 | from collections.abc import Mapping 13 | from pprint import pformat 14 | import json 15 | 16 | __all__ = [ 17 | 'static_vars', 18 | 'keydefaultdict', 19 | 'ignorenonedict', 20 | 'attrdict', 21 | 'function_cache', 22 | 'Singleton', 23 | ] 24 | 25 | 26 | def static_vars(**kwargs): 27 | """Decorator to create static variables for a function 28 | 29 | Usage: 30 | ```python 31 | @static_vars(i = 0) 32 | def function(x): 33 | i += 1 34 | return x + i 35 | 36 | function(0) # 1 37 | function(0) # 2 38 | function(10) # 13 39 | ``` 40 | """ 41 | def decorate(func): 42 | for k in kwargs: 43 | setattr(func, k, kwargs[k]) 44 | return func 45 | 46 | return decorate 47 | 48 | 49 | class keydefaultdict(defaultdict): 50 | """ 51 | Extension of defaultdict that support 52 | passing the key to the default_factory 53 | """ 54 | def __missing__(self, key): 55 | if self.default_factory is None: 56 | raise KeyError(key, "Must pass a default factory with a single argument.") 57 | else: 58 | ret = self[key] = self.default_factory(key) 59 | return ret 60 | 61 | 62 | class ignorenonedict(dict): 63 | def __init__(self, other=None, **kwargs): 64 | super().__init__() 65 | self.update(other, **kwargs) 66 | 67 | def __setitem__(self, key, value): 68 | if value is not None: 69 | super().__setitem__(key, value) 70 | 71 | def update(self, other=None, **kwargs): 72 | if other is not None: 73 | for k, v in other.items() if isinstance(other, Mapping) else other: 74 | self.__setitem__(k, v) 75 | for k, v in kwargs.items(): 76 | self.__setitem__(k, v) 77 | 78 | 79 | class attrdict(dict): 80 | """ 81 | Attributes-dict bounded structure for paramenters 82 | -> When a dictionary key is set the corresponding attribute is set 83 | -> When an attribute is set the corresponding dictionary key is set 84 | 85 | Usage: 86 | 87 | # Create the object 88 | args = AttrDict() 89 | 90 | args.a = 1 91 | print(args.a) # 1 92 | print(args['a']) # 1 93 | 94 | args['b'] = 2 95 | print(args.b) # 2 96 | print(args['b']) # 2 97 | 98 | """ 99 | def __init__(self, *args, **kwargs): 100 | super(attrdict, self).__init__(*args, **kwargs) 101 | self.__dict__ = self 102 | 103 | def repr(self): 104 | return dict(self) 105 | 106 | def __repr__(self): 107 | return pformat(self.repr()) 108 | 109 | def __str__(self): 110 | return self.__repr__() 111 | 112 | def update_defaults(self, d: dict): 113 | for k, v in d.items(): 114 | self.setdefault(k, v) 115 | 116 | def save_json(self, file_name): 117 | with open(file_name, 'w') as fp: 118 | json.dump(self.repr(), fp) 119 | 120 | def copy(self): 121 | return type(self)(self) 122 | 123 | 124 | class FunctionCache(dict): 125 | def __init__(self, f): 126 | self.f = f 127 | 128 | def get(self, *args, **kwargs): 129 | key = hash(str(args)) + hash(str(kwargs)) 130 | if key in self: 131 | return self[key] 132 | else: 133 | ret = self[key] = self.f(*args, **kwargs) 134 | return ret 135 | 136 | 137 | # Decorator 138 | def function_cache(f, name='cache'): 139 | """" 140 | 141 | Usage Example: 142 | 143 | @function_cache(lambda X: expensive_function(X)) 144 | @function_cache(lambda X: expensive_function2(X), name = 'second_cache') 145 | def f(X, y): 146 | return expensive_function(Y) - f.cache.get(X) + f.second_cache.get(X) 147 | 148 | 149 | X, Y = ..., ... 150 | for y in Y: 151 | f(X, y) # The function is called multiple times with X 152 | 153 | 154 | 155 | """ 156 | def decorate(func): 157 | setattr(func, name, FunctionCache(f)) 158 | return func 159 | 160 | return decorate 161 | 162 | 163 | class cached_property(object): 164 | def __init__(self, func): 165 | self.__doc__ = getattr(func, "__doc__") 166 | 167 | self.func = func 168 | 169 | def __get__(self, obj, cls): 170 | if obj is None: 171 | return self 172 | 173 | value = obj.__dict__[self.func.__name__] = self.func(obj) 174 | 175 | return value 176 | 177 | 178 | class Singleton(ABCMeta): 179 | """ 180 | Singleton META-CLASS from https://stackoverflow.com/questions/6760685/creating-a-singleton-in-python 181 | 182 | Usage Example: 183 | class Logger(metaclass=Singleton): 184 | pass 185 | 186 | What it does: 187 | When you call logger with Logger(), Python first asks the metaclass of Logger, Singleton, what to do, 188 | allowing instance creation to be pre-empted. This process is the same as Python asking a class what to do 189 | by calling __getattr__ when you reference one of it's attributes by doing myclass.attribute. 190 | 191 | """ 192 | _instances = {} 193 | 194 | def __call__(cls, *args, **kwargs): 195 | if cls not in cls._instances: 196 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 197 | return cls._instances[cls] -------------------------------------------------------------------------------- /src/cfshap/utils/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | This module implements some general utilities for explanations 5 | """ 6 | 7 | import numpy as np 8 | 9 | from ..base import Model, XGBWrapping, ListOf2DArrays 10 | 11 | __all__ = [ 12 | 'get_shap_compatible_model', 13 | 'get_top_counterfactuals', 14 | 'expand_dims_counterfactuals', 15 | 'get_top_counterfactual_costs', 16 | ] 17 | 18 | 19 | def get_top_counterfactuals( 20 | XX_C: ListOf2DArrays, 21 | X: np.ndarray, 22 | n_top: int = 1, 23 | nan: bool = True, 24 | return_2d=True, 25 | ): 26 | """Extracts the top-x counterfactuals from the input 27 | 28 | Args: 29 | XX_C (ListOf2DArrays): List of arrays (for each sample) of counterfactuals (multiple) 30 | X (np.ndarray): The query instances/samples 31 | n_top (int, optional): Number of counterfactuals to extract. Defaults to 1. 32 | nan (bool, optional): Return NaN. Defaults to True. 33 | return_2d (bool, optional): Return a 2D array if n_top = 1. Defaults to True. 34 | 35 | Returns: 36 | ListOf2DArrays: The top-counterfactuals 37 | If n_top = 1 and return_2d = True: 38 | nb_samples x nb_features 39 | Note: if nan = False, nb_samples may be less that X.shape[0] 40 | Else: 41 | nb_samples (List) x nb_counterfactuals x nb_features 42 | Note: if nan = False, nb_counterfactuals may be less than n_top 43 | 44 | """ 45 | assert X.shape[0] == len(XX_C) 46 | 47 | if n_top == 1 and return_2d is True: 48 | XC = np.full(X.shape, np.nan) 49 | for i, XC_ in enumerate(XX_C): 50 | if XC_.shape[0] > 0: 51 | XC[i] = XC_[0] 52 | if nan is False: 53 | XC = XC[np.any(np.isnan(XC), axis=1)] 54 | return XC 55 | else: 56 | if nan is True: 57 | XX_C_top = [np.full((n_top, X.shape[1]), np.nan)] 58 | for i, XC_ in enumerate(XX_C): 59 | if XC_.shape[0] > 0: 60 | n_xc = min(n_top, XC_.shape[0]) 61 | XX_C_top[i][:n_xc] = XC_[:n_xc] 62 | return XX_C_top 63 | else: 64 | XX_C = [XC[~np.any(np.isnan(XC), axis=1)] for XC in XX_C] 65 | return [XC[:n_top] for XC in XX_C] 66 | 67 | 68 | def expand_dims_counterfactuals(X): 69 | """Expand dimensions of a list of counterfactuals 70 | 71 | Args: 72 | X (np.ndarray): The counterfactuals 73 | 74 | Returns: 75 | List[np.ndarray]: A list of arrays of counterfactuals. All arrays will have length 1. 76 | 77 | """ 78 | 79 | if isinstance(X, np.ndarray): 80 | return [x.reshape(-1, 1) for x in X] 81 | else: 82 | return [np.array([x]) for x in X] 83 | 84 | 85 | def get_top_counterfactual_costs(costs: ListOf2DArrays, X: np.ndarray, nan: bool = True): 86 | """Extract the costs of the top-x counterfactuals 87 | 88 | Pre-condition: Counterfactuals costs for each sample must be ordered (ascending). 89 | 90 | Args: 91 | costs (ListOf2DArrays): List of arrays (for each sample) of counterfactuals costs (multiple). 92 | X (np.ndarray): The query instances/samples 93 | nan (bool, optional): If True, return NaN cost if no counterfactual exists, skips the sample otherwise. Defaults to True. 94 | 95 | Returns: 96 | np.ndarray: Array of the costs of the best counterfactuals 97 | """ 98 | if nan: 99 | return np.array([c[0] if len(c) > 0 else np.nan for c in costs]) 100 | else: 101 | raise NotImplementedError() 102 | 103 | 104 | def get_shap_compatible_model(model: Model): 105 | """Get the SHAP-compatible model 106 | 107 | Args: 108 | model (Model): A model 109 | 110 | Raises: 111 | ValueError: If the model is not supported 112 | 113 | Returns: 114 | object: SHAP-compatible model. 115 | """ 116 | # Extract XGBoost-wrapped models 117 | if isinstance(model, XGBWrapping): 118 | model = model.get_booster() 119 | 120 | # Explicit error handling 121 | if not (type(model).__module__ == 'xgboost.core' and type(model).__name__ == 'Booster'): 122 | raise ValueError("'get_booster' must return an XGBoost Booster object.") 123 | 124 | return model 125 | -------------------------------------------------------------------------------- /src/cfshap/utils/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Parallelization plotting utilities. 5 | """ 6 | 7 | from .batch import * 8 | from .utils import * 9 | -------------------------------------------------------------------------------- /src/cfshap/utils/parallel/batch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Parallelization Utilities for batch processing with JobLib. 5 | """ 6 | 7 | import logging 8 | from typing import Union 9 | from tqdm import tqdm 10 | 11 | from .utils import Element, split, join, max_cpu_count 12 | 13 | __all__ = ['batch_process'] 14 | 15 | 16 | def _parallel_joblib(f, splits, n_jobs): 17 | # pylint: disable=import-outside-toplevel 18 | from joblib import Parallel, delayed 19 | # pylint: enable=import-outside-toplevel 20 | return Parallel(n_jobs=n_jobs)(delayed(f)(s) for s in splits) 21 | 22 | 23 | def batch_process( 24 | f, 25 | iterable: Element, 26 | split_size: int = None, 27 | batch_size: int = None, 28 | desc='Batches', 29 | verbose=1, 30 | n_jobs: Union[None, int, str] = None, 31 | ) -> Element: 32 | ''' 33 | Batch-processing for iterable that support list comprehension. 34 | 35 | f : function 36 | iterable : An iterable to be batch processed (list of numpy) 37 | split_size/batch_size : size of the batch 38 | ''' 39 | 40 | assert split_size is None or batch_size is None 41 | if split_size is None: 42 | split_size = batch_size 43 | 44 | if n_jobs is None: 45 | n_jobs = 1 46 | if n_jobs == 'auto': 47 | n_jobs = n_jobs or (max_cpu_count() - 1) 48 | 49 | # Split 50 | if verbose > 1: 51 | logging.info('Splitting the iterable for batch process...') 52 | splits = split(iterable, split_size) 53 | if verbose > 1: 54 | logging.info('Splitting done.') 55 | 56 | # Attach TQDM 57 | if desc is not None and desc is not False and verbose > 0: 58 | splits = tqdm(splits, desc=str(desc)) 59 | 60 | # Batch process 61 | if n_jobs > 1 and len(iterable) > split_size: 62 | logging.info(f'Using joblib with {n_jobs} jobs') 63 | results = _parallel_joblib(f, splits, n_jobs) 64 | else: 65 | results = [f(s) for s in splits] 66 | 67 | # Join 68 | if verbose > 1: 69 | logging.info('Joining the iterable after batch process...') 70 | results = join(results) 71 | if verbose > 1: 72 | logging.info('Joining done.') 73 | 74 | # Return 75 | return results -------------------------------------------------------------------------------- /src/cfshap/utils/parallel/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | General utilities needed when parallelizing (e.g., determining number of CPUs, split an array) 5 | """ 6 | 7 | from typing import Union, List, Tuple 8 | import itertools 9 | import multiprocessing 10 | 11 | import pandas as pd 12 | import numpy as np 13 | 14 | __all__ = [ 15 | 'max_cpu_count', 16 | ] 17 | 18 | Element = Union[list, np.ndarray, pd.DataFrame] 19 | Container = Union[List[Element], np.ndarray] 20 | 21 | 22 | def max_cpu_count(reserved=1): 23 | count = multiprocessing.cpu_count() 24 | return max(1, count - reserved) 25 | 26 | 27 | def nb_splits(iterable: Element, split_size: int) -> int: 28 | return int(np.ceil(len(iterable) / split_size)) 29 | 30 | 31 | def split_indexes(iterable: Element, split_size: int) -> List[Tuple[int, int]]: 32 | nb = nb_splits(iterable, split_size) 33 | split_size = int(np.ceil(len(iterable) / nb)) # Equalize sizes of the splits 34 | return [(split_id * split_size, (split_id + 1) * split_size) for split_id in range(0, nb)] 35 | 36 | 37 | def split(iterable: Element, split_size) -> Container: 38 | """ 39 | iterable : The objet that should be split in pieces 40 | split_size: Maximum split size (the size is equalized among all splits) 41 | """ 42 | if isinstance(iterable, (np.ndarray, pd.DataFrame)): 43 | return np.array_split(iterable, nb_splits(iterable, split_size)) 44 | else: 45 | return [iterable[a:b] for a, b in split_indexes(iterable, split_size)] 46 | 47 | 48 | def join(iterable: Container, axis: int = 0) -> Element: 49 | # Get rid of iterators 50 | iterable = list(iterable) 51 | 52 | assert len(set((type(a) for a in iterable))) == 1, "Expecting a non-empty iterable of objects of the same type." 53 | 54 | if isinstance(iterable[0], pd.DataFrame): 55 | return pd.concat(iterable, axis=axis) 56 | elif isinstance(iterable[0], np.ndarray): 57 | return np.concatenate(iterable, axis=axis) 58 | elif isinstance(iterable[0], list): 59 | return list(itertools.chain.from_iterable(iterable)) 60 | else: 61 | raise TypeError('Can handle only pandas, numpy and lists. No iterators.') 62 | -------------------------------------------------------------------------------- /src/cfshap/utils/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | ML / Data Science Pre-processing Utilities 5 | """ 6 | 7 | from .scaling import * 8 | -------------------------------------------------------------------------------- /src/cfshap/utils/preprocessing/scaling/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantiletransformer import * 2 | from .madscaler import * 3 | from .multiscaler import * -------------------------------------------------------------------------------- /src/cfshap/utils/preprocessing/scaling/madscaler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the (Mean/Median) Absolute Deviation Scalers 3 | 4 | NOTE: The acronym MAD (Mean/Median Absolute Deviation) has many different meaning. 5 | This module implements all the four possible alternatives: 6 | - Mean Absolute Deviation Scaler from the Mean 7 | - Median Absolute Deviation Scaler from the Median 8 | - Mean Absolute Deviation Scaler from the Median 9 | - Median Absolute Deviation Scaler from the Mean 10 | 11 | """ 12 | 13 | __all__ = [ 14 | 'MeanAbsoluteDeviationFromMeanScaler', 15 | 'MeanAbsoluteDeviationFromMedianScaler', 16 | 'MedianAbsoluteDeviationFromMeanScaler', 17 | 'MedianAbsoluteDeviationFromMedianScaler', 18 | 'MedianAbsoluteDeviationScaler', 19 | 'MeanAbsoluteDeviationScaler', 20 | ] 21 | __author__ = 'Emanuele Albini' 22 | 23 | from abc import ABC, abstractmethod 24 | import numpy as np 25 | from scipy import sparse 26 | from scipy import stats 27 | from sklearn.base import BaseEstimator, TransformerMixin 28 | from sklearn.utils.validation import (check_is_fitted, FLOAT_DTYPES) 29 | from sklearn.utils.sparsefuncs import inplace_column_scale 30 | from sklearn.utils import check_array 31 | 32 | 33 | def mean_absolute_deviation(data, axis=None): 34 | return np.mean(np.absolute(data - np.mean(data, axis)), axis) 35 | 36 | 37 | def mean_absolute_deviation_from_median(data, axis=None): 38 | return np.mean(np.absolute(data - np.median(data, axis)), axis) 39 | 40 | 41 | def median_absolute_deviation_from_mean(data, axis=None): 42 | return np.median(np.absolute(data - np.mean(data, axis)), axis) 43 | 44 | 45 | class AbsoluteDeviationScaler(TransformerMixin, BaseEstimator, ABC): 46 | """ 47 | This class is the interface and base class for the Mean/Median Absolute Deviation Scalers. 48 | It implements the scikit-learn API for scalers. 49 | 50 | This class is based on scikit-learn StandardScaler and RobustScaler. 51 | """ 52 | def __init__(self, *, copy=True, with_centering=True, with_scaling=True): 53 | self.with_centering = with_centering 54 | self.with_scaling = with_scaling 55 | self.copy = copy 56 | 57 | @abstractmethod 58 | def _center(self, X): 59 | pass 60 | 61 | @abstractmethod 62 | def _scale(self, X): 63 | pass 64 | 65 | def _check_inputs(self, X): 66 | try: 67 | X = self._validate_data(X, 68 | accept_sparse='csc', 69 | estimator=self, 70 | dtype=FLOAT_DTYPES, 71 | force_all_finite='allow-nan') 72 | except AttributeError: 73 | X = check_array(X, 74 | accept_sparse='csr', 75 | copy=self.copy, 76 | estimator=self, 77 | dtype=FLOAT_DTYPES, 78 | force_all_finite='allow-nan') 79 | return X 80 | 81 | def fit(self, X, y=None): 82 | X = self._check_inputs(X) 83 | 84 | if self.with_centering: 85 | if sparse.issparse(X): 86 | raise ValueError("Cannot center sparse matrices: use `with_centering=False`" 87 | " instead. See docstring for motivation and alternatives.") 88 | self.center_ = self._center(X) 89 | else: 90 | self.center_ = None 91 | 92 | if self.with_scaling: 93 | self.scale_ = self._scale(X) 94 | else: 95 | self.scale_ = None 96 | 97 | return self 98 | 99 | def transform(self, X): 100 | """Center and scale the data. 101 | Parameters 102 | ---------- 103 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 104 | The data used to scale along the specified axis. 105 | Returns 106 | ------- 107 | X_tr : {ndarray, sparse matrix} of shape (n_samples, n_features) 108 | Transformed array. 109 | """ 110 | check_is_fitted(self, ['center_', 'scale_']) 111 | X = self._check_inputs(X) 112 | 113 | if sparse.issparse(X): 114 | if self.with_scaling: 115 | inplace_column_scale(X, 1.0 / self.scale_) 116 | else: 117 | if self.with_centering: 118 | X -= self.center_ 119 | if self.with_scaling: 120 | X /= self.scale_ 121 | return X 122 | 123 | def inverse_transform(self, X): 124 | """Scale back the data to the original representation 125 | Parameters 126 | ---------- 127 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 128 | The rescaled data to be transformed back. 129 | Returns 130 | ------- 131 | X_tr : {ndarray, sparse matrix} of shape (n_samples, n_features) 132 | Transformed array. 133 | """ 134 | check_is_fitted(self, ['center_', 'scale_']) 135 | X = check_array(X, 136 | accept_sparse=('csr', 'csc'), 137 | copy=self.copy, 138 | estimator=self, 139 | dtype=FLOAT_DTYPES, 140 | force_all_finite='allow-nan') 141 | 142 | if sparse.issparse(X): 143 | if self.with_scaling: 144 | inplace_column_scale(X, self.scale_) 145 | else: 146 | if self.with_scaling: 147 | X *= self.scale_ 148 | if self.with_centering: 149 | X += self.center_ 150 | return X 151 | 152 | def _more_tags(self): 153 | return {'allow_nan': True} 154 | 155 | 156 | class MeanAbsoluteDeviationFromMeanScaler(AbsoluteDeviationScaler): 157 | """Mean absolute deviation scaler (from the mean) 158 | It scales using the the MEAN deviation from the MEAN 159 | """ 160 | def _center(self, X): 161 | return np.nanmean(X, axis=0) 162 | 163 | def _scale(self, X): 164 | return mean_absolute_deviation(X, axis=0) 165 | 166 | 167 | class MedianAbsoluteDeviationFromMedianScaler(AbsoluteDeviationScaler): 168 | """Median absolute deviation scaler (from the median) 169 | It scales using the the MEDIAN deviation from the MEDIAN 170 | """ 171 | def _center(self, X): 172 | return np.nanmedian(X, axis=0) 173 | 174 | def _scale(self, X): 175 | return stats.median_absolute_deviation(X, axis=0) 176 | 177 | 178 | class MeanAbsoluteDeviationFromMedianScaler(AbsoluteDeviationScaler): 179 | """Mean absolute deviation scaler (from the median) 180 | It scales using the the MEAN deviation from the MEDIAN 181 | """ 182 | def _center(self, X): 183 | return np.nanmean(X, axis=0) 184 | 185 | def _scale(self, X): 186 | return mean_absolute_deviation_from_median(X, axis=0) 187 | 188 | 189 | class MedianAbsoluteDeviationFromMeanScaler(AbsoluteDeviationScaler): 190 | """Median absolute deviation scaler (from the mean) 191 | It scales using the the MEDIAN deviation from the MEAN 192 | """ 193 | def _center(self, X): 194 | return np.nanmean(X, axis=0) 195 | 196 | def _scale(self, X): 197 | return median_absolute_deviation_from_mean(X, axis=0) 198 | 199 | 200 | # Aliases 201 | # By default we return the matching absolute scalers (median-median and mean-mean). 202 | MeanAbsoluteDeviationScaler = MeanAbsoluteDeviationFromMeanScaler 203 | MedianAbsoluteDeviationScaler = MedianAbsoluteDeviationFromMedianScaler -------------------------------------------------------------------------------- /src/cfshap/utils/preprocessing/scaling/multiscaler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the MultiScaler. 3 | The multi scaler is a scaler that allows for different scaling within the same class through an argument passed to the `transform` methods. 4 | e.g., STD, MAD, Quantile, etc. 5 | 6 | """ 7 | 8 | __all__ = [ 9 | 'MultiScaler', 10 | 'IdentityScaler', 11 | 'get_scaler_name', 12 | 'SCALERS', 13 | ] 14 | __author__ = 'Emanuele Albini' 15 | 16 | from typing import Union 17 | 18 | import numpy as np 19 | import scipy as sp 20 | 21 | from sklearn.preprocessing import MinMaxScaler, StandardScaler 22 | from sklearn.base import BaseEstimator, TransformerMixin 23 | 24 | from ... import keydefaultdict 25 | 26 | from .madscaler import ( 27 | MedianAbsoluteDeviationScaler, 28 | MeanAbsoluteDeviationScaler, 29 | MeanAbsoluteDeviationFromMedianScaler, 30 | MedianAbsoluteDeviationFromMeanScaler, 31 | ) 32 | from .quantiletransformer import EfficientQuantileTransformer 33 | 34 | # NOTE: This will be deprecated, it is confusing 35 | _DISTANCE_TO_SCALER = { 36 | 'euclidean': 'std', 37 | 'manhattan': 'mad', 38 | 'cityblock': 'mad', 39 | 'percshift': 'quantile', 40 | } 41 | 42 | _SCALER_TO_NAME = { 43 | 'quantile': 'Quantile', 44 | 'std': 'Standard', 45 | 'mad': 'Median Absolute Dev. (from median)', 46 | 'Mad': 'Mean Absolute Dev. (from mean)', 47 | 'madM': 'Median Absolute Dev. from mean', 48 | 'Madm': 'Mean Absolute Dev. from median', 49 | 'minmax': 'Min Max', 50 | 'quantile_nan': 'Quantile w/NaN OF', 51 | 'quantile_sum': 'Quantile w/OF-Σ', 52 | None: 'Identity', 53 | } 54 | 55 | SCALERS = list(_SCALER_TO_NAME.keys()) 56 | 57 | 58 | def _get_method_safe(method): 59 | if method is None or method in SCALERS: 60 | return method 61 | # NOTE: This will be deprecated, it is confusing 62 | elif method in list(_DISTANCE_TO_SCALER.keys()): 63 | return _DISTANCE_TO_SCALER[method] 64 | else: 65 | raise ValueError('Invalid normalization method.') 66 | 67 | 68 | def get_scaler_name(method): 69 | return _SCALER_TO_NAME[_get_method_safe(method)] 70 | 71 | 72 | class IdentityScaler(TransformerMixin, BaseEstimator): 73 | """A dummy/identity scaler compatible with the sklearn interface for scalers 74 | It returns the same input it receives. 75 | 76 | """ 77 | def __init__(self): 78 | pass 79 | 80 | def fit(self, X, y=None): 81 | return self 82 | 83 | def transform(self, X): 84 | return X 85 | 86 | def inverse_transform(self, X): 87 | return X 88 | 89 | 90 | def get_transformer_class(method): 91 | method = _get_method_safe(method) 92 | 93 | if method is None: 94 | return IdentityScaler 95 | elif method == 'std': 96 | return StandardScaler 97 | elif method == 'minmax': 98 | return MinMaxScaler 99 | elif method == 'mad': 100 | return MedianAbsoluteDeviationScaler 101 | elif method == 'Mad': 102 | return MeanAbsoluteDeviationScaler 103 | elif method == 'Madm': 104 | return MeanAbsoluteDeviationFromMedianScaler 105 | elif method == 'madM': 106 | return MedianAbsoluteDeviationFromMeanScaler 107 | elif method == 'quantile': 108 | return EfficientQuantileTransformer # PercentileShifterCached 109 | elif method == 'quantile_nan': 110 | return EfficientQuantileTransformer 111 | elif method == 'quantile_sum': 112 | return EfficientQuantileTransformer 113 | 114 | 115 | def get_transformer_kwargs(method): 116 | method = _get_method_safe(method) 117 | 118 | if method == 'quantile_nan': 119 | return dict(overflow="nan") 120 | elif method == 'quantile_sum': 121 | return dict(overflow="sum") 122 | else: 123 | return dict() 124 | 125 | 126 | class MultiScaler: 127 | """Multi-Scaler 128 | 129 | Raises: 130 | Exception: If an invalid normalization is used 131 | 132 | """ 133 | # Backward compatibility 134 | NORMALIZATION = SCALERS 135 | 136 | def __init__(self, data: np.ndarray = None): 137 | """Constructor 138 | 139 | - The data on which to train the scalers can be passed here (in the constructor), or 140 | - It can also be passe later using the .fit(data) method. 141 | 142 | Args: 143 | data (pd.DataFrame): A dataframe with the features, optional. Default to None. 144 | 145 | """ 146 | 147 | if data is not None: 148 | return self.fit(data) 149 | 150 | self.__suppress_warning = False 151 | 152 | def fit(self, data: np.ndarray): 153 | self.data = np.asarray(data) 154 | 155 | self.transformers = keydefaultdict(lambda method: get_transformer_class(method) 156 | (**get_transformer_kwargs(method)).fit(self.data)) 157 | 158 | def single_transformer(method, f): 159 | return get_transformer_class(method)(**get_transformer_kwargs(method)).fit(self.data[:, f].reshape(-1, 1)) 160 | 161 | self.single_transformers = keydefaultdict(lambda args: single_transformer(*args)) 162 | self.data_transformed = keydefaultdict(lambda method: self.transform(self.data, method)) 163 | 164 | self.covs = keydefaultdict(lambda method, lib: self.__compute_covariance_matrix(self.data, method, lib)) 165 | 166 | self._lower_bounds = keydefaultdict(lambda method: self.data_transformed[method].min(axis=0)) 167 | self._upper_bounds = keydefaultdict(lambda method: self.data_transformed[method].max(axis=0)) 168 | 169 | def lower_bounds(self, method): 170 | return self._lower_bounds[method] 171 | 172 | def upper_bounds(self, method): 173 | return self._upper_bound[method] 174 | 175 | def suppress_warnings(self, value=True): 176 | self.__suppress_warning = value 177 | 178 | def __compute_covariance_matrix(self, data, method, lib): 179 | if lib == 'np': 180 | return sp.linalg.inv(np.cov((self.transform(data, method=method)), rowvar=False)) 181 | elif lib == 'tf': 182 | from ..tf import inv_cov as tf_inv_cov 183 | return tf_inv_cov(self.transform(data, method=method)) 184 | else: 185 | raise ValueError('Invalid lib.') 186 | 187 | def transform(self, data: np.ndarray, method: str, **kwargs): 188 | """Normalize the data according to the "method" passed 189 | 190 | Args: 191 | data (np.ndarray): The data to be normalized (nb_samples x nb_features) 192 | method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'. 193 | 194 | Raises: 195 | ValueError: Invalid normalization 196 | 197 | Returns: 198 | np.ndarray: Normalized array 199 | """ 200 | method = _get_method_safe(method) 201 | return self.transformers[method].transform(data) 202 | 203 | def inverse_transform(self, data: np.ndarray, method: str = 'std'): 204 | """Un-normalize the data according to the "method" passes 205 | 206 | Args: 207 | data (np.ndarray): The data to be un-normalized (nb_samples x nb_features) 208 | method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'. 209 | 210 | Raises: 211 | ValueError: Invalid normalization 212 | 213 | Returns: 214 | np.ndarray: Un-normalized array 215 | """ 216 | method = _get_method_safe(method) 217 | return self.transformers[method].inverse_transform(data) 218 | 219 | def feature_deviation(self, method: str = 'std', phi: Union[float, int] = 1): 220 | """Get the deviation of each feature according to the normalization method 221 | 222 | Args: 223 | method (str): method (str, optional): Normalization (see class documentation for details on the available scalings). Defaults to 'std'. 224 | phi (Union[float, int]): The fraction of the STD/MAD/MINMAX. Default to 1. 225 | 226 | Raises: 227 | ValueError: Invalid normalization 228 | 229 | Returns: 230 | np.ndarray: Deviations, shape = (nb_features, ) 231 | """ 232 | method = _get_method_safe(method) 233 | transformer = self.transformers[method] 234 | if 'scale_' in dir(transformer): 235 | return transformer.scale_ * phi 236 | else: 237 | return np.ones(self.data.shape[1]) * phi 238 | 239 | def feature_transform(self, x: np.ndarray, f: int, method: str): 240 | x = np.asarray(x) 241 | transformer = self.get_feature_transformer(method, f) 242 | return transformer.transform(x.reshape(-1, 1))[:, 0] 243 | 244 | def value_transform(self, x: float, f: int, method: str): 245 | x = np.asarray([x]) 246 | transformer = self.get_feature_transformer(method, f) 247 | return transformer.transform(x.reshape(-1, 1))[:, 0][0] 248 | 249 | def shift_transform(self, X, shifts, method, **kwargs): 250 | transformer = self.get_transformer(method) 251 | if 'shift' in dir(transformer): 252 | return transformer.shift_transform(X, shifts=shifts, **kwargs) 253 | else: 254 | return X + shifts 255 | 256 | def move_transform(self, X, costs, method, **kwargs): 257 | transformer = self.get_transformer(method) 258 | assert costs.shape[0] == X.shape[1] 259 | return transformer.inverse_transform(transformer.transform(X) + np.tile(costs, (X.shape[0], 1))) 260 | 261 | def get_transformer(self, method: str): 262 | return self.transformers[_get_method_safe(method)] 263 | 264 | def get_feature_transformer(self, method: str, f: int): 265 | return self.single_transformers[(_get_method_safe(method), f)] 266 | 267 | def single_transform(self, x, *args, **kwargs): 268 | return self.transform(np.array([x]), *args, **kwargs)[0] 269 | 270 | def single_inverse_transform(self, x, *args, **kwargs): 271 | return self.inverse_transform(np.array([x]), *args, **kwargs)[0] 272 | 273 | def single_shift_transform(self, x, shift, **kwargs): 274 | return self.shift_transform(np.array([x]), np.array([shift]), **kwargs)[0] 275 | 276 | # NOTE: This must be deprecated, it does not fit here. 277 | def covariance_matrix(self, data: np.ndarray, method: Union[None, str], lib='np'): 278 | if data is None: 279 | return self.covs[(method, lib)] 280 | else: 281 | return self.__compute_covariance_matrix(data, method, lib) 282 | 283 | # NOTE: This must be deprecated, it does not fit here. 284 | def covariance_matrices(self, data: np.ndarray, methods=None, lib='np'): 285 | """Compute the covariance matrices 286 | 287 | Args: 288 | data (np.ndarray): The data from which to extract the covariance 289 | 290 | Returns: 291 | Dict[np.ndarray]: Dictionary (for each normalization method) of covariance matrices 292 | """ 293 | # If no method is passed we compute for all of them 294 | if methods is None: 295 | methods = self.NORMALIZATION 296 | 297 | return {method: self.covariance_matrix(data, method, lib) for method in methods} -------------------------------------------------------------------------------- /src/cfshap/utils/preprocessing/scaling/quantiletransformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements an efficient and exact version of the scikit-learn QuantileTransformer. 3 | 4 | Note: This module has been inspired from scikit-learn QuantileTransformer 5 | (and it is effectively an extension of QuantileTransformer API) 6 | See https://github.com/scikit-learn/scikit-learn/blob/2beed5584/sklearn/preprocessing 7 | 8 | """ 9 | 10 | __all__ = ['EfficientQuantileTransformer'] 11 | __author__ = 'Emanuele Albini' 12 | 13 | import numpy as np 14 | from scipy import sparse 15 | from sklearn.utils import check_random_state 16 | from sklearn.utils.validation import (check_is_fitted, FLOAT_DTYPES) 17 | from sklearn.preprocessing import QuantileTransformer 18 | from sklearn.utils import check_array 19 | 20 | 21 | class EfficientQuantileTransformer(QuantileTransformer): 22 | """ 23 | This class directly extends and improve the efficiency of scikit-learn QuantileTransformer 24 | 25 | Note: The efficient implementation will be only used if: 26 | - The input are NumPy arrays (NOT scipy sparse matrices) 27 | The flag self.smart_fit_ marks when the efficient implementation is being used. 28 | 29 | """ 30 | def __init__( 31 | self, 32 | *, 33 | subsample=np.inf, 34 | random_state=None, 35 | copy=True, 36 | overflow=None, # "nan" or "sum" 37 | ): 38 | """Initialize the transformer 39 | 40 | Args: 41 | subsample (int, optional): Number of samples to use to create the quantile space. Defaults to np.inf. 42 | random_state (int, optional): Random seed (sampling happen only if subsample < number of samples fitted). Defaults to None. 43 | copy (bool, optional): If False, passed arrays will be edited in place. Defaults to True. 44 | overflow (str, optional): Overflow strategy. Defaults to None. 45 | When doing the inverse transformation if a quantile > 1 or < 0 is passed then: 46 | - None > Nothing is done. max(0, min(1, q)) will be used. The 0% or 100% reference will be returned. 47 | - 'sum' > It will add proportionally, e.g., q = 1.2 will result in adding 20% more quantile to the 100% reference. 48 | - 'nan' > It will return NaN 49 | """ 50 | self.ignore_implicit_zeros = False 51 | self.n_quantiles_ = np.inf 52 | self.output_distribution = 'uniform' 53 | self.subsample = subsample 54 | self.random_state = random_state 55 | self.overflow = overflow 56 | self.copy = copy 57 | 58 | def _smart_fit(self, X, random_state): 59 | n_samples, n_features = X.shape 60 | self.references_ = [] 61 | self.quantiles_ = [] 62 | for col in X.T: 63 | # Do sampling if necessary 64 | if self.subsample < n_samples: 65 | subsample_idx = random_state.choice(n_samples, size=self.subsample, replace=False) 66 | col = col.take(subsample_idx, mode='clip') 67 | col = np.sort(col) 68 | quantiles = np.sort(np.unique(col)) 69 | references = 0.5 * np.array( 70 | [np.searchsorted(col, v, side='left') + np.searchsorted(col, v, side='right') 71 | for v in quantiles]) / n_samples 72 | self.quantiles_.append(quantiles) 73 | self.references_.append(references) 74 | 75 | def fit(self, X, y=None): 76 | """Compute the quantiles used for transforming. 77 | Parameters 78 | ---------- 79 | X : {array-like} of shape (n_samples, n_features) 80 | The data used to scale along the features axis. 81 | 82 | y : None 83 | Ignored. 84 | Returns 85 | ------- 86 | self : object 87 | Fitted transformer. 88 | """ 89 | 90 | if self.subsample <= 1: 91 | raise ValueError("Invalid value for 'subsample': %d. " 92 | "The number of subsamples must be at least two." % self.subsample) 93 | 94 | X = self._check_inputs(X, in_fit=True, copy=False) 95 | n_samples = X.shape[0] 96 | 97 | if n_samples <= 1: 98 | raise ValueError("Invalid value for samples: %d. " 99 | "The number of samples to fit for must be at least two." % n_samples) 100 | 101 | rng = check_random_state(self.random_state) 102 | 103 | # Create the quantiles of reference 104 | self.smart_fit_ = not sparse.issparse(X) 105 | if self.smart_fit_: # <<<<<- New case 106 | self._smart_fit(X, rng) 107 | else: 108 | raise NotImplementedError('EfficientQuantileTransformer handles only NON-sparse matrices!') 109 | 110 | return self 111 | 112 | def _smart_transform_col(self, X_col, quantiles, references, inverse): 113 | """Private function to transform a single feature.""" 114 | 115 | isfinite_mask = ~np.isnan(X_col) 116 | X_col_finite = X_col[isfinite_mask] 117 | # Simply Interpolate 118 | if not inverse: 119 | X_col[isfinite_mask] = np.interp(X_col_finite, quantiles, references) 120 | else: 121 | X_col[isfinite_mask] = np.interp(X_col_finite, references, quantiles) 122 | 123 | return X_col 124 | 125 | def _check_inputs(self, X, in_fit, accept_sparse_negative=False, copy=False): 126 | """Check inputs before fit and transform.""" 127 | try: 128 | X = self._validate_data(X, 129 | reset=in_fit, 130 | accept_sparse=False, 131 | copy=copy, 132 | dtype=FLOAT_DTYPES, 133 | force_all_finite='allow-nan') 134 | except AttributeError: # Old sklearn version (_validate_data do not exists) 135 | X = check_array(X, accept_sparse=False, copy=self.copy, dtype=FLOAT_DTYPES, force_all_finite='allow-nan') 136 | 137 | # we only accept positive sparse matrix when ignore_implicit_zeros is 138 | # false and that we call fit or transform. 139 | with np.errstate(invalid='ignore'): # hide NaN comparison warnings 140 | if (not accept_sparse_negative and not self.ignore_implicit_zeros 141 | and (sparse.issparse(X) and np.any(X.data < 0))): 142 | raise ValueError('QuantileTransformer only accepts' ' non-negative sparse matrices.') 143 | 144 | # check the output distribution 145 | if self.output_distribution not in ('normal', 'uniform'): 146 | raise ValueError("'output_distribution' has to be either 'normal'" 147 | " or 'uniform'. Got '{}' instead.".format(self.output_distribution)) 148 | 149 | return X 150 | 151 | def _transform(self, X, inverse=False): 152 | """Forward and inverse transform. 153 | Parameters 154 | ---------- 155 | X : ndarray of shape (n_samples, n_features) 156 | The data used to scale along the features axis. 157 | inverse : bool, default=False 158 | If False, apply forward transform. If True, apply 159 | inverse transform. 160 | Returns 161 | ------- 162 | X : ndarray of shape (n_samples, n_features) 163 | Projected data. 164 | """ 165 | for feature_idx in range(X.shape[1]): 166 | X[:, feature_idx] = self._smart_transform_col(X[:, feature_idx], self.quantiles_[feature_idx], 167 | self.references_[feature_idx], inverse) 168 | 169 | return X 170 | 171 | def transform(self, X): 172 | """Feature-wise transformation of the data. 173 | Parameters 174 | ---------- 175 | X : {array-like} of shape (n_samples, n_features) 176 | The data used to scale along the features axis. 177 | 178 | Returns 179 | ------- 180 | Xt : {ndarray, sparse matrix} of shape (n_samples, n_features) 181 | The projected data. 182 | """ 183 | check_is_fitted(self, ['quantiles_', 'references_', 'smart_fit_']) 184 | X = self._check_inputs(X, in_fit=False, copy=self.copy) 185 | return self._transform(X, inverse=False) 186 | 187 | def inverse_transform(self, X): 188 | """Back-projection to the original space. 189 | Parameters 190 | ---------- 191 | X : {array-like} of shape (n_samples, n_features) 192 | The data used to scale along the features axis. 193 | 194 | Returns 195 | ------- 196 | Xt : {ndarray, sparse matrix} of (n_samples, n_features) 197 | The projected data. 198 | """ 199 | check_is_fitted(self, ['quantiles_', 'references_', 'smart_fit_']) 200 | X = self._check_inputs(X, in_fit=False, accept_sparse_negative=False, copy=self.copy) 201 | 202 | if self.overflow is None: 203 | T = self._transform(X, inverse=True) 204 | elif self.overflow == 'nan': 205 | NaN_mask = np.ones(X.shape) 206 | NaN_mask[(X > 1) | (X < 0)] = np.nan 207 | T = NaN_mask * self._transform(X, inverse=True) 208 | 209 | elif self.overflow == 'sum': 210 | ones = self._transform(np.ones(X.shape), inverse=True) 211 | zeros = self._transform(np.zeros(X.shape), inverse=True) 212 | 213 | # Standard computation 214 | T = self._transform(X.copy(), inverse=True) 215 | 216 | # Deduct already computed part 217 | X = np.where((X > 0), np.maximum(X - 1, 0), X) 218 | 219 | # After this X > 0 => Remaining quantile > 1.00 220 | # and X < 0 => Remaining quantile < 0.00 221 | 222 | T = T + (X > 1) * np.floor(X) * (ones - zeros) 223 | X = np.where((X > 1), np.maximum(X - np.floor(X), 0), X) 224 | T = T + (X > 0) * (ones - self._transform(1 - X.copy(), inverse=True)) 225 | 226 | T = T - (X < -1) * np.floor(-X) * (ones - zeros) 227 | X = np.where((X < -1), np.minimum(X + np.floor(-X), 0), X) 228 | T = T - (X < 0) * (self._transform(-X.copy(), inverse=True) - zeros) 229 | 230 | # Set the NaN the values that have not been reached after a certaing amount of iterations 231 | # T[(X > 0) | (X < 0)] = np.nan 232 | 233 | else: 234 | raise ValueError('Invalid value for overflow.') 235 | 236 | return T 237 | 238 | 239 | # %% 240 | -------------------------------------------------------------------------------- /src/cfshap/utils/project.py: -------------------------------------------------------------------------------- 1 | """Functions to find the decision boundary of a classifier using bisection. 2 | """ 3 | 4 | __author__ = 'Emanuele Albini' 5 | __all__ = [ 6 | 'find_decision_boundary_bisection', 7 | ] 8 | 9 | import numpy as np 10 | 11 | from ..base import Model, Scaler 12 | from .parallel import batch_process 13 | 14 | 15 | def find_decision_boundary_bisection( 16 | x: np.ndarray, 17 | Y: np.ndarray, 18 | model: Model, 19 | scaler: Scaler = None, 20 | num: int = 1000, 21 | n_jobs: int = 1, 22 | error: float = 1e-10, 23 | method: str = 'mean', 24 | mem_coeff: float = 1, 25 | model_parallelism: int = 1, 26 | desc='Find Decision Boundary Ellipse (batch)', 27 | ) -> np.ndarray: 28 | """Find the decision boundary between x and multiple points Y using bisection. 29 | 30 | Args: 31 | x (np.ndarray): A point. 32 | Y (np.ndarray): An array of points. 33 | model (Model): A model (that implements model.predict(X)) 34 | scaler (Scaler, optional): A scaler. Defaults to None (no scaling). 35 | num (int, optional): The (maximum) number of bisection steps. Defaults to 1000. 36 | n_jobs (int, optional): The number of parallel jobs. Defaults to 1. 37 | error (float, optional): The early stopping error for the bisection. Defaults to 1e-10. 38 | method (str, optional): The method to find the point. Defaults to 'mean'. 39 | - 'left': On the left of the decision boundary (closer to x). 40 | - 'right': On the right of the decision boundary (closer to y). 41 | - 'mean': The mean of the two points (it can be either on the left of right side). 42 | - 'counterfactual': alias for 'right'. 43 | mem_coeff (float, optional): The coefficient for the job split (the higher the bigger every single job will be). Defaults to 1. 44 | model_parallelism (int, optional): Factor to enable parallel bisection. Defaults to 1. 45 | desc (str, optional): TQDM progress bar description. Defaults to 'Find Decision Boundary Ellipse (batch)'. 46 | 47 | Returns: 48 | np.ndarray: The points close to the decision boundary. 49 | 50 | NOTE: This function perform much better when num << 1,000,000 51 | because it batches together multiple points when predicting 52 | (this is also done to avoid memory issues). 53 | """ 54 | 55 | assert model_parallelism >= 1 56 | assert method in ['mean', 'left', 'right', 'counterfactual'] 57 | 58 | if scaler is not None: 59 | assert hasattr(scaler, 'transform'), 'Scaler must have a `transform` method.' 60 | assert hasattr(scaler, 'inverse_transform'), 'Scaler must have a `inverse_transform` method.' 61 | 62 | def _find_decision_boundary_bisection(x: np.ndarray, Y: np.ndarray): 63 | 64 | x = np.asarray(x).copy() 65 | Y = np.asarray(Y).copy() 66 | X = np.tile(x, (Y.shape[0], 1)) 67 | Yrange = np.arange(Y.shape[0]) 68 | 69 | model_parallelism_ = int(np.ceil(model_parallelism / len(Y))) 70 | 71 | # Compute the predictions for the two points and check that they are different 72 | pred_x = model.predict(np.array([x])) 73 | pred_Y = model.predict(np.array(Y)) 74 | assert np.all(pred_x != pred_Y), 'The predictions for x and Y are the same. They must be different.' 75 | 76 | for _ in range(num): 77 | # 'Pure' bisection 78 | if model_parallelism_ == 1: 79 | if scaler is None: 80 | M = (X + Y) / 2 81 | else: 82 | X_ = scaler.transform(X) 83 | Y_ = scaler.transform(Y) 84 | M_ = (X_ + Y_) / 2 85 | M = scaler.inverse_transform(M_) 86 | 87 | # Cast to proper type (e.g., if X and/or Y are integers) with proper precision 88 | if M.dtype != X.dtype: 89 | X = X.astype(M.dtype) 90 | if M.dtype != Y.dtype: 91 | Y = Y.astype(M.dtype) 92 | 93 | preds = (model.predict(M) != pred_x) 94 | different_indexes = np.argwhere(preds) 95 | non_different_indexes = np.argwhere(~preds) 96 | 97 | Y[different_indexes] = M[different_indexes] 98 | X[non_different_indexes] = M[non_different_indexes] 99 | 100 | # Parallel bisection 101 | else: 102 | if scaler is None: 103 | M = np.concatenate(np.linspace(X, Y, model_parallelism_ + 1, endpoint=False)[1:]) 104 | else: 105 | X_ = scaler.transform(X) 106 | Y_ = scaler.transform(Y) 107 | M_ = np.concatenate(np.linspace(X_, Y_, model_parallelism_, endpoint=False)) 108 | M = scaler.inverse_transform(M_) 109 | 110 | # Predict 111 | preds = (model.predict(M) != pred_x).reshape(model_parallelism_, -1).T 112 | 113 | # Rebuild M 114 | M = np.concatenate([ 115 | np.expand_dims(X, axis=0), 116 | M.reshape(model_parallelism_, Y.shape[0], -1), 117 | np.expand_dims(Y, axis=0) 118 | ]) 119 | 120 | # Find next index 121 | left_index = np.array([np.searchsorted(preds[i], 1) for i in range(len(Y))]) 122 | right_index = left_index + 1 123 | 124 | # Update left and right 125 | X = M[left_index, Yrange] 126 | Y = M[right_index, Yrange] 127 | 128 | # Early stopping 129 | err = np.max(np.linalg.norm((X - Y), axis=1, ord=2)) 130 | if err < error: 131 | break 132 | 133 | # The decision boundary is in between the change points 134 | if method == 'mean': 135 | return (X + Y) / 2 136 | # Same predictions 137 | elif method == 'left': 138 | return X 139 | # Counterfactual 140 | elif method == 'right' or method == 'counterfactual': 141 | return Y 142 | else: 143 | raise ValueError('Invalid method.') 144 | 145 | return batch_process( 146 | lambda Y: list(_find_decision_boundary_bisection(x=x, Y=Y)), 147 | iterable=Y, 148 | # Optimal split size (cpu-mem threadoff) 149 | split_size=int(max(1, 100 * 1000 * 1000 / num / (x.shape[0]))) * mem_coeff, 150 | desc=desc, 151 | n_jobs=min(n_jobs, len(Y)), 152 | ) -------------------------------------------------------------------------------- /src/cfshap/utils/random/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Random utilities. 5 | """ 6 | 7 | from ._sample import * 8 | from ._stability import * -------------------------------------------------------------------------------- /src/cfshap/utils/random/_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Random sampling utilities. 5 | """ 6 | 7 | from typing import Union 8 | import numpy as np 9 | import pandas as pd 10 | 11 | __all__ = [ 12 | 'sample_data', 13 | 'np_sample', 14 | ] 15 | 16 | 17 | def sample_data( 18 | X: Union[pd.DataFrame, np.ndarray], 19 | n=None, 20 | frac=None, 21 | random_state=None, 22 | replace=False, 23 | **kwargs, 24 | ) -> Union[pd.DataFrame, np.ndarray]: 25 | assert frac is None or n is None, "Cannot specify both `n` and `frac`" 26 | assert not (frac is None and n is None), "One of `n` or `frac` must be passed." 27 | 28 | if isinstance(X, pd.DataFrame): 29 | return X.sample( 30 | n=n, 31 | frac=frac, 32 | random_state=random_state, 33 | replace=replace, 34 | **kwargs, 35 | ) 36 | elif isinstance(X, np.ndarray): 37 | if frac is not None: 38 | n = int(np.ceil(len(X) * frac)) 39 | return np_sample(X, n=n, replace=replace, random_state=random_state, **kwargs) 40 | else: 41 | raise NotImplementedError('Unsupported dataset type.') 42 | 43 | 44 | def np_sample( 45 | a: Union[np.ndarray, int], 46 | n: Union[int, None], 47 | replace: bool = False, 48 | seed: Union[None, int] = None, 49 | random_state: Union[None, int] = None, 50 | safe: bool = False, 51 | ) -> np.ndarray: 52 | """Randomly sample on axis 0 of a NumPy array 53 | 54 | Args: 55 | a (Union[np.ndarray, int]): The array to be samples, or the integer (max) for an `range` 56 | n (int or None): Number of samples to be draw. If None, it sample all the samples. 57 | replace (bool, optional): Repeat samples or not. Defaults to False. 58 | seed (Union[None, int], optional): Random seed for NumPy. Defaults to None. 59 | random_state (Union[None, int], optional): Alias for seed. Defaults to None. 60 | safe (bool, optional) : Safely handle `n` or not. If True and replace = False, and n > len(a), then n = len(a) 61 | 62 | Returns: 63 | np.ndarray: A random sample 64 | """ 65 | assert random_state is None or seed is None 66 | 67 | if random_state is not None: 68 | seed = random_state 69 | 70 | if seed is not None: 71 | random_state = np.random.RandomState(seed) 72 | else: 73 | random_state = np.random 74 | 75 | # Range case 76 | if isinstance(a, int): 77 | if safe and n > a: 78 | n = a 79 | return random_state.choice(a, n, replace=replace) 80 | # Array sampling case 81 | else: 82 | if n is None: 83 | n = len(a) 84 | if safe and n > len(a): 85 | n = len(a) 86 | return a[random_state.choice(a.shape[0], n, replace=replace)] -------------------------------------------------------------------------------- /src/cfshap/utils/random/_stability.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Random stability (i.e., deterministic execution) utilities. 5 | """ 6 | 7 | import logging 8 | import os 9 | import random 10 | 11 | __all__ = [ 12 | 'random_stability', 13 | ] 14 | 15 | 16 | def random_stability( 17 | seed_value=0, 18 | deterministic=True, 19 | verbose=True, 20 | ): 21 | '''Set random seed/global random states to the specified value for a series of libraries: 22 | 23 | - Python environment 24 | - Python random package 25 | - NumPy/Scipy 26 | - Tensorflow 27 | - Keras 28 | - Torch 29 | 30 | seed_value (int): random seed 31 | deterministic (bool) : negatively effect performance making (parallel) operations deterministic. Default to True. 32 | verbose (bool): Output verbose log. Default to True. 33 | ''' 34 | # pylint: disable=bare-except 35 | 36 | outputs = [] 37 | 38 | # Python environment 39 | os.environ['PYTHONHASHSEED'] = str(seed_value) 40 | outputs.append('PYTHONHASHSEED (env)') 41 | 42 | # Python random 43 | random.seed(seed_value) 44 | outputs.append('random') 45 | 46 | try: 47 | import numpy as np 48 | 49 | np.random.seed(seed_value) 50 | outputs.append('NumPy') 51 | except ModuleNotFoundError: 52 | pass 53 | 54 | # TensorFlow 2 55 | try: 56 | import tensorflow as tf 57 | tf.random.set_seed(seed_value) 58 | if deterministic: 59 | outputs.append('TensorFlow 2 (deterministic)') 60 | tf.config.threading.set_inter_op_parallelism_threads(1) 61 | tf.config.threading.intra_op_parallelism_threads(1) 62 | else: 63 | outputs.append('TensorFlow 2 (parallel, non-deterministic)') 64 | except (ModuleNotFoundError, ImportError, AttributeError): 65 | pass 66 | 67 | # TensorFlow 1 & Keras ? Not sure it works 68 | try: 69 | import tensorflow as tf 70 | if tf.__version__ < '2': 71 | try: 72 | import tensorflow.compat.v1 73 | from tf.compat.v1 import set_random_seed 74 | except (ModuleNotFoundError, ImportError): 75 | from tf import set_random_seed 76 | 77 | try: 78 | import tensorflow.compat.v1 # noqa: F811 79 | from tf.compat.v1 import ConfigProto 80 | except (ModuleNotFoundError, ImportError): 81 | from tf import ConfigProto 82 | 83 | set_random_seed(seed_value) 84 | if deterministic: 85 | outputs.append('TensorFlow 1 (deterministic)') 86 | session_conf = ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1) 87 | else: 88 | outputs.append('TensorFlow 1 (parallel, non-deterministic)') 89 | session_conf = ConfigProto() 90 | session_conf.gpu_options.allow_growth = True 91 | sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) 92 | try: 93 | from keras import backend as K 94 | K.set_session(sess) 95 | outputs.append('Keras') 96 | except (ModuleNotFoundError, ImportError, AttributeError): 97 | 'Keras random stability failed.' 98 | except (ModuleNotFoundError, ImportError, AttributeError): 99 | pass 100 | 101 | try: 102 | import torch 103 | torch.manual_seed(seed_value) 104 | torch.cuda.manual_seed_all(seed_value) 105 | if deterministic: 106 | torch.backends.cudnn.deterministic = True 107 | torch.backends.cudnn.benchmark = False 108 | outputs.append('PyTorch (deterministic)') 109 | else: 110 | outputs.append('PyTorch (parallel, non-deterministic)') 111 | 112 | except (ModuleNotFoundError, ImportError, AttributeError): 113 | pass 114 | # pylint: enable=bare-except 115 | 116 | if verbose: 117 | logging.info('Random seed (%d) set for: %s', seed_value, ", ".join(outputs)) 118 | -------------------------------------------------------------------------------- /src/cfshap/utils/tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Emanuele Albini 3 | 4 | Utilities for tree-based models explainability. 5 | """ 6 | 7 | from collections import defaultdict 8 | import numpy as np 9 | 10 | from ..utils import get_shap_compatible_model 11 | 12 | try: 13 | # Newer versions of SHAPss 14 | from shap.explainers._tree import TreeEnsemble 15 | except (ImportError, ModuleNotFoundError): 16 | # Older versions of SHAP 17 | from shap.explainers.tree import TreeEnsemble 18 | 19 | __all__ = [ 20 | 'TreeEnsemble', 21 | 'get_splits', 22 | 'splits_to_values', 23 | 'SplitTransformer', 24 | ] 25 | 26 | 27 | def get_splits(model, nb_features): 28 | """Extract feature splits from a tree-based model 29 | In order to extract the trees this method rely on the SHAP API: 30 | any model supported by SHAP is supported. 31 | 32 | Args: 33 | model (Any tree-based model): The model 34 | nb_features (int): The number of features (sometimes it cannot be induced by the model) 35 | 36 | Returns: 37 | List[np.ndarray]: A list of splits for each of the features 38 | nb_features x nb_splits (it may different for each feature) 39 | """ 40 | 41 | # Convert the model to something that SHAP API can understand 42 | model = get_shap_compatible_model(model) 43 | 44 | # Extract the trees (using TreeEnseble API from SHAP) 45 | ensemble = TreeEnsemble(model) 46 | 47 | splits = defaultdict(list) 48 | for tree in ensemble.trees: 49 | for i in range(len(tree.features)): 50 | if tree.children_left[i] != -1: # i.e., it is a split (and not a leaf) 51 | splits[tree.features[i]].append(tree.thresholds[i]) 52 | 53 | assert len(splits) <= nb_features 54 | assert max(list(splits.keys())) < nb_features 55 | 56 | return [np.sort(np.array(list(set(splits[i])))) for i in range(nb_features)] 57 | 58 | 59 | def __splits_to_values(splits, how: str, eps: float): 60 | if len(splits) == 0: 61 | return [0] 62 | 63 | if how == 'left': 64 | return ([splits[0] - eps] + [(splits[i] + eps) for i in range(len(splits) - 1)] + [splits[-1] + eps]) 65 | elif how == 'right': 66 | return ([splits[0] - eps] + [(splits[i + 1] - eps) for i in range(len(splits) - 1)] + [splits[-1] + eps]) 67 | elif how == 'center': 68 | return ([splits[0] - eps] + [(splits[i] + splits[i + 1]) / 2 69 | for i in range(len(splits) - 1)] + [splits[-1] + eps]) 70 | else: 71 | raise ValueError('Invalid mode.') 72 | 73 | 74 | def splits_to_values(splits, how: str = 'center', eps: float = 1e-6): 75 | """Convert lists of splits in values (in between splits) 76 | 77 | Args: 78 | splits (List[np.ndarray]): List of splits (one per each feature) obtained using get_feature_splits 79 | how (str): Where should the values be wrt to the splits. Default to 'center'. 80 | 'center': precisely in between the two splits (except for the first and last split) 81 | 'right': split - eps 82 | 'left': split + eps 83 | e.g. if we indicate with | the splits and the values with x 84 | 'center': x| x | x | x | x |x 85 | 'left': x|x |x |x |x |x 86 | 'right': x| x| x| x| x|x 87 | 88 | eps (float, optional): epsilon to use for the conversion of splits to values. Defaults to 1e-6. 89 | 90 | Returns: 91 | List[np.ndarray]: List of values for the features supported based on the splits (one list per feature) 92 | """ 93 | 94 | return [np.unique(np.array(__splits_to_values(s, how=how, eps=eps))) for s in splits] 95 | 96 | 97 | class SplitTransformer: 98 | def __init__(self, model, nb_features=None): 99 | """The initialization is lazy because we may not know yet the number of features""" 100 | if hasattr(model, 'get_booster'): 101 | model = model.get_booster() 102 | self.model = model 103 | self.nb_features = nb_features 104 | self.values = None 105 | self.splits = None 106 | 107 | self._initialize() 108 | 109 | def _initialize(self, X=None): 110 | if self.nb_features is None and X is not None: 111 | self.nb_features = X.shape[1] 112 | if (self.values is None or self.splits is None) and (self.nb_features is not None): 113 | self.splits = get_splits(self.model, self.nb_features) 114 | self.values = splits_to_values(self.splits, how='center') 115 | 116 | def get_nb_value(self, i): 117 | return len(self.splits[i]) + 1 118 | 119 | def transform(self, X): 120 | """Takes inputs and transform them into discrete splits""" 121 | A = np.asarray(X) 122 | self._initialize(A) 123 | for i in range(A.shape[1]): 124 | if len(self.splits[i]) > 0: 125 | A[:, i] = np.digitize(A[:, i], self.splits[i]) 126 | else: 127 | A[:, i] = 0 128 | return A.astype(int) 129 | 130 | def inverse_transform(self, X): 131 | """Takes splits and transform them into continuous inputs 132 | Note: The post-condition T^(-1)(T(X)) == X may not hold. 133 | """ 134 | X_ = np.asarray(X).astype(int) 135 | self._initialize(X) 136 | A = np.full(X.shape, np.nan) 137 | for i in range(X.shape[1]): 138 | A[:, i] = self.values[i][X_[:, i]] 139 | return A 140 | 141 | 142 | # import numba 143 | # try: 144 | # from shap.explainers._tree import TreeEnsemble 145 | # except: 146 | # # Older versions 147 | # from shap.explainers.tree import TreeEnsemble 148 | # @numba.jit(nopython=True, nogil=True) 149 | # def predict_tree(x, region, features, values, thresholds, children_left, children_right, children_default): 150 | # n = 0 151 | # while n != -1: # -1 => Leaf 152 | # f = features[n] 153 | # t = thresholds[n] 154 | # v = x[f] 155 | 156 | # if v < t: # Left 157 | # # Constrain region 158 | # region[f, 1] = np.minimum(region[f, 1], t) 159 | # # print(f'T{i}: x_{f} = {v} < {t}') 160 | 161 | # # Go to child 162 | # n = children_left[n] 163 | 164 | # elif v >= t: # Right 165 | # # Constrain region 166 | # region[f, 0] = np.maximum(region[f, 0], t) 167 | # # print(f'T{i}: x_{f} = {v} >= {t}') 168 | 169 | # # Go to child 170 | # n = children_right[n] 171 | 172 | # else: # Missing 173 | # n = children_default[n] 174 | # region[f] = np.nan 175 | 176 | # return region 177 | 178 | # def predict_region(x, region, ensemble): 179 | # for i, tree in enumerate(ensemble.trees): 180 | # region = predict_tree( 181 | # x, region, tree.features, tree.values, tree.thresholds, tree.children_left, tree.children_right, 182 | # tree.children_default 183 | # ) 184 | # return region 185 | 186 | # def predict_regions(X, model): 187 | # ensemble = TreeEnsemble(model) 188 | # regions = np.tile( 189 | # np.array([-np.inf, np.inf], dtype=ensemble.trees[0].thresholds.dtype), (X.shape[0], X.shape[1], 1) 190 | # ) 191 | # return np.stack([predict_region(x, region, ensemble) for x, region in zip(X, regions)]) -------------------------------------------------------------------------------- /tests/test_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests of basic usage. 3 | """ 4 | 5 | import pytest 6 | import itertools 7 | 8 | from sklearn.datasets import load_boston 9 | from xgboost import XGBClassifier 10 | 11 | from cfshap.utils.preprocessing import EfficientQuantileTransformer 12 | from cfshap.counterfactuals import KNNCounterfactuals 13 | from cfshap.attribution import TreeExplainer, CompositeExplainer 14 | from cfshap.trend import TrendEstimator 15 | 16 | 17 | def test_usage(): 18 | 19 | dataset = load_boston() 20 | X = dataset.data 21 | y = (dataset.target > 21).astype(int) 22 | 23 | model = XGBClassifier() 24 | model.fit(X, y) 25 | 26 | scaler = EfficientQuantileTransformer() 27 | scaler.fit(X) 28 | 29 | trend_estimator = TrendEstimator(strategy='mean') 30 | 31 | explainer = CompositeExplainer( 32 | KNNCounterfactuals( 33 | model=model, 34 | X=X, 35 | n_neighbors=100, 36 | distance='cityblock', 37 | scaler=scaler, 38 | max_samples=10000, 39 | ), 40 | TreeExplainer( 41 | model, 42 | data=None, 43 | trend_estimator=trend_estimator, 44 | max_samples=10000, 45 | ), 46 | ) 47 | 48 | return explainer(X[:10]) 49 | --------------------------------------------------------------------------------