├── src └── interpretability │ ├── models │ ├── __init__.py │ ├── base.py │ ├── linear_regression.py │ ├── multilayer_perceptron.py │ └── recurrent_neural_net.py │ ├── interpretability_models │ ├── utils │ │ ├── __init__.py │ │ ├── simplex_schedulers.py │ │ ├── data.py │ │ └── io.py │ ├── __init__.py │ ├── base.py │ ├── shap_explainer.py │ ├── dynamask_explainer.py │ └── symbolic_pursuit_explainer.py │ ├── __init__.py │ ├── utils │ └── pip.py │ └── exceptions │ └── exceptions.py ├── docs ├── contributing.rst ├── readme.rst ├── authors.rst ├── changelog.rst ├── license.rst ├── requirements.txt ├── Makefile ├── _static │ └── .gitignore ├── index.rst └── conf.py ├── AUTHORS.rst ├── images ├── user_inter_face_upload.png ├── Short_intro_video_thumbnail.png ├── interpretability_suite_image.png └── Interpretability_method_flow_diagram.svg ├── CHANGELOG.rst ├── tests └── conftest.py ├── pyproject.toml ├── .readthedocs.yml ├── requirements.txt ├── .coveragerc ├── setup.py ├── LICENSE.txt ├── .gitignore ├── tox.ini ├── setup.cfg ├── Notebooks ├── Tutorial_05_implement_symbolic_pursuit.ipynb └── Tutorial_02_implement_simplex_time_series.ipynb ├── CONTRIBUTING.rst └── README.md /src/interpretability/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. _readme: 2 | .. include:: ../README.md 3 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGELOG.rst 3 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * robsdavis 6 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. include:: ../LICENSE.txt 8 | -------------------------------------------------------------------------------- /images/user_inter_face_upload.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanderschaarlab/Interpretability/HEAD/images/user_inter_face_upload.png -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Import explainers to hoist the explainer classes to top level explainers package 2 | -------------------------------------------------------------------------------- /images/Short_intro_video_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanderschaarlab/Interpretability/HEAD/images/Short_intro_video_thumbnail.png -------------------------------------------------------------------------------- /images/interpretability_suite_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanderschaarlab/Interpretability/HEAD/images/interpretability_suite_image.png -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1 6 | =========== 7 | 8 | - The initial version of interpretability 9 | - Implements a python interface for Simplex, Dynamask, Shap, and Symbolic Pursuit 10 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dummy conftest.py for interpretability. 3 | 4 | If you don't know what this is for, just leave it empty. 5 | Read more about conftest.py under: 6 | - https://docs.pytest.org/en/stable/fixture.html 7 | - https://docs.pytest.org/en/stable/writing_plugins.html 8 | """ 9 | 10 | # import pytest 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # For smarter version schemes and other configuration options, 8 | # check out https://github.com/pypa/setuptools_scm 9 | version_scheme = "no-guess-dev" 10 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/utils/simplex_schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class Scheduler: 5 | def __init__(self, n_epoch): 6 | self.n_epoch = n_epoch 7 | 8 | 9 | class ExponentialScheduler(Scheduler): 10 | def __init__(self, x_init: float, x_final: float, n_epoch: int): 11 | Scheduler.__init__(self, n_epoch) 12 | self.step_factor = math.exp(math.log(x_final / x_init) / n_epoch) 13 | 14 | def step(self, x): 15 | return x * self.step_factor 16 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Build documentation with MkDocs 12 | #mkdocs: 13 | # configuration: mkdocs.yml 14 | 15 | # Optionally build your docs in additional formats such as PDF 16 | formats: 17 | - pdf 18 | 19 | python: 20 | version: 3.8 21 | install: 22 | - requirements: docs/requirements.txt 23 | - {path: ., method: pip} 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | beautifulsoup4==4.11.1 2 | dill==0.3.6 3 | dynamask==0.0.1 4 | # invase==0.0.3 5 | # lfxai==0.1.1 6 | matplotlib==3.5.3 7 | matplotlib-inline==0.1.6 8 | numpy==1.22.4 9 | pandas==1.3.5 10 | pydot==1.4.2 11 | PyScaffold==4.3.1 12 | scikit-learn==0.24.2 13 | scipy==1.8.1 14 | seaborn==0.11.2 15 | shap==0.41.0 16 | simplexai==0.0.2 17 | sktime==0.13.2 18 | symbolic-pursuit==0.0.1 19 | sympy==1.6.2 20 | tensorflow==2.10.0 21 | termcolor==2.0.1 22 | torch==1.12.1 23 | torchvision==0.13.1 24 | tqdm==4.64.1 25 | typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work 26 | xgboost==1.6.2 -------------------------------------------------------------------------------- /src/interpretability/models/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | 4 | 5 | class BlackBox(torch.nn.Module): 6 | @abc.abstractmethod 7 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 8 | """ 9 | Evaluates the latent representation for the example x 10 | :param x: input features 11 | :return: 12 | """ 13 | return 14 | 15 | @abc.abstractmethod 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | """ 18 | Evaluates the output for the example x 19 | :param x: input features 20 | :return: 21 | """ 22 | return 23 | -------------------------------------------------------------------------------- /src/interpretability/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = interpretability 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for interpretability. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.3.1. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 14 | except: # noqa 15 | print( 16 | "\n\nAn error occurred while building the project, " 17 | "please ensure you have the most updated version of setuptools, " 18 | "setuptools_scm and wheel with:\n" 19 | " pip install -U setuptools setuptools_scm wheel\n\n" 20 | ) 21 | raise 22 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for ReadTheDocs, check .readthedocs.yml. 2 | # To build the module reference correctly, make sure every external package 3 | # under `install_requires` in `setup.cfg` is also listed here! 4 | sphinx>=3.2.1 5 | # sphinx_rtd_theme 6 | beautifulsoup4>=4.11.1 7 | dill>=0.3.6 8 | dynamask>=0.0.1 9 | # invase>=0.0.3 10 | # lfxai>=0.1.1 11 | matplotlib>=3.5.3 12 | matplotlib-inline>=0.1.6 13 | numpy>=1.22.4 14 | pandas>=1.3.5 15 | pydot>=1.4.2 16 | scikit-learn==0.24.2 17 | scipy>=1.8.1 18 | seaborn>=0.11.2 19 | shap>=0.41.0 20 | simplexai>=0.0.2 21 | sktime>=0.13.2 22 | symbolic-pursuit>=0.0.1 23 | sympy>=1.6.2 24 | tensorflow>=2.10.0 25 | termcolor>=2.0.1 26 | torch>=1.12.1 27 | torchvision>=0.13.1 28 | tqdm>=4.64.1 29 | typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work 30 | xgboost>=1.6.2 -------------------------------------------------------------------------------- /src/interpretability/utils/pip.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | from pathlib import Path 3 | import subprocess 4 | import sys 5 | 6 | current_dir = Path(__file__).parent 7 | 8 | predefined = { 9 | "shap": "shap>=0.41.0", 10 | "combo": "git+https://github.com/yzhao062/combo", 11 | "symbolic_pursuit": "git+https://github.com/vanderschaarlab/Symbolic-Pursuit", 12 | } 13 | 14 | 15 | def install(packages: list) -> None: 16 | for package in packages: 17 | install_pack = package 18 | if package in predefined: 19 | install_pack = predefined[package] 20 | print(f"Installing {install_pack}") 21 | 22 | try: 23 | subprocess.check_call( 24 | [sys.executable, "-m", "pip", "install", install_pack], 25 | stdout=subprocess.DEVNULL, 26 | stderr=subprocess.DEVNULL, 27 | ) 28 | except BaseException as e: 29 | print(f"failed to install package {package}: {e}") 30 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TabularDataset(Dataset): 7 | def __init__(self, X, y=None) -> None: 8 | self.X = X 9 | self.y = y.astype(int) 10 | 11 | def __len__(self) -> int: 12 | return len(self.X) 13 | 14 | def __getitem__(self, i: int) -> tuple: 15 | data = torch.tensor(self.X.iloc[i, :], dtype=torch.float32) 16 | target = self.y.iloc[i] 17 | return data, target 18 | 19 | 20 | class TimeSeriesDataset(Dataset): 21 | def __init__(self, X, y=None) -> None: 22 | self.X = X 23 | self.y = y.astype(int) 24 | 25 | def __len__(self) -> int: 26 | return self.X.shape[0] 27 | 28 | def __getitem__(self, i: int) -> tuple: 29 | data = torch.tensor(self.X[i], dtype=torch.float32) 30 | target = torch.tensor(self.y[i], dtype=torch.float32) 31 | return data, target 32 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 robsdavis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | AUTODOCDIR = api 11 | 12 | # User-friendly check for sphinx-build 13 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 14 | $(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") 15 | endif 16 | 17 | .PHONY: help clean Makefile 18 | 19 | # Put it first so that "make" without argument is like "make help". 20 | help: 21 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | clean: 24 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 25 | 26 | # Catch-all target: route all unknown targets to Sphinx using the new 27 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 28 | %: Makefile 29 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 30 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dill as pkl 3 | from copy import deepcopy 4 | import torch 5 | 6 | cwd = os.path.abspath(".") 7 | 8 | pkl.settings["recurse"] = True 9 | 10 | 11 | def save_explainer(explainer, save_path, verbose=True): 12 | save_path = os.path.join(cwd, save_path) 13 | if verbose: 14 | print(f"Saving explainer to: {save_path}") 15 | with open(save_path, "wb") as f: 16 | 17 | pkl.dump(explainer, f) 18 | 19 | 20 | def load_explainer(save_path, join_to_cwd_to_save_path=True): 21 | if join_to_cwd_to_save_path: 22 | with open(os.path.join(cwd, save_path), "rb") as f: 23 | return pkl.load(f) 24 | else: 25 | with open(save_path, "rb") as f: 26 | return pkl.load(f) 27 | 28 | 29 | def check_attribute_eq(attribute, explainer, explainer_from_file): 30 | print(f"Comparing {attribute}") 31 | if isinstance(getattr(explainer, attribute), torch.Tensor): 32 | exp = torch.equal( 33 | getattr(explainer, attribute), getattr(explainer_from_file, attribute) 34 | ) 35 | else: 36 | exp = getattr(explainer, attribute) == getattr(explainer_from_file, attribute) 37 | if not exp: 38 | try: 39 | assert getattr(explainer, attribute) != deepcopy( 40 | getattr(explainer, attribute) 41 | ) 42 | print(f"\t{attribute} not comparible") 43 | except AssertionError as e: 44 | print(f"\t{attribute} is not equal") 45 | raise e 46 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/base.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Optional, Union, List 4 | 5 | # third party 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | class Explainer(metaclass=ABCMeta): 11 | def __init__(self) -> None: 12 | self.has_been_fit = False 13 | self.explanation = None 14 | ... 15 | 16 | @staticmethod 17 | @abstractmethod 18 | def name() -> str: 19 | ... 20 | 21 | @staticmethod 22 | @abstractmethod 23 | def pretty_name() -> str: 24 | ... 25 | 26 | @staticmethod 27 | def type() -> str: 28 | return "explainer" 29 | 30 | @abstractmethod 31 | def fit(self, X: pd.DataFrame) -> pd.DataFrame: 32 | """ 33 | The function to fit the explainer to the data 34 | """ 35 | ... 36 | 37 | @abstractmethod 38 | def explain(self) -> pd.DataFrame: 39 | """ 40 | The function to get the explanation data from the explainer 41 | """ 42 | ... 43 | 44 | 45 | class Explanation(metaclass=ABCMeta): 46 | def __init__(self) -> None: 47 | ... 48 | 49 | @staticmethod 50 | @abstractmethod 51 | def name() -> str: 52 | ... 53 | 54 | @staticmethod 55 | def type() -> str: 56 | return "explanation" 57 | 58 | 59 | class FeatureExplanation(Explanation): 60 | def __init__(self, feature_importances: Union[pd.DataFrame, List]) -> None: 61 | self.feature_importances = feature_importances 62 | super().__init__() 63 | 64 | @staticmethod 65 | def name() -> str: 66 | return "Feature Explanation" 67 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | */__pycache__/ 12 | __pycache__/ 13 | .cache/* 14 | .*.swp 15 | */.ipynb_checkpoints/* 16 | .DS_Store 17 | 18 | # Project files 19 | .ropeproject 20 | .project 21 | .pydevproject 22 | .settings 23 | .idea 24 | .vscode 25 | tags 26 | 27 | # Package files 28 | *.egg 29 | *.eggs/ 30 | .installed.cfg 31 | *.egg-info 32 | 33 | # Unittest and coverage 34 | htmlcov/* 35 | .coverage 36 | .coverage.* 37 | .tox 38 | junit*.xml 39 | coverage.xml 40 | .pytest_cache/ 41 | 42 | # Build and docs folder/files 43 | build/* 44 | dist/* 45 | sdist/* 46 | docs/api/* 47 | docs/_rst/* 48 | docs/_build/* 49 | cover/* 50 | MANIFEST 51 | 52 | # Per-project virtualenvs 53 | .venv*/ 54 | .conda*/ 55 | .python-version 56 | 57 | # Notebooks for testing implementations 58 | */notebooks/ 59 | notebooks/ 60 | Notebooks/ 61 | */implement*.py 62 | 63 | # Package files 64 | *.egg 65 | *.eggs/ 66 | .installed.cfg 67 | *.egg-info 68 | 69 | # Unittest and coverage 70 | htmlcov/* 71 | .coverage 72 | .coverage.* 73 | .tox 74 | junit*.xml 75 | coverage.xml 76 | .pytest_cache/ 77 | 78 | # images 79 | images/ 80 | !images/Short_intro_video_thumbnail.png 81 | !images/Interpretability_method_flow_diagram.svg 82 | images/archived_images/ 83 | 84 | 85 | 86 | # output 87 | output/ 88 | # Keep empty output folder 89 | !output/.gitkeep 90 | 91 | 92 | # data 93 | notebooks/data/ 94 | notebooks/data/MNIST/ 95 | notebooks/data/MNIST/raw/ 96 | data/ 97 | raw.githubusercontent.com/ 98 | 99 | # resources 100 | resources/ 101 | resources/data_scalers 102 | resources/saved_explainers 103 | 104 | # models 105 | models/ 106 | models/resources 107 | models/resources/trained_models 108 | 109 | -------------------------------------------------------------------------------- /src/interpretability/models/linear_regression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from interpretability.models.base import BlackBox 5 | 6 | 7 | class LinearRegression(BlackBox): 8 | def __init__(self, n_cont: int = 3, input_feature_num=26, output_dim=2) -> None: 9 | """ 10 | Mortality predictor MLP 11 | :param n_cont: number of continuous features among the output features 12 | """ 13 | super().__init__() 14 | self.n_cont = n_cont 15 | self.lin = nn.Linear(input_feature_num, output_dim) 16 | self.bn1 = nn.BatchNorm1d(self.n_cont) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | x_cont, x_disc = x[:, : self.n_cont], x[:, self.n_cont :] 20 | x_cont = self.bn1(x_cont) 21 | x = torch.cat([x_cont, x_disc], 1) 22 | x = self.lin(x) 23 | x = F.log_softmax(x, dim=-1) 24 | return x 25 | 26 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 27 | """ 28 | Latent space is the input space for linear regression 29 | """ 30 | return x 31 | 32 | def probabilities(self, x: torch.Tensor) -> torch.Tensor: 33 | """ 34 | Returns the class probabilities for the input x 35 | :param x: input features 36 | :return: probabilities 37 | """ 38 | x = self.latent_representation(x) 39 | x = self.lin(x) 40 | x = F.softmax(x, dim=-1) 41 | return x 42 | 43 | def predict(self, x: torch.Tensor) -> torch.Tensor: 44 | probs = self.probabilities(x) 45 | preds = torch.argmax(probs) 46 | print(preds) 47 | return preds 48 | 49 | return 50 | 51 | def latent_to_presoftmax(self, h: torch.Tensor) -> torch.Tensor: 52 | """ 53 | Maps a latent representation to a preactivation output 54 | :param h: latent representations 55 | :return: presoftmax activations 56 | """ 57 | h = self.lin(h) 58 | return h 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | */__pycache__/ 12 | __pycache__/ 13 | .cache/* 14 | .*.swp 15 | */.ipynb_checkpoints/* 16 | .DS_Store 17 | 18 | # Project files 19 | .ropeproject 20 | .project 21 | .pydevproject 22 | .settings 23 | .idea 24 | .vscode 25 | tags 26 | 27 | # Package files 28 | *.egg 29 | *.eggs/ 30 | .installed.cfg 31 | *.egg-info 32 | 33 | # Unittest and coverage 34 | htmlcov/* 35 | .coverage 36 | .coverage.* 37 | .tox 38 | junit*.xml 39 | coverage.xml 40 | .pytest_cache/ 41 | 42 | # Build and docs folder/files 43 | build/* 44 | dist/* 45 | sdist/* 46 | docs/api/* 47 | docs/_rst/* 48 | docs/_build/* 49 | cover/* 50 | MANIFEST 51 | 52 | # Per-project virtualenvs 53 | .venv*/ 54 | .conda*/ 55 | .python-version 56 | 57 | # Notebooks for testing implementations 58 | */notebooks/ 59 | notebooks/ 60 | Notebooks/* 61 | !Notebooks/Tutorial_01_implement_simplex.ipynb 62 | !Notebooks/Tutorial_02_implement_simplex_time_series.ipynb 63 | !Notebooks/Tutorial_03_implement_dynamask.ipynb 64 | !Notebooks/Tutorial_04_implement_shap.ipynb 65 | !Notebooks/Tutorial_05_implement_symbolic_pursuit.ipynb 66 | 67 | # Package files 68 | *.egg 69 | *.eggs/ 70 | .installed.cfg 71 | *.egg-info 72 | 73 | # Unittest and coverage 74 | htmlcov/* 75 | .coverage 76 | .coverage.* 77 | .tox 78 | junit*.xml 79 | coverage.xml 80 | .pytest_cache/ 81 | 82 | # images 83 | images/* 84 | !images/Short_intro_video_thumbnail.png 85 | !images/Interpretability_method_flow_diagram.svg 86 | !images/interpretability_suite_image.png 87 | !images/user_inter_face_upload.png 88 | images/archived_images/ 89 | symbolic_pursuit_expression.png 90 | symbolic_pursuit_projections.png 91 | 92 | 93 | 94 | # output 95 | output/ 96 | # Keep empty output folder 97 | !output/.gitkeep 98 | 99 | 100 | # data 101 | notebooks/data/ 102 | notebooks/data/MNIST/ 103 | notebooks/data/MNIST/raw/ 104 | data/ 105 | raw.githubusercontent.com/ 106 | 107 | # resources 108 | resources/* 109 | !resources/saved_models 110 | resources/saved_models/* 111 | 112 | 113 | # models 114 | **/models/train_*.py 115 | src/interpretability/models/resources/ 116 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | interpretability 3 | ================ 4 | 5 | This is the documentation of **interpretability**. 6 | 7 | .. note:: 8 | 9 | This is the main page of your project's `Sphinx`_ documentation. 10 | It is formatted in `reStructuredText`_. Add additional pages 11 | by creating rst-files in ``docs`` and adding them to the `toctree`_ below. 12 | Use then `references`_ in order to link them from this page, e.g. 13 | :ref:`authors` and :ref:`changes`. 14 | 15 | It is also possible to refer to the documentation of other Python packages 16 | with the `Python domain syntax`_. By default you can reference the 17 | documentation of `Sphinx`_, `Python`_, `NumPy`_, `SciPy`_, `matplotlib`_, 18 | `Pandas`_, `Scikit-Learn`_. You can add more by extending the 19 | ``intersphinx_mapping`` in your Sphinx's ``conf.py``. 20 | 21 | The pretty useful extension `autodoc`_ is activated by default and lets 22 | you include documentation from docstrings. Docstrings can be written in 23 | `Google style`_ (recommended!), `NumPy style`_ and `classical style`_. 24 | 25 | 26 | Contents 27 | ======== 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | 32 | Overview 33 | Contributions & Help 34 | License 35 | Authors 36 | Changelog 37 | Module Reference 38 | 39 | 40 | Indices and tables 41 | ================== 42 | 43 | * :ref:`genindex` 44 | * :ref:`modindex` 45 | * :ref:`search` 46 | 47 | .. _toctree: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html 48 | .. _reStructuredText: https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 49 | .. _references: https://www.sphinx-doc.org/en/stable/markup/inline.html 50 | .. _Python domain syntax: https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#the-python-domain 51 | .. _Sphinx: https://www.sphinx-doc.org/ 52 | .. _Python: https://docs.python.org/ 53 | .. _Numpy: https://numpy.org/doc/stable 54 | .. _SciPy: https://docs.scipy.org/doc/scipy/reference/ 55 | .. _matplotlib: https://matplotlib.org/contents.html# 56 | .. _Pandas: https://pandas.pydata.org/pandas-docs/stable 57 | .. _Scikit-Learn: https://scikit-learn.org/stable 58 | .. _autodoc: https://www.sphinx-doc.org/en/master/ext/autodoc.html 59 | .. _Google style: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings 60 | .. _NumPy style: https://numpydoc.readthedocs.io/en/latest/format.html 61 | .. _classical style: https://www.sphinx-doc.org/en/master/domains.html#info-field-lists 62 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | 11 | [testenv] 12 | description = Invoke pytest to run automated tests 13 | setenv = 14 | TOXINIDIR = {toxinidir} 15 | passenv = 16 | HOME 17 | SETUPTOOLS_* 18 | extras = 19 | testing 20 | commands = 21 | pytest {posargs} 22 | 23 | 24 | # # To run `tox -e lint` you need to make sure you have a 25 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 26 | # [testenv:lint] 27 | # description = Perform static analysis and style checks 28 | # skip_install = True 29 | # deps = pre-commit 30 | # passenv = 31 | # HOMEPATH 32 | # PROGRAMDATA 33 | # SETUPTOOLS_* 34 | # commands = 35 | # pre-commit run --all-files {posargs:--show-diff-on-failure} 36 | 37 | 38 | [testenv:{build,clean}] 39 | description = 40 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 41 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 42 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 43 | skip_install = True 44 | changedir = {toxinidir} 45 | deps = 46 | build: build[virtualenv] 47 | passenv = 48 | SETUPTOOLS_* 49 | commands = 50 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 51 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 52 | build: python -m build {posargs} 53 | 54 | 55 | [testenv:{docs,doctests,linkcheck}] 56 | description = 57 | docs: Invoke sphinx-build to build the docs 58 | doctests: Invoke sphinx-build to run doctests 59 | linkcheck: Check for broken links in the documentation 60 | passenv = 61 | SETUPTOOLS_* 62 | setenv = 63 | DOCSDIR = {toxinidir}/docs 64 | BUILDDIR = {toxinidir}/docs/_build 65 | docs: BUILD = html 66 | doctests: BUILD = doctest 67 | linkcheck: BUILD = linkcheck 68 | deps = 69 | -r {toxinidir}/docs/requirements.txt 70 | # ^ requirements.txt shared with Read The Docs 71 | commands = 72 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 73 | 74 | 75 | [testenv:publish] 76 | description = 77 | Publish the package you have been developing to a package index server. 78 | By default, it uses testpypi. If you really want to publish your package 79 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 80 | skip_install = True 81 | changedir = {toxinidir} 82 | passenv = 83 | # See: https://twine.readthedocs.io/en/latest/ 84 | TWINE_USERNAME 85 | TWINE_PASSWORD 86 | TWINE_REPOSITORY 87 | TWINE_REPOSITORY_URL 88 | deps = twine 89 | commands = 90 | python -m twine check dist/* 91 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 92 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = interpretability 8 | description = An interface for different interpretability methods. 9 | author = robsdavis 10 | author_email = robsdavis473@gmail.com 11 | license = MIT 12 | license_files = LICENSE.txt 13 | long_description = file: README.md 14 | long_description_content_type = text/markdown; charset=UTF-8 15 | url = https://github.com/vanderschaarlab/Interpretability 16 | # project_urls = 17 | # Documentation = https://pyscaffold.org/ 18 | # Source = https://github.com/pyscaffold/pyscaffold/ 19 | # Changelog = https://pyscaffold.org/en/latest/changelog.html 20 | # Tracker = https://github.com/pyscaffold/pyscaffold/issues 21 | # Conda-Forge = https://anaconda.org/conda-forge/pyscaffold 22 | # Download = https://pypi.org/project/PyScaffold/#files 23 | # Twitter = https://twitter.com/PyScaffold 24 | 25 | # Change if running only on Windows, Mac or Linux (comma-separated) 26 | platforms = any 27 | 28 | # Add here all kinds of additional classifiers as defined under 29 | # https://pypi.org/classifiers/ 30 | classifiers = 31 | Development Status :: 4 - Beta 32 | Programming Language :: Python 33 | 34 | 35 | [options] 36 | zip_safe = False 37 | packages = find_namespace: 38 | include_package_data = True 39 | package_dir = 40 | =src 41 | 42 | # Require a min/specific Python version (comma-separated conditions) 43 | # python_requires = >=3.8 44 | 45 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 46 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 47 | # new major versions. This works if the required packages follow Semantic Versioning. 48 | # For more information, check out https://semver.org/. 49 | install_requires = 50 | beautifulsoup4 51 | dill 52 | dynamask 53 | matplotlib 54 | matplotlib-inline 55 | numpy 56 | pandas 57 | pydot 58 | PyScaffold 59 | scikit-learn 60 | scipy 61 | seaborn 62 | shap 63 | simplexai 64 | sktime 65 | symbolic-pursuit 66 | sympy 67 | tensorflow 68 | termcolor 69 | torch 70 | torchvision 71 | tqdm 72 | xgboost 73 | importlib-metadata; python_version<"3.8" 74 | 75 | 76 | [options.packages.find] 77 | where = src 78 | exclude = 79 | tests 80 | 81 | [options.extras_require] 82 | # Add here additional requirements for extra features, to install with: 83 | # `pip install interpretability[PDF]` like: 84 | # PDF = ReportLab; RXP 85 | 86 | # Add here test requirements (semicolon/line-separated) 87 | testing = 88 | setuptools 89 | pytest 90 | pytest-cov 91 | 92 | [options.entry_points] 93 | # Add here console scripts like: 94 | # console_scripts = 95 | # script_name = interpretability.module:function 96 | # For example: 97 | # console_scripts = 98 | # fibonacci = interpretability.skeleton:run 99 | # And any other entry points, for example: 100 | # pyscaffold.cli = 101 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 102 | 103 | [tool:pytest] 104 | # Specify command line options as you would do when invoking pytest directly. 105 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 106 | # in order to write a coverage file that can be read by Jenkins. 107 | # CAUTION: --cov flags may prohibit setting breakpoints while debugging. 108 | # Comment those flags to avoid this pytest issue. 109 | addopts = 110 | --cov interpretability --cov-report term-missing 111 | --verbose 112 | norecursedirs = 113 | dist 114 | build 115 | .tox 116 | testpaths = tests 117 | # Use pytest markers to select/deselect specific tests 118 | # markers = 119 | # slow: mark tests as slow (deselect with '-m "not slow"') 120 | # system: mark end-to-end system tests 121 | 122 | [devpi:upload] 123 | # Options for the devpi: PyPI server and packaging tool 124 | # VCS export must be deactivated since we are using setuptools-scm 125 | no_vcs = 1 126 | formats = bdist_wheel 127 | 128 | [flake8] 129 | # Some sane defaults for the code style checker flake8 130 | max_line_length = 88 131 | extend_ignore = E203, W503 132 | # ^ Black-compatible 133 | # E203 and W503 have edge cases handled by black 134 | exclude = 135 | .tox 136 | build 137 | dist 138 | .eggs 139 | docs/conf.py 140 | 141 | [pyscaffold] 142 | # PyScaffold's parameters when the project was created. 143 | # This will be used when updating. Do not change! 144 | version = 4.3.1 145 | package = interpretability 146 | extensions = 147 | no_skeleton 148 | -------------------------------------------------------------------------------- /src/interpretability/models/multilayer_perceptron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from interpretability.models.base import BlackBox 5 | 6 | 7 | class DiabetesMLPRegressor(BlackBox): 8 | def __init__(self, input_feature_num=10) -> None: 9 | """ 10 | Mortality predictor MLP 11 | :param n_cont: number of continuous features among the output features 12 | """ 13 | super().__init__() 14 | self.lin1 = nn.Linear(input_feature_num, 200) 15 | self.lin2 = nn.Linear(200, 50) 16 | self.lin3 = nn.Linear(50, 1) 17 | self.drops = nn.Dropout(0.3) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | x = self.latent_representation(x) 21 | x = self.lin3(x) 22 | return x 23 | 24 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 25 | x = F.relu(self.lin1(x)) 26 | x = self.drops(x) 27 | x = F.relu(self.lin2(x)) 28 | x = self.drops(x) 29 | return x 30 | 31 | def predict(self, x: torch.Tensor) -> torch.Tensor: 32 | return self.forward(x).detach().numpy() 33 | 34 | 35 | class IrisMLP(BlackBox): 36 | def __init__(self, n_cont: int = 3, input_feature_num=26) -> None: 37 | """ 38 | Mortality predictor MLP 39 | :param n_cont: number of continuous features among the output features 40 | """ 41 | super().__init__() 42 | self.n_cont = n_cont 43 | self.lin1 = nn.Linear(input_feature_num, 200) 44 | self.lin2 = nn.Linear(200, 50) 45 | self.lin3 = nn.Linear(50, 3) 46 | self.bn1 = nn.BatchNorm1d(self.n_cont) 47 | self.drops = nn.Dropout(0.3) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | x = self.latent_representation(x) 51 | x = self.lin3(x) 52 | x = F.log_softmax(x, dim=-1) 53 | return x 54 | 55 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 56 | x_cont, x_disc = x[:, : self.n_cont], x[:, self.n_cont :] 57 | x_cont = self.bn1(x_cont) 58 | x = torch.cat([x_cont, x_disc], 1) 59 | x = F.relu(self.lin1(x)) 60 | x = self.drops(x) 61 | x = F.relu(self.lin2(x)) 62 | x = self.drops(x) 63 | return x 64 | 65 | def probabilities(self, x: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Returns the class probabilities for the input x 68 | :param x: input features 69 | :return: probabilities 70 | """ 71 | x = self.latent_representation(x) 72 | x = self.lin3(x) 73 | x = F.softmax(x, dim=-1) 74 | return x 75 | 76 | def predict(self, x: torch.Tensor) -> torch.Tensor: 77 | probs = self.probabilities(x) 78 | preds = torch.argmax(probs) 79 | return preds 80 | 81 | def latent_to_presoftmax(self, h: torch.Tensor) -> torch.Tensor: 82 | """ 83 | Maps a latent representation to a preactivation output 84 | :param h: latent representations 85 | :return: presoftmax activations 86 | """ 87 | h = self.lin3(h) 88 | return h 89 | 90 | 91 | class WineMLP(BlackBox): 92 | def __init__(self, n_cont: int = 11, input_feature_num=11) -> None: 93 | """ 94 | Mortality predictor MLP 95 | :param n_cont: number of continuous features among the output features 96 | """ 97 | super().__init__() 98 | self.n_cont = n_cont 99 | self.lin1 = nn.Linear(input_feature_num, 200) 100 | self.lin2 = nn.Linear(200, 50) 101 | self.lin3 = nn.Linear(50, 7) 102 | self.bn1 = nn.BatchNorm1d(self.n_cont) 103 | self.drops = nn.Dropout(0.3) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | x = self.latent_representation(x) 107 | x = self.lin3(x) 108 | x = F.log_softmax(x, dim=-1) 109 | return x 110 | 111 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 112 | x_cont, x_disc = x[:, : self.n_cont], x[:, self.n_cont :] 113 | x_cont = self.bn1(x_cont) 114 | x = torch.cat([x_cont, x_disc], 1) 115 | x = F.relu(self.lin1(x)) 116 | x = self.drops(x) 117 | x = F.relu(self.lin2(x)) 118 | x = self.drops(x) 119 | return x 120 | 121 | def probabilities(self, x: torch.Tensor) -> torch.Tensor: 122 | """ 123 | Returns the class probabilities for the input x 124 | :param x: input features 125 | :return: probabilities 126 | """ 127 | x = self.latent_representation(x) 128 | x = self.lin3(x) 129 | x = F.softmax(x, dim=-1) 130 | return x 131 | 132 | def latent_to_presoftmax(self, h: torch.Tensor) -> torch.Tensor: 133 | """ 134 | Maps a latent representation to a preactivation output 135 | :param h: latent representations 136 | :return: presoftmax activations 137 | """ 138 | h = self.lin3(h) 139 | return h 140 | -------------------------------------------------------------------------------- /src/interpretability/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class ExplainCalledBeforeFit(Exception): 2 | """ 3 | Exception raised when explain is called before fit. 4 | """ 5 | 6 | def __init__( 7 | self, 8 | explainer_has_been_fit, 9 | exception_message="The explainer must be fit before explain() is called. Please call .fit() first.", 10 | ): 11 | self.explainer_has_been_fit = explainer_has_been_fit 12 | self.message = exception_message 13 | super().__init__(self.message) 14 | 15 | def __str__(self): 16 | return f"{self.message}" 17 | 18 | 19 | class MeasureFitQualityCalledBeforeFit(Exception): 20 | """ 21 | Exception raised when measure_fit_quality() is called before fit(). 22 | """ 23 | 24 | def __init__( 25 | self, 26 | explainer_has_been_fit, 27 | exception_message="The explainer must be fit before measure_fit_quality() is called. Please call .fit() first.", 28 | ): 29 | self.explainer_has_been_fit = explainer_has_been_fit 30 | self.message = exception_message 31 | super().__init__(self.message) 32 | 33 | def __str__(self): 34 | return f"{self.message}" 35 | 36 | 37 | class ModelsLatentRepresentationsNotAccessible(Exception): 38 | """ 39 | Exception raised when latent_representation() is called on the model object, but the model has no such method. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | exception_message="The model object has no 'latent_representation()' method, which receives the input data and returns their latent space representation. This method is a requirement for using SimplEx. The method should be the same as the full forward() method, but stop short of the final layer. For help in adding a method please see the examples of the models here: https://github.com/vanderschaarlab/Simplex/tree/main/src/simplexai/models", 45 | ): 46 | self.message = exception_message 47 | super().__init__(self.message) 48 | 49 | def __str__(self): 50 | return f"{self.message}" 51 | 52 | 53 | class InvalidEstimatorType(Exception): 54 | """ 55 | Exception raised when the estimator type is in the list of valid types (usually classifier or regressor). 56 | """ 57 | 58 | def __init__( 59 | self, 60 | estimator_type, 61 | valid_estimator_types=[], 62 | ): 63 | self.estimator_type = estimator_type 64 | self.valid_estimator_types = valid_estimator_types 65 | self.message = f"Estimator_type \"{self.estimator_type}\" not valid. Please use one of the following values: {', '.join(self.valid_estimator_types)}." 66 | super().__init__(self.message) 67 | 68 | def __str__(self): 69 | return f"{self.message}" 70 | 71 | 72 | class InvalidShapeForModelOutput(Exception): 73 | """ 74 | Exception raised when the estimator output is of an invalid shape. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | output_shape: int, 80 | ): 81 | self.message = f"Invalid shape of {output_shape} for output from the forward call of the estimator. The explainer supports single or multi-label classification and single label regression only." 82 | super().__init__(self.message) 83 | 84 | def __str__(self): 85 | return f"{self.message}" 86 | 87 | 88 | class ExampleImportanceThresholdTooHigh(Exception): 89 | """ 90 | Exception raised when the Example Importance Threshold is too high, such that there are no examples left for the explanation with an importance above the threshold. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | example_importance_threshold: int, 96 | max_importance: float, 97 | ): 98 | self.message = f"example_importance_threshold of {example_importance_threshold} is highest example importance value of {max_importance:0.2f}. Please reduce the example_importance_threshold to below {max_importance:0.2f} in order to see the examples." 99 | super().__init__(self.message) 100 | 101 | def __str__(self): 102 | return f"{self.message}" 103 | 104 | 105 | class NoDataToExplain(Exception): 106 | """ 107 | Exception raised when a fit call is made without any data to explain. 108 | """ 109 | 110 | def __init__( 111 | self, 112 | ): 113 | self.message = f"No data to explain has been passed to the explainer. This is only allowed when re-fitting a previously fit explainer." 114 | super().__init__(self.message) 115 | 116 | def __str__(self): 117 | return f"{self.message}" 118 | 119 | 120 | # class InvalidTimeStepAxis(Exception): 121 | # """ 122 | # 123 | # """ 124 | 125 | # def __init__( 126 | # self, 127 | # time_step_axis, 128 | # ): 129 | # self.message = ( 130 | # f"Value given for time_step_axis must be 0 or 1, not {time_step_axis}." 131 | # ) 132 | # super().__init__(self.message) 133 | 134 | # def __str__(self): 135 | # return f"{self.message}" 136 | -------------------------------------------------------------------------------- /src/interpretability/models/recurrent_neural_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from interpretability.models.base import BlackBox 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class MortalityGRU(BlackBox): 13 | def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob=0.2): 14 | super().__init__() 15 | self.hidden_dim = hidden_dim 16 | self.n_layers = n_layers 17 | 18 | self.gru = nn.GRU( 19 | input_dim, 20 | hidden_dim, 21 | n_layers, 22 | batch_first=True, # dropout=drop_prob 23 | ) 24 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 25 | self.fc2 = nn.Linear(hidden_dim, output_dim) 26 | self.sigmoid = nn.Sigmoid() 27 | 28 | def forward(self, x): 29 | x = self.latent_representation(x) 30 | x = self.fc2(x) 31 | x = self.sigmoid(x) 32 | return x 33 | 34 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 35 | x, h = self.gru(x) 36 | x = x[:, -1, :] 37 | x = self.fc1(x) 38 | return x 39 | 40 | 41 | class ArrowHeadGRU(BlackBox): 42 | def __init__(self, input_dim, hidden_dim, output_dim, n_layers, drop_prob=0.2): 43 | super().__init__() 44 | self.hidden_dim = hidden_dim 45 | self.n_layers = n_layers 46 | 47 | self.gru = nn.GRU( 48 | input_dim, 49 | hidden_dim, 50 | n_layers, 51 | batch_first=True, # dropout=drop_prob 52 | ) 53 | self.fc1 = nn.Linear(hidden_dim, hidden_dim) 54 | self.fc2 = nn.Linear(hidden_dim, output_dim) 55 | self.sigmoid = nn.Sigmoid() 56 | 57 | def forward(self, x): 58 | x = self.latent_representation(x) 59 | x = self.fc2(x) 60 | x = self.sigmoid(x) 61 | return x 62 | 63 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 64 | x, h = self.gru(x) 65 | x = x[:, -1, :] 66 | x = self.fc1(x) 67 | return x 68 | 69 | 70 | class ConvNet(BlackBox): 71 | def __init__( 72 | self, 73 | input_dim=1, 74 | hidden_dim=64, 75 | kernel_size=3, 76 | output_dim=1, 77 | drop_prob=0.2, 78 | activation_func="sigmoid", 79 | ): 80 | super().__init__() 81 | self.hidden_dim = hidden_dim 82 | 83 | self.convInput = nn.Conv1d(input_dim, hidden_dim, kernel_size, padding="same") 84 | self.convHidden1 = nn.Conv1d( 85 | hidden_dim, hidden_dim, kernel_size, padding="same" 86 | ) 87 | self.convHidden2 = nn.Conv1d( 88 | hidden_dim, hidden_dim, kernel_size, padding="same" 89 | ) 90 | self.bn1 = nn.BatchNorm1d(hidden_dim) 91 | self.bn2 = nn.BatchNorm1d(hidden_dim) 92 | self.bn3 = nn.BatchNorm1d(hidden_dim) 93 | self.relu = nn.ReLU() 94 | self.pool = nn.AdaptiveMaxPool1d(hidden_dim) 95 | self.flatten = nn.Flatten() 96 | self.fc1 = nn.Linear(hidden_dim**2, output_dim) 97 | # self.fc2 = nn.Linear(hidden_dim, output_dim) 98 | if activation_func == "sigmoid": 99 | self.activation_func = nn.Sigmoid() 100 | elif activation_func == "softmax": 101 | self.activation_func = nn.Softmax(dim=-1) 102 | elif not activation_func: 103 | self.activation_func = None 104 | 105 | def forward(self, x): 106 | x = self.latent_representation(x) 107 | x = self.fc1(x) 108 | if self.activation_func: 109 | x = self.activation_func(x) 110 | return x 111 | 112 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 113 | x = torch.transpose(x, 1, 2) 114 | x = self.relu(self.bn1(self.convInput(x))) 115 | x = self.relu(self.bn2(self.convHidden1(x))) 116 | x = self.relu(self.bn3(self.convHidden2(x))) 117 | x = self.flatten(self.pool(x)) 118 | return x 119 | 120 | 121 | class GRU(BlackBox): 122 | def __init__( 123 | self, input_dim=1, hidden_dim=5, output_dim=1, n_layers=3, drop_prob=0.2 124 | ): 125 | super().__init__() 126 | self.hidden_dim = hidden_dim 127 | self.n_layers = n_layers 128 | 129 | self.gru = nn.GRU( 130 | input_dim, 131 | hidden_dim, 132 | n_layers, 133 | batch_first=True, # dropout=drop_prob 134 | ) 135 | self.fc1 = nn.Linear(hidden_dim, output_dim) 136 | # self.fc2 = nn.Linear(hidden_dim, output_dim) 137 | self.sigmoid = nn.Sigmoid() 138 | 139 | def forward(self, x): 140 | x = self.latent_representation(x) 141 | x = self.fc1(x) 142 | x = self.sigmoid( 143 | x, 144 | ) 145 | return x 146 | 147 | def latent_representation(self, x: torch.Tensor) -> torch.Tensor: 148 | x, h = self.gru(x) 149 | x = x[:, -1, :] 150 | return x 151 | 152 | 153 | class ShallowRegressionLSTM(nn.Module): 154 | def __init__(self, num_sensors, hidden_units): 155 | super().__init__() 156 | self.num_sensors = num_sensors # this is the number of features 157 | self.hidden_units = hidden_units 158 | self.num_layers = 1 159 | 160 | self.lstm = nn.LSTM( 161 | input_size=num_sensors, 162 | hidden_size=hidden_units, 163 | batch_first=True, 164 | num_layers=self.num_layers, 165 | ) 166 | 167 | self.linear = nn.Linear(in_features=self.hidden_units, out_features=1) 168 | 169 | def latent_representation(self, x): 170 | batch_size = x.shape[0] 171 | h0 = torch.zeros( 172 | self.num_layers, batch_size, self.hidden_units 173 | ).requires_grad_() 174 | c0 = torch.zeros( 175 | self.num_layers, batch_size, self.hidden_units 176 | ).requires_grad_() 177 | 178 | _, (hn, _) = self.lstm(x, (h0, c0)) 179 | 180 | return hn[0] 181 | 182 | def forward(self, x): 183 | x = self.latent_representation(x) 184 | out = self.linear( 185 | x 186 | ).flatten() # First dim of Hn is num_layers, which is set to 1 above. 187 | 188 | return out 189 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/shap_explainer.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import sys 3 | import copy 4 | from typing import Any, List, Tuple, Optional, Union 5 | from abc import abstractmethod 6 | from pathlib import Path 7 | import os 8 | import pickle as pkl 9 | 10 | # third party 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from torch.utils.data import DataLoader 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | # Interpretability relative 19 | from .utils import data 20 | from .base import Explainer, FeatureExplanation 21 | 22 | # Interpretability absolute 23 | from interpretability.utils.pip import install 24 | from interpretability.exceptions.exceptions import ExplainCalledBeforeFit 25 | 26 | # shap 27 | for retry in range(2): 28 | try: 29 | # third party 30 | import shap 31 | 32 | break 33 | except ImportError: 34 | depends = ["shap"] 35 | install(depends) 36 | 37 | 38 | class ShapExplainerBase(Explainer): 39 | def __init__(self, model, X_explain, y_explain, *argv, **kwargs) -> None: 40 | super(ShapExplainerBase, self).__init__(*argv, **kwargs) 41 | self.has_been_fit = True 42 | self.shap_values = None 43 | self.inner_explainer_save_path = None 44 | 45 | @staticmethod 46 | def type() -> str: 47 | return "explainer" 48 | 49 | def fit(self): 50 | print("SHAP explainers do not need to be fit. Please simply call explain().") 51 | 52 | def explain(self, *argv, **kwargs) -> pd.DataFrame: 53 | """ 54 | The function to get the explanation data from the explainer 55 | """ 56 | ... 57 | 58 | def summary_plot( 59 | self, 60 | explanation: List = None, 61 | show=True, 62 | save_path="temp_shap_plot.png", 63 | **kwargs 64 | ): 65 | """ 66 | Plot the feature importances using the shap summary_plot function. 67 | 68 | Args: 69 | explanation (List, optional): A list of shap_values from an explanation call. Defaults to None, which means it uses the values staored in self.shap_values, generated by the last call of explain(). 70 | """ 71 | if not explanation: 72 | explanation = self.shap_values 73 | shap.summary_plot(explanation, self.explain_inputs, show=show, **kwargs) 74 | if not show: 75 | plt.savefig(save_path) 76 | 77 | # WATERFALL NOT WORKING 78 | # def plot_waterfall(self, class_idx, explanation: List = None): 79 | 80 | # if not explanation: 81 | # explanation = self.shap_values 82 | # shap.plots.waterfall(explanation[class_idx]) 83 | 84 | 85 | class ShapKernelExplainer(ShapExplainerBase): 86 | """ 87 | This is a light-weight wrapper for the kernel explainer from "SHAP", which is 88 | available from . 89 | Additional functionality from the source class is accessible via the 'explainer' 90 | object. 91 | """ 92 | 93 | def __init__(self, model, X_explain, y_explain, *argv, **kwargs) -> None: 94 | self.explain_inputs = X_explain 95 | self.explainer = shap.KernelExplainer( 96 | model, self.explain_inputs, *argv, **kwargs 97 | ) 98 | 99 | @staticmethod 100 | def name() -> str: 101 | return "shap_kernel_explainer" 102 | 103 | @staticmethod 104 | def pretty_name() -> str: 105 | return "SHAP Kernel Explainer" 106 | 107 | @staticmethod 108 | def type() -> str: 109 | return "explainer" 110 | 111 | def explain(self, *argv, **kwargs) -> pd.DataFrame: 112 | """ 113 | The function to get the explanation data from the explainer 114 | """ 115 | self.shap_values = self.explainer.shap_values( 116 | self.explain_inputs, *argv, **kwargs 117 | ) 118 | self.explanation = FeatureExplanation(self.shap_values) 119 | return self.explanation 120 | 121 | 122 | class ShapGradientExplainer(ShapExplainerBase): 123 | """ 124 | This is a light-weight wrapper for the Gradient explainer from "SHAP", which is 125 | available from . 126 | Additional functionality from the source class is accessible via the 'explainer' 127 | object. 128 | """ 129 | 130 | def __init__(self, model, X_explain, y_explain, *argv, **kwargs) -> None: 131 | self.explain_inputs = X_explain 132 | self.explainer = shap.GradientExplainer( 133 | model, self.explain_inputs, *argv, **kwargs 134 | ) 135 | 136 | @staticmethod 137 | def name() -> str: 138 | return "shap_gradient_explainer" 139 | 140 | @staticmethod 141 | def pretty_name() -> str: 142 | return "SHAP Gradient Explainer" 143 | 144 | @staticmethod 145 | def type() -> str: 146 | return "explainer" 147 | 148 | def explain(self, *argv, **kwargs) -> pd.DataFrame: 149 | """ 150 | The function to get the explanation data from the explainer 151 | """ 152 | self.shap_values = self.explainer.shap_values( 153 | self.explain_inputs, *argv, **kwargs 154 | ) 155 | self.explanation = FeatureExplanation(self.shap_values) 156 | return self.explanation 157 | 158 | 159 | class ShapDeepExplainer(ShapExplainerBase): 160 | """ 161 | This is a light-weight wrapper for the Deep explainer from "SHAP", which is 162 | available from . 163 | Additional functionality from the source class is accessible via the 'explainer' 164 | object. 165 | """ 166 | 167 | def __init__(self, model, X_explain, y_explain, *argv, **kwargs) -> None: 168 | explain_data = data.TabularDataset(X_explain, y_explain) 169 | explain_loader = DataLoader( 170 | explain_data, batch_size=len(y_explain), shuffle=True 171 | ) 172 | self.explain_inputs, explain_targets = next(iter(explain_loader)) 173 | self.explainer = shap.DeepExplainer(model, self.explain_inputs, *argv, **kwargs) 174 | 175 | @staticmethod 176 | def name() -> str: 177 | return "shap_deep_explainer" 178 | 179 | @staticmethod 180 | def pretty_name() -> str: 181 | return "SHAP Deep Explainer" 182 | 183 | @staticmethod 184 | def type() -> str: 185 | return "explainer" 186 | 187 | def explain(self, *argv, **kwargs) -> pd.DataFrame: 188 | """ 189 | The function to get the explanation data from the explainer 190 | """ 191 | self.shap_values = self.explainer.shap_values( 192 | self.explain_inputs, *argv, **kwargs 193 | ) 194 | self.explanation = FeatureExplanation(self.shap_values) 195 | return self.explanation 196 | 197 | 198 | class ShapTreeExplainer(ShapExplainerBase): 199 | """ 200 | This is a light-weight wrapper for the Tree explainer from "SHAP", which is 201 | available from . 202 | Additional functionality from the source class is accessible via the 'explainer' 203 | object. 204 | """ 205 | 206 | def __init__(self, model, X_explain, *argv, **kwargs) -> None: 207 | self.explain_inputs = X_explain 208 | self.explainer = shap.TreeExplainer(model, X_explain, *argv, **kwargs) 209 | 210 | @staticmethod 211 | def name() -> str: 212 | return "shap_tree_explainer" 213 | 214 | @staticmethod 215 | def pretty_name() -> str: 216 | return "SHAP Tree Explainer" 217 | 218 | @staticmethod 219 | def type() -> str: 220 | return "explainer" 221 | 222 | def explain(self, *argv, **kwargs) -> pd.DataFrame: 223 | """ 224 | The function to get the explanation data from the explainer 225 | """ 226 | self.shap_values = self.explainer.shap_values( 227 | self.explain_inputs, *argv, **kwargs 228 | ) 229 | self.explanation = FeatureExplanation(self.shap_values) 230 | return self.explanation 231 | 232 | 233 | class ShapLinearExplainer(ShapExplainerBase): 234 | """ 235 | This is a light-weight wrapper for the linear explainer from "SHAP", which is 236 | available from . 237 | Additional functionality from the source class is accessible via the 'explainer' 238 | object. 239 | """ 240 | 241 | def __init__(self, model, X_explain, *argv, **kwargs) -> None: 242 | self.explain_inputs = X_explain 243 | self.explainer = shap.LinearExplainer(model, X_explain, *argv, **kwargs) 244 | 245 | @staticmethod 246 | def name() -> str: 247 | return "shap_linear_explainer" 248 | 249 | @staticmethod 250 | def pretty_name() -> str: 251 | return "SHAP Linear Explainer" 252 | 253 | @staticmethod 254 | def type() -> str: 255 | return "explainer" 256 | 257 | def explain(self, X_explain=None, *argv, **kwargs) -> pd.DataFrame: 258 | """ 259 | The function to get the explanation data from the explainer 260 | """ 261 | 262 | self.shap_values = self.explainer.shap_values( 263 | self.explain_inputs, *argv, **kwargs 264 | ) 265 | self.explanation = FeatureExplanation(self.shap_values) 266 | return self.explanation 267 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is execfile()d with the current directory set to its containing dir. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import os 11 | import sys 12 | import shutil 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | __location__ = os.path.dirname(__file__) 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | sys.path.insert(0, os.path.join(__location__, "../src")) 22 | 23 | # -- Run sphinx-apidoc ------------------------------------------------------- 24 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 25 | # `sphinx-build -b html . _build/html`. See Issue: 26 | # https://github.com/readthedocs/readthedocs.org/issues/1139 27 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 28 | # setup.py install" in the RTD Advanced Settings. 29 | # Additionally it helps us to avoid running apidoc manually 30 | 31 | try: # for Sphinx >= 1.7 32 | from sphinx.ext import apidoc 33 | except ImportError: 34 | from sphinx import apidoc 35 | 36 | output_dir = os.path.join(__location__, "api") 37 | module_dir = os.path.join(__location__, "../src/interpretability") 38 | try: 39 | shutil.rmtree(output_dir) 40 | except FileNotFoundError: 41 | pass 42 | 43 | try: 44 | import sphinx 45 | 46 | cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" 47 | 48 | args = cmd_line.split(" ") 49 | if tuple(sphinx.__version__.split(".")) >= ("1", "7"): 50 | # This is a rudimentary parse_version to avoid external dependencies 51 | args = args[1:] 52 | 53 | apidoc.main(args) 54 | except Exception as e: 55 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 56 | 57 | # -- General configuration --------------------------------------------------- 58 | 59 | # If your documentation needs a minimal Sphinx version, state it here. 60 | # needs_sphinx = '1.0' 61 | 62 | # Add any Sphinx extension module names here, as strings. They can be extensions 63 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 64 | extensions = [ 65 | "sphinx.ext.autodoc", 66 | "sphinx.ext.intersphinx", 67 | "sphinx.ext.todo", 68 | "sphinx.ext.autosummary", 69 | "sphinx.ext.viewcode", 70 | "sphinx.ext.coverage", 71 | "sphinx.ext.doctest", 72 | "sphinx.ext.ifconfig", 73 | "sphinx.ext.mathjax", 74 | "sphinx.ext.napoleon", 75 | ] 76 | 77 | # Add any paths that contain templates here, relative to this directory. 78 | templates_path = ["_templates"] 79 | 80 | # The suffix of source filenames. 81 | source_suffix = ".rst" 82 | 83 | # The encoding of source files. 84 | # source_encoding = 'utf-8-sig' 85 | 86 | # The master toctree document. 87 | master_doc = "index" 88 | 89 | # General information about the project. 90 | project = "interpretability" 91 | copyright = "2022, robsdavis" 92 | 93 | # The version info for the project you're documenting, acts as replacement for 94 | # |version| and |release|, also used in various other places throughout the 95 | # built documents. 96 | # 97 | # version: The short X.Y version. 98 | # release: The full version, including alpha/beta/rc tags. 99 | # If you don’t need the separation provided between version and release, 100 | # just set them both to the same value. 101 | try: 102 | from interpretability import __version__ as version 103 | except ImportError: 104 | version = "" 105 | 106 | if not version or version.lower() == "unknown": 107 | version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD 108 | 109 | release = version 110 | 111 | # The language for content autogenerated by Sphinx. Refer to documentation 112 | # for a list of supported languages. 113 | # language = None 114 | 115 | # There are two options for replacing |today|: either, you set today to some 116 | # non-false value, then it is used: 117 | # today = '' 118 | # Else, today_fmt is used as the format for a strftime call. 119 | # today_fmt = '%B %d, %Y' 120 | 121 | # List of patterns, relative to source directory, that match files and 122 | # directories to ignore when looking for source files. 123 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] 124 | 125 | # The reST default role (used for this markup: `text`) to use for all documents. 126 | # default_role = None 127 | 128 | # If true, '()' will be appended to :func: etc. cross-reference text. 129 | # add_function_parentheses = True 130 | 131 | # If true, the current module name will be prepended to all description 132 | # unit titles (such as .. function::). 133 | # add_module_names = True 134 | 135 | # If true, sectionauthor and moduleauthor directives will be shown in the 136 | # output. They are ignored by default. 137 | # show_authors = False 138 | 139 | # The name of the Pygments (syntax highlighting) style to use. 140 | pygments_style = "sphinx" 141 | 142 | # A list of ignored prefixes for module index sorting. 143 | # modindex_common_prefix = [] 144 | 145 | # If true, keep warnings as "system message" paragraphs in the built documents. 146 | # keep_warnings = False 147 | 148 | # If this is True, todo emits a warning for each TODO entries. The default is False. 149 | todo_emit_warnings = True 150 | 151 | 152 | # -- Options for HTML output ------------------------------------------------- 153 | 154 | # The theme to use for HTML and HTML Help pages. See the documentation for 155 | # a list of builtin themes. 156 | html_theme = "alabaster" 157 | 158 | # Theme options are theme-specific and customize the look and feel of a theme 159 | # further. For a list of options available for each theme, see the 160 | # documentation. 161 | html_theme_options = { 162 | "sidebar_width": "300px", 163 | "page_width": "1200px" 164 | } 165 | 166 | # Add any paths that contain custom themes here, relative to this directory. 167 | # html_theme_path = [] 168 | 169 | # The name for this set of Sphinx documents. If None, it defaults to 170 | # " v documentation". 171 | # html_title = None 172 | 173 | # A shorter title for the navigation bar. Default is the same as html_title. 174 | # html_short_title = None 175 | 176 | # The name of an image file (relative to this directory) to place at the top 177 | # of the sidebar. 178 | # html_logo = "" 179 | 180 | # The name of an image file (within the static path) to use as favicon of the 181 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 182 | # pixels large. 183 | # html_favicon = None 184 | 185 | # Add any paths that contain custom static files (such as style sheets) here, 186 | # relative to this directory. They are copied after the builtin static files, 187 | # so a file named "default.css" will overwrite the builtin "default.css". 188 | html_static_path = ["_static"] 189 | 190 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 191 | # using the given strftime format. 192 | # html_last_updated_fmt = '%b %d, %Y' 193 | 194 | # If true, SmartyPants will be used to convert quotes and dashes to 195 | # typographically correct entities. 196 | # html_use_smartypants = True 197 | 198 | # Custom sidebar templates, maps document names to template names. 199 | # html_sidebars = {} 200 | 201 | # Additional templates that should be rendered to pages, maps page names to 202 | # template names. 203 | # html_additional_pages = {} 204 | 205 | # If false, no module index is generated. 206 | # html_domain_indices = True 207 | 208 | # If false, no index is generated. 209 | # html_use_index = True 210 | 211 | # If true, the index is split into individual pages for each letter. 212 | # html_split_index = False 213 | 214 | # If true, links to the reST sources are added to the pages. 215 | # html_show_sourcelink = True 216 | 217 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 218 | # html_show_sphinx = True 219 | 220 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 221 | # html_show_copyright = True 222 | 223 | # If true, an OpenSearch description file will be output, and all pages will 224 | # contain a tag referring to it. The value of this option must be the 225 | # base URL from which the finished HTML is served. 226 | # html_use_opensearch = '' 227 | 228 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 229 | # html_file_suffix = None 230 | 231 | # Output file base name for HTML help builder. 232 | htmlhelp_basename = "interpretability-doc" 233 | 234 | 235 | # -- Options for LaTeX output ------------------------------------------------ 236 | 237 | latex_elements = { 238 | # The paper size ("letterpaper" or "a4paper"). 239 | # "papersize": "letterpaper", 240 | # The font size ("10pt", "11pt" or "12pt"). 241 | # "pointsize": "10pt", 242 | # Additional stuff for the LaTeX preamble. 243 | # "preamble": "", 244 | } 245 | 246 | # Grouping the document tree into LaTeX files. List of tuples 247 | # (source start file, target name, title, author, documentclass [howto/manual]). 248 | latex_documents = [ 249 | ("index", "user_guide.tex", "interpretability Documentation", "robsdavis", "manual") 250 | ] 251 | 252 | # The name of an image file (relative to this directory) to place at the top of 253 | # the title page. 254 | # latex_logo = "" 255 | 256 | # For "manual" documents, if this is true, then toplevel headings are parts, 257 | # not chapters. 258 | # latex_use_parts = False 259 | 260 | # If true, show page references after internal links. 261 | # latex_show_pagerefs = False 262 | 263 | # If true, show URL addresses after external links. 264 | # latex_show_urls = False 265 | 266 | # Documents to append as an appendix to all manuals. 267 | # latex_appendices = [] 268 | 269 | # If false, no module index is generated. 270 | # latex_domain_indices = True 271 | 272 | # -- External mapping -------------------------------------------------------- 273 | python_version = ".".join(map(str, sys.version_info[0:2])) 274 | intersphinx_mapping = { 275 | "sphinx": ("https://www.sphinx-doc.org/en/master", None), 276 | "python": ("https://docs.python.org/" + python_version, None), 277 | "matplotlib": ("https://matplotlib.org", None), 278 | "numpy": ("https://numpy.org/doc/stable", None), 279 | "sklearn": ("https://scikit-learn.org/stable", None), 280 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 281 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 282 | "setuptools": ("https://setuptools.pypa.io/en/stable/", None), 283 | "pyscaffold": ("https://pyscaffold.org/en/stable", None), 284 | } 285 | 286 | print(f"loading configurations for {project} {version} ...", file=sys.stderr) 287 | -------------------------------------------------------------------------------- /Notebooks/Tutorial_05_implement_symbolic_pursuit.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","metadata":{},"source":["# Tutorial 5 - Symbolic Pursuit\n","\n","In this tutorial we we create a symbolic pursuit explainer object and use it to get an explanation of the predictive model. The explainer is then saved to disk and can be given to someone else to view in the [Interpretability Suite App](https://vanderschaarlab-demo-interpretabi-interpretability-suite-1uteyn.streamlit.app/).\n","\n","We will be explaining the predictions of a multilayer perceptron regressor provided by sci-kit learn that is trained on the diabetes dataset also from sci-kit learn.\n","\n","### Import the relevant modules"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["# IMPORTS\n","\n","# Third Party\n","from sklearn.datasets import load_diabetes\n","from sklearn.model_selection import train_test_split\n","from sklearn.neural_network import MLPRegressor#\n","from sklearn.ensemble import RandomForestRegressor\n","from sklearn.linear_model import LinearRegression#\n","from sklearn.metrics import mean_squared_error\n","\n","# Interpretability\n","from interpretability.interpretability_models import symbolic_pursuit_explainer\n","from interpretability.interpretability_models.utils import io"]},{"cell_type":"markdown","metadata":{},"source":["### Load the data \n","Load the data and split it into the different sets for training the mlp, fitting the explainer and testing the explainer."]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[],"source":["\n","X, y = load_diabetes(return_X_y=True, as_frame=True)\n","\n","feature_names = X.columns.to_list()\n","X_mlp_train, X_test, y_mlp_train, y_test = train_test_split(X, y, test_size=0.5)\n","X_expl_train, X_explain_test, y_expl_train, y_explain_test = train_test_split(X_test, y_test, test_size=0.2)\n","\n","X_mlp_train, y_mlp_train, X_expl_train, y_expl_train, X_explain_test, y_explain_test = (\n"," X_mlp_train.to_numpy(),\n"," y_mlp_train.to_numpy(),\n"," X_expl_train.to_numpy(),\n"," y_expl_train.to_numpy(),\n"," X_explain_test.to_numpy(),\n"," y_explain_test.to_numpy(),\n",")"]},{"cell_type":"markdown","metadata":{},"source":["### Train the model\n","\n","We will train our own model here using the sklearn library here. We simply have to initialize it and fit it, before it is ready to pass to the SymbolicPursuitExplainer in the next step."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAOwAAAAQCAYAAAAVg5N2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAABJ0AAASdAHeZh94AAAIIUlEQVR4nO2bfcyXVRnHPygODY0avrBe5guTQtLeFdKQJw0tyKKytoaJm5BTByqgRrOLr82ElghqTVL3oOU/li8TxScUWYSSbfowbZg6EFCX1CAcimTA0x/XufHm/t3373ef8/u1cnu+22/nuc+5rnOu6zrXebvOeQb09fXRj370472BgcUMSR8BrgXOBoYCfwMeAGRm/6xbsaT5wOeAEcDhwNvAplDXLWa2NUc7BehuUeVeMzuw0/JKmgz8OnxONbPbO0lflyfGXiW8ZwCXAmOADwJbgeeARWa2LNAMBSYBE4ATgQ8D7wS6bqDbzPZW1D8BmAGcwLs2fhpYYGZrKnii+yWGR9JG4OgKk2wxs2EdaGMKCX4Za68YXQ4oMA4PFV8A/Bm4EdgQGl8TOr0uLgcGA48Ci4C7gd3AXOBZSR/N0a4FVPF7PNA8UmygXXmDDLcAb9ZRKJY+kifGXvn6fwY8hg/2B4EbgIeBI4BxOdJzgduAU4CngIXAvcAngNuBeyQNKKl/PvAQ8BmgJ8j2DPB14IkwGRV5ovslsS/foNxnfl5Cm9LG2or6K/0yxV4xuhRX2F8CRwLTzezmnBALcIe6DrioosEi3m9mu4qZkq4D5gA/BC4GMLO1uHEaICmbkX5VUpwsb3DObnw1ug+Y1UyZWPoEntr2ypVNBWYDdwLTzOydQvlBuc8XgXOAh/MrqaQ5uPN+C/gmPoizsmFB5i3ASWb291xZF+601wK/KYid0i8pPNvNbC71EdVGrF+2Ya/auuxbYcPsMx7YCPyiQGfAW8B5kga3qhSgzPkC7gnp8a3qkHQiMBp4DV818mXtyjsd+BI+277VSpYE+iieWHtJGoQ72GZKBmuo89+5vx83s6XFba+ZvQ7cGj7HFao4GveRp/LOF/hWAjvwlTwvV3S/dNr3ytDJNpr4ZbS9YpHfEneFdHlJp+4AngDeFwRtB18L6bM1aKeF9A4z21MoS5ZX0khgHn7GW9VKiFj6VJ4KVNnry3jn3wfslTRB0lWSZkgaE9lGNrB3F/Jfws+5J0s6PF8gaSxwGL4dzyOlX1L7cpCkyZLmBL27JDXEOdpsowxVfpliryhd8lvij4X0xYoKX8JnqBHAiqbq7C/oLOBQYAh+zjoNd755LfgOASYDe/AzVhFJ8koaiAeANuNbzVbyR9Gn8uR469rr8yHdBfTiZ9F8PauAb5vZP2rI+v3w2ZMvM7Ntkq4CFgDrJD2Ab++H49vrR4EfFKpM6ZdU3xvGu8G8DC9LusDM/tABuRrQzC8T7RWlS36FHRLSNyoqzPI/UKVMBWbhW47LcOfrAca3ciTgO6GtHjN7paQ8Vd4fA58GppjZ2y1kSKFP5clQ115HhnQ20Ad8EZ/BTwKWA2OB39Zobx4+2JeZ2e+LhWa2ED/bDgSmAlfjAaxXgCXFrR9p/ZLC0w2cgTv6YDzyvRg4BnhE0ic7IFcZmvplgr2idGm41uk0spC0pKOAL+AO0itpopk904Q123Ys7pQskk7BV7wbqq4j2qFP5ckjwl7ZZLsbOMfMNobv5yRNAl4ATpc0psnVy3RgJvBX4LwKmiuBnwI34dHu14GPA9cDd0v6lJldGatnuzAzFbL+Alwk6U1cp7n4NVan0dQvU+wVo0t+hc1mmCGUI8vf3lSdCpjZFjO7H992DAXuqqKVNAp31leBZRVkUfKGrd9d+JbomlbyxtKn8lShhr22h7Q3N1gz3p1AtlqeXCHrpfiVwzqgy8y2ldCMA+YDD5rZFWa2wcx2holjEh50mSnpuBxbih910veyANrYQn7bbbTyy0R7NUODLvkB+0JIR1QwZ1HKqjNALZjZJtxJRhUP5jk0CzZliJX30EA7EtglqS/74VtQgNtC3sIE+pQ2WqKJvTL9t1ewZo8ADikWSLoMuBmfybtCpLgME0O6skSunfh10AH49r8oV4wfddL3sqNDMdrbiTZa+WWKvZqhQZf8gM0aGS+p+KDiMOBUYCfwp5qNNcOHQtqgtKSD8e3ZHuCOJnXEyvuvUF/ZrzfQrA7faxLoU9qoizJ7rcDPricU9Q/IglAv5zNDUORG/H6xq+JMlWFQSKuuIrL8/JVSih910veyKO+GQn5bbdT0yxR7NUODLvvOsGa2XtJyfAt2CT4D75MXH+WLzWzffWK42zoIWJ+/85M0An9Std8BPxjqJ3jA5MmKJ2rn4s/rHqoINiXJG4I/F5bVJWkuPuvdWXg2GEWf2EaSvcxsk6SlePRxBj4IM77xwFn46tuTy78Gv7h/Gg9kNWyDC/gj/uRxmqTFZvZarq6v4E6+C3gyJ1e0H8XyhCuzzfk6Qv4x+LkRCo8TUuQqoI5fRtsrVpdi0OniUNlN8vepz+NP2brwrcKPCvQr8MviY/EL6QxfBa6XtBqf4bcCRwGnA8fhB/GpFUpn246yl01FxMr7/4pUe12CTwIL5O9Xe/G++Aa+ElyYTQKSzscH6x7csaZLxVgHG81sSe77d/i94ZnA85LuD7KMxLd/A4CrrfGdc0q/xPB8Fz8LrsLfW+/Ar04mAAfj58uy54nt+Esdv0yxV5Qu+20NzGw9fve3JCgyMzAvAkaXdEwVHsO3DUfgIe7Z+NO3bfhsNsrM1hWZwmxzGs2DTf8Nef/XSLKXmb0KfBafiY/HV9pxwFLgVDO7N0d+bEgPxK+MrOQ3pVD/XnwyuRw/R0/CbTwa75+zzGxRiVzR/RLJsxJ/rzsc+B5wBT65rQbOByZWvPxK8pe6fploryhdBvT/e10/+vHewX8AzB5JXenZEZEAAAAASUVORK5CYII=","text/latex":["$\\displaystyle 0.37044436280050785$"],"text/plain":["0.37044436280050785"]},"execution_count":16,"metadata":{},"output_type":"execute_result"}],"source":["## Load the model\n","\n","model = LinearRegression()\n","model.fit(X_mlp_train, y_mlp_train)\n","model.score(X_explain_test, y_explain_test)"]},{"cell_type":"markdown","metadata":{},"source":["### Initialize and fit Symbolic Pursuit Explainer\n","Initialize the explainer object by passing the models predict function, and the data to explain. The fit step can take some considerable time, this can be reduced by reducing the `patience` argument. The default value for patience is 10, reducing it significantly from that value, may affect the quality of the explanation."]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[],"source":["\n","my_explainer = symbolic_pursuit_explainer.SymbolicPursuitExplainer(\n"," model.predict, X_expl_train, feature_names=feature_names, patience=1\n",")\n","my_explainer.fit()"]},{"cell_type":"markdown","metadata":{},"source":["### Measure fit quality\n","This prints the Mean Squared Error for both the mlp predictive model and the learned symbolic model"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["MSE score for the model: 3400.9267047529574\n","MSE score for the Symbolic Regressor: 3913.308267786906\n"]}],"source":["my_explainer.measure_fit_quality(X_explain_test, y_explain_test)"]},{"cell_type":"markdown","metadata":{},"source":["### Get the explanation\n","Get the explanation of the model in terms of its symbolic expression and projections of the predictive model."]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["0\n"]}],"source":["explanation = my_explainer.explain(X_explain_test[3], taylor_expansion_order=2)"]},{"cell_type":"markdown","metadata":{},"source":["### Show the symbolic expression and projections"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["1.28673528875237*meijerg(((-0.0238334588166228,), (0.677142219058804,)), ((-1.89775855719856, -0.76336887618189), (-0.434727427351685,)), 1.0*[ReLU(P1)])\n","P1 = -0.261979271495097*X1 - 0.568545152086398*X10 + 0.884542293937177*X2 - 1.89124887830286*X3 - 0.840798636983763*X4 - 0.0968543199716866*X5 + 0.700029016346244*X6 + 1.11613582817466*X7 + 0.113530407585627*X8 - 2.18748153588367*X9\n","\n"]},{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
agesexbmibps1s2s3s4s5s6
00.00.00.00.00.00.00.00.00.00.0
\n","
"],"text/plain":[" age sex bmi bp s1 s2 s3 s4 s5 s6\n","0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["The taylor expansion that calculates feature interactions is not available. Try fitting the SymbolicPursuitExplainer for more iterations, by increasing `patience` or reducing `loss_tol`.\n"]}],"source":["my_explainer.summary_plot(show_expression=False, show_feature_importance=True, show_feature_interactions=True)"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["[1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05, 1e-05]\n","[201.90674082 201.90674082 201.90674082 201.90674082 201.90674082\n"," 201.90674082 201.90674082 201.90674082 201.90674082 201.90674082]\n"]}],"source":["print(explanation.feature_importance)\n","print(my_explainer.symbolic_model.predict(X_explain_test[1]))"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["0\n"]}],"source":["print(explanation.taylor_expansion)"]},{"cell_type":"markdown","metadata":{},"source":["### Save the explainer to file\n","This file can now be uploaded to the [Interpretability Suite App](https://vanderschaarlab-demo-interpretabi-interpretability-suite-1uteyn.streamlit.app/). This provides a non-programtic interface with which to view the various explanations, allowing you to send the explainer to a colleague who is less fluent in python."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Saving explainer to: /home/rob/Documents/projects/Interpretability/Notebooks/my_new_diabetes_sklearn_random_forrest_symbolic_pursuit_explainer.p\n"]}],"source":["\n","io.save_explainer(\n"," my_explainer, \"my_new_diabetes_sklearn_random_forrest_symbolic_pursuit_explainer.p\"\n",")"]}],"metadata":{"kernelspec":{"display_name":"Python 3.8.13 ('interp')","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"6fd73f071793638ac14baf0ff0f19e5ab81431475f40d47f0df0002312a62017"}}},"nbformat":4,"nbformat_minor":2} 2 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/dynamask_explainer.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import sys 3 | import copy 4 | from typing import Any, List, Optional, Union 5 | from abc import abstractmethod 6 | 7 | # third party 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | import numpy as np 11 | import pandas as pd 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | 16 | # interpretability relative 17 | from .utils import data 18 | from .base import Explainer, FeatureExplanation 19 | 20 | # Interpretability absolute 21 | from interpretability.utils.pip import install 22 | from interpretability.exceptions import exceptions 23 | 24 | # dynamask 25 | for retry in range(2): 26 | try: 27 | # third party 28 | import dynamask 29 | 30 | break 31 | except ImportError: 32 | depends = ["dynamask"] 33 | install(depends) 34 | 35 | from dynamask.attribution import mask, mask_group, perturbation 36 | from dynamask.utils import losses 37 | 38 | 39 | class DynamaskExplainer(Explainer): 40 | def __init__( 41 | self, 42 | model: Any, 43 | perturbation_method: str = "gaussian_blur", 44 | group: bool = False, 45 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 46 | ) -> None: 47 | """ 48 | Initialises the mask. 49 | 50 | Args: 51 | model (Any): The model to explain. Must be a trained pytorch model. 52 | perturbation_method (str): The method to create and apply perturbation on inputs based on masks. Defaults to "gaussian_blur". 53 | group (bool): Boolean value to select whether or not to use a MaskGroup. MaskGroups allow fitting several masks of different areas simultaneously. Defaults to False. 54 | device (str,): The device to send torch.tensors. Defaults to "cuda" if torch.cuda.is_available() else "cpu". 55 | """ 56 | self.DEVICE = device 57 | model = model.train() 58 | model = model.to(self.DEVICE) 59 | 60 | def f(x): 61 | x = x.unsqueeze(0) 62 | out = model(x).float() 63 | return out 64 | 65 | self.model = f 66 | available_perturbation_methods = { 67 | "fade_moving_average": perturbation.FadeMovingAverage, 68 | "gaussian_blur": perturbation.GaussianBlur, 69 | "fade_moving_average_window": perturbation.FadeMovingAverageWindow, 70 | "fade_moving_average_past_window": perturbation.FadeMovingAveragePastWindow, 71 | "fade_reference": perturbation.FadeReference, 72 | } 73 | self.perturbation_method = available_perturbation_methods[perturbation_method] 74 | self.perturbation = None 75 | self.perturbation_baseline = None 76 | self.group = group 77 | self.mask_class = mask_group.MaskGroup if group else mask.Mask 78 | self.mask = None 79 | self.loss_function = None 80 | self.all_data = None 81 | self.explain_data = None 82 | self.explain_target = None 83 | super().__init__() 84 | 85 | def fit( 86 | self, 87 | explain_id: int, 88 | X: Optional[np.array] = None, 89 | loss_function: str = "mse", 90 | target: Optional[np.array] = None, 91 | baseline: Optional[torch.tensor] = None, 92 | area_list: Union[np.array, List] = np.arange(0.001, 0.051, 0.001), 93 | ): 94 | """ 95 | Trains the mask. 96 | 97 | Args: 98 | X (np.array, optional): The data to be explained. 99 | loss_function (str): The name of the loss function to use, e.g. "cross_entropy", "log_loss", "log_loss_target", or "mse" Defaults to "mse".. Defaults to "mse". 100 | target (np.array, optional): The target for the data being explained. Defaults to None. If none provided targets are generated from the blackbox model. 101 | baseline (torch.tensor, optional): A baseline for the perturbation method. Only required for fade_reference. Defaults to None. 102 | area_list (Union[np.array, List]): List of areas for the group mask. Defaults to np.arange(0.001, 0.051, 0.001). 103 | 104 | Returns: 105 | None 106 | pd.DataFrame: The importance dataframe. This is of shape time_steps x features that contains the calculated importance values. 107 | """ 108 | if X is not None: 109 | self.all_data = X 110 | else: 111 | if self.all_data is None: 112 | raise exceptions.NoDataToExplain 113 | if target is not None: 114 | self.target = torch.tensor(target).to(self.DEVICE).detach() 115 | self.explain_data = ( 116 | torch.tensor(self.all_data[explain_id]).float().to(self.DEVICE) 117 | ) 118 | 119 | available_loss_functions = { 120 | "cross_entropy": losses.cross_entropy, 121 | "log_loss": losses.log_loss, 122 | "log_loss_target": losses.log_loss_target, 123 | "mse": losses.mse, 124 | } 125 | self.loss_function = available_loss_functions[loss_function] 126 | # Fit a mask to the input with a Gaussian Blur perturbation: 127 | if self.perturbation_method == perturbation.FadeReference: 128 | if baseline: 129 | self.perturbation_baseline = baseline 130 | else: 131 | self.perturbation_baseline = torch.zeros(size=self.explain_data.shape) 132 | self.perturbation = self.perturbation_method( 133 | self.DEVICE, self.perturbation_baseline 134 | ) 135 | else: 136 | self.perturbation = self.perturbation_method(self.DEVICE) 137 | self.mask = self.mask_class(self.perturbation, self.DEVICE) 138 | 139 | print("Fitting Dynamask") 140 | if self.group: 141 | self.mask.fit( 142 | self.explain_data, 143 | self.model, 144 | area_list, 145 | loss_function=self.loss_function, 146 | n_epoch=1000, 147 | initial_mask_coeff=0.5, 148 | size_reg_factor_init=0.1, 149 | size_reg_factor_dilation=100, 150 | learning_rate=0.1, 151 | momentum=0.9, 152 | time_reg_factor=0, 153 | ) 154 | else: 155 | self.mask.fit( 156 | self.explain_data, 157 | self.model, 158 | loss_function=self.loss_function, 159 | target=self.target, 160 | n_epoch=500, 161 | keep_ratio=0.5, 162 | initial_mask_coeff=0.5, 163 | size_reg_factor_init=0.5, 164 | size_reg_factor_dilation=100, 165 | time_reg_factor=0, 166 | learning_rate=1.0e-1, 167 | momentum=0.9, 168 | ) 169 | self.has_been_fit = True 170 | 171 | def refit(self, explain_id: int): 172 | """A Helper function to fit the model again with the same parameters but for a different data record. 173 | 174 | Args: 175 | explain_id (int): The id of the record to get the explanation for by refitting 176 | """ 177 | print("Re-fitting dynamask") 178 | self.fit(explain_id) 179 | 180 | def explain( 181 | self, 182 | ids_time: Union[List, np.array] = None, 183 | ids_feature: Union[List, np.array] = None, 184 | smooth: bool = False, 185 | sigma: float = 1.0, 186 | get_mask_from_group_method: str = "best", 187 | extremal_mask_threshold: float = 0.01, 188 | ) -> FeatureExplanation: 189 | """ 190 | Get the explanation from the trained mask. 191 | 192 | Args: 193 | ids_time (Union[list, np.array], optional): A list of time steps to focus to explanation on. Defaults to None leading to all time steps being included in the explanation. 194 | ids_feature (Union[list, np.array], optional): A list of features to focus to explanation on. Defaults to None leading to all features being included in the explanation. 195 | smooth (bool, optional): A boolean value to state weather or not to smooth the mask (i.e. interpolate between extreme values to provide a smooth transition in the time dimention). Defaults to False. 196 | sigma (float, optional): Width of the smoothing Gaussian kernel.. Defaults to 1.0. 197 | get_mask_from_group_method (str, optional): Can take values of "best" or "extremal". "best" returns the mask with lowest error. "extremal" returns the extremal mask for the acceptable error threshold. Defaults to "best". 198 | extremal_mask_threshold (float, optional): The acceptable error threshold for extremal masks. Defaults to 0.01. 199 | 200 | Returns: 201 | FeatureExplanation: A simple feature importance pd.dataframe where columns refer to the time steps and rows refer to the features. 202 | """ 203 | if self.has_been_fit: 204 | if self.group: 205 | available_get_mask_from_group_method = { 206 | "best": self.mask.get_best_mask(), 207 | "extremal": self.mask.get_extremal_mask(extremal_mask_threshold), 208 | } 209 | mask_tensor_list = available_get_mask_from_group_method[ 210 | get_mask_from_group_method 211 | ] 212 | submask_tensor_np = mask_tensor_list.mask_tensor.numpy() 213 | df = pd.DataFrame( 214 | data=np.transpose(submask_tensor_np), 215 | ) 216 | else: 217 | if smooth: 218 | mask_tensor = self.mask.get_smooth_mask(sigma) 219 | else: 220 | mask_tensor = self.mask.mask_tensor 221 | # Extract submask from ids 222 | submask_tensor_np = self.mask.extract_submask( 223 | mask_tensor, ids_time, ids_feature 224 | ).numpy() 225 | df = pd.DataFrame( 226 | data=np.transpose(submask_tensor_np), 227 | index=ids_feature, 228 | columns=ids_time, 229 | ) 230 | self.explanation = FeatureExplanation(df) 231 | return self.explanation 232 | else: 233 | raise exceptions.ExplainCalledBeforeFit(self.has_been_fit) 234 | 235 | def summary_plot( 236 | self, 237 | explanation: List = None, 238 | show: bool = True, 239 | save_path: str = "temp_dynamask_plot.png", 240 | ) -> None: 241 | """This method plots (part of) the mask. 242 | 243 | Args: 244 | explanation (List, optional): The FeatureExplanation returned by .explain(). Defaults to None, in which case it is assumed the explanation is from the result of the previous explain() call. 245 | show (bool, optional): Boolean value to decide if the plot is displayed. Defaults to True. 246 | save_path (str, optional): The path with which to save the plot if show is set to false. Defaults to "temp_dynamask_plot.png". 247 | Returns: 248 | None 249 | """ 250 | if not explanation: 251 | explanation = self.explanation 252 | sns.set() 253 | # Generate heatmap plot 254 | color_map = sns.diverging_palette(10, 133, as_cmap=True) 255 | sns.heatmap( 256 | data=explanation.feature_importances, 257 | cmap=color_map, 258 | cbar_kws={"label": "Mask"}, 259 | vmin=0, 260 | vmax=1, 261 | ) 262 | plt.xlabel("Time") 263 | plt.ylabel("Feature Number") 264 | plt.title("Mask coefficients over time") 265 | if show: 266 | plt.show() 267 | else: 268 | plt.savefig(save_path) 269 | 270 | @staticmethod 271 | def name() -> str: 272 | return "dynamask" 273 | 274 | @staticmethod 275 | def pretty_name() -> str: 276 | return "Dynamask" 277 | 278 | @staticmethod 279 | def type() -> str: 280 | return "explainer" 281 | -------------------------------------------------------------------------------- /src/interpretability/interpretability_models/symbolic_pursuit_explainer.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import sys 3 | import os 4 | import copy 5 | from typing import Any, List, Tuple, Optional, Union 6 | from abc import abstractmethod 7 | import inspect 8 | import itertools 9 | 10 | # third party 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from torch.utils.data import DataLoader 15 | import sympy as smp # We use sympy to display mathematical expresssions 16 | from sklearn.metrics import ( 17 | mean_squared_error, 18 | accuracy_score, 19 | ) # we are going to assess the quality of the SymbolicRegressor based on the MSE 20 | from PIL import Image 21 | import matplotlib.pyplot as plt 22 | import seaborn as sns 23 | 24 | # Interpretability relative 25 | from .utils import data 26 | from .base import Explainer, Explanation 27 | 28 | # Interpretability absolute 29 | from interpretability.utils.pip import install 30 | from interpretability.exceptions import exceptions 31 | 32 | # symbolic-pursuit 33 | for retry in range(2): 34 | try: 35 | # third party 36 | import symbolic_pursuit 37 | 38 | break 39 | except ImportError: 40 | depends = ["symbolic-pursuit"] 41 | install(depends) 42 | from symbolic_pursuit import models 43 | 44 | 45 | class SymbolicPursuitExplanation(Explanation): 46 | """ 47 | The explanation object for symbolic pursuit 48 | """ 49 | 50 | def __init__( 51 | self, 52 | expression, 53 | projections, 54 | x0: np.array, 55 | feature_importance: List, 56 | taylor_expansion: smp.core.add.Add, 57 | model_fit_quality: Optional[float] = None, 58 | fit_quality: Optional[float] = None, 59 | ) -> None: 60 | """Initialize the explanation object 61 | 62 | Args: 63 | expression (smp.core.add.Add): The symbolic expression of the model. 64 | projections (List): The projections in the symbolic expression. 65 | x0 (np.array): The record to evaluate the feature importance and feature interaction for. 66 | feature_importance (smp.core.add.Add): The feature importance produced by SymbolicPursuitExplainer.symbolic_model.get_feature_importance(x0). 67 | taylor_expansion (List): The taylor expansion produced by SymbolicPursuitExplainer.symbolic_model.get_taylor(x0, order). 68 | model_fit_quality (Optional[float]): The MSE score for the predictive model based on a test dataset. Needs measure_fit_quality() to be run. Defaults to None. 69 | fit_quality (Optional[float]): The MSE score for the symbolic model based on a test dataset. Needs measure_fit_quality() to be run. Defaults to None. 70 | 71 | """ 72 | self.expression = expression 73 | self.projections = projections 74 | self.x0 = x0 75 | self.feature_importance = feature_importance 76 | self.taylor_expansion = taylor_expansion 77 | self.model_fit_quality = model_fit_quality 78 | self.fit_quality = fit_quality 79 | super().__init__() 80 | print(smp.latex(smp.expand(taylor_expansion))) 81 | 82 | @staticmethod 83 | def name() -> str: 84 | return "Symbolic Pursuit Explanation" 85 | 86 | 87 | class SymbolicPursuitExplainer(Explainer): 88 | def __init__( 89 | self, model: Any, X_explain: np.array, feature_names: List = [], *argv, **kwargs 90 | ) -> None: 91 | """ 92 | SymbolicPursuitExplainer 93 | 94 | This explainer can take a very long time to fit. If fitting time is an issue there are several 95 | options you can pass to reduce it, such as increased `loss_tol` or reduced `patience`. 96 | 97 | Args: 98 | model (Any): The model to approximate. 99 | X_explain (np.array): The data used to fit the SymbolicRegressor. 100 | loss_tol: The tolerance for the loss under which the pursuit stops. Defaults to 1.0e-3, 101 | ratio_tol: A new term is added only if new_loss / old_loss < ratio_tol. Defaults to 0.9, 102 | maxiter: Maximum number of iterations for optimization. Defaults to 100, 103 | eps: The smallest representable number such that 1.0 + eps != 1.0. Defaults to 1.0e-5, 104 | random_seed (int): The random seed for reproducibility. This is passed to . Defaults to 42, 105 | baselines (List): Defaults to list(load_h().keys()), 106 | task_type (str): Either the string "classification" or "regression". Defaults to "regression", 107 | patience (int) : A hard limit on the number of optimisation loops in fit(). Defaults to 10, 108 | """ 109 | self.model = model 110 | self.X_explain = X_explain 111 | self.model_fit_quality = None 112 | self.fit_quality = None 113 | if feature_names: 114 | self.feature_names = feature_names 115 | else: 116 | self.feature_names = list(range(X_explain.shape[1])) 117 | 118 | super().__init__() 119 | 120 | smp.init_printing() 121 | self.symbolic_model = models.SymbolicRegressor(*argv, **kwargs) 122 | 123 | @staticmethod 124 | def name() -> str: 125 | return "symbolic_pursuit_explainer" 126 | 127 | @staticmethod 128 | def pretty_name() -> str: 129 | return "Symbolic Pursuit Explainer" 130 | 131 | def fit(self): 132 | """ 133 | Fit the symbolic Regressor 134 | """ 135 | # try to fit with numpy array (which works for some models e.g. sklearn models) 136 | for retry in range(2): 137 | try: 138 | self.symbolic_model.fit(self.model, self.X_explain) 139 | break 140 | # If that fails due to expecting a different type for X_explain try again with X_explain as a torch tensor 141 | # This works for pytorch models 142 | except TypeError: 143 | self.X_explain = torch.Tensor(self.X_explain) 144 | self.has_been_fit = True 145 | 146 | def measure_fit_quality(self, X_test: np.array, y_test: np.array): 147 | 148 | if self.has_been_fit: 149 | self.X_test = X_test 150 | self.y_test = y_test 151 | if self.symbolic_model.task_type == "classification": 152 | for retry in range(2): 153 | try: 154 | self.fit_quality = accuracy_score( 155 | self.y_test, self.symbolic_model.predict(self.X_test) 156 | ) 157 | break 158 | except TypeError: 159 | self.X_test = torch.Tensor(self.X_test) 160 | self.y_test = torch.Tensor(self.y_test) 161 | for retry in range(2): 162 | try: 163 | self.model_fit_quality = accuracy_score( 164 | self.y_test, self.model(self.X_test) 165 | ) 166 | break 167 | except TypeError: 168 | self.X_test = torch.Tensor(self.X_test) 169 | self.y_test = torch.Tensor(self.y_test) 170 | 171 | print(f"Accuracy score for the model: {self.model_fit_quality}") 172 | print(f"Accuracy score for the Symbolic Regressor: {self.fit_quality}") 173 | elif self.symbolic_model.task_type == "regression": 174 | for retry in range(2): 175 | try: 176 | self.fit_quality = mean_squared_error( 177 | self.y_test, self.symbolic_model.predict(self.X_test) 178 | ) 179 | break 180 | except TypeError: 181 | self.X_test = torch.Tensor(self.X_test) 182 | self.y_test = torch.Tensor(self.y_test) 183 | for retry in range(2): 184 | try: 185 | self.model_fit_quality = mean_squared_error( 186 | self.y_test, self.model(self.X_test) 187 | ) 188 | break 189 | except TypeError: 190 | self.X_test = torch.Tensor(self.X_test) 191 | self.y_test = torch.Tensor(self.y_test) 192 | 193 | print(f"MSE score for the model: {self.model_fit_quality}") 194 | print(f"MSE score for the Symbolic Regressor: {self.fit_quality}") 195 | else: 196 | raise exceptions.MeasureFitQualityCalledBeforeFit(self.has_been_fit) 197 | 198 | def explain( 199 | self, x0: np.array = None, taylor_expansion_order: int = 2 200 | ) -> pd.DataFrame: 201 | """ 202 | The function to get the explanation data from the explainer 203 | """ 204 | if self.has_been_fit: 205 | expression = self.symbolic_model.get_expression() 206 | projections = self.symbolic_model.get_projections() 207 | feature_importance = self.symbolic_model.get_feature_importance(x0) 208 | taylor_expansion = self.symbolic_model.get_taylor( 209 | x0, taylor_expansion_order 210 | ) 211 | self.explanation = SymbolicPursuitExplanation( 212 | expression, 213 | projections, 214 | x0, 215 | feature_importance, 216 | taylor_expansion, 217 | self.model_fit_quality, 218 | self.fit_quality, 219 | ) 220 | return self.explanation 221 | else: 222 | raise exceptions.ExplainCalledBeforeFit(self.has_been_fit) 223 | 224 | def summary_plot( 225 | self, 226 | file_prefilx="symbolic_pursuit", 227 | show_expression=True, 228 | show_feature_importance=True, 229 | show_feature_interactions=True, 230 | save_folder=".", 231 | ): 232 | """ 233 | Plot the latex'ed equations if latex installed 234 | """ 235 | 236 | def create_coefficient_heatmap_from_second_order_taylor_expansion( 237 | expression, 238 | ): 239 | expression = smp.Poly(expression) 240 | symbols = list(expression.free_symbols) 241 | symbol_pairs = itertools.product(symbols, repeat=2) 242 | coeffs_dict = {} 243 | for s_p in symbol_pairs: 244 | coeffs_dict[f"{s_p[0]}{s_p[1]}"] = expression.coeff_monomial( 245 | s_p[0] * s_p[1] 246 | ) 247 | coeffs_dict = dict(sorted(coeffs_dict.items())) 248 | coeffs_dict_reoriented = {} 249 | for i in range(len(self.feature_names)): 250 | coeffs_dict_reoriented[f"{self.feature_names[i]}"] = [ 251 | float(coeffs_dict[f"X{i}X{j}"]) 252 | for j in range(len(self.feature_names)) 253 | ] 254 | coeffs = pd.DataFrame(data=coeffs_dict_reoriented, index=self.feature_names) 255 | mask = np.triu(coeffs) 256 | # np.fill_diagonal(mask, 0) 257 | figure = sns.heatmap( 258 | coeffs, annot=True, mask=mask, fmt=".2f", annot_kws={"fontsize": 4} 259 | ).get_figure() 260 | return figure 261 | 262 | if show_expression: 263 | """Show image of latex'ed expression if possible else print expression to console. Print projections to console.""" 264 | try: 265 | save_path_stem = os.path.abspath(save_folder) 266 | save_path_stem = os.path.join(save_path_stem, file_prefilx) 267 | smp.preview( 268 | self.explanation.expression, 269 | viewer="file", 270 | filename=save_path_stem + "_expression.png", 271 | dvioptions=["-D", "1200"], 272 | ) 273 | expression_img = Image.open(save_path_stem + "_expression.png") 274 | expression_img.show() 275 | except RuntimeError as e: 276 | print( 277 | "For an output that does not require latex set `show_expression=False`." 278 | ) 279 | raise e 280 | else: 281 | print(self.explanation.expression) 282 | self.symbolic_model.print_projections() 283 | 284 | if show_feature_importance: 285 | """Print a dataframe of feature importance (display if in Notebook)""" 286 | feature_importance = self.symbolic_model.get_feature_importance( 287 | self.explanation.x0 288 | ) 289 | feature_importance_zip = zip(self.feature_names, feature_importance) 290 | feature_importance_dict = { 291 | k: [round(v, 2)] for k, v in feature_importance_zip 292 | } 293 | feature_importance_df = pd.DataFrame(data=feature_importance_dict) 294 | try: 295 | display(feature_importance_df) 296 | except: 297 | print(feature_importance_df) 298 | 299 | if show_feature_interactions: 300 | """Show a heatmap of the feature interactions""" 301 | taylor_expansion = self.symbolic_model.get_taylor(self.explanation.x0, 2) 302 | if taylor_expansion == 0: 303 | print( 304 | "The taylor expansion that calculates feature interactions is not available. Try fitting the SymbolicPursuitExplainer for more iterations, by increasing `patience` or reducing `loss_tol`." 305 | ) 306 | else: 307 | taylor_expand_expr = smp.expand(taylor_expansion) 308 | heatmap = create_coefficient_heatmap_from_second_order_taylor_expansion( 309 | taylor_expand_expr 310 | ) 311 | plt.show() 312 | 313 | def symbolic_predict( 314 | self, 315 | predict_array: np.array, 316 | ): 317 | return self.symbolic_model.predict(predict_array) 318 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. todo:: THIS IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 2 | 3 | The document assumes you are using a source repository service that promotes a 4 | contribution model similar to `GitHub's fork and pull request workflow`_. 5 | While this is true for the majority of services (like GitHub, GitLab, 6 | BitBucket), it might not be the case for private repositories (e.g., when 7 | using Gerrit). 8 | 9 | Also notice that the code examples might refer to GitHub URLs or the text 10 | might use GitHub specific terminology (e.g., *Pull Request* instead of *Merge 11 | Request*). 12 | 13 | Please make sure to check the document having these assumptions in mind 14 | and update things accordingly. 15 | 16 | .. todo:: Provide the correct links/replacements at the bottom of the document. 17 | 18 | .. todo:: You might want to have a look on `PyScaffold's contributor's guide`_, 19 | 20 | especially if your project is open source. The text should be very similar to 21 | this template, but there are a few extra contents that you might decide to 22 | also include, like mentioning labels of your issue tracker or automated 23 | releases. 24 | 25 | 26 | ============ 27 | Contributing 28 | ============ 29 | 30 | Welcome to ``interpretability`` contributor's guide. 31 | 32 | This document focuses on getting any potential contributor familiarized 33 | with the development processes, but `other kinds of contributions`_ are also 34 | appreciated. 35 | 36 | If you are new to using git_ or have never collaborated in a project previously, 37 | please have a look at `contribution-guide.org`_. Other resources are also 38 | listed in the excellent `guide created by FreeCodeCamp`_ [#contrib1]_. 39 | 40 | Please notice, all users and contributors are expected to be **open, 41 | considerate, reasonable, and respectful**. When in doubt, `Python Software 42 | Foundation's Code of Conduct`_ is a good reference in terms of behavior 43 | guidelines. 44 | 45 | 46 | Issue Reports 47 | ============= 48 | 49 | If you experience bugs or general issues with ``interpretability``, please have a look 50 | on the `issue tracker`_. If you don't see anything useful there, please feel 51 | free to fire an issue report. 52 | 53 | .. tip:: 54 | Please don't forget to include the closed issues in your search. 55 | Sometimes a solution was already reported, and the problem is considered 56 | **solved**. 57 | 58 | New issue reports should include information about your programming environment 59 | (e.g., operating system, Python version) and steps to reproduce the problem. 60 | Please try also to simplify the reproduction steps to a very minimal example 61 | that still illustrates the problem you are facing. By removing other factors, 62 | you help us to identify the root cause of the issue. 63 | 64 | 65 | Documentation Improvements 66 | ========================== 67 | 68 | You can help improve ``interpretability`` docs by making them more readable and coherent, or 69 | by adding missing information and correcting mistakes. 70 | 71 | ``interpretability`` documentation uses Sphinx_ as its main documentation compiler. 72 | This means that the docs are kept in the same repository as the project code, and 73 | that any documentation update is done in the same way was a code contribution. 74 | 75 | .. todo:: Don't forget to mention which markup language you are using. 76 | 77 | e.g., reStructuredText_ or CommonMark_ with MyST_ extensions. 78 | 79 | .. todo:: If your project is hosted on GitHub, you can also mention the following tip: 80 | 81 | .. tip:: 82 | Please notice that the `GitHub web interface`_ provides a quick way of 83 | propose changes in ``interpretability``'s files. While this mechanism can 84 | be tricky for normal code contributions, it works perfectly fine for 85 | contributing to the docs, and can be quite handy. 86 | 87 | If you are interested in trying this method out, please navigate to 88 | the ``docs`` folder in the source repository_, find which file you 89 | would like to propose changes and click in the little pencil icon at the 90 | top, to open `GitHub's code editor`_. Once you finish editing the file, 91 | please write a message in the form at the bottom of the page describing 92 | which changes have you made and what are the motivations behind them and 93 | submit your proposal. 94 | 95 | When working on documentation changes in your local machine, you can 96 | compile them using |tox|_:: 97 | 98 | tox -e docs 99 | 100 | and use Python's built-in web server for a preview in your web browser 101 | (``http://localhost:8000``):: 102 | 103 | python3 -m http.server --directory 'docs/_build/html' 104 | 105 | 106 | Code Contributions 107 | ================== 108 | 109 | .. todo:: Please include a reference or explanation about the internals of the project. 110 | 111 | An architecture description, design principles or at least a summary of the 112 | main concepts will make it easy for potential contributors to get started 113 | quickly. 114 | 115 | Submit an issue 116 | --------------- 117 | 118 | Before you work on any non-trivial code contribution it's best to first create 119 | a report in the `issue tracker`_ to start a discussion on the subject. 120 | This often provides additional considerations and avoids unnecessary work. 121 | 122 | Create an environment 123 | --------------------- 124 | 125 | Before you start coding, we recommend creating an isolated `virtual 126 | environment`_ to avoid any problems with your installed Python packages. 127 | This can easily be done via either |virtualenv|_:: 128 | 129 | virtualenv 130 | source /bin/activate 131 | 132 | or Miniconda_:: 133 | 134 | conda create -n interpretability python=3 six virtualenv pytest pytest-cov 135 | conda activate interpretability 136 | 137 | Clone the repository 138 | -------------------- 139 | 140 | #. Create an user account on |the repository service| if you do not already have one. 141 | #. Fork the project repository_: click on the *Fork* button near the top of the 142 | page. This creates a copy of the code under your account on |the repository service|. 143 | #. Clone this copy to your local disk:: 144 | 145 | git clone git@github.com:YourLogin/interpretability.git 146 | cd interpretability 147 | 148 | #. You should run:: 149 | 150 | pip install -U pip setuptools -e . 151 | 152 | to be able to import the package under development in the Python REPL. 153 | 154 | .. todo:: if you are not using pre-commit, please remove the following item: 155 | 156 | #. Install |pre-commit|_:: 157 | 158 | pip install pre-commit 159 | pre-commit install 160 | 161 | ``interpretability`` comes with a lot of hooks configured to automatically help the 162 | developer to check the code being written. 163 | 164 | Implement your changes 165 | ---------------------- 166 | 167 | #. Create a branch to hold your changes:: 168 | 169 | git checkout -b my-feature 170 | 171 | and start making changes. Never work on the main branch! 172 | 173 | #. Start your work on this branch. Don't forget to add docstrings_ to new 174 | functions, modules and classes, especially if they are part of public APIs. 175 | 176 | #. Add yourself to the list of contributors in ``AUTHORS.rst``. 177 | 178 | #. When you’re done editing, do:: 179 | 180 | git add 181 | git commit 182 | 183 | to record your changes in git_. 184 | 185 | .. todo:: if you are not using pre-commit, please remove the following item: 186 | 187 | Please make sure to see the validation messages from |pre-commit|_ and fix 188 | any eventual issues. 189 | This should automatically use flake8_/black_ to check/fix the code style 190 | in a way that is compatible with the project. 191 | 192 | .. important:: Don't forget to add unit tests and documentation in case your 193 | contribution adds an additional feature and is not just a bugfix. 194 | 195 | Moreover, writing a `descriptive commit message`_ is highly recommended. 196 | In case of doubt, you can check the commit history with:: 197 | 198 | git log --graph --decorate --pretty=oneline --abbrev-commit --all 199 | 200 | to look for recurring communication patterns. 201 | 202 | #. Please check that your changes don't break any unit tests with:: 203 | 204 | tox 205 | 206 | (after having installed |tox|_ with ``pip install tox`` or ``pipx``). 207 | 208 | You can also use |tox|_ to run several other pre-configured tasks in the 209 | repository. Try ``tox -av`` to see a list of the available checks. 210 | 211 | Submit your contribution 212 | ------------------------ 213 | 214 | #. If everything works fine, push your local branch to |the repository service| with:: 215 | 216 | git push -u origin my-feature 217 | 218 | #. Go to the web page of your fork and click |contribute button| 219 | to send your changes for review. 220 | 221 | .. todo:: if you are using GitHub, you can uncomment the following paragraph 222 | 223 | Find more detailed information in `creating a PR`_. You might also want to open 224 | the PR as a draft first and mark it as ready for review after the feedbacks 225 | from the continuous integration (CI) system or any required fixes. 226 | 227 | 228 | Troubleshooting 229 | --------------- 230 | 231 | The following tips can be used when facing problems to build or test the 232 | package: 233 | 234 | #. Make sure to fetch all the tags from the upstream repository_. 235 | The command ``git describe --abbrev=0 --tags`` should return the version you 236 | are expecting. If you are trying to run CI scripts in a fork repository, 237 | make sure to push all the tags. 238 | You can also try to remove all the egg files or the complete egg folder, i.e., 239 | ``.eggs``, as well as the ``*.egg-info`` folders in the ``src`` folder or 240 | potentially in the root of your project. 241 | 242 | #. Sometimes |tox|_ misses out when new dependencies are added, especially to 243 | ``setup.cfg`` and ``docs/requirements.txt``. If you find any problems with 244 | missing dependencies when running a command with |tox|_, try to recreate the 245 | ``tox`` environment using the ``-r`` flag. For example, instead of:: 246 | 247 | tox -e docs 248 | 249 | Try running:: 250 | 251 | tox -r -e docs 252 | 253 | #. Make sure to have a reliable |tox|_ installation that uses the correct 254 | Python version (e.g., 3.7+). When in doubt you can run:: 255 | 256 | tox --version 257 | # OR 258 | which tox 259 | 260 | If you have trouble and are seeing weird errors upon running |tox|_, you can 261 | also try to create a dedicated `virtual environment`_ with a |tox|_ binary 262 | freshly installed. For example:: 263 | 264 | virtualenv .venv 265 | source .venv/bin/activate 266 | .venv/bin/pip install tox 267 | .venv/bin/tox -e all 268 | 269 | #. `Pytest can drop you`_ in an interactive session in the case an error occurs. 270 | In order to do that you need to pass a ``--pdb`` option (for example by 271 | running ``tox -- -k --pdb``). 272 | You can also setup breakpoints manually instead of using the ``--pdb`` option. 273 | 274 | 275 | Maintainer tasks 276 | ================ 277 | 278 | Releases 279 | -------- 280 | 281 | .. todo:: This section assumes you are using PyPI to publicly release your package. 282 | 283 | If instead you are using a different/private package index, please update 284 | the instructions accordingly. 285 | 286 | If you are part of the group of maintainers and have correct user permissions 287 | on PyPI_, the following steps can be used to release a new version for 288 | ``interpretability``: 289 | 290 | #. Make sure all unit tests are successful. 291 | #. Tag the current commit on the main branch with a release tag, e.g., ``v1.2.3``. 292 | #. Push the new tag to the upstream repository_, e.g., ``git push upstream v1.2.3`` 293 | #. Clean up the ``dist`` and ``build`` folders with ``tox -e clean`` 294 | (or ``rm -rf dist build``) 295 | to avoid confusion with old builds and Sphinx docs. 296 | #. Run ``tox -e build`` and check that the files in ``dist`` have 297 | the correct version (no ``.dirty`` or git_ hash) according to the git_ tag. 298 | Also check the sizes of the distributions, if they are too big (e.g., > 299 | 500KB), unwanted clutter may have been accidentally included. 300 | #. Run ``tox -e publish -- --repository pypi`` and check that everything was 301 | uploaded to PyPI_ correctly. 302 | 303 | 304 | 305 | .. [#contrib1] Even though, these resources focus on open source projects and 306 | communities, the general ideas behind collaborating with other developers 307 | to collectively create software are general and can be applied to all sorts 308 | of environments, including private companies and proprietary code bases. 309 | 310 | 311 | .. <-- strart --> 312 | .. todo:: Please review and change the following definitions: 313 | 314 | .. |the repository service| replace:: GitHub 315 | .. |contribute button| replace:: "Create pull request" 316 | 317 | .. _repository: https://github.com//interpretability 318 | .. _issue tracker: https://github.com//interpretability/issues 319 | .. <-- end --> 320 | 321 | 322 | .. |virtualenv| replace:: ``virtualenv`` 323 | .. |pre-commit| replace:: ``pre-commit`` 324 | .. |tox| replace:: ``tox`` 325 | 326 | 327 | .. _black: https://pypi.org/project/black/ 328 | .. _CommonMark: https://commonmark.org/ 329 | .. _contribution-guide.org: https://www.contribution-guide.org/ 330 | .. _creating a PR: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request 331 | .. _descriptive commit message: https://chris.beams.io/posts/git-commit 332 | .. _docstrings: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 333 | .. _first-contributions tutorial: https://github.com/firstcontributions/first-contributions 334 | .. _flake8: https://flake8.pycqa.org/en/stable/ 335 | .. _git: https://git-scm.com 336 | .. _GitHub's fork and pull request workflow: https://guides.github.com/activities/forking/ 337 | .. _guide created by FreeCodeCamp: https://github.com/FreeCodeCamp/how-to-contribute-to-open-source 338 | .. _Miniconda: https://docs.conda.io/en/latest/miniconda.html 339 | .. _MyST: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html 340 | .. _other kinds of contributions: https://opensource.guide/how-to-contribute 341 | .. _pre-commit: https://pre-commit.com/ 342 | .. _PyPI: https://pypi.org/ 343 | .. _PyScaffold's contributor's guide: https://pyscaffold.org/en/stable/contributing.html 344 | .. _Pytest can drop you: https://docs.pytest.org/en/stable/how-to/failures.html#using-python-library-pdb-with-pytest 345 | .. _Python Software Foundation's Code of Conduct: https://www.python.org/psf/conduct/ 346 | .. _reStructuredText: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ 347 | .. _Sphinx: https://www.sphinx-doc.org/en/master/ 348 | .. _tox: https://tox.wiki/en/stable/ 349 | .. _virtual environment: https://realpython.com/python-virtual-environments-a-primer/ 350 | .. _virtualenv: https://virtualenv.pypa.io/en/stable/ 351 | 352 | .. _GitHub web interface: https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files 353 | .. _GitHub's code editor: https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files 354 | -------------------------------------------------------------------------------- /Notebooks/Tutorial_02_implement_simplex_time_series.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","metadata":{},"source":["# Tutorial 2 - SimplEx for Time Series Data\n","\n","In this tutorial we we create a simplex explainer object and use it to explain a test record. The explainer is then saved to disk and can be given to someone else to view in the [Interpretability Suite App](https://vanderschaarlab-demo-interpretabi-interpretability-suite-1uteyn.streamlit.app/).\n","\n","We will be explaining the predictions of pytorch convolutional neural net that we have trained and saved separately on an engine noise dataset from IEEE World Congress on Computational Intelligence, 2008. The Interpretability.models module provides a pytorch model for this that is compatible with trained models `state_dict`s available on the Google Drive link below.\n","\n","### Import the relevant modules"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["# IMPORTS\n","# Standard\n","import os\n","import numpy as np\n","import pathlib\n","# Third Party\n","import torch\n","# Interpretability\n","from interpretability.interpretability_models import simplex_explainer\n","from interpretability.interpretability_models.utils import io\n","from interpretability.models.recurrent_neural_net import ConvNet\n","import sklearn"]},{"cell_type":"markdown","metadata":{},"source":["### Load the data \n","Load the data and split it into the corpus of examples used for explanation and the test examples we will explain. This cell will download the data from the `root_url` and save it to a subdirectory in the folder this notebook is being run."]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["# LOADS\n","def load_forda_data():\n","\n"," def readucr(filename):\n"," data = np.loadtxt(filename, delimiter=\"\\t\")\n"," y = data[:, 0]\n"," x = data[:, 1:]\n"," return x, y.astype(int)\n","\n"," root_url = \"https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/\"\n","\n"," x_train, y_train = readucr(root_url + \"FordA_TRAIN.tsv\")\n"," x_test, y_test = readucr(root_url + \"FordA_TEST.tsv\")\n","\n"," x_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))\n"," x_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))\n","\n"," idx = np.random.permutation(len(x_train))\n"," x_train = x_train[idx]\n"," y_train = y_train[idx]\n","\n"," y_train[y_train == -1] = 0\n"," y_test[y_test == -1] = 0\n","\n"," return x_train, y_train, x_test, y_test\n","\n","\n","# LOAD data\n","(\n"," X_corpus,\n"," y_corpus,\n"," X_explain,\n"," y_explain,\n",") = load_forda_data()\n","\n","# # Scaling is not required here but purely shown for illustrative purposes \n","# scaler = sklearn.preprocessing.MinMAxScaler()\n","# scaler.fit(X_corpus)\n","# X_corpus, X_explain = scaler.transform(X_corpus), scaler.transform(X_explain)"]},{"cell_type":"markdown","metadata":{},"source":["### Download the trained model from Google Drive\n","\n","You could train your own model using the ConvNet class and load it here, but we have trained one already.\n","\n","Download the model using this link: https://drive.google.com/file/d/173vniHegUSGmdC6fKCLupynRoxEdz9Ko/view?usp=sharing and save it in a location matching the path `TRAINED_MODEL_STATE_PATH` below. The default location is the `\"resources/saved_models\"` folder inside the root Interpretability directory.\n","\n","\n","### Load the model"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["## Load the model\n","model = ConvNet()\n","\n","def load_trained_model(model, trained_model_state_path, device='cpu'):\n"," model.load_state_dict(torch.load(trained_model_state_path, map_location=torch.device(device)))\n"," model.eval()\n"," return model\n","\n","DEVICE = \"cpu\"\n","\n","root_path = pathlib.Path.cwd().parents[0]\n","saved_models_path = root_path / \"resources/saved_models\"\n","TRAINED_MODEL_STATE_PATH = saved_models_path / \"model_cv1_2.pth\"\n","model = load_trained_model(model, TRAINED_MODEL_STATE_PATH, device=DEVICE)"]},{"cell_type":"markdown","metadata":{},"source":["### Initialize SimplEX\n","Initialize the explainer object by passing the predictive model and corpus."]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["/home/rob/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:306: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:80.)\n"," return F.conv1d(input, weight, bias, self.stride,\n"]}],"source":["# Fit SimplEx\n","corpus_size = 100\n","# Initialize SimplEX, fit it on test examples\n","my_explainer = simplex_explainer.SimplexTimeSeriesExplainer(\n"," model,\n"," X_corpus,\n"," y_corpus,\n"," estimator_type=\"classifier\",\n"," feature_names=[\"Engine Noise\"],\n"," corpus_size=corpus_size,\n"," device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":["### Fit the explainer\n","\n","Fit the explainer on the test data. This makes explanations of the test data available in the subsequent step."]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Weight Fitting Epoch: 2000/10000 ; Error: 5.9e+06 ; Regulator: 890 ; Reg Factor: 1\n","Weight Fitting Epoch: 4000/10000 ; Error: 5.67e+06 ; Regulator: 640 ; Reg Factor: 1\n","Weight Fitting Epoch: 6000/10000 ; Error: 5.63e+06 ; Regulator: 616 ; Reg Factor: 1\n","Weight Fitting Epoch: 8000/10000 ; Error: 5.62e+06 ; Regulator: 611 ; Reg Factor: 1\n","Weight Fitting Epoch: 10000/10000 ; Error: 5.62e+06 ; Regulator: 610 ; Reg Factor: 1\n"]}],"source":["my_explainer.fit(X_explain, y_explain)"]},{"cell_type":"markdown","metadata":{},"source":["### Get the explanation\n","Explain any given record in the test set by changing the index, i."]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["i = 1\n","explanation = my_explainer.explain(i, baseline=\"median\")"]},{"cell_type":"markdown","metadata":{},"source":["### Plot the explanation\n","\n","The explanation is plotted as a styled df, in this notebook, but it is also viewable in the browser, if the `return_type` is set to \"html\"."]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[{"data":{"text/html":["
\n","\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
Engine Noise
(t_max) - 90.435186
(t_max) - 8-0.346502
(t_max) - 7-0.924912
(t_max) - 6-1.208716
(t_max) - 5-1.247996
(t_max) - 4-1.139974
(t_max) - 3-1.041772
(t_max) - 2-1.041772
(t_max) - 1-1.159614
(t_max)-1.375659
\n","
"],"text/plain":[" Engine Noise\n","(t_max) - 9 0.435186\n","(t_max) - 8 -0.346502\n","(t_max) - 7 -0.924912\n","(t_max) - 6 -1.208716\n","(t_max) - 5 -1.247996\n","(t_max) - 4 -1.139974\n","(t_max) - 3 -1.041772\n","(t_max) - 2 -1.041772\n","(t_max) - 1 -1.159614\n","(t_max) -1.375659"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Corpus Example: 0\n","Example Importance: 0.14360299706459045\n"]},{"name":"stderr","output_type":"stream","text":["/home/rob/miniconda3/envs/interp/lib/python3.10/site-packages/interpretability/interpretability_models/simplex_explainer.py:1148: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n"," importance_df_colors = importance_df_colors.applymap(\n"]},{"data":{"text/html":["\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
 0123456789
Engine Noise2.8399691.8653680.538502-0.692078-1.398957-1.457668-0.959800-0.1526420.5925160.976485
\n"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Corpus Example: 1\n","Example Importance: 0.08999679982662201\n"]},{"name":"stderr","output_type":"stream","text":["/home/rob/miniconda3/envs/interp/lib/python3.10/site-packages/interpretability/interpretability_models/simplex_explainer.py:1148: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n"," importance_df_colors = importance_df_colors.applymap(\n"]},{"data":{"text/html":["\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
 0123456789
Engine Noise0.7990380.5370300.003914-0.632543-1.182462-1.501255-1.521180-1.182462-0.608634-0.014682
\n"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Corpus Example: 2\n","Example Importance: 0.0897858589887619\n"]},{"data":{"text/html":["\n","\n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n"," \n","
 0123456789
Engine Noise-1.730200-2.020920-2.029729-1.791868-1.439481-1.095904-0.787565-0.525918-0.2202220.150048
\n"],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["# Corpus of patients\n","my_explainer.summary_plot(\n"," example_importance_threshold=0.08,\n"," time_steps_to_display=10,\n"," return_type=\"styled_df\",\n"," # rescaler=scaler,\n",")"]},{"cell_type":"markdown","metadata":{},"source":["### Save the explainer to file\n","This file can now be uploaded to the [Interpretability Suite App](https://vanderschaarlab-demo-interpretabi-interpretability-suite-1uteyn.streamlit.app/). This provides a non-programtic interface with which to view the various explanations, allowing you to send the explainer to a colleague who is less fluent in python."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Saving explainer to: /home/rob/Documents/projects/Interpretability/Notebooks/my_new_forda_conv_time_simplex_explainer.p\n"]}],"source":["io.save_explainer(\n"," my_explainer, \"my_new_forda_conv_time_simplex_explainer.p\"\n",")"]}],"metadata":{"kernelspec":{"display_name":"Python 3.8.13 ('interp')","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.14"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"6fd73f071793638ac14baf0ff0f19e5ab81431475f40d47f0df0002312a62017"}}},"nbformat":4,"nbformat_minor":2} 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning Interpretability Methods 2 | 3 | This repository collects different Machine Learning interpretability methods and aims to act 4 | as a reference where the user can select the method best suited for their needs. All the methods 5 | aim to provide an insight into why a machine learning model has made a given prediction. 6 | This is critical because for a model's predictions to be trusted they must be understood. 7 | 8 | 9 | # Table of Contents 10 | 0. [Background](#background) 11 | - [Introductory video](#introductory-video) 12 | 1. [Interface](#Interface) 13 | 2. [Explainers](#explainers) 14 | - [Selecting an Interpretability Method](#selecting-an-interpretability-method) 15 | 3. [Implementation and Notebooks](#implementation-and-notebooks) 16 | - [SimplEx](#simplex) 17 | - [Dynamask](#dynamask) 18 | - [shap](#shap) 19 | - [Symbolic Pursuit](#symbolic-pursuit) 20 | 4. [Explainers By Model Type](#explainers-by-model-type) 21 | - [Tabular Data Model Explainers](#tabular-data-model-explainers) 22 | - [Time Series model Explainers](#time-series-model-explainers) 23 | - [Unsupervised Model Explainers](#unsupervised-model-explainers) 24 | - [Individualized Treatment Effect Explainers](#individualized-treatment-effect-explainers) 25 | 5. [Generalized and Personalized Explainers](#feature-based-and-example-based-explainers) 26 | - [Generalized Explainers](#feature-based-explainers) 27 | - [Personalized Explainers](#example-based-explainers) 28 | - [Concept-based Explainers](#concept-based-explainers) 29 | 30 | # Background 31 | 32 | The Machine Learning (ML) community has produced incredible models for making highly 33 | accurate predictions and classifications across many fields. However, uptake of these models into 34 | settings outside of the ML community faces a key barrier: Interpretability. If a human cannot 35 | understand why a decision has been made by a machine, they cannot be reasonably expected to act 36 | on that decision with full confidence, particularly in a high-stakes environment such as medicine. 37 | Therefore making the decisions of "Black-Box" models more transparent is of vital importance. For more 38 | information see this [blog post](https://www.vanderschaar-lab.com/interpretable-machine-learning/). 39 | 40 | This GitHub repository aims to act as a home for interpretability methods, where the state-of-the-art models 41 | can be found for every application. All the linked van der Schaar Lab repositories on this page are pytorch compatible. 42 | Pytorch versions of the other methods are available on public libraries, such as [captum](https://captum.ai/). 43 | 44 | ## Introductory video 45 | 46 | This [video](https://www.youtube.com/watch?v=R-27AiRK1r0) is a quick introduction to our Interpretability Suite. 47 | It discusses why ML interpretability is so important and shows the array of different methods developed by the van der Schaar Lab 48 | that are available on this GitHub page. 49 | 50 | [![Introduction to the Interpretability Suite](images/Short_intro_video_thumbnail.png)](https://www.youtube.com/watch?v=R-27AiRK1r0) 51 | 52 | # Interface 53 | 54 | The Interpretability Suite provides a common python interface for the following interpretability methods: SimplEx, Dynamask, shap, and Symbolic Pursuit. Each of these methods are also included in our [user interface](https://vanderschaarlab-demo-interpretabi-interpretability-suite-1uteyn.streamlit.app/). To guarantee compatibility with the app please create your explainers using the interpretability_suite_app branch. 55 | 56 | ![The interpretability Suite](images/interpretability_suite_image.png) 57 | *Figure 1: The Interpretability Suite User Interface Landing page* 58 | 59 | This user interface not only demonstrates the methods and how they are used on example datasets, but it also gives the user the ability to upload their own explainer to visualize the results. This means that you can save your explainer and give the file to a less python-literate colleague and they can see the results for themselves simply by drag-and-dropping it into the `Upload your own Explainer` tab. 60 | 61 | ![Upload your own Explainer tab](images/user_inter_face_upload.png) 62 | *Figure 2: An example of the `Upload your own Explainer` tab on the user interface from the SimplEx page* 63 | 64 | # Explainers 65 | 66 | Different model architectures can require different interpretability models, or "Explainers". 67 | Below are all the explainers included in this repository, with links to their source code and the papers that introduced them. SimplEx, Dynamask, shap, and Symbolic Pursuit have a common python interface implemented for them for ease of implementation (see [Interface](#interface) above and [Implementation and Notebooks](#implementation-andnotebooks) below). But any of the other methods can also be implemented by using the code in the GitHub column of the table below. 68 | 69 | | Explainer | Affiliation | GitHub | Paper | Date of Paper | 70 | | ----------- | ----------- | ----------- | ----------- | ----------- | 71 | | Concept Activation Regions (CARs) | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [CARs source Code](https://github.com/vanderschaarlab/CARs) | [CARs Paper](https://arxiv.org/abs/2209.11222) | 2022| 72 | | ITErpretability | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [ITErpretability Source Code](https://github.com/vanderschaarlab/ITErpretability) | [ITErpretability Paper](https://arxiv.org/abs/2206.08363) | 2022| 73 | | Label-Free XAI | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [Label-Free XAI Source Code](https://github.com/vanderschaarlab/Label-Free-XAI) | [Label-Free XAI Paper](https://arxiv.org/abs/2203.01928) | 2022| 74 | | SimplEx | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [SimplEx Source Code](https://github.com/vanderschaarlab/Simplex) | [SimplEx Paper](https://papers.nips.cc/paper/2021/hash/65658fde58ab3c2b6e5132a39fae7cb9-Abstract.html) | 2021 | 75 | | Dynamask | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [Dynamask Source Code](https://github.com/vanderschaarlab/Dynamask) | [Dynamask Paper](https://arxiv.org/abs/2106.05303) | 2021 | 76 | | Symbolic Pursuit | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [Symbolic Pursuit Source Code](https://github.com/vanderschaarlab/Symbolic-Pursuit) | [Symbolic Pursuit Paper](https://arxiv.org/abs/2011.08596#:~:text=Learning%20outside%20the%20Black%2DBox%3A%20The%20pursuit%20of%20interpretable%20models,-Jonathan%20Crabb%C3%A9%2C%20Yao&text=Machine%20Learning%20has%20proved%20its,difficulties%20of%20interpreting%20these%20models.) | 2020 | 77 | | INVASE | [van der Schaar Lab](https://www.vanderschaar-lab.com/) | [INVASE Source Code](https://github.com/vanderschaarlab/INVASE) | [INVASE Paper](https://openreview.net/forum?id=BJg_roAcK7) | 2019 | 78 | | SHAP | University of Washington | [SHAP Source Code](https://github.com/slundberg/shap) (pytorch implementation: [Captum GradientShap](https://captum.ai/api/gradient_shap.html))| [SHAP Paper](https://papers.nips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html) | 2017 | 79 | | Integrated gradients | Google | [Integrated Gradient Source Code](https://github.com/ankurtaly/Integrated-Gradients) (pytorch implementation: [Captum Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients))| [Integrated Gradient paper](https://arxiv.org/abs/1703.01365) | 2017 | 80 | | LIME | University of Washington | [LIME Source Code](https://github.com/marcotcr/lime) (pytorch implementation: [Captum Lime](https://captum.ai/api/lime.html))| [LIME Paper](https://arxiv.org/abs/1602.04938) | 2016 | 81 | 82 | ## Selecting an Interpretability Method 83 | 84 | Figure 3 shows a flowchart to help with the process of selecting the method that is most appropriate for your project. 85 | 86 | ![method selection flow chart](images/Interpretability_method_flow_diagram.svg) 87 | *Figure 3: Interpretability Method selection flowchart.* 88 | 89 | 90 | # Implementation and Notebooks 91 | 92 | This repository includes a common python interface for the following interpretability methods: SimplEx, Dynamask, shap, and Symbolic Pursuit. The interface provides the same methods for each of the methods such that you can use the same python methods in your scripts to set up an explainer for each interpretability method. The methods that are: 93 | 94 | - init: Instantiate the class of explainer of your choice. 95 | - fit: Performs and training for the explainer (This is not required for Shap explainers). 96 | - explain: Provide the explanation of the data provided. 97 | - summary_plot: Visualize the explanation. 98 | 99 | There are also Notebooks in this GitHub repository to demonstrate how each create the explainer object. These explainers can be saved and uploaded into the Interpretability Suite user interface. 100 | 101 | ## SimplEx 102 | The SimplEx explainer is only compatible with pytorch models. These models must also implement the function latent_representation(). A base class (interpretability.models.base.BlackBox) that your Black-Box model can inherit from is provided, in this package, but this is not a requirement. The notebook to demonstrate the Simplex explainer can be found [here](Notebooks/Tutorial_01_implement_simplex.ipynb) for tabular data and [here for time series data](Notebooks/Tutorial_02_implement_simplex_time_series.ipynb). 103 | 104 | ## Dynamask 105 | The Dynamask explainer is only compatible with pytorch models for time series data. The notebook to demonstrate the Dynamask explainer can be found [here](Notebooks/Tutorial_03_implement_dynamask.ipynb). 106 | 107 | ## Shap 108 | The Shap explainer is compatible with pytorch models such as multilayer perceptrons, tree-based models from libraries such as sci-kit learn or xgboost, Kernel models such as those from sci-kit learn, and linear models. The notebook to demonstrate the Shap explainer can be found [here](Notebooks/Tutorial_04_implement_shap.ipynb). It uses a ShapTreeExplainer to explain the predictions of an XGBoost classifier. 109 | 110 | ## Symbolic Pursuit 111 | The SimplEx explainer is only compatible with pytorch models and sci-kit learn linear and multilayer perceptron models. The notebook to demonstrate the Symbolic Pursuit explainer can be found [here](Notebooks/Tutorial_05_implement_symbolic_pursuit.ipynb). 112 | 113 | # Explainers By Model Type 114 | 115 | The following sections break the methods down by the type of model that they explain. 116 | 117 | ## Tabular Data Model Explainers 118 | 119 | There are many different static tabular data explainers, giving many different options to choose between. These methods are list below with the inputs required to use them. 120 | 121 | 122 | | Explainer | Inputs | Notes | 123 | | ----------- | ----------- | ----------- | 124 | | CARs | The latent representation of the concept examples from the predictive model. | It is worth noting that CARs appears in both this and the following section as it can be easily implemented for either tabular data or time series data. | 125 | | SimplEx | The latent representation of records from the predictive model. | It is worth noting that SimplEx appears in both this and the following section as it can be easily implemented for either tabular data or time series data. | 126 | | Symbolic Pursuit | The training data and corresponding model predictions. | This method has the benefit of producing a mathematical expression that approximates the predictive model. This seeks to discover “rules” and “laws” learned by the machine model. | 127 | | INVASE | The training data and corresponding model predictions, with which to train a selector network. | Currently, INVASE has only been implemented for tabular data, but it could be extended to include time series explanations with some further work, see Section 5 of the paper linked above, "Future Work". | 128 | | SHAP | The predictive model and training data from which to subsequently calculate Shapley values. || 129 | | Integrated gradients | The predictive model and its gradients. || 130 | | LIME | The predictive model from which to calculate weights in a local linear model . || 131 | 132 | 133 | 134 | ## Time Series model Explainers 135 | 136 | The following Explainers work with models for making predictions from time series data. 137 | 138 | | Explainer | Inputs | Notes | 139 | | ----------- | ----------- | ----------- | 140 | | Dynamask | The predictive model and its gradients. | Dynamask calculates feature importance at every time steps for each feature. This is advantageous over other saliency methods, such as SHAP and Integrated Gradients were optimized for static data then later extended to time series. They are thus static, and hence the context, is forgotten when treating all the time steps as separate features. | 141 | | CARs | The latent representation of the concept examples from the predictive model. | It is worth noting that CARs appears in both this and the following section as it can be easily implemented for either tabular data or time series data. | 142 | | SimplEx | The latent representation of records from the predictive model. | It is worth noting that SimplEx appears in both this and the following section as it can be easily implemented for either tabular data or time series data. | 143 | | SHAP | The predictive model and training data from which to subsequently calculate Shapley values. | SHAP and Integrated Gradients can both be used for time series, however their efficacy in this setting has been criticized in [this paper](https://papers.nips.cc/paper/2020/hash/47a3893cc405396a5c30d91320572d6d-Abstract.html). | 144 | | Integrated gradients | The predictive model and its gradients. | SHAP and Integrated Gradients can both be used for time series, however their efficacy in this setting has been criticized in [this paper](https://papers.nips.cc/paper/2020/hash/47a3893cc405396a5c30d91320572d6d-Abstract.html). | 145 | | LIME | The predictive model from which to calculate weights in a local linear model . | Lime has been extended to work with time series data with libraries such as Lime-For-Time. | 146 | 147 | 148 | ## Unsupervised Model Explainers 149 | 150 | The following Explainers work with unsupervised clustering ML models, that is to say those without labelled data in the training set. 151 | 152 | | Explainer | Inputs | 153 | | ----------- | ----------- | 154 | | Label-Free XAI | The predictive model. | 155 | 156 | ## Individualized Treatment Effect Explainers 157 | 158 | The following Explainers work with Individualized Treatment Effects (otherwise known as Conditional Average Treatment Effects). 159 | 160 | | Explainer | Inputs | 161 | | ----------- | ----------- | 162 | | ITErpretability | The conditional average treatment effects (CATE) model. | 163 | 164 | 165 | # Feature-based and example-based Explainers 166 | 167 | Increased interpretability of a model can be achieved in multiple ways. Generalized methods 168 | may provide an explanation for a model's predictions in terms of the features that were important 169 | for that decision, e.g. a predicted value yi was given because feature 1 was high and 170 | feature 3 was low for the ith prediction record. Whereas, personalized methods may provide their 171 | explanation by showing examples that were important to the prediction, e.g. a predicted value 172 | yi was given because the model had previously seen three records with a similar profile 173 | they all had the same label as the predicted value here. 174 | 175 | SimplEx is worthy of note in this section as it bridges the gap between these two categories by 176 | providing example importances and the features that are important for those examples. It therefore 177 | appears in both sections below. 178 | 179 | ## Feature-based Explainers 180 | 181 | Feature-based explainers have the advantage of specifically telling you the features that are important 182 | for a prediction for a given record. These explainers are useful at several points in the development cycle of black-box models. 183 | They are useful when debugging a model, because they can reveal places where the model is relying 184 | too heavily on fields which you know should be of lesser importance. It is also useful for feature 185 | engineering as you can directly see which features are being used the most. 186 | 187 | The following explainers provide feature importances for a given prediction. 188 | 189 | 190 | | Explainer | Inputs | Notes | 191 | | ----------- | ----------- | ----------- | 192 | | SimplEx | The latent representation of records from the predictive model. | SimplEx is the first method, to our knowledge, to calculate both feature importance and example importance. It calculates the importance of each feature in each of the examples in the explanation in order to reveal why the model sees the test record as similar to those in the explanation. | 193 | | Dynamask | The predictive model and its gradients. | Dynamask calculates feature importance at every time steps for each feature, without treating time steps as independent features.| 194 | | ITErpretability | The conditional average treatment effects (CATE) model. | | 195 | | INVASE | The training data and corresponding model predictions, with which to train a selector network. | | 196 | | Symbolic Pursuit | The training data and corresponding model predictions. | Symbolic Pursuit calculates feature interactions as well as feature importances. It has the benefit of producing a mathematical expression that approximates the predictive model. | 197 | | SHAP | The predictive model and training data from which to subsequently calculate Shapley values. | SHAP calculates feature interactions as well as feature importances. | 198 | | Integrated gradients | The predictive model and its gradients. | | 199 | | LIME | The predictive model from which to calculate weights in a local linear model . | | 200 | 201 | 202 | ## Example-based Explainers 203 | 204 | 205 | Example-based explainers show the user example records that the model sees as being similar. They can be 206 | useful for debugging in two ways: 1) if the explainer reveals that two examples are being viewed by the 207 | model as similar, but the user disagrees, this can be indicative of a model not classifying records correctly; 208 | 2) if the explainer states the most similar examples (from the models perspective) are incorrectly classified that 209 | also casts doubt on the validity of the predition for the test record. 210 | 211 | The following explainers provide example importances for a given prediction. 212 | 213 | | Explainer | Inputs | Notes | 214 | | ----------- | ----------- | ----------- | 215 | | Label-Free XAI | The predictive model. || 216 | | SimplEx | The latent representation of records from the predictive model. | SimplEx is the first method, to our knowledge, to calculate both feature importance and example importance. It calculates the importance of each feature in each of the examples in the explanation in order to reveal why the model sees the test record as similar to those in the explanation. | 217 | 218 | 219 | ## Concept-based Explainers 220 | 221 | 222 | Concept-based explainers explain the predictive models output in terms of the concepts provided by the user. This has the huge advantage of giving an explanation that is comprised of concepts that the user can understand even though they are not the input features to the model. For example, if there was a classifier that distinguished between horses and zebras one could provide examples of the concept of stripes and see how important stripes are to the predicted values. 223 | 224 | The following explainers provide example importances for a given prediction. 225 | 226 | | Explainer | Inputs | 227 | | ----------- | ----------- | 228 | | CARs | The latent representation of labelled concept examples from the predictive model. | 229 | 230 | -------------------------------------------------------------------------------- /images/Interpretability_method_flow_diagram.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
Black-box model
Black-box model
Time series
Time series
Unsupervised
Unsupervised
What is the data type?
What is the data type?
Static Data
Static Data
Label-free
Explainability
Label-free...
What does the method
use to explain the model?
What does the method...
What does the method
use to explain the model?
What does the method...
Features
Features
User-defined
Concepts
User-defined...
Example
Example
Feature
Feature
Dynamask, Lime, SHAP, Integrated Gradients
Dynamask, Lime, SHAP, I...
SimplEx
SimplEx
CARs
CARs
Both feature interactions
and importances?
Both feature interactions...
Feature Interactions
Feature Interactions
Feature Importances
only
Feature Importances...
Example-based
or feature-based?
Example-based...
Feature
Feature
Example
Example
Provided Model
Provided Model
Black-box model
Black-box model
Suggested Interpretability Method
Suggested Interpretabilit...
Symbolic Pursuit, SHAP
Symbolic Pursuit, SHAP
INVASE, SimplEx, LIME, Integrated Gradients
INVASE, SimplEx, LIME, Integrat...
Supervised or
unsupervised?
Supervised or...
Supervised
Supervised
SimplEx
SimplEx
Examples
Examples
User-defined
Concepts
User-defined...
ITErpretability
ITErpretability
Individualised Treatment Effects
Individualised Treatment Effects
Text is not SVG - cannot display
--------------------------------------------------------------------------------