├── docs ├── py-ciu.pdf ├── requirements.txt ├── Makefile ├── index.rst ├── make.bat └── conf.py ├── images ├── titanic_3d.png ├── default_plot.png ├── modified_plot.png ├── titanic_shap.png ├── ames_high_plot.png ├── titanic_family.png ├── titanic_wealth.png ├── ames_basement_plot.png ├── ames_default_plot.png ├── ames_garage_plot.png ├── titanic_influence.png ├── ames_house_cond_plot.png └── titanic_intermediate.png ├── ciu ├── __pycache__ │ └── ciu_object.cpython-39.pyc ├── __init__.py ├── ciuplots.py ├── PerturbationMinMaxEstimator.py └── CIU.py ├── requirements.txt ├── ciu_tests ├── __init__.py ├── heart_disease_rf.py ├── boston_gbm.py ├── iris_lda.py ├── titanic_rf.py └── ames_housing_gbm.py ├── .circleci └── config.yml ├── .readthedocs.yaml ├── LICENSE ├── setup.py ├── .gitignore ├── README.md └── RunTests.ipynb /docs/py-ciu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/docs/py-ciu.pdf -------------------------------------------------------------------------------- /images/titanic_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_3d.png -------------------------------------------------------------------------------- /images/default_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/default_plot.png -------------------------------------------------------------------------------- /images/modified_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/modified_plot.png -------------------------------------------------------------------------------- /images/titanic_shap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_shap.png -------------------------------------------------------------------------------- /images/ames_high_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/ames_high_plot.png -------------------------------------------------------------------------------- /images/titanic_family.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_family.png -------------------------------------------------------------------------------- /images/titanic_wealth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_wealth.png -------------------------------------------------------------------------------- /images/ames_basement_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/ames_basement_plot.png -------------------------------------------------------------------------------- /images/ames_default_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/ames_default_plot.png -------------------------------------------------------------------------------- /images/ames_garage_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/ames_garage_plot.png -------------------------------------------------------------------------------- /images/titanic_influence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_influence.png -------------------------------------------------------------------------------- /images/ames_house_cond_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/ames_house_cond_plot.png -------------------------------------------------------------------------------- /images/titanic_intermediate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/images/titanic_intermediate.png -------------------------------------------------------------------------------- /ciu/__pycache__/ciu_object.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaryFramling/py-ciu/HEAD/ciu/__pycache__/ciu_object.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements for running notebooks 2 | -e . 3 | matplotlib 4 | numpy 5 | pandas 6 | sklearn 7 | xgboost 8 | scikit_learn 9 | -------------------------------------------------------------------------------- /ciu_tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .heart_disease_rf import get_heart_disease_rf 2 | from .iris_lda import get_iris_test 3 | from .ames_housing_gbm import get_ames_gbm_test 4 | from .boston_gbm import get_boston_gbm_test 5 | from .titanic_rf import get_titanic_rf 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Defining the exact version will make sure things don't break 2 | numpy==1.26.4 3 | pandas==2.2.0 4 | matplotlib 5 | plotly 6 | sphinx==5.0.2 7 | sphinx_rtd_theme==2.0.0 8 | #scikit-image==0.22.0 9 | #readthedocs-sphinx-search 10 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Python CircleCI 2.0 configuration file 2 | # 3 | # Check https://circleci.com/docs/2.0/language-python/ for more details 4 | # 5 | version: 2 6 | jobs: 7 | build: 8 | docker: 9 | # specify the version you desire here 10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers` 11 | - image: circleci/python:3.7.4 12 | 13 | # Specify service dependencies here if necessary 14 | # CircleCI maintains a library of pre-built images 15 | # documented at https://circleci.com/docs/2.0/circleci-images/ 16 | # - image: circleci/postgres:9.4 17 | 18 | working_directory: ~/repo 19 | 20 | 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. py-ciu documentation master file, created by 2 | sphinx-quickstart on Thu Nov 23 19:20:07 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to py-ciu's documentation! 7 | ================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | .. automodule:: ciu 14 | :members: 15 | 16 | .. automodule:: ciu.CIU 17 | :members: 18 | 19 | .. automodule:: ciu.PerturbationMinMaxEstimator 20 | :members: 21 | 22 | .. automodule:: ciu.ciuplots 23 | :members: 24 | 25 | Indices and tables 26 | ================== 27 | 28 | * :ref:`genindex` 29 | * :ref:`modindex` 30 | * :ref:`search` 31 | -------------------------------------------------------------------------------- /ciu/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package implements the Contextual Importance and Utility (CIU) method. 3 | 4 | Classes: 5 | - :class:`ciu.CIU`: The CIU class implements the Contextual Importance and Utility method for Explainable AI. 6 | - :class:`ciu.PerturbationMinMaxEstimator.PerturbationMinMaxEstimator`: Class that finds minimal and maximal output values by perturbation of input value(s). This is the default class/method used by `CIU`. 7 | 8 | Functions: 9 | - `ciu.CIU.contrastive_ciu`: Function for calculating contrastive values from two CIU results. 10 | 11 | Example: 12 | :: 13 | 14 | # Example code using the module 15 | import ciu as ciu 16 | CIU = ciu.CIU(model.predict_proba, ['Output Name(s)'], data=X_train) 17 | CIUres = CIU.explain(instance) 18 | print(CIUres) 19 | """ 20 | from .CIU import CIU 21 | 22 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | # formats: 24 | # - pdf 25 | # - epub 26 | 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | # - method: pip 31 | # path: . 32 | # extra_requirements: 33 | # - docs 34 | 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kary Främling 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ciu_tests/heart_disease_rf.py: -------------------------------------------------------------------------------- 1 | 2 | def get_heart_disease_rf(inst_ind=0): 3 | """ 4 | :return: heart disease CIU Object with a Random Forest Classifier 5 | """ 6 | from sklearn.ensemble import RandomForestClassifier 7 | from sklearn.model_selection import train_test_split 8 | import pandas as pd 9 | import numpy as np 10 | import ciu as ciu 11 | from ciu.CIU import CIU 12 | 13 | model = RandomForestClassifier(n_estimators=100) 14 | 15 | df = pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data") 16 | df.columns = ["age", "sex", "cp", "trestbps", "chol","fbs", "restecg", 17 | "thalach","exang", "oldpeak","slope", "ca", "thal", "num"] 18 | 19 | df = df.replace({'?':np.nan}).dropna() 20 | 21 | df.loc[df["num"] > 0, "num"] = 1 22 | 23 | # Shortcut here: everything to float 24 | for i in df.columns: 25 | if 'object' in str(df[i].dtypes): 26 | df[i] = df[i].astype(float) 27 | 28 | X = df.drop('num',axis=1) 29 | y = df['num'] 30 | 31 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123) 32 | 33 | model.fit(X_train, y_train) 34 | 35 | CIU = CIU(model.predict_proba, ['No', 'Yes'], data=X_train) 36 | 37 | instance = X_test.iloc[[inst_ind]] 38 | 39 | return CIU, model, instance -------------------------------------------------------------------------------- /ciu_tests/boston_gbm.py: -------------------------------------------------------------------------------- 1 | 2 | def get_boston_gbm_test(inst_ind=1): 3 | """ 4 | :return: CIU, XGB and instance 5 | """ 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import xgboost as xgb 10 | from sklearn.model_selection import train_test_split 11 | import ciu as ciu 12 | from ciu.CIU import CIU 13 | 14 | data_url = "http://lib.stat.cmu.edu/datasets/boston" 15 | raw_df = pd.read_csv(data_url, sep="\s+", skiprows=22, header=None) 16 | data = np.hstack([raw_df.values[::2, :], raw_df.values[1::2, :2]]) 17 | target = raw_df.values[1::2, 2] 18 | 19 | data = pd.DataFrame(data) 20 | data.columns = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'] 21 | 22 | #xgb.DMatrix(data=data,label=target) 23 | 24 | X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=123) 25 | xg_reg = xgb.XGBRegressor(objective ='reg:squarederror', colsample_bytree = 0.3, learning_rate = 0.1,max_depth = 5, alpha = 10, n_estimators = 10) 26 | 27 | xg_reg.fit(X_train,y_train) 28 | 29 | out_minmaxs = pd.DataFrame({'mins': [min(y_train)], 'maxs': max(y_train)}) 30 | out_minmaxs.index = ['Price'] 31 | CIU = CIU(xg_reg.predict, ['Price'], data=X_train, out_minmaxs=out_minmaxs) 32 | 33 | instance = X_test.iloc[[inst_ind]] 34 | 35 | return CIU, xg_reg, instance 36 | -------------------------------------------------------------------------------- /ciu_tests/iris_lda.py: -------------------------------------------------------------------------------- 1 | 2 | def get_iris_test(): 3 | """ 4 | :return: iris LDA CIU Object, LDA model, example instance 5 | """ 6 | 7 | import pandas as pd 8 | import numpy as np 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 11 | from sklearn import datasets 12 | import ciu as ciu 13 | from ciu.CIU import CIU 14 | 15 | 16 | iris=datasets.load_iris() 17 | 18 | df = pd.DataFrame(data = np.c_[iris['data'], iris['target']], 19 | columns = iris['feature_names'] + ['target']) 20 | df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names) 21 | df.columns = ['s_length', 's_width', 'p_length', 'p_width', 'target', 'species'] 22 | iris_outnames = df['species'].cat.categories.tolist() 23 | X = df[['s_length', 's_width', 'p_length', 'p_width']] 24 | y = df['species'] 25 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=123) 26 | 27 | 28 | model = LinearDiscriminantAnalysis() 29 | model.fit(X_train, y_train) 30 | 31 | #iris_df = df.apply(pd.to_numeric, errors='ignore') 32 | 33 | #Can also be written manually: 34 | test_iris = pd.DataFrame.from_dict({'s_length' : [2.0], 's_width' : [3.2], 'p_length': [1.8], 'p_width' : [2.4]}) 35 | 36 | ciu = CIU(model.predict_proba, iris_outnames, data=X_train) 37 | 38 | return ciu, model, test_iris -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Package metadata 4 | NAME = 'py-ciu' # Package name 5 | DESCRIPTION = 'Python implementation of the Contextual Importance and Utility (CIU) explainable AI method' 6 | VERSION = '0.5.0.1' # Use Semantic Versioning (https://semver.org/) 7 | AUTHOR = 'Kary Främling, Vlad Apopei, others' 8 | EMAIL = 'kary.framling@umu.se' 9 | URL = 'https://github.com/KaryFramling/py-ciu' # Repository URL 10 | 11 | # Define your package's dependencies 12 | INSTALL_REQUIRES = [ 13 | 'matplotlib', 14 | 'numpy', 15 | 'pandas', 16 | 'scikit-learn', 17 | 'xgboost', 18 | ] 19 | 20 | # Long description from README.md 21 | #with open('README.md', 'r') as f: 22 | # LONG_DESCRIPTION = f.read() 23 | LONG_DESCRIPTION = 'Please read the README file' 24 | 25 | setup( 26 | name=NAME, 27 | version=VERSION, 28 | description=DESCRIPTION, 29 | long_description=LONG_DESCRIPTION, 30 | long_description_content_type='text/markdown', 31 | author=AUTHOR, 32 | author_email=EMAIL, 33 | url=URL, 34 | packages=find_packages(), 35 | install_requires=INSTALL_REQUIRES, 36 | license='MIT', 37 | classifiers=[ 38 | # 'Development Status :: 3 - Alpha', 39 | # 'Intended Audience :: Developers', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Programming Language :: Python :: 3', 42 | # 'Programming Language :: Python :: 3.7', 43 | # 'Programming Language :: Python :: 3.8', 44 | # 'Programming Language :: Python :: 3.9', 45 | # 'Programming Language :: Python :: 3.10', 46 | # 'Programming Language :: Python :: 3.11', 47 | ], 48 | keywords='Contextual Importance and Utility, CIU, Explainable AI, Explainable Artificial Intelligence', 49 | #project_urls={ 50 | # 'Source': URL, 51 | #}, 52 | ) 53 | -------------------------------------------------------------------------------- /ciu_tests/titanic_rf.py: -------------------------------------------------------------------------------- 1 | def get_titanic_rf(): 2 | """ 3 | Random forest model using the intermediate concepts on the Titanic dataset 4 | :return: Titanic CIU object with intermediate concepts 5 | """ 6 | 7 | import pandas as pd 8 | from sklearn.ensemble import RandomForestClassifier 9 | import ciu as ciu 10 | from ciu.CIU import CIU 11 | 12 | data = pd.read_csv("https://raw.githubusercontent.com/KaryFramling/py-ciu/master/ciu_tests/data/titanic.csv") 13 | data = data.drop(data.columns[0], axis=1) 14 | unused = ['PassengerId','Cabin','Name','Ticket'] 15 | 16 | for col in unused: 17 | data = data.drop(col, axis=1) 18 | 19 | from sklearn.preprocessing import LabelEncoder 20 | data = data.dropna().apply(LabelEncoder().fit_transform) 21 | train = data.drop('Survived', axis=1) 22 | 23 | # Create test instance (8-year old boy) 24 | new_passenger = pd.DataFrame.from_dict({"Pclass" : [1], "Sex": [1], "Age": [8.0], "SibSp": [0], "Parch": [0], "Fare": [72.0], "Embarked": [1]}) 25 | 26 | model = RandomForestClassifier(n_estimators=100) 27 | model.fit(train, data.Survived) 28 | 29 | category_mapping = { 30 | 'Sex': ['female','male'], 31 | 'Pclass': list(range(max(data.Pclass))), 32 | 'SibSp': list(range(max(data.SibSp))), 33 | 'Parch': list(range(max(data.Parch))), 34 | 'Embarked': ["Belfast","Cherbourg","Queenstown","Southampton"] 35 | } 36 | 37 | titanic_voc = { 38 | "Wealth":['Pclass', 'Fare'], 39 | "Family":['SibSp', 'Parch'], 40 | "Gender":['Sex'], 41 | "Age":['Age'], 42 | "Embarked":['Embarked'] 43 | } 44 | 45 | CIU_titanic = CIU(model.predict_proba, ['No', 'Yes'], data=train, category_mapping=category_mapping, vocabulary=titanic_voc) 46 | 47 | return CIU_titanic, model, new_passenger -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Pycharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'py-ciu' 21 | copyright = '2023, Kary Främling' 22 | author = 'Kary Främling' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.duration', 32 | 'sphinx.ext.doctest', 33 | 'sphinx.ext.autodoc', 34 | 'sphinx.ext.autosummary', 35 | ] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ['_templates'] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | # Doing "pip install sphinx_rtd_theme" might be necessary here. 52 | # 53 | #html_theme = 'alabaster' 54 | #html_theme_options = { 55 | # 'navigation_depth': 4, # Set the depth of the navigation menu 56 | # 'collapse_navigation': False, # Set to True to collapse the navigation menu by default 57 | # 'sticky_navigation': True, # Set to True to make the navigation menu sticky 58 | # 'includehidden': False, # Set to True to include hidden TOC entries in the navigation menu 59 | #} 60 | 61 | html_theme = 'sphinx_rtd_theme' 62 | # Theme options are theme-specific and customize the look and feel of a theme 63 | # further. For a list of options available for each theme, see the 64 | # documentation. 65 | # 66 | html_theme_options = { 67 | #'canonical_url': '', 68 | #"logo_only": True, 69 | "display_version": True, 70 | "prev_next_buttons_location": "bottom", 71 | "style_external_links": False, 72 | #"style_nav_header_background": "#343131", 73 | # Toc options 74 | "collapse_navigation": True, 75 | "sticky_navigation": True, 76 | "navigation_depth": 4, 77 | "includehidden": True, 78 | "titles_only": False, 79 | } 80 | 81 | # Add any paths that contain custom static files (such as style sheets) here, 82 | # relative to this directory. They are copied after the builtin static files, 83 | # so a file named "default.css" will overwrite the builtin "default.css". 84 | html_static_path = ['_static'] 85 | 86 | # Latex options. 87 | #latex_engine = 'pdflatex' 88 | #latex_elements = { 89 | # 'papersize': 'letterpaper', 90 | # 'pointsize': '10pt', 91 | #} -------------------------------------------------------------------------------- /ciu_tests/ames_housing_gbm.py: -------------------------------------------------------------------------------- 1 | def get_ames_gbm_test(): 2 | """ 3 | :return: A CIU object and the list of intermediate concepts used in the example. 4 | """ 5 | import pandas as pd 6 | import xgboost as xgb 7 | from sklearn.model_selection import train_test_split 8 | import ciu as ciu 9 | from ciu.CIU import CIU 10 | 11 | df = pd.read_csv('https://raw.githubusercontent.com/KaryFramling/py-ciu/master/ciu_tests/data/AmesHousing.csv') 12 | 13 | #Checking for missing data 14 | missing_data_count = df.isnull().sum() 15 | missing_data_percent = df.isnull().sum() / len(df) * 100 16 | 17 | missing_data = pd.DataFrame({ 18 | 'Count': missing_data_count, 19 | 'Percent': missing_data_percent 20 | }) 21 | 22 | missing_data = missing_data[missing_data.Count > 0] 23 | missing_data.sort_values(by='Count', ascending=False, inplace=True) 24 | 25 | #This one has spaces for some reason 26 | df.columns = df.columns.str.replace(' ', '') 27 | 28 | 29 | #Taking care of missing values 30 | from sklearn.impute import SimpleImputer 31 | # Group 1: 32 | group_1 = [ 33 | 'PoolQC', 'MiscFeature', 'Alley', 'Fence', 'FireplaceQu', 'GarageType', 34 | 'GarageFinish', 'GarageQual', 'GarageCond', 'BsmtQual', 'BsmtCond', 35 | 'BsmtExposure', 'BsmtFinType1', 'BsmtFinType2', 'MasVnrType' 36 | ] 37 | df[group_1] = df[group_1].fillna("None") 38 | 39 | # Group 2: 40 | group_2 = [ 41 | 'GarageArea', 'GarageCars', 'BsmtFinSF1', 'BsmtFinSF2', 'BsmtUnfSF', 42 | 'TotalBsmtSF', 'BsmtFullBath', 'BsmtHalfBath', 'MasVnrArea' 43 | ] 44 | 45 | df[group_2] = df[group_2].fillna(0) 46 | 47 | # Group 3: 48 | group_3a = [ 49 | 'Functional', 'MSZoning', 'Electrical', 'KitchenQual', 'Exterior1st', 50 | 'Exterior2nd', 'SaleType', 'Utilities' 51 | ] 52 | 53 | imputer = SimpleImputer(strategy='most_frequent') 54 | df[group_3a] = pd.DataFrame(imputer.fit_transform(df[group_3a]), index=df.index) 55 | 56 | df.LotFrontage = df.LotFrontage.fillna(df.LotFrontage.mean()) 57 | df.GarageYrBlt = df.GarageYrBlt.fillna(df.YearBuilt) 58 | 59 | #Label encoding 60 | from sklearn.preprocessing import LabelEncoder 61 | df = df.apply(LabelEncoder().fit_transform) 62 | 63 | data = df.drop(columns=['SalePrice']) 64 | data = data.astype(float) # This is a "quick fix" to make everything into float. Some of these would rather need a category mapping. 65 | target = df.SalePrice 66 | 67 | #Splitting and training 68 | X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.3, random_state=123) 69 | xg_reg = xgb.XGBRegressor(objective ='reg:squarederror', colsample_bytree = 0.5, learning_rate = 0.1, max_depth = 15, alpha = 10) 70 | 71 | xg_reg.fit(X_train,y_train) 72 | 73 | ames_voc = { 74 | "Garage":[c for c in df.columns if 'Garage' in c], 75 | "Basement":[c for c in df.columns if 'Bsmt' in c], 76 | "Lot":list(df.columns[[3,4,7,8,9,10,11]]), 77 | "Access":list(df.columns[[13,14]]), 78 | "House_type":list(df.columns[[1,15,16,21]]), 79 | "House_aesthetics":list(df.columns[[22,23,24,25,26]]), 80 | "House_condition":list(df.columns[[20,18,21,28,19,29]]), 81 | "First_floor_surface":list(df.columns[[43]]), 82 | "Above_ground_living area":[c for c in df.columns if 'GrLivArea' in c] 83 | } 84 | 85 | out_minmaxs = pd.DataFrame({'mins': [min(y_train)], 'maxs': max(y_train)}) 86 | out_minmaxs.index = ['Price'] 87 | ciu = CIU(xg_reg.predict, ['Price'], data=X_train, out_minmaxs=out_minmaxs, vocabulary=ames_voc) 88 | 89 | instance = X_test.iloc[[345]] 90 | 91 | return ciu, xg_reg, instance 92 | 93 | -------------------------------------------------------------------------------- /ciu/ciuplots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import plotly.graph_objects as go 5 | from ciu.CIU import contrastive_ciu 6 | 7 | def ciu_beeswarm(df, xcol='CI', ycol='feature', color_col='norm_invals', legend_title=None, jitter_level=0.5, 8 | palette = ["blue", "red"], opacity=0.8): 9 | """ 10 | Create a beeswarm plot of values. This can be used for CI, Cinfl, CU or any values in principle 11 | (including Shapley value, LIME values, ...). 12 | 13 | **Remark:** This has not been tested/implemented for non-numerical values, intermediate concepts etc. 14 | (unlike the R version). 15 | 16 | :param: df: A "long" CIU result DataFrame, typically produced by a call to :func:`ciu.CIU.CIU.explain_all`. 17 | :type df: DataFrame 18 | :param xcol: Name of column to use for X-axis (numerical). 19 | :type xcol: str 20 | :param ycol: Name of column to use for Y-axis, typically the one that contains feature names. 21 | :type ycol: str 22 | :param color_col: Name of column to use for dot color, typically the one that instance/feature values 23 | that are normalised into `[0,1]` interval. 24 | :type color_col: str 25 | :param legend_title: Text to use as legend title. If `None`, then used `color_col`. 26 | :type legend_title: str 27 | :param jitter_level: Level of jitter to use. 28 | :type jitter_level: float 29 | :param palette: Color palette to use. The default value is a list with two colors but can probably be 30 | any kind of palette that is accepted by plotly.graphobjects. 31 | :type palette: list 32 | :param opacity: Opacity value to use for dots. 33 | :type opacity: float 34 | 35 | 36 | :return: A plotly.graphobjects Figure. 37 | """ 38 | 39 | # Deal with None parameters 40 | if legend_title is None: 41 | legend_title = color_col 42 | 43 | N = len(df) 44 | fig = go.Figure() 45 | 46 | features = pd.Categorical(df.loc[:,ycol]).categories 47 | nfeatures= len(features) 48 | for i in range(nfeatures): 49 | f = features[i] 50 | inds = np.where(df.loc[:,ycol]==f)[0] 51 | dfs = df.iloc[inds,:] 52 | marker = dict( 53 | size=9, 54 | color=dfs[color_col], 55 | colorscale=palette, 56 | opacity=opacity, 57 | ) 58 | if i == 0: 59 | marker['colorbar'] = dict(title=legend_title) 60 | fig.add_trace(go.Scatter( 61 | x=dfs.loc[:,xcol], 62 | y=i + np.random.rand(N) * jitter_level, 63 | mode='markers', 64 | marker=marker, 65 | name=f, 66 | )) 67 | fig.update_layout(showlegend=False, coloraxis_showscale=True, legend_title_text='My Legend Title') 68 | fig.update_yaxes(tickvals=list(range(len(features))), ticktext=list(features)) 69 | return fig 70 | 71 | def plot_contrastive(ciures1, ciures2, xminmax=None, main=None, figsize=(6, 4), 72 | colors=("firebrick","steelblue"), edgecolors=("#808080","#808080")): 73 | """ 74 | Create a contrastive plot for the two CIU results passed. This is essentially similar to 75 | an influence plot. 76 | 77 | :param ciures1: See :func:`ciu.CIU.contrastive_ciu` 78 | :type ciures1: DataFrame 79 | :param ciures2: See :func:`ciu.CIU.contrastive_ciu` 80 | :type ciures2: DataFrame 81 | :param xminmax: Min/max values to use for X axis. 82 | :type xminmax: array/list 83 | :param main: Main title to use. 84 | :type main: str 85 | :param figsize: Figure size. 86 | :type figsize: array 87 | :param colors: Bar colors to use. 88 | :type colors: array 89 | :param edgecolors: Bar edge colors to use. 90 | :type edgecolors: array 91 | 92 | :return: A pyplot plot. 93 | """ 94 | contrastive = contrastive_ciu(ciures1, ciures2) 95 | feature_names = ciures1['feature'] 96 | nfeatures = len(feature_names) 97 | 98 | fig, ax = plt.subplots(figsize=figsize) 99 | y_pos = np.arange(nfeatures) 100 | 101 | # cinfl, feature_names = (list(t) for t in zip(*sorted(zip(cinfl, feature_names)))) 102 | 103 | plt.xlabel("ϕ") 104 | for m in range(len(contrastive)): 105 | ax.barh(y_pos[m], contrastive.iloc[m], color=[colors[0] if contrastive.iloc[m] < 0 else colors[1]], 106 | edgecolor=[edgecolors[0] if contrastive.iloc[m] < 0 else edgecolors[1]], zorder=2) 107 | 108 | plt.ylabel("Features") 109 | if xminmax is not None: 110 | ax.set_xlim(xminmax) 111 | if main is not None: 112 | plt.title(main) 113 | 114 | ax.set_facecolor(color="#D9D9D9") 115 | 116 | # Y axis labels 117 | ax.set_yticks(y_pos) 118 | ax.set_yticklabels(feature_names) 119 | ax.grid(which = 'minor') 120 | ax.grid(which='minor', color='white') 121 | ax.grid(which='major', color='white') 122 | 123 | 124 | -------------------------------------------------------------------------------- /ciu/PerturbationMinMaxEstimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from itertools import product 4 | 5 | class PerturbationMinMaxEstimator: 6 | """ 7 | This class is for abstracting the operation of finding minimal and maximal output value(s) 8 | for a given instance and given inputs (input indices). 9 | 10 | PerturbationMinMaxEstimator is mainly meant to be used from CIU, not directly! It is the 11 | default claas used by CIU for finding minimal and maximal output values but it can be 12 | replaced by some other class/object that does it in some (presumably) more efficient way. 13 | This can be useful if some model-specific knowledge is available or if there's a reason to do 14 | the sampling in a more in-distribution way. 15 | 16 | The only compulsory method is ``get_minmax_outvals``, which is the method called by CIU with the 17 | parameters `ìnstance`` and ``indices``. 18 | 19 | :param predictor: The predictor function to call. 20 | :param in_minmaxs: DataFrame with as many rows as features and two columns with min and max 21 | feature values, respectively. 22 | :param nsamples: How many samples to use. 23 | :type nsamples: int 24 | """ 25 | def __init__(self, predictor, in_minmaxs, nsamples): 26 | 27 | self.predictor = predictor 28 | self.in_minmaxs = in_minmaxs 29 | self.nsamples = nsamples 30 | 31 | def get_minmax_outvals(self, instance, indices, category_mapping=None): 32 | """ 33 | Find the minimal and maximal output value(s) that can be obtained by modifying the inputs 34 | ``indices`` of the instance ``instance``. 35 | 36 | :param instance: The instance to generate the permuted instances for. 37 | :param indices: list of indices for which to generate perturbed values. 38 | 39 | :return: Two np.arrays with mininmal and maximal output values found for the input or 40 | coalition of inputs in ``indices``. 41 | """ 42 | samples = self._generate_samples(instance, indices, category_mapping) 43 | samples_out = self.predictor(samples) 44 | if samples_out.ndim == 1: 45 | samples_out = samples_out[:, np.newaxis] 46 | maxs = np.amax(samples_out,axis=0) 47 | mins = np.amin(samples_out,axis=0) 48 | 49 | return mins, maxs 50 | 51 | def _generate_samples(self, instance, indices, category_mapping=None): 52 | """ 53 | Generate a list of instances for estimating CIU. 54 | 55 | :param instance: The instance to generate the permuted instances for. 56 | :param indices: list of indices for which to generate perturbed values. 57 | 58 | :return: DataFrame with perturbed instances. 59 | """ 60 | samples_to_do = self.nsamples - 1 # We include the instance in the count 61 | 62 | # Separate indices for numeric features and "category" features. 63 | category_indices = [] 64 | fnames = self.in_minmaxs.index 65 | if category_mapping is None: 66 | numeric_indices = indices 67 | else: 68 | numeric_indices = [] 69 | category_values = [] 70 | for i in indices: 71 | if fnames[i] in category_mapping: 72 | category_indices.append(i) 73 | catvals = category_mapping[fnames[i]] 74 | # String values have to be converted into numerical. 75 | if isinstance(catvals[0], str): 76 | catvals = list(range(len(catvals))) 77 | category_values.append(catvals) 78 | else: 79 | numeric_indices.append(i) 80 | 81 | # Get categorically perturbed samples. 82 | catmat = None 83 | if len(category_indices) > 0: 84 | catmat = product(*category_values) 85 | catmat = pd.DataFrame(list(product(*category_values)), columns=fnames[category_indices]) 86 | # If the number of value combinations is bigger than the requested number, then 87 | # we need to increase the number accordingly for the numeric features. 88 | if catmat.shape[0] > samples_to_do: 89 | samples_to_do = catmat.shape[0] 90 | 91 | # Get numerically perturbed samples. 92 | # TO-DO: Add rows with all xmin and xmax value combinations for all features. 93 | numvals = None 94 | if len(numeric_indices) > 0: 95 | # First generate all min/max combinations 96 | mins = np.array(self.in_minmaxs.iloc[numeric_indices,0]) 97 | maxs = np.array(self.in_minmaxs.iloc[numeric_indices,1]) 98 | minmaxgrid = pd.DataFrame(list(product(*self.in_minmaxs.values[numeric_indices,:]))) 99 | nrsamples_to_do = max(0, samples_to_do - minmaxgrid.shape[0]) 100 | # Then fill up the rest with random numbers 101 | if nrsamples_to_do > 0: 102 | numvals = np.random.rand(nrsamples_to_do, len(numeric_indices)) 103 | numvals = mins + (maxs - mins)*numvals 104 | numvals = pd.concat([minmaxgrid, pd.DataFrame(numvals)], ignore_index=True) 105 | else: 106 | numvals = minmaxgrid 107 | # This is needed for the case if the number of numeric min/max value combinations 108 | # becomes greater than the equested number of samples. 109 | if numvals.shape[0] > samples_to_do: 110 | samples_to_do = numvals.shape[0] 111 | 112 | # Merge numerical and categorical so that the total number of samples is 113 | # max(self.nsamples, rows_in_categorical) 114 | 115 | # If no numeric values, then we use only the categorical ones that we have 116 | if numvals is None: 117 | samples_to_do = catmat.shape[0] 118 | samples = pd.concat([instance] * samples_to_do, ignore_index=True) 119 | 120 | # We have categorical values. 121 | if catmat is not None: 122 | # Here we may have to expand the categorical values to have same number 123 | # of rows as the numerical. 124 | if numvals is not None and catmat.shape[0] < numvals.shape[0]: 125 | catmat_nrows = catmat.shape[0] 126 | numvals_nrows = numvals.shape[0] 127 | ncopies = int(numvals_nrows/catmat_nrows) 128 | nrows = (numvals_nrows - catmat_nrows) % catmat_nrows 129 | l = [pd.concat([catmat]*ncopies), catmat.iloc[np.random.randint(0, catmat_nrows-1, nrows),:]] 130 | catmat = pd.concat(l) 131 | samples.iloc[:,category_indices] = catmat 132 | # We insert numerical columns if we have some 133 | if numvals is not None: 134 | samples.iloc[:,numeric_indices] = numvals 135 | 136 | return pd.concat([instance, samples]) 137 | 138 | 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # py-ciu 2 | 3 | *Explainable Machine Learning through Contextual Importance and Utility* 4 | 5 | The *py-ciu* library provides methods to generate post-hoc explanations for 6 | machine learning-based classifiers. 7 | 8 | **NOTE: This implementation is the result of a major rewrite made in November/December 2023. This was made in order to make the use of this py-ciu package similar to the [R CIU](https://github.com/KaryFramling/ciu) and [py.ciu.image](https://github.com/KaryFramling/py.ciu.image) (re-written in Nov 2023) packages. Also, the core CIU development is still done in "R" and some of the functionality present in the R version may not be available in this python version.** 9 | 10 | The version of py-ciu that has been replaced on 12 January 2024 is available in the branch `VersionUntilNov2023`. 11 | 12 | ## Usage 13 | 14 | First install the `py-ciu` package. The recommended approach is to clone this repository and use it from there. The current version has been tested with Python versions 3.11.5 and 3.12.1. More extensive testing against different versions will be performed in the future. 15 | 16 | The other approach is to install it using `pip install py-ciu` but since the development of CIU is sometimes quite rapid, you should not except that all functionality is available in that version. 17 | 18 | A quick overview of the use of py-ciu with different datasets amd ML models is available as a Jupyter notebook [README_notebook.ipynb](README_notebook.ipynb). All the notebooks have been ran in Visual Studio Code, presumably with the latest versions of all packages. There will be errors or strange results in case of library version mismatches and that seems to happen quite frequently for Python libraries. So if your results are not identical to the ones that you see in these notebooks, then you probably have to update some/many libraries in your installation. 19 | 20 | Other notebooks available in the repository are: 21 | 22 | - [BostonTests.ipynb](BostonTests.ipynb): Examples of py-ciu use for Boston data set. This notebook provides a good overview of CIU capabilities for a regression task. 23 | - [TitanicTests.ipynb](TitanicTests.ipynb): Examples of py-ciu use for Titanic data set. This notebook provides a good overview of CIU capabilities for a classification task. It also gives a small example of the use of CIU's "intermediate concepts". 24 | - [AmesHousingTests.ipynb](AmesHousingTests.ipynb): Examples of py-ciu use for Ames housing data set. This data set has 80 input features and provides a good example of the use of CIU's "intermediate concepts", as well as why they are necessary in order to give "correct" explanations even in the presence of dependencies between features (which is also the case for Titanic). 25 | - [IrisTests.ipynb](IrisTests.ipynb): Examples of py-ciu use for Iris data set. This notebook includes some "low-level" use of the package that may not be found in the other notebooks. 26 | - [Beeswarms.ipynb](Beeswarms.ipynb): Examples of beeswarm plots as a means of "global explanation" by visualising importance/influence/utility/whatever for a set of instances. 27 | - [ContrastiveTests.ipynb](ContrastiveTests.ipynb): Examples of contrastive "Why A and not B?" and ""Why not B but rather A?" explanations. 28 | - [RunTests.ipynb](RunTests.ipynb): Notebook for running various tests that are found in the `ciu_tests` directory. 29 | 30 | ## Documentation 31 | 32 | The package has been documented using Sphinx and is available at [https://py-ciu.readthedocs.io/](https://py-ciu.readthedocs.io/). Sphinx HTML documentation can also be generated by doing `cd docs` and then `make html`. Other formats such as PDF can also be produced, as for the PDF version at: [docs/py-ciu.pdf](docs/py-ciu.pdf). 33 | 34 | # What is CIU? 35 | 36 | **Remark**: It seems like Github Markdown doesn’t show correctly the “{” 37 | and “}” characters in Latex equations, whereas they are shown correctly 38 | in Rstudio. Therefore, in most cases where there is an $i$ shown in 39 | Github, it actually signifies `{i}` and where there is an $I$ it 40 | signifies `{I}`. 41 | 42 | CIU is a model-agnostic method for producing outcome explanations of 43 | results of any “black-box” model `y=f(x)`. CIU directly estimates two 44 | elements of explanation by observing the behaviour of the black-box 45 | model (without creating any “surrogate” model `g` of `f(x)`). 46 | 47 | **Contextual Importance (CI)** answers the question: ***how much can the 48 | result (or the utility of it) change as a function of feature*** $i$ or a 49 | set of features $\{i\}$ jointly, in the context $x$? 50 | 51 | **Contextual Utility (CU)** answers the question: ***how favorable is the 52 | current value*** of feature $i$ (or a set of features $\{i\}$ jointly) for a good 53 | (high-utility) result, in the context $x$? 54 | 55 | CI of one feature or a set of features (jointly) $\{i\}$ compared to a 56 | superset of features $\{I\}$ is defined as 57 | 58 | $$ 59 | \omega_{j,\{i\},\{I\}}(x)=\frac{umax_{j}(x,\{i\})-umin_{j}(x,\{i\})}{umax_{j}(x,\{I\})-umin_{j}(x,\{I\})}, 60 | $$ 61 | 62 | where $\{i\} \subseteq \{I\}$ and $\{I\} \subseteq \{1,\dots,n\}$. $x$ 63 | is the instance/context to be explained and defines the values of input 64 | features that do not belong to $\{i\}$ or $\{I\}$. In practice, CI is 65 | calculated as: 66 | 67 | $$ 68 | \omega_{j,\{i\},\{I\}}(x)= \frac{ymax_{j,\{i\}}(x)-ymin_{j,\{i\}}(x)}{ ymax_{j,\{I\}}(x)-ymin_{j,\{I\}}(x)}, 69 | $$ 70 | 71 | where $ymin_{j}()$ and $ymax_{j}()$ are the minimal and maximal $y_{j}$ 72 | values observed for output $j$. 73 | 74 | CU is defined as 75 | 76 | $$ 77 | CU_{j,\{i\}}(x)=\frac{u_{j}(x)-umin_{j,\{i\}}(x)}{umax_{j,\{i\}}(x)-umin_{j,\{i\}}(x)}. 78 | $$ 79 | 80 | When $u_{j}(y_{j})=Ay_{j}+b$, this can be written as: 81 | 82 | $$ 83 | CU_{j,\{i\}}(x)=\left|\frac{ y_{j}(x)-yumin_{j,\{i\}}(x)}{ymax_{j,\{i\}}(x)-ymin_{j,\{i\}}(x)}\right|, 84 | $$ 85 | 86 | where $yumin=ymin$ if $A$ is positive and $yumin=ymax$ if $A$ is 87 | negative. 88 | 89 | # Related resources 90 | 91 | The original R implementation can be found at: 92 | 93 | There are also two implementations of CIU for explaining images: 94 | 95 | - Python: 96 | - R: 97 | 98 | Future work on image explanation will presumably focus on the Python version, due to the extensive use of deep neural networks that tend to be implemented mainly for Python. 99 | 100 | ## Authors 101 | * [Kary Främling](https://github.com/KaryFramling) 102 | * [Vlad Apopei](https://github.com/vladapopei/) 103 | * [Timotheus Kampik](https://github.com/TimKam/) 104 | 105 | The first version of py-ciu was mainly implemented by [Timotheus Kampik](https://github.com/TimKam/). The old code is available in the branch "Historical". 106 | 107 | The re-write in 2022 was mainly made by [Vlad Apopei](https://github.com/vladapopei/) and is available in the branch "VersionUntilNov2023". 108 | 109 | 110 | -------------------------------------------------------------------------------- /ciu/CIU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | from matplotlib import cm, colors, ticker 5 | from .PerturbationMinMaxEstimator import PerturbationMinMaxEstimator 6 | 7 | class CIU: 8 | """ 9 | The CIU class implements the Contextual Importance and Utility method for Explainable AI. 10 | 11 | The method :func:`explain_core` contains all the CIU mathematics. However, it is probably not the method that 12 | would normally be called directly because it estimates CIU for a coalition of inputs (which works 13 | both for individual features and for CIU's *Intermediate Concepts*). It returns a list of DataFrames 14 | with CIU results, where each DataFrame corresponds to the explanation of one output. 15 | 16 | The methods that would normally be called are :func:`explain` (for individual features), :func:`explain_voc` for 17 | Intermediate Concepts (/coalitions of features), and :func:`explain_all` for a set of instances. These all 18 | return a DataFrame with (presumably) all useful CIU result information (CI, CU, Contextual influence etc.). 19 | 20 | Then there are also various methods for presenting CIU results graphically and textually. Some of these wouldn't 21 | necessarily have to be methods of the `CIU` class but they have been included here as a compromise. 22 | 23 | :param predictor: Model prediction function to be used. 24 | :param [str] out_names: List of names for the model outputs. This parameter is compulsory because 25 | it is used for determining how many outputs there are and initializing ``out_minmaxs`` to 0/1 if 26 | they are not provided as parameters. 27 | :param DataFrame data: Data set to use for inferring min and max input values. Only 28 | needed if ``in_minmaxs`` is not provided. 29 | :param [str] input_names: list of input column names in ``data``. 30 | :param DataFrame in_minmaxs: Pandas DataFrame with columns ``min`` and ``max`` and one row per input. If this parameter 31 | is provided, then ``data`` does not need to be passed. 32 | :param DataFrame out_minmaxs: Pandas DataFrame with columns ``min`` and ``max`` and one row per 33 | model output. If the value is ``None``, then ``out_minmaxs`` is initialized to ``[0,1]`` for 34 | all outputs. In practice this signifies that this parameter is typically not needed for classification 35 | tasks but is necessary to provide or regression tasks. 36 | :param int nsamples: Number of samples to use for estimating CIU of numerical inputs. 37 | :param dict category_mapping: Dictionary that contains names of features that should be dealt with as categories, i.e. 38 | having discrete int/str values. The use of this mapping is strongly recommended for efficiency and accuracy reasons! 39 | In the "R" implementation such a mapping is not needed because the `factor` column type indicates the columns and 40 | the possible values. The corresponding `Categorical` type doesn't seem to be used consistently in Python ML 41 | libraries so it didn't seem like a good choice to use that for the moment. 42 | :param float neutralCU: Reference/baseline value to use for Contextual influence. 43 | :param [int] output_inds: Default output index/indices to explain. This value doesn't have to be given as a list, it can also 44 | be a single integer (that is automatically converted into a list). 45 | :param dict vocabulary: Vocabulary to use. 46 | :param object minmax_estimator: Object to be used for estimating ymin/ymax values, if something else is to be used than the 47 | default one. 48 | """ 49 | def __init__( 50 | self, 51 | predictor, 52 | out_names, 53 | data=None, 54 | input_names=None, 55 | in_minmaxs=None, 56 | out_minmaxs=None, 57 | nsamples=100, 58 | category_mapping=None, 59 | neutralCU=0.5, 60 | output_inds=[0], 61 | vocabulary=None, 62 | minmax_estimator=None 63 | ): 64 | self.out_names=out_names 65 | if out_minmaxs is None: 66 | self.out_minmaxs = pd.DataFrame({'mins': 0, 'maxs': 1}, index=range(len(out_names))) 67 | self.out_minmaxs.index = out_names 68 | else: 69 | self.out_minmaxs = out_minmaxs 70 | if data is not None: 71 | input_names = list(data.columns) 72 | if in_minmaxs is None: 73 | try: 74 | self.in_minmaxs = pd.DataFrame({'mins': data[input_names].min(), 'maxs': data[input_names].max()}) 75 | except: 76 | print("Logic Error: You must provide either min_max values or a dataset and input names from which they can be inferred.") 77 | raise 78 | else: 79 | self.in_minmaxs=in_minmaxs 80 | input_names = list(self.in_minmaxs.index) 81 | 82 | self.predictor = predictor 83 | self.data = data 84 | self.input_names = input_names 85 | self.nsamples = nsamples 86 | self.category_mapping = category_mapping 87 | self.neutralCU = neutralCU 88 | self.output_inds = [output_inds] if isinstance(output_inds, int) else output_inds 89 | self.vocabulary = vocabulary 90 | self.minmax_estimator = minmax_estimator 91 | 92 | # Other instance variables 93 | self.instance = None 94 | self.last_ciu_result = None 95 | 96 | def explain_core(self, coalition_inputs, instance=None, output_inds=None, feature_name=None, nsamples=None, neutralCU = None, 97 | target_inputs=None, out_minmaxs=None, target_concept=None): 98 | """ 99 | Calculate CIU for a coalition of inputs. This is the "core" CIU method with the actual 100 | CIU calculations. All other methods should call this one for doing actual CIU calculations. 101 | 102 | Coalitions of inputs are used for defining CIU's "intermediate concepts". It signifies that all the 103 | inputs in the coalition are perturbed at the same time. 104 | 105 | :param [int] coalition_inputs: list of input indices. 106 | :param DataFrame instance: Instance to be explained. If ``instance=None`` then 107 | the last passed instance is used by default. 108 | :param output_inds: See corresponding parameter of :class:`CIU` constructor method. Default value ``None`` will use 109 | the value given to constructor method. 110 | :param str feature_name: Feature name to use for coalition of inputs (i.e. if more than one input index is given), 111 | instead of the default "Coalition of..." feature name. 112 | :param int nsamples: See corresponding parameter of constructor method. Default value ``None`` will use 113 | the value given to constructor method. 114 | :param float neutralCU: See corresponding parameter of constructor method. Default value ``None`` will use 115 | the value given to constructor method. 116 | :param [int] target_inputs: list of input indices for "target" concept, i.e. a CIU "intermediate concept". 117 | Normally "coalition_inputs" should be a subset of "target_inputs" but that is not a requirement, 118 | mathematically taken. Default is None, which signifies that the model outputs (i.e. "all inputs") 119 | are the targets and the "out_minmaxs" values are used for CI calculation. 120 | :param DataFrame out_minmaxs: DataFrame with min/max output values to use instead of the "global" ones. This is used 121 | for implementing Intermediate Concept calculations. The DataFrame must have one row per output and two 122 | columns, preferably named `ymin` and `ymax`. 123 | :param str target_concept: Name of the target concept. This is not used for calculations, it is only for filling up 124 | the ``target_concept`` coliumn of the CIU results. 125 | 126 | :return: A ``list`` of DataFrames with CIU results, one for each output of the model. **Remark:** `explain_core()` 127 | indeed returns a `list`, which is a difference compared to the two other `explain_` methods! 128 | """ 129 | # Deal with parameter values, especially None 130 | if instance is not None: 131 | self.instance = instance 132 | if self.instance is None: 133 | raise ValueError("No instance to explain has been given.") 134 | if nsamples is None: 135 | nsamples = self.nsamples 136 | if neutralCU is None: 137 | neutralCU = self.neutralCU 138 | if output_inds is None: 139 | output_inds = self.output_inds 140 | if isinstance(output_inds, int): 141 | output_inds = [output_inds] 142 | 143 | # Predict current instance. 144 | outvals = self.predictor(self.instance) 145 | # We want to make sure that we have a matrix, not an array. 146 | if outvals.ndim == 1: 147 | outvals = outvals[:, np.newaxis] 148 | 149 | # Abstraction of MinMaxEstimator comes here, i.e. use default one if none defined. 150 | if self.minmax_estimator is None: 151 | estimator = PerturbationMinMaxEstimator(self.predictor, self.in_minmaxs, nsamples) 152 | else: 153 | estimator = self.minmax_estimator 154 | mins, maxs = estimator.get_minmax_outvals(instance, coalition_inputs, self.category_mapping) 155 | 156 | # If "target_inputs" is given, then we need to get "outmin" and "outmax" values for that 157 | # coalition of inputs, rather than for the final outputs. 158 | if out_minmaxs is not None: 159 | outmins = out_minmaxs.iloc[:,0] 160 | outmaxs = out_minmaxs.iloc[:,1] 161 | elif target_inputs is not None: 162 | target_cius = self.explain_core(target_inputs, instance, output_inds=output_inds, nsamples=nsamples, neutralCU=neutralCU) 163 | all_cius = pd.concat(target_cius) 164 | outmins = all_cius.loc[:,'ymin'] 165 | outmaxs = all_cius.loc[:,'ymax'] 166 | else: 167 | outmins = self.out_minmaxs.iloc[output_inds,0] 168 | outmaxs = self.out_minmaxs.iloc[output_inds,1] 169 | 170 | # Create CIU result for each requested output. 171 | cius = [] 172 | for i, outi in enumerate(output_inds): 173 | outval = outvals[0,outi] 174 | ci = (maxs[outi] - mins[outi])/(outmaxs.iloc[i] - outmins.iloc[i]) if (outmaxs.iloc[i] - outmins.iloc[i]) != 0 else 0 175 | cu = (outval - mins[outi])/(maxs[outi] - mins[outi]) if (maxs[outi] - mins[outi]) != 0 else 0 176 | cinfl = ci*(cu - neutralCU) 177 | if len(coalition_inputs) == 1: 178 | fname = self.input_names[coalition_inputs[0]] 179 | else: 180 | fname = "Coalition of %i inputs" % len(coalition_inputs) if feature_name is None else feature_name 181 | invals = self.instance.iloc[0,coalition_inputs].values 182 | ciu = pd.DataFrame({'CI': [ci], 'CU': [cu], 'Cinfl': [cinfl], 'outname': [self.out_names[outi]], 'outval': [outval], 183 | 'feature': [fname], 'ymin': [mins[outi]], 'ymax': [maxs[outi]], 184 | 'inputs': [coalition_inputs], 'invals':[invals], 'neutralCU':[neutralCU], 185 | 'target_concept': [target_concept], 'target_inputs': [target_inputs]}) 186 | ciu.index.name = 'Feature' 187 | ciu.index = [[fname]] 188 | cius.append(ciu) 189 | return cius 190 | 191 | def explain(self, instance=None, output_inds=None, input_inds=None, nsamples=None, neutralCU=None, 192 | vocabulary=None, target_concept=None, target_ciu=None): 193 | """ 194 | Determines contextual importance and utility for a given instance (set of input/feature values). 195 | This method calculates CIU values only for individual features (not for Intermediate Concepts / 196 | coalitions of features), so if ``input_inds`` is given, then the returned CIU DataFrame will have 197 | the individual CI, CU etc values. If ``input_inds=None``, then CIU results are returned for all 198 | inputs/features. 199 | 200 | :param DataFrame instance: Instance to be explained. If ``instance=None`` then 201 | the last passed instance is used by default. 202 | :param [int] output_inds: Index of model output to explain. Default is None, in which case it is the 203 | ``output_inds`` value given to the :class:`CIU` constructor. This value doesn't have to be 204 | given as a list, it can also be a single integer (that is automatically converted into a list). 205 | :param [int] input_inds: List of input indices to include in explanation. Default is None, which 206 | signifies "all inputs". 207 | :param int nsamples: Number of samples to use. Default is ``None``, which means using the 208 | value of the :class:`CIU` constructor. 209 | :param float neutralCU: Value to use for "neutral CU" in Contextual influence calculation. 210 | Default is ``None`` because this parameter is only intended to temporarily override the value 211 | given to the :class:`CIU` constructor. 212 | :param dict vocabulary: Vocabulary to use. Only needed for overriding the default 213 | vocabulary given to :class:`CIU` constructor and if there's a ``target_concept``. 214 | :param str target_concept: Name of target concept, if the explanation is for an intermediate concept 215 | rather than for the output value. 216 | :param DataFrame target_ciu: If a CIU result already exists for the target_concept, then it can be passed with 217 | this parameter. Doing so avoids extra calculations and also avoids potential noise due to 218 | perturbation randomness in CIU calculations. 219 | 220 | :return: DataFrame with CIU results for the requested output(s). 221 | """ 222 | # Deal with None parameters. 223 | if vocabulary is None: 224 | vocabulary = self.vocabulary 225 | if output_inds is None: 226 | output_inds = self.output_inds 227 | else: 228 | if isinstance(output_inds, int): 229 | output_inds = [output_inds] 230 | out_minmaxs = None 231 | if target_concept is None: 232 | target_inds = None 233 | if input_inds is None: 234 | input_inds = list(range(len(self.input_names))) 235 | else: 236 | input_inds = [instance.columns.get_loc(col) for col in self.vocabulary[target_concept]] 237 | target_inds = [self.input_names.index(value) for value in vocabulary[target_concept]] 238 | if target_ciu is not None: 239 | out_minmaxs = target_ciu.loc[target_concept,['ymin','ymax']] 240 | 241 | # Do the actual work: call explain_core for every input index. 242 | cius = [] 243 | for i in input_inds: 244 | ciu = self.explain_core([i], instance, output_inds=output_inds, nsamples=nsamples, neutralCU=neutralCU, 245 | target_concept=target_concept, target_inputs=target_inds, out_minmaxs=out_minmaxs) 246 | ciu = pd.concat(ciu) 247 | cius.append(ciu) 248 | 249 | # Memorize last result for direct plotting 250 | ciu = pd.concat(cius) 251 | self.last_ciu_result = ciu 252 | return ciu 253 | 254 | def explain_voc(self, instance=None, output_inds=None, input_concepts=None, nsamples=None, neutralCU=None, 255 | vocabulary=None, target_concept=None, target_ciu=None): 256 | """ 257 | Determines contextual importance and utility for a given instance (set of input/feature values), 258 | using the intermediate concept vocabulary. 259 | 260 | :param DataFrame instance: See :func:`explain`. 261 | :param [int] output_inds: See :func:`explain`. 262 | :param [str] input_concepts: List of concepts to include in the explanation. Default is None, which 263 | signifies "all concepts in the vocabulary". 264 | :param int nsamples: See :func:`explain`. 265 | :param float neutralCU: See :func:`explain`. 266 | :param dict vocabulary: Vocabulary to use. Only needed for overriding the default 267 | vocabulary given to :class:`CIU` constructor. 268 | :param str target_concept: See :func:`explain`. 269 | :param DataFrame target_ciu: See :func:`explain`. 270 | 271 | :return: DataFrame with CIU results for the requested output(s). 272 | """ 273 | 274 | # Deal with None parameters. 275 | if vocabulary is None: 276 | vocabulary = self.vocabulary 277 | if output_inds is None: 278 | output_inds = self.output_inds 279 | else: 280 | if isinstance(output_inds, int): 281 | output_inds = [output_inds] 282 | out_minmaxs = None 283 | if target_concept is None: 284 | target_inds = None 285 | else: 286 | target_inds = [self.input_names.index(value) for value in vocabulary[target_concept]] 287 | if target_ciu is not None: 288 | out_minmaxs = target_ciu.loc[target_concept,['ymin','ymax']] 289 | if input_concepts is None: 290 | input_concepts = list(vocabulary.keys()) 291 | 292 | # Do the actual work: call explain_core for every input index. 293 | cius = [] 294 | for ic in input_concepts: 295 | inds = [self.input_names.index(value) for value in vocabulary[ic]] 296 | ciu = self.explain_core(inds, instance, output_inds=output_inds, nsamples=nsamples, neutralCU=neutralCU, feature_name=ic, 297 | target_concept=target_concept, target_inputs=target_inds, out_minmaxs=out_minmaxs) 298 | ciu = pd.concat(ciu) 299 | cius.append(ciu) 300 | 301 | # Memorize last result for direct plotting 302 | ciu = pd.concat(cius) 303 | self.last_ciu_result = ciu 304 | return ciu 305 | 306 | def explain_all(self, data=None, output_inds=None, input_inds=None, nsamples=None, neutralCU=None, 307 | vocabulary=None, target_concept=None, target_ciu=None, do_norm_invals=False): 308 | """ 309 | Do CIU for all instances in `data`. 310 | 311 | :param DataFrame data: DataFrame with all instances to evaluate. 312 | :param [int] output_inds: See :func:`explain`. 313 | :param [int] input_inds: See :func:`explain`. 314 | :param int nsamples: See :func:`explain`. 315 | :param float neutralCU: See :func:`explain`. 316 | :param dict vocabulary: See :func:`explain`. 317 | :param str target_concept: See :func:`explain`. 318 | :param DataFrame target_ciu: See :func:`explain`. 319 | :param boolean do_norm_invals: Should a column with normalized input values be produced or not? This 320 | can only be done for "basic" features, not for coalitions of features (intermediate concepts) at 321 | least for the moment. It is useful to provide normalized input values for getting more 322 | meaningful beeswarm plots, for instance. 323 | 324 | :return: DataFrame with CIU results of all instances concatenated. 325 | """ 326 | # Deal with None parameters. 327 | if data is None: 328 | if self.data is not None: 329 | data = self.data 330 | else: 331 | raise ValueError("No data provided.") 332 | if output_inds is None: 333 | output_inds = self.output_inds 334 | else: 335 | if isinstance(output_inds, int): 336 | output_inds = [output_inds] 337 | 338 | # Get values that are needed for normalizing input values. 339 | if do_norm_invals: 340 | minmaxrows = self.in_minmaxs.loc[list(data.columns),:] 341 | mins = np.array(minmaxrows.iloc[:,0]) 342 | maxs = np.array(minmaxrows.iloc[:,1]) 343 | ranges = maxs - mins 344 | 345 | # Go through all the instances (rows) in the data 346 | ciu_res = [] 347 | for i in range(len(data)): 348 | instance = data.iloc[[i]] 349 | ciu = self.explain(instance, output_inds=output_inds, input_inds=input_inds, 350 | nsamples=nsamples, neutralCU=neutralCU, vocabulary=vocabulary, 351 | target_concept=target_concept, target_ciu=target_ciu) 352 | row_name = list(instance.index)[0] 353 | ciu['instance_name'] = [str(row_name)]*len(ciu) 354 | if do_norm_invals: 355 | ciu['norm_invals'] = ciu['invals'] 356 | ciu = ciu.explode('norm_invals',) # Get rid of list 357 | ninvals = np.array(ciu.loc[:,'norm_invals']) 358 | ciu.loc[:,'norm_invals'] = (ninvals - mins)/ranges 359 | ciu_res.append(ciu) 360 | return pd.concat(ciu_res, ignore_index=True) 361 | 362 | #============================================================================================ 363 | # Plotting and textual explanation functions here, which tend to become very long. 364 | #============================================================================================ 365 | 366 | # Input/output plot, with possibility to illustrate CIU. 367 | def plot_input_output(self, instance=None, ind_input=0, output_inds=0, in_min_max_limits=None, 368 | n_points=40, main=None, xlab="x", ylab="y", ylim=0, figsize=(6, 4), 369 | illustrate_CIU=False, legend_location=0, neutralCU=None, 370 | CIU_illustration_colours=("red","green","orange")): 371 | """ 372 | Plot model output(s) value(s) as a function on one input. Works both for numerical and for 373 | categorical inputs. 374 | 375 | :param DataFrame instance: See :func:`explain`. If `None`, then use last instance passed to an `explain_()` method. 376 | :param int ind_input: Index of input to use. 377 | :param output_inds: Integer value, list of integers or None. If None then all outputs are plotted. 378 | Default: 0. 379 | :type output_inds: int, [int], None 380 | :param [int] in_min_max_limits: Limits to use for input values. If None, the default ones are used. 381 | :param int n_points: Number of x-values to use for numerical inputs. 382 | :param str xlab: X-axis label. 383 | :param str ylab: Y-axis label. 384 | :param ylim: Value limits for y-axis. Can be zero, actual limits or None. Zero signifies that the known 385 | min/max values for the output will be used. ``None`` signifies that no limits are defined and are 386 | auto-determined by ``plt.plot``. If actual limits are given, they are passed to ``plt.ylim`` as such. 387 | Default: zero. 388 | :type ylim: int, (min, max), None 389 | :param (int,int) figsize: Figure size to use. 390 | :param boolean illustrate_CIU: Plot CIU illustration or not? 391 | :param legend_location: See :func:`matplotlib.pyplot.legend` 392 | :param float neutral_CU: Neutral CU value to use for plotting Contextual influence reference value. 393 | :param (str,str) CIU_illustration_colours: Colors to use for CIU illustration, in order: `(ymin,ymax,neutral.CU)`. 394 | 395 | :return: matplotlib.figure.Figure 396 | """ 397 | 398 | # Deal with None parameters and other parameter value arrangements. 399 | if instance is None: 400 | instance = self.instance 401 | if output_inds is None: 402 | output_inds = list(range(len(self.out_names))) 403 | elif type(output_inds) is not list: 404 | output_inds = [output_inds] 405 | if neutralCU is None: 406 | neutralCU = self.neutralCU 407 | 408 | # Check is it's a numeric or categorical input. 409 | fname = self.input_names[ind_input] 410 | if self.category_mapping is None or fname not in self.category_mapping: 411 | input_type = 'N' 412 | else: 413 | input_type = 'C' 414 | 415 | # First deal with "numeric" possibility. 416 | if input_type == 'N': 417 | if in_min_max_limits is None: 418 | in_min_max_limits = self.in_minmaxs.iloc[ind_input,:] 419 | in_min = in_min_max_limits.iloc[0] 420 | in_max = in_min_max_limits.iloc[1] 421 | interv = (in_max - in_min)/n_points 422 | x = np.arange(in_min, in_max + interv, interv) # Added + interv to ensure that in_max is included in x 423 | m = np.tile(instance, (len(x), 1)) # Used len(x) to get the correct size for m 424 | else: 425 | x = self.category_mapping[fname] 426 | xlabels = x 427 | if isinstance(x[0], str): 428 | x = list(range(len(x))) 429 | m = np.tile(instance, (len(x), 1)) 430 | 431 | m[:,ind_input] = x 432 | y = self.predictor(pd.DataFrame(m, columns=self.input_names)) 433 | if y.ndim == 1: 434 | y = y[:,np.newaxis] 435 | outvals = self.predictor(instance) 436 | if outvals.ndim == 1: 437 | outvals = outvals[:,np.newaxis] 438 | 439 | # Do actual plotting. If None is given, then we plot all outputs 440 | plt_out_names = [self.out_names[i] for i in output_inds] if len(output_inds) > 1 else self.out_names[output_inds[0]] 441 | fig, ax = plt.subplots(figsize=figsize) 442 | 443 | if input_type == 'N': 444 | plt.plot(x, y[:, output_inds], label=plt_out_names) 445 | # circle_radius = 0.5 446 | # plt.scatter(instance.iloc[0,ind_input], cu_val, color='red', marker='o', s=circle_radius**2 * 100) 447 | else: 448 | plt.bar(xlabels, y[:, output_inds[0]], label=plt_out_names) 449 | 450 | # Plot current value(s) as dot(s) 451 | repx = np.repeat(instance.iloc[0,ind_input], len(output_inds)) 452 | plt.scatter(repx, outvals[0,output_inds], color='red', marker='o', label='out') # This radius seems OK 453 | 454 | # Decide on y-limits 455 | if ylim == 0: 456 | ylim = (np.amin(self.out_minmaxs.iloc[[0],0].iloc[0]), np.amax(self.out_minmaxs.iloc[[0],1].iloc[0])) 457 | plt.ylim(ylim) 458 | elif ylim is not None: 459 | plt.ylim(ylim) 460 | 461 | if illustrate_CIU: 462 | y_min = np.amin(y[:, output_inds]) 463 | plt.axhline(y=y_min, color=CIU_illustration_colours[0], linestyle='--', label='ymin') 464 | y_max = np.amax(y[:, output_inds]) 465 | plt.axhline(y=y_max, color=CIU_illustration_colours[1], linestyle='--', label='ymax') 466 | if neutralCU is not None: 467 | y_neutral = y_min + neutralCU*(y_max - y_min) 468 | plt.axhline(y=y_neutral, color=CIU_illustration_colours[2], linestyle='--', label='neutral') 469 | 470 | # Legend? 471 | if legend_location is not None: 472 | plt.legend(loc=legend_location) 473 | 474 | # Add titles. 475 | if main is None: 476 | main = 'Output value as a function of feature value' 477 | plt.title(main) 478 | plt.xlabel(self.input_names[ind_input]) 479 | if len(output_inds) == 1: 480 | plt.ylabel(self.out_names[output_inds[0]]) 481 | else: 482 | plt.ylabel('Output values') 483 | return fig 484 | 485 | def plot_ciu(self, ciu_result=None, plot_mode='color_CU', CImax=1.0, 486 | sort='CI', main=None, color_blind=None, figsize=(6, 4), 487 | color_fill_ci='#7fffd44d', color_edge_ci='#66CDAA', 488 | color_fill_cu="#006400cc", color_edge_cu="#006400"): 489 | 490 | """ 491 | The core plotting method for CIU results, which uses both CI and CU values in the explanation. 492 | 493 | :param DataFrame ciu_result: CIU result DataFrame as returned by one of the "explain..." methods. 494 | :param str plot_mode: defines the type plot to use between 'color_CU', 'overlap' and 'combined'. 495 | :param float CImax: Limit CI axis to the given value. 496 | :param str sort: defines the order of the plot bars by the 'CI' (default), 'CU' values or unsorted if None. 497 | :param str main: Plot title. 498 | :param str color_blind: defines accessible color maps to use for the plots, such as 'protanopia', 499 | 'deuteranopia' and 'tritanopia'. 500 | :param str color_edge_ci: defines the hex or named color for the CI edge in the overlap plot mode. 501 | :param str color_fill_ci: defines the hex or named color for the CI fill in the overlap plot mode. 502 | :param str color_edge_cu: defines the hex or named color for the CU edge in the overlap plot mode. 503 | :param str color_fill_cu: defines the hex or named color for the CU fill in the overlap plot mode. 504 | 505 | :return: matplotlib.figure.Figure 506 | """ 507 | 508 | # Deal with None parameters etc 509 | if ciu_result is None: 510 | if self.last_ciu_result is None: 511 | raise ValueError("No ciu_result given or stored from cal to explain method!") 512 | else: 513 | ciu_result = self.last_ciu_result 514 | 515 | feature_names = ciu_result.feature 516 | ci = ciu_result.CI 517 | cu = ciu_result.CU 518 | nfeatures = len(feature_names) 519 | 520 | fig, ax = plt.subplots(figsize=figsize) 521 | 522 | y_pos = np.arange(nfeatures) 523 | 524 | if sort in ['CI', 'influence']: 525 | ci, cu, feature_names = (list(t) for t in zip(*sorted(zip(ci, cu, feature_names)))) 526 | elif sort == 'CU': 527 | cu, ci, feature_names = (list(t) for t in zip(*sorted(zip(cu, ci, feature_names)))) 528 | 529 | my_norm = colors.Normalize(vmin=0, vmax=1) 530 | nodes = [0.0, 0.5, 1.0] 531 | 532 | # Take care of available color palettes. 533 | if color_blind is None: 534 | colours = ["red", "yellow", "green"] 535 | elif color_blind == 'protanopia': 536 | colours = ["gray", "yellow", "blue"] 537 | elif color_blind == 'deuteranopia': 538 | colours = ["slategray", "orange", "dodgerblue"] 539 | elif color_blind == 'tritanopia': 540 | colours = ["#ff0066", "#ffe6f2", "#00e6e6"] 541 | cmap1 = colors.LinearSegmentedColormap.from_list("mycmap", list(zip(nodes, colours))) 542 | sm = cm.ScalarMappable(cmap=cmap1, norm=my_norm) 543 | sm.set_array([]) 544 | 545 | if plot_mode == "color_CU": 546 | cbar = plt.colorbar(sm, ax=plt.gca()) 547 | cbar.set_label('CU', rotation=0, labelpad=25) 548 | plt.xlabel("CI") 549 | for m in range(nfeatures): 550 | ax.barh(y_pos[m], ci[m], color=cmap1(my_norm(cu[m])), 551 | edgecolor="#808080", zorder=2) 552 | 553 | if plot_mode == "overlap": 554 | plt.xlabel("CI and relative CU") 555 | for m in range(nfeatures): 556 | ax.barh(y_pos[m], ci[m], color=color_fill_ci, 557 | edgecolor=color_edge_ci, linewidth=1.5, zorder=2) 558 | ax.barh(y_pos[m], cu[m]*ci[m], color=color_fill_cu, 559 | edgecolor=color_edge_cu, linewidth=1.5, zorder=2) 560 | 561 | if plot_mode == "combined": 562 | plt.xlabel("CI and relative CU") 563 | cbar = plt.colorbar(sm, ax=plt.gca()) 564 | cbar.set_label('CU', rotation=0, labelpad=25) 565 | for m in range(nfeatures): 566 | ax.barh(y_pos[m], ci.iloc[m], color="#ffffff66", edgecolor="#808080", zorder=2) 567 | ax.barh(y_pos[m], cu.iloc[m]*ci.iloc[m], color=cmap1(my_norm(cu.iloc[m])), zorder=2) 568 | 569 | plt.ylabel("Features") 570 | ax.set_xlim(0, CImax) 571 | if main is not None: 572 | plt.title(main) 573 | 574 | ax.set_facecolor(color="#D9D9D9") 575 | ax.set_yticks(y_pos) 576 | ax.set_yticklabels(feature_names) 577 | ax.grid(which = 'minor') 578 | ax.grid(which='minor', color='white') 579 | ax.grid(which='major', color='white') 580 | return fig 581 | 582 | def plot_influence(self, ciu_result=None, xminmax=None, main=None, figsize=(6, 4), colors=("firebrick","steelblue"), 583 | edgecolors=("#808080","#808080")): 584 | 585 | """ 586 | Plot CIU result as a bar plot using Contextual influence values. 587 | 588 | :param DataFrame ciu_result: CIU result DataFrame as returned by one of the "explain..." methods. 589 | :param (float,float) xminmax: Range to pass to ``xlim``. Default: None. 590 | :param str main: Plot title. 591 | :param (int,int) figsize: Value to pass as ``figsize`` parameter. 592 | :param (str,str) colors: Bar colors to use. First value is for negative influence, second for positive influence. 593 | :param (str,str) edgecolors: Bar edge colors to use. 594 | 595 | :return: matplotlib.figure.Figure 596 | """ 597 | 598 | # Deal with None parameters etc 599 | if ciu_result is None: 600 | if self.last_ciu_result is None: 601 | raise ValueError("No ciu_result given or stored from cal to explain method!") 602 | else: 603 | ciu_result = self.last_ciu_result 604 | 605 | feature_names = ciu_result.feature 606 | cinfl = ciu_result.Cinfl 607 | nfeatures = len(feature_names) 608 | 609 | fig, ax = plt.subplots(figsize=figsize) 610 | 611 | y_pos = np.arange(nfeatures) 612 | 613 | cinfl, feature_names = (list(t) for t in zip(*sorted(zip(cinfl, feature_names)))) 614 | 615 | plt.xlabel("ϕ") 616 | 617 | for m in range(len(cinfl)): 618 | ax.barh(y_pos[m], cinfl[m], color=[colors[0] if cinfl[m] < 0 else colors[1]], 619 | edgecolor=[edgecolors[0] if cinfl[m] < 0 else edgecolors[1]], zorder=2) 620 | 621 | plt.ylabel("Features") 622 | if xminmax is not None: 623 | ax.set_xlim(xminmax) 624 | if main is not None: 625 | plt.title(main) 626 | 627 | ax.set_facecolor(color="#D9D9D9") 628 | 629 | # Y axis labels 630 | ax.set_yticks(y_pos) 631 | ax.set_yticklabels(feature_names) 632 | ax.grid(which = 'minor') 633 | ax.grid(which='minor', color='white') 634 | ax.grid(which='major', color='white') 635 | return fig 636 | 637 | def textual_explanation(self, ciu_result=None, target_ciu=None, thresholds_ci=None, thresholds_cu=None, use_markdown_effects=False): 638 | """ 639 | Translate a CIU result into some kind of "natural language" using threshold values for CI and CU. 640 | 641 | :param DataFrame ciu_result: CIU result as returned by one of the "explain..." methods. 642 | :param DataFrame target_ciu: CIU result for the target concept to explain, as returned by one of 643 | the "explain..." methods. 644 | :param dict thresholds_ci: Dictionary containing the labels and ceiling values for CI thresholds. 645 | :param dict thresholds_cu: Dictionary containing the labels and ceiling values for CU thresholds. 646 | :param boolean use_markdown_effects: Produce Markdown codes in the text or not? 647 | 648 | :return: Explanation as `str`. 649 | """ 650 | 651 | # Deal with None parameters etc 652 | if ciu_result is None: 653 | if self.last_ciu_result is None: 654 | raise ValueError("No ciu_result given or stored from call to explain method!") 655 | else: 656 | ciu_result = self.last_ciu_result 657 | 658 | if thresholds_ci is None: 659 | thresholds_ci = { 660 | 'very low importance': 0.20, 661 | 'low importance': 0.40, 662 | 'normal importance': 0.60, 663 | 'high importance': 0.80, 664 | 'very high importance': 1 665 | } 666 | 667 | if thresholds_cu is None: 668 | thresholds_cu = { 669 | 'low utility': 0.25, 670 | 'lower than average utility': 0.5, 671 | 'higher than average utility': 0.75, 672 | 'high utility': 1 673 | } 674 | 675 | if len(thresholds_cu) < 2 or len(thresholds_ci) < 2: 676 | raise ValueError(f"The dictionaries containing the CI/CU thresholds must have at least 2 elements. \ 677 | \nCI dict: {thresholds_ci} \nCU dict: {thresholds_cu}") 678 | 679 | # Definitions for text formatting. 680 | if use_markdown_effects: 681 | BR = "
" 682 | BLD = "**" 683 | ITS = "*" 684 | else: 685 | BR = "\n" 686 | BLD = "" 687 | ITS = "" 688 | 689 | feature_names = ciu_result.loc[:,'feature'] 690 | explanation = [] 691 | 692 | # cu_concept = round(self.cu[target_concept] * 100, 2) 693 | out_name = ciu_result.loc[:,'outname'].iloc[0] 694 | if ciu_result.loc[:,'target_concept'].iloc[0] is None: 695 | outval = ciu_result.loc[:,'outval'].iloc[0] 696 | outmin = self.out_minmaxs.loc[out_name,:].iloc[0] 697 | out_cu = (outval - outmin)/(self.out_minmaxs.loc[out_name,:].iloc[1] - outmin) 698 | out_cu_text = list(thresholds_cu.keys())[self._find_interval(out_cu, thresholds_cu.values())] 699 | explanation.append(f"The explained value is {BLD}{ITS}{out_name}{ITS}{BLD} with the value " \ 700 | f"{outval:.2f} (CU={out_cu:.2f}), which is {BLD}{out_cu_text}{BLD}.{BR}") 701 | else: 702 | target_concept = ciu_result.loc[:,'target_concept'].iloc[0] 703 | if target_ciu is not None: 704 | ci = target_ciu.loc[target_concept,'CI'].iloc[0] 705 | ci_text = list(thresholds_ci.keys())[self._find_interval(ci, thresholds_ci.values())] 706 | cu = target_ciu.loc[target_concept,'CU'].iloc[0] 707 | cu_text = list(thresholds_cu.keys())[self._find_interval(cu, thresholds_cu.values())] 708 | explanation.append(f"The explained value is {BLD}{ITS}{target_concept}{ITS}{BLD} for output "\ 709 | f"{BLD}{ITS}{out_name}{ITS}{BLD}, which has "\ 710 | f"{BLD}{ci_text} (CI={ci:.2f}){BLD} and {BLD}{cu_text} (CU={cu:.2f}){BLD}.{BR}") 711 | else: 712 | explanation.append(f"The explained value is {BLD}{ITS}{target_concept}{ITS}{BLD} for output {BLD}{ITS}{out_name}{ITS}{BLD}.{BR}") 713 | 714 | for feature in list(feature_names): 715 | ci = ciu_result.loc[feature,'CI'].iloc[0] 716 | ci_text = list(thresholds_ci.keys())[self._find_interval(ci, thresholds_ci.values())] 717 | cu = ciu_result.loc[feature,'CU'].iloc[0] 718 | cu_text = list(thresholds_cu.keys())[self._find_interval(cu, thresholds_cu.values())] 719 | fvalue = ciu_result.loc[feature,'invals'].iloc[0] 720 | if len(fvalue) == 1: # Coalition or single feature? 721 | fvalue = fvalue[0] 722 | explanation.append(f"Feature {ITS}{feature}{ITS} has {BLD}{ci_text} (CI={ci:.2f}){BLD} " \ 723 | f"and has value(s) {fvalue}, which is {BLD}{cu_text} (CU={cu:.2f}){BLD}{BR}") 724 | 725 | return "".join(explanation) 726 | 727 | def _find_interval(self, value, thresholds): 728 | for i, threshold in enumerate(thresholds): 729 | if value <= threshold: 730 | return i 731 | return len(thresholds) - 1 # We can't allow indices to go beyond this. 732 | 733 | def plot_3D(self, ind_inputs, instance=None, ind_output=0, nbr_pts=(40,40), zlim=None, title="", **kwargs): 734 | """ 735 | Plot output value as a function of two inputs. 736 | 737 | :param [int,int] ind_inputs: indexes for two features to use for the 3D plot. 738 | :param DataFrame instance: instance to use. 739 | :param int ind_output: index of output to plot. Default: 0. 740 | :param (int,int) nbr_pts: number of points to use (both axis). 741 | :param (float, float) zlim: Limits to use for Z axis. 742 | :param str title: Title to use for plot. "" gives default title, None omits title. 743 | :param (int,int) figsize: Values to pass to ``plt.figure()``. 744 | :param float azim: azimuth angle to use. 745 | 746 | :return: matplotlib.figure.Figure 747 | """ 748 | # Deal with None parameters and other parameter value arrangements. 749 | if instance is None: 750 | instance = self.instance 751 | 752 | # Get input/feature names 753 | fnames = [self.input_names[i] for i in ind_inputs] 754 | 755 | #fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, figsize=(6, 6)) 756 | 757 | # Create a figure and a 3D axis 758 | fig = plt.figure(figsize=kwargs.get('figsize', None)) 759 | ax = fig.add_subplot(111, projection='3d') 760 | 761 | # Generate data points 762 | minmaxs = self.in_minmaxs 763 | x = np.linspace(minmaxs.iloc[ind_inputs[0],0], minmaxs.iloc[ind_inputs[0],1], nbr_pts[0]) 764 | y = np.linspace(minmaxs.iloc[ind_inputs[1],0], minmaxs.iloc[ind_inputs[1],1], nbr_pts[1]) 765 | x, y = np.meshgrid(x, y) 766 | total_npoints = x.shape[0]*x.shape[1] 767 | m = np.tile(instance, (total_npoints, 1)) 768 | m[:,ind_inputs[0]] = x.reshape(total_npoints) 769 | m[:,ind_inputs[1]] = y.reshape(total_npoints) 770 | z = self.predictor(pd.DataFrame(m, columns=self.input_names)) 771 | if z.ndim == 1: 772 | zm = z.reshape(x.shape[0], x.shape[1]) 773 | else: 774 | zm = z[:,ind_output].reshape(x.shape[0], x.shape[1]) 775 | 776 | # Create a 3D surface plot 777 | ax.plot_surface(x, y, zm, color="lightblue", linewidth=1, antialiased=True, zorder=1, alpha=0.8) 778 | 779 | # Adding instance point marker 780 | outvals = self.predictor(instance) 781 | if outvals.ndim == 1: 782 | outvals = outvals[:,np.newaxis] 783 | xp = instance.iloc[0, ind_inputs[0]] 784 | yp = instance.iloc[0, ind_inputs[1]] 785 | ax.scatter(xp, yp, outvals[0,ind_output], color="red", alpha=1, s=100, zorder=3) 786 | 787 | # Add labels 788 | ax.set_xlabel(fnames[0]) 789 | ax.set_ylabel(fnames[1]) 790 | ax.set_zlabel(self.out_names[ind_output]) 791 | 792 | # Final adjustments 793 | if title is not None: 794 | if title == "": 795 | title = f"Prediction for {self.out_names[ind_output]} is {outvals[0,ind_output]:.3f}" 796 | fig.suptitle(title) 797 | azim = kwargs.get('azim', None) 798 | if azim is not None: 799 | ax.azim = azim 800 | ax.set_zlim(zlim) 801 | return fig 802 | 803 | def contrastive_ciu(ciures1, ciures2): 804 | """ 805 | Calculate contrastive influence values for two CIU result DataFrames. 806 | 807 | The two DataFrames should have the same features, in the same order. 808 | 809 | :param DataFrame ciures1: CIU result DataFrame of the "focus" instance. 810 | :param DataFrame ciures2: CIU result DataFrame of the "challenger" instance. 811 | 812 | :return: `list` with one influence value per feature/concept. 813 | """ 814 | contrastive = ciures1['CI']*(ciures1['CU'] - ciures2['CU']) 815 | return contrastive 816 | 817 | -------------------------------------------------------------------------------- /RunTests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Notebook for running tests in the `ciu_tests` directory.\n", 8 | "\n", 9 | "This notebook runs tests written in \"pure Python\" for different datasets. The current test functions can also be used elsewhere as a shortcut for loading the data, training a model, creating a CIU object and getting a test instance. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "Basic imports." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "import numpy as np\n", 27 | "import ciu_tests as ciu_tests" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Iris" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "data": { 44 | "text/html": [ 45 | "
\n", 46 | "\n", 59 | "\n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
s_length9.887474e-012.119259e-31-4.943737e-01setosa3.254150e-17s_length3.254150e-179.887474e-01[0][2.0]0.5NoneNone
s_width2.782717e-091.169415e-08-1.391358e-09setosa3.254150e-17s_width1.332128e-242.782717e-09[1][3.2]0.5NoneNone
p_length5.536589e-105.877536e-08-2.768294e-10setosa3.254150e-17p_length3.381699e-795.536589e-10[2][1.8]0.5NoneNone
p_width9.999970e-013.015588e-17-4.999985e-01setosa3.254150e-17p_width2.385707e-189.999970e-01[3][2.4]0.5NoneNone
\n", 145 | "
" 146 | ], 147 | "text/plain": [ 148 | " CI CU Cinfl outname outval \\\n", 149 | "s_length 9.887474e-01 2.119259e-31 -4.943737e-01 setosa 3.254150e-17 \n", 150 | "s_width 2.782717e-09 1.169415e-08 -1.391358e-09 setosa 3.254150e-17 \n", 151 | "p_length 5.536589e-10 5.877536e-08 -2.768294e-10 setosa 3.254150e-17 \n", 152 | "p_width 9.999970e-01 3.015588e-17 -4.999985e-01 setosa 3.254150e-17 \n", 153 | "\n", 154 | " feature ymin ymax inputs invals neutralCU \\\n", 155 | "s_length s_length 3.254150e-17 9.887474e-01 [0] [2.0] 0.5 \n", 156 | "s_width s_width 1.332128e-24 2.782717e-09 [1] [3.2] 0.5 \n", 157 | "p_length p_length 3.381699e-79 5.536589e-10 [2] [1.8] 0.5 \n", 158 | "p_width p_width 2.385707e-18 9.999970e-01 [3] [2.4] 0.5 \n", 159 | "\n", 160 | " target_concept target_inputs \n", 161 | "s_length None None \n", 162 | "s_width None None \n", 163 | "p_length None None \n", 164 | "p_width None None " 165 | ] 166 | }, 167 | "metadata": {}, 168 | "output_type": "display_data" 169 | }, 170 | { 171 | "data": { 172 | "text/html": [ 173 | "
\n", 174 | "\n", 187 | "\n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
s_length0.9887430.8141430.310607versicolor0.816231s_length1.125260e-020.999995[0][2.0]0.5NoneNone
s_width0.8815020.7952130.260231versicolor0.816231s_width1.152488e-010.996751[1][3.2]0.5NoneNone
p_length0.9995460.8166010.316458versicolor0.816231p_length2.910413e-170.999546[2][1.8]0.5NoneNone
p_width0.9999970.8162310.316230versicolor0.816231p_width3.002736e-061.000000[3][2.4]0.5NoneNone
\n", 273 | "
" 274 | ], 275 | "text/plain": [ 276 | " CI CU Cinfl outname outval feature \\\n", 277 | "s_length 0.988743 0.814143 0.310607 versicolor 0.816231 s_length \n", 278 | "s_width 0.881502 0.795213 0.260231 versicolor 0.816231 s_width \n", 279 | "p_length 0.999546 0.816601 0.316458 versicolor 0.816231 p_length \n", 280 | "p_width 0.999997 0.816231 0.316230 versicolor 0.816231 p_width \n", 281 | "\n", 282 | " ymin ymax inputs invals neutralCU target_concept \\\n", 283 | "s_length 1.125260e-02 0.999995 [0] [2.0] 0.5 None \n", 284 | "s_width 1.152488e-01 0.996751 [1] [3.2] 0.5 None \n", 285 | "p_length 2.910413e-17 0.999546 [2] [1.8] 0.5 None \n", 286 | "p_width 3.002736e-06 1.000000 [3] [2.4] 0.5 None \n", 287 | "\n", 288 | " target_inputs \n", 289 | "s_length None \n", 290 | "s_width None \n", 291 | "p_length None \n", 292 | "p_width None " 293 | ] 294 | }, 295 | "metadata": {}, 296 | "output_type": "display_data" 297 | }, 298 | { 299 | "data": { 300 | "text/html": [ 301 | "
\n", 302 | "\n", 315 | "\n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
s_length0.1837691.0000000.091885virginica0.183769s_length8.883252e-120.183769[0][2.0]0.5NoneNone
s_width0.8815020.204787-0.260231virginica0.183769s_width3.249199e-030.884751[1][3.2]0.5NoneNone
p_length0.9995460.183399-0.316458virginica0.183769p_length4.537308e-041.000000[2][1.8]0.5NoneNone
p_width0.4628240.397061-0.047643virginica0.183769p_width2.658730e-200.462824[3][2.4]0.5NoneNone
\n", 401 | "
" 402 | ], 403 | "text/plain": [ 404 | " CI CU Cinfl outname outval feature \\\n", 405 | "s_length 0.183769 1.000000 0.091885 virginica 0.183769 s_length \n", 406 | "s_width 0.881502 0.204787 -0.260231 virginica 0.183769 s_width \n", 407 | "p_length 0.999546 0.183399 -0.316458 virginica 0.183769 p_length \n", 408 | "p_width 0.462824 0.397061 -0.047643 virginica 0.183769 p_width \n", 409 | "\n", 410 | " ymin ymax inputs invals neutralCU target_concept \\\n", 411 | "s_length 8.883252e-12 0.183769 [0] [2.0] 0.5 None \n", 412 | "s_width 3.249199e-03 0.884751 [1] [3.2] 0.5 None \n", 413 | "p_length 4.537308e-04 1.000000 [2] [1.8] 0.5 None \n", 414 | "p_width 2.658730e-20 0.462824 [3] [2.4] 0.5 None \n", 415 | "\n", 416 | " target_inputs \n", 417 | "s_length None \n", 418 | "s_width None \n", 419 | "p_length None \n", 420 | "p_width None " 421 | ] 422 | }, 423 | "metadata": {}, 424 | "output_type": "display_data" 425 | } 426 | ], 427 | "source": [ 428 | "from ciu_tests import iris_lda\n", 429 | "\n", 430 | "np.random.seed(24) # We want to always get the same Random Forest model here.\n", 431 | "CIU_iris, iris_lda_model, instance = iris_lda.get_iris_test()\n", 432 | "CIUres_iris = CIU_iris.explain(instance)\n", 433 | "display(CIUres_iris)\n", 434 | "CIUres_iris = CIU_iris.explain(instance, output_inds=1)\n", 435 | "display(CIUres_iris)\n", 436 | "CIUres_iris = CIU_iris.explain(instance, output_inds=2)\n", 437 | "display(CIUres_iris)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "## Boston housing" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 3, 450 | "metadata": {}, 451 | "outputs": [ 452 | { 453 | "data": { 454 | "text/html": [ 455 | "
\n", 456 | "\n", 469 | "\n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
CRIM0.0014050.000000-0.000702Price25.574661CRIM25.57466125.637869[0][0.05735]0.5NoneNone
ZN0.0210480.000000-0.010524Price25.574661ZN25.57466126.521812[1][0.0]0.5NoneNone
INDUS0.0461090.138218-0.016681Price25.574661INDUS25.28787027.362782[2][4.49]0.5NoneNone
CHAS0.0000000.000000-0.000000Price25.574661CHAS25.57466125.574661[3][0.0]0.5NoneNone
NOX0.0253401.0000000.012670Price25.574661NOX24.43434525.574661[4][0.449]0.5NoneNone
RM0.0749850.259624-0.018025Price25.574661RM24.69860128.072943[5][6.63]0.5NoneNone
AGE0.0138941.0000000.006947Price25.574661AGE24.94944625.574661[6][56.1]0.5NoneNone
DIS0.0219460.5687360.001509Price25.574661DIS25.01298726.000570[7][4.4377]0.5NoneNone
RAD0.0000000.000000-0.000000Price25.574661RAD25.57466125.574661[8][3.0]0.5NoneNone
TAX0.0344471.0000000.017224Price25.574661TAX24.02453425.574661[9][247.0]0.5NoneNone
PTRATIO0.0508430.406109-0.004774Price25.574661PTRATIO24.64551226.933443[10][18.5]0.5NoneNone
B0.0238761.0000000.011938Price25.574661B24.50022125.574661[11][392.3]0.5NoneNone
LSTAT0.1108950.6872540.020765Price25.574661LSTAT22.14508227.135347[12][6.53]0.5NoneNone
\n", 699 | "
" 700 | ], 701 | "text/plain": [ 702 | " CI CU Cinfl outname outval feature ymin \\\n", 703 | "CRIM 0.001405 0.000000 -0.000702 Price 25.574661 CRIM 25.574661 \n", 704 | "ZN 0.021048 0.000000 -0.010524 Price 25.574661 ZN 25.574661 \n", 705 | "INDUS 0.046109 0.138218 -0.016681 Price 25.574661 INDUS 25.287870 \n", 706 | "CHAS 0.000000 0.000000 -0.000000 Price 25.574661 CHAS 25.574661 \n", 707 | "NOX 0.025340 1.000000 0.012670 Price 25.574661 NOX 24.434345 \n", 708 | "RM 0.074985 0.259624 -0.018025 Price 25.574661 RM 24.698601 \n", 709 | "AGE 0.013894 1.000000 0.006947 Price 25.574661 AGE 24.949446 \n", 710 | "DIS 0.021946 0.568736 0.001509 Price 25.574661 DIS 25.012987 \n", 711 | "RAD 0.000000 0.000000 -0.000000 Price 25.574661 RAD 25.574661 \n", 712 | "TAX 0.034447 1.000000 0.017224 Price 25.574661 TAX 24.024534 \n", 713 | "PTRATIO 0.050843 0.406109 -0.004774 Price 25.574661 PTRATIO 24.645512 \n", 714 | "B 0.023876 1.000000 0.011938 Price 25.574661 B 24.500221 \n", 715 | "LSTAT 0.110895 0.687254 0.020765 Price 25.574661 LSTAT 22.145082 \n", 716 | "\n", 717 | " ymax inputs invals neutralCU target_concept target_inputs \n", 718 | "CRIM 25.637869 [0] [0.05735] 0.5 None None \n", 719 | "ZN 26.521812 [1] [0.0] 0.5 None None \n", 720 | "INDUS 27.362782 [2] [4.49] 0.5 None None \n", 721 | "CHAS 25.574661 [3] [0.0] 0.5 None None \n", 722 | "NOX 25.574661 [4] [0.449] 0.5 None None \n", 723 | "RM 28.072943 [5] [6.63] 0.5 None None \n", 724 | "AGE 25.574661 [6] [56.1] 0.5 None None \n", 725 | "DIS 26.000570 [7] [4.4377] 0.5 None None \n", 726 | "RAD 25.574661 [8] [3.0] 0.5 None None \n", 727 | "TAX 25.574661 [9] [247.0] 0.5 None None \n", 728 | "PTRATIO 26.933443 [10] [18.5] 0.5 None None \n", 729 | "B 25.574661 [11] [392.3] 0.5 None None \n", 730 | "LSTAT 27.135347 [12] [6.53] 0.5 None None " 731 | ] 732 | }, 733 | "metadata": {}, 734 | "output_type": "display_data" 735 | } 736 | ], 737 | "source": [ 738 | "from ciu_tests import boston_gbm\n", 739 | "\n", 740 | "np.random.seed(26) # We want to always get the same Random Forest model here.\n", 741 | "CIU, boston_xgb_model, instance = boston_gbm.get_boston_gbm_test()\n", 742 | "CIUres = CIU.explain(instance)\n", 743 | "display(CIUres)" 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": {}, 749 | "source": [ 750 | "## Titanic" 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": 4, 756 | "metadata": {}, 757 | "outputs": [ 758 | { 759 | "data": { 760 | "text/html": [ 761 | "
\n", 762 | "\n", 775 | "\n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
Pclass0.0000000.000000-0.000000No0.4Pclass0.4000000.4000[0][1.0]0.5NoneNone
Sex0.3093331.0000000.154667No0.4Sex0.0906670.4000[1][1.0]0.5NoneNone
Age0.6300000.047619-0.285000No0.4Age0.3700001.0000[2][8.0]0.5NoneNone
SibSp0.1600000.187500-0.050000No0.4SibSp0.3700000.5300[3][0.0]0.5NoneNone
Parch0.1200001.0000000.060000No0.4Parch0.2800000.4000[4][0.0]0.5NoneNone
Fare0.1166670.292857-0.024167No0.4Fare0.3658330.4825[5][72.0]0.5NoneNone
Embarked0.0300001.0000000.015000No0.4Embarked0.3700000.4000[6][1.0]0.5NoneNone
\n", 909 | "
" 910 | ], 911 | "text/plain": [ 912 | " CI CU Cinfl outname outval feature ymin \\\n", 913 | "Pclass 0.000000 0.000000 -0.000000 No 0.4 Pclass 0.400000 \n", 914 | "Sex 0.309333 1.000000 0.154667 No 0.4 Sex 0.090667 \n", 915 | "Age 0.630000 0.047619 -0.285000 No 0.4 Age 0.370000 \n", 916 | "SibSp 0.160000 0.187500 -0.050000 No 0.4 SibSp 0.370000 \n", 917 | "Parch 0.120000 1.000000 0.060000 No 0.4 Parch 0.280000 \n", 918 | "Fare 0.116667 0.292857 -0.024167 No 0.4 Fare 0.365833 \n", 919 | "Embarked 0.030000 1.000000 0.015000 No 0.4 Embarked 0.370000 \n", 920 | "\n", 921 | " ymax inputs invals neutralCU target_concept target_inputs \n", 922 | "Pclass 0.4000 [0] [1.0] 0.5 None None \n", 923 | "Sex 0.4000 [1] [1.0] 0.5 None None \n", 924 | "Age 1.0000 [2] [8.0] 0.5 None None \n", 925 | "SibSp 0.5300 [3] [0.0] 0.5 None None \n", 926 | "Parch 0.4000 [4] [0.0] 0.5 None None \n", 927 | "Fare 0.4825 [5] [72.0] 0.5 None None \n", 928 | "Embarked 0.4000 [6] [1.0] 0.5 None None " 929 | ] 930 | }, 931 | "metadata": {}, 932 | "output_type": "display_data" 933 | } 934 | ], 935 | "source": [ 936 | "from ciu_tests import titanic_rf\n", 937 | "\n", 938 | "np.random.seed(26) # We want to always get the same Random Forest model here.\n", 939 | "CIU_titanic, titanic_model, titanic_instance = titanic_rf.get_titanic_rf()\n", 940 | "CIUres_titanic = CIU_titanic.explain(titanic_instance)\n", 941 | "display(CIUres_titanic)" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": 5, 947 | "metadata": {}, 948 | "outputs": [ 949 | { 950 | "data": { 951 | "text/html": [ 952 | "
\n", 953 | "\n", 966 | "\n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
Wealth0.2600000.5769230.020000No0.4Wealth0.2500000.51[0, 5][1.0, 72.0]0.5NoneNone
Family0.3000000.5666670.020000No0.4Family0.2300000.53[3, 4][0.0, 0.0]0.5NoneNone
Sex0.3093331.0000000.154667No0.4Sex0.0906670.40[1][1.0]0.5NoneNone
Age0.6300000.047619-0.285000No0.4Age0.3700001.00[2][8.0]0.5NoneNone
Embarked0.0300001.0000000.015000No0.4Embarked0.3700000.40[6][1.0]0.5NoneNone
\n", 1068 | "
" 1069 | ], 1070 | "text/plain": [ 1071 | " CI CU Cinfl outname outval feature ymin \\\n", 1072 | "Wealth 0.260000 0.576923 0.020000 No 0.4 Wealth 0.250000 \n", 1073 | "Family 0.300000 0.566667 0.020000 No 0.4 Family 0.230000 \n", 1074 | "Sex 0.309333 1.000000 0.154667 No 0.4 Sex 0.090667 \n", 1075 | "Age 0.630000 0.047619 -0.285000 No 0.4 Age 0.370000 \n", 1076 | "Embarked 0.030000 1.000000 0.015000 No 0.4 Embarked 0.370000 \n", 1077 | "\n", 1078 | " ymax inputs invals neutralCU target_concept target_inputs \n", 1079 | "Wealth 0.51 [0, 5] [1.0, 72.0] 0.5 None None \n", 1080 | "Family 0.53 [3, 4] [0.0, 0.0] 0.5 None None \n", 1081 | "Sex 0.40 [1] [1.0] 0.5 None None \n", 1082 | "Age 1.00 [2] [8.0] 0.5 None None \n", 1083 | "Embarked 0.40 [6] [1.0] 0.5 None None " 1084 | ] 1085 | }, 1086 | "metadata": {}, 1087 | "output_type": "display_data" 1088 | } 1089 | ], 1090 | "source": [ 1091 | "CIUres_voc_top_titanic = CIU_titanic.explain_voc(titanic_instance, nsamples=1000)\n", 1092 | "display(CIUres_voc_top_titanic)" 1093 | ] 1094 | }, 1095 | { 1096 | "cell_type": "markdown", 1097 | "metadata": {}, 1098 | "source": [ 1099 | "## Ames housing" 1100 | ] 1101 | }, 1102 | { 1103 | "cell_type": "code", 1104 | "execution_count": 6, 1105 | "metadata": {}, 1106 | "outputs": [ 1107 | { 1108 | "data": { 1109 | "text/html": [ 1110 | "
\n", 1111 | "\n", 1124 | "\n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | " \n", 1207 | " \n", 1208 | " \n", 1209 | " \n", 1210 | " \n", 1211 | " \n", 1212 | " \n", 1213 | " \n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | " \n", 1315 | " \n", 1316 | " \n", 1317 | " \n", 1318 | " \n", 1319 | " \n", 1320 | " \n", 1321 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
Order0.0310890.9754020.014780Price740.222046Order708.957336741.010498[0][1561.0]0.5NoneNone
PID0.0307420.8073100.009447Price740.222046PID714.634338746.329346[1][2700.0]0.5NoneNone
MSSubClass0.0000000.000000-0.000000Price740.222046MSSubClass740.222046740.222046[2][11.0]0.5NoneNone
MSZoning0.0039701.0000000.001985Price740.222046MSZoning736.128540740.222046[3][6.0]0.5NoneNone
LotFrontage0.0358200.7372160.008497Price740.222046LotFrontage712.996338749.926758[4][20.0]0.5NoneNone
..........................................
MiscVal0.0003391.0000000.000169Price740.222046MiscVal739.872620740.222046[76][0.0]0.5NoneNone
MoSold0.0223510.5378450.000846Price740.222046MoSold727.828247750.871704[77][6.0]0.5NoneNone
YrSold0.0108581.0000000.005429Price740.222046YrSold729.027527740.222046[78][2.0]0.5NoneNone
SaleType0.0423021.0000000.021151Price740.222046SaleType696.609192740.222046[79][9.0]0.5NoneNone
SaleCondition0.0063091.0000000.003155Price740.222046SaleCondition733.717407740.222046[80][4.0]0.5NoneNone
\n", 1322 | "

81 rows × 13 columns

\n", 1323 | "
" 1324 | ], 1325 | "text/plain": [ 1326 | " CI CU Cinfl outname outval \\\n", 1327 | "Order 0.031089 0.975402 0.014780 Price 740.222046 \n", 1328 | "PID 0.030742 0.807310 0.009447 Price 740.222046 \n", 1329 | "MSSubClass 0.000000 0.000000 -0.000000 Price 740.222046 \n", 1330 | "MSZoning 0.003970 1.000000 0.001985 Price 740.222046 \n", 1331 | "LotFrontage 0.035820 0.737216 0.008497 Price 740.222046 \n", 1332 | "... ... ... ... ... ... \n", 1333 | "MiscVal 0.000339 1.000000 0.000169 Price 740.222046 \n", 1334 | "MoSold 0.022351 0.537845 0.000846 Price 740.222046 \n", 1335 | "YrSold 0.010858 1.000000 0.005429 Price 740.222046 \n", 1336 | "SaleType 0.042302 1.000000 0.021151 Price 740.222046 \n", 1337 | "SaleCondition 0.006309 1.000000 0.003155 Price 740.222046 \n", 1338 | "\n", 1339 | " feature ymin ymax inputs invals \\\n", 1340 | "Order Order 708.957336 741.010498 [0] [1561.0] \n", 1341 | "PID PID 714.634338 746.329346 [1] [2700.0] \n", 1342 | "MSSubClass MSSubClass 740.222046 740.222046 [2] [11.0] \n", 1343 | "MSZoning MSZoning 736.128540 740.222046 [3] [6.0] \n", 1344 | "LotFrontage LotFrontage 712.996338 749.926758 [4] [20.0] \n", 1345 | "... ... ... ... ... ... \n", 1346 | "MiscVal MiscVal 739.872620 740.222046 [76] [0.0] \n", 1347 | "MoSold MoSold 727.828247 750.871704 [77] [6.0] \n", 1348 | "YrSold YrSold 729.027527 740.222046 [78] [2.0] \n", 1349 | "SaleType SaleType 696.609192 740.222046 [79] [9.0] \n", 1350 | "SaleCondition SaleCondition 733.717407 740.222046 [80] [4.0] \n", 1351 | "\n", 1352 | " neutralCU target_concept target_inputs \n", 1353 | "Order 0.5 None None \n", 1354 | "PID 0.5 None None \n", 1355 | "MSSubClass 0.5 None None \n", 1356 | "MSZoning 0.5 None None \n", 1357 | "LotFrontage 0.5 None None \n", 1358 | "... ... ... ... \n", 1359 | "MiscVal 0.5 None None \n", 1360 | "MoSold 0.5 None None \n", 1361 | "YrSold 0.5 None None \n", 1362 | "SaleType 0.5 None None \n", 1363 | "SaleCondition 0.5 None None \n", 1364 | "\n", 1365 | "[81 rows x 13 columns]" 1366 | ] 1367 | }, 1368 | "metadata": {}, 1369 | "output_type": "display_data" 1370 | } 1371 | ], 1372 | "source": [ 1373 | "from ciu_tests import ames_housing_gbm\n", 1374 | "\n", 1375 | "np.random.seed(26) # We want to always get the same Random Forest model here.\n", 1376 | "CIU, ames_xgb_model, ames_instance = ames_housing_gbm.get_ames_gbm_test()\n", 1377 | "CIUres = CIU.explain(ames_instance)\n", 1378 | "display(CIUres)" 1379 | ] 1380 | }, 1381 | { 1382 | "cell_type": "code", 1383 | "execution_count": 7, 1384 | "metadata": {}, 1385 | "outputs": [ 1386 | { 1387 | "data": { 1388 | "text/html": [ 1389 | "
\n", 1390 | "\n", 1403 | "\n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | " \n", 1482 | " \n", 1483 | " \n", 1484 | " \n", 1485 | " \n", 1486 | " \n", 1487 | " \n", 1488 | " \n", 1489 | " \n", 1490 | " \n", 1491 | " \n", 1492 | " \n", 1493 | " \n", 1494 | " \n", 1495 | " \n", 1496 | " \n", 1497 | " \n", 1498 | " \n", 1499 | " \n", 1500 | " \n", 1501 | " \n", 1502 | " \n", 1503 | " \n", 1504 | " \n", 1505 | " \n", 1506 | " \n", 1507 | " \n", 1508 | " \n", 1509 | " \n", 1510 | " \n", 1511 | " \n", 1512 | " \n", 1513 | " \n", 1514 | " \n", 1515 | " \n", 1516 | " \n", 1517 | " \n", 1518 | " \n", 1519 | " \n", 1520 | " \n", 1521 | " \n", 1522 | " \n", 1523 | " \n", 1524 | " \n", 1525 | " \n", 1526 | " \n", 1527 | " \n", 1528 | " \n", 1529 | " \n", 1530 | " \n", 1531 | " \n", 1532 | " \n", 1533 | " \n", 1534 | " \n", 1535 | " \n", 1536 | " \n", 1537 | " \n", 1538 | " \n", 1539 | " \n", 1540 | " \n", 1541 | " \n", 1542 | " \n", 1543 | " \n", 1544 | " \n", 1545 | " \n", 1546 | " \n", 1547 | " \n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
Garage0.0649930.8228620.020984Price740.222046Garage685.084045752.091675[59, 60, 61, 62, 63, 64, 65][1.0, 103.0, 0.0, 2.0, 224.0, 5.0, 5.0]0.5NoneNone
Basement0.2171351.0000000.108568Price740.222046Basement516.355713740.222046[31, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49][0.0, 5.0, 1.0, 2.0, 941.0, 6.0, 0.0, 81.0, 84...0.5NoneNone
Lot0.0628370.7988250.018777Price740.222046Lot688.470215753.255127[3, 4, 7, 8, 9, 10, 11][6.0, 20.0, 1.0, 0.0, 1.0, 0.0, 4.0]0.5NoneNone
Access0.0662350.9984310.033013Price740.222046Access672.041138740.329163[13, 14][6.0, 2.0]0.5NoneNone
House_type0.0717510.7002950.014371Price740.222046House_type688.417786762.392700[1, 15, 16, 21][2700.0, 2.0, 4.0, 55.0]0.5NoneNone
House_aesthetics0.0562050.9386500.024654Price740.222046House_aesthetics685.829773743.777100[22, 23, 24, 25, 26][3.0, 1.0, 5.0, 5.0, 4.0]0.5NoneNone
House_condition0.4457990.7709610.120794Price740.222046House_condition385.873962845.492798[20, 18, 21, 28, 19, 29][111.0, 6.0, 55.0, 2.0, 4.0, 4.0]0.5NoneNone
Electrical0.0002951.0000000.000147Price740.222046Electrical739.918396740.222046[43][4.0]0.5NoneNone
GrLivArea0.1491490.8942860.058807Price740.222046GrLivArea602.705200756.477905[47][711.0]0.5NoneNone
\n", 1569 | "
" 1570 | ], 1571 | "text/plain": [ 1572 | " CI CU Cinfl outname outval \\\n", 1573 | "Garage 0.064993 0.822862 0.020984 Price 740.222046 \n", 1574 | "Basement 0.217135 1.000000 0.108568 Price 740.222046 \n", 1575 | "Lot 0.062837 0.798825 0.018777 Price 740.222046 \n", 1576 | "Access 0.066235 0.998431 0.033013 Price 740.222046 \n", 1577 | "House_type 0.071751 0.700295 0.014371 Price 740.222046 \n", 1578 | "House_aesthetics 0.056205 0.938650 0.024654 Price 740.222046 \n", 1579 | "House_condition 0.445799 0.770961 0.120794 Price 740.222046 \n", 1580 | "Electrical 0.000295 1.000000 0.000147 Price 740.222046 \n", 1581 | "GrLivArea 0.149149 0.894286 0.058807 Price 740.222046 \n", 1582 | "\n", 1583 | " feature ymin ymax \\\n", 1584 | "Garage Garage 685.084045 752.091675 \n", 1585 | "Basement Basement 516.355713 740.222046 \n", 1586 | "Lot Lot 688.470215 753.255127 \n", 1587 | "Access Access 672.041138 740.329163 \n", 1588 | "House_type House_type 688.417786 762.392700 \n", 1589 | "House_aesthetics House_aesthetics 685.829773 743.777100 \n", 1590 | "House_condition House_condition 385.873962 845.492798 \n", 1591 | "Electrical Electrical 739.918396 740.222046 \n", 1592 | "GrLivArea GrLivArea 602.705200 756.477905 \n", 1593 | "\n", 1594 | " inputs \\\n", 1595 | "Garage [59, 60, 61, 62, 63, 64, 65] \n", 1596 | "Basement [31, 32, 33, 34, 35, 36, 37, 38, 39, 48, 49] \n", 1597 | "Lot [3, 4, 7, 8, 9, 10, 11] \n", 1598 | "Access [13, 14] \n", 1599 | "House_type [1, 15, 16, 21] \n", 1600 | "House_aesthetics [22, 23, 24, 25, 26] \n", 1601 | "House_condition [20, 18, 21, 28, 19, 29] \n", 1602 | "Electrical [43] \n", 1603 | "GrLivArea [47] \n", 1604 | "\n", 1605 | " invals \\\n", 1606 | "Garage [1.0, 103.0, 0.0, 2.0, 224.0, 5.0, 5.0] \n", 1607 | "Basement [0.0, 5.0, 1.0, 2.0, 941.0, 6.0, 0.0, 81.0, 84... \n", 1608 | "Lot [6.0, 20.0, 1.0, 0.0, 1.0, 0.0, 4.0] \n", 1609 | "Access [6.0, 2.0] \n", 1610 | "House_type [2700.0, 2.0, 4.0, 55.0] \n", 1611 | "House_aesthetics [3.0, 1.0, 5.0, 5.0, 4.0] \n", 1612 | "House_condition [111.0, 6.0, 55.0, 2.0, 4.0, 4.0] \n", 1613 | "Electrical [4.0] \n", 1614 | "GrLivArea [711.0] \n", 1615 | "\n", 1616 | " neutralCU target_concept target_inputs \n", 1617 | "Garage 0.5 None None \n", 1618 | "Basement 0.5 None None \n", 1619 | "Lot 0.5 None None \n", 1620 | "Access 0.5 None None \n", 1621 | "House_type 0.5 None None \n", 1622 | "House_aesthetics 0.5 None None \n", 1623 | "House_condition 0.5 None None \n", 1624 | "Electrical 0.5 None None \n", 1625 | "GrLivArea 0.5 None None " 1626 | ] 1627 | }, 1628 | "metadata": {}, 1629 | "output_type": "display_data" 1630 | } 1631 | ], 1632 | "source": [ 1633 | "CIUres_voc_top = CIU.explain_voc(ames_instance, nsamples=1000)\n", 1634 | "display(CIUres_voc_top)" 1635 | ] 1636 | }, 1637 | { 1638 | "cell_type": "markdown", 1639 | "metadata": {}, 1640 | "source": [ 1641 | "## Heart disease" 1642 | ] 1643 | }, 1644 | { 1645 | "cell_type": "markdown", 1646 | "metadata": {}, 1647 | "source": [ 1648 | "The target variable has been restricted to only two classes, which are \"no disease\" and \"disease\". The data originally classifies diseases into four diferent classes. " 1649 | ] 1650 | }, 1651 | { 1652 | "cell_type": "code", 1653 | "execution_count": 8, 1654 | "metadata": {}, 1655 | "outputs": [ 1656 | { 1657 | "name": "stdout", 1658 | "output_type": "stream", 1659 | "text": [ 1660 | " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", 1661 | "106 57.0 1.0 3.0 128.0 229.0 0.0 2.0 150.0 0.0 0.4 \n", 1662 | "\n", 1663 | " slope ca thal \n", 1664 | "106 2.0 1.0 7.0 \n", 1665 | "[[0.41 0.59]]\n" 1666 | ] 1667 | }, 1668 | { 1669 | "data": { 1670 | "text/html": [ 1671 | "
\n", 1672 | "\n", 1685 | "\n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | " \n", 1746 | " \n", 1747 | " \n", 1748 | " \n", 1749 | " \n", 1750 | " \n", 1751 | " \n", 1752 | " \n", 1753 | " \n", 1754 | " \n", 1755 | " \n", 1756 | " \n", 1757 | " \n", 1758 | " \n", 1759 | " \n", 1760 | " \n", 1761 | " \n", 1762 | " \n", 1763 | " \n", 1764 | " \n", 1765 | " \n", 1766 | " \n", 1767 | " \n", 1768 | " \n", 1769 | " \n", 1770 | " \n", 1771 | " \n", 1772 | " \n", 1773 | " \n", 1774 | " \n", 1775 | " \n", 1776 | " \n", 1777 | " \n", 1778 | " \n", 1779 | " \n", 1780 | " \n", 1781 | " \n", 1782 | " \n", 1783 | " \n", 1784 | " \n", 1785 | " \n", 1786 | " \n", 1787 | " \n", 1788 | " \n", 1789 | " \n", 1790 | " \n", 1791 | " \n", 1792 | " \n", 1793 | " \n", 1794 | " \n", 1795 | " \n", 1796 | " \n", 1797 | " \n", 1798 | " \n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1804 | " \n", 1805 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1827 | " \n", 1828 | " \n", 1829 | " \n", 1830 | " \n", 1831 | " \n", 1832 | " \n", 1833 | " \n", 1834 | " \n", 1835 | " \n", 1836 | " \n", 1837 | " \n", 1838 | " \n", 1839 | " \n", 1840 | " \n", 1841 | " \n", 1842 | " \n", 1843 | " \n", 1844 | " \n", 1845 | " \n", 1846 | " \n", 1847 | " \n", 1848 | " \n", 1849 | " \n", 1850 | " \n", 1851 | " \n", 1852 | " \n", 1853 | " \n", 1854 | " \n", 1855 | " \n", 1856 | " \n", 1857 | " \n", 1858 | " \n", 1859 | " \n", 1860 | " \n", 1861 | " \n", 1862 | " \n", 1863 | " \n", 1864 | " \n", 1865 | " \n", 1866 | " \n", 1867 | " \n", 1868 | " \n", 1869 | " \n", 1870 | " \n", 1871 | " \n", 1872 | " \n", 1873 | " \n", 1874 | " \n", 1875 | " \n", 1876 | " \n", 1877 | " \n", 1878 | " \n", 1879 | " \n", 1880 | " \n", 1881 | " \n", 1882 | " \n", 1883 | " \n", 1884 | " \n", 1885 | " \n", 1886 | " \n", 1887 | " \n", 1888 | " \n", 1889 | " \n", 1890 | " \n", 1891 | " \n", 1892 | " \n", 1893 | " \n", 1894 | " \n", 1895 | " \n", 1896 | " \n", 1897 | " \n", 1898 | " \n", 1899 | " \n", 1900 | " \n", 1901 | " \n", 1902 | " \n", 1903 | " \n", 1904 | " \n", 1905 | " \n", 1906 | " \n", 1907 | " \n", 1908 | " \n", 1909 | " \n", 1910 | " \n", 1911 | " \n", 1912 | " \n", 1913 | " \n", 1914 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
age0.100.200000-0.030No0.41age0.390.49[0][57.0]0.5NoneNone
sex0.030.000000-0.015No0.41sex0.410.44[1][1.0]0.5NoneNone
cp0.161.0000000.080No0.41cp0.250.41[2][3.0]0.5NoneNone
trestbps0.090.8888890.035No0.41trestbps0.330.42[3][128.0]0.5NoneNone
chol0.110.181818-0.035No0.41chol0.390.50[4][229.0]0.5NoneNone
fbs0.070.000000-0.035No0.41fbs0.410.48[5][0.0]0.5NoneNone
restecg0.090.000000-0.045No0.41restecg0.410.50[6][2.0]0.5NoneNone
thalach0.071.0000000.035No0.41thalach0.340.41[7][150.0]0.5NoneNone
exang0.070.000000-0.035No0.41exang0.410.48[8][0.0]0.5NoneNone
oldpeak0.250.9200000.105No0.41oldpeak0.180.43[9][0.4]0.5NoneNone
slope0.280.035714-0.130No0.41slope0.400.68[10][2.0]0.5NoneNone
ca0.220.136364-0.080No0.41ca0.380.60[11][1.0]0.5NoneNone
thal0.180.000000-0.090No0.41thal0.410.59[12][7.0]0.5NoneNone
\n", 1915 | "
" 1916 | ], 1917 | "text/plain": [ 1918 | " CI CU Cinfl outname outval feature ymin ymax inputs \\\n", 1919 | "age 0.10 0.200000 -0.030 No 0.41 age 0.39 0.49 [0] \n", 1920 | "sex 0.03 0.000000 -0.015 No 0.41 sex 0.41 0.44 [1] \n", 1921 | "cp 0.16 1.000000 0.080 No 0.41 cp 0.25 0.41 [2] \n", 1922 | "trestbps 0.09 0.888889 0.035 No 0.41 trestbps 0.33 0.42 [3] \n", 1923 | "chol 0.11 0.181818 -0.035 No 0.41 chol 0.39 0.50 [4] \n", 1924 | "fbs 0.07 0.000000 -0.035 No 0.41 fbs 0.41 0.48 [5] \n", 1925 | "restecg 0.09 0.000000 -0.045 No 0.41 restecg 0.41 0.50 [6] \n", 1926 | "thalach 0.07 1.000000 0.035 No 0.41 thalach 0.34 0.41 [7] \n", 1927 | "exang 0.07 0.000000 -0.035 No 0.41 exang 0.41 0.48 [8] \n", 1928 | "oldpeak 0.25 0.920000 0.105 No 0.41 oldpeak 0.18 0.43 [9] \n", 1929 | "slope 0.28 0.035714 -0.130 No 0.41 slope 0.40 0.68 [10] \n", 1930 | "ca 0.22 0.136364 -0.080 No 0.41 ca 0.38 0.60 [11] \n", 1931 | "thal 0.18 0.000000 -0.090 No 0.41 thal 0.41 0.59 [12] \n", 1932 | "\n", 1933 | " invals neutralCU target_concept target_inputs \n", 1934 | "age [57.0] 0.5 None None \n", 1935 | "sex [1.0] 0.5 None None \n", 1936 | "cp [3.0] 0.5 None None \n", 1937 | "trestbps [128.0] 0.5 None None \n", 1938 | "chol [229.0] 0.5 None None \n", 1939 | "fbs [0.0] 0.5 None None \n", 1940 | "restecg [2.0] 0.5 None None \n", 1941 | "thalach [150.0] 0.5 None None \n", 1942 | "exang [0.0] 0.5 None None \n", 1943 | "oldpeak [0.4] 0.5 None None \n", 1944 | "slope [2.0] 0.5 None None \n", 1945 | "ca [1.0] 0.5 None None \n", 1946 | "thal [7.0] 0.5 None None " 1947 | ] 1948 | }, 1949 | "metadata": {}, 1950 | "output_type": "display_data" 1951 | }, 1952 | { 1953 | "data": { 1954 | "text/html": [ 1955 | "
\n", 1956 | "\n", 1969 | "\n", 1970 | " \n", 1971 | " \n", 1972 | " \n", 1973 | " \n", 1974 | " \n", 1975 | " \n", 1976 | " \n", 1977 | " \n", 1978 | " \n", 1979 | " \n", 1980 | " \n", 1981 | " \n", 1982 | " \n", 1983 | " \n", 1984 | " \n", 1985 | " \n", 1986 | " \n", 1987 | " \n", 1988 | " \n", 1989 | " \n", 1990 | " \n", 1991 | " \n", 1992 | " \n", 1993 | " \n", 1994 | " \n", 1995 | " \n", 1996 | " \n", 1997 | " \n", 1998 | " \n", 1999 | " \n", 2000 | " \n", 2001 | " \n", 2002 | " \n", 2003 | " \n", 2004 | " \n", 2005 | " \n", 2006 | " \n", 2007 | " \n", 2008 | " \n", 2009 | " \n", 2010 | " \n", 2011 | " \n", 2012 | " \n", 2013 | " \n", 2014 | " \n", 2015 | " \n", 2016 | " \n", 2017 | " \n", 2018 | " \n", 2019 | " \n", 2020 | " \n", 2021 | " \n", 2022 | " \n", 2023 | " \n", 2024 | " \n", 2025 | " \n", 2026 | " \n", 2027 | " \n", 2028 | " \n", 2029 | " \n", 2030 | " \n", 2031 | " \n", 2032 | " \n", 2033 | " \n", 2034 | " \n", 2035 | " \n", 2036 | " \n", 2037 | " \n", 2038 | " \n", 2039 | " \n", 2040 | " \n", 2041 | " \n", 2042 | " \n", 2043 | " \n", 2044 | " \n", 2045 | " \n", 2046 | " \n", 2047 | " \n", 2048 | " \n", 2049 | " \n", 2050 | " \n", 2051 | " \n", 2052 | " \n", 2053 | " \n", 2054 | " \n", 2055 | " \n", 2056 | " \n", 2057 | " \n", 2058 | " \n", 2059 | " \n", 2060 | " \n", 2061 | " \n", 2062 | " \n", 2063 | " \n", 2064 | " \n", 2065 | " \n", 2066 | " \n", 2067 | " \n", 2068 | " \n", 2069 | " \n", 2070 | " \n", 2071 | " \n", 2072 | " \n", 2073 | " \n", 2074 | " \n", 2075 | " \n", 2076 | " \n", 2077 | " \n", 2078 | " \n", 2079 | " \n", 2080 | " \n", 2081 | " \n", 2082 | " \n", 2083 | " \n", 2084 | " \n", 2085 | " \n", 2086 | " \n", 2087 | " \n", 2088 | " \n", 2089 | " \n", 2090 | " \n", 2091 | " \n", 2092 | " \n", 2093 | " \n", 2094 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | " \n", 2103 | " \n", 2104 | " \n", 2105 | " \n", 2106 | " \n", 2107 | " \n", 2108 | " \n", 2109 | " \n", 2110 | " \n", 2111 | " \n", 2112 | " \n", 2113 | " \n", 2114 | " \n", 2115 | " \n", 2116 | " \n", 2117 | " \n", 2118 | " \n", 2119 | " \n", 2120 | " \n", 2121 | " \n", 2122 | " \n", 2123 | " \n", 2124 | " \n", 2125 | " \n", 2126 | " \n", 2127 | " \n", 2128 | " \n", 2129 | " \n", 2130 | " \n", 2131 | " \n", 2132 | " \n", 2133 | " \n", 2134 | " \n", 2135 | " \n", 2136 | " \n", 2137 | " \n", 2138 | " \n", 2139 | " \n", 2140 | " \n", 2141 | " \n", 2142 | " \n", 2143 | " \n", 2144 | " \n", 2145 | " \n", 2146 | " \n", 2147 | " \n", 2148 | " \n", 2149 | " \n", 2150 | " \n", 2151 | " \n", 2152 | " \n", 2153 | " \n", 2154 | " \n", 2155 | " \n", 2156 | " \n", 2157 | " \n", 2158 | " \n", 2159 | " \n", 2160 | " \n", 2161 | " \n", 2162 | " \n", 2163 | " \n", 2164 | " \n", 2165 | " \n", 2166 | " \n", 2167 | " \n", 2168 | " \n", 2169 | " \n", 2170 | " \n", 2171 | " \n", 2172 | " \n", 2173 | " \n", 2174 | " \n", 2175 | " \n", 2176 | " \n", 2177 | " \n", 2178 | " \n", 2179 | " \n", 2180 | " \n", 2181 | " \n", 2182 | " \n", 2183 | " \n", 2184 | " \n", 2185 | " \n", 2186 | " \n", 2187 | " \n", 2188 | " \n", 2189 | " \n", 2190 | " \n", 2191 | " \n", 2192 | " \n", 2193 | " \n", 2194 | " \n", 2195 | " \n", 2196 | " \n", 2197 | " \n", 2198 | "
CICUCinfloutnameoutvalfeatureyminymaxinputsinvalsneutralCUtarget_concepttarget_inputs
age0.100.8000000.030Yes0.59age0.510.61[0][57.0]0.5NoneNone
sex0.031.0000000.015Yes0.59sex0.560.59[1][1.0]0.5NoneNone
cp0.160.000000-0.080Yes0.59cp0.590.75[2][3.0]0.5NoneNone
trestbps0.090.111111-0.035Yes0.59trestbps0.580.67[3][128.0]0.5NoneNone
chol0.120.8333330.040Yes0.59chol0.490.61[4][229.0]0.5NoneNone
fbs0.071.0000000.035Yes0.59fbs0.520.59[5][0.0]0.5NoneNone
restecg0.091.0000000.045Yes0.59restecg0.500.59[6][2.0]0.5NoneNone
thalach0.070.000000-0.035Yes0.59thalach0.590.66[7][150.0]0.5NoneNone
exang0.071.0000000.035Yes0.59exang0.520.59[8][0.0]0.5NoneNone
oldpeak0.250.080000-0.105Yes0.59oldpeak0.570.82[9][0.4]0.5NoneNone
slope0.280.9642860.130Yes0.59slope0.320.60[10][2.0]0.5NoneNone
ca0.220.8636360.080Yes0.59ca0.400.62[11][1.0]0.5NoneNone
thal0.181.0000000.090Yes0.59thal0.410.59[12][7.0]0.5NoneNone
\n", 2199 | "
" 2200 | ], 2201 | "text/plain": [ 2202 | " CI CU Cinfl outname outval feature ymin ymax inputs \\\n", 2203 | "age 0.10 0.800000 0.030 Yes 0.59 age 0.51 0.61 [0] \n", 2204 | "sex 0.03 1.000000 0.015 Yes 0.59 sex 0.56 0.59 [1] \n", 2205 | "cp 0.16 0.000000 -0.080 Yes 0.59 cp 0.59 0.75 [2] \n", 2206 | "trestbps 0.09 0.111111 -0.035 Yes 0.59 trestbps 0.58 0.67 [3] \n", 2207 | "chol 0.12 0.833333 0.040 Yes 0.59 chol 0.49 0.61 [4] \n", 2208 | "fbs 0.07 1.000000 0.035 Yes 0.59 fbs 0.52 0.59 [5] \n", 2209 | "restecg 0.09 1.000000 0.045 Yes 0.59 restecg 0.50 0.59 [6] \n", 2210 | "thalach 0.07 0.000000 -0.035 Yes 0.59 thalach 0.59 0.66 [7] \n", 2211 | "exang 0.07 1.000000 0.035 Yes 0.59 exang 0.52 0.59 [8] \n", 2212 | "oldpeak 0.25 0.080000 -0.105 Yes 0.59 oldpeak 0.57 0.82 [9] \n", 2213 | "slope 0.28 0.964286 0.130 Yes 0.59 slope 0.32 0.60 [10] \n", 2214 | "ca 0.22 0.863636 0.080 Yes 0.59 ca 0.40 0.62 [11] \n", 2215 | "thal 0.18 1.000000 0.090 Yes 0.59 thal 0.41 0.59 [12] \n", 2216 | "\n", 2217 | " invals neutralCU target_concept target_inputs \n", 2218 | "age [57.0] 0.5 None None \n", 2219 | "sex [1.0] 0.5 None None \n", 2220 | "cp [3.0] 0.5 None None \n", 2221 | "trestbps [128.0] 0.5 None None \n", 2222 | "chol [229.0] 0.5 None None \n", 2223 | "fbs [0.0] 0.5 None None \n", 2224 | "restecg [2.0] 0.5 None None \n", 2225 | "thalach [150.0] 0.5 None None \n", 2226 | "exang [0.0] 0.5 None None \n", 2227 | "oldpeak [0.4] 0.5 None None \n", 2228 | "slope [2.0] 0.5 None None \n", 2229 | "ca [1.0] 0.5 None None \n", 2230 | "thal [7.0] 0.5 None None " 2231 | ] 2232 | }, 2233 | "metadata": {}, 2234 | "output_type": "display_data" 2235 | } 2236 | ], 2237 | "source": [ 2238 | "from ciu_tests import heart_disease_rf\n", 2239 | "\n", 2240 | "np.random.seed(26) # We want to always get the same Random Forest model here.\n", 2241 | "inst_ind = 2 # Instance 0 has no disease, instance 2 (for instance) has higher probability of disease than no disease\n", 2242 | "CIU_hd, hd_model, hd_instance = heart_disease_rf.get_heart_disease_rf(inst_ind)\n", 2243 | "print(hd_instance)\n", 2244 | "print(hd_model.predict_proba(hd_instance))\n", 2245 | "CIUres_hd = CIU_hd.explain(hd_instance)\n", 2246 | "display(CIUres_hd)\n", 2247 | "CIUres_hd = CIU_hd.explain(hd_instance, output_inds=1)\n", 2248 | "display(CIUres_hd)" 2249 | ] 2250 | } 2251 | ], 2252 | "metadata": { 2253 | "kernelspec": { 2254 | "display_name": "Python 3 (ipykernel)", 2255 | "language": "python", 2256 | "name": "python3" 2257 | }, 2258 | "language_info": { 2259 | "codemirror_mode": { 2260 | "name": "ipython", 2261 | "version": 3 2262 | }, 2263 | "file_extension": ".py", 2264 | "mimetype": "text/x-python", 2265 | "name": "python", 2266 | "nbconvert_exporter": "python", 2267 | "pygments_lexer": "ipython3", 2268 | "version": "3.12.1" 2269 | } 2270 | }, 2271 | "nbformat": 4, 2272 | "nbformat_minor": 4 2273 | } 2274 | --------------------------------------------------------------------------------