├── .gitignore ├── tests └── explainability │ ├── __init__.py │ ├── model_and_data_pyspark │ ├── gbtModelPySpark │ │ ├── data │ │ │ ├── _SUCCESS │ │ │ ├── ._SUCCESS.crc │ │ │ ├── part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet │ │ │ ├── part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet │ │ │ ├── part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet │ │ │ ├── part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet │ │ │ ├── .part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc │ │ │ ├── .part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc │ │ │ ├── .part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc │ │ │ └── .part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc │ │ ├── metadata │ │ │ ├── _SUCCESS │ │ │ ├── ._SUCCESS.crc │ │ │ ├── .part-00000.crc │ │ │ └── part-00000 │ │ └── treesMetadata │ │ │ ├── _SUCCESS │ │ │ ├── ._SUCCESS.crc │ │ │ ├── part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet │ │ │ ├── part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet │ │ │ ├── part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet │ │ │ ├── part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet │ │ │ ├── .part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc │ │ │ ├── .part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc │ │ │ ├── .part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc │ │ │ └── .part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc │ └── assemblerScaled │ │ └── metadata │ │ ├── _SUCCESS │ │ ├── ._SUCCESS.crc │ │ ├── .part-00000.crc │ │ └── part-00000 │ ├── model_and_data │ └── FICO_lr_model.pkl │ ├── test_pyspark_wrapper.py │ ├── test_strategy.py │ ├── test_counterfactuals.py │ ├── test_counterfactualproto.py │ ├── test_ale.py │ ├── test_serializer.py │ ├── test_explanation.py │ ├── test_counterfactual_basic.py │ ├── test_anchors_extended.py │ ├── test_clustering_tree_explainer.py │ ├── test_partial_dependence.py │ ├── test_shuffle_importance.py │ └── conftest.py ├── MANIFEST.in ├── docs ├── reference │ ├── explainers.md │ └── explanations.md └── index.md ├── requirements.txt ├── mercury └── explainability │ ├── explainers │ ├── _tree_splitters │ │ ├── __init__.py │ │ └── cut_finder.pyx │ ├── explainer.py │ ├── __init__.py │ ├── _dummy_alibi_explainers.py │ ├── shuffle_importance.py │ ├── counter_fact_basic.py │ ├── anchors.py │ ├── ale.py │ └── cf_strategies.py │ ├── explanations │ ├── __init__.py │ ├── shuffle_importance.py │ ├── anchors.py │ ├── clustering_tree_explanation.py │ ├── partial_dependence.py │ └── counter_factual.py │ ├── __init__.py │ ├── create_tutorials.py │ └── pyspark_utils.py ├── CHANGELOG.md ├── .bumpversion.cfg ├── setup.py ├── .github └── workflows │ ├── pypi_upload.yml │ └── test.yml ├── mkdocs.yml ├── pyproject.toml ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | -------------------------------------------------------------------------------- /tests/explainability/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/_SUCCESS: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/assemblerScaled/metadata/_SUCCESS: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/metadata/_SUCCESS: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/_SUCCESS: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft mercury/explainability/ 2 | 3 | recursive-include tutorials * 4 | -------------------------------------------------------------------------------- /docs/reference/explainers.md: -------------------------------------------------------------------------------- 1 | # Base Tests 2 | 3 | ::: mercury.explainability.explainers -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/._SUCCESS.crc: -------------------------------------------------------------------------------- 1 | crc -------------------------------------------------------------------------------- /docs/reference/explanations.md: -------------------------------------------------------------------------------- 1 | # Explanations 2 | 3 | ::: mercury.explainability.explanations -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/assemblerScaled/metadata/._SUCCESS.crc: -------------------------------------------------------------------------------- 1 | crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/metadata/._SUCCESS.crc: -------------------------------------------------------------------------------- 1 | crc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | numpy 3 | bokeh 4 | simanneal 5 | dill 6 | Cython 7 | graphviz 8 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/._SUCCESS.crc: -------------------------------------------------------------------------------- 1 | crc -------------------------------------------------------------------------------- /mercury/explainability/explainers/_tree_splitters/__init__.py: -------------------------------------------------------------------------------- 1 | from cut_finder import get_min_mistakes_cut 2 | from cut_finder import get_min_surrogate_cut 3 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data/FICO_lr_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data/FICO_lr_model.pkl -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/assemblerScaled/metadata/.part-00000.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/assemblerScaled/metadata/.part-00000.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/metadata/.part-00000.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/metadata/.part-00000.crc -------------------------------------------------------------------------------- /mercury/explainability/explanations/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchors import AnchorsWithImportanceExplanation 2 | from .counter_factual import ( 3 | CounterfactualBasicExplanation, 4 | CounterfactualWithImportanceExplanation 5 | ) 6 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## Latest version 1.1.4 2 | 3 | | Release | Date | Main feature(s) | 4 | | -------- | ---- | --------------- | 5 | | 1.1.4 | 2025/03/25 | Implements create_tutorials(), minor fixes, adds support for python 3.12, improves documentation. | 6 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00000-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00001-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00002-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/data/.part-00003-76edbf1d-359e-4633-927a-8de47c6d57b3-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00000-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00001-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00002-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BBVA/mercury-explainability/HEAD/tests/explainability/model_and_data_pyspark/gbtModelPySpark/treesMetadata/.part-00003-96219efb-a895-4f33-82d1-347c6f01251c-c000.snappy.parquet.crc -------------------------------------------------------------------------------- /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.1.4 3 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(-(?P\w+)\.(?P\d+))? 4 | serialize = 5 | {major}.{minor}.{patch}-{release}.{build} 6 | {major}.{minor}.{patch} 7 | commit = True 8 | tag = True 9 | 10 | [bumpversion:file:mercury/explainability/__init__.py] 11 | 12 | [bumpversion:file:README.md] 13 | 14 | [bumpversion:file:docs/index.md] 15 | 16 | [bumpversion:file:pyproject.toml] -------------------------------------------------------------------------------- /mercury/explainability/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.1.4' 2 | 3 | from .explainers.counter_fact_basic import CounterFactualExplainerBasic 4 | from .explainers.shuffle_importance import ShuffleImportanceExplainer 5 | from .explainers.explainer import MercuryExplainer 6 | from .explainers.partial_dependence import PartialDependenceExplainer 7 | from .explanations.anchors import AnchorsWithImportanceExplanation 8 | from .explanations.counter_factual import CounterfactualWithImportanceExplanation 9 | 10 | from .explainers import ALEExplainer 11 | from .explainers import AnchorsWithImportanceExplainer 12 | from .explainers import CounterfactualExplainer, CounterfactualProtoExplainer 13 | 14 | from .create_tutorials import create_tutorials 15 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/gbtModelPySpark/metadata/part-00000: -------------------------------------------------------------------------------- 1 | {"class":"org.apache.spark.ml.classification.GBTClassificationModel","timestamp":1595936771762,"sparkVersion":"2.4.3","uid":"GBTClassifier_bd1511e5939c","paramMap":{"maxIter":10,"featuresCol":"scaledFeatures","labelCol":"label"},"defaultParamMap":{"maxIter":20,"predictionCol":"prediction","stepSize":0.1,"seed":-2472803760548765012,"lossType":"logistic","subsamplingRate":1.0,"rawPredictionCol":"rawPrediction","checkpointInterval":10,"probabilityCol":"probability","maxMemoryInMB":256,"cacheNodeIds":false,"featuresCol":"features","maxDepth":5,"impurity":"gini","minInstancesPerNode":1,"maxBins":32,"labelCol":"label","validationTol":0.01,"minInfoGain":0.0,"featureSubsetStrategy":"all"},"numFeatures":20,"numTrees":10} 2 | -------------------------------------------------------------------------------- /tests/explainability/model_and_data_pyspark/assemblerScaled/metadata/part-00000: -------------------------------------------------------------------------------- 1 | {"class":"org.apache.spark.ml.feature.VectorAssembler","timestamp":1595936791697,"sparkVersion":"2.4.3","uid":"VectorAssembler_3d2fe536534a","paramMap":{"outputCol":"scaledFeatures","inputCols":["ExternalRiskEstimate","MSinceOldestTradeOpen","MSinceMostRecentTradeOpen","AverageMInFile","NumSatisfactoryTrades","NumTrades60Ever2DerogPubRec","NumTrades90Ever2DerogPubRec","PercentTradesNeverDelq","MaxDelq2PublicRecLast12M","MaxDelqEver","NumTotalTrades","NumTradesOpeninLast12M","PercentInstallTrades","MSinceMostRecentInqexcl7days","NumInqLast6M","NetFractionRevolvingBurden","NumRevolvingTradesWBalance","NumInstallTradesWBalance","NumBank2NatlTradesWHighUtilization","PercentTradesWBalance"]},"defaultParamMap":{"outputCol":"VectorAssembler_3d2fe536534a__output","handleInvalid":"error"}} 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | 3 | 4 | # Move tutorials inside mercury.explainability before packaging 5 | if os.path.exists('tutorials'): 6 | shutil.move('tutorials', 'mercury/explainability/tutorials') 7 | 8 | 9 | from setuptools import setup, find_packages, Extension 10 | 11 | import numpy 12 | 13 | 14 | setup_args = dict( 15 | name = 'mercury-explainability', 16 | packages = find_packages(include = ['mercury*', 'tutorials*']), 17 | include_package_data = True, 18 | package_data = {'mypackage': ['tutorials/*', 'tutorials/data/*']}, 19 | ext_modules = [Extension('cut_finder', ['mercury/explainability/explainers/_tree_splitters/cut_finder.pyx'], 20 | extra_compile_args = ['-fopenmp', '-O2'], 21 | extra_link_args = ['-fopenmp'], 22 | include_dirs = [numpy.get_include()] 23 | ) 24 | ] 25 | 26 | ) 27 | 28 | setup(**setup_args) 29 | -------------------------------------------------------------------------------- /.github/workflows/pypi_upload.yml: -------------------------------------------------------------------------------- 1 | name: PyPi Upload 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v3 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.10' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install wheel 20 | pip install build 21 | pip install numpy 22 | - name: Build packages 23 | run: | 24 | python setup.py sdist 25 | - name: Publish package 26 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 27 | with: 28 | user: ${{ secrets.pypi_user }} 29 | password: ${{ secrets.pypi_password }} 30 | packages_dir: ./dist/ 31 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/explainer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import os 4 | import dill 5 | 6 | class MercuryExplainer(ABC): 7 | 8 | @abstractmethod 9 | def explain(self, data): 10 | pass 11 | 12 | def save(self, filename: str = "explainer.pkl"): 13 | """ 14 | Saves the explainer with its internal state to a file. 15 | 16 | Args: 17 | filename (str): Path where the explainer will be saved 18 | """ 19 | with open(filename, 'wb') as f: 20 | dill.dump(self, f) 21 | assert os.path.isfile(filename), "Error storing file" 22 | 23 | @classmethod 24 | def load(self, filename: str = "explainer.pkl"): 25 | """ 26 | Loads a previosly saved explainer with its internal state to a file. 27 | 28 | Args: 29 | filename (str): Path where the explainer is stored 30 | """ 31 | assert os.path.isfile(filename), "File does not exist or not a valid file" 32 | with open(filename, 'rb') as f: 33 | return dill.load(f) 34 | 35 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: mercury-explainability 2 | repo_url: https://github.com/BBVA/mercury-explainability/ 3 | repo_name: mercury-explainability 4 | theme: 5 | name: material 6 | features: 7 | - tabs 8 | - navigation.indexes 9 | icon: 10 | logo: material/book-open-page-variant 11 | repo: fontawesome/brands/github 12 | site_dir: site 13 | nav: 14 | - Home: index.md 15 | - Api: 16 | - explainers: reference/explainers.md 17 | - explanations: reference/explanations.md 18 | markdown_extensions: 19 | - codehilite 20 | - admonition 21 | - pymdownx.superfences 22 | - pymdownx.arithmatex: 23 | generic: true 24 | extra_css: 25 | - stylesheets/extra.css 26 | extra_javascript: 27 | - javascripts/config.js 28 | - https://polyfill.io/v3/polyfill.min.js?features=es6 29 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 30 | plugins: 31 | - mkdocstrings: 32 | handlers: 33 | python: 34 | options: 35 | show_root_heading: true 36 | show_submodules: true 37 | merge_init_into_class: true 38 | docstring_style: google 39 | dev_addr: 0.0.0.0:8080 -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Mercury-Explainability 2 | 3 | on: 4 | push: 5 | branches: [ "master", "develop" ] 6 | pull_request: 7 | branches: [ "master", "develop" ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install package 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install flake8 pytest build 27 | python -m pip install -e .[dev] 28 | - name: Lint with flake8 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --max-line-length=140 --statistics 32 | - name: Test with pytest 33 | run: | 34 | pytest 35 | - name: Test build 36 | run: | 37 | python -m build -------------------------------------------------------------------------------- /mercury/explainability/create_tutorials.py: -------------------------------------------------------------------------------- 1 | import os, pkg_resources, shutil 2 | 3 | 4 | def create_tutorials(destination, silent = False): 5 | """ 6 | Copies mercury.explainability tutorial notebooks to `destination`. A folder will be created inside 7 | destination, named 'explainability_tutorials'. The folder `destination` must exist. 8 | 9 | Args: 10 | destination (str): The destination directory 11 | silent (bool): If True, suppresses output on success. 12 | 13 | Raises: 14 | ValueError: If `destination` is equal to source path. 15 | 16 | Examples: 17 | >>> # copy tutorials to /tmp/explainability_tutorials 18 | >>> from mercury.explainability import create_tutorials 19 | >>> create_tutorials('/tmp') 20 | 21 | """ 22 | src = pkg_resources.resource_filename(__package__, 'tutorials') 23 | dst = os.path.abspath(destination) 24 | 25 | assert src != dst, 'Destination (%s) cannot be the same as source.' % src 26 | 27 | assert os.path.isdir(dst), 'Destination (%s) must be a directory.' % dst 28 | 29 | dst = os.path.join(dst, 'explainability_tutorials') 30 | 31 | assert not os.path.exists(dst), 'Destination (%s) already exists' % dst 32 | 33 | shutil.copytree(src, dst) 34 | 35 | if not silent: 36 | print('Tutorials copied to: %s' % dst) 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "numpy", "Cython"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mercury-explainability" 7 | license = {file = "LICENSE.txt"} 8 | version = "1.1.4" 9 | authors = [ 10 | { name="Mercury Team", email="mercury.group@bbva.com" }, 11 | ] 12 | description = "Mercury's explainability is a library with implementations of different state-of-the-art methods in the field of explainability" 13 | readme = "README.md" 14 | requires-python = ">=3.7" 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Operating System :: OS Independent", 19 | ] 20 | dependencies = [ 21 | 'pandas', 22 | 'numpy', 23 | 'bokeh', 24 | 'simanneal', 25 | 'shap', 26 | 'dill', 27 | 'Cython', 28 | 'graphviz', 29 | ] 30 | 31 | [project.optional-dependencies] 32 | dev = [ 33 | 'pyspark', 34 | 'pytest', 35 | 'flake8', 36 | 'scikit-learn', 37 | 'alibi', 38 | 'tensorflow' 39 | ] 40 | doc = [ 41 | 'mkdocs', 42 | 'mkdocstrings[python]', 43 | 'mkdocs-material', 44 | 'mkdocs-minify-plugin==0.5.0', 45 | 'mkdocs-exclude', 46 | 'nbconvert', 47 | ] 48 | 49 | [tool.setuptools.packages.find] 50 | include = ["mercury*"] 51 | exclude = ["docs*", "tests*", "tutorials"] 52 | 53 | 54 | [project.urls] 55 | "Homepage" = "https://github.com/BBVA/mercury-explainability" 56 | "Bug Tracker" = "https://github.com/BBVA/mercury-explainability/issues" 57 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # mercury-explainability 2 | 3 | [![](https://github.com/BBVA/mercury-explainability/actions/workflows/test.yml/badge.svg)](https://github.com/BBVA/mercury-explainability) 4 | ![](https://img.shields.io/badge/latest-1.1.4-blue) 5 | 6 | ***mercury-explainability*** is a library with implementations of different state-of-the-art methods in the field of explainability. They are designed to work efficiently and to be easily integrated with the main Machine Learning frameworks. 7 | 8 | ## Mercury project at BBVA 9 | 10 | Mercury is a collaborative library that was developed by the Advanced Analytics community at BBVA. Originally, it was created as an [InnerSource](https://en.wikipedia.org/wiki/Inner_source) project but after some time, we decided to release certain parts of the project as Open Source. 11 | That's the case with the `mercury-explainability` package. 12 | 13 | If you're interested in learning more about the Mercury project, we recommend reading this blog [post](https://www.bbvaaifactory.com/mercury-acelerando-la-reutilizacion-en-ciencia-de-datos-dentro-de-bbva/) from www.bbvaaifactory.com 14 | 15 | ## User installation 16 | 17 | The easiest way to install `mercury-explainability` is using ``pip``: 18 | 19 | pip install -U mercury-explainability 20 | 21 | ## Help and support 22 | 23 | This library is currently maintained by a dedicated team of data scientists and machine learning engineers from BBVA AI Factory. 24 | 25 | ### Documentation 26 | website: https://bbva.github.io/mercury-explainability/site/ 27 | 28 | ### Email 29 | mercury.group@bbva.com -------------------------------------------------------------------------------- /mercury/explainability/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | import signal 2 | 3 | 4 | def run_until_timeout(timeout, fn, *args, **kwargs): 5 | """ 6 | Timeout function to stop the execution in case it takes longer than the timeout argument. 7 | After timeout seconds it will raise an exception. 8 | 9 | Args: 10 | timeout (int): Number of seconds until the Exception is raised. 11 | fn (callable): Function to execute. 12 | *args (dict): args of the function fn. 13 | **kwargs (dict): keyword args passed to fn. 14 | 15 | Example: 16 | >>> explanation = run_until_timeout(timeout, explainer.explain, data=data) 17 | 18 | """ 19 | def signal_handler(signum, frame): 20 | raise Exception('Timeout: explanation took too long...') 21 | 22 | if timeout > 0: 23 | signal.signal(signal.SIGALRM, signal_handler) 24 | signal.alarm(timeout) 25 | else: 26 | signal.alarm(0) 27 | return fn(*args, **kwargs) 28 | 29 | 30 | from .counter_fact_basic import CounterFactualExplainerBasic 31 | from .shuffle_importance import ShuffleImportanceExplainer 32 | from .partial_dependence import PartialDependenceExplainer 33 | from .clustering_tree_explainer import ClusteringTreeExplainer 34 | from .explainer import MercuryExplainer 35 | 36 | # Classes with alibi dependencies 37 | try: 38 | from .ale import ALEExplainer 39 | from .anchors import AnchorsWithImportanceExplainer 40 | from .counter_fact_importance import CounterfactualExplainer, CounterfactualProtoExplainer 41 | except ModuleNotFoundError: 42 | # if import fails, then import Dummy Class with the same name which raises Error if instantiated 43 | from ._dummy_alibi_explainers import ( 44 | ALEExplainer, 45 | AnchorsWithImportanceExplainer, 46 | CounterfactualExplainer, 47 | CounterfactualProtoExplainer 48 | ) 49 | -------------------------------------------------------------------------------- /mercury/explainability/explanations/shuffle_importance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class FeatureImportanceExplanation(): 6 | def __init__(self, data:dict, reverse:bool = False): 7 | """ 8 | This class holds the data related to the importance a given 9 | feature has for a model. 10 | 11 | Args: 12 | data (dict): 13 | Contains the result of the PartialDependenceExplainer. It must be in the 14 | form of: :: 15 | { 16 | 'feature_name': 1.0, 17 | 'feature_name2': 2.3, ... 18 | } 19 | 20 | reverse (bool): 21 | Whether to reverse sort the features by increasing order (i.e. Worst 22 | performance (latest) = Smallest value). Default False (decreasing order). 23 | """ 24 | self.data = data 25 | self._sorted_features = sorted(list(data.items()), key=lambda i: i[1], 26 | reverse=not reverse) 27 | 28 | def plot(self, ax: "matplotlib.axes.Axes" = None, # noqa:F821 29 | figsize: tuple = (15, 15), limit_axis_x=False, **kwargs) -> "matplotlib.axes.Axes": # noqa:F821 30 | """ 31 | Plots a summary of the importances for each feature 32 | 33 | Args: 34 | figsize (tuple): Size of the plotted figure 35 | limit_axis_x (bool): Whether to adjust axis x to limit between the minimum and maximum feature values 36 | """ 37 | ax = ax if ax else plt.gca() 38 | 39 | feature_names = [i[0] for i in self._sorted_features] 40 | feature_values = [i[1] for i in self._sorted_features] 41 | ax.barh(feature_names, feature_values) 42 | 43 | if limit_axis_x: 44 | ax.set_xlim(min(feature_values), max(feature_values)) 45 | 46 | return ax 47 | 48 | def __getitem__(self, key:str)->float: 49 | """ 50 | Gets the feature importance of the desired feature. 51 | 52 | Args: 53 | key (str): Name of the feature. 54 | """ 55 | return self.data[key] 56 | 57 | def get_importances(self)->list: 58 | """ Returns a list of tuples (feature, importance) sorted by importances. 59 | """ 60 | return self._sorted_features 61 | -------------------------------------------------------------------------------- /tests/explainability/test_pyspark_wrapper.py: -------------------------------------------------------------------------------- 1 | from mercury.explainability.pyspark_utils import SparkWrapper 2 | 3 | import pytest 4 | 5 | def test_predict_np(spark_session, model_and_data): 6 | gbtModel = model_and_data['gbtModel'] 7 | assembler = model_and_data['assembler'] 8 | data_pd = model_and_data['data_pd'] 9 | 10 | shape = (1, data_pd.shape[1]) 11 | feat_names = list(data_pd.columns) 12 | 13 | wrap = SparkWrapper(gbtModel, 14 | feat_names, 15 | spark_session, 16 | model_inp_name='scaledFeatures', 17 | model_out_name='probability', 18 | vector_assembler=assembler 19 | ) 20 | 21 | out = wrap(data_pd.head(1).values).flatten() 22 | 23 | assert ( 24 | out[0] == pytest.approx(0.87631167, 0.01) and 25 | out[1] == pytest.approx(0.12368833, 0.01) 26 | ) 27 | 28 | def test_predict_pandas(spark_session, model_and_data): 29 | gbtModel = model_and_data['gbtModel'] 30 | assembler = model_and_data['assembler'] 31 | data_pd = model_and_data['data_pd'] 32 | 33 | shape = (1, data_pd.shape[1]) 34 | feat_names = list(data_pd.columns) 35 | 36 | wrap = SparkWrapper(gbtModel, 37 | feat_names, 38 | spark_session, 39 | model_inp_name='scaledFeatures', 40 | model_out_name='probability', 41 | #vector_assembler=assembler 42 | ) 43 | 44 | out = wrap(data_pd.head(1)).flatten() 45 | 46 | assert ( 47 | out[0] == pytest.approx(0.87631167, 0.01) and 48 | out[1] == pytest.approx(0.12368833, 0.01) 49 | ) 50 | 51 | 52 | def test_normalize_outs(spark_session, model_and_data): 53 | gbtModel = model_and_data['gbtModel'] 54 | assembler = model_and_data['assembler'] 55 | data_pd = model_and_data['data_pd'] 56 | 57 | shape = (1, data_pd.shape[1]) 58 | feat_names = list(data_pd.columns) 59 | 60 | wrap = SparkWrapper(gbtModel, 61 | feat_names, 62 | spark_session, 63 | model_inp_name='scaledFeatures', 64 | model_out_name='probability', 65 | probability_threshold=0.5 66 | ) 67 | 68 | out = wrap(data_pd.head(1)).flatten() 69 | 70 | assert ( 71 | out[0] == pytest.approx(0.87631167, 0.01) and 72 | out[1] == pytest.approx(0.12368833, 0.01) 73 | ) 74 | -------------------------------------------------------------------------------- /tests/explainability/test_strategy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import typing as TP 3 | import functools 4 | 5 | from mercury.explainability.explainers import cf_strategies as strat 6 | 7 | import pytest 8 | import numpy as np 9 | 10 | 11 | class StrategyTests(unittest.TestCase): 12 | 13 | # The 2D-Rosenbrock function à la classifier output 14 | def _rosenbrock(self, x: 'np.ndarray', a: int=1, b: int=100) -> 'np.ndarray': 15 | assert x.shape[1] == 2, 'Invalid input for Rosenbrock function' 16 | 17 | # return a numpy array, instead of a scalar to simulate a 18 | # classifier "probability" (2nd field will be ignored) 19 | def _rosenbrock(x): 20 | return np.array([(a - x[0])**2 + b * (x[1] - x[0]**2)**2]) 21 | 22 | return np.apply_along_axis(_rosenbrock, 1, x) 23 | # return np.array([[((a - x[:, 0])**2 + b * (x[:, 1] - x[:, 0]**2)**2)[0], 0.]]) 24 | 25 | def _strategy_with(self, fn: TP.Callable[['np.ndarray'], 'np.ndarray'], strategy: str='backtracking') -> None: 26 | if strategy == 'backtracking': 27 | s1 = strat.Backtracking(self.x0, self.bounds, self.backtracking_step, fn, self.class_idx, self.threshold, self.kernel) 28 | x, p, visited, explored = s1.run(max_iter=self.max_iter) 29 | assert abs(x[0] - 1) <= self.diff and abs(x[1] - 1) <= self.diff 30 | else: 31 | s2 = strat.SimulatedAnnealing(self.x0, self.bounds, self.simanneal_step, fn, self.class_idx, self.threshold, self.kernel) # type: strat.SimulatedAnnealing 32 | x, p, visited, _ = s2.run(tmax=20, tmin=1e-4, steps=5e4) 33 | assert abs(x[0] - 1) <= self.diff and abs(x[1] - 1) <= self.diff 34 | 35 | @pytest.fixture(autouse=True) 36 | def deterministic(self): 37 | np.random.seed(1) 38 | self.x0 = np.random.uniform(-3, 3, size=2) 39 | self.bounds = np.array([[-4, 4], [-4, 4]]) 40 | self.backtracking_step = np.array([0.01, 0.01]) 41 | self.simanneal_step = np.array([0.1, 0.1]) 42 | self.class_idx = 0 43 | self.threshold = 0. 44 | self.kernel = np.ones(2) 45 | self.max_iter = 2000 46 | self.diff = 1e-2 47 | 48 | # @pytest.mark.skip(reason='Not finding global minima yet') 49 | def test_backtracking_rosenbrock(self) -> None: 50 | self._strategy_with(functools.partial(self._rosenbrock, a=1, b=100)) # type: ignore 51 | 52 | def test_simanneal_rosenbrock(self) -> None: 53 | self._strategy_with(functools.partial(self._rosenbrock, a=1, b=100), strategy='simanneal') # type: ignore 54 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/_dummy_alibi_explainers.py: -------------------------------------------------------------------------------- 1 | # This module contains dummy versions of Explainers with Alibi dependencies. 2 | # This classes are imported in the __init__.py if the import of the original class fails 3 | # If the user tries to import any of the classes, and error is raised indicating that alibi 4 | # must be installed first 5 | 6 | class _DummyAlibiExplainer: 7 | """ 8 | Class which raises and error if instantiated 9 | """ 10 | def __init__(self): 11 | raise ModuleNotFoundError("You need to install alibi library to use this explainer.") 12 | 13 | class ALEExplainer(_DummyAlibiExplainer): 14 | 15 | def __init__(self, predictor, target_names): 16 | super().__init__() 17 | 18 | class AnchorsWithImportanceExplainer(_DummyAlibiExplainer): 19 | 20 | def __init__(self, predict_fn=None, train_data=None, categorical_names=None, disc_perc=None): 21 | super().__init__() 22 | 23 | class CounterfactualExplainer(_DummyAlibiExplainer): 24 | 25 | def __init__( 26 | self, 27 | predict_fn=None, 28 | feature_names=None, 29 | shape=None, 30 | drop_features=None, 31 | distance_fn=None, 32 | target_proba=None, 33 | target_class=None, 34 | max_iter=None, 35 | early_stop=None, 36 | lam_init=None, 37 | max_lam_steps=None, 38 | tol=None, 39 | learning_rate_init=None, 40 | feature_range=None, 41 | eps=None, 42 | init=None, 43 | decay=None, 44 | write_dir=None, 45 | debug=None, 46 | sess=None 47 | ): 48 | super().__init__() 49 | 50 | class CounterfactualProtoExplainer(_DummyAlibiExplainer): 51 | 52 | def __init__( 53 | self, 54 | predict_fn=None, 55 | train_data=None, 56 | feature_names=None, 57 | shape=None, 58 | drop_features=None, 59 | kappa=None, 60 | beta=None, 61 | feature_range=None, 62 | gamma=None, 63 | ae_model=None, 64 | enc_model=None, 65 | theta=None, 66 | cat_vars=None, 67 | ohe=None, 68 | use_kdtree=None, 69 | learning_rate_init=None, 70 | max_iterations=None, 71 | c_init=None, 72 | c_steps=None, 73 | eps=None, 74 | clip=None, 75 | update_num_grad=None, 76 | write_dir=None, 77 | sess=None, 78 | trustscore_kwargs=None, 79 | d_type=None, 80 | w=None, 81 | disc_perc=None, 82 | standardize_cat_vars=None, 83 | smooth=None, 84 | center=None, 85 | update_feature_range=None 86 | ): 87 | super().__init__() 88 | -------------------------------------------------------------------------------- /tests/explainability/test_counterfactuals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from mercury.explainability.explainers import CounterfactualExplainer 4 | import pytest 5 | import pickle 6 | import tensorflow as tf 7 | 8 | 9 | @pytest.fixture(scope='session') 10 | def model_and_data(): 11 | tf.compat.v1.disable_eager_execution() 12 | logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 13 | fit_data = pd.read_csv('./tests/explainability/model_and_data/fit_data_red.csv', index_col=0) 14 | explain_data = pd.read_csv('./tests/explainability/model_and_data/explain_data.csv', index_col=0) 15 | return { 16 | 'logRegModel': logRegModel, 17 | 'fit_data': fit_data, 18 | 'explain_data': explain_data 19 | } 20 | 21 | 22 | pytestmark = pytest.mark.usefixtures('model_and_data') 23 | 24 | 25 | # def test_cf_explain(model_and_data): 26 | # # Test it does not crash during __init__ (where fit is done) and 27 | # # that returns some counterfactual. 28 | # model = model_and_data['logRegModel'] 29 | # fit_data = model_and_data['fit_data'] 30 | # explain_data = model_and_data['explain_data'] 31 | # feature_names = list(explain_data.columns) 32 | 33 | # cfExplainer = CounterfactualExplainer( 34 | # predict_fn=model.predict_proba, 35 | # feature_names=feature_names 36 | # ) 37 | 38 | # cf_explanation = cfExplainer.explain(explain_data.head(1).values) 39 | 40 | # assert ( 41 | # cf_explanation.data['cf']['X'][0][0] == pytest.approx(1.3380102, 0.01) and 42 | # cf_explanation.data['cf']['X'][0][1] == pytest.approx(-1.5059024, 0.01) 43 | # ) 44 | 45 | 46 | # def test_cf_feature_importance(model_and_data): 47 | # # Test feat importance 48 | # model = model_and_data['logRegModel'] 49 | # fit_data = model_and_data['fit_data'] 50 | # explain_data = model_and_data['explain_data'] 51 | # feature_names = list(explain_data.columns) 52 | 53 | # tf.compat.v1.reset_default_graph() 54 | # cfExplainer = CounterfactualExplainer( 55 | # predict_fn=model.predict_proba, 56 | # feature_names=feature_names 57 | # ) 58 | 59 | # explanations = cfExplainer.get_feature_importance( 60 | # explain_data.head(3) 61 | # ) 62 | 63 | # assert ( 64 | # explanations.importances[0][0] == 'ExternalRiskEstimate' and 65 | # explanations.count_diffs_norm['ExternalRiskEstimate'] == 1.0 66 | # ) 67 | 68 | 69 | def test_cf_explain(model_and_data): 70 | # Skipped due to alibi (which is optional) expecting tensorflow 1.xx 71 | pass 72 | 73 | 74 | def test_cf_feature_importance(model_and_data): 75 | # Skipped due to alibi (which is optional) expecting tensorflow 1.xx 76 | pass 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mercury-explainability 2 | 3 | [![](https://github.com/BBVA/mercury-explainability/actions/workflows/test.yml/badge.svg)](https://github.com/BBVA/mercury-explainability) 4 | ![](https://img.shields.io/badge/latest-1.1.4-blue) 5 | [![Python 3.8](https://img.shields.io/badge/python-3.8-blue.svg)](https://www.python.org/downloads/release/python-3816/) 6 | [![Python 3.9](https://img.shields.io/badge/python-3.9-blue.svg)](https://www.python.org/downloads/release/python-3916/) 7 | [![Python 3.10](https://img.shields.io/badge/python-3.10-blue.svg)](https://www.python.org/downloads/release/python-31011/) 8 | [![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-3119/) 9 | [![Python 3.12](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/downloads/release/python-3128/) 10 | [![Apache 2 license](https://shields.io/badge/license-Apache%202-blue)](http://www.apache.org/licenses/LICENSE-2.0) 11 | [![Ask Me Anything !](https://img.shields.io/badge/Ask%20me-anything-1abc9c.svg)](https://github.com/BBVA/mercury-explainability/issues) 12 | 13 | ***mercury-explainability*** is a library with implementations of different state-of-the-art methods in the field of explainability. They are designed to work efficiently and to be easily integrated with the main Machine Learning frameworks. 14 | 15 | ## Mercury project at BBVA 16 | 17 | Mercury is a collaborative library that was developed by the Advanced Analytics community at BBVA. Originally, it was created as an [InnerSource](https://en.wikipedia.org/wiki/Inner_source) project but after some time, we decided to release certain parts of the project as Open Source. 18 | That's the case with the `mercury-explainability` package. 19 | 20 | The basic block of ***mercury-explainability*** is the `Explainer` class. Each one of the explainers in ***mercury-explainability*** offers a different method for explaining your models and often will return an `Explanation` type object containing the result of that particular explainer. 21 | 22 | The usage of most of the explainers you will find in this library follows this schema: 23 | 24 | ```python 25 | from mercury.explainability import ExplainerExample 26 | explainer = ExplainerExample(function_to_explain) 27 | explanation = explainer.explain(dataset) 28 | ``` 29 | 30 | Basically, you simply need to instantiate your desired `Explainer` (note that the above example `ExplainerExample` does not exist) 31 | providing your custom function you desire to get an explanation for, which usually will be your model’s inference or evaluation function. 32 | These explainers are ready to work efficiently with most of the frameworks you will likely use as a data scientist (yes, included *Spark*). 33 | 34 | If you're interested in learning more about the Mercury project, we recommend reading this blog [post](https://www.bbvaaifactory.com/mercury-acelerando-la-reutilizacion-en-ciencia-de-datos-dentro-de-bbva/) from www.bbvaaifactory.com 35 | 36 | ## User installation 37 | 38 | The easiest way to install `mercury-explainability` is using ``pip``: 39 | 40 | pip install -U mercury-explainability 41 | 42 | ## Help and support 43 | 44 | This library is currently maintained by a dedicated team of data scientists and machine learning engineers from BBVA. 45 | 46 | ### Documentation 47 | website: https://bbva.github.io/mercury-explainability/site/ 48 | 49 | ### Email 50 | mercury.group@bbva.com 51 | -------------------------------------------------------------------------------- /mercury/explainability/explanations/anchors.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import numpy as np 3 | import pandas as pd 4 | import heapq 5 | import itertools 6 | 7 | class AnchorsWithImportanceExplanation(object): 8 | """ 9 | Extended Anchors Explanations 10 | 11 | Args: 12 | explain_data: 13 | A pandas DataFrame containing the observations for which an explanation has to be found. 14 | explanations: 15 | A list containing the results of computing the explanations for explain_data. 16 | categorical: 17 | A dictionary containing as key the features that are categorical and as value, the possible 18 | categorical values. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | explain_data: pd.DataFrame, 24 | explanations: TP.List, 25 | categorical: dict = {} 26 | ) -> None: 27 | self.explain_data = explain_data 28 | self.explanations = explanations 29 | self.categorical = categorical 30 | 31 | def interpret_explanations(self, n_important_features: int) -> str: 32 | """ 33 | This method prints a report of the important features obtaiend. 34 | 35 | Args: 36 | n_important_features: 37 | The number of imporant features that will appear in the report. 38 | Defaults to 3. 39 | """ 40 | names = [] 41 | explanations_found = [explan for explan in self.explanations if not isinstance(explan, str)] 42 | for expl in explanations_found: 43 | for name in expl.data['anchor']: 44 | # split without an argument splits by spaces, and in every item in expl['names'] 45 | # the first word refers to the feature name. 46 | if ( 47 | (' = ' in name) or 48 | ((len(self.categorical) > 0) and (name in [item for sublist in list(self.categorical.values()) for item in sublist])) 49 | ): 50 | names.append(name) 51 | else: 52 | names.append(' '.join(name[::-1].split('.', 1)[1][::-1].split()[:-1])) 53 | 54 | unique_names, count_names = np.unique(names, return_counts=True) 55 | top_feats = heapq.nlargest(n_important_features, count_names) 56 | print_values = ['The ', str(n_important_features), ' most common features are: '] 57 | unique_names_ordered = sorted(unique_names.tolist(), key=lambda x: count_names[unique_names.tolist().index(x)], reverse=True) 58 | count_names_ordered = sorted(count_names.tolist(), reverse=True) 59 | n_explanations = 0 60 | for unique_name, count_name in zip(unique_names_ordered[:n_important_features], count_names_ordered[:n_important_features]): 61 | if n_explanations == 0: 62 | print_values.append([unique_name, ' with a frequency of ', 63 | str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) ']) 64 | elif n_explanations == n_important_features - 1: 65 | print_values.append([' and ', unique_name, ' with a frequency of ', 66 | str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) ']) 67 | else: 68 | print_values.append([', ',unique_name, ' with a frequency of ', 69 | str(count_name), ' (', str(100 * count_name / len(explanations_found)), '%) ']) 70 | n_explanations += 1 71 | interptretation = ''.join(list(itertools.chain(*print_values))) 72 | print(interptretation) 73 | return interptretation 74 | -------------------------------------------------------------------------------- /tests/explainability/test_counterfactualproto.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from mercury.explainability.explainers import CounterfactualExplainer, CounterfactualProtoExplainer 4 | import pytest 5 | import pickle 6 | import tensorflow as tf 7 | # from tensorflow.python.framework import ops 8 | 9 | @pytest.fixture(scope='session') 10 | def model_and_data(): 11 | tf.compat.v1.disable_eager_execution() 12 | logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 13 | fit_data = pd.read_csv('./tests/explainability/model_and_data/fit_data_red.csv', index_col=0) 14 | explain_data = pd.read_csv('./tests/explainability/model_and_data/explain_data.csv', index_col=0) 15 | return { 16 | 'logRegModel': logRegModel, 17 | 'fit_data': fit_data, 18 | 'explain_data': explain_data 19 | } 20 | 21 | 22 | pytestmark = pytest.mark.usefixtures('model_and_data') 23 | 24 | 25 | # def test_cf_proto_explain(model_and_data): 26 | # # Test it does not crash during __init__ (where fit is done) and 27 | # # that returns some counterfactual. 28 | # logRegModel = model_and_data['logRegModel'] 29 | # fit_data = model_and_data['fit_data'] 30 | # explain_data = model_and_data['explain_data'] 31 | # feature_names = list(explain_data.columns) 32 | 33 | # cfExplainer = CounterfactualProtoExplainer( 34 | # predict_fn=logRegModel.predict_proba, 35 | # train_data=fit_data, 36 | # use_kdtree=True, 37 | # ) 38 | 39 | # explanation = cfExplainer.explain(explain_data.head(1).values) 40 | 41 | # cfs = False 42 | # for c in list(explanation['data']['all'].values()): 43 | # if len(c) > 0: 44 | # cfs = True 45 | # break 46 | 47 | # assert cfs 48 | 49 | 50 | # def test_cfproto_feature_importance(model_and_data): 51 | # logRegModel = model_and_data['logRegModel'] 52 | # fit_data = model_and_data['fit_data'] 53 | # explain_data = model_and_data['explain_data'] 54 | # feature_names = list(explain_data.columns) 55 | 56 | # predict_fn = lambda x: logRegModel.predict_proba(x) 57 | 58 | # shape = (1,) + fit_data.shape[1:] 59 | 60 | # feature_range = ( 61 | # fit_data.min(axis=0), 62 | # fit_data.max(axis=0) 63 | # ) 64 | 65 | # c_init = 1. 66 | # c_steps = 10 67 | # eps = (1e-2, 1e-2) 68 | # theta = 10 69 | # max_iter = 1000 70 | 71 | # cfprotoExtendedExplainer = CounterfactualProtoExplainer( 72 | # predict_fn=predict_fn, 73 | # train_data=fit_data, 74 | # shape=shape, 75 | # feature_names=feature_names, 76 | # feature_range=feature_range, 77 | # c_init=c_init, 78 | # c_steps=c_steps, 79 | # max_iterations=max_iter, 80 | # theta=theta, 81 | # eps=eps, 82 | # update_num_grad=1 83 | # ) 84 | 85 | # explanations = cfprotoExtendedExplainer.get_feature_importance( 86 | # explain_data.head(3) 87 | # ) 88 | 89 | # assert ( 90 | # explanations.importances[0][0] == 'ExternalRiskEstimate' and 91 | # explanations.importances[0][1] == 1.0 and 92 | # explanations.importances[0][2] == -1.0 and 93 | # explanations.importances[1][0] == 'NetFractionRevolvingBurden' and 94 | # explanations.importances[1][1] >= 0.5 and 95 | # explanations.importances[1][2] == 1.0 96 | # ) 97 | 98 | 99 | def test_cf_proto_explain(model_and_data): 100 | # Skipped due to alibi (which is optional) expecting tensorflow 1.xx 101 | pass 102 | 103 | 104 | def test_cfproto_feature_importance(model_and_data): 105 | # Skipped due to alibi (which is optional) expecting tensorflow 1.xx 106 | pass 107 | -------------------------------------------------------------------------------- /mercury/explainability/explanations/clustering_tree_explanation.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | import numpy as np 4 | try: 5 | from graphviz import Source 6 | graphviz_available = True 7 | except Exception: 8 | graphviz_available = False 9 | 10 | 11 | class ClusteringTreeExplanation(): 12 | 13 | """ 14 | Explanation for ClusteringTreeExplainer. Represents a Decision Tree for the explanation of a clustering 15 | algorithm. 16 | Using the plot method generates a visualization of the decision tree (requires graphviz package) 17 | 18 | Args: 19 | tree: the fitted decision tree 20 | feature_names: the feature names used in the decision tree 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | tree: "Node", # noqa: F821 27 | feature_names: List = None, 28 | ): 29 | self.tree = tree 30 | self.feature_names = feature_names 31 | 32 | def plot(self, filename: str = "tree_explanation", feature_names: List = None, scalers: dict = None): 33 | 34 | """ 35 | Generates a graphviz.Source object representing the decision tree, which can be visualized in a notebook 36 | or saved in a file. 37 | 38 | Args: 39 | filename: filename to save if render() method is called over the returned object 40 | feature_names: the feature names to use. If not specified, the feature names specified in the constructor 41 | are used. 42 | scalers: dictionary of scalers. If passed, the tree will show the denormalized value in the split instead 43 | of the normalized value. The key is the feature name and the scaler must have the `inverse_transform` 44 | method 45 | 46 | Returns: 47 | (graphviz.Source): object representing the decision tree. 48 | """ 49 | 50 | feature_names = self.feature_names if feature_names is None else feature_names 51 | scalers = {} if scalers is None else scalers 52 | 53 | if not graphviz_available: 54 | raise Exception("Required package is missing. Please install graphviz") 55 | 56 | if self.tree is not None: 57 | dot_str = ["digraph ClusteringTree {\n"] 58 | queue = [self.tree] 59 | nodes = [] 60 | edges = [] 61 | id = 0 62 | while len(queue) > 0: 63 | curr = queue.pop(0) 64 | if curr.is_leaf(): 65 | label = "%s\nsamples=%d\nmistakes=%d" % (str(self._get_node_split_value(curr)), curr.samples, curr.mistakes) 66 | else: 67 | feature_name = curr.feature if feature_names is None else feature_names[curr.feature] 68 | condition = "%s <= %.3f" % (feature_name, self._get_node_split_value(curr, feature_name, scalers)) 69 | label = "%s\nsamples=%d" % (condition, curr.samples) 70 | queue.append(curr.left) 71 | queue.append(curr.right) 72 | edges.append((id, id + len(queue) - 1)) 73 | edges.append((id, id + len(queue))) 74 | nodes.append({"id": id, 75 | "label": label, 76 | "node": curr}) 77 | id += 1 78 | for node in nodes: 79 | dot_str.append("n_%d [label=\"%s\"];\n" % (node["id"], node["label"])) 80 | for edge in edges: 81 | dot_str.append("n_%d -> n_%d;\n" % (edge[0], edge[1])) 82 | dot_str.append("}") 83 | dot_str = "".join(dot_str) 84 | s = Source(dot_str, filename=filename + '.gv', format="png") 85 | return s 86 | 87 | def _get_node_split_value(self, node, feature_name=None, scalers=None): 88 | if (feature_name is not None) and (scalers is not None) and (feature_name in scalers): 89 | return scalers[feature_name].inverse_transform(np.array([node.value]).reshape(1, -1))[0][0] 90 | else: 91 | return node.value 92 | -------------------------------------------------------------------------------- /mercury/explainability/pyspark_utils.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import pandas as pd 3 | import numpy as np 4 | 5 | 6 | class SparkWrapper: 7 | """ 8 | This class is an adaptor which allows Spark models to also be 9 | explained by Mercury. In order to explain your model you should wrap it 10 | with this. 11 | 12 | Args: 13 | transformer: Trained PySpark model (transformer), or anything implementing a transform method 14 | feature_names: Name of the features the PySpark model uses. 15 | spark_session: Current spark session in use. 16 | model_inp_name: Name of the input column for the model (output name of the VectorAssembler) 17 | model_out_name: Output column name of the PySpark model. Default one is "probability". 18 | vector_assembler: If None, a default one will be created in order to transform the features to a Vector used by the PySpark model. 19 | probability_threshold: If >=0, the output of the model will be normalized taking into account the threshold. 20 | 21 | Example: 22 | ```python 23 | # model is a spark transformer (including a pipeline) already trained we want to explain 24 | >>> model_w = SparkWrapper( 25 | ... model, 26 | ... dataset.feature_names, 27 | ... spark_session=spark, 28 | ... model_inp_name="features", 29 | ... model_out_name="probability", 30 | ... ) 31 | # model_w is a model ready to be explained with Mercury explainers. 32 | ``` 33 | """ 34 | 35 | def __init__(self, 36 | transformer: 'pyspark.ml.Transformer', # noqa: F821 37 | feature_names: list, 38 | spark_session: 'pyspark.sql.SparkSession' = None, # noqa: F821 39 | model_inp_name: str = "features", 40 | model_out_name: str = "probability", 41 | vector_assembler: 'pyspark.sql.VectorAssembler' = None, # noqa: F821 42 | probability_threshold: float = -1 43 | ): 44 | 45 | import pyspark 46 | 47 | self.model = transformer 48 | self.spark_session = spark_session 49 | self.feature_names = feature_names 50 | self.model_out_name = model_out_name 51 | self.vector_assembler = vector_assembler 52 | self.probability_threshold = probability_threshold 53 | 54 | if self.vector_assembler is None: 55 | self.vector_assembler = pyspark.ml.feature.VectorAssembler( 56 | inputCols=feature_names, outputCol=model_inp_name 57 | ) 58 | 59 | @staticmethod 60 | def _transform_threshold( 61 | x: np.ndarray, 62 | threshold: float 63 | ) -> np.ndarray: 64 | 65 | if x[1] < threshold: 66 | x[1] = 0.5 * (x[1] / threshold) 67 | else: 68 | x[1] = 0.5 + 0.5 * ((x[1] - threshold) / (1 - threshold)) 69 | x[0] = 1 - x[1] 70 | return x 71 | 72 | def __call__(self, data: TP.Union['pd.DataFrame', np.ndarray] = None): 73 | if data.shape[1] != len(self.feature_names): 74 | raise ValueError("The input does not have the same number of features as the specified in self.feature_names.") 75 | 76 | x_to_be_predicted = data 77 | if type(data) is np.ndarray: 78 | data = pd.DataFrame(data, columns=self.feature_names) 79 | sp_df_x = self.spark_session.createDataFrame(data) 80 | x_to_be_predicted = self.vector_assembler.transform( 81 | sp_df_x, 82 | ) if self.vector_assembler else sp_df_x 83 | 84 | pred_out = np.stack(self.model.transform(x_to_be_predicted) 85 | .select(self.model_out_name).toPandas()[self.model_out_name].apply( 86 | lambda x: np.array(x.toArray()) 87 | ).values) 88 | 89 | if self.probability_threshold >= 0: 90 | pred_out = np.apply_along_axis( 91 | lambda x: self._transform_threshold(x, self.probability_threshold), 92 | 1, 93 | pred_out 94 | ) 95 | 96 | return pred_out -------------------------------------------------------------------------------- /tests/explainability/test_ale.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from matplotlib import pyplot as plt 4 | from mercury.explainability import ALEExplainer 5 | from mercury.explainability.explainers.ale import plot_ale 6 | 7 | import pandas as pd 8 | import numpy as np 9 | import pickle 10 | 11 | import pytest 12 | 13 | 14 | @pytest.fixture(scope='session') 15 | def model_and_data(): 16 | logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 17 | fit_data = pd.read_csv('./tests/explainability/model_and_data_pyspark/data_ale.csv', index_col=0) 18 | return { 19 | 'logRegModel': logRegModel, 20 | 'data': fit_data, 21 | } 22 | 23 | def test_ale_bad_init(model_and_data): 24 | model = model_and_data['logRegModel'] 25 | data_pd = model_and_data['data'] 26 | with pytest.raises(AttributeError) as excinfo: 27 | ale_instance = ALEExplainer( 28 | lambda x: model.predict_proba(x), 29 | target_names=1 30 | ) 31 | 32 | with pytest.raises(AttributeError) as excinfo: 33 | ale_instance = ALEExplainer( 34 | lambda x: model.predict_proba(x), 35 | target_names=1 36 | ) 37 | 38 | def test_ale_bad_explain(model_and_data): 39 | model = model_and_data['logRegModel'] 40 | data_pd = model_and_data['data'] 41 | 42 | ale_instance = ALEExplainer( 43 | lambda x: model.predict_proba(x), 44 | target_names=['label'] 45 | ) 46 | 47 | with pytest.raises(ValueError) as excinfo: 48 | explanation = ale_instance.explain(np.array([1,2,3])) 49 | 50 | 51 | def test_ale_explain(model_and_data): 52 | model = model_and_data['logRegModel'] 53 | data_pd = model_and_data['data'] 54 | features = [c for c in list(data_pd.columns) if c not in ['label']] 55 | 56 | ale_instance = ALEExplainer( 57 | lambda x: model.predict_proba(x), 58 | target_names="label" 59 | ) 60 | 61 | explanation = ale_instance.explain(data_pd[features]) 62 | 63 | assert ( 64 | explanation.ale_values[0][0][0] == pytest.approx(-0.20772727415447495, rel=0.1, abs=0.5) and 65 | explanation.ale_values[0][0][1] == pytest.approx(0.20772727415447498, rel=0.1, abs=0.5) and 66 | explanation.ale_values[1][0][0] == pytest.approx(0.0027093534329286116, rel=0.1, abs=0.5) and 67 | explanation.ale_values[1][0][1] == pytest.approx(-0.0027093534329286047, rel=0.1, abs=0.5) and 68 | explanation.ale_values[2][0][0] == pytest.approx(0.006239338356358494, rel=0.1, abs=0.5) 69 | ) 70 | 71 | def test_ale_explain_ignoring(model_and_data): 72 | model = model_and_data['logRegModel'] 73 | data_pd = model_and_data['data'] 74 | features = [c for c in list(data_pd.columns) if c not in ['label']] 75 | 76 | ale_instance = ALEExplainer( 77 | lambda x: model.predict_proba(x), 78 | target_names="label" 79 | ) 80 | 81 | to_ignore = features[3:6] 82 | explanation = ale_instance.explain(data_pd[features], ignore_features=to_ignore) 83 | 84 | for e in explanation.feature_names: 85 | assert e not in to_ignore 86 | 87 | 88 | def test_plot_explanation(model_and_data): 89 | model = model_and_data['logRegModel'] 90 | data_pd = model_and_data['data'] 91 | features = [c for c in list(data_pd.columns) if c not in ['label']] 92 | 93 | ale_instance = ALEExplainer( 94 | lambda x: model.predict_proba(x), 95 | target_names='label' 96 | ) 97 | 98 | explanation = ale_instance.explain(data_pd[features]) 99 | 100 | axes = plot_ale(explanation, 101 | n_cols=1, 102 | fig_kw={'figwidth': 13, 'figheight': 20}, 103 | line_kw={'markersize': 3, 'marker': 'o', 'label': None} , 104 | sharey=None) 105 | 106 | assert axes.shape == (20, 1) 107 | 108 | fig, ax = plt.subplots(len(features)) 109 | axes = plot_ale(explanation, features=features, targets=['label'], ax=ax) 110 | 111 | # Test only plot of certain features 112 | axes = plot_ale(explanation, features=features[3:6]) 113 | assert axes.shape == (1,3) 114 | -------------------------------------------------------------------------------- /tests/explainability/test_serializer.py: -------------------------------------------------------------------------------- 1 | from mercury.explainability.explainers import ( 2 | AnchorsWithImportanceExplainer, 3 | ALEExplainer, 4 | MercuryExplainer 5 | ) 6 | import pytest 7 | import pickle 8 | import pandas as pd 9 | import numpy as np 10 | import os 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def model_and_data_anchors(): 15 | logRegModel = pickle.load(open('tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 16 | fit_data = pd.read_csv('tests/explainability/model_and_data/fit_data_red.csv', index_col=0) 17 | explain_data = pd.read_csv('tests/explainability/model_and_data/explain_data.csv', index_col=0) 18 | return { 19 | 'logRegModel': logRegModel, 20 | 'fit_data': fit_data, 21 | 'explain_data': explain_data 22 | } 23 | 24 | @pytest.fixture(scope='session') 25 | def model_and_data_ale(): 26 | logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 27 | fit_data = pd.read_csv('./tests/explainability/model_and_data_pyspark/data_ale.csv', index_col=0) 28 | return { 29 | 'logRegModel': logRegModel, 30 | 'data': fit_data, 31 | } 32 | 33 | pytestmark = pytest.mark.usefixtures("model_and_data") 34 | 35 | def test_serializer_explainer(model_and_data_anchors): 36 | """ 37 | Testing out that the explainers are properly saved and then loaded 38 | back. 39 | """ 40 | logRegModel = model_and_data_anchors['logRegModel'] 41 | fit_data = model_and_data_anchors['fit_data'] 42 | explain_data = model_and_data_anchors['explain_data'] 43 | feature_names = list(explain_data.columns) 44 | 45 | TEST_FILE = "/tmp/explainer.pkl" 46 | 47 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 48 | train_data=fit_data, 49 | predict_fn=logRegModel.predict_proba, 50 | feature_names=feature_names 51 | ) 52 | anchorsExtendedExplainer.save(TEST_FILE) 53 | assert os.path.isfile(TEST_FILE), "File does not exist" 54 | 55 | anchorsExtendedExplainer_recovered = MercuryExplainer.load(TEST_FILE) 56 | assert type(anchorsExtendedExplainer_recovered) ==\ 57 | AnchorsWithImportanceExplainer, "Bad load" 58 | 59 | os.remove(TEST_FILE) 60 | 61 | def test_serializer_anchors_with_importance_explainer(model_and_data_anchors): 62 | 63 | logRegModel = model_and_data_anchors['logRegModel'] 64 | fit_data = model_and_data_anchors['fit_data'] 65 | explain_data = model_and_data_anchors['explain_data'] 66 | feature_names = list(explain_data.columns) 67 | 68 | TEST_FILE = "/tmp/explainer_anchor.pkl" 69 | 70 | explainer = AnchorsWithImportanceExplainer( 71 | train_data=fit_data, 72 | predict_fn=logRegModel.predict_proba, 73 | feature_names=feature_names 74 | ) 75 | explainer.save(TEST_FILE) 76 | 77 | explainer_loaded = MercuryExplainer.load(TEST_FILE) 78 | assert isinstance(explainer_loaded, AnchorsWithImportanceExplainer) 79 | assert explainer.params == explainer_loaded.params 80 | assert explainer.feature_values == explainer_loaded.feature_values 81 | 82 | # We are able to execute explain 83 | explanation_loaded = explainer_loaded.explain(fit_data.values[0]) 84 | 85 | os.remove(TEST_FILE) 86 | 87 | def test_serializer_ale_explainer(model_and_data_ale): 88 | 89 | model = model_and_data_ale['logRegModel'] 90 | data_pd = model_and_data_ale['data'] 91 | features = [c for c in list(data_pd.columns) if c not in ['label']] 92 | 93 | TEST_FILE = "/tmp/explainer_ale.pkl" 94 | 95 | explainer = ALEExplainer( 96 | lambda x: model.predict_proba(x), 97 | target_names="label" 98 | ) 99 | explainer.save(TEST_FILE) 100 | 101 | explainer_loaded = MercuryExplainer.load(TEST_FILE) 102 | isinstance(explainer_loaded, ALEExplainer) 103 | 104 | # Check explanations 105 | explanation = explainer.explain(data_pd[features]) 106 | explanation_loaded = explainer_loaded.explain(data_pd[features]) 107 | 108 | for i in range(len(explanation.data)): 109 | assert np.all(explanation.data['ale_values'][i] == explanation_loaded.data['ale_values'][i]) 110 | 111 | os.remove(TEST_FILE) -------------------------------------------------------------------------------- /tests/explainability/test_explanation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import typing as TP 3 | 4 | from mercury.explainability.explanations import CounterfactualBasicExplanation 5 | 6 | import pytest 7 | import numpy as np 8 | 9 | 10 | class ExplanationTest(unittest.TestCase): 11 | """ A collection of simple 2D cases. """ 12 | 13 | # Simple 1D cases 14 | def test_invalid_shape(self): 15 | """ Case where from_.shape != to_.shape. """ 16 | with self.assertRaises(AssertionError): 17 | CounterfactualBasicExplanation( 18 | np.array([1.0]), # from 19 | np.array([]), # to 20 | 0., # p 21 | np.array([]), # path 22 | np.array([]), # path_ps 23 | np.array([[-1., 1.]]), # bounds 24 | np.array([]), # explored 25 | np.array([])) # explored_ps 26 | 27 | def test_invalid_invalid_probability(self): 28 | """ Case where probability is invalid. """ 29 | with self.assertRaises(AssertionError): 30 | CounterfactualBasicExplanation( 31 | np.array([1.0]), # from 32 | np.array([2.0]), # to 33 | -1., # p 34 | np.array([]), # path 35 | np.array([]), # path_ps 36 | np.array([[-1., 1.]]), # bounds 37 | np.array([]), # explored 38 | np.array([])) # explored_ps 39 | 40 | def test_invalid_path_shape(self): 41 | """ Case where path.shape != path_ps.shape. """ 42 | with self.assertRaises(AssertionError): 43 | CounterfactualBasicExplanation( 44 | np.array([1.0]), # from 45 | np.array([2.0]), # to 46 | 0., # p 47 | np.array([1.2, 1.8, 2.0]), # path 48 | np.array([0.8, 0.2]), # path_ps 49 | np.array([[-1., 1.]]), # bounds 50 | np.array([]), # explored 51 | np.array([])) # explored_ps 52 | 53 | def test_invalid_bounds(self): 54 | """ Case where bounds.shape[0] != from_.shape[0] """ 55 | with self.assertRaises(AssertionError): 56 | CounterfactualBasicExplanation( 57 | np.array([1.0]), # from 58 | np.array([2.0]), # to 59 | 0., # p 60 | np.array([1.2, 1.8, 2.0]), # path 61 | np.array([0.8, 0.2, 0.]), # path_ps 62 | np.array([]), # bounds 63 | np.array([]), # explored 64 | np.array([])) # explored_ps 65 | 66 | def test_invalid_path2(self): 67 | """ Case where explored.shape[0] != explored_ps.shape[0] """ 68 | with self.assertRaises(AssertionError): 69 | CounterfactualBasicExplanation( 70 | np.array([1.0]), # from 71 | np.array([2.0]), # to 72 | 0., # p 73 | np.array([1.2, 1.8, 2.]), # path 74 | np.array([0.8, 0.2, 0.01]), # path_ps 75 | np.array([[-1., 1.]]), # bounds 76 | np.array([1.1, 1.2, 1.5, 1.8, 2.]), # explored 77 | np.array([])) # explored_ps 78 | 79 | def test_invalid_labels(self): 80 | """ Case where len(labels) != bounds.shape[0]. """ 81 | with self.assertRaises(AssertionError): 82 | CounterfactualBasicExplanation( 83 | np.array([1.0]), # from 84 | np.array([2.0]), # to 85 | 0., # p 86 | np.array([1.2, 1.8, 2.]), # path 87 | np.array([0.8, 0.2, 0.01]), # path_ps 88 | np.array([[-1., 1.]]), # bounds 89 | np.array([1.1, 1.2, 1.5, 1.8, 2.]), # explored 90 | np.array([0.6, 0.3, 0.45, 0.1, 0.]), # explored_ps 91 | labels=['a', 'b']) # labels 92 | -------------------------------------------------------------------------------- /tests/explainability/test_counterfactual_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from mercury.explainability.explainers import CounterFactualExplainerBasic 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pytest 8 | 9 | 10 | class CounterfactualBasicTest(unittest.TestCase): 11 | """Simple checks on Panellet.""" 12 | 13 | # The 2D-Rosenbrock function à la classifier output 14 | def _rosenbrock(self, x: "np.ndarray", a: int = 1, b: int = 100) -> "np.ndarray": 15 | assert x.shape[1] == 2, "Invalid input for Rosenbrock function" 16 | 17 | # return a numpy array, instead of a scalar to simulate a 18 | # classifier "probability" (2nd field will be ignored) 19 | def _clip_rosenbrock(x): 20 | return np.clip(np.array([(a - x[0]) ** 2 + b * (x[1] - x[0] ** 2) ** 2]), 0, 1) 21 | 22 | return np.apply_along_axis(_clip_rosenbrock, 1, x) 23 | # return np.array([[((a - x[:, 0])**2 + b * (x[:, 1] - x[:, 0]**2)**2)[0], 0.]]) 24 | 25 | def test_nodata_constructor_invalid_bounds1(self): 26 | with self.assertRaises(AssertionError): 27 | # Labels/bounds shape does not match 28 | CounterFactualExplainerBasic( 29 | None, self._rosenbrock, labels=["a", "b"], bounds=np.array([]) 30 | ) 31 | 32 | def test_nodata_constructor_invalid_bounds2(self): 33 | with self.assertRaises(AssertionError): 34 | # Invalid bounds shape 35 | CounterFactualExplainerBasic( 36 | None, self._rosenbrock, labels=["a"], bounds=np.array([0]) 37 | ) 38 | 39 | def test_nodata_constructor_invalid_labels(self): 40 | with self.assertRaises(AssertionError): 41 | # Labels/bounds shape does not match 42 | CounterFactualExplainerBasic( 43 | None, self._rosenbrock, labels=["a"], bounds=np.array([[0, 1], [0, 1]]) 44 | ) 45 | 46 | def test_nodata_roll_invalid_from(self): 47 | pan = CounterFactualExplainerBasic( 48 | None, self._rosenbrock, labels=["a"], bounds=np.array([[0, 1]]) 49 | ) 50 | with self.assertRaises(AssertionError): 51 | pan.explain(np.array([[1, 2]]), 0.1) 52 | 53 | def test_nodata_roll_invalid_bounds(self): 54 | pan = CounterFactualExplainerBasic( 55 | None, self._rosenbrock, labels=["a"], bounds=np.array([[0, 1]]) 56 | ) 57 | with self.assertRaises(AssertionError): 58 | pan.explain(np.array([1]), 0.1, bounds=np.array([[0, 1], [0, 1]])) 59 | 60 | def test_nodata_roll_invalid_step(self): 61 | pan = CounterFactualExplainerBasic( 62 | None, self._rosenbrock, labels=["a"], bounds=np.array([[0, 1]]) 63 | ) 64 | with self.assertRaises(AssertionError): 65 | pan.explain(np.array([1]), 0.1, step=np.array([0.01, 0.01])) 66 | 67 | def test_nodata_bt_strategy(self): 68 | df_data = pd.DataFrame(data={"a": [-10, 10], "b": [-10, 10]}) 69 | np_data = np.array([[-10, -10], [10, 10]]) 70 | pan = CounterFactualExplainerBasic( 71 | np_data, self._rosenbrock, labels=["a", "b"] 72 | ) 73 | pan.explain(np_data[0], 0.1, class_idx=0, strategy="backtracking") 74 | 75 | def test_build_wrong_dtype(self): 76 | with pytest.raises(TypeError) as excinfo: 77 | pan = CounterFactualExplainerBasic( 78 | [0,1], self._rosenbrock, labels=["a", "b"] 79 | ) 80 | 81 | def test_nodata_sima_strategy(self): 82 | df_data = pd.DataFrame(data={"a": [-10, 10], "b": [-10, 10]}) 83 | np_data = np.array([[-10, -10], [10, 10]]) 84 | pan = CounterFactualExplainerBasic( 85 | df_data, self._rosenbrock 86 | ) 87 | pan.explain(np_data[0], 0.1, 88 | class_idx=0, 89 | bounds=np.array([[0, 9], [0, 9]]), 90 | strategy="simanneal") 91 | 92 | def test_nodata_invalid_strategy(self): 93 | pan = CounterFactualExplainerBasic( 94 | None, self._rosenbrock, labels=["a"], bounds=np.array([[0, 1]]) 95 | ) 96 | with self.assertRaises(ValueError): 97 | pan.explain(np.array([1]), 0.1, strategy="meh") 98 | 99 | def test_data_constructor_invalid_labels(self): 100 | df_data = pd.DataFrame(data={"a": [-10, 10], "b": [-10, 10]}) 101 | np_data = np.array([[-10, -10], [10, 10]]) 102 | # this shold give no error 103 | CounterFactualExplainerBasic(df_data, self._rosenbrock) 104 | with self.assertRaises(AssertionError): 105 | # Labels/bounds shape does not match 106 | CounterFactualExplainerBasic(np_data, self._rosenbrock, labels=["a", "b", "c"]) 107 | 108 | def test_keep_explored_points_false(self): 109 | np_data = np.array([[-10, -10], [10, 10]]) 110 | pan = CounterFactualExplainerBasic( 111 | np_data, self._rosenbrock, labels=["a", "b"] 112 | ) 113 | explanation = pan.explain( 114 | np_data[0], 0.1, class_idx=0, strategy="backtracking", limit=2, max_iter=3, keep_explored_points=False 115 | ) 116 | assert len(explanation.explored) == 0 117 | 118 | def test_with_kernel_and_step(self): 119 | np_data = np.array([[-10, -10], [10, 10]]) 120 | pan = CounterFactualExplainerBasic( 121 | np_data, self._rosenbrock, labels=["a", "b"], 122 | ) 123 | explanation = pan.explain( 124 | np_data[0], 0.1, class_idx=0, strategy="backtracking", limit=1, max_iter=5, keep_explored_points=False, 125 | shuffle_limit=True, 126 | kernel=np.array([1.0, 0.]), step=np.array([0.2, 0.]) 127 | ) 128 | assert (pan.kernel == np.array([1.0, 0.])).all() 129 | assert (pan.step == np.array([0.2, 0.])).all() 130 | assert explanation.get_changes()[1] == 0 131 | -------------------------------------------------------------------------------- /tests/explainability/test_anchors_extended.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from mercury.explainability.explainers import AnchorsWithImportanceExplainer 4 | import pytest 5 | import pickle 6 | import tensorflow as tf 7 | 8 | @pytest.fixture(scope="session") 9 | def model_and_data(): 10 | logRegModel = pickle.load(open('./tests/explainability/model_and_data/FICO_lr_model.pkl', 'rb')) 11 | fit_data = pd.read_csv('./tests/explainability/model_and_data/fit_data_red.csv', index_col=0) 12 | explain_data = pd.read_csv('./tests/explainability/model_and_data/explain_data.csv', index_col=0) 13 | return { 14 | 'logRegModel': logRegModel, 15 | 'fit_data': fit_data, 16 | 'explain_data': explain_data 17 | } 18 | 19 | 20 | pytestmark = pytest.mark.usefixtures("model_and_data") 21 | 22 | 23 | def test_anchors_build(model_and_data): 24 | """ 25 | Testing out that we can obtain explanations and they make sense. 26 | For that we have to fit the anchors model and provide a seed to 27 | avoid randomness. 28 | """ 29 | logRegModel = model_and_data['logRegModel'] 30 | fit_data = model_and_data['fit_data'] 31 | explain_data = model_and_data['explain_data'] 32 | 33 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 34 | predict_fn=logRegModel.predict_proba, 35 | train_data=fit_data, 36 | disc_perc=[10,20,30,40,50,60,70,80,90] 37 | ) 38 | 39 | with pytest.raises(AttributeError) as excinfo: 40 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 41 | predict_fn=logRegModel.predict_proba, 42 | train_data='wrong data', 43 | disc_perc=[10,20,30,40,50,60,70,80,90] 44 | ) 45 | 46 | with pytest.raises(AttributeError) as excinfo: 47 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 48 | predict_fn=logRegModel.predict_proba, 49 | train_data=fit_data, 50 | categorical_names=[] 51 | ) 52 | 53 | def test_anchors_fit_and_explain_precision(model_and_data): 54 | """ 55 | Testing out that we can obtain explanations and they make sense. 56 | For that we have to fit the anchors model and provide a seed to 57 | avoid randomness. 58 | """ 59 | logRegModel = model_and_data['logRegModel'] 60 | fit_data = model_and_data['fit_data'] 61 | explain_data = model_and_data['explain_data'] 62 | 63 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 64 | predict_fn=logRegModel.predict_proba, 65 | train_data=fit_data, 66 | disc_perc=[10,20,30,40,50,60,70,80,90] 67 | ) 68 | 69 | np.random.seed(42) 70 | explanation = anchorsExtendedExplainer.explain( 71 | explain_data.head(1).values, threshold=0.95 72 | ) 73 | 74 | assert explanation.data['precision'] > 0.95 75 | 76 | def test_anchors_fit_and_explain_coverage(model_and_data): 77 | """ 78 | Testing out that we can obtain explanations and they make sense. 79 | For that we have to fit the anchors model and provide a seed to 80 | avoid randomness. 81 | """ 82 | logRegModel = model_and_data['logRegModel'] 83 | fit_data = model_and_data['fit_data'] 84 | explain_data = model_and_data['explain_data'] 85 | 86 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 87 | predict_fn=logRegModel.predict_proba, 88 | train_data=fit_data, 89 | disc_perc=[10,20,30,40,50,60,70,80,90] 90 | ) 91 | 92 | np.random.seed(42) 93 | explanation = anchorsExtendedExplainer.explain( 94 | explain_data.head(1).values, threshold=0.95 95 | ) 96 | assert explanation.data['precision'] > 0.95 97 | 98 | def test_anchors_feature_importance_obtention(model_and_data): 99 | """ 100 | Testing out that we can obtain explanations and they make sense. 101 | For that we have to fit the anchors model and provide a seed to 102 | avoid randomness. 103 | """ 104 | logRegModel = model_and_data['logRegModel'] 105 | fit_data = model_and_data['fit_data'] 106 | explain_data = model_and_data['explain_data'] 107 | 108 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 109 | predict_fn=logRegModel.predict_proba, 110 | train_data=fit_data, 111 | disc_perc=[10,20,30,40,50,60,70,80,90] 112 | ) 113 | 114 | np.random.seed(42) 115 | anchorsExplanations = anchorsExtendedExplainer.get_feature_importance( 116 | explain_data.head(10), print_every=10, print_explanations=True 117 | ) 118 | 119 | anchorsInterpretation = anchorsExplanations.interpret_explanations(n_important_features=3) 120 | assert ( 121 | 'ExternalRiskEstimate' in anchorsInterpretation and 122 | 'NetFractionRevolvingBurden' in anchorsInterpretation and 123 | 'AverageMInFile' in anchorsInterpretation 124 | ) 125 | 126 | def test_anchors_feature_importance_obtention_top_5(model_and_data): 127 | """ 128 | Testing out that we can obtain explanations and they make sense. 129 | For that we have to fit the anchors model and provide a seed to 130 | avoid randomness. 131 | """ 132 | logRegModel = model_and_data['logRegModel'] 133 | fit_data = model_and_data['fit_data'] 134 | explain_data = model_and_data['explain_data'] 135 | 136 | anchorsExtendedExplainer = AnchorsWithImportanceExplainer( 137 | predict_fn=logRegModel.predict_proba, 138 | train_data=fit_data, 139 | disc_perc=[10,20,30,40,50,60,70,80,90] 140 | ) 141 | 142 | np.random.seed(42) 143 | anchorsExplanations = anchorsExtendedExplainer.get_feature_importance( 144 | explain_data.head(10) 145 | ) 146 | 147 | anchorsInterpretation = anchorsExplanations.interpret_explanations(n_important_features=5) 148 | 149 | # I don't include the last two features because they only have a frequency of 1. 150 | 151 | assert ( 152 | 'ExternalRiskEstimate' in anchorsInterpretation and 153 | 'NetFractionRevolvingBurden' in anchorsInterpretation and 154 | 'AverageMInFile' in anchorsInterpretation 155 | ) 156 | -------------------------------------------------------------------------------- /tests/explainability/test_clustering_tree_explainer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn.preprocessing import StandardScaler 3 | from sklearn.cluster import KMeans as Sklearn_KMeans 4 | 5 | from mercury.explainability.explainers.clustering_tree_explainer import ClusteringTreeExplainer 6 | 7 | pytestmark = pytest.mark.usefixtures("model_and_data_cte") 8 | 9 | 10 | def _assert_all_clusters_have_leafs(explanation): 11 | assert """label="0""" in str(explanation) 12 | assert """label="1""" in str(explanation) 13 | assert """label="2""" in str(explanation) 14 | 15 | def test_clustering_tree_explainer_pandas(model_and_data_cte): 16 | 17 | sk_kmeans = model_and_data_cte['sk_kmeans'] 18 | K = model_and_data_cte['K'] 19 | pandas_df = model_and_data_cte['pandas_df'] 20 | 21 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K) 22 | explanation = clustering_tree_explainer.explain(pandas_df) 23 | plot_explanation = explanation.plot(filename="test1") 24 | 25 | _assert_all_clusters_have_leafs(str(plot_explanation)) 26 | assert clustering_tree_explainer.tree._size() >= K 27 | assert clustering_tree_explainer.tree._max_depth() > 1 28 | 29 | score = clustering_tree_explainer.score(pandas_df) 30 | surrogate_score = clustering_tree_explainer.surrogate_score(pandas_df) 31 | assert isinstance(score, float) and isinstance(surrogate_score, float) 32 | assert surrogate_score >= score 33 | 34 | def test_clustering_tree_explainer_pandas_sk_pipeline(model_and_data_cte): 35 | 36 | sk_pipeline = model_and_data_cte['sk_pipeline'] 37 | K = model_and_data_cte['K'] 38 | pandas_df = model_and_data_cte['pandas_df'] 39 | 40 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_pipeline, max_leaves=K) 41 | explanation = clustering_tree_explainer.explain(pandas_df) 42 | plot_explanation = explanation.plot(filename="test1") 43 | 44 | _assert_all_clusters_have_leafs(str(plot_explanation)) 45 | assert clustering_tree_explainer.tree._size() >= K 46 | assert clustering_tree_explainer.tree._max_depth() > 1 47 | 48 | 49 | def test_clustering_tree_explainer_scalers(model_and_data_cte): 50 | 51 | K = model_and_data_cte['K'] 52 | pandas_df = model_and_data_cte['pandas_df'].copy() 53 | 54 | scalers = {} 55 | for c in pandas_df.columns: 56 | scalers[c] = StandardScaler() 57 | pandas_df[c] = scalers[c].fit_transform(pandas_df[c].values.reshape(-1, 1)) 58 | 59 | sk_kmeans = Sklearn_KMeans(K, random_state=42) 60 | sk_kmeans.fit(pandas_df) 61 | 62 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K) 63 | explanation = clustering_tree_explainer.explain(pandas_df) 64 | plot_explanation = explanation.plot(filename="test1", scalers=scalers) 65 | _assert_all_clusters_have_leafs(str(plot_explanation)) 66 | 67 | 68 | def test_clustering_tree_explainer_pandas_more_leaves(model_and_data_cte): 69 | 70 | """Test case when tree explainer grows to more leaves than clusters (ExKMC method)""" 71 | sk_kmeans = model_and_data_cte['sk_kmeans'] 72 | K = model_and_data_cte['K'] 73 | pandas_df = model_and_data_cte['pandas_df'] 74 | 75 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K+3) 76 | explanation = clustering_tree_explainer.explain(pandas_df) 77 | plot_explanation = explanation.plot(filename="test1") 78 | 79 | _assert_all_clusters_have_leafs(str(plot_explanation)) 80 | # check that explanation has more nodes 81 | nodes = ["n_" + str(i) for i in range(7)] 82 | for n in nodes: 83 | assert n in str(plot_explanation) 84 | assert clustering_tree_explainer.tree._size() >= K 85 | 86 | def test_clustering_tree_explainer_spark_pipeline(model_and_data_cte): 87 | 88 | spark_pipeline_model = model_and_data_cte['spark_pipeline_model'] 89 | K = model_and_data_cte['K'] 90 | spark_df = model_and_data_cte['spark_df'] 91 | 92 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=spark_pipeline_model, max_leaves=K) 93 | explanation = clustering_tree_explainer.explain(spark_df) 94 | plot_explanation = explanation.plot(filename="test1") 95 | _assert_all_clusters_have_leafs(str(plot_explanation)) 96 | assert clustering_tree_explainer.tree._size() >= K 97 | 98 | 99 | def test_clustering_tree_explainer_spark_kmeans(model_and_data_cte): 100 | 101 | spark_kmeans_model = model_and_data_cte['spark_kmeans_model'] 102 | K = model_and_data_cte['K'] 103 | spark_df = model_and_data_cte['spark_df_2'] 104 | 105 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=spark_kmeans_model, max_leaves=K) 106 | explanation = clustering_tree_explainer.explain(spark_df) 107 | plot_explanation = explanation.plot(filename="test1", feature_names=["feature_1", "feature_2"]) 108 | 109 | _assert_all_clusters_have_leafs(str(plot_explanation)) 110 | assert clustering_tree_explainer.tree._size() >= K 111 | 112 | 113 | def test_clustering_tree_exlainer_spark_subsampling(model_and_data_cte): 114 | 115 | spark_pipeline_model = model_and_data_cte['spark_pipeline_model'] 116 | K = model_and_data_cte['K'] 117 | spark_df = model_and_data_cte['spark_df'] 118 | 119 | clustering_tree_explainer = ClusteringTreeExplainer( 120 | clustering_model=spark_pipeline_model, max_leaves=K, verbose=True 121 | ) 122 | explanation = clustering_tree_explainer.explain(spark_df, subsample=0.9) 123 | plot_explanation = explanation.plot(filename="test1") 124 | 125 | # check that any node has all the 1000 samples 126 | assert "samples=\1000" not in str(plot_explanation) 127 | 128 | 129 | def test_clustering_tree_explainer_no_ibb(model_and_data_cte): 130 | 131 | sk_kmeans = model_and_data_cte['sk_kmeans'] 132 | K = model_and_data_cte['K'] 133 | pandas_df = model_and_data_cte['pandas_df'] 134 | 135 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K, base_tree='NONE') 136 | explanation = clustering_tree_explainer.explain(pandas_df) 137 | plot_explanation = explanation.plot(filename="test1") 138 | 139 | _assert_all_clusters_have_leafs(str(plot_explanation)) 140 | 141 | 142 | def test_clustering_tree_explainer_errors(model_and_data_cte): 143 | 144 | sk_kmeans = model_and_data_cte['sk_kmeans'] 145 | K = model_and_data_cte['K'] 146 | pandas_df = model_and_data_cte['pandas_df'] 147 | 148 | with pytest.raises(Exception): 149 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K-2) 150 | explanation = clustering_tree_explainer.explain(pandas_df) 151 | 152 | with pytest.raises(Exception): 153 | clustering_tree_explainer = ClusteringTreeExplainer(clustering_model=sk_kmeans, max_leaves=K, base_tree="new_tree") 154 | explanation = clustering_tree_explainer.explain(pandas_df) -------------------------------------------------------------------------------- /mercury/explainability/explainers/shuffle_importance.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable 2 | from .explainer import MercuryExplainer 3 | from ..explanations.shuffle_importance import FeatureImportanceExplanation 4 | 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.preprocessing import MinMaxScaler 9 | 10 | 11 | class ShuffleImportanceExplainer(MercuryExplainer): 12 | """ 13 | This explainer estimates the feature importance of each predictor for a 14 | given black-box model. The used strategy consists on random shuffling one 15 | variable at a time and, on each step, checking how much a particular 16 | metric worses. The features which make the model to perform the worst are 17 | the most important ones. 18 | 19 | Args: 20 | eval_fn (Callable): 21 | Custom evaluation function. It will recieve a DataFrame with features and 22 | a Numpy array with the target outputs for each instance. It must 23 | implement an inference process and return a metric to score the model 24 | performance on the given data. This metric must be real numbered. 25 | If we use a metric which higher values means better metric (like accuracy) 26 | and we use the parameter `normalize=True` (default option), then it is 27 | recommended to return the negative of that metric in `eval_fn` to make 28 | the results more intuitive. 29 | In the case of Pyspark explanations, the function will only recieve 30 | a PySpark in the first argument already containing the target column, 31 | whereas the second argument will be None. 32 | normalize (bool): 33 | Whether to scale the feature importances between 0 and 1. If True, then 34 | it shows the relative importance of the features. 35 | If False, then the feature importances will be the value of the metric 36 | returned in `eval_fn` when shuffling the features. 37 | Default value is True 38 | 39 | Example: 40 | ```python 41 | # "Plain python" example 42 | >>> features = pd.read_csv(PATH_DATA) 43 | >>> targets = features['target'] # Targets 44 | >>> features = features.loc[:, FEATURE_NAMES] # DataFrame with only features 45 | >>> def my_inference_function(features, targets): 46 | ... predictions = model.predict(features) 47 | ... return mean_squared_error(targets, predictions) 48 | >>> explainer = ShuffleImportanceExplainer(my_inference_function) 49 | >>> explanation = explainer.explain(features, targets) 50 | >>> explanation.plot() 51 | 52 | # Explain a pyspark model (or pipeline) 53 | >>> features = sess.createDataFrame(pandas_dataframe) 54 | >>> target_colname = "target" # Column name with the ground truth labels 55 | >>> def my_inference_function(features, targets): 56 | ... model_inp = vectorAssembler.transform(features) 57 | ... model_out = my_pyspark_transformer.transform(model_inp) 58 | ... return my_evaluator.evaluate(model_out) 59 | >>> explainer = ShuffleImportanceExplainer(my_inference_function) 60 | >>> explanation = explainer.explain(features, target_colname) 61 | >>> explanation.plot() 62 | ``` 63 | """ 64 | def __init__(self, 65 | eval_fn: Callable[[Union["pd.DataFrame", "pyspark.sql.DataFrame"], Union["np.ndarray", str]], float], # noqa: F821 66 | normalize: bool = True 67 | ): 68 | self.eval_fn = eval_fn 69 | self.normalize = normalize 70 | 71 | def explain(self, 72 | predictors: Union["pd.DataFrame", "pyspark.sql.DataFrame"], # noqa: F821 73 | target: Union["np.ndarray", str] 74 | ) -> FeatureImportanceExplanation: 75 | """ 76 | Explains the model given the data. 77 | 78 | Args: 79 | predictors (Union[pandas.DataFrame, pyspark.sql.DataFrame]): 80 | DataFrame with the features the model needs and that will be explained. 81 | In the case of PySpark, this dataframe must also contain a column 82 | with the target. 83 | target (Union[numpy.ndarray, str]): 84 | The ground-truth target for each one of the instances. In the case of 85 | Pyspark, this should be the name of the column in the DataFrame which 86 | holds the target. 87 | 88 | Raises: 89 | ValueError: if type(predictors) == pyspark.sql.DataFrame && type(target) != str 90 | ValueError: if type(predictors) == pyspark.sql.DataFrame && target not in predictors.columns 91 | 92 | Returns: 93 | FeatureImportanceExplanation with the performances of the model 94 | """ 95 | 96 | implementation = self.__impl_base 97 | feature_names = [] 98 | # Cheap way of check if type(predictors) == pyspark.sql.DataFrame (without importing pyspark). 99 | if hasattr(type(predictors), 'toPandas'): 100 | if type(target) != str: 101 | raise ValueError("""If predictors is a Spark DataFrame, target should be the name \ 102 | of the tareget column (a str)""") 103 | implementation = self.__impl_pyspark 104 | feature_names = list(filter(lambda x: x!=target, predictors.columns)) 105 | if len(feature_names) == len(predictors.columns): 106 | raise ValueError(f"""`target` must be the name of the target column in the DataFrame. \ 107 | Value passed: {target}""") 108 | else: 109 | feature_names = list(predictors.columns) 110 | if type(target) == str: 111 | feature_names = list(filter(lambda x: x!=target, feature_names)) 112 | 113 | metrics = {} 114 | for col in feature_names: 115 | metrics[col] = implementation(predictors, target, col) 116 | if self.normalize: 117 | metrics = self._normalize_importances(metrics) 118 | return FeatureImportanceExplanation(metrics) 119 | 120 | def __impl_base(self, predictors, target, column): 121 | temp = predictors.copy() 122 | temp[column] = np.array(temp[column].sample(frac=1)) 123 | return self.eval_fn(temp, target) 124 | 125 | def __impl_pyspark(self, predictors, target, column): 126 | from pyspark.sql.functions import rand, row_number, monotonically_increasing_id 127 | from pyspark.sql.window import Window 128 | 129 | # Shuffle column values creating a temporal rand column and ordering by it. 130 | # Then, we merge the ordered DF with the old one removing the original and 131 | # temporal columns 132 | shuffled = predictors.select(column).withColumn('rand', rand()).orderBy('rand') 133 | # In order to merge back we also need to include a row number. 134 | shuffled=shuffled.withColumn('shuff_id_idx', row_number().over(Window.orderBy(monotonically_increasing_id()))) 135 | temp = predictors.withColumn('shuff_id_idx', row_number().over(Window.orderBy(monotonically_increasing_id()))) 136 | 137 | temp = temp.drop(column).join(shuffled, on='shuff_id_idx').drop('shuff_id_idx').drop('rand') 138 | 139 | return self.eval_fn(temp, None) 140 | 141 | def _normalize_importances(self, metrics): 142 | df_metrics = pd.DataFrame.from_dict(metrics, orient="index", columns=["importance"]) 143 | df_metrics["importance"] = MinMaxScaler().fit_transform(df_metrics["importance"].values.reshape(-1,1)) 144 | return df_metrics["importance"].to_dict() 145 | -------------------------------------------------------------------------------- /tests/explainability/test_partial_dependence.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | from mercury.explainability import PartialDependenceExplainer 6 | 7 | from pyspark.ml import Pipeline 8 | from pyspark.ml.feature import StringIndexer, VectorAssembler 9 | from pyspark.sql.functions import when, col 10 | 11 | import pytest 12 | 13 | pytestmark = pytest.mark.usefixtures("model_and_data_pdp") 14 | 15 | 16 | def test_pandas_classification(model_and_data_pdp): 17 | """ Tests the explainer for multinomial classification task using "plain Python" 18 | models 19 | """ 20 | rf = model_and_data_pdp['rf_iris_sk'] 21 | features = model_and_data_pdp['iris_pd_df'] 22 | 23 | explainer = PartialDependenceExplainer(rf.predict_proba) 24 | explanation = explainer.explain(features) 25 | return_dict = explanation.data 26 | 27 | # Explanation should contain all the features 28 | assert list(return_dict.keys()) == list(features.columns) 29 | # Predictions should contain R^3 vectors 30 | assert return_dict[list(return_dict.keys())[0]]['preds'].shape[1] == 3 31 | # Quantiles should be R^3 32 | assert return_dict[list(return_dict.keys())[0]]['lower_quantile'].shape[1] == 3 33 | assert return_dict[list(return_dict.keys())[0]]['upper_quantile'].shape[1] == 3 34 | 35 | # # Check plotting doesnt crash 36 | # explanation.plot() 37 | 38 | 39 | def test_pandas_regression(model_and_data_pdp): 40 | rf = model_and_data_pdp['rf_houses_sk'] 41 | features = model_and_data_pdp['houses_pd_df'] 42 | explainer = PartialDependenceExplainer(rf.predict, verbose=True) 43 | 44 | features_to_ignore = ['HouseAge', 'AveBedrms', 'Population'] 45 | features_to_use = [f for f in list(features.columns) if f not in 46 | features_to_ignore] 47 | 48 | explanation = explainer.explain(features, ignore_feats=features_to_ignore) 49 | 50 | return_dict = explanation.data 51 | # Explanation should contain all the non-ignored features 52 | assert list(return_dict.keys()) == features_to_use 53 | # Predictions should contain real numbers 54 | assert len(return_dict[list(return_dict.keys())[0]]['preds'].shape) == 1 55 | # Quantiles should contain real numbers 56 | assert len(return_dict[list(return_dict.keys())[0]]['lower_quantile'].shape) == 1 57 | assert len(return_dict[list(return_dict.keys())[0]]['upper_quantile'].shape) == 1 58 | 59 | # # Check plotting doesnt crash 60 | # explanation.plot() 61 | 62 | 63 | def test_pandas_regression_with_categoricals(model_and_data_pdp): 64 | rf = model_and_data_pdp['rf_boston_sk'] 65 | features = model_and_data_pdp['boston_pd_df'] 66 | explainer = PartialDependenceExplainer(rf.predict, quantiles=False) 67 | explanation = explainer.explain(features, ignore_feats=['PTRATIO', 'B', 'LSTAT', 'AGE']) 68 | return_dict = explanation.data 69 | # Predictions should contain real numbers 70 | assert len(return_dict[list(return_dict.keys())[0]]['preds'].shape) == 1 71 | 72 | assert len(return_dict['NOX']['values']) == 50 73 | 74 | # # Check plotting doesnt crash 75 | # explanation.plot() 76 | 77 | 78 | def test_spark_classification(model_and_data_pdp): 79 | rf = model_and_data_pdp['rf_iris_sp'] 80 | assembler = model_and_data_pdp['assembler_iris'] 81 | features = model_and_data_pdp['iris_sp_df'] 82 | 83 | def my_pred_fn(data): 84 | temp_df = assembler.transform(data) 85 | return rf.transform(temp_df) 86 | 87 | features_to_ignore = ['petal_length','petal_width'] 88 | features_to_use = [f for f in list(features.columns) if f not in features_to_ignore] 89 | 90 | explainer = PartialDependenceExplainer(my_pred_fn, output_col='probability') 91 | explanation = explainer.explain(features, ignore_feats=features_to_ignore) 92 | return_dict = explanation.data 93 | 94 | # Explanation should contain all the non-ignored features 95 | assert list(return_dict.keys()) == features_to_use 96 | # Predictions should contain R^3 vectors 97 | assert return_dict[list(return_dict.keys())[0]]['preds'].shape[1] == 3 98 | # Quantiles should be R^3 99 | assert return_dict[list(return_dict.keys())[0]]['lower_quantile'].shape[1] == 3 100 | assert return_dict[list(return_dict.keys())[0]]['upper_quantile'].shape[1] == 3 101 | 102 | # # Check plotting doesnt crash 103 | # explanation.plot(filter_classes=[True, False, True], quantiles=[True, False, True]) 104 | 105 | 106 | def test_spark_regression(model_and_data_pdp): 107 | rf = model_and_data_pdp['rf_houses_sp'] 108 | assembler = model_and_data_pdp['assembler_houses'] 109 | features = model_and_data_pdp['houses_sp_df'] 110 | 111 | def my_pred_fn(data): 112 | temp_df = assembler.transform(data) 113 | return rf.transform(temp_df) 114 | 115 | features_to_ignore = ['AveBedrms', 'Population'] 116 | features_to_use = [f for f in list(features.columns) if f not in 117 | features_to_ignore] 118 | 119 | explainer = PartialDependenceExplainer(my_pred_fn, output_col='prediction', verbose=True, quantiles=False) 120 | explanation = explainer.explain(features, ignore_feats=features_to_ignore) 121 | return_dict = explanation.data 122 | 123 | # Explanation should contain all the non-ignored features 124 | assert list(return_dict.keys()) == features_to_use 125 | # Predictions should contain real numbers 126 | assert len(return_dict[list(return_dict.keys())[0]]['preds'].shape) == 1 127 | 128 | 129 | def test_spark_regression_with_categorical(model_and_data_pdp): 130 | rf = model_and_data_pdp['rf_boston_sp'] 131 | assembler = model_and_data_pdp['assembler_boston'] 132 | features = model_and_data_pdp['boston_sp_df'] 133 | 134 | # Emulate a categorical variable with strings 135 | features = features.withColumn('AGESTR', 136 | when(col('AGE') < 20, "YOUNG") 137 | .when(col('AGE') < 60, "MID") 138 | .when(col('AGE') < 100, "OLD") 139 | .otherwise("ELDER") 140 | ).drop('AGE')\ 141 | .withColumnRenamed('AGESTR', 'AGE') 142 | 143 | # Make a pipeline instead of a model 144 | indexer = StringIndexer(inputCol="AGE", outputCol="AGEInt") 145 | assembler = VectorAssembler( 146 | inputCols=['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGEInt', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'], 147 | outputCol='features' 148 | ) 149 | pipe = Pipeline(stages=[indexer, assembler, rf]) 150 | m = pipe.fit(features) 151 | 152 | def my_pred_fn(data): 153 | return m.transform(data) 154 | 155 | features_to_ignore = ['PTRATIO', 'B', 'LSTAT', 'CHAS', 'CRIM', 'TAX'] 156 | features_to_use = [f for f in list(features.columns) if f not in 157 | features_to_ignore] 158 | 159 | explainer = PartialDependenceExplainer(my_pred_fn, output_col='prediction', resolution=10) 160 | explanation = explainer.explain(features, ignore_feats=features_to_ignore) 161 | return_dict = explanation.data 162 | 163 | # Explanation should contain all the non-ignored features 164 | assert list(return_dict.keys()) == features_to_use 165 | # Predictions should contain real numbers 166 | assert len(return_dict[list(return_dict.keys())[0]]['preds'].shape) == 1 167 | # Assert integrity of categorical string variables 168 | assert type(explanation.data['AGE']['values'][0]) == str 169 | 170 | # # Check plotting doesnt crash 171 | # explanation.plot(quantiles=False) 172 | 173 | 174 | def test_explanation_plot(model_and_data_pdp): 175 | rf = model_and_data_pdp['rf_boston_sk'] 176 | features = model_and_data_pdp['boston_pd_df'] 177 | explainer = PartialDependenceExplainer(rf.predict) 178 | explanation = explainer.explain(features, ignore_feats=['PTRATIO', 'B', 'LSTAT', 'AGE']) 179 | 180 | assert len(explanation['CHAS'][0]) == 2 and len(explanation['CHAS'][1]) == 2 181 | 182 | _, ax = plt.subplots() 183 | explanation.plot_single('CRIM', ax=ax) 184 | assert ax.get_title() == 'CRIM' 185 | 186 | # # Check that plotting doesnt crash 187 | # explanation.plot(quantiles=True) 188 | -------------------------------------------------------------------------------- /tests/explainability/test_shuffle_importance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | import pyspark 5 | import pyspark.ml.regression as pysparkreg 6 | import pyspark.ml.classification as pysparkclas 7 | 8 | from pyspark.sql import SparkSession 9 | from pyspark.ml.evaluation import RegressionEvaluator 10 | from pyspark.ml.feature import VectorAssembler 11 | from pyspark.sql.functions import when, col 12 | from sklearn.datasets import fetch_california_housing, load_iris 13 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier 14 | from sklearn.metrics import mean_squared_error 15 | from mercury.explainability.explainers.shuffle_importance import ShuffleImportanceExplainer 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def model_and_data(): 20 | spark_sess = SparkSession.builder.config("k1", "v1").getOrCreate() 21 | 22 | houses = fetch_california_housing() 23 | houses_pd_df = pd.DataFrame(houses['data'], columns=houses['feature_names']) 24 | houses_pd_df['target'] = houses['target'] 25 | houses_sp_df = spark_sess.createDataFrame(houses_pd_df) 26 | 27 | # Fit sklearn RFs 28 | rf_houses_sk = RandomForestRegressor().fit(houses_pd_df[['MedInc', 'HouseAge', 29 | 'AveRooms', 'AveBedrms', 'Population', 'AveOccup','Latitude', 'Longitude']], 30 | houses_pd_df['target']) 31 | 32 | # Fit Spark RFs 33 | assembler_houses = VectorAssembler( 34 | inputCols=['MedInc','HouseAge','AveRooms','AveBedrms','Population','AveOccup','Latitude','Longitude'], 35 | outputCol='features') 36 | houses_sp_df_temp = assembler_houses.transform(houses_sp_df) 37 | rf_houses_sp = pysparkreg.RandomForestRegressor( 38 | featuresCol="features", 39 | labelCol="target" 40 | ).fit(houses_sp_df_temp) 41 | 42 | return { 43 | 'spark_sess': spark_sess, 44 | 'houses_pd_df': houses_pd_df, 45 | 'houses_sp_df': houses_sp_df.drop("features"), 46 | 'rf_houses_sk': rf_houses_sk, 47 | 'rf_houses_sp': rf_houses_sp, 48 | 'assembler_houses': assembler_houses 49 | } 50 | 51 | 52 | pytestmark = pytest.mark.usefixtures("model_and_data") 53 | 54 | 55 | def test_importance_pyspark_target_doesnt_exist(model_and_data): 56 | rf = model_and_data['rf_houses_sp'] 57 | assembler = model_and_data['assembler_houses'] 58 | features = model_and_data['houses_sp_df'] 59 | evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='target') 60 | feat_names = [f for f in features.columns if f != 'target'] 61 | 62 | def eval_fn(features, target): 63 | t = assembler.transform(features) 64 | t = rf.transform(t) 65 | return evaluator.evaluate(t) 66 | 67 | expl = ShuffleImportanceExplainer(eval_fn) 68 | 69 | with pytest.raises(ValueError): 70 | # Target does not exist 71 | explanation = expl.explain(features, 'target_not_existing') 72 | 73 | 74 | def test_importance_pyspark_target_exists(model_and_data): 75 | rf = model_and_data['rf_houses_sp'] 76 | assembler = model_and_data['assembler_houses'] 77 | features = model_and_data['houses_sp_df'] 78 | evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='target') 79 | feat_names = [f for f in features.columns if f != 'target'] 80 | 81 | def eval_fn(features, target): 82 | t = assembler.transform(features) 83 | t = rf.transform(t) 84 | return evaluator.evaluate(t) 85 | 86 | expl = ShuffleImportanceExplainer(eval_fn) 87 | explanation = expl.explain(features, 'target') 88 | 89 | assert set(explanation.data.keys()) == set(feat_names) 90 | assert explanation.get_importances()[0][0] == 'MedInc' 91 | 92 | # Check explanation doesnt crash 93 | explanation['MedInc'] 94 | # Check yaxis plot contains correct label names 95 | ax = explanation.plot() 96 | assert len(ax.get_yaxis().get_ticklabels()) == len(feat_names) 97 | 98 | def test_importance_pyspark_with_categoricals(model_and_data): 99 | assembler = model_and_data['assembler_houses'] 100 | features = model_and_data['houses_sp_df'] 101 | evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='target') 102 | # Emulate a categorical variable 103 | feats = features.withColumn('AveRoomsCat', 104 | when(col('AveRooms') < 2, 0) 105 | .when(col('AveRooms') < 5, 1) 106 | .when(col('AveRooms') < 20, 3) 107 | .otherwise(10) 108 | ).drop('AveRooms')\ 109 | .withColumnRenamed('AveRoomsCat', 'AveRooms') 110 | 111 | houses_sp_df_temp = assembler.transform(feats) 112 | 113 | rf= pysparkreg.RandomForestRegressor( 114 | featuresCol="features", 115 | labelCol="target" 116 | ).fit(houses_sp_df_temp) 117 | 118 | feat_names = [f for f in feats.columns if f != 'target'] 119 | 120 | def eval_fn(features, target): 121 | t = assembler.transform(features) 122 | t = rf.transform(t) 123 | return evaluator.evaluate(t) 124 | 125 | expl = ShuffleImportanceExplainer(eval_fn) 126 | explanation = expl.explain(feats, 'target') 127 | 128 | assert set(explanation.data.keys()) == set(feat_names) 129 | assert explanation.get_importances()[0][0] == 'MedInc' 130 | 131 | 132 | def test_importance_pyspark_bad_input(model_and_data): 133 | rf = model_and_data['rf_houses_sp'] 134 | assembler = model_and_data['assembler_houses'] 135 | features = model_and_data['houses_sp_df'] 136 | evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='target') 137 | feat_names = [f for f in features.columns if f != 'target'] 138 | 139 | def eval_fn(features, target): 140 | t = assembler.transform(features) 141 | t = rf.transform(t) 142 | return evaluator.evaluate(t) 143 | 144 | expl = ShuffleImportanceExplainer(eval_fn) 145 | 146 | with pytest.raises(ValueError): 147 | # Explain target parameter for pysparkshould be a string 148 | explanation = expl.explain(features, [3,4,5,6,3]) 149 | 150 | def test_importance_pandas(model_and_data): 151 | rf = model_and_data['rf_houses_sk'] 152 | houses_df = model_and_data['houses_pd_df'] 153 | 154 | feature_names = ['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup','Latitude', 'Longitude'] 155 | tgt = 'target' 156 | 157 | feats = houses_df.loc[:, feature_names] 158 | targ = houses_df.loc[:, tgt] 159 | 160 | def eval_fn(features, target): 161 | pred = rf.predict(features) 162 | return mean_squared_error(target, pred) 163 | 164 | expl = ShuffleImportanceExplainer(eval_fn) 165 | explanation = expl.explain(feats, targ) 166 | assert explanation.get_importances()[0][0] == 'MedInc' 167 | 168 | # Same test without normalizatoin 169 | expl = ShuffleImportanceExplainer(eval_fn, normalize=False) 170 | explanation = expl.explain(feats, targ) 171 | assert explanation.get_importances()[0][0] == 'MedInc' 172 | 173 | 174 | def test_importance_bad_input_pandas(model_and_data): 175 | rf = model_and_data['rf_houses_sk'] 176 | houses_df = model_and_data['houses_pd_df'] 177 | 178 | feature_names = ['MedInc', 'HouseAge', 'AveRooms', 'AveBedrms', 'Population', 'AveOccup','Latitude', 'Longitude'] 179 | tgt = 'target' 180 | 181 | feats = houses_df.loc[:, feature_names] 182 | targ = houses_df.loc[:, tgt] 183 | 184 | def eval_fn(features, target): 185 | pred = rf.predict(features) 186 | return mean_squared_error(target, pred) 187 | 188 | expl = ShuffleImportanceExplainer(eval_fn) 189 | 190 | with pytest.raises(ValueError): 191 | # Explain target parameter for plain python should not be a string 192 | explanation = expl.explain(feats, "target") 193 | 194 | def test_importance_pandas_bug_target_col_not_filtered(model_and_data): 195 | rf = model_and_data['rf_houses_sk'] 196 | houses_df = model_and_data['houses_pd_df'] 197 | 198 | def eval_fn(data, target_col): 199 | X = data.loc[:, data.columns != target_col] 200 | Y = data.loc[:, target_col] 201 | Yh = rf.predict(X) 202 | return mean_squared_error(Y, Yh) 203 | 204 | expl = ShuffleImportanceExplainer(eval_fn) 205 | explanation = expl.explain(houses_df, 'target') 206 | features, importances = zip(*explanation.get_importances()) 207 | 208 | assert 'target' not in features 209 | 210 | -------------------------------------------------------------------------------- /tests/explainability/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pyspark import SparkConf, SparkContext 4 | from pyspark.sql import SparkSession 5 | 6 | from sklearn.datasets import fetch_california_housing, load_iris 7 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier 8 | from sklearn.datasets import make_blobs 9 | from sklearn.cluster import KMeans as Sklearn_KMeans 10 | from sklearn.preprocessing import StandardScaler 11 | from pyspark.ml.classification import GBTClassificationModel 12 | from sklearn.pipeline import Pipeline as SklearnPipeline 13 | import pyspark.ml.regression as pysparkreg 14 | import pyspark.ml.classification as pysparkclas 15 | from pyspark.ml.feature import VectorAssembler 16 | from pyspark.ml.clustering import KMeans as SparkKMeans 17 | from pyspark.ml import Pipeline 18 | 19 | 20 | import os 21 | import pandas as pd 22 | import pytest 23 | 24 | @pytest.fixture(scope='package') 25 | def spark_context(): 26 | conf = (SparkConf().setMaster('local[2]').setAppName('Alibi Tests')) 27 | sc = SparkContext(conf=conf) 28 | 29 | logger = logging.getLogger('py4j') 30 | logger.setLevel(logging.WARN) 31 | 32 | return sc 33 | 34 | @pytest.fixture(scope='package') 35 | def spark_session(spark_context): 36 | return SparkSession.builder.appName('Alibi Tests').getOrCreate() 37 | 38 | @pytest.fixture(scope="module") 39 | def model_and_data(spark_session): 40 | gbtModel = GBTClassificationModel.load('tests/explainability/model_and_data_pyspark/gbtModelPySpark') 41 | assembler = VectorAssembler.load('tests/explainability/model_and_data_pyspark/assemblerScaled') 42 | data_pd = pd.read_csv('tests/explainability/model_and_data_pyspark/data_pandas_red.csv', index_col=0) 43 | data_pyspark = spark_session.createDataFrame(data_pd) 44 | data_preproc = assembler.transform(data_pyspark) 45 | return { 46 | 'gbtModel': gbtModel, 47 | 'assembler': assembler, 48 | 'data_pd': data_pd, 49 | 'data_pyspark': data_pyspark, 50 | 'data_preproc': data_preproc 51 | } 52 | 53 | @pytest.fixture(scope='module') 54 | def model_and_data_ale(spark_session): 55 | gbtModel = GBTClassificationModel.load('tests/explainability/model_and_data_pyspark/gbtModelPySpark') 56 | assembler = VectorAssembler.load('tests/explainability/model_and_data_pyspark/assemblerScaled') 57 | data_pd = pd.read_csv('tests/explainability/model_and_data_pyspark/data_ale.csv', index_col=0) 58 | data_pyspark = spark_session.createDataFrame(data_pd) 59 | data_preproc = assembler.transform(data_pyspark) 60 | return { 61 | 'gbtModel': gbtModel, 62 | 'assembler': assembler, 63 | 'data_pd': data_pd, 64 | 'data_pyspark': data_pyspark, 65 | 'data_preproc': data_preproc 66 | } 67 | 68 | @pytest.fixture(scope="session", autouse=True) 69 | def env_var_patching(): 70 | os.environ["MOMA_ENV"] = "test" 71 | os.environ["MERCURY_LOGGING_DISABLE"] = "1" 72 | 73 | 74 | @pytest.fixture(scope="module") 75 | def model_and_data_pdp(spark_session): 76 | spark_sess = spark_session 77 | 78 | iris = load_iris() 79 | houses = fetch_california_housing() 80 | 81 | houses_pd_df = pd.DataFrame(houses['data'], columns=houses['feature_names']) 82 | houses_pd_df['target'] = houses['target'] 83 | iris_pd_df = pd.DataFrame(iris['data'], columns=['sepal_length', 'sepal_width', 'petal_length', 'petal_width']) 84 | iris_pd_df['target'] = iris['target'] 85 | boston_pd_df = pd.read_csv("tests/explainability/model_and_data/boston.csv") 86 | 87 | houses_sp_df = spark_sess.createDataFrame(houses_pd_df) 88 | iris_sp_df = spark_sess.createDataFrame(iris_pd_df) 89 | boston_sp_df = spark_sess.createDataFrame(boston_pd_df) 90 | 91 | # Fit sklearn RFs 92 | rf_iris_sk = RandomForestClassifier().fit(iris_pd_df[['sepal_length', 'sepal_width', 93 | 'petal_length', 'petal_width']], iris_pd_df['target']) 94 | rf_houses_sk = RandomForestRegressor().fit(houses_pd_df[['MedInc', 'HouseAge', 95 | 'AveRooms', 'AveBedrms', 'Population', 'AveOccup','Latitude', 'Longitude']], 96 | houses_pd_df['target']) 97 | rf_boston_sk = RandomForestRegressor().fit( 98 | boston_pd_df[['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 99 | 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']], 100 | boston_pd_df['target'] 101 | ) 102 | 103 | # Fit Spark RFs 104 | assembler_iris = VectorAssembler(inputCols=['sepal_length','sepal_width','petal_length','petal_width'], 105 | outputCol='features') 106 | iris_sp_df_temp = assembler_iris.transform(iris_sp_df) 107 | rf_iris_sp = pysparkclas.RandomForestClassifier(featuresCol="features", 108 | labelCol="target").fit(iris_sp_df_temp) 109 | 110 | assembler_houses = VectorAssembler( 111 | inputCols=['MedInc','HouseAge','AveRooms','AveBedrms','Population','AveOccup','Latitude','Longitude'], 112 | outputCol='features') 113 | houses_sp_df_temp = assembler_houses.transform(houses_sp_df) 114 | rf_houses_sp = pysparkreg.RandomForestRegressor( 115 | featuresCol="features", 116 | labelCol="target" 117 | ).fit(houses_sp_df_temp) 118 | 119 | assembler_boston = VectorAssembler( 120 | inputCols=['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'], 121 | outputCol='features' 122 | ) 123 | boston_sp_df_temp = assembler_boston.transform(boston_sp_df) 124 | rf_boston_sp = pysparkreg.RandomForestRegressor( 125 | featuresCol='features', 126 | labelCol='target' 127 | ).fit(boston_sp_df_temp) 128 | 129 | return { 130 | 'spark_sess': spark_sess, 131 | 'iris_pd_df': iris_pd_df.loc[:, ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']], 132 | 'houses_pd_df': houses_pd_df.loc[:, ['MedInc','HouseAge','AveRooms','AveBedrms','Population','AveOccup','Latitude', 'Longitude']], 133 | 'boston_pd_df': boston_pd_df.loc[:, ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT']], 134 | 'iris_sp_df': iris_sp_df.drop("features").drop('target'), 135 | 'houses_sp_df': houses_sp_df.drop("features").drop('target'), 136 | 'boston_sp_df': boston_sp_df.drop("target"), 137 | 'rf_iris_sk': rf_iris_sk, 138 | 'rf_houses_sk': rf_houses_sk, 139 | 'rf_boston_sk': rf_boston_sk, 140 | 'rf_iris_sp': rf_iris_sp, 141 | 'assembler_iris':assembler_iris, 142 | 'rf_houses_sp': rf_houses_sp, 143 | 'assembler_houses': assembler_houses, 144 | 'assembler_boston': assembler_boston, 145 | 'rf_boston_sp': rf_boston_sp 146 | } 147 | 148 | @pytest.fixture(scope="module") 149 | def model_and_data_cte(spark_session): 150 | 151 | # Generate Dataset 152 | K = 3 153 | random_state = 42 154 | x_data, _ = make_blobs(n_samples=1000, n_features=2, centers=K, cluster_std=2.5, random_state=random_state) 155 | 156 | # K-means with pandas and sklearn 157 | features_names = ["feature_1", "feature_2"] 158 | pandas_df = pd.DataFrame(x_data, columns=features_names) 159 | sk_kmeans = Sklearn_KMeans(K, random_state=42) 160 | sk_kmeans.fit(x_data) 161 | 162 | # K-means sklearn pipeline 163 | sk_pipeline = SklearnPipeline(steps=[("scaler", StandardScaler()), ("kmeans", Sklearn_KMeans(K, random_state=random_state))]) 164 | sk_pipeline.fit(x_data) 165 | 166 | # K-means with spark dataframes (spark pipeline) 167 | spark_df = spark_session.createDataFrame(pandas_df) 168 | assembler = VectorAssembler(inputCols=features_names, outputCol="features") 169 | spark_kmeans = SparkKMeans(k=K, seed=random_state) 170 | spark_pipeline = Pipeline(stages=[assembler, spark_kmeans]) 171 | spark_pipeline_model = spark_pipeline.fit(spark_df) 172 | 173 | # K-means with spark dataframes (no pipeline) 174 | assembler = VectorAssembler(inputCols=features_names, outputCol="features") 175 | spark_df_2 = assembler.transform(spark_df).select("features") 176 | spark_kmeans = SparkKMeans(k=K, seed=random_state) 177 | spark_kmeans_model = spark_kmeans.fit(spark_df_2) 178 | 179 | return { 180 | 'spark_sess': spark_session, 181 | 'pandas_df': pandas_df, 182 | 'sk_kmeans': sk_kmeans, 183 | 'sk_pipeline': sk_pipeline, 184 | 'spark_df': spark_df, 185 | 'spark_pipeline_model': spark_pipeline_model, 186 | 'spark_df_2': spark_df_2, 187 | 'spark_kmeans_model': spark_kmeans_model, 188 | 'K': K 189 | } -------------------------------------------------------------------------------- /mercury/explainability/explanations/partial_dependence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import typing as TP 4 | 5 | from math import ceil 6 | 7 | class PartialDependenceExplanation(): 8 | """ 9 | This class holds the result of a Partial Dependence explanation and 10 | provides functionality for plotting those results via Partial Dependence Plots. 11 | 12 | Args: 13 | data (dict): 14 | Contains the result of the PartialDependenceExplainer. It must be in the 15 | form of: :: 16 | { 17 | 'feature_name': {'values': [...], 'preds': [...], 'lower_quantile': [...], 'upper_quantile': [...]}, 18 | 'feature_name2': {'values': [...], 'preds': [...], 'lower_quantile': [...], 'upper_quantile': [...]}, 19 | ... 20 | } 21 | 22 | """ 23 | def __init__(self, data): 24 | self.data = data 25 | 26 | def plot_single(self, var_name: str, ax=None, quantiles:TP.Union[bool, list] = False, filter_classes:list = None, **kwargs): 27 | """ 28 | Plots the partial dependence of a single variable. 29 | 30 | Args: 31 | var_name (str): 32 | Name of the desired variable to plot. 33 | ax (matplotlib.axes._subplots.AxesSubplot): 34 | Axes object on which the data will be plotted. 35 | quantiles (bool or list[bool]): 36 | Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions 37 | have high or low dispersion. If data doesn't contain the quantiles this parameter will be ignored. 38 | filter_classes (list): 39 | List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable 40 | is not categorical. 41 | """ 42 | # If user pass a single bool and prediction data is a multinomial, we conver the 43 | # single boolean to a mask array to only plot the quantile range over the selected 44 | # classes. 45 | if len(self.data[var_name]['preds'].shape)>=2: 46 | if type(quantiles) == list and len(quantiles) != self.data[var_name]['preds'].shape[1]: 47 | raise ValueError("len(quantiles) must be equal to the number of classes.") 48 | if type(quantiles) == bool: 49 | quantiles = [quantiles for i in range(self.data[var_name]['preds'].shape[1])] 50 | elif type(quantiles) == list and len(self.data[var_name]['preds'].shape)==1: 51 | quantiles = quantiles[0] 52 | 53 | if filter_classes is not None: 54 | filter_classes = np.where(filter_classes)[0].tolist() 55 | else: 56 | filter_classes = np.arange(self.data[var_name]['preds'].shape[-1]).tolist() 57 | if len(self.data[var_name]['preds'].shape) < 2: 58 | filter_classes = None 59 | 60 | ax = ax if ax else plt.gca() 61 | 62 | ax.set_title(var_name) 63 | ax.set_xlabel(f"{var_name} value") 64 | ax.set_ylabel("Avg model prediction") 65 | 66 | vals = np.array(self.data[var_name]['values']) 67 | int_locations = np.arange(len(vals)) 68 | 69 | non_numerical_values = False 70 | # Check if variable is categorical. If so, plot bars 71 | if self.data[var_name]['categorical'] and not type(vals[0]) == float: 72 | bar_width = .2 73 | class_nb = 0 if not filter_classes else len(filter_classes) 74 | 75 | if type(vals[0]) == float or type(vals[0]) == int: 76 | ax.set_xticks(self.data[var_name]['values']) 77 | else: 78 | non_numerical_values = True 79 | bar_offsets = np.linspace(-bar_width, bar_width, num=class_nb) / class_nb 80 | ax.set_xticks(int_locations) 81 | ax.set_xticklabels(self.data[var_name]['values']) 82 | 83 | if class_nb == 0: 84 | # If prediction is a single scalar 85 | if non_numerical_values: 86 | ax.bar(int_locations, self.data[var_name]['preds'], width=bar_width, label='Prediction',**kwargs) 87 | else: 88 | ax.bar(vals, self.data[var_name]['preds'], width=bar_width, label='Prediction', **kwargs) 89 | 90 | if quantiles: 91 | ax.errorbar( 92 | int_locations, 93 | self.data[var_name]['preds'], 94 | yerr=np.vstack([self.data[var_name]['lower_quantile'], 95 | self.data[var_name]['upper_quantile']]), 96 | fmt='ko', 97 | label='Quantiles', 98 | **kwargs 99 | ) 100 | 101 | else: 102 | # If prediction is multiclass 103 | for i in range(class_nb): 104 | if i in filter_classes: 105 | if non_numerical_values: 106 | ax.bar(int_locations + bar_offsets[i], self.data[var_name]['preds'][:,i], 107 | width=bar_width / class_nb, label=f'Class {i}',**kwargs) 108 | else: 109 | ax.bar(vals, self.data[var_name]['preds'][:,i], width=bar_width / class_nb, label=f'Class {i}', **kwargs) 110 | 111 | if quantiles[i]: 112 | ax.errorbar( 113 | int_locations + bar_offsets[i], 114 | self.data[var_name]['preds'][:, i], 115 | yerr=np.vstack([self.data[var_name]['lower_quantile'][:,i], 116 | self.data[var_name]['upper_quantile'][:,i]]), 117 | fmt='ko', 118 | label=f'Quantiles {i}', 119 | **kwargs 120 | ) 121 | 122 | if class_nb > 0: 123 | ax.legend() 124 | 125 | else: # Variable is continuous 126 | 127 | # Check whether prediction data is multinomial 128 | if filter_classes: 129 | objs = ax.plot(vals, self.data[var_name]['preds'][:, filter_classes], **kwargs) 130 | else: 131 | objs = ax.plot(vals, self.data[var_name]['preds'], **kwargs) 132 | if len(self.data[var_name]['preds'].shape)>=2: 133 | labels = [f"Class: {i}" for i in range(self.data[var_name]['preds'].shape[1])] 134 | # Filter labels 135 | labels = [l for i, l in enumerate(labels) if i in filter_classes] 136 | # Show labels 137 | ax.legend(iter(objs), labels) 138 | for i in range(self.data[var_name]['preds'].shape[1]): 139 | if quantiles[i] and len(self.data[var_name]['lower_quantile']) > 0: 140 | # Plot quantiles and a shaded band between them 141 | 142 | # We will need the color assigned to each one of the lines so the 143 | # shaded area also has that color. Since filtering can be done, we 144 | # extract the line index as the minimum between the current class 145 | # index and the maximum amount of lines on the canvas. 146 | obj_index = min(i, len(objs) - 1) 147 | 148 | # Actually plot the shaded area 149 | ax.plot(vals, self.data[var_name]['lower_quantile'][:,i], ls='--', color=objs[obj_index].get_color(),**kwargs) 150 | ax.plot(vals, self.data[var_name]['upper_quantile'][:,i], ls='--', color=objs[obj_index].get_color(), **kwargs) 151 | ax.fill_between(vals, 152 | self.data[var_name]['lower_quantile'][:,i], self.data[var_name]['upper_quantile'][:,i], alpha=.05) 153 | else: # If target is not multinomial 154 | if quantiles and len(self.data[var_name]['lower_quantile']) > 0: 155 | # Plot quantiles and a shaded band between them 156 | ax.plot(vals, self.data[var_name]['lower_quantile'], ls='--', color=objs[0].get_color(),**kwargs) 157 | ax.plot(vals, self.data[var_name]['upper_quantile'], ls='--', color=objs[0].get_color(), **kwargs) 158 | ax.fill_between(vals, self.data[var_name]['lower_quantile'], self.data[var_name]['upper_quantile'], alpha=.05) 159 | 160 | def plot(self, ncols:int = 1, figsize:tuple = (15,15), quantiles:TP.Union[bool, list] = False, filter_classes:list = None, **kwargs): 161 | """ 162 | Plots a summary of all the partial dependences. 163 | 164 | Args: 165 | ncols (int): 166 | Number of columns of the summary. 1 as default. 167 | figsize (tuple): 168 | Size of the plotted figure 169 | quantiles (bool or list): 170 | Whether to also plot the quantiles and a shaded area between them. Useful to check whether the predictions 171 | have high or low dispersion. If this is a list of booleans, quantiles 172 | will be plotted filtered by class (i.e. `quantiles[0]` = `class number 0`). 173 | filter_classes (list): 174 | List of bool with the classes to plot. If None, all classes will be plotted. Ignored if the target variable 175 | is not categorical. 176 | """ 177 | features = list(self.data.keys()) 178 | 179 | fig, ax = plt.subplots(ceil(len(features) / ncols), ncols, figsize=figsize) 180 | 181 | for i, feat_name in enumerate(features): 182 | sbplt = ax[i] if ncols==1 or ncols==len(features) else ax[i // ncols, i % ncols] 183 | self.plot_single(feat_name, sbplt, quantiles=quantiles, filter_classes=filter_classes, **kwargs) 184 | 185 | def __getitem__(self, key:str): 186 | """ 187 | Gets the dependence data of the desired feature. 188 | 189 | Args: 190 | key (str): 191 | Name of the feature. 192 | """ 193 | return self.data[key]['values'], self.data[key]['preds'] 194 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/counter_fact_basic.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from .explainer import MercuryExplainer 6 | from .cf_strategies import SimulatedAnnealing, Backtracking 7 | 8 | from mercury.explainability.explanations.counter_factual import ( 9 | CounterfactualBasicExplanation 10 | ) 11 | 12 | 13 | class CounterFactualExplainerBasic(MercuryExplainer): 14 | """ 15 | Explains predictions on tabular (i.e. matrix) data for binary/multiclass classifiers. 16 | Currently two main strategies are implemented: one following a backtracking strategy and 17 | another following a probabilistic process (simulated annealing strategy). 18 | 19 | Args: 20 | train (TP.Union['np.ndarray', pd.DataFrame]): 21 | Training dataset to extract feature bounds from. 22 | fn (TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], TP.Union[float, 'np.ndarray']]): 23 | Classifier `predict_proba`-like function. Note that the returned probabilities 24 | must be valid, ie. the values must be between 0 and 1. 25 | labels (TP.List[str]): 26 | List of labels to be used when plotting results. If DataFrame used, labels take 27 | dataframe column names. Default is empty list. 28 | bounds (TP.Optional['np.ndarray']): 29 | Feature bounds used when no train data is provided (shape must match labels'). 30 | Default is None. 31 | n_steps (int): 32 | Parameter used to indicate how small/large steps should be when exploring 33 | the space (default is 200). 34 | 35 | Raises: 36 | AssertionError: 37 | if bounds.size <= 0 when no train data is provided | 38 | if bounds.ndim != 2 when no train data is provided | 39 | if bounds.shape[1] != 2 when no train data is provided | 40 | if bounds.shape[0] != len(labels) 41 | TypeError: 42 | if train is not a DataFrame or numpy array. 43 | """ 44 | 45 | def __init__(self, 46 | train: TP.Union[TP.Optional['np.ndarray'], TP.Optional[pd.DataFrame]], 47 | fn: TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], TP.Union[float, 'np.ndarray']], 48 | labels: TP.List[str] = [], 49 | bounds: TP.Optional['np.ndarray'] = None, 50 | n_steps: int = 200) -> None: 51 | if train is None: 52 | # If data is not provided, labels and bounds are required 53 | assert bounds.size > 0, 'Bounds are required if no data is provided' 54 | assert bounds.ndim == 2 and bounds.shape[1] == 2, 'min/max values are required for each feature' 55 | assert len(labels) == bounds.shape[0], \ 56 | 'Labels and bound shapes must match, got {} and {} respectively' \ 57 | .format(len(labels), bounds.shape[0]) 58 | # min/max values for each feature 59 | self.labels = labels 60 | self.bounds = bounds 61 | else: 62 | # Compute bounds 63 | if isinstance(train, pd.DataFrame): 64 | self.labels = train.columns.tolist() 65 | self.bounds = train.describe().loc[['min', 'max']].values.T 66 | assert len(self.labels) == self.bounds.shape[0], \ 67 | 'Labels and bound shapes must match, got {} and {} respectively' \ 68 | .format(len(self.labels), self.bounds.shape[0]) 69 | elif isinstance(train, np.ndarray): 70 | self.labels = labels 71 | self.bounds = np.stack([ 72 | np.apply_along_axis(np.min, 0, train), 73 | np.apply_along_axis(np.max, 0, train)], axis=1) 74 | assert len(self.labels) == self.bounds.shape[0], \ 75 | 'Labels and bound shapes must match, got {} and {} respectively' \ 76 | .format(len(self.labels), self.bounds.shape[0]) 77 | else: 78 | raise TypeError('Invalid type for argument train, got {} but expected numpy array or pandas dataframe'. 79 | format(type(train))) 80 | 81 | # Compute steps 82 | self.n_steps = n_steps 83 | self.step = (self.bounds[:, 1] - self.bounds[:, 0]) / self.n_steps 84 | 85 | # Function to be evaluated on optimization 86 | self.fn = fn 87 | 88 | def explain(self, 89 | from_: 'np.ndarray', 90 | threshold: float, 91 | class_idx: int = 1, 92 | kernel: TP.Optional['np.ndarray'] = None, 93 | bounds: TP.Optional['np.ndarray'] = None, 94 | step: TP.Optional['np.ndarray'] = None, 95 | strategy: str = 'backtracking', 96 | report: bool = False, 97 | keep_explored_points: bool = True, 98 | **kwargs) -> CounterfactualBasicExplanation: 99 | """ 100 | Roll the panellet down the valley and find an explanation. 101 | 102 | Args: 103 | from_ ('np.ndarray'): 104 | Starting point. 105 | threshold (float): 106 | Probability to be achieved (if path is found). 107 | class_idx (int): 108 | Class to be explained (e.g. 1 for binary classifiers). 109 | kernel (TP.Optional['np.ndarray']): 110 | Used to penalize certain dimensions when trying to move around 111 | the probability space (some dimensions may be more difficult to explain, 112 | hence don't move along them). Default is np.ones(n), meaning all dimensions 113 | can be used to move around the space (must be a value between 0 and 1). 114 | bounds (TP.Optional['np.ndarray']): 115 | Feature bound values to be used when exploring the probability space. If not 116 | specified, the ones extracted from the training data are used instead. 117 | step (TP.Optional['np.ndarray']): 118 | Step values to be used when moving around the probability space. If not specified, 119 | training bounds are divided by 200 (arbitrary value) and these are used as step value. 120 | strategy (str): 121 | If 'backtracking', the backtracking strategy is used to move around the probability space. 122 | If 'simanneal', the simulated annealing strategy is used to move around the probability space. 123 | report (bool): 124 | Whether to report the algorithm progress during the execution. 125 | keep_explored_points (bool): 126 | Whether to keep the points that the algorithm explores. Setting it to False will decrease 127 | the computation time and memory usage in some cases. Default value is True. 128 | 129 | Raises: 130 | AssertionError: 131 | If `from_` number of dimensions is != 1 | 132 | If `from_` shape does not match `bounds` shape | 133 | If `bounds` shape is not valid | 134 | If `step` shape does not match `bounds` shape | 135 | ValueError: 136 | if strategy is not 'backtacking' or 'simanneal'. 137 | 138 | Returns: 139 | explanation (CounterfactualBasicExplanation): 140 | CounterfactualBasicExplanation with the solution found and how it differs from the starting point. 141 | """ 142 | 143 | if kernel is None: 144 | self.kernel = np.ones(from_.shape[0]) 145 | else: 146 | self.kernel = kernel 147 | 148 | assert from_.ndim == 1, \ 149 | 'Invalid starting point shape, got {} but expected unidimensional vector'.format(from_.shape) 150 | 151 | if bounds is not None: 152 | assert from_.shape[0] == bounds.shape[0], \ 153 | 'Starting point and bounds shapes should match, got {} and {}'.format(from_.shape, bounds.shape[0]) 154 | 155 | # Update bounds based on reference point 156 | l_bounds = self.bounds.copy() 157 | for i, bound in enumerate(l_bounds): 158 | new_min = bound[0] 159 | new_max = bound[1] 160 | if from_[i] < bound[0]: 161 | new_min = from_[i] 162 | elif from_[i] > bound[1]: 163 | new_max = from_[1] 164 | l_bounds[i, 0] = new_min 165 | l_bounds[i, 1] = new_max 166 | 167 | # Update bounds, if new_bounds are specified 168 | if bounds is not None: 169 | assert bounds.shape == l_bounds.shape, \ 170 | 'Invalid dimensions for new bounds, got {} but expected {}'.format(bounds.shape, l_bounds.shape) 171 | for i, bound in enumerate(bounds): 172 | # Update bound only if starting point is within it in this dimension 173 | if bound[0] <= from_[i] and bound[1] >= from_[i]: 174 | l_bounds[i] = bounds[i] 175 | 176 | if step is not None: 177 | assert step.shape[0] == l_bounds.shape[0], \ 178 | 'Invalid step shape, got {} but expected {}'.format(step.shape, l_bounds.shape[0]) 179 | self.step = step 180 | 181 | if strategy == 'backtracking': 182 | # Backtracking strategy 183 | sol, p, visited, explored = Backtracking(from_, l_bounds, self.step, self.fn, class_idx, 184 | threshold=threshold, kernel=self.kernel, report=report, 185 | keep_explored_points=keep_explored_points).run( 186 | **kwargs) 187 | ps = visited[:, -1] 188 | visited = visited[:, :-1] 189 | if keep_explored_points: 190 | explored_points = explored[:, :-1] 191 | explored_ps = explored[:, -1] 192 | else: 193 | explored_points = np.array([]) 194 | explored_ps = np.array([]) 195 | return CounterfactualBasicExplanation( 196 | from_, sol, p, visited, ps, l_bounds, explored_points, 197 | explored_ps, labels=self.labels) 198 | elif strategy == 'simanneal': 199 | # Simulated Annealing strategy 200 | sol, p, visited, energies = SimulatedAnnealing(from_, l_bounds, self.step, self.fn, class_idx, 201 | threshold=threshold, kernel=self.kernel, 202 | report=report).run(**kwargs) 203 | return CounterfactualBasicExplanation(from_, sol, abs(p), visited, energies[:-1], l_bounds, 204 | labels=self.labels) 205 | else: 206 | raise ValueError('Invalid strategy') 207 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/anchors.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import numpy as np 3 | import pandas as pd 4 | import signal 5 | 6 | from .explainer import MercuryExplainer 7 | from mercury.explainability.explainers import run_until_timeout 8 | from alibi.explainers import AnchorTabular 9 | from mercury.explainability.explanations.anchors import AnchorsWithImportanceExplanation 10 | from alibi.api.interfaces import Explanation 11 | 12 | 13 | class AnchorsWithImportanceExplainer(AnchorTabular, MercuryExplainer): 14 | """ 15 | Extending Alibi's AnchorsTabular Implementation, this module allows for the 16 | computation of feature importance by means of calculating several anchors. 17 | Initialize the anchor tabular explainer. 18 | 19 | Args: 20 | predict_fn: Model prediction function 21 | train_data: Pandas Dataframe with the features 22 | disc_perc: List or tuple with percentiles (int) used for discretization. 23 | categorical_names: Dictionary where keys are feature columns and values are the categories for the feature 24 | 25 | Raises: 26 | AttributeError: if categorical_names is not a dict 27 | AttributeError: if train_data is not a pd.DataFrame 28 | 29 | Example: 30 | ```python 31 | >>> explain_data = pd.read_csv('./test/explain_data.csv') 32 | >>> model = MyModel() # (Trained) model prediction function (has be callable) 33 | >>> explainer = AnchorsWithImportanceExplainer(model, explain_data) 34 | >>> explanation = explainer.explain(explain_data.head(10).values) # For the first 10 samples 35 | >>> explanation.interpret_explanations(n_important_features=2) 36 | # We can also get the feature importances for the first 10 samples. 37 | >>> anchorsExtendedExplainer.get_feature_importance(explain_data=explain_data.head(10)) 38 | ``` 39 | """ 40 | 41 | def __init__( 42 | self, 43 | predict_fn: TP.Callable, 44 | train_data: pd.DataFrame, 45 | categorical_names: TP.Dict[str, TP.List] = {}, 46 | disc_perc: TP.Tuple[TP.Union[int, float], ...] = (25, 50, 75), 47 | *args, **kwargs 48 | ) -> None: 49 | if not isinstance(categorical_names, dict): 50 | raise AttributeError(""" 51 | The attribute categorical_names should be a dictionary 52 | where the keys are the categorical feature names and the 53 | values are the categories for each categorical feature. 54 | """) 55 | 56 | if not isinstance(train_data, pd.DataFrame): 57 | raise AttributeError(""" 58 | train_data should be a pandas DataFrame. 59 | """) 60 | 61 | super().__init__(predict_fn, list(train_data.columns), categorical_names) 62 | self.categorical_names = categorical_names 63 | 64 | super().fit( 65 | train_data=train_data.values, 66 | disc_perc=disc_perc, 67 | *args, **kwargs 68 | ) 69 | 70 | def explain(self, 71 | X: np.ndarray, 72 | threshold: float = 0.95, 73 | delta: float = 0.1, 74 | tau: float = 0.15, 75 | batch_size: int = 100, 76 | coverage_samples: int = 10000, 77 | beam_size: int = 1, 78 | stop_on_first: bool = False, 79 | max_anchor_size: TP.Optional[int] = None, 80 | min_samples_start: int = 100, 81 | n_covered_ex: int = 10, 82 | binary_cache_size: int = 10000, 83 | cache_margin: int = 1000, 84 | verbose: bool = False, 85 | verbose_every: int = 1, 86 | **kwargs: TP.Any) -> Explanation: 87 | """ 88 | Explain prediction made by classifier on instance `X`. 89 | 90 | Args: 91 | X: Instance to be explained. 92 | threshold: Minimum precision threshold. 93 | delta: Used to compute `beta`. 94 | tau: Margin between lower confidence bound and minimum precision or upper bound. 95 | batch_size: Batch size used for sampling. 96 | coverage_samples: Number of samples used to estimate coverage from during result search. 97 | beam_size: The number of anchors extended at each step of new anchors construction. 98 | stop_on_first: If ``True``, the beam search algorithm will return the 99 | first anchor that has satisfies the probability constraint. 100 | max_anchor_size: Maximum number of features in result. 101 | min_samples_start: Min number of initial samples. 102 | n_covered_ex: How many examples where anchors apply to store for each anchor sampled during search 103 | (both examples where prediction on samples agrees/disagrees with `desired_label` are stored). 104 | binary_cache_size: The result search pre-allocates `binary_cache_size` batches for storing the binary arrays 105 | returned during sampling. 106 | cache_margin: When only ``max(cache_margin, batch_size)`` positions in the binary cache remain empty, a new cache 107 | of the same size is pre-allocated to continue buffering samples. 108 | verbose: Display updates during the anchor search iterations. 109 | verbose_every: Frequency of displayed iterations during anchor search process. 110 | 111 | Returns: 112 | explanation 113 | `Explanation` object containing the result explaining the instance with additional metadata as attributes. 114 | See usage at `AnchorTabular examples`_ for details. 115 | .. _AnchorTabular examples: 116 | https://docs.seldon.io/projects/alibi/en/latest/methods/Anchors.html 117 | """ 118 | exp = super().explain( 119 | X=X, 120 | threshold=threshold, 121 | delta=delta, 122 | tau=tau, 123 | batch_size=batch_size, 124 | coverage_samples=coverage_samples, 125 | beam_size=beam_size, 126 | stop_on_first=stop_on_first, 127 | max_anchor_size=max_anchor_size, 128 | min_samples_start=min_samples_start, 129 | n_covered_ex=n_covered_ex, 130 | binary_cache_size=binary_cache_size, 131 | cache_margin=cache_margin, 132 | verbose=verbose, 133 | verbose_every=verbose_every 134 | ) 135 | 136 | # This attribute makes pickle serialization crash, so we delete it. 137 | if hasattr(self, "mab"): 138 | delattr(self, "mab") 139 | 140 | return exp 141 | 142 | def get_feature_importance( 143 | self, 144 | explain_data: pd.DataFrame, 145 | threshold: float = 0.95, 146 | print_every: int = 0, 147 | print_explanations: bool = False, 148 | n_important_features: int = 3, 149 | tau: float = 0.15, 150 | timeout: int = 0) -> AnchorsWithImportanceExplanation: 151 | """ 152 | Args: 153 | explain_data: 154 | Pandas dataframe containing all the instances for which to find an anchor and therefore 155 | obtain feature importances. 156 | threshold: To be used in and passed down to the anchor explainer as defined on Alibi's documentation. 157 | Controls the minimum precision desired when looking for anchors. 158 | Defaults to 0.95. 159 | print_every: 160 | Logging information. 161 | Defaults to 0 - No logging 162 | print_explanations: 163 | Boolean that determines whether to print the explanations at the end of the method or not. 164 | Defaults to False. 165 | n_important_features: 166 | Number of top features that will be printed. 167 | Defaults to 3. 168 | tau: 169 | To be used in and passed down to the anchos explainer as defined on Alibi's documentation. 170 | Used within the multi-armed bandit part of the optimisation problem. 171 | Defaults to 0.15 172 | timeout: 173 | Maximum time to be spent looking for an Anchor in seconds. A value of 0 means that no timeout 174 | is set. 175 | Defaults to 0. 176 | 177 | Returns: 178 | A list containing all the explanations. 179 | """ 180 | 181 | explanations = [] 182 | if print_every > 0: 183 | print('Looking for a total of {} explanations'.format( 184 | len(explain_data)) 185 | ) 186 | for explain_datum_idx, explain_datum in explain_data.iterrows(): 187 | try: 188 | explanation = run_until_timeout(timeout, 189 | self.explain, 190 | explain_datum.values, 191 | threshold=threshold, 192 | tau=tau) 193 | explanations.append(explanation) 194 | except Exception: 195 | if print_every > 0: 196 | print('No anchor found for observation {}' 197 | .format(explain_datum_idx)) 198 | explanations.append('No explanation') 199 | 200 | # Unset timeout 201 | signal.alarm(0) 202 | 203 | if print_every > 0: 204 | if len(explanations) % print_every == 0: 205 | print( 206 | ("""A total of {} observations have been processed """ + 207 | """for explaining""").format(len(explanations))) 208 | print("{} anchors have already been found".format( 209 | sum([1 for explan in explanations 210 | if not isinstance(explan, str)]) 211 | )) 212 | # Here we have a list with all the anchors explanations that we've been able to find. 213 | anchorsExtendedExplanation = AnchorsWithImportanceExplanation( 214 | explain_data=explain_data, 215 | explanations=explanations, 216 | categorical=self.categorical_names 217 | ) 218 | if print_explanations: 219 | anchorsExtendedExplanation.interpret_explanations( 220 | n_important_features=n_important_features 221 | ) 222 | return anchorsExtendedExplanation 223 | 224 | def translate(self, explanation: Explanation) -> str: 225 | """ 226 | Translates an explanation into simple words 227 | 228 | Args: 229 | explanation: Alibi explanation object 230 | 231 | """ 232 | coverage = explanation['data']['coverage'] 233 | if type(explanation['data']['precision']) is np.ndarray: 234 | precision = explanation['data']['precision'][0] 235 | else: 236 | precision = explanation['data']['precision'] 237 | 238 | if coverage * precision < 0.1: 239 | quality = "POOR" 240 | elif 0.1 <= coverage * precision < 0.4: 241 | quality = "GOOD" 242 | else: 243 | quality = "GREAT" 244 | 245 | return "[{} explanation] This anchor explains a {}% of all records of its class with {}% confidence.".format( 246 | quality, 247 | round(100 * coverage, 2), 248 | round(100 * precision, 2) 249 | ) 250 | 251 | def save(self, filename): 252 | """Overwrite to ensure that we use MercuryExplainer.save""" 253 | MercuryExplainer.save(self, filename=filename) 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/ale.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | 6 | import copy 7 | import math 8 | 9 | from typing import no_type_check 10 | from itertools import count 11 | from matplotlib.gridspec import GridSpec 12 | from alibi.api.interfaces import Explanation, Explainer 13 | from alibi.api.defaults import DEFAULT_META_ALE, DEFAULT_DATA_ALE 14 | from alibi.explainers.ale import get_quantiles, bisect_fun, minimum_satisfied, adaptive_grid, ale_num 15 | 16 | from .explainer import MercuryExplainer 17 | 18 | 19 | class ALEExplainer(Explainer, MercuryExplainer): 20 | """ 21 | Accumulated Local Effects for tabular datasets. Current implementation supports first order 22 | feature effects of numerical features. 23 | 24 | Args: 25 | predictor: 26 | A callable that takes in an NxF array as input and outputs an NxT array (N - number of 27 | data points, F - number of features, T - number of outputs/targets (e.g. 1 for single output 28 | regression, >=2 for classification). 29 | target_names: 30 | A list of target/output names used for displaying results. 31 | """ 32 | 33 | def __init__(self, predictor: TP.Callable, target_names: TP.Union[TP.List[str], str]) -> None: 34 | super().__init__(meta=copy.deepcopy(DEFAULT_META_ALE)) 35 | 36 | if (not isinstance(target_names, list) and 37 | not isinstance(target_names, str)): 38 | raise AttributeError('The attribute target_names should be a string or list.') 39 | 40 | if type(target_names) == str: 41 | target_names = [target_names] 42 | 43 | self.predictor = predictor 44 | self.target_names = target_names 45 | 46 | def explain(self, X: pd.DataFrame, min_bin_points: int = 4, ignore_features: list = []) -> Explanation: 47 | """ 48 | Calculate the ALE curves for each feature with respect to the dataset `X`. 49 | 50 | Args: 51 | X: 52 | An NxF tabular dataset used to calculate the ALE curves. This is typically the training dataset 53 | or a representative sample. 54 | min_bin_points: 55 | Minimum number of points each discretized interval should contain to ensure more precise 56 | ALE estimation. 57 | ignore_features: 58 | Features that will be ignored while computing the ALE curves. Useful for reducing computing time 59 | if there are predictors we dont care about. 60 | 61 | Returns: 62 | An `Explanation` object containing the data and the metadata of the calculated ALE curves. 63 | 64 | """ 65 | self.meta['params'].update(min_bin_points=min_bin_points) 66 | 67 | if not isinstance(X, pd.DataFrame): 68 | raise ValueError('X must be a pandas DataFrame') 69 | 70 | features = list(X.columns) 71 | n_features =len(features) 72 | 73 | self.feature_names = np.array(features) 74 | self.target_names = np.array(self.target_names) 75 | 76 | feature_values = [] 77 | ale_values = [] 78 | ale0 = [] 79 | feature_deciles = [] 80 | 81 | X = X[features].values 82 | 83 | # TODO: use joblib to paralelise? 84 | for feature, feat_name in enumerate(self.feature_names): 85 | if feat_name not in ignore_features: 86 | q, ale, a0 = ale_num( 87 | self.predictor, 88 | X=X, 89 | feature=feature, 90 | min_bin_points=min_bin_points 91 | ) 92 | deciles = get_quantiles(X[:, feature], num_quantiles=11) 93 | 94 | feature_values.append(q) 95 | ale_values.append(ale) 96 | ale0.append(a0) 97 | feature_deciles.append(deciles) 98 | 99 | constant_value = self.predictor(X).mean() 100 | # TODO: an ALE plot ideally requires a rugplot to gauge density of instances in the feature space. 101 | # I've replaced this with feature deciles which is coarser but has constant space complexity 102 | # as opposed to a rugplot. Alternatively, could consider subsampling to produce a rug with some 103 | # maximum number of points. 104 | return self._build_explanation( 105 | ale_values=ale_values, 106 | ale0=ale0, 107 | constant_value=constant_value, 108 | feature_values=feature_values, 109 | feature_deciles=feature_deciles, 110 | ignore_features=ignore_features 111 | ) 112 | 113 | def _build_explanation(self, 114 | ale_values: TP.List[np.ndarray], 115 | ale0: TP.List[np.ndarray], 116 | constant_value: float, 117 | feature_values: TP.List[np.ndarray], 118 | feature_deciles: TP.List[np.ndarray], 119 | ignore_features: TP.List = []) -> Explanation: 120 | """ 121 | Helper method to build the Explanation object. 122 | """ 123 | # TODO decide on the format for these lists of arrays 124 | # Currently each list element relates to a feature and each column relates to an output dimension, 125 | # this is different from e.g. SHAP but arguably more convenient for ALE. 126 | 127 | data = copy.deepcopy(DEFAULT_DATA_ALE) 128 | data.update( 129 | ale_values=ale_values, 130 | ale0=ale0, 131 | constant_value=constant_value, 132 | feature_values=feature_values, 133 | feature_names=[x for x in self.feature_names if x not in ignore_features], 134 | target_names=self.target_names, 135 | feature_deciles=feature_deciles 136 | ) 137 | 138 | return Explanation(meta=copy.deepcopy(self.meta), data=data) 139 | 140 | def save(self, filename): 141 | """Overwrite to ensure that we use MercuryExplainer.save""" 142 | MercuryExplainer.save(self, filename=filename) 143 | 144 | @no_type_check 145 | def plot_ale(exp: Explanation, 146 | features: TP.Union[TP.List[TP.Union[int, str]], str] = 'all', 147 | targets: TP.Union[TP.List[TP.Union[int, str]], str] = 'all', 148 | n_cols: int = 3, 149 | sharey: str = 'all', 150 | constant: bool = False, 151 | ax: TP.Union['plt.Axes', np.ndarray, None] = None, 152 | line_kw: TP.Optional[dict] = None, 153 | fig_kw: TP.Optional[dict] = None) -> 'np.ndarray': 154 | """ 155 | Plot ALE curves on matplotlib axes. 156 | 157 | Args: 158 | exp: An `Explanation` object produced by a call to the `ALE.explain` method. 159 | features: A list of features for which to plot the ALE curves or `all` for all features. 160 | Can be a mix of integers denoting feature index or strings denoting entries in 161 | `exp.feature_names`. Defaults to 'all'. 162 | targets: A list of targets for which to plot the ALE curves or `all` for all targets. 163 | Can be a mix of integers denoting target index or strings denoting entries in 164 | `exp.target_names`. Defaults to 'all'. 165 | n_cols: Number of columns to organize the resulting plot into. 166 | sharey: A parameter specifying whether the y-axis of the ALE curves should be on the same scale 167 | for several features. Possible values are `all`, `row`, `None`. 168 | constant: A parameter specifying whether the constant zeroth order effects should be added to the 169 | ALE first order effects. 170 | ax: A `matplotlib` axes object or a numpy array of `matplotlib` axes to plot on. 171 | line_kw: Keyword arguments passed to the `plt.plot` function. 172 | fig_kw: Keyword arguments passed to the `fig.set` function. 173 | 174 | Returns: 175 | An array of matplotlib axes with the resulting ALE plots. 176 | 177 | """ 178 | # line_kw and fig_kw values 179 | default_line_kw = {'markersize': 3, 'marker': 'o', 'label': None} 180 | if line_kw is None: 181 | line_kw = {} 182 | line_kw = {**default_line_kw, **line_kw} 183 | 184 | default_fig_kw = {'tight_layout': 'tight'} 185 | if fig_kw is None: 186 | fig_kw = {} 187 | fig_kw = {**default_fig_kw, **fig_kw} 188 | 189 | if features == 'all': 190 | selected_features = range(0, len(exp.feature_names)) 191 | else: 192 | selected_features = [] 193 | for ix, f in enumerate(features): 194 | if isinstance(f, str): 195 | try: 196 | selected_features.append(exp.feature_names.index(f)) 197 | except ValueError: 198 | raise ValueError("Feature name {} does not exist.".format(f)) 199 | elif isinstance(f, int): 200 | selected_features.append(f) 201 | n_features = len(selected_features) 202 | 203 | if targets == 'all': 204 | targets = range(0, len(exp.target_names)) 205 | else: 206 | for ix, t in enumerate(targets): 207 | if isinstance(t, str): 208 | try: 209 | t = np.argwhere(exp.target_names == t).item() 210 | except ValueError: 211 | raise ValueError("Target name {} does not exist.".format(t)) 212 | targets[ix] = t 213 | 214 | # make axes 215 | if ax is None: 216 | fig, ax = plt.subplots() 217 | 218 | if isinstance(ax, plt.Axes) and n_features != 1: 219 | ax.set_axis_off() # treat passed axis as a canvas for subplots 220 | fig = ax.figure 221 | n_cols = min(n_cols, n_features) 222 | n_rows = math.ceil(n_features / n_cols) 223 | 224 | axes = np.empty((n_rows, n_cols), dtype=np.object_) 225 | axes_ravel = axes.ravel() 226 | # gs = GridSpecFromSubplotSpec(n_rows, n_cols, subplot_spec=ax.get_subplotspec()) 227 | gs = GridSpec(n_rows, n_cols) 228 | for i, spec in zip(range(n_features), gs): 229 | # determine which y-axes should be shared 230 | if sharey == 'all': 231 | cond = i != 0 232 | elif sharey == 'row': 233 | cond = i % n_cols != 0 234 | else: 235 | cond = False 236 | 237 | if cond: 238 | axes_ravel[i] = fig.add_subplot(spec, sharey=axes_ravel[i - 1]) 239 | continue 240 | axes_ravel[i] = fig.add_subplot(spec) 241 | 242 | else: # array-like 243 | if isinstance(ax, plt.Axes): 244 | ax = np.array(ax) 245 | if ax.size < n_features: 246 | raise ValueError("Expected ax to have {} axes, got {}".format(n_features, ax.size)) 247 | axes = np.atleast_2d(ax) 248 | axes_ravel = axes.ravel() 249 | fig = axes_ravel[0].figure 250 | 251 | # make plots 252 | for ix, feature, ax_ravel in \ 253 | zip(count(), selected_features, axes_ravel): 254 | _ = _plot_one_ale_num(exp=exp, 255 | feature=feature, 256 | targets=targets, 257 | constant=constant, 258 | ax=ax_ravel, 259 | legend=not ix, # only one legend 260 | line_kw=line_kw) 261 | 262 | # if explicit labels passed, handle the legend here as the axis passed might be repeated 263 | if line_kw['label'] is not None: 264 | axes_ravel[0].legend() 265 | 266 | fig.set(**fig_kw) 267 | # TODO: should we return just axes or ax + axes 268 | return axes 269 | 270 | @no_type_check 271 | def _plot_one_ale_num(exp: Explanation, 272 | feature: int, 273 | targets: TP.List[int], 274 | constant: bool = False, 275 | ax: 'plt.Axes' = None, 276 | legend: bool = True, 277 | line_kw: dict = None) -> 'plt.Axes': 278 | """ 279 | Plots the ALE of exactly one feature on one axes. 280 | """ 281 | import matplotlib.pyplot as plt 282 | from matplotlib import transforms 283 | 284 | if ax is None: 285 | ax = plt.gca() 286 | 287 | # add zero baseline 288 | ax.axhline(0, color='grey') 289 | 290 | # Sometimes within the computation of the ale values, we get more values corresponding 291 | # to the feature than values corresponding to the target or vice-versa, i.e. the number 292 | # of X's and Y's is not the same and therefore is not possible to properly build the ale 293 | # plot. These conditions help ensure len(x) equals len(y). 294 | if len(exp.feature_values[feature]) == len(exp.ale_values[feature][:, targets]): 295 | lines = ax.plot( 296 | exp.feature_values[feature], 297 | exp.ale_values[feature][:, targets] + constant * exp.constant_value, 298 | **line_kw 299 | ) 300 | elif len(exp.feature_values[feature]) < len(exp.ale_values[feature][:, targets]): 301 | diff = len(exp.ale_values[feature][:, targets]) - len(exp.feature_values[feature]) 302 | x = np.append(exp.feature_values[feature], np.repeat(exp.feature_values[feature][-1], diff)) 303 | y = exp.ale_values[feature][:, targets] 304 | lines = ax.plot( 305 | x, y, **line_kw 306 | ) 307 | elif len(exp.feature_values[feature]) > len(exp.ale_values[feature][:, targets]): 308 | diff = len(exp.feature_values[feature]) - len(exp.ale_values[feature][:, targets]) 309 | y = np.append(exp.ale_values[feature][:, targets], np.repeat(exp.ale_values[feature][:, targets][-1], diff)) 310 | x = exp.feature_values[feature] 311 | lines = ax.plot( 312 | x, y, **line_kw 313 | ) 314 | 315 | # add decile markers to the bottom of the plot 316 | trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) 317 | ax.vlines(exp.feature_deciles[feature][1:], 0, 0.05, transform=trans) 318 | 319 | ax.set_xlabel(exp.feature_names[feature]) 320 | ax.set_ylabel('ALE') 321 | 322 | if legend: 323 | # if no explicit labels passed, just use target names 324 | if line_kw['label'] is None: 325 | ax.legend(lines, exp.target_names[targets]) 326 | 327 | return ax 328 | -------------------------------------------------------------------------------- /mercury/explainability/explanations/counter_factual.py: -------------------------------------------------------------------------------- 1 | import typing as TP 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import bokeh.plotting as BP 6 | import bokeh.io as BPIO 7 | 8 | from bokeh.models import ColorBar, LinearColorMapper 9 | from bokeh.layouts import layout, row, column 10 | 11 | 12 | class CounterfactualBasicExplanation(object): 13 | """ 14 | A Panallet explanation. 15 | 16 | Args: 17 | from_ (np.ndarray): 18 | Starting point. 19 | to_ (np.ndarray): 20 | Found solution. 21 | p (float): 22 | Probability of found solution. 23 | path (np.ndarray): 24 | Path followed to get to the found solution. 25 | path_ps (np.ndarray): 26 | Probabilities of each path step. 27 | bounds (np.ndarray): 28 | Feature bounds used when exploring the probability space. 29 | explored (np.ndarray): 30 | Points explored but not visited (available only when backtracking 31 | strategy is used, empty for Simulated Annealing) 32 | explored_ps (np.ndarray): 33 | Probabilities of explored points (available only when backtracking 34 | strategy is used, empty for Simulated Annealing) 35 | labels (TP.Optional[TP.List[str]]): 36 | Labels to be used for each point dimension (used when plotting). 37 | 38 | Raises: 39 | AssertionError: if from_ shape != to_.shape 40 | AssertionError: if dim(from_) != 1 41 | AssertionError: if not 0 <= p <= 1 42 | AssertionError: if path.shape[0] != path_ps.shape[0] 43 | AssertionError: if bounds.shape[0] != from_.shape[0] 44 | AssertionError: if explored.shape[0] != explored_ps.shape[0] 45 | AssertionError: if len(labels) > 0 and len(labels) != bounds.shape[0] 46 | """ 47 | def __init__(self, 48 | from_: 'np.ndarray', 49 | to_: 'np.ndarray', 50 | p: float, 51 | path: 'np.ndarray', 52 | path_ps: 'np.ndarray', 53 | bounds: 'np.ndarray', 54 | explored: 'np.ndarray' = np.array([]), 55 | explored_ps: 'np.ndarray' = np.array([]), 56 | labels: TP.Optional[TP.List[str]] = []) -> None: 57 | # Initial/end points 58 | assert from_.shape == to_.shape and from_.ndim == 1, 'Invalid dimensions' 59 | self.from_ = from_ 60 | self.to_ = to_ 61 | 62 | # Found solution probability 63 | assert p >= 0 and p <= 1, 'Invalid probability' 64 | self.p = p 65 | 66 | # Path followed till solution is found 67 | assert path.shape[0] == path_ps.shape[0], \ 68 | 'Invalid shape for path probabilities, got {} but expected {}'.format(path.shape[0], path_ps.shape[0]) 69 | self.path = path 70 | self.path_ps = path_ps 71 | 72 | # Used bounds in the solution 73 | assert bounds.shape[0] == self.from_.shape[0], 'Invalid bounds shape' 74 | self.bounds = bounds 75 | 76 | assert explored.shape[0] == explored_ps.shape[0], \ 77 | 'Invalid shape for explored probabilities, got {} but expected {}'.format(explored.shape[0], 78 | explored_ps.shape[0]) 79 | self.explored = explored 80 | 81 | if labels is not None and len(labels) > 0: 82 | assert len(labels) == self.bounds.shape[0], 'Invalid number of labels' 83 | self.labels = labels 84 | 85 | def get_changes(self, relative=True) -> 'np.ndarray': 86 | """ 87 | Returns relative/absolute changes between initial and ending point. 88 | 89 | Args: 90 | relative (bool): 91 | True for relative changes, False for absolute changes. 92 | 93 | Returns: 94 | (np.ndarray) Relative or absolute changes for each feature. 95 | """ 96 | 97 | if relative: 98 | # Avoid divs by zero 99 | aux = self.from_.copy() 100 | aux[aux == 0.] = 1. 101 | return (self.to_.squeeze() - self.from_.squeeze()) * 100 / (np.sign(aux) * aux.squeeze()) 102 | else: 103 | return self.to_.squeeze() - self.from_.squeeze() 104 | 105 | @staticmethod 106 | def plot_butterfly(data, 107 | columns, 108 | axis, 109 | title: str = '', 110 | decimals: int = 1, 111 | positive_color: str = '#0A5FB4', 112 | negative_color: str = '#DA3851', 113 | special_color: str = '#48AE64') -> None: # pragma: no cover 114 | data_ = data.copy().squeeze() 115 | num_labels = [(' {:.' + str(decimals) + 'f} ').format(float(x)) for x in data_] 116 | are_normal = np.isfinite(data_) 117 | are_special = False == are_normal 118 | if sum(are_normal): 119 | imputed_pos = max(0., np.max(data_[are_normal])) 120 | imputed_neg = min(0., np.min(data_[are_normal])) 121 | else: 122 | imputed_pos, imputed_neg = 1., -1. 123 | data_[np.logical_and(are_special, np.isnan(data_))] = 0. 124 | data_[np.logical_and(are_special, np.isposinf(data_))] = imputed_pos 125 | data_[np.logical_and(are_special, np.isneginf(data_))] = imputed_neg 126 | rec = axis.barh(range(data_.size), data_) 127 | axis.tick_params(top='off', bottom='on', left='off', right='off', labelleft='off', labelbottom='on') 128 | for i, (value, label, num_label, is_normal) in enumerate(zip(data_, columns, num_labels, are_normal)): 129 | if is_normal: 130 | color = negative_color if value < 0 else positive_color 131 | else: 132 | color = special_color 133 | num_align, label_align = ('left', 'right') if value > 0 else ('right', 'left') 134 | axis.text(value, i, num_label, ha=num_align, va='center', color=color, size='smaller') 135 | axis.text(0, i, ' {} '.format(label), ha=label_align, va='center', color=color, size='smaller') 136 | rec[i].set_color(color) 137 | axis.set_axis_off() 138 | axis.set_title(title) 139 | axis.title.set_text(title) 140 | 141 | def show(self, figsize: TP.Tuple[int, int] = (12, 6), debug: bool = False, 142 | path: TP.Optional[str] = None, backend='matplotlib') -> None: # pragma: no cover 143 | """ 144 | Creates a plot with the explanation. 145 | 146 | Args: 147 | figsize (tuple): 148 | Width and height of the figure (inches if matplotlib backend is used, 149 | pixels for bokeh backend). 150 | debug (bool): 151 | Display verbose information (debug mode). 152 | """ 153 | 154 | def _show(from_: 'np.ndarray', to_: 'np.ndarray', backend='matplotlib', 155 | path: TP.Optional[str] = None, debug: bool = False) -> None: 156 | """ Backend specific show method. """ 157 | 158 | if backend == 'matplotlib': 159 | # It seems we can't decouple figure from axes 160 | fig = plt.figure(figsize=figsize) 161 | 162 | # LIME-like hbars showing relative differences 163 | ax = plt.subplot2grid((2, 5), (0, 1)) 164 | CounterfactualBasicExplanation.plot_butterfly( 165 | self.get_changes(relative=False), self.labels, ax, 166 | title='Absolute delta') 167 | 168 | ax = plt.subplot2grid((2, 5), (0, 3)) 169 | CounterfactualBasicExplanation.plot_butterfly( 170 | self.get_changes(relative=True), self.labels, ax, 171 | title='Relative delta') 172 | 173 | # Probabilities 174 | ax = plt.subplot2grid((2, 5), (1, 0), colspan=5) 175 | xs = np.arange(len(self.path_ps)) 176 | ys = self.path_ps 177 | cax = ax.scatter(xs, ys, c=ys) 178 | fig.colorbar(cax) 179 | ax.plot(xs, ys, '--', c='k', linewidth=.2, alpha=.3) 180 | ax.grid() 181 | ax.set_title('Visited itinerary') 182 | ax.set_xlabel('# Iteration') 183 | ax.set_ylabel('probability') 184 | plt.tight_layout() 185 | 186 | if path is not None: 187 | plt.savefig(path, output='pdf') 188 | else: 189 | plt.show() 190 | 191 | elif backend == 'bokeh': 192 | # LIME-like hbars showing relative differences 193 | values = self.get_changes() 194 | fig1 = BP.figure(plot_width=400, plot_height=300, y_range=self.labels, 195 | x_range=(min(values), max(values)), x_axis_label='Relative change') 196 | colors = np.where(values <= 0, '#ff0000', '#00ff00') 197 | fig1.hbar(y=self.labels, height=0.75, right=values, fill_color=colors) 198 | 199 | # LIME-like hbars showing absolute differences 200 | values = self.get_changes(relative=False) 201 | fig2 = BP.figure(plot_width=400, plot_height=300, y_range=self.labels, 202 | x_range=(min(values), max(values)), x_axis_label='Absolute change') 203 | colors = np.where(values <= 0, '#ff0000', '#00ff00') 204 | fig2.hbar(y=self.labels, height=0.75, right=values, fill_color=colors, line_color=None) 205 | 206 | # Probabilities 207 | fig3 = BP.figure(plot_width=800, plot_height=200, x_axis_label='Step', y_axis_label='p') 208 | xs = np.arange(self.path_ps.size) 209 | ys = self.path_ps 210 | color_mapper = LinearColorMapper(palette='Viridis256', low=min(self.path_ps), 211 | high=max(self.path_ps)) 212 | color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0)) 213 | fig3.circle(xs, ys, size=5, fill_color={'field': 'y', 'transform': color_mapper}, 214 | fill_alpha=.3, line_color=None) 215 | fig3.add_layout(color_bar, 'left') 216 | fig3.line(xs, ys, line_dash='dashed', line_alpha=.3, line_width=.2) 217 | 218 | row1 = row([fig1, fig2]) 219 | row2 = row([fig3]) 220 | lyt = column([row1, row2]) 221 | 222 | if path is not None: 223 | BPIO.export_png(lyt, filename=path) 224 | else: 225 | BPIO.output_notebook(hide_banner=True) 226 | BP.show(lyt) 227 | 228 | else: 229 | raise ValueError('Unsupported backend') 230 | 231 | if debug: 232 | self.__verbose() 233 | 234 | _show(self.from_, self.to_, debug=debug, path=path, backend=backend) 235 | 236 | def __verbose(self): # pragma: no cover 237 | """ Internal debug information. """ 238 | 239 | print('Used bounds:') 240 | for i in range(self.bounds.shape[0]): 241 | label = self.labels[i] if self.labels else '' 242 | print('\t[{}] {}: [{}, {}]'.format(i, label, self.bounds[i][0], self.bounds[i][1])) 243 | print('Starting point: {}'.format(self.from_)) 244 | print('Found solution: {} with probability {}'.format(self.to_, self.p)) 245 | print('Changes:') 246 | for i in range(self.from_.shape[0]): 247 | if self.from_[i] != self.to_[i]: 248 | label = self.labels[i] if self.labels else '' 249 | print('\t[{}] {}: {} -> {}'.format(i, label, self.from_[i], self.to_[i])) 250 | 251 | 252 | class CounterfactualWithImportanceExplanation(object): 253 | """ 254 | Extended Counterfactual Explanations 255 | 256 | Args: 257 | explain_data: 258 | A pandas DataFrame containing the observations for which an explanation has to be found. 259 | counterfactuals: 260 | 261 | importances: 262 | A list of tuples containing the importance values of the features. 263 | count_diffs: 264 | A dictionary containing the count differences of the features. 265 | count_diffs_norm: 266 | A dictionary containing the normalized count differences of the features. 267 | """ 268 | 269 | def __init__( 270 | self, 271 | explain_data: pd.DataFrame, 272 | counterfactuals: TP.List[dict], 273 | importances: TP.List[TP.Tuple], 274 | count_diffs: dict, 275 | count_diffs_norm: dict 276 | ) -> None: 277 | self.explain_data = explain_data 278 | self.counterfactuals = counterfactuals 279 | self.importances = importances 280 | self.count_diffs = count_diffs 281 | self.count_diffs_norm = count_diffs_norm 282 | 283 | def interpret_explanations(self, n_important_features: int = 3) -> str: 284 | """ 285 | This method prints a report of the important features obtaiend. 286 | 287 | Args: 288 | n_important_features: 289 | The number of imporant features that will appear in the report. 290 | Defaults to 3. 291 | """ 292 | 293 | importances_str = [] 294 | for n in range(n_important_features): 295 | importance_str = [imp if isinstance(imp, str) else '{:.2f}'.format(imp) for imp in self.importances[n]] 296 | importances_str.append(importance_str) 297 | 298 | count_diffs_norm_str = [] 299 | for n in range(n_important_features): 300 | count_diffs_i = list(self.count_diffs_norm.items())[n] 301 | count_diff_norm_str = '{} {:.2f}'.format(count_diffs_i[0], count_diffs_i[1]) 302 | count_diffs_norm_str.append(count_diff_norm_str) 303 | 304 | interptretation = """The {} most important features and their importance values according to the first metric (amount features change) are: 305 | {}. 306 | 307 | According to the second metric (times features change), these importances are: 308 | {}""".format( 309 | n_important_features, 310 | ' AND '.join([' '.join(imp_str) for imp_str in importances_str]), 311 | ' AND '.join(count_diffs_norm_str) 312 | ) 313 | print(interptretation) 314 | return interptretation 315 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/_tree_splitters/cut_finder.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # cython: boundscheck = False 3 | # cython: wraparound = False 4 | # cython: profile = False 5 | 6 | # file origin: https://github.com/navefr/ExKMC 7 | # copy date: 2025/02/20 8 | 9 | import numpy as np 10 | cimport numpy as np 11 | cimport cython 12 | from cython.parallel import prange 13 | from libc.stdlib cimport malloc, free 14 | 15 | ctypedef np.int32_t NP_INT_t 16 | ctypedef np.float64_t NP_FLOAT_t 17 | 18 | cdef extern from "" nogil: 19 | const float INFINITY 20 | 21 | 22 | cdef extern from "": 23 | const int INT_MIN 24 | const int INT_MAX 25 | 26 | 27 | cdef struct IMM_Cut: 28 | int col 29 | float threshold 30 | 31 | 32 | @cython.boundscheck(False) 33 | @cython.wraparound(False) 34 | def get_min_mistakes_cut(NP_FLOAT_t[:,:] X, NP_INT_t[:] y, NP_FLOAT_t[:,:] centers, NP_INT_t[:] valid_centers, NP_INT_t[:] valid_cols, int njobs): 35 | cdef int n = X.shape[0] 36 | cdef int k = centers.shape[0] 37 | cdef int d = valid_cols.shape[0] 38 | cdef int *centers_count = malloc(k * sizeof(int)) 39 | cdef NP_FLOAT_t *cols_thresholds = malloc(d * sizeof(NP_FLOAT_t)) 40 | cdef int *cols_mistakes = malloc(d * sizeof(int)) 41 | cdef int col 42 | cdef int best_col = -1 43 | cdef NP_FLOAT_t best_threshold 44 | cdef int min_mistakes = INT_MAX 45 | 46 | # Count the number of data points for each center. 47 | # This information will be helpful for fast mistakes calculation, once a threshold pass a center. 48 | for i in range(k): 49 | centers_count[i] = 0 50 | for i in range(n): 51 | centers_count[y[i]] += 1 52 | 53 | if njobs is None or njobs <= 1: 54 | # Iterate over valid coordinates 55 | for col in range(d): 56 | if valid_cols[col] == 1: 57 | update_col_min_mistakes_cut(X, y, centers, valid_centers, centers_count, cols_thresholds, cols_mistakes, col, n, d, k) 58 | else: 59 | # Iterate over valid coordinates 60 | for col in prange(d, nogil=True, num_threads=njobs): 61 | if valid_cols[col] == 1: 62 | update_col_min_mistakes_cut(X, y, centers, valid_centers, centers_count, cols_thresholds, cols_mistakes, col, n, d, k) 63 | 64 | for col in range(d): 65 | # This is a valid column 66 | if valid_cols[col] == 1: 67 | # We found a valid split 68 | if cols_mistakes[col] != -1: 69 | # This is a better cut 70 | if cols_mistakes[col] < min_mistakes: 71 | best_col = col 72 | min_mistakes = cols_mistakes[col] 73 | 74 | if best_col != -1: 75 | best_threshold = cols_thresholds[best_col] 76 | 77 | free(cols_thresholds) 78 | free(cols_mistakes) 79 | free(centers_count) 80 | 81 | if best_col == -1: 82 | return None 83 | else: 84 | return IMM_Cut(best_col, best_threshold) 85 | 86 | 87 | @cython.boundscheck(False) 88 | @cython.wraparound(False) 89 | cdef void update_col_min_mistakes_cut(NP_FLOAT_t[:,:] X, NP_INT_t[:] y, NP_FLOAT_t[:,:] centers, NP_INT_t[:] valid_centers, int* centers_count, NP_FLOAT_t *cols_thresholds, int *cols_mistakes, int col, int n, int d, int k) nogil: 90 | cdef int i 91 | cdef int ix 92 | cdef int ic 93 | cdef int mistakes 94 | cdef NP_FLOAT_t prev_threshold 95 | cdef NP_FLOAT_t threshold 96 | cdef NP_FLOAT_t max_val 97 | cdef np.int64_t[:] data_order 98 | cdef np.int64_t[:] centers_order 99 | cdef int *left_centers_count = malloc(k * sizeof(int)) 100 | cdef int curr_center_idx 101 | cdef bint valid_found = 0 102 | cdef bint is_center_threshold 103 | cdef int min_mistakes = INT_MAX 104 | cdef NP_FLOAT_t best_threshold 105 | 106 | # Sort data points and centers 107 | with gil: 108 | data_order = np.asarray(X[:,col]).argsort() 109 | centers_order = np.asarray(centers[:, col]).argsort() 110 | 111 | # Find maximal value of valid centers. Possible threshold must be strictly smaller than that. 112 | max_val = -INFINITY 113 | for i in range(k): 114 | if valid_centers[i] == 1: 115 | if centers[i, col] > max_val: 116 | max_val = centers[i, col] 117 | 118 | # For each center, count number of data points associated with it that are smaller than the current threshold. 119 | # This information will be helpful for fast mistakes calculation, once a threshold pass a center. 120 | for i in range(k): 121 | left_centers_count[i] = 0 122 | 123 | # Initialize pointers to data points and centers sorted lists. 124 | ix = 0 125 | ic = 0 126 | 127 | # Initialize number of mistakes. 128 | mistakes = 0 129 | 130 | # Advance center index to the first valid one 131 | while valid_centers[centers_order[ic]] == 0: 132 | ic += 1 133 | 134 | # The first threshold is the value of the first valid center. 135 | threshold = centers[centers_order[ic], col] 136 | 137 | # The first threshold is a center. 138 | # All data points that are smaller or equal to the threshold will be accumulated prior to the main loop. 139 | is_center_threshold = 1 140 | 141 | # Advance data point index to the first valid center (which is the first valid threshold). 142 | # Count mistakes of points smaller than the first valid threshold 143 | while ix < n and X[data_order[ix], col] <= threshold: 144 | curr_center_idx = y[data_order[ix]] # Center of the current data point 145 | left_centers_count[curr_center_idx] += 1 146 | if centers[curr_center_idx, col] >= threshold: 147 | mistakes += 1 148 | ix += 1 149 | 150 | # Corner case. 151 | # Exactly n - 1 data points are on the left of the first center - this is a valid cut, but we won't enter the main loop. 152 | if ix == n - 1: 153 | # Recalculate the number of mistakes. 154 | # In this corner case all points except one are to the left of the current threshold. 155 | mistakes = 0 156 | ic = 0 157 | # Go over all valid centers. 158 | while ic < k: 159 | if valid_centers[ic] != 0: 160 | # If a center is to the right of the current threshold, then all of its point are considered as mistakes. 161 | # (Perhaps except to the last point that will be corrected later). 162 | if centers[ic, col] > threshold: 163 | mistakes += centers_count[ic] 164 | ic += 1 165 | # Find the center of the single point that is on the right of the threshold. 166 | # If the center is also to the right of the current threshold, then remove one mistake. 167 | ic = y[data_order[n - 1]] 168 | if centers[ic, col] > threshold: 169 | mistakes -= 1 170 | # Update best cut. 171 | if mistakes < min_mistakes: 172 | valid_found = 1 173 | best_col = col 174 | best_threshold = threshold 175 | min_mistakes = mistakes 176 | 177 | # Main loop 178 | while ix < n - 1 and ic < k: 179 | 180 | # If threshold reached to the last valid center, the loop should end. 181 | if threshold >= max_val: 182 | break 183 | 184 | # In case this threshold is associated to a data point 185 | if is_center_threshold == 0: 186 | # Find data point center 187 | curr_center_idx = y[data_order[ix]] 188 | # Increase the count of points smaller than the threshold associated with the center 189 | left_centers_count[curr_center_idx] += 1 190 | 191 | # Update the mistakes count 192 | if centers[curr_center_idx, col] >= threshold: 193 | mistakes += 1 194 | elif centers[curr_center_idx, col] < threshold: 195 | mistakes -= 1 196 | 197 | # Move to the next data point index 198 | ix += 1 199 | 200 | # In case this threshold is associated to a center 201 | else: 202 | # Update mistakes count 203 | # left points are no longer mistakes 204 | # right points (which equal to total points - left points) are now mistakes 205 | mistakes += centers_count[centers_order[ic]] - 2 * left_centers_count[centers_order[ic]] 206 | 207 | # Move to the next center index 208 | ic += 1 209 | while valid_centers[centers_order[ic]] == 0 and ic < k: 210 | ic += 1 211 | 212 | prev_threshold = threshold 213 | 214 | # Find next threshold 215 | # in case of equality, data points arrive before centers in order to correctly find left count 216 | if X[data_order[ix], col] <= centers[centers_order[ic], col]: 217 | threshold = X[data_order[ix], col] 218 | is_center_threshold = 0 219 | else: 220 | threshold = centers[centers_order[ic], col] 221 | is_center_threshold = 1 222 | 223 | # Update best cut (only if the next threshold is not equal to the current one) 224 | if (prev_threshold != threshold) and (mistakes < min_mistakes): 225 | valid_found = 1 226 | best_threshold = prev_threshold 227 | min_mistakes = mistakes 228 | 229 | free(left_centers_count) 230 | 231 | if valid_found == 1: 232 | cols_thresholds[col] = best_threshold 233 | cols_mistakes[col] = min_mistakes 234 | else: 235 | cols_thresholds[col] = -1.0 236 | cols_mistakes[col] = -1 237 | 238 | 239 | cdef struct Surrogate_Cut: 240 | int col 241 | float threshold 242 | NP_FLOAT_t cost 243 | int center_left 244 | int center_right 245 | 246 | 247 | @cython.boundscheck(False) 248 | @cython.wraparound(False) 249 | def get_min_surrogate_cut(NP_FLOAT_t[:,:] X, NP_FLOAT_t[:,:] X_center_dot, NP_FLOAT_t[:] X_sum_all_center_dot, NP_FLOAT_t[:] centers_norm_sqr, int njobs): 250 | cdef int n = X.shape[0] 251 | cdef int k = X_center_dot.shape[1] 252 | cdef int d = X.shape[1] 253 | cdef int col 254 | cdef NP_FLOAT_t *thresholds = malloc(d * sizeof(NP_FLOAT_t)) 255 | cdef NP_FLOAT_t *costs = malloc(d * sizeof(NP_FLOAT_t)) 256 | cdef int *left_centers = malloc(d * sizeof(int)) 257 | cdef int *right_centers = malloc(d * sizeof(int)) 258 | cdef (NP_FLOAT_t, NP_FLOAT_t, int, int) cut 259 | cdef int best_col = -1 260 | cdef NP_FLOAT_t best_cost = INFINITY 261 | cdef NP_FLOAT_t best_threshold 262 | cdef int best_center_left 263 | cdef int best_center_right 264 | 265 | 266 | if njobs is None or njobs <= 1: 267 | # Iterate over valid coordinates 268 | for col in range(d): 269 | update_col_surrogate_cut(X, X_center_dot, centers_norm_sqr, X_sum_all_center_dot, n, d, k, col, thresholds, costs, left_centers, right_centers) 270 | else: 271 | # Iterate over valid coordinates 272 | for col in prange(d, nogil=True, num_threads=njobs): 273 | update_col_surrogate_cut(X, X_center_dot, centers_norm_sqr, X_sum_all_center_dot, n, d, k, col, thresholds, costs, left_centers, right_centers) 274 | 275 | for col in range(d): 276 | # This is a valid cut 277 | if left_centers[col] != -1: 278 | # This is a better cut 279 | if costs[col] < best_cost: 280 | best_col = col 281 | best_cost = costs[col] 282 | 283 | 284 | if best_col != -1: 285 | best_threshold = thresholds[best_col] 286 | best_center_left = left_centers[best_col] 287 | best_center_right = right_centers[best_col] 288 | 289 | free(thresholds) 290 | free(costs) 291 | free(left_centers) 292 | free(right_centers) 293 | 294 | if best_col == -1: 295 | return None 296 | else: 297 | return Surrogate_Cut(best_col, best_threshold, best_cost, best_center_left, best_center_right) 298 | 299 | 300 | @cython.boundscheck(False) 301 | @cython.wraparound(False) 302 | cdef void update_col_surrogate_cut(NP_FLOAT_t[:,:] X, NP_FLOAT_t[:,:] X_center_dot, NP_FLOAT_t[:] centers_norm_sqr, NP_FLOAT_t[:] X_sum_all_center_dot, int n, int d, int k, int col, NP_FLOAT_t *thresholds, NP_FLOAT_t *costs, int *left_centers, int *right_centers) nogil: 303 | cdef int i 304 | cdef int ix 305 | cdef int ic 306 | cdef int n_left 307 | cdef int n_right 308 | cdef NP_FLOAT_t[:] curr_X_center_dot 309 | cdef NP_FLOAT_t *X_sum_left_center_dot = malloc(k * sizeof(NP_FLOAT_t)) 310 | cdef NP_FLOAT_t *X_sum_right_center_dot = malloc(k * sizeof(NP_FLOAT_t)) 311 | cdef NP_FLOAT_t cost 312 | cdef NP_FLOAT_t left_cost 313 | cdef NP_FLOAT_t right_cost 314 | cdef int left_center 315 | cdef int right_center 316 | cdef NP_FLOAT_t cur_total_cost 317 | cdef NP_FLOAT_t prev_threshold 318 | cdef NP_FLOAT_t threshold 319 | cdef np.int64_t[:] data_order 320 | cdef bint valid_found = 0 321 | cdef NP_FLOAT_t best_cost = INFINITY 322 | cdef NP_FLOAT_t best_threshold 323 | cdef int best_center_left 324 | cdef int best_center_right 325 | 326 | # Sort data points 327 | with gil: 328 | data_order = np.asarray(X[:,col]).argsort() 329 | 330 | ix = 0 331 | 332 | curr_X_center_dot = X_center_dot[data_order[0]] 333 | for i in range(k): 334 | X_sum_left_center_dot[i] = curr_X_center_dot[i] 335 | X_sum_right_center_dot[i] = X_sum_all_center_dot[i] 336 | for i in range(k): 337 | X_sum_right_center_dot[i] -= curr_X_center_dot[i] 338 | n_left = 1 339 | n_right = n - 1 340 | 341 | threshold = X[data_order[0], col] 342 | 343 | while ix < n - 1: 344 | 345 | ix += 1 346 | 347 | prev_threshold = threshold 348 | threshold = X[data_order[ix], col] 349 | 350 | if prev_threshold != threshold: 351 | 352 | left_cost = INFINITY 353 | right_cost = INFINITY 354 | for ic in range(k): 355 | cost = n_left * centers_norm_sqr[ic] - 2 * X_sum_left_center_dot[ic] 356 | if cost < left_cost: 357 | left_cost = cost 358 | left_center = ic 359 | cost = n_right * centers_norm_sqr[ic] - 2 * X_sum_right_center_dot[ic] 360 | if cost < right_cost: 361 | right_cost = cost 362 | right_center = ic 363 | 364 | cur_total_cost = left_cost + right_cost 365 | 366 | # Add dot product of current vector and each center to the left 367 | # Subtract dot product of current vector and each center to the right 368 | curr_X_center_dot = X_center_dot[data_order[ix]] 369 | for i in range(k): 370 | X_sum_left_center_dot[i] = X_sum_left_center_dot[i] + curr_X_center_dot[i] 371 | X_sum_right_center_dot[i] = X_sum_right_center_dot[i] - curr_X_center_dot[i] 372 | 373 | n_left += 1 374 | n_right -= 1 375 | 376 | if (prev_threshold != threshold) and (cur_total_cost < best_cost): 377 | valid_found = 1 378 | best_threshold = prev_threshold 379 | best_cost = cur_total_cost 380 | best_center_left = left_center 381 | best_center_right = right_center 382 | 383 | 384 | free(X_sum_left_center_dot) 385 | free(X_sum_right_center_dot) 386 | 387 | if valid_found == 1: 388 | thresholds[col] = best_threshold 389 | costs[col] = best_cost 390 | left_centers[col] = best_center_left 391 | right_centers[col] = best_center_right 392 | else: 393 | thresholds[col] = -1.0 394 | costs[col] = -1.0 395 | left_centers[col] = -1 396 | right_centers[col] = -1 397 | -------------------------------------------------------------------------------- /mercury/explainability/explainers/cf_strategies.py: -------------------------------------------------------------------------------- 1 | """ 2 | Optimization strategies for simple counterfactual explanations. 3 | """ 4 | 5 | import random 6 | import typing as TP 7 | from abc import ABCMeta, abstractmethod 8 | import queue 9 | import sys 10 | 11 | from simanneal import Annealer 12 | import numpy as np 13 | import pandas as pd 14 | 15 | 16 | class Strategy(metaclass=ABCMeta): 17 | """ Base class for explanation strategies. """ 18 | 19 | def __init__(self, 20 | state: 'np.ndarray', 21 | bounds: 'np.ndarray', 22 | step: 'np.ndarray', 23 | fn: TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], 'np.ndarray'], 24 | class_idx: int = 1, 25 | threshold: float = 0., 26 | kernel: TP.Optional['np.ndarray'] = None, 27 | report: bool = False) -> None: 28 | 29 | # Starting point 30 | self.state = state 31 | # Dimension bounds 32 | assert bounds.shape[0] == state.shape[0] 33 | self.bounds = bounds 34 | assert step.shape[0] == bounds.shape[0] 35 | self.step = step 36 | # Classifier fn 37 | self.fn = fn 38 | # Class 39 | self.class_idx = class_idx 40 | # Explored points 41 | self.explored = [self.state] 42 | # Probability threshold 43 | assert threshold >= 0. and threshold <= 1. 44 | self.threshold = threshold 45 | # Energies 46 | self.energies = [self.fn(self.state.reshape(1, -1))[:, self.class_idx][0]] 47 | # Kernel 48 | if kernel is None: 49 | kernel = np.ones(self.state.shape[0]) 50 | # minimization or maximization problem? 51 | cur_prob = self.fn(self.state.reshape(1, -1))[0, self.class_idx] 52 | if cur_prob > self.threshold: 53 | self.min = True 54 | else: 55 | self.min = False 56 | 57 | self.kernel = kernel 58 | self.report = report 59 | 60 | @abstractmethod 61 | def run(self, *args, **kwargs) -> TP.Tuple['np.ndarray', float, 'np.ndarray', TP.Optional['np.ndarray']]: 62 | pass 63 | 64 | @abstractmethod 65 | def update(self, *args, **kwargs) -> None: 66 | pass 67 | 68 | 69 | class SimulatedAnnealing(Strategy, Annealer): 70 | """ 71 | Simulated Annealing strategy. 72 | 73 | Args: 74 | state (np.ndarray): 75 | Initial state (initial starting point). 76 | bounds (np.ndarray): 77 | Bounds to be used when moving around the probability space defined by `fn`. 78 | step (np.ndarray): 79 | Step size values to be used when moving around the probability space defined by `fn`. 80 | Lower values may take more time/steps to find a solution while too large values may 81 | make impossible to find a solution. 82 | fn (TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], 'np.ndarray']): 83 | Classifier `predict_proba`- like function. 84 | class_idx (int): 85 | Class to be explained (e.g. 1 for binary classifiers). Default value is 1. 86 | threshold (float): 87 | Probability to be achieved (if path is found). Default value is 0.0. 88 | kernel (TP.Optional['np.ndarray']): 89 | Used to penalize certain dimensions when trying to move around the probability 90 | space (some dimensions may be more difficult to explain, hence don't move along them). 91 | report (bool): 92 | Whether to display probability updates during space search. 93 | """ 94 | 95 | def __init__(self, 96 | state: 'np.ndarray', 97 | bounds: 'np.ndarray', 98 | step: 'np.ndarray', 99 | fn: TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], 'np.ndarray'], 100 | class_idx: int = 1, 101 | threshold: float = 0., 102 | kernel: TP.Optional['np.ndarray'] = None, 103 | report: bool = False) -> None: 104 | Strategy.__init__(self, state, bounds, step, fn, class_idx, threshold, kernel, report) 105 | # Annealer's __init__ uses signal, see Annealers's init: 106 | # https://github.com/perrygeo/simanneal/blob/b2576eb75d88f8b8c91d959a44dd708706bb108e/simanneal/anneal.py#L52 107 | # This may break execution (depending on where simanneal runs, e.g. threads) 108 | # Overload it! 109 | self.state = Annealer.copy_state(self, state) 110 | 111 | def move(self) -> None: 112 | """ Move step in Simulated Annealing. """ 113 | self.state = np.random.uniform(-0.1, 0.1, size=self.state.shape[0]) * self.kernel * self.step + self.state 114 | for i, bound in enumerate(self.bounds): 115 | if self.kernel[i] != 0: 116 | if self.state[i] < bound[0]: 117 | self.state[i] = bound[0] 118 | elif self.state[i] > bound[1]: 119 | self.state[i] = bound[1] 120 | self.explored.append(self.state) 121 | 122 | def energy(self) -> None: 123 | """ Energy step in Simulated Annealing. """ 124 | p = self.fn(np.array(self.state.reshape(1, -1)))[0, self.class_idx] 125 | # energy check 126 | if self.min: 127 | if self.threshold is not None and p < self.threshold: 128 | p = self.threshold 129 | else: 130 | if self.threshold is not None and p > self.threshold: 131 | p = self.threshold 132 | self.energies.append(p) 133 | value = p if self.min else -p 134 | return value 135 | 136 | def best_solution(self, n: int = 3) -> 'np.ndarray': 137 | """ Returns the n best solutions found during the Simulated Annealing. 138 | 139 | Args: 140 | n (int): Number of solutions to be retrieved. Default value is 3. 141 | """ 142 | 143 | ps = self.fn(np.array(self.explored))[:, self.class_idx] 144 | sorted_ps_idxs = np.argsort(ps) 145 | if not self.min: 146 | sorted_ps_idxs = sorted_ps_idxs[::-1] 147 | return np.array(self.explored)[sorted_ps_idxs[:n]] 148 | 149 | def run(self, *args, **kwargs) -> TP.Tuple['np.ndarray', float, 'np.ndarray', TP.Optional['np.ndarray']]: 150 | """ Kick off Simulated Annealing. 151 | 152 | Args: 153 | **kwargs (dict): 154 | Simulated Annealing specific arguments 155 | - tmin: min temperature 156 | - tmax: max temperature 157 | - steps: number of iterations 158 | 159 | Returns: 160 | result (TP.Tuple['np.ndarray', float, 'np.ndarray', TP.Optional['np.ndarray']]): 161 | Tuple containing found solution, probability achieved, points visited and 162 | corresponding energies (i.e. probabilities). 163 | """ 164 | 165 | if 'tmin' in kwargs: 166 | self.Tmin = kwargs['tmin'] 167 | if 'tmax' in kwargs: 168 | self.Tmax = kwargs['tmax'] 169 | if 'steps' in kwargs: 170 | self.steps = kwargs['steps'] 171 | self.updates = self.steps 172 | # self.updates = 100 173 | sol, p = self.anneal() # type: TP.Tuple['np.ndarray', float] 174 | return sol, p, np.array(self.explored), np.array(self.energies) 175 | 176 | def update(self, *args, **kwargs) -> None: 177 | # TODO: By default use default's simanneal update method if necessary 178 | if self.report: 179 | Annealer.update(self, *args, **kwargs) 180 | 181 | 182 | class MyPriorityQueue(queue.PriorityQueue): 183 | 184 | def __init__(self): 185 | super().__init__() 186 | 187 | def get_same_priority(self, limit=None, block=False, shuffle_limit=False): 188 | e = self.get(block=block) 189 | elements = [e] 190 | prio = e[0] 191 | while self.qsize(): 192 | e = self.get(block=block) 193 | if e[0] != prio: 194 | self.put(e) 195 | break 196 | elements.append(e) 197 | # if we already reached the limit and no shuffle, we stop here 198 | if (limit is not None) and (not shuffle_limit) and (len(elements) >= limit): 199 | return elements 200 | if limit is not None and len(elements) > limit: 201 | random.shuffle(elements) 202 | for e in elements[limit:]: 203 | self.put(e) 204 | return elements[:limit] 205 | return elements 206 | 207 | 208 | class Backtracking(Strategy): 209 | """ 210 | Backtracking strategy. 211 | 212 | Args: 213 | state ('np.ndarray'): Initial state (initial starting point). 214 | bounds (np.ndarray): Bounds to be used when moving around the probability space defined by `fn`. 215 | step (np.ndarray): Step size values to be used when moving around the probability space defined by `fn`. 216 | Lower values may take more time/steps to find a solution while too large values may 217 | make impossible to find a solution. 218 | fn (TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], 'np.ndarray']): Classifier `predict_proba`- like function. 219 | class_idx (int): Class to be explained (e.g. 1 for binary classifiers). Default value is 1. 220 | threshold (float): Probability to be achieved (if path is found). Default value is 0.0. 221 | kernel (TP.Optional['np.ndarray']): Used to penalize certain dimensions when trying to move around the probability 222 | space (some dimensions may be more difficult to explain, hence don't move along them). 223 | report (bool): Whether to display probability updates during space search. 224 | keep_explored_points (bool): Whether to keep the points that the algorithm explores. Setting it to False will 225 | decrease the computation time and memory usage in some cases. Default value is True. 226 | 227 | """ 228 | 229 | def __init__(self, 230 | state: 'np.ndarray', 231 | bounds: 'np.ndarray', 232 | step: 'np.ndarray', 233 | fn: TP.Callable[[TP.Union['np.ndarray', pd.DataFrame]], 'np.ndarray'], 234 | class_idx: int = 1, 235 | threshold: float = 0., 236 | kernel: TP.Optional['np.ndarray'] = None, 237 | report: bool = False, 238 | keep_explored_points: bool = True, 239 | ) -> None: 240 | super().__init__(state, bounds, step, fn, class_idx, threshold, kernel, report) 241 | self.keep_explored_points = keep_explored_points 242 | 243 | def run(self, *args, **kwargs) -> TP.Tuple['np.ndarray', float, 'np.ndarray', TP.Optional['np.ndarray']]: 244 | """ Kick off the backtracking. 245 | 246 | Args: 247 | **kwargs (dict): Backtracking specific arguments 248 | - max_iter: max number of iterations 249 | 250 | Returns: 251 | result (TP.Tuple['np.ndarray', float, 'np.ndarray', TP.Optional['np.ndarray']]): 252 | Tuple containing found solution, probability achieved, points visited and 253 | corresponding energies (i.e. probabilities). 254 | """ 255 | 256 | # Backtracking specific arguments 257 | if 'max_iter' in kwargs: 258 | max_iter = kwargs['max_iter'] 259 | else: 260 | max_iter = 0 261 | if 'limit' in kwargs: 262 | limit = kwargs['limit'] 263 | else: 264 | limit = None 265 | if 'shuffle_limit' in kwargs: 266 | shuffle_limit = kwargs['shuffle_limit'] 267 | else: 268 | shuffle_limit = False 269 | 270 | # Backtracking starting point. 271 | point = self.state.reshape(1, -1) 272 | curr_prob = self.fn(point)[0, self.class_idx] 273 | best_point = point.copy() 274 | best_prob = curr_prob 275 | 276 | # Define condition (>= or <=) from current point and given threshold. 277 | # and the priority modifier (ascending or descending order) using prio_mod. 278 | if self.threshold > curr_prob: 279 | condition = float.__ge__ 280 | prio_mod = lambda x: -x 281 | else: 282 | condition = float.__le__ 283 | prio_mod = lambda x: x 284 | self.kernel = 1.1 - self.kernel 285 | 286 | visited = np.hstack([point, [[curr_prob]]]) 287 | explored = np.hstack([point, [[curr_prob]]]) if self.keep_explored_points else np.array([]) 288 | 289 | q = MyPriorityQueue() # type: ignore 290 | # Avoid duplicates and querying known points. 291 | cache = {str(point.tolist()): curr_prob} 292 | points = [point] 293 | counter = 0 294 | while condition(self.threshold, curr_prob) and (not max_iter or counter < max_iter): 295 | # Explore neighbours. 296 | neighbours = np.zeros((0, self.step.size)) 297 | neighbours_l = [] 298 | neighbours_kernel = [] 299 | local_cache = set() 300 | for p_ in points: 301 | for i, delta in enumerate(self.step): 302 | # In every direction. 303 | for sign in (-1, 1): 304 | aux = p_.copy() 305 | new_aux_i = aux[0, i] + sign * delta 306 | if new_aux_i >= self.bounds[i][0] and new_aux_i <= self.bounds[i][1]: 307 | aux[0, i] = new_aux_i 308 | aux_l = aux.tolist() 309 | str_aux_l = str(aux_l) 310 | if str_aux_l not in cache and str_aux_l not in local_cache: 311 | # If point is not in cache (visited) -> enqueue. 312 | neighbours = np.vstack([neighbours, aux]) 313 | neighbours_l.append(aux_l) 314 | neighbours_kernel.append(self.kernel[i]) 315 | local_cache.add(str_aux_l) 316 | if len(neighbours_l): 317 | assert neighbours.shape[0] == len(neighbours_kernel) == len(neighbours_l), \ 318 | 'Number of neighbours should match.' 319 | probs = self.fn(neighbours)[:, self.class_idx] 320 | for n_idx, kernel, pt in zip( 321 | range(probs.shape[0]), neighbours_kernel, neighbours_l): 322 | prob = float(probs[n_idx]) 323 | prio = prob * kernel 324 | q.put((prio_mod(prio), prob, pt)) 325 | if self.keep_explored_points: 326 | explored = np.vstack( 327 | [explored, 328 | np.hstack([neighbours[n_idx].reshape(1, -1), [[prob]]])]) 329 | cache[str(pt)] = prob 330 | try: 331 | elements = q.get_same_priority(limit=limit, block=False, shuffle_limit=shuffle_limit) 332 | curr_prob = max([x[1] for x in elements]) 333 | points = [np.array(x[2]) for x in elements] 334 | except queue.Empty: 335 | self.visited = visited 336 | self.explored = explored 337 | return best_point[0], best_prob, self.visited, self.explored 338 | if condition(curr_prob, best_prob): 339 | best_prob = curr_prob 340 | best_point = points[0] 341 | for point in points: 342 | visited = np.vstack([visited, np.hstack([point, 343 | [[curr_prob]]])]) 344 | 345 | if self.report: 346 | 347 | iter_string = f"{counter}/{max_iter}" 348 | update_string = f"\r{iter_string}\t\t{round(best_prob, 2)}" 349 | if counter == 0: 350 | print('\rIteration\tBest Prob', file=sys.stderr) 351 | print(update_string, file=sys.stderr, end="") 352 | else: 353 | print(update_string, file=sys.stderr, end="") 354 | sys.stderr.flush() 355 | 356 | counter += 1 if max_iter else 0 357 | 358 | self.explored = explored 359 | self.visited = visited 360 | return best_point[0], best_prob, self.visited, self.explored 361 | 362 | def update(self, *args, **kwargs): 363 | if self.report: 364 | # TODO: An implementation for updates should be provided 365 | pass 366 | --------------------------------------------------------------------------------