├── src └── trustyai │ ├── dep │ ├── .gitkeep │ └── org │ │ └── trustyai │ │ └── .gitkeep │ ├── language │ ├── __init__.py │ └── detoxify │ │ └── __init__.py │ ├── local │ └── __init__.py │ ├── metrics │ ├── fairness │ │ ├── __init__.py │ │ └── group.py │ ├── __init__.py │ ├── language.py │ ├── distance.py │ └── saliency.py │ ├── version.py │ ├── utils │ ├── extras │ │ ├── models.py │ │ ├── timeseries.py │ │ └── metrics_service.py │ ├── tokenizers.py │ ├── text.py │ ├── __init__.py │ ├── DataUtils.py │ ├── api │ │ └── api.py │ ├── _visualisation.py │ └── _tyrus_info_text.py │ ├── explainers │ ├── __init__.py │ ├── explanation_results.py │ ├── extras │ │ ├── tssaliency.py │ │ ├── tslime.py │ │ └── tsice.py │ └── pdp.py │ ├── _default_initializer.py │ ├── __init__.py │ ├── visualizations │ ├── visualization_results.py │ ├── distance.py │ ├── pdp.py │ ├── __init__.py │ ├── lime.py │ └── shap.py │ ├── model │ └── domain.py │ └── initializer.py ├── docs ├── requirements.txt ├── clean.sh ├── _static │ ├── css │ │ ├── fonts │ │ │ ├── RedHatMono-VariableFont_wght.woff │ │ │ ├── RedHatText-VariableFont_wght.woff │ │ │ ├── RedHatDisplay-VariableFont_wght.woff │ │ │ ├── RedHatMono-Italic-VariableFont_wght.woff │ │ │ ├── RedHatText-Italic-VariableFont_wght.woff │ │ │ └── RedHatDisplay-Italic-VariableFont_wght.woff │ │ └── custom.css │ └── artwork │ │ ├── logo.png │ │ └── favicon.png ├── generated │ ├── trustyai.model.feature.rst │ ├── trustyai.model.output.rst │ ├── trustyai.initializer.init.rst │ ├── trustyai.model.feature_domain.rst │ ├── trustyai.model.simple_prediction.rst │ ├── trustyai.model.counterfactual_prediction.rst │ ├── trustyai.explainers.LimeExplainer.rst │ ├── trustyai.explainers.SHAPExplainer.rst │ ├── trustyai.model.Model.rst │ ├── trustyai.explainers.CounterfactualExplainer.rst │ ├── trustyai.explainers.LimeResults.rst │ ├── trustyai.explainers.SHAPResults.rst │ ├── trustyai.model.Dataset.rst │ └── trustyai.explainers.CounterfactualResult.rst ├── Makefile ├── make.bat ├── api.rst ├── index.rst └── conf.py ├── tests ├── benchmarks │ ├── benchmark_common.py │ ├── xai_benchmark.py │ └── benchmark.py ├── general │ ├── data │ │ ├── income-biased.zip │ │ └── income-unbiased.zip │ ├── models │ │ ├── income-xgd-biased.joblib │ │ └── credit-bias-model-clean.joblib │ ├── common.py │ ├── test_tyrus.py │ ├── test_model.py │ ├── universal.py │ ├── test_pdp.py │ ├── test_prediction.py │ ├── test_shap_background_generation.py │ ├── test_dataset.py │ ├── test_metrics_language.py │ ├── test_datautils.py │ ├── test_shap.py │ └── test_limeexplainer.py ├── initialization │ └── test_initialization.py └── extras │ ├── test_metrics_service.py │ ├── test_tssaliency.py │ ├── test_tslime.py │ └── test_tsice.py ├── MANIFEST.in ├── requirements.txt ├── .dockerignore ├── .gitmodules ├── deps.sh ├── .github ├── workflows │ ├── publish.yml │ ├── workflow.yml │ ├── benchmarks.yml │ ├── benchmarks-merge.yml │ └── security.yaml └── actions │ └── build-core │ └── action.yml ├── .readthedocs.yaml ├── scripts ├── local.sh └── build.sh ├── README.md ├── cliff.toml ├── .gitignore ├── pyproject.toml ├── info └── detoxify.md └── CONTRIBUTING.md /src/trustyai/dep/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trustyai/language/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trustyai/local/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme -------------------------------------------------------------------------------- /src/trustyai/dep/org/trustyai/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_common.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trustyai/metrics/fairness/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/clean.sh: -------------------------------------------------------------------------------- 1 | rm generated/* 2 | rm -r _build/ -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatMono-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatText-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatDisplay-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatMono-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatText-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatDisplay-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/trustyai/version.py: -------------------------------------------------------------------------------- 1 | """TrustyAI version""" 2 | 3 | __version__ = "0.6.3" 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | prune tests 3 | prune docs 4 | prune .github 5 | 6 | global-exclude *~ *.py[cod] *.so *.sh -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Jpype1==1.5.0 2 | pyarrow>=20.0.0 3 | matplotlib~=3.10.3 4 | pandas>=2.1.0 5 | numpy>=1.26.4 6 | jupyter-bokeh~=4.0.5 -------------------------------------------------------------------------------- /docs/_static/artwork/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/docs/_static/artwork/logo.png -------------------------------------------------------------------------------- /src/trustyai/language/detoxify/__init__.py: -------------------------------------------------------------------------------- 1 | """Language detoxification module.""" 2 | 3 | from trustyai.language.detoxify.tmarco import TMaRCo 4 | -------------------------------------------------------------------------------- /docs/_static/artwork/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/docs/_static/artwork/favicon.png -------------------------------------------------------------------------------- /tests/general/data/income-biased.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/tests/general/data/income-biased.zip -------------------------------------------------------------------------------- /docs/generated/trustyai.model.feature.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.feature 2 | ====================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: feature -------------------------------------------------------------------------------- /docs/generated/trustyai.model.output.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.output 2 | ===================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: output -------------------------------------------------------------------------------- /tests/general/data/income-unbiased.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/tests/general/data/income-unbiased.zip -------------------------------------------------------------------------------- /docs/generated/trustyai.initializer.init.rst: -------------------------------------------------------------------------------- 1 | trustyai.initializer.init 2 | ========================= 3 | 4 | .. currentmodule:: trustyai.initializer 5 | 6 | .. autofunction:: init -------------------------------------------------------------------------------- /tests/general/models/income-xgd-biased.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/tests/general/models/income-xgd-biased.joblib -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | *.log 2 | __pycache__ 3 | .pytest_cache 4 | .pynb_checkpoints 5 | .Rproj.user 6 | .vscode 7 | .idea 8 | .mypy_cache 9 | build 10 | dist 11 | trustyai.egg-info 12 | *.pyc -------------------------------------------------------------------------------- /src/trustyai/utils/extras/models.py: -------------------------------------------------------------------------------- 1 | """AIX360 model wrappers""" 2 | 3 | from aix360.algorithms.tsutils.model_wrappers import * # pylint: disable=wildcard-import,unused-wildcard-import 4 | -------------------------------------------------------------------------------- /tests/general/models/credit-bias-model-clean.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/HEAD/tests/general/models/credit-bias-model-clean.joblib -------------------------------------------------------------------------------- /docs/generated/trustyai.model.feature_domain.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.feature\_domain 2 | ============================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: feature_domain -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/benchmarks/trustyai_xai_bench"] 2 | path = tests/benchmarks/trustyai_xai_bench 3 | url = https://github.com/trustyai-explainability/trustyai_xai_bench 4 | branch = main 5 | -------------------------------------------------------------------------------- /docs/generated/trustyai.model.simple_prediction.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.simple\_prediction 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: simple_prediction -------------------------------------------------------------------------------- /docs/generated/trustyai.model.counterfactual_prediction.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.counterfactual\_prediction 2 | ========================================= 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: counterfactual_prediction -------------------------------------------------------------------------------- /src/trustyai/utils/extras/timeseries.py: -------------------------------------------------------------------------------- 1 | """Extra time series utilities.""" 2 | 3 | from aix360.algorithms.tsutils.tsframe import tsFrame # pylint: disable=unused-import 4 | from aix360.algorithms.tsutils.tsperturbers import * # pylint: disable=wildcard-import,unused-wildcard-import 5 | -------------------------------------------------------------------------------- /src/trustyai/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Explainers module""" 2 | 3 | # pylint: disable=duplicate-code 4 | from .counterfactuals import CounterfactualResult, CounterfactualExplainer 5 | from .lime import LimeExplainer, LimeResults 6 | from .shap import SHAPExplainer, SHAPResults, BackgroundGenerator 7 | from .pdp import PDPExplainer 8 | -------------------------------------------------------------------------------- /src/trustyai/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | """ "Default tokenizers for TrustyAI.""" 2 | 3 | # pylint: disable = import-error 4 | 5 | from org.apache.commons.text import StringTokenizer as _StringTokenizer 6 | from opennlp.tools.tokenize import SimpleTokenizer as _SimpleTokenizer 7 | 8 | CommonsStringTokenizer = _StringTokenizer 9 | OpenNLPTokenizer = _SimpleTokenizer 10 | -------------------------------------------------------------------------------- /src/trustyai/_default_initializer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value, invalid-name, R0801 2 | """The default initializer""" 3 | import trustyai 4 | from trustyai import initializer # pylint: disable=no-name-in-module 5 | 6 | if not trustyai.TRUSTYAI_IS_INITIALIZED: 7 | trustyai.TRUSTYAI_IS_INITIALIZED = initializer.init() 8 | -------------------------------------------------------------------------------- /src/trustyai/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, invalid-name, wrong-import-order, no-name-in-module 2 | """General model classes""" 3 | from trustyai import _default_initializer # pylint: disable=unused-import 4 | from org.kie.trustyai.metrics.explainability import ( 5 | ExplainabilityMetrics as _ExplainabilityMetrics, 6 | ) 7 | 8 | ExplainabilityMetrics = _ExplainabilityMetrics 9 | -------------------------------------------------------------------------------- /src/trustyai/utils/text.py: -------------------------------------------------------------------------------- 1 | """Utility methods for text data handling""" 2 | 3 | from typing import List, Callable 4 | 5 | from jpype import _jclass 6 | 7 | 8 | def tokenizer(function: Callable[[str], List[str]]): 9 | """Post-process outputs of a Python tokenizer function""" 10 | 11 | def wrapper(_input: str): 12 | return _jclass.JClass("java.util.Arrays").asList(function(_input)) 13 | 14 | return wrapper 15 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.LimeExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.LimeExplainer 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: LimeExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~LimeExplainer.__init__ 17 | ~LimeExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.SHAPExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.SHAPExplainer 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: SHAPExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SHAPExplainer.__init__ 17 | ~SHAPExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.model.Model.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.Model 2 | ==================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autoclass:: Model 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Model.__init__ 17 | ~Model.equals 18 | ~Model.hashCode 19 | ~Model.predictAsync 20 | ~Model.toString 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.CounterfactualExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.CounterfactualExplainer 2 | =========================================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: CounterfactualExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~CounterfactualExplainer.__init__ 17 | ~CounterfactualExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.LimeResults.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.LimeResults 2 | =============================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: LimeResults 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~LimeResults.__init__ 17 | ~LimeResults.as_dataframe 18 | ~LimeResults.as_html 19 | ~LimeResults.map 20 | ~LimeResults.plot 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.SHAPResults.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.SHAPResults 2 | =============================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: SHAPResults 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SHAPResults.__init__ 17 | ~SHAPResults.as_dataframe 18 | ~SHAPResults.as_html 19 | ~SHAPResults.candlestick_plot 20 | ~SHAPResults.get_fnull 21 | ~SHAPResults.get_saliencies 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | TRUSTY_VERSION="1.12.0.Final" 4 | 5 | mvn org.apache.maven.plugins:maven-dependency-plugin:2.10:get \ 6 | -DremoteRepositories=https://repository.sonatype.org/content/repositories/central \ 7 | -Dartifact=org.kie.kogito:explainability-core:$TRUSTY_VERSION \ 8 | -Dmaven.repo.local=dep -q 9 | 10 | # We also need the test JARs in order to get the test models 11 | wget -O ./dep/org/kie/kogito/explainability-core/$TRUSTY_VERSION/explainability-core-$TRUSTY_VERSION-tests.jar \ 12 | https://repo1.maven.org/maven2/org/kie/kogito/explainability-core/$TRUSTY_VERSION/explainability-core-$TRUSTY_VERSION-tests.jar -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | on: 3 | release: 4 | types: [ published ] 5 | jobs: 6 | pypi-publish: 7 | name: upload release to PyPI 8 | runs-on: ubuntu-latest 9 | environment: pypi 10 | permissions: 11 | id-token: write 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 0 17 | - name: Build explainability-core 18 | uses: ./.github/actions/build-core 19 | - run: python3 -m pip install --upgrade build && python3 -m build 20 | - name: Publish package 21 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/actions/build-core/action.yml: -------------------------------------------------------------------------------- 1 | name: Build exp-core JAR 2 | description: Clone and build TrustyAI-Explainability library (shaded in a single JAR) 3 | runs: 4 | using: "composite" 5 | steps: 6 | - name: Set up JDK 17 7 | uses: actions/setup-java@v2 8 | with: 9 | distribution: 'adopt' 10 | java-version: '17' 11 | - name: Build explainability-core 12 | shell: bash 13 | run: | 14 | git clone https://github.com/trustyai-explainability/trustyai-explainability.git 15 | mvn clean install -DskipTests -f trustyai-explainability/pom.xml -Pshaded -fae -e -nsu 16 | mv trustyai-explainability/explainability-arrow/target/explainability-arrow-*-SNAPSHOT.jar src/trustyai/dep/org/trustyai/ -------------------------------------------------------------------------------- /docs/generated/trustyai.model.Dataset.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.Dataset 2 | ====================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autoclass:: Dataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Dataset.__init__ 17 | ~Dataset.df_to_prediction_object 18 | ~Dataset.from_df 19 | ~Dataset.from_numpy 20 | ~Dataset.numpy_to_prediction_object 21 | ~Dataset.prediction_object_to_numpy 22 | ~Dataset.prediction_object_to_pandas 23 | 24 | 25 | 26 | 27 | 28 | .. rubric:: Attributes 29 | 30 | .. autosummary:: 31 | 32 | ~Dataset.data 33 | ~Dataset.inputs 34 | ~Dataset.outputs 35 | 36 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.CounterfactualResult.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.CounterfactualResult 2 | ======================================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: CounterfactualResult 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~CounterfactualResult.__init__ 17 | ~CounterfactualResult.as_dataframe 18 | ~CounterfactualResult.as_html 19 | ~CounterfactualResult.plot 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~CounterfactualResult.proposed_features_array 30 | ~CounterfactualResult.proposed_features_dataframe 31 | 32 | -------------------------------------------------------------------------------- /src/trustyai/explainers/explanation_results.py: -------------------------------------------------------------------------------- 1 | """Generic class for Explanation and Saliency results""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import pandas as pd 6 | from pandas.io.formats.style import Styler 7 | 8 | 9 | class ExplanationResults(ABC): 10 | """Abstract class for explanation visualisers""" 11 | 12 | @abstractmethod 13 | def as_dataframe(self) -> pd.DataFrame: 14 | """Display explanation result as a dataframe""" 15 | 16 | @abstractmethod 17 | def as_html(self) -> Styler: 18 | """Visualise the styled dataframe""" 19 | 20 | 21 | # pylint: disable=too-few-public-methods 22 | class SaliencyResults(ExplanationResults): 23 | """Abstract class for saliency visualisers""" 24 | 25 | @abstractmethod 26 | def saliency_map(self): 27 | """Return the Saliencies as a dictionary, keyed by output name""" 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/benchmarks/xai_benchmark.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from trustyai_xai_bench import run_benchmark_config 3 | 4 | 5 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 6 | def test_level_0(benchmark): 7 | # ~4.5 min 8 | result = benchmark(run_benchmark_config, 0) 9 | benchmark.extra_info['runs'] = result.to_dict('records') 10 | 11 | 12 | @pytest.mark.skip(reason="full diagnostic benchmark, ~2 hour runtime") 13 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 14 | def test_level_1(benchmark): 15 | result = benchmark(run_benchmark_config, 1) 16 | benchmark.extra_info['runs'] = result.to_dict('records') 17 | 18 | 19 | @pytest.mark.skip(reason="very thorough benchmark, >>2 hour runtime") 20 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 21 | def test_level_2(benchmark): 22 | result = benchmark(run_benchmark_config, 2) 23 | benchmark.extra_info['runs'] = result.to_dict('records') -------------------------------------------------------------------------------- /src/trustyai/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value 2 | # pylint: disable = invalid-name, R0801, duplicate-code 3 | """Main TrustyAI Python bindings""" 4 | import os 5 | import logging 6 | 7 | # set initialized env variable to 0 8 | import warnings 9 | from .version import __version__ 10 | 11 | TRUSTYAI_IS_INITIALIZED = False 12 | 13 | if os.getenv("PYTHON_TRUSTY_DEBUG") == "1": 14 | _LOGGING_LEVEL = logging.DEBUG 15 | else: 16 | _LOGGING_LEVEL = logging.WARN 17 | 18 | logging.basicConfig(level=_LOGGING_LEVEL) 19 | 20 | 21 | def init(): 22 | """Deprecated manual initializer for the JVM. This function has been replaced by 23 | automatic initialization when importing the components of the module that require 24 | JVM access, or by manual user initialization via :func:`trustyai`initializer.init`. 25 | """ 26 | warnings.warn( 27 | "trustyai.init() is now deprecated; the trustyai library will now " 28 | + "automatically initialize. For manual initialization options, see " 29 | + "trustyai.initializer.init()" 30 | ) 31 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | 4 | @font-face { 5 | font-family: "Red Hat Text", sans-serif; !important; 6 | src: url("fonts/RedHatText-VariableFont_wght.woff"); 7 | } 8 | 9 | @font-face { 10 | font-family: "Red Hat Display", sans-serif; !important; 11 | src: url("fonts/RedHatDisplay-VariableFont_wght.woff"); 12 | } 13 | 14 | @font-face { 15 | font-family: "Red Hat Mono", sans-serif; !important; 16 | src: url("fonts/RedHatMono-VariableFont_wght.woff"); 17 | } 18 | 19 | body { 20 | font-family: "Red Hat Text", sans-serif; !important; 21 | } 22 | 23 | h1, h2, h3, h4, h5, h6 { 24 | font-family: "Red Hat Display", sans-serif; !important; 25 | } 26 | 27 | .rst-content code, .rst-content tt { 28 | font-family: "Red Hat Mono", sans-serif; !important; 29 | } 30 | 31 | .wy-side-nav-search { 32 | background-color: #343131 !important; 33 | } 34 | 35 | html.writer-html4 .rst-content dl:not(.docutils)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt{ 36 | border-top: 3px solid #e06666; 37 | background: #e0666633; 38 | color: #a64d79; 39 | } 40 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/visualization_results.py: -------------------------------------------------------------------------------- 1 | """Generic class for Visualization results""" 2 | 3 | # pylint: disable = import-error, too-few-public-methods, line-too-long, missing-final-newline 4 | from abc import ABC, abstractmethod 5 | from typing import Dict 6 | 7 | import bokeh.models 8 | 9 | 10 | class VisualizationResults(ABC): 11 | """Abstract class for visualization results""" 12 | 13 | @abstractmethod 14 | def _matplotlib_plot( 15 | self, explanations, output_name: str, block: bool, call_show: bool 16 | ) -> None: 17 | """Plot the saliencies of a particular output in matplotlib""" 18 | 19 | @abstractmethod 20 | def _get_bokeh_plot(self, explanations, output_name: str) -> bokeh.models.Plot: 21 | """Get a bokeh plot visualizing the saliencies of a particular output""" 22 | 23 | @abstractmethod 24 | def _get_bokeh_plot_dict(self, explanations) -> Dict[str, bokeh.models.Plot]: 25 | """Get a dictionary containing visualizations of the saliencies of all outputs, 26 | keyed by output name""" 27 | return { 28 | output_name: self._get_bokeh_plot(explanations, output_name) 29 | for output_name in explanations.saliency_map().keys() 30 | } 31 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | apt_packages : 12 | - maven 13 | tools: 14 | python: "3.9" 15 | jobs: 16 | pre_create_environment: 17 | - rm -f src/trustyai/dep/org/trustyai/* 18 | - git clone https://github.com/trustyai-explainability/trustyai-explainability.git 19 | - mvn clean install -DskipTests -f trustyai-explainability/pom.xml -Pquickly -fae -e -nsu 20 | - mvn clean install -DskipTests -f trustyai-explainability/explainability-arrow/pom.xml -Pshaded -fae -e -nsu 21 | - mv trustyai-explainability/explainability-arrow/target/explainability-arrow-*-SNAPSHOT.jar src/trustyai/dep/org/trustyai/ 22 | 23 | post_build: 24 | - rm -Rf trustyai-explainability 25 | 26 | # install the package 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | - method: pip 31 | path: . 32 | extra_requirements: 33 | - dev 34 | 35 | # Build documentation in the docs/ directory with Sphinx 36 | sphinx: 37 | configuration: docs/conf.py 38 | -------------------------------------------------------------------------------- /src/trustyai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, invalid-name, wrong-import-order 2 | """General model classes""" 3 | 4 | from jpype._jproxy import _createJProxy, _createJProxyDeferred 5 | from trustyai import _default_initializer 6 | 7 | from org.kie.trustyai.explainability import Config as _Config 8 | from org.kie.trustyai.explainability.utils.models import TestModels as _TestModels 9 | 10 | TestModels = _TestModels 11 | Config = _Config 12 | 13 | 14 | def JImplementsWithDocstring(*interfaces, deferred=False, **kwargs): 15 | """JPype's JImplements decorator overwrites the docstring of any annotated functions. This 16 | is a quick hack to preserve docstrings across the jproxy process.""" 17 | if deferred: 18 | 19 | def JProxyCreator(cls): 20 | proxy_class = _createJProxyDeferred(cls, *interfaces, **kwargs) 21 | proxy_class.__doc__ = cls.__doc__ 22 | proxy_class.__name__ = cls.__name__ 23 | return proxy_class 24 | 25 | else: 26 | 27 | def JProxyCreator(cls): 28 | proxy_class = _createJProxy(cls, *interfaces, **kwargs) 29 | proxy_class.__doc__ = cls.__doc__ 30 | proxy_class.__name__ = cls.__name__ 31 | return proxy_class 32 | 33 | return JProxyCreator 34 | -------------------------------------------------------------------------------- /src/trustyai/utils/DataUtils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = invalid-name, import-error 2 | """DataUtils module""" 3 | from org.kie.trustyai.explainability.utils import DataUtils as du 4 | 5 | getMean = du.getMean 6 | getStdDev = du.getStdDev 7 | gaussianKernel = du.gaussianKernel 8 | euclideanDistance = du.euclideanDistance 9 | hammingDistance = du.hammingDistance 10 | doublesToFeatures = du.doublesToFeatures 11 | exponentialSmoothingKernel = du.exponentialSmoothingKernel 12 | generateRandomDataDistribution = du.generateRandomDataDistribution 13 | 14 | 15 | def generateData(mean, stdDeviation, size, jrandom): 16 | """Generate data""" 17 | return list(du.generateData(mean, stdDeviation, size, jrandom)) 18 | 19 | 20 | def perturbFeatures(originalFeatures, perturbationContext): 21 | """Perform perturbations on a fixed number of features in the given input.""" 22 | return du.perturbFeatures(originalFeatures, perturbationContext) 23 | 24 | 25 | def getLinearizedFeatures(originalFeatures): 26 | """Transform a list of eventually composite / nested features into a 27 | flat list of non composite / non nested features.""" 28 | return du.getLinearizedFeatures(originalFeatures) 29 | 30 | 31 | def sampleWithReplacement(values, sampleSize, jrandom): 32 | """Sample (with replacement) from a list of values.""" 33 | return du.sampleWithReplacement(values, sampleSize, jrandom) 34 | -------------------------------------------------------------------------------- /tests/general/common.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """Common methods and models for tests""" 3 | import os 4 | import sys 5 | from typing import Optional, List 6 | 7 | import numpy as np 8 | import pandas as pd # pylint: disable=unused-import 9 | 10 | myPath = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.insert(0, myPath + "/../../src") 12 | 13 | from trustyai.model import ( 14 | FeatureFactory, 15 | ) 16 | 17 | 18 | def mock_feature(value, name='f-num'): 19 | """Create a mock numerical feature""" 20 | return FeatureFactory.newNumericalFeature(name, value) 21 | 22 | 23 | def sum_skip_model(inputs: np.ndarray) -> np.ndarray: 24 | """SumSkip test model""" 25 | return np.sum(inputs[:, [i for i in range(inputs.shape[1]) if i != 5]], 1) 26 | 27 | 28 | def create_random_dataframe(weights: Optional[List[float]] = None): 29 | """Create a simple random Pandas dataframe""" 30 | from sklearn.datasets import make_classification 31 | if not weights: 32 | weights = [0.9, 0.1] 33 | 34 | X, y = make_classification(n_samples=5000, n_features=2, n_informative=2, n_redundant=0, n_repeated=0, n_classes=2, 35 | n_clusters_per_class=2, class_sep=2, flip_y=0, weights=weights, 36 | random_state=23) 37 | 38 | return pd.DataFrame({ 39 | 'x1': X[:, 0], 40 | 'x2': X[:, 1], 41 | 'y': y 42 | }) 43 | -------------------------------------------------------------------------------- /tests/initialization/test_initialization.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | 4 | import pytest 5 | from multiprocessing import Process, Value 6 | import sys 7 | 8 | 9 | # slightly hacky functions to make sure the test process does not see the trustyai initialization 10 | # from commons.py 11 | def test_manual_initializer_process(): 12 | import trustyai 13 | from trustyai import initializer 14 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 15 | initializer.init(path=initializer._get_default_path()[0]) 16 | 17 | # test imports work 18 | from trustyai.explainers import SHAPExplainer 19 | 20 | # test initialization is set 21 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 22 | assert initial_state == False 23 | assert final_state == True 24 | 25 | 26 | def test_default_initializer_process_mod(): 27 | import trustyai 28 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 29 | import trustyai.model 30 | 31 | # test initialization is set 32 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 33 | assert initial_state == False 34 | assert final_state == True 35 | 36 | 37 | def test_default_initializer_process_exp(): 38 | import trustyai 39 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 40 | import trustyai.explainers 41 | 42 | # test initialization is set 43 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 44 | assert initial_state == False 45 | assert final_state == True 46 | -------------------------------------------------------------------------------- /scripts/local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Red Hat, Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | ROOT_DIR=$(git rev-parse --show-toplevel) 19 | 20 | EXP_CORE_DEST=$1 21 | 22 | if [[ "$EXP_CORE_DEST" == "" ]] 23 | then 24 | EXP_CORE_DEST="../trustyai-explainability" 25 | echo "No argument provided, building trustyai-explainability from ${EXP_CORE_DEST}" 26 | else 27 | echo "Building trustyai-explainability from ${EXP_CORE_DEST}" 28 | fi 29 | 30 | echo "Copying JARs from ${EXP_CORE_DEST} into ${ROOT_DIR}/dep/org/trustyai/" 31 | mvn install package -DskipTests -f "${EXP_CORE_DEST}"/pom.xml -Pshaded 32 | mv "${EXP_CORE_DEST}"/explainability-arrow/target/explainability-arrow-*.jar "${ROOT_DIR}"/src/trustyai/dep/org/trustyai/ 33 | 34 | 35 | if [[ "$VIRTUAL_ENV" != "" ]] 36 | then 37 | pip install "${ROOT_DIR}" --force 38 | else 39 | echo "Not in a virtualenv. Installation not recommended." 40 | exit 1 41 | fi 42 | -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ '3.10', '3.11', '3.12' ] 11 | java-version: [ 17 ] 12 | maven-version: [ '3.8.6' ] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up JDK + Maven version 16 | uses: s4u/setup-maven-action@v1.4.0 17 | with: 18 | java-version: ${{ matrix.java-version }} 19 | maven-version: ${{ matrix.maven-version }} 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Build explainability-core 25 | uses: ./.github/actions/build-core 26 | - name: Install TrustyAI Python package 27 | run: | 28 | pip install --upgrade pip 29 | pip install . 30 | pip install ".[dev]" 31 | pip install ".[api]" 32 | # Extras extra removed; keep AIX360-only deps out of the default install 33 | - name: Lint 34 | run: | 35 | pylint --ignore-imports=yes $(find src/trustyai -type f -name "*.py" | grep -v "/extras/") 36 | - name: Test with pytest 37 | run: | 38 | pytest -v -s tests/general 39 | # pytest -v -s tests/extras # Extras-only deps not installed in CI 40 | pytest -v -s tests/initialization --forked 41 | - name: Style 42 | run: | 43 | black --check $(find src/trustyai -type f -name "*.py") 44 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Red Hat, Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | ROOT_DIR=$(git rev-parse --show-toplevel) 19 | TMP_DIR=$(mktemp -d) 20 | 21 | EXP_CORE="trustyai-explainability" 22 | 23 | EXP_CORE_DEST="${TMP_DIR}/${EXP_CORE}" 24 | if [ ! -d "${EXP_CORE_DEST}" ] 25 | then 26 | echo "Cloning trustyai-explainability into ${EXP_CORE_DEST}" 27 | git clone --branch main https://github.com/${EXP_CORE}/${EXP_CORE}.git "${EXP_CORE_DEST}" 28 | echo "Copying JARs from ${EXP_CORE_DEST} into ${ROOT_DIR}/dep/org/trustyai/" 29 | mvn install package -DskipTests -f "${EXP_CORE_DEST}"/pom.xml -Pshaded 30 | mv "${EXP_CORE_DEST}"/explainability-arrow/target/explainability-arrow-*.jar "${ROOT_DIR}"/src/trustyai/dep/org/trustyai/ 31 | else 32 | echo "Directory ${EXP_CORE_DEST} already exists. Please delete it or move it." 33 | exit 1 34 | fi 35 | 36 | if [[ "$VIRTUAL_ENV" != "" ]] 37 | then 38 | pip install "${ROOT_DIR}" --force 39 | else 40 | echo "Not in a virtualenv. Installation not recommended." 41 | exit 1 42 | fi 43 | -------------------------------------------------------------------------------- /src/trustyai/utils/api/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Server module 3 | """ 4 | 5 | # pylint: disable = import-error, too-few-public-methods, assignment-from-no-return 6 | __SUCCESSFUL_IMPORT = True 7 | 8 | try: 9 | from kubernetes import config, dynamic 10 | from kubernetes.dynamic.exceptions import ResourceNotFoundError 11 | from kubernetes.client import api_client 12 | 13 | except ImportError as e: 14 | print( 15 | "Warning: api dependencies not found. " 16 | "Dependencies can be installed with 'pip install trustyai[api]" 17 | ) 18 | __SUCCESSFUL_IMPORT = False 19 | 20 | if __SUCCESSFUL_IMPORT: 21 | 22 | class TrustyAIApi: 23 | """ 24 | Gets TrustyAI service information 25 | """ 26 | 27 | def __init__(self): 28 | try: 29 | k8s_client = config.load_incluster_config() 30 | except config.ConfigException: 31 | k8s_client = config.load_kube_config() 32 | self.dyn_client = dynamic.DynamicClient( 33 | api_client.ApiClient(configuration=k8s_client) 34 | ) 35 | 36 | def get_service_route(self, name: str, namespace: str): 37 | """ 38 | Gets routes for services under a specified namespace 39 | """ 40 | route_api = self.dyn_client.resources.get(api_version="v1", kind="Route") 41 | try: 42 | service = route_api.get(name=name, namespace=namespace) 43 | return f"https://{service.spec.host}" 44 | except ResourceNotFoundError: 45 | return f"Error accessing service {name} in namespace {namespace}." 46 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/distance.py: -------------------------------------------------------------------------------- 1 | """Visualizations.distance module""" 2 | 3 | # pylint: disable = import-error, too-few-public-methods, line-too-long, missing-final-newline 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class DistanceViz: 9 | """Visualizes Levenshtein distance""" 10 | 11 | def plot(self, explanations): 12 | """Plot the Levenshtein distance matrix""" 13 | cmap = plt.cm.viridis # pylint: disable=no-member 14 | 15 | _, axes = plt.subplots() 16 | cax = axes.imshow(explanations.matrix, cmap=cmap, interpolation="nearest") 17 | 18 | plt.colorbar(cax) 19 | 20 | axes.set_xticks(np.arange(len(explanations.reference))) 21 | axes.set_yticks(np.arange(len(explanations.hypothesis))) 22 | axes.set_xticklabels(explanations.reference) 23 | axes.set_yticklabels(explanations.hypothesis) 24 | 25 | plt.setp( 26 | axes.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" 27 | ) 28 | 29 | nrows, ncols = explanations.matrix.shape 30 | for i in range(nrows): 31 | for j in range(ncols): 32 | color = ( 33 | "white" 34 | if explanations.matrix[i, j] < explanations.matrix.max() / 2 35 | else "black" 36 | ) 37 | axes.text( 38 | j, 39 | i, 40 | int(explanations.matrix[i, j]), 41 | ha="center", 42 | va="center", 43 | color=color, 44 | ) 45 | 46 | plt.show() 47 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: trustyai 2 | 3 | API Reference 4 | ============= 5 | This page contains the API reference for public objects and function within TrustyAI. See the 6 | (example notebooks) for usage guides and tutorials. 7 | 8 | trustyai.initializer 9 | -------------------- 10 | Initializing The JVM 11 | ########################## 12 | .. currentmodule:: trustyai.initializer 13 | .. model_api: 14 | .. autosummary:: 15 | :toctree: generated/ 16 | 17 | init 18 | 19 | 20 | trustyai.model 21 | -------------- 22 | Feature and Output Objects 23 | ########################## 24 | .. currentmodule:: trustyai.model 25 | .. model_api: 26 | .. autosummary:: 27 | :toctree: generated/ 28 | 29 | feature 30 | feature_domain 31 | output 32 | 33 | Data Objects 34 | ############ 35 | .. autosummary:: 36 | :toctree: generated/ 37 | 38 | Dataset 39 | 40 | Model Classes 41 | ############# 42 | .. autosummary:: 43 | :toctree: generated/ 44 | 45 | Model 46 | 47 | trustyai.explainers 48 | ------------------- 49 | LIME 50 | #### 51 | .. currentmodule:: trustyai.explainers 52 | .. explainers_api: 53 | .. autosummary:: 54 | :toctree: generated/ 55 | 56 | LimeExplainer 57 | LimeResults 58 | 59 | SHAP 60 | #### 61 | .. autosummary:: 62 | :toctree: generated/ 63 | 64 | SHAPExplainer 65 | BackgroundGenerator 66 | SHAPResults 67 | 68 | Counterfactuals 69 | ############### 70 | .. autosummary:: 71 | :toctree: generated/ 72 | 73 | CounterfactualExplainer 74 | CounterfactualResult 75 | 76 | trustyai.utils 77 | -------------- 78 | .. currentmodule:: trustyai.utils.tyrus 79 | .. utils_api: 80 | .. autosummary:: 81 | :toctree: generated/ 82 | 83 | Tyrus 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- 1 | name: TrustyAI Python benchmarks (PR) 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | deployments: write 11 | pages: write 12 | pull-requests: write 13 | 14 | jobs: 15 | benchmark: 16 | name: Run pytest-benchmark benchmark 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v2 20 | - uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - uses: actions/setup-java@v2 24 | with: 25 | distribution: "adopt" 26 | java-version: "11" 27 | check-latest: true 28 | - uses: stCarolas/setup-maven@v4 29 | with: 30 | maven-version: 3.8.1 31 | - name: Build explainability-core 32 | uses: ./.github/actions/build-core 33 | - name: Install TrustyAI Python package 34 | run: | 35 | pip install -r requirements-dev.txt 36 | pip install . 37 | - name: Run benchmark 38 | run: | 39 | pytest tests/benchmarks/benchmark.py --benchmark-json tests/benchmarks/results.json 40 | - name: Benchmark result comment 41 | uses: benchmark-action/github-action-benchmark@v1 42 | with: 43 | name: TrustyAI continuous benchmarks 44 | tool: 'pytest' 45 | output-file-path: tests/benchmarks/results.json 46 | github-token: ${{ secrets.GITHUB_TOKEN }} 47 | auto-push: false 48 | alert-threshold: '200%' 49 | comment-on-alert: true 50 | save-data-file: false 51 | comment-always: true 52 | fail-on-alert: false 53 | alert-comment-cc-users: '@ruivieira' 54 | -------------------------------------------------------------------------------- /.github/workflows/benchmarks-merge.yml: -------------------------------------------------------------------------------- 1 | name: TrustyAI Python benchmarks (merge) 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | 9 | permissions: 10 | contents: write 11 | deployments: write 12 | pages: write 13 | pull-requests: write 14 | 15 | jobs: 16 | benchmark: 17 | if: github.event.pull_request.merged == 'true' 18 | name: Run pytest-benchmark benchmark 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v2 22 | - uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.8 25 | - uses: actions/setup-java@v2 26 | with: 27 | distribution: "adopt" 28 | java-version: "11" 29 | check-latest: true 30 | - uses: stCarolas/setup-maven@v4 31 | with: 32 | maven-version: 3.8.1 33 | - name: Build explainability-core 34 | uses: ./.github/actions/build-core 35 | - name: Install TrustyAI Python package 36 | run: | 37 | pip install -r requirements-dev.txt 38 | pip install . 39 | - name: Run benchmark 40 | run: | 41 | pytest tests/benchmarks/benchmark.py --benchmark-json tests/benchmarks/results.json 42 | - name: Store benchmark result 43 | uses: benchmark-action/github-action-benchmark@v1 44 | with: 45 | name: TrustyAI continuous benchmarks 46 | tool: 'pytest' 47 | output-file-path: tests/benchmarks/results.json 48 | github-token: ${{ secrets.GITHUB_TOKEN }} 49 | auto-push: true 50 | gh-pages-branch: gh-pages 51 | alert-threshold: '200%' 52 | comment-on-alert: true 53 | comment-always: true 54 | fail-on-alert: false 55 | alert-comment-cc-users: '@ruivieira' -------------------------------------------------------------------------------- /tests/general/test_tyrus.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from trustyai.model import Model 5 | from trustyai.utils.tyrus import Tyrus 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import os 10 | 11 | 12 | def test_tyrus_series(): 13 | # define data 14 | data = pd.DataFrame(np.random.rand(101, 5), columns=list('ABCDE')) 15 | 16 | # define model 17 | def predict_function(x): 18 | return pd.DataFrame( 19 | np.stack( 20 | [x.sum(1), x.std(1), np.linalg.norm(x, axis=1)]).T, 21 | columns= ['Sum', 'StdDev', 'L2 Norm']) 22 | 23 | predictions = predict_function(data) 24 | 25 | model = Model(predict_function, dataframe_input=True) 26 | 27 | # create Tyrus instance 28 | tyrus = Tyrus( 29 | model, 30 | data.iloc[100], 31 | predictions.iloc[100], 32 | background=data.iloc[:100], 33 | filepath=os.getcwd() 34 | ) 35 | 36 | # launch dashboard 37 | tyrus.run() 38 | 39 | # see if dashboard html exists 40 | assert "trustyai_dashboard.html" in os.listdir() 41 | 42 | # cleanup 43 | os.remove("trustyai_dashboard.html") 44 | 45 | 46 | def test_tyrus_numpy(): 47 | # define data 48 | data = np.random.rand(101, 5) 49 | 50 | # define model 51 | def predict_function(x): 52 | return np.stack([x.sum(1), x.std(1), np.linalg.norm(x, axis=1)]).T 53 | 54 | predictions = predict_function(data) 55 | 56 | model = Model(predict_function, dataframe_input=False) 57 | 58 | # create Tyrus instance 59 | tyrus = Tyrus( 60 | model, 61 | data[100], 62 | predictions[100], 63 | background=data[:100] 64 | ) 65 | 66 | # launch dashboard 67 | tyrus.run() 68 | 69 | # see if dashboard html exists 70 | assert "trustyai_dashboard.html" in os.listdir(tyrus.filepath) 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![version](https://img.shields.io/badge/version-0.6.3-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml) 2 | 3 | # python-trustyai 4 | 5 | Python bindings to [TrustyAI](https://kogito.kie.org/trustyai/)'s explainability library. 6 | 7 | ## Setup 8 | 9 | ### PyPi 10 | 11 | Install from PyPi with 12 | 13 | ```shell 14 | pip install trustyai 15 | ``` 16 | 17 | ### Local 18 | 19 | The minimum dependencies can be installed (from the root directory) with 20 | 21 | ```shell 22 | pip install . 23 | ``` 24 | 25 | If running the examples or developing, also install the development dependencies: 26 | 27 | ```shell 28 | pip install '.[dev]' 29 | ``` 30 | 31 | ### Docker 32 | 33 | Alternatively create a container image and run it using 34 | 35 | ```shell 36 | $ docker build -f Dockerfile -t python-trustyai:latest . 37 | $ docker run --rm -it -p 8888:8888 python-trustyai:latest 38 | ``` 39 | 40 | The Jupyter server will be available at `localhost:8888`. 41 | 42 | ### Binder 43 | 44 | You can also run the example Jupyter notebooks 45 | using `mybinder.org`: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/trustyai-python/trustyai-explainability-python-examples/main?labpath=examples) 46 | 47 | ## Documentation 48 | 49 | Check out the [ReadTheDocs page](https://trustyai-explainability-python.readthedocs.io/en/latest/) for API references 50 | and examples. 51 | 52 | ## Getting started 53 | 54 | ### Examples 55 | 56 | There are several working examples available in the [examples](https://github.com/trustyai-explainability/trustyai-explainability-python-examples/tree/main/examples) repository. 57 | 58 | ## Contributing 59 | 60 | Please see the [CONTRIBUTING.md](CONTRIBUTING.md) file for instructions on how to contribute to this project. 61 | -------------------------------------------------------------------------------- /tests/general/test_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Test model provider interface""" 3 | 4 | from common import * 5 | from trustyai.model import Model, Dataset, feature 6 | 7 | import pytest 8 | 9 | from trustyai.utils.data_conversions import numpy_to_prediction_object 10 | 11 | 12 | def test_basic_model(): 13 | """Test basic model""" 14 | 15 | model = Model(lambda x: x, output_names=['a', 'b', 'c', 'd', 'e']) 16 | features = numpy_to_prediction_object(np.arange(0, 100).reshape(20, 5), feature) 17 | result = model.predictAsync(features).get() 18 | assert len(result[0].outputs) == 5 19 | 20 | 21 | def test_cast_output(): 22 | np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'], disable_arrow=True) 23 | np2df = Model(lambda x: pd.DataFrame(x), disable_arrow=True) 24 | df2np = Model(lambda x: x.sum(1).values, 25 | dataframe_input=True, 26 | output_names=['sum'], 27 | disable_arrow=True) 28 | df2df = Model(lambda x: x, dataframe_input=True, disable_arrow=True) 29 | 30 | pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature) 31 | 32 | for m in [np2np, np2df, df2df, df2np]: 33 | output_val = m.predictAsync(pis).get() 34 | assert len(output_val) == 25 35 | 36 | 37 | def test_cast_output_arrow(): 38 | np2np = Model(lambda x: np.sum(x, 1), output_names=['sum']) 39 | np2df = Model(lambda x: pd.DataFrame(x)) 40 | df2np = Model(lambda x: x.sum(1).values, dataframe_input=True, output_names=['sum']) 41 | df2df = Model(lambda x: x, dataframe_input=True) 42 | pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature) 43 | 44 | for m in [np2np, np2df, df2df, df2np]: 45 | m._set_arrow(pis[0]) 46 | output_val = m.predictAsync(pis).get() 47 | assert len(output_val) == 25 48 | 49 | -------------------------------------------------------------------------------- /src/trustyai/metrics/language.py: -------------------------------------------------------------------------------- 1 | """ "Language metrics""" 2 | 3 | # pylint: disable = import-error 4 | from dataclasses import dataclass 5 | 6 | from typing import List, Optional, Union, Callable 7 | 8 | from org.kie.trustyai.metrics.language.levenshtein import ( 9 | WordErrorRate as _WordErrorRate, 10 | ErrorRateResult as _ErrorRateResult, 11 | ) 12 | from opennlp.tools.tokenize import Tokenizer 13 | from trustyai import _default_initializer # pylint: disable=unused-import 14 | 15 | from .distance import LevenshteinCounters 16 | 17 | 18 | @dataclass 19 | class ErrorRateResult: 20 | """Word Error Rate Result""" 21 | 22 | value: float 23 | alignment_counters: LevenshteinCounters 24 | 25 | @staticmethod 26 | def convert(result: _ErrorRateResult): 27 | """Converts a Java ErrorRateResult to a Python ErrorRateResult""" 28 | value = result.getValue() 29 | alignment_counters = result.getAlignmentCounters() 30 | return ErrorRateResult( 31 | value=value, 32 | alignment_counters=alignment_counters, 33 | ) 34 | 35 | 36 | def word_error_rate( 37 | reference: str, 38 | hypothesis: str, 39 | tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None, 40 | ) -> ErrorRateResult: 41 | """Calculate Word Error Rate between reference and hypothesis strings""" 42 | if not tokenizer: 43 | _wer = _WordErrorRate() 44 | elif isinstance(tokenizer, Tokenizer): 45 | _wer = _WordErrorRate(tokenizer) 46 | elif callable(tokenizer): 47 | tokenized_reference = tokenizer(reference) 48 | tokenized_hypothesis = tokenizer(hypothesis) 49 | _wer = _WordErrorRate() 50 | return ErrorRateResult.convert( 51 | _wer.calculate(tokenized_reference, tokenized_hypothesis) 52 | ) 53 | else: 54 | raise ValueError("Unsupported tokenizer") 55 | return ErrorRateResult.convert(_wer.calculate(reference, hypothesis)) 56 | -------------------------------------------------------------------------------- /tests/general/universal.py: -------------------------------------------------------------------------------- 1 | # General Setup 2 | from trustyai.model import Model, simple_prediction, counterfactual_prediction 3 | from trustyai.explainers import * 4 | 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | 10 | np.random.seed(0) 11 | 12 | @pytest.mark.skip("redundant") 13 | def test_all_explainers(): 14 | # universal setup ============================================================================== 15 | data = pd.DataFrame(np.random.rand(1, 5)) 16 | model_weights = np.random.rand(5) 17 | predict_function = lambda x: np.dot(x.values, model_weights) 18 | model = Model(predict_function, dataframe_input=True, arrow=True) 19 | prediction = simple_prediction(input_features=data, outputs=model(data)) 20 | 21 | # SHAP ========================================================================================= 22 | background = pd.DataFrame(np.zeros([100, 5])) 23 | shap_explainer = SHAPExplainer(background=background) 24 | explanation = shap_explainer.explain(prediction, model) 25 | 26 | for score in explanation.as_dataframe()['SHAP Value'].iloc[1:-1]: 27 | assert score > 0 28 | 29 | # LIME ========================================================================================= 30 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 31 | explanation = explainer.explain(prediction, model) 32 | for score in explanation.as_dataframe()["output-0_score"]: 33 | assert score > 0 34 | 35 | # Counterfactual =============================================================================== 36 | features = [feature(str(k), "number", v, domain=(-10., 10.)) for k, v in data.iloc[0].items()] 37 | goal = np.array([[0]]) 38 | cf_prediction = counterfactual_prediction(input_features=features, outputs=goal) 39 | explainer = CounterfactualExplainer(steps=10_000) 40 | explanation = explainer.explain(cf_prediction, model) 41 | result_output = model(explanation.get_proposed_features_as_pandas()) 42 | assert result_output < .01 43 | assert result_output > -.01 44 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/pdp.py: -------------------------------------------------------------------------------- 1 | """Visualizations.pdp module""" 2 | 3 | # pylint: disable = import-error, wrong-import-order, too-few-public-methods, missing-final-newline 4 | # pylint: disable = protected-access 5 | import matplotlib.pyplot as plt 6 | 7 | from trustyai.explainers.pdp import PDPResults 8 | 9 | 10 | class PDPViz: 11 | """Visualizes PDP graphs""" 12 | 13 | def plot(self, explanations, output_name=None, block=True, call_show=True) -> None: 14 | """ 15 | Parameters 16 | ---------- 17 | explanations: pdp.PDPResults 18 | the partial dependence plots associated to the model outputs 19 | output_name: str 20 | name of the output to be plotted 21 | Default to None 22 | block: bool 23 | whether the plotting operation 24 | should be blocking or not 25 | call_show: bool 26 | (default= 'True') Whether plt.show() will be called by default at the end of 27 | the plotting function. If `False`, the plot will be returned to the user for 28 | further editing. 29 | """ 30 | pdp_graphs = explanations.pdp_graphs 31 | fig, axs = plt.subplots(len(pdp_graphs), constrained_layout=True) 32 | p_idx = 0 33 | for pdp_graph in pdp_graphs: 34 | if output_name is not None and output_name != str( 35 | pdp_graph.getOutput().getName() 36 | ): 37 | continue 38 | fig.suptitle(str(pdp_graph.getOutput().getName())) 39 | pdp_x = [] 40 | for i in range(len(pdp_graph.getX())): 41 | pdp_x.append(PDPResults._to_plottable(pdp_graph.getX()[i])) 42 | pdp_y = [] 43 | for i in range(len(pdp_graph.getY())): 44 | pdp_y.append(PDPResults._to_plottable(pdp_graph.getY()[i])) 45 | axs[p_idx].plot(pdp_x, pdp_y) 46 | axs[p_idx].set_title( 47 | str(pdp_graph.getFeature().getName()), loc="left", fontsize="small" 48 | ) 49 | axs[p_idx].grid() 50 | p_idx += 1 51 | fig.supylabel("Partial Dependence Plot") 52 | if call_show: 53 | plt.show(block=block) 54 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TrustyAI documentation master file, created by 2 | sphinx-quickstart on Tue Jul 12 11:47:01 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TrustyAI's documentation! 7 | ==================================== 8 | Red Hat's TrustyAI-Python library provides XAI explanations of decision services and 9 | predictive models for both enterprise and data science use-cases. 10 | 11 | This library is designed to provide a set of Python bindings to the main 12 | `TrustyAI Java toolkit `_, to allow 13 | for easier access to the toolkit in data science and prototyping use cases. This means the library 14 | benefits from both the speed of Java as well as the ease-of-use of Python; our whitepaper shows that 15 | the TrustyAI-Python LIME and SHAP explainers can run faster than the the official implementations. 16 | 17 | Installation 18 | ============ 19 | ``pip install trustyai`` 20 | 21 | Tutorial and Examples 22 | ===================== 23 | To get started, check out the :ref:`tutorial`. For more usage examples, see the example notebooks: 24 | 25 | * `LIME `_ 26 | * `SHAP `_ 27 | * `Counterfactuals `_ 28 | 29 | GitHub Repos 30 | ============ 31 | * `TrustyAI Python `_ 32 | * `TrustyAI Python Examples `_ 33 | * `TrustyAI Java `_ 34 | 35 | Paper 36 | ===== 37 | `TrustyAI Explainability Toolkit `_, 2022 38 | 39 | Contents 40 | ======== 41 | .. toctree:: 42 | tutorial 43 | api 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | -------------------------------------------------------------------------------- /tests/general/test_pdp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """PDP test suite""" 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from sklearn.datasets import make_classification 8 | from trustyai.explainers import PDPExplainer 9 | from trustyai.model import Model 10 | from trustyai.utils import TestModels 11 | from trustyai.visualizations import plot 12 | 13 | 14 | def create_random_df(): 15 | X, _ = make_classification(n_samples=5000, n_features=5, n_classes=2, 16 | n_clusters_per_class=2, class_sep=2, flip_y=0, random_state=23) 17 | 18 | return pd.DataFrame({ 19 | 'x1': X[:, 0], 20 | 'x2': X[:, 1], 21 | 'x3': X[:, 2], 22 | 'x4': X[:, 3], 23 | 'x5': X[:, 4], 24 | }) 25 | 26 | 27 | def test_pdp_sumskip(): 28 | """Test PDP with sum skip model on random generated data""" 29 | 30 | df = create_random_df() 31 | model = TestModels.getSumSkipModel(0) 32 | pdp_explainer = PDPExplainer() 33 | pdp_results = pdp_explainer.explain(model, df) 34 | assert pdp_results is not None 35 | assert pdp_results.as_dataframe() is not None 36 | 37 | 38 | def test_pdp_sumthreshold(): 39 | """Test PDP with sum threshold model on random generated data""" 40 | 41 | df = create_random_df() 42 | model = TestModels.getLinearThresholdModel([0.1, 0.2, 0.3, 0.4, 0.5], 0) 43 | pdp_explainer = PDPExplainer() 44 | pdp_results = pdp_explainer.explain(model, df) 45 | assert pdp_results is not None 46 | assert pdp_results.as_dataframe() is not None 47 | 48 | 49 | def pdp_plots(block): 50 | """Test PDP plots""" 51 | np.random.seed(0) 52 | data = pd.DataFrame(np.random.rand(101, 5)) 53 | 54 | model_weights = np.random.rand(5) 55 | predict_function = lambda x: np.stack([np.dot(x.values, model_weights), 2 * np.dot(x.values, model_weights)], -1) 56 | model = Model(predict_function, dataframe_input=True) 57 | pdp_explainer = PDPExplainer() 58 | explanation = pdp_explainer.explain(model, data) 59 | 60 | plot(explanation, block=block) 61 | plot(explanation, block=block, output_name='output-0') 62 | 63 | 64 | @pytest.mark.block_plots 65 | def test_lime_plots_blocking(): 66 | pdp_plots(True) 67 | 68 | 69 | def test_lime_plots(): 70 | pdp_plots(False) 71 | -------------------------------------------------------------------------------- /src/trustyai/utils/_visualisation.py: -------------------------------------------------------------------------------- 1 | """Visualiser utilies for explainer results""" 2 | 3 | # pylint: disable = consider-using-f-string 4 | 5 | 6 | # HTML FORMAT FUNCTIONS ============================================================================ 7 | def bold_green_html(content): 8 | """Format the content string as a bold, green html object""" 9 | return '{}'.format( 10 | DEFAULT_STYLE["positive_primary_colour"], content 11 | ) 12 | 13 | 14 | def bold_red_html(content): 15 | """Format the content string as a bold, red html object""" 16 | return '{}'.format( 17 | DEFAULT_STYLE["negative_primary_colour"], content 18 | ) 19 | 20 | 21 | def output_html(content): 22 | """Format the content string as a bold object in TrustyAI purple, used for 23 | Tyrus output displays""" 24 | return '{}'.format(content) 25 | 26 | 27 | def feature_html(content): 28 | """Format the content string as a bold object in black, used for 29 | Tyrus feature displays""" 30 | return '{}'.format(content) 31 | 32 | 33 | DEFAULT_STYLE = { 34 | "positive_primary_colour": "#13ba3c", 35 | "positive_primary_colour_faded": "#88dc9d", 36 | "negative_primary_colour": "#ee0000", 37 | "negative_primary_colour_faded": "#f67f7f", 38 | "neutral_primary_colour": "#ffffff", 39 | } 40 | 41 | DEFAULT_RC_PARAMS = { 42 | "patch.linewidth": 0.5, 43 | "patch.facecolor": "348ABD", 44 | "patch.edgecolor": "EEEEEE", 45 | "patch.antialiased": True, 46 | "font.size": 10.0, 47 | "axes.facecolor": "DDDDDD", 48 | "axes.edgecolor": "white", 49 | "axes.linewidth": 1, 50 | "axes.grid": True, 51 | "axes.titlesize": "x-large", 52 | "axes.labelsize": "large", 53 | "axes.labelcolor": "black", 54 | "axes.axisbelow": True, 55 | "text.color": "black", 56 | "xtick.color": "black", 57 | "xtick.direction": "out", 58 | "ytick.color": "black", 59 | "ytick.direction": "out", 60 | "legend.facecolor": "ffffff", 61 | "grid.color": "white", 62 | "grid.linestyle": "-", # solid line 63 | "figure.figsize": (16, 9), 64 | "figure.dpi": 100, 65 | "figure.facecolor": "ffffff", 66 | "figure.edgecolor": "777777", 67 | "savefig.bbox": "tight", 68 | } 69 | -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | [changelog] 2 | # changelog header 3 | header = """ 4 | # Changelog\n 5 | All notable changes to this project will be documented in this file.\n 6 | """ 7 | # template for the changelog body 8 | body = """ 9 | {% if version %}\ 10 | ## [{{ version }}] - {{ timestamp | date(format="%Y-%m-%d") }} 11 | {% else %}\ 12 | ## [unreleased] 13 | {% endif %}\ 14 | {% for group, commits in commits | group_by(attribute="group") %} 15 | ### {{ group | upper_first }} 16 | {% for commit in commits %} 17 | - {{ commit.message | upper_first }}\ 18 | {% endfor %} 19 | {% endfor %}\n 20 | """ 21 | # remove the leading and trailing whitespace from the template 22 | trim = true 23 | # changelog footer 24 | footer = "" 25 | 26 | [git] 27 | # parse the commits based on https://www.conventionalcommits.org 28 | conventional_commits = true 29 | # filter out the commits that are not conventional 30 | filter_unconventional = false 31 | # process each line of a commit as an individual commit 32 | split_commits = false 33 | # regex for preprocessing the commit messages 34 | commit_preprocessors = [ 35 | # { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"}, # replace issue numbers 36 | ] 37 | # regex for parsing and grouping commits 38 | commit_parsers = [ 39 | { message = "^feat", group = "Features" }, 40 | { message = "^fix", group = "Bug Fixes" }, 41 | { message = "^doc", group = "Documentation" }, 42 | { message = "^perf", group = "Performance" }, 43 | { message = "^refactor", group = "Refactor" }, 44 | { message = "^style", group = "Styling" }, 45 | { message = "^test", group = "Testing" }, 46 | { message = "^chore\\(release\\): prepare for", skip = true }, 47 | { message = "^chore", group = "Miscellaneous Tasks" }, 48 | { body = ".*security", group = "Security" }, 49 | ] 50 | # protect breaking changes from being skipped due to matching a skipping commit_parser 51 | protect_breaking_commits = false 52 | # filter out the commits that are not matched by commit parsers 53 | filter_commits = false 54 | # glob pattern for matching git tags 55 | tag_pattern = "[0-9]*" 56 | # regex for skipping tags 57 | skip_tags = "v0.1.0-beta.1" 58 | # regex for ignoring tags 59 | ignore_tags = "" 60 | # sort the tags topologically 61 | topo_order = false 62 | # sort the commits inside sections by oldest/newest order 63 | sort_commits = "oldest" 64 | # limit the number of commits included in the changelog. 65 | # limit_commits = 42 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Sphinx stuff 69 | docs/_build 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .DS_Store 131 | .idea 132 | .Rproj.user 133 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "trustyai" 3 | version = "0.6.3" 4 | description = "Python bindings to the TrustyAI explainability library." 5 | authors = [{ name = "Rui Vieira", email = "rui@redhat.com" }] 6 | license = { text = "Apache License Version 2.0" } 7 | readme = "README.md" 8 | requires-python = ">=3.10" 9 | 10 | keywords = ["trustyai", "xai", "explainability", "ml"] 11 | 12 | classifiers = [ 13 | "License :: OSI Approved :: Apache Software License", 14 | "Development Status :: 4 - Beta", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Science/Research", 17 | "Programming Language :: Java", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3.12", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Software Development :: Libraries :: Java Libraries", 23 | ] 24 | 25 | dependencies = [ 26 | "Jpype1==1.5.0", 27 | "pyarrow>=20.0.0", 28 | "matplotlib~=3.10.3", 29 | "pandas>=2.1.0", 30 | "numpy>=1.26.4", 31 | "jupyter-bokeh~=4.0.5", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "JPype1==1.5.0", 37 | "black~=25.11", 38 | "click==8.0.4", 39 | "joblib~=1.2.0", 40 | "jupyterlab~=4.4.4", 41 | "numpydoc==1.5.0", 42 | "pylint==3.2.0", 43 | "pytest~=7.2.1", 44 | "pytest-benchmark==4.0.0", 45 | "pytest-forked~=1.6.0", 46 | "scikit-learn~=1.7.0", 47 | "setuptools", 48 | "twine==3.4.2", 49 | "wheel~=0.38.4", 50 | "xgboost~=3.0.2", 51 | ] 52 | extras = ["aix360[default,tsice,tslime,tssaliency]==0.3.0"] 53 | 54 | detoxify = [ 55 | "transformers~=4.36.2", 56 | "datasets", 57 | "scipy~=1.12.0", 58 | "torch~=2.2.1", 59 | "evaluate", 60 | "trl", 61 | ] 62 | 63 | api = ["kubernetes"] 64 | 65 | [project.urls] 66 | homepage = "https://github.com/trustyai-explainability/trustyai-explainability-python" 67 | documentation = "https://trustyai-explainability-python.readthedocs.io/en/latest/" 68 | repository = "https://github.com/trustyai-explainability/trustyai-explainability-python" 69 | 70 | [build-system] 71 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 72 | build-backend = "setuptools.build_meta" 73 | 74 | [tool.setuptools] 75 | package-dir = { "" = "src" } 76 | 77 | [tool.pytest.ini_options] 78 | log_cli = true 79 | addopts = '-m="not block_plots"' 80 | markers = [ 81 | "block_plots: Test plots will block execution of subsequent tests until closed", 82 | ] 83 | 84 | [tool.setuptools.packages.find] 85 | where = ["src"] 86 | 87 | [tool.setuptools_scm] 88 | -------------------------------------------------------------------------------- /tests/general/test_prediction.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Test model provider interface""" 3 | 4 | from common import * 5 | from trustyai.model import simple_prediction, counterfactual_prediction,feature, output 6 | from trustyai.utils.data_conversions import numpy_to_prediction_object 7 | import pytest 8 | 9 | # test that predictions are created correctly from numpy arrays 10 | def test_predictions_numpy(): 11 | input_values = np.arange(5) 12 | output_values = np.arange(2) 13 | 14 | pred = simple_prediction(input_values, output_values) 15 | assert len(pred.getInput().getFeatures()) == 5 16 | 17 | pred = counterfactual_prediction(input_values, output_values) 18 | assert len(pred.getInput().getFeatures()) == 5 19 | 20 | 21 | # test that predictions are created correctly from dataframe 22 | def test_predictions_pandas(): 23 | input_values = pd.DataFrame(np.arange(5).reshape(1, -1), columns=list("abcde")) 24 | output_values = pd.DataFrame(np.arange(2).reshape(1, -1), columns=list("xy")) 25 | 26 | pred = simple_prediction(input_values, output_values) 27 | assert len(pred.getInput().getFeatures()) == 5 28 | assert pred.getInput().getFeatures()[0].getName() == "a" 29 | 30 | pred = counterfactual_prediction(input_values, output_values) 31 | assert pred.getInput().getFeatures()[0].getName() == "a" 32 | assert len(pred.getInput().getFeatures()) == 5 33 | 34 | 35 | # test that predictions are created correctly from prediction input + outputs 36 | def test_prediction_pi(): 37 | input_values = numpy_to_prediction_object(np.arange(5).reshape(1, -1), feature)[0] 38 | output_values = numpy_to_prediction_object(np.arange(2).reshape(1, -1), output)[0] 39 | 40 | pred = simple_prediction(input_values, output_values) 41 | assert len(pred.getInput().getFeatures()) == 5 42 | 43 | pred = counterfactual_prediction(input_values, output_values) 44 | assert len(pred.getInput().getFeatures()) == 5 45 | 46 | 47 | # test that predictions are created correctly from feature+output lists 48 | def test_prediction_featurelist(): 49 | input_values = numpy_to_prediction_object( 50 | np.arange(5).reshape(1, -1), feature 51 | )[0].getFeatures() 52 | output_values = numpy_to_prediction_object( 53 | np.arange(2).reshape(1, -1), output 54 | )[0].getOutputs() 55 | 56 | pred = simple_prediction(input_values, output_values) 57 | assert len(pred.getInput().getFeatures()) == 5 58 | 59 | pred = counterfactual_prediction(input_values, output_values) 60 | assert len(pred.getInput().getFeatures()) == 5 61 | -------------------------------------------------------------------------------- /tests/extras/test_metrics_service.py: -------------------------------------------------------------------------------- 1 | """Test suite for TrustyAI metrics service data conversions""" 2 | import json 3 | import os 4 | import random 5 | import unittest 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from trustyai.utils.extras.metrics_service import ( 10 | json_to_df, 11 | df_to_json 12 | ) 13 | 14 | def generate_json_data(batch_list, data_path): 15 | for batch in batch_list: 16 | data = { 17 | "inputs": [ 18 | {"name": "test_data_input", 19 | "shape": [1, 100], 20 | "datatype": "FP64", 21 | "data": [random.uniform(a=100, b=200) for i in range(100)] 22 | } 23 | ] 24 | } 25 | for batch in batch_list: 26 | with open(data_path + f"{batch}.json", 'w', encoding="utf-8") as f: 27 | json.dump(data, f, ensure_ascii=False) 28 | 29 | 30 | def generate_test_df(): 31 | data = { 32 | '0': np.random.uniform(low=100, high=200, size=100), 33 | '1': np.random.uniform(low=5000, high=10000, size=100), 34 | '2': np.random.uniform(low=100, high=200, size=100), 35 | '3': np.random.uniform(low=5000, high=10000, size=100), 36 | '4': np.random.uniform(low=5000, high=10000, size=100) 37 | } 38 | return pd.DataFrame(data=data) 39 | 40 | 41 | class TestMetricsService(unittest.TestCase): 42 | def setUp(self): 43 | self.df = generate_test_df() 44 | self.data_path = "data/" 45 | if not os.path.exists(self.data_path): 46 | os.mkdir("data/") 47 | self.batch_list = list(range(0, 5)) 48 | 49 | def test_json_to_df(self): 50 | """Test json data to pandas dataframe conversion""" 51 | generate_json_data(batch_list=self.batch_list, data_path=self.data_path) 52 | df = json_to_df(self.data_path, self.batch_list) 53 | n_rows, n_cols = 0, 0 54 | for batch in self.batch_list: 55 | file = self.data_path + f"{batch}.json" 56 | with open(file, encoding="utf8") as f: 57 | data = json.load(f)["inputs"][0] 58 | n_rows += data["shape"][0] 59 | n_cols = data["shape"][1] 60 | self.assertEqual(df.shape, (n_rows, n_cols)) 61 | 62 | 63 | def test_df_to_json(self): 64 | """Test pandas dataframe to json data conversion""" 65 | df = generate_test_df() 66 | name = 'test_data_input' 67 | json_file = 'data/test.json' 68 | df_to_json(df, name, json_file) 69 | with open(json_file, encoding="utf8") as f: 70 | data = json.load(f)["inputs"][0] 71 | n_rows = data["shape"][0] 72 | n_cols = data["shape"][1] 73 | self.assertEqual(df.shape, (n_rows, n_cols)) 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /info/detoxify.md: -------------------------------------------------------------------------------- 1 | # TrustyAI-Detoxify 2 | Algorithms and tools for detecting and fixing hate speech, abuse and profanity in content generated by Large Language Models (_LLMs_). The source code is located in the [trustyai.language.detoxify](https://github.com/trustyai-explainability/trustyai-explainability-python/tree/main/src/trustyai/language/detoxify) module. 3 | 4 | ## T-MaRCo 5 | 6 | T-MaRCo is an extension of the work [Detoxifying Text with MaRCo: Controllable Revision with Experts and Anti-Experts 7 | ](https://arxiv.org/abs/2212.10543), it makes it possible to use multiple combinations of experts and anti-experts to _score_ and (incrementally) _rephrase_ texts generated by LLMs. 8 | 9 | In addition to that, it can integrate _rephrasing_ with the base model _self-reflection_ capabilities (see papers [Towards Mitigating Hallucination in Large Language Models via Self-Reflection 10 | ](https://arxiv.org/abs/2310.06271) and [N-Critics: Self-Refinement of Large Language Models with Ensemble of Critics 11 | ](https://arxiv.org/abs/2310.18679)). 12 | 13 | T-MaRCo hence provides the following features: 14 | * content *scoring*: providing a _disagreement score_ for each input token; high disagreement is often attached to toxic content. 15 | * content *masking*: providing a masked version of the input content, where all tokens that are consired toxic are replaced with the `` token. 16 | * content *redirection*: providing a non-toxic "regenerated" version of the original content. 17 | 18 | How to use T-MaRCo: 19 | ```python 20 | from trustyai.detoxify import TMaRCo 21 | 22 | # instantiate T-MaRCo 23 | tmarco = TMaRCo(expert_weights=[-1, 2]) 24 | 25 | # load pretrained anti-expert and expert models 26 | tmarco.load_models(["trustyai/gminus", "trustyai/gplus"]) 27 | 28 | # pick up some text generated by a LLM 29 | text = "Stand by me, just as long as you fucking stand by me" 30 | 31 | # generate T-MaRCo disagreement scores 32 | scores = tmarco.score([text]) # '[0.78664607 0.06622718 0.02403926 5.331921 0.49842355 0.46609956 0.22441313 0.43487906 0.51990145 1.9062967 0.64200985 0.30269763 1.7964466 ]' 33 | 34 | # mask tokens having high disagreement scores 35 | masked_text = tmarco.mask([text], scores=scores) # 'Stand by me just as long as you stand by' 36 | 37 | # rephrase masked tokens 38 | rephrased = tmarco.rephrase([text], [masked_text]) # 'Stand by me and just as long as you want stand by me'' 39 | 40 | # combine rephrasing and a base model self-reflection capabilities 41 | reflected = tmarco.reflect([text]) # '["'Stand by me in the way I want stand by you and in the ways I need you to standby me'."]' 42 | 43 | ``` 44 | 45 | T-MaRCo Pretrained models are available under [TrustyAI HuggingFace space](https://huggingface.co/trustyai) at https://huggingface.co/trustyai/gminus and https://huggingface.co/trustyai/gplus. 46 | -------------------------------------------------------------------------------- /src/trustyai/metrics/distance.py: -------------------------------------------------------------------------------- 1 | """ "Distance metrics""" 2 | 3 | # pylint: disable = import-error 4 | from dataclasses import dataclass 5 | from typing import List, Optional, Union, Callable 6 | 7 | from org.kie.trustyai.metrics.language.distance import ( 8 | Levenshtein as _Levenshtein, 9 | LevenshteinResult as _LevenshteinResult, 10 | LevenshteinCounters as _LevenshteinCounters, 11 | ) 12 | from opennlp.tools.tokenize import Tokenizer 13 | import numpy as np 14 | from trustyai import _default_initializer # pylint: disable=unused-import 15 | 16 | 17 | @dataclass 18 | class LevenshteinCounters: 19 | """LevenshteinCounters Counters""" 20 | 21 | substitutions: int 22 | insertions: int 23 | deletions: int 24 | correct: int 25 | 26 | @staticmethod 27 | def convert(result: _LevenshteinCounters): 28 | """Converts a Java LevenshteinCounters to a Python LevenshteinCounters""" 29 | return LevenshteinCounters( 30 | substitutions=result.getSubstitutions(), 31 | insertions=result.getInsertions(), 32 | deletions=result.getDeletions(), 33 | correct=result.getCorrect(), 34 | ) 35 | 36 | 37 | @dataclass 38 | class LevenshteinResult: 39 | """Levenshtein Result""" 40 | 41 | distance: float 42 | counters: LevenshteinCounters 43 | matrix: np.ndarray 44 | reference: List[str] 45 | hypothesis: List[str] 46 | 47 | @staticmethod 48 | def convert(result: _LevenshteinResult): 49 | """Converts a Java LevenshteinResult to a Python LevenshteinResult""" 50 | distance = result.getDistance() 51 | counters = LevenshteinCounters.convert(result.getCounters()) 52 | data = result.getDistanceMatrix().getData() 53 | numpy_array = np.array(data)[1:, 1:] 54 | reference = result.getReferenceTokens() 55 | hypothesis = result.getHypothesisTokens() 56 | 57 | return LevenshteinResult( 58 | distance=distance, 59 | counters=counters, 60 | matrix=numpy_array, 61 | reference=reference, 62 | hypothesis=hypothesis, 63 | ) 64 | 65 | 66 | def levenshtein( 67 | reference: str, 68 | hypothesis: str, 69 | tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None, 70 | ) -> LevenshteinResult: 71 | """Calculate Levenshtein distance between two strings""" 72 | if not tokenizer: 73 | return LevenshteinResult.convert( 74 | _Levenshtein.calculateToken(reference, hypothesis) 75 | ) 76 | if isinstance(tokenizer, Tokenizer): 77 | return LevenshteinResult.convert( 78 | _Levenshtein.calculateToken(reference, hypothesis, tokenizer) 79 | ) 80 | if callable(tokenizer): 81 | tokenized_reference = tokenizer(reference) 82 | tokenized_hypothesis = tokenizer(hypothesis) 83 | return LevenshteinResult.convert( 84 | _Levenshtein.calculateToken(tokenized_reference, tokenized_hypothesis) 85 | ) 86 | 87 | raise ValueError("Unsupported tokenizer") 88 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath('../src/trustyai/')) 17 | 18 | import sphinx_rtd_theme 19 | 20 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'TrustyAI' 25 | copyright = '2023, Rob Geada, Tommaso Teofili, Rui Vieira, Rebecca Whitworth, Daniele Zonca' 26 | author = 'Rob Geada, Tommaso Teofili, Rui Vieira, Rebecca Whitworth, Daniele Zonca' 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.autosummary', 36 | 'sphinx.ext.autosectionlabel', 37 | 'sphinx_rtd_theme', 38 | 'sphinx.ext.mathjax', 39 | 'numpydoc' 40 | ] 41 | 42 | autodoc_default_options = { 43 | 'members': True, 44 | 'inherited-members': True 45 | } 46 | autosummary_generate = True 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # List of patterns, relative to source directory, that match files and 52 | # directories to ignore when looking for source files. 53 | # This pattern also affects html_static_path and html_extra_path. 54 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 55 | 56 | # -- Options for HTML output ------------------------------------------------- 57 | 58 | # The theme to use for HTML and HTML Help pages. See the documentation for 59 | # a list of builtin themes. 60 | # 61 | html_theme = 'sphinx_rtd_theme' 62 | html_static_path = ['_static'] 63 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 64 | html_theme_options = { 65 | 'logo_only': True, 66 | 'style_nav_header_background': '#343131', 67 | } 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | html_css_files = ['css/custom.css'] 73 | 74 | html_favicon = '_static/artwork/favicon.png' 75 | html_logo = '_static/artwork/logo.png' 76 | # numpydoc settings 77 | numpydoc_show_class_members = False 78 | 79 | 80 | def setup(app): 81 | import trustyai 82 | from trustyai.model import Model 83 | from trustyai.utils.tyrus import Tyrus 84 | Model.__name__ = "Model" 85 | Tyrus.__name__ = "Tyrus" -------------------------------------------------------------------------------- /src/trustyai/visualizations/__init__.py: -------------------------------------------------------------------------------- 1 | """Generates visualization according to explanation type""" 2 | 3 | # pylint: disable=import-error, wrong-import-order, protected-access, missing-final-newline 4 | from typing import Union, Optional 5 | 6 | from bokeh.io import show 7 | 8 | from trustyai.explainers import SHAPResults, LimeResults, pdp 9 | from trustyai.metrics.distance import LevenshteinResult 10 | from trustyai.visualizations.visualization_results import VisualizationResults 11 | from trustyai.visualizations.shap import SHAPViz 12 | from trustyai.visualizations.lime import LimeViz 13 | from trustyai.visualizations.pdp import PDPViz 14 | from trustyai.visualizations.distance import DistanceViz 15 | 16 | 17 | def get_viz(explanations) -> VisualizationResults: 18 | """ 19 | Get visualization according to the explanation method 20 | """ 21 | if isinstance(explanations, SHAPResults): 22 | return SHAPViz() 23 | if isinstance(explanations, LimeResults): 24 | return LimeViz() 25 | if isinstance(explanations, pdp.PDPResults): 26 | return PDPViz() 27 | if isinstance(explanations, LevenshteinResult): 28 | return DistanceViz() 29 | raise ValueError("Explanation method unknown") 30 | 31 | 32 | def plot( 33 | explanations: Union[SHAPResults, LimeResults, pdp.PDPResults, LevenshteinResult], 34 | output_name: Optional[str] = None, 35 | render_bokeh: bool = False, 36 | block: bool = True, 37 | call_show: bool = True, 38 | ) -> None: 39 | """ 40 | Plot the found feature saliencies. 41 | 42 | Parameters 43 | ---------- 44 | explanations: Union[LimeResults, SHAPResults, PDPResults, LevenshteinResult] 45 | the explanation result to plot 46 | output_name : str 47 | (default= `None`) The name of the output to be explainer. If `None`, all outputs will 48 | be displayed 49 | render_bokeh : bool 50 | (default= `False`) If true, render plot in bokeh, otherwise use matplotlib. 51 | block: bool 52 | (default= `True`) Whether displaying the plot blocks subsequent code execution 53 | call_show: bool 54 | (default= 'True') Whether plt.show() will be called by default at the end of the 55 | plotting function. If `False`, the plot will be returned to the user for further 56 | editing. 57 | """ 58 | viz = get_viz(explanations) 59 | 60 | if isinstance(explanations, pdp.PDPResults): 61 | viz.plot(explanations, output_name) 62 | elif isinstance(explanations, LevenshteinResult): 63 | viz.plot(explanations) 64 | elif output_name is None: 65 | for output_name_iterator in explanations.saliency_map().keys(): 66 | if render_bokeh: 67 | show(viz._get_bokeh_plot(explanations, output_name_iterator)) 68 | else: 69 | viz._matplotlib_plot( 70 | explanations, output_name_iterator, block, call_show 71 | ) 72 | else: 73 | if render_bokeh: 74 | show(viz._get_bokeh_plot(explanations, output_name)) 75 | else: 76 | viz._matplotlib_plot(explanations, output_name, block, call_show) 77 | -------------------------------------------------------------------------------- /tests/general/test_shap_background_generation.py: -------------------------------------------------------------------------------- 1 | """SHAP background generation test suite""" 2 | 3 | import pytest 4 | import numpy as np 5 | import math 6 | 7 | from trustyai.explainers.shap import BackgroundGenerator 8 | from trustyai.model import Model, feature_domain 9 | from trustyai.utils.data_conversions import prediction_object_to_numpy 10 | 11 | 12 | def test_random_generation(): 13 | """Test that random sampling recovers samples from distribution""" 14 | seed = 0 15 | np.random.seed(seed) 16 | data = np.random.rand(100, 5) 17 | background_ta = BackgroundGenerator(data).sample(5) 18 | background = prediction_object_to_numpy(background_ta) 19 | 20 | assert len(background) == 5 21 | for row in background: 22 | assert row in data 23 | 24 | 25 | def test_kmeans_generation(): 26 | """Test that k-means recovers centroids of well-clustered data""" 27 | 28 | seed = 0 29 | clusters = 5 30 | np.random.seed(seed) 31 | 32 | data = [] 33 | ground_truth = [] 34 | for cluster in range(clusters): 35 | data.append(np.random.rand(100 // clusters, 5) + cluster * 10) 36 | ground_truth.append(np.array([cluster * 10] * 5)) 37 | data = np.vstack(data) 38 | ground_truth = np.vstack(ground_truth) 39 | background_ta = BackgroundGenerator(data).kmeans(clusters) 40 | background = prediction_object_to_numpy(background_ta) 41 | 42 | assert len(background) == 5 43 | for row in background: 44 | ground_truth_idx = math.floor(row[0] / 10) 45 | assert np.linalg.norm(row - ground_truth[ground_truth_idx]) < 2.5 46 | 47 | 48 | def test_counterfactual_generation_single_goal(): 49 | """Test that cf background meets requirements""" 50 | seed = 0 51 | np.random.seed(seed) 52 | data = np.random.rand(100, 5) 53 | model = Model(lambda x: x.sum(1)) 54 | goal = np.array([1.0]) 55 | 56 | # check that undomained backgrounds are caught 57 | attribute_error_thrown = False 58 | try: 59 | BackgroundGenerator(data).counterfactual(goal, model, 10,) 60 | except AttributeError: 61 | attribute_error_thrown = True 62 | assert attribute_error_thrown 63 | 64 | domains = [feature_domain((-10, 10)) for _ in range(5)] 65 | background_ta = BackgroundGenerator(data, domains, seed)\ 66 | .counterfactual(goal, model, 5, step_count=5000, timeout_seconds=2) 67 | background = prediction_object_to_numpy(background_ta) 68 | 69 | for row in background: 70 | assert np.linalg.norm(goal - model(row.reshape(1, -1))) < .01 71 | 72 | 73 | def test_counterfactual_generation_multi_goal(): 74 | """Test that cf background meets requirements for multiple goals""" 75 | 76 | seed = 0 77 | np.random.seed(seed) 78 | data = np.random.rand(100, 5) 79 | model = Model(lambda x: x.sum(1)) 80 | goals = np.arange(1, 10).reshape(-1, 1) 81 | domains = [feature_domain((-10, 10)) for _ in range(5)] 82 | background_ta = BackgroundGenerator(data, domains, seed)\ 83 | .counterfactual(goals, model, 1, step_count=5000, timeout_seconds=2, chain=True) 84 | background = prediction_object_to_numpy(background_ta) 85 | 86 | for i, goal in enumerate(goals): 87 | assert np.linalg.norm(goal - model(background[i:i+1])) < goal[0]/100 88 | -------------------------------------------------------------------------------- /.github/workflows/security.yaml: -------------------------------------------------------------------------------- 1 | name: Security Scan 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | jobs: 11 | trivy-scan: 12 | name: Trivy 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | security-events: write 17 | actions: read 18 | 19 | steps: 20 | - name: Checkout code 21 | uses: actions/checkout@v4 22 | 23 | - name: Set up Python 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: '3.11' 27 | 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | if [ -f pyproject.toml ]; then 32 | pip install -e ".[dev]" 33 | fi 34 | 35 | - name: Run Trivy vulnerability scan 36 | uses: aquasecurity/trivy-action@0.28.0 37 | with: 38 | scan-type: 'fs' 39 | scan-ref: '.' 40 | format: 'sarif' 41 | output: 'trivy-results.sarif' 42 | severity: 'CRITICAL,HIGH,MEDIUM,LOW' 43 | exit-code: '0' 44 | - name: Check for critical and high vulnerabilities 45 | uses: aquasecurity/trivy-action@0.28.0 46 | with: 47 | scan-type: 'fs' 48 | scan-ref: '.' 49 | format: 'table' 50 | severity: 'CRITICAL,HIGH' 51 | exit-code: '1' 52 | continue-on-error: true 53 | 54 | - name: Upload Trivy scan results to Security tab 55 | uses: github/codeql-action/upload-sarif@v3 56 | if: always() 57 | with: 58 | sarif_file: 'trivy-results.sarif' 59 | category: 'trivy-security-scan' 60 | 61 | bandit-scan: 62 | name: Bandit 63 | runs-on: ubuntu-latest 64 | permissions: 65 | security-events: write 66 | actions: read 67 | contents: read 68 | checks: write 69 | 70 | steps: 71 | - uses: actions/checkout@v4 72 | 73 | - name: Set up Python 74 | uses: actions/setup-python@v5 75 | with: 76 | python-version: "3.11" 77 | cache: "pip" 78 | 79 | - name: Create virtual environment 80 | run: | 81 | python -m pip install --upgrade pip 82 | python -m venv .venv 83 | 84 | - name: Install dependencies 85 | run: | 86 | source .venv/bin/activate 87 | pip install -e ".[dev]" 88 | 89 | - name: Install Bandit 90 | run: | 91 | source .venv/bin/activate 92 | pip install bandit[sarif] 93 | 94 | - name: Run Bandit Security Scan 95 | uses: PyCQA/bandit-action@v1 96 | with: 97 | targets: "." 98 | exclude: "tests" 99 | 100 | - name: Upload SARIF results to Security tab 101 | if: github.ref == 'refs/heads/main' 102 | uses: github/codeql-action/upload-sarif@v3 103 | with: 104 | sarif_file: results.sarif 105 | category: bandit-security-scan 106 | continue-on-error: true 107 | 108 | - name: Upload SARIF as artifact 109 | uses: actions/upload-artifact@v4 110 | with: 111 | name: bandit-sarif-results 112 | path: results.sarif 113 | retention-days: 30 114 | continue-on-error: true -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tssaliency.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSSaliencyExplainer from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | 6 | from typing import Callable, List 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from aix360.algorithms.tssaliency import TSSaliencyExplainer as TSSaliencyExplainerAIX 11 | from pandas.io.formats.style import Styler 12 | import matplotlib.pyplot as plt 13 | 14 | from trustyai.explainers.explanation_results import ExplanationResults 15 | 16 | 17 | class TSSaliencyResults(ExplanationResults): 18 | """Wraps TSSaliency results. This object is returned by the :class:`~TSSaliencyExplainer`, 19 | and provides a variety of methods to visualize and interact with the explanation. 20 | """ 21 | 22 | def __init__(self, explanation): 23 | self.explanation = explanation 24 | 25 | def as_dataframe(self) -> pd.DataFrame: 26 | saliencies = self.explanation["saliency"].reshape(-1) 27 | return pd.DataFrame(saliencies, columns=self.explanation["feature_names"]) 28 | 29 | def as_html(self) -> Styler: 30 | """Returns the explanation as an HTML table.""" 31 | dataframe = self.as_dataframe() 32 | return dataframe.style 33 | 34 | def plot(self, index: int, cpos, window: int = None): 35 | """Plot tssaliency explanation for the test point 36 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tssaliency""" 37 | if window: 38 | scores = ( 39 | np.convolve( 40 | self.explanation["saliency"].flatten(), np.ones(window), mode="same" 41 | ) 42 | / window 43 | ) 44 | else: 45 | scores = self.explanation["saliency"] 46 | 47 | vmax = np.max(np.abs(self.explanation["saliency"])) 48 | 49 | plt.figure(layout="constrained") 50 | plt.imshow( 51 | scores[np.newaxis, :], aspect="auto", cmap="seismic", vmin=-vmax, vmax=vmax 52 | ) 53 | plt.colorbar() 54 | plt.plot(self.explanation["input_data"]) 55 | instance = self.explanation["instance_prediction"] 56 | plt.title( 57 | "Time Series Saliency Explanation Plot for test point" 58 | f" i={index} with P(Y={cpos})= {instance}" 59 | ) 60 | plt.show() 61 | 62 | 63 | class TSSaliencyExplainer(TSSaliencyExplainerAIX): 64 | """ 65 | Wrapper for TSSaliencyExplainer from aix360. 66 | """ 67 | 68 | def __init__( # pylint: disable=too-many-arguments 69 | self, 70 | model: Callable, 71 | input_length: int, 72 | feature_names: List[str], 73 | base_value: List[float] = None, 74 | n_samples: int = 50, 75 | gradient_samples: int = 25, 76 | gradient_function: Callable = None, 77 | random_seed: int = 22, 78 | ): 79 | super().__init__( 80 | model=model, 81 | input_length=input_length, 82 | feature_names=feature_names, 83 | base_value=base_value, 84 | n_samples=n_samples, 85 | gradient_samples=gradient_samples, 86 | gradient_function=gradient_function, 87 | random_seed=random_seed, 88 | ) 89 | 90 | def explain(self, inputs, outputs=None, **kwargs) -> TSSaliencyResults: 91 | """ 92 | Explain the model's prediction on X. 93 | """ 94 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 95 | return TSSaliencyResults(_explanation) 96 | -------------------------------------------------------------------------------- /src/trustyai/model/domain.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error 2 | """Conversion method between Python and TrustyAI Java types""" 3 | from typing import Optional, Tuple, List, Union 4 | 5 | from jpype import _jclass 6 | 7 | from org.kie.trustyai.explainability.model.domain import ( 8 | FeatureDomain, 9 | NumericalFeatureDomain, 10 | CategoricalFeatureDomain, 11 | CategoricalNumericalFeatureDomain, 12 | ObjectFeatureDomain, 13 | EmptyFeatureDomain, 14 | ) 15 | 16 | 17 | def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDomain]: 18 | r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a 19 | particular feature, which is useful when constraining a counterfactual explanation to ensure it 20 | only recovers valid inputs. For example, if we had a feature that described a person's age, we 21 | might want to constrain it to the range [0, 125] to ensure the counterfactual explanation 22 | doesn't return unlikely ages such as -5 or 715. 23 | 24 | Parameters 25 | ---------- 26 | values : Optional[Union[Tuple, List]] 27 | The valid values of the feature. If ``values`` takes the form of: 28 | 29 | * **A tuple of floats or integers**: The feature domain will be a continuous range from 30 | ``values[0]`` to ``values[1]``. 31 | * **A list of floats or integers**: The feature domain will be a *numeric* categorical, 32 | where `values` contains all possible valid feature values. 33 | * **A list of strings**: The feature domain will be a *string* categorical, where ``values`` 34 | contains all possible valid feature values. 35 | * **A list of objects**: The feature domain will be an *object* categorical, where 36 | ``values`` contains all possible valid feature values. These may present an issue if the 37 | objects are not natively Java serializable. 38 | 39 | Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held 40 | fixed during the counterfactual explanation. 41 | 42 | Returns 43 | ------- 44 | :class:`FeatureDomain` 45 | A Java :class:`FeatureDomain` object, to be used in the :func:`~trustyai.model.feature` 46 | function. 47 | 48 | """ 49 | if not values: 50 | domain = EmptyFeatureDomain.create() 51 | else: 52 | if isinstance(values, tuple): 53 | assert isinstance(values[0], (float, int)) and isinstance( 54 | values[1], (float, int) 55 | ) 56 | assert len(values) == 2, ( 57 | "Tuples passed as domain values must only contain" 58 | " two values that define the (minimum, maximum) of the domain" 59 | ) 60 | domain = NumericalFeatureDomain.create(values[0], values[1]) 61 | 62 | elif isinstance(values, list): 63 | java_array = _jclass.JClass("java.util.Arrays").asList(values) 64 | if isinstance(values[0], bool) and isinstance(values[1], bool): 65 | domain = ObjectFeatureDomain.create(java_array) 66 | elif isinstance(values[0], (float, int)) and isinstance( 67 | values[1], (float, int) 68 | ): 69 | domain = CategoricalNumericalFeatureDomain.create(java_array) 70 | elif isinstance(values[0], str): 71 | domain = CategoricalFeatureDomain.create(java_array) 72 | else: 73 | domain = ObjectFeatureDomain.create(java_array) 74 | 75 | else: 76 | domain = EmptyFeatureDomain.create() 77 | return domain 78 | -------------------------------------------------------------------------------- /tests/general/test_dataset.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801 2 | """Test suite for the Dataset structure""" 3 | 4 | from common import * 5 | 6 | from java.util import Random 7 | from pytest import approx 8 | import pandas as pd 9 | import numpy as np 10 | import uuid 11 | 12 | from trustyai.model import Dataset, Type 13 | 14 | 15 | jrandom = Random() 16 | jrandom.setSeed(0) 17 | 18 | def generate_test_df(): 19 | data = { 20 | 'x1': np.random.uniform(low=100, high=200, size=100), 21 | 'x2': np.random.uniform(low=5000, high=10000, size=100), 22 | 'x3': [str(uuid.uuid4()) for _ in range(100)], 23 | 'x4': np.random.randint(low=0, high=42, size=100), 24 | 'select': np.random.choice(a=[False, True], size=100) 25 | } 26 | return pd.DataFrame(data=data) 27 | 28 | def generate_test_array(): 29 | return np.random.rand(100, 5) 30 | 31 | 32 | def test_no_output(): 33 | """Checks whether we have an output when specifying none""" 34 | df = generate_test_df() 35 | dataset = Dataset.from_df(df) 36 | outputs = dataset.outputs[0].outputs 37 | assert len(outputs) == 1 38 | assert outputs[0].name == 'select' 39 | 40 | def test_outputs(): 41 | """Checks whether we have the correct specified outputs""" 42 | df = generate_test_df() 43 | dataset = Dataset.from_df(df, outputs=["x2", "x3"]) 44 | outputs = dataset.outputs[0].outputs 45 | assert len(outputs) == 2 46 | assert outputs[0].name == 'x2' and outputs[1].name == 'x3' 47 | 48 | def test_shape(): 49 | """Checks whether we have the correct shape""" 50 | df = generate_test_df() 51 | dataset = Dataset.from_df(df, outputs=["x4"]) 52 | assert len(dataset.outputs) == 100 53 | assert len(dataset.inputs) == 100 54 | assert len(dataset.data) == 100 55 | 56 | assert len(dataset.inputs[0].features) == 4 57 | assert len(dataset.outputs[0].outputs) == 1 58 | 59 | def test_types(): 60 | """Checks whether we have the correct shape""" 61 | df = generate_test_df() 62 | dataset = Dataset.from_df(df, outputs=["x4"]) 63 | features = dataset.inputs[0].features 64 | assert features[0].type == Type.NUMBER and features[0].name == 'x1' 65 | assert features[1].type == Type.NUMBER and features[1].name == 'x2' 66 | assert features[2].type == Type.CATEGORICAL and features[2].name == 'x3' 67 | assert features[3].type == Type.BOOLEAN and features[3].name == 'select' 68 | outputs = dataset.outputs[0].outputs 69 | assert outputs[0].type == Type.NUMBER and outputs[0].name == 'x4' 70 | 71 | def test_array_no_output(): 72 | """Checks whether we have an output when specifying none""" 73 | array = generate_test_array() 74 | dataset = Dataset.from_numpy(array) 75 | outputs = dataset.outputs[0].outputs 76 | assert len(outputs) == 1 77 | assert outputs[0].name == 'output-0' 78 | 79 | def test_array_outputs(): 80 | """Checks whether we have the correct specified outputs""" 81 | array = generate_test_array() 82 | dataset = Dataset.from_numpy(array, outputs=[1, 2]) 83 | outputs = dataset.outputs[0].outputs 84 | assert len(outputs) == 2 85 | assert outputs[0].name == 'output-0' and outputs[1].name == 'output-1' 86 | 87 | def test_array_shape(): 88 | """Checks whether we have the correct shape""" 89 | array = generate_test_array() 90 | dataset = Dataset.from_numpy(array, outputs=[4]) 91 | assert len(dataset.outputs) == 100 92 | assert len(dataset.inputs) == 100 93 | assert len(dataset.data) == 100 94 | 95 | assert len(dataset.inputs[0].features) == 4 96 | assert len(dataset.outputs[0].outputs) == 1 -------------------------------------------------------------------------------- /tests/general/test_metrics_language.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code, unused-import 2 | """Language metrics test suite""" 3 | 4 | from common import * 5 | from trustyai.metrics.language import word_error_rate 6 | import math 7 | 8 | tolerance = 1e-4 9 | 10 | REFERENCES = [ 11 | "This is the test reference, to which I will compare alignment against.", 12 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce euismod tortor massa, nec euismod sapien laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.", 13 | "The quick red fox jumped over the lazy brown dog"] 14 | 15 | INPUTS = [ 16 | "I'm a hypothesis reference, from which the aligner will compare against.", 17 | "Lorem ipsum sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce blandit euismod tortor massa, nec euismod sapien blandit laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.", 18 | "dog brown lazy the over jumped fox red quick The"] 19 | 20 | 21 | def test_default_tokenizer(): 22 | """Test default tokenizer""" 23 | results = [4 / 7, 1 / 26, 1] 24 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 25 | wer = word_error_rate(reference, hypothesis).value 26 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 27 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 28 | 29 | 30 | def test_commons_stringtokenizer(): 31 | """Test Apache Commons StringTokenizer""" 32 | from trustyai.utils.tokenizers import CommonsStringTokenizer 33 | results = [8 / 12., 3 / 66., 1.0] 34 | 35 | def tokenizer(text: str) -> List[str]: 36 | return CommonsStringTokenizer(text).getTokenList() 37 | 38 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 39 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 40 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 41 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 42 | 43 | 44 | def test_opennlp_tokenizer(): 45 | """Test Apache Commons StringTokenizer""" 46 | from trustyai.utils.tokenizers import OpenNLPTokenizer 47 | results = [9 / 14., 3 / 78., 1.0] 48 | tokenizer = OpenNLPTokenizer() 49 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 50 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 51 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 52 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 53 | 54 | 55 | def test_python_tokenizer(): 56 | """Test pure Python whitespace tokenizer""" 57 | 58 | results = [3 / 4., 3 / 66., 1.0] 59 | 60 | def tokenizer(text: str) -> List[str]: 61 | return text.split(" ") 62 | 63 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 64 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 65 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 66 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 67 | -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tslime.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSLIME from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | 6 | from typing import Callable, List, Union 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from aix360.algorithms.tslime import TSLimeExplainer as TSLimeExplainerAIX 11 | from aix360.algorithms.tslime.surrogate import LinearSurrogateModel 12 | from pandas.io.formats.style import Styler 13 | import matplotlib.pyplot as plt 14 | 15 | from trustyai.explainers.explanation_results import ExplanationResults 16 | from trustyai.utils.extras.timeseries import TSPerturber 17 | 18 | 19 | class TSSLIMEResults(ExplanationResults): 20 | """Wraps TSLimeExplainer results. This object is returned by the :class:`~TSLimeExplainer`, 21 | and provides a variety of methods to visualize and interact with the explanation. 22 | """ 23 | 24 | def __init__(self, explanation): 25 | self.explanation = explanation 26 | 27 | def as_dataframe(self) -> pd.DataFrame: 28 | """Returns the weights as a pandas dataframe.""" 29 | return pd.DataFrame(self.explanation["history_weights"]) 30 | 31 | def as_html(self) -> Styler: 32 | """Returns the explanation as an HTML table.""" 33 | dataframe = self.as_dataframe() 34 | return dataframe.style 35 | 36 | def plot(self): 37 | """Plot TSLime explanation for the time-series instance. Based on 38 | https://github.com/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_univariate_demo.ipynb 39 | """ 40 | relevant_history = self.explanation["history_weights"].shape[0] 41 | input_data = self.explanation["input_data"] 42 | relevant_df = input_data[-relevant_history:] 43 | 44 | plt.figure(layout="constrained") 45 | plt.plot(relevant_df, label="Input Time Series", marker="o") 46 | plt.gca().invert_yaxis() 47 | 48 | normalized_weights = ( 49 | self.explanation["history_weights"] 50 | / np.mean(np.abs(self.explanation["history_weights"])) 51 | ).flatten() 52 | 53 | plt.bar( 54 | input_data.index[-relevant_history:], 55 | normalized_weights, 56 | 0.4, 57 | label="TSLime Weights (Normalized)", 58 | color="red", 59 | ) 60 | plt.axhline(y=0, color="r", linestyle="-", alpha=0.4) 61 | plt.title("Time Series Lime Explanation Plot") 62 | plt.legend(bbox_to_anchor=(1.25, 1.0), loc="upper right") 63 | plt.show() 64 | 65 | 66 | class TSLimeExplainer(TSLimeExplainerAIX): 67 | """ 68 | Wrapper for TSLimeExplainer from aix360. 69 | """ 70 | 71 | def __init__( # pylint: disable=too-many-arguments 72 | self, 73 | model: Callable, 74 | input_length: int, 75 | n_perturbations: int = 2000, 76 | relevant_history: int = None, 77 | perturbers: List[Union[TSPerturber, dict]] = None, 78 | local_interpretable_model: LinearSurrogateModel = None, 79 | random_seed: int = None, 80 | ): 81 | super().__init__( 82 | model=model, 83 | input_length=input_length, 84 | n_perturbations=n_perturbations, 85 | relevant_history=relevant_history, 86 | perturbers=perturbers, 87 | local_interpretable_model=local_interpretable_model, 88 | random_seed=random_seed, 89 | ) 90 | 91 | def explain(self, inputs, outputs=None, **kwargs) -> TSSLIMEResults: 92 | """ 93 | Explain the model's prediction on X. 94 | """ 95 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 96 | return TSSLIMEResults(_explanation) 97 | -------------------------------------------------------------------------------- /tests/extras/test_tssaliency.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.ensemble import RandomForestRegressor 6 | 7 | from aix360.datasets import SunspotDataset 8 | from trustyai.explainers.extras.tssaliency import TSSaliencyExplainer 9 | from trustyai.utils.extras.timeseries import tsFrame 10 | 11 | 12 | # transform a time series dataset into a supervised learning dataset 13 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 14 | class RandomForestUniVariateForecaster: 15 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 16 | self.n_past = n_past 17 | self.n_future = n_future 18 | self.model = RandomForestRegressor(**RFparams) 19 | 20 | def fit(self, X): 21 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 22 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 23 | self.model = self.model.fit(trainX, trainy) 24 | return self 25 | 26 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 27 | n_vars = 1 if type(data) is list else data.shape[1] 28 | df = pd.DataFrame(data) 29 | cols = list() 30 | 31 | # input sequence (t-n, ... t-1) 32 | for i in range(n_in, 0, -1): 33 | cols.append(df.shift(i)) 34 | # forecast sequence (t, t+1, ... t+n) 35 | for i in range(0, n_out): 36 | cols.append(df.shift(-i)) 37 | # put it all together 38 | agg = pd.concat(cols, axis=1) 39 | # drop rows with NaN values 40 | if dropnan: 41 | agg.dropna(inplace=True) 42 | return agg.values 43 | 44 | def predict(self, X): 45 | row = X[-self.n_past:].flatten() 46 | y_pred = self.model.predict(np.asarray([row])) 47 | return y_pred 48 | 49 | 50 | class TestTSSaliencyExplainer(unittest.TestCase): 51 | def setUp(self): 52 | # load data 53 | df, schema = SunspotDataset().load_data() 54 | ts = tsFrame( 55 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 56 | ) 57 | 58 | (self.ts_train, self.ts_test) = train_test_split( 59 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 60 | ) 61 | 62 | def test_tssaliency(self): 63 | # load model 64 | input_length = 48 65 | forecast_horizon = 10 66 | forecaster = RandomForestUniVariateForecaster( 67 | n_past=input_length, n_future=forecast_horizon 68 | ) 69 | 70 | forecaster.fit(self.ts_train.iloc[-200:]) 71 | 72 | # initialize/fit explainer 73 | 74 | explainer = TSSaliencyExplainer( 75 | model=forecaster.predict, 76 | input_length=input_length, 77 | feature_names=self.ts_train.columns.tolist(), 78 | n_samples=2, 79 | gradient_samples=50, 80 | ) 81 | 82 | # compute explanations 83 | test_window = self.ts_test.iloc[:input_length] 84 | explanation = explainer.explain(test_window) 85 | 86 | # validate explanation structure 87 | self.assertIn("input_data", explanation.explanation) 88 | self.assertIn("feature_names", explanation.explanation) 89 | self.assertIn("saliency", explanation.explanation) 90 | self.assertIn("timestamps", explanation.explanation) 91 | self.assertIn("base_value", explanation.explanation) 92 | self.assertIn("instance_prediction", explanation.explanation) 93 | self.assertIn("base_value_prediction", explanation.explanation) 94 | 95 | self.assertEqual(explanation.explanation["saliency"].shape, test_window.shape) 96 | -------------------------------------------------------------------------------- /src/trustyai/initializer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value, invalid-name, R0801 2 | # pylint: disable = deprecated-module 3 | """Main TrustyAI Python bindings""" 4 | try: 5 | from distutils.sysconfig import get_python_lib 6 | except ImportError: 7 | # distutils is deprecated and removed in Python 3.12+ 8 | # Use sysconfig instead 9 | import sysconfig 10 | 11 | def get_python_lib(): 12 | """Fallback implementation of get_python_lib using sysconfig.""" 13 | return sysconfig.get_path("purelib") 14 | 15 | 16 | import glob 17 | import logging 18 | import os 19 | from pathlib import Path 20 | import site 21 | from typing import List 22 | import uuid 23 | import warnings 24 | 25 | import jpype 26 | import jpype.imports 27 | from jpype import _jcustomizer, _jclass 28 | 29 | DEFAULT_ARGS = ( 30 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 31 | # see https://arrow.apache.org/docs/java/install.html#java-compatibility 32 | "-Dorg.slf4j.simpleLogger.defaultLogLevel=error", 33 | ) 34 | 35 | 36 | def _get_default_path(): 37 | try: 38 | default_dep_path = os.path.join(site.getsitepackages()[0], "trustyai", "dep") 39 | except AttributeError: 40 | default_dep_path = os.path.join(get_python_lib(), "trustyai", "dep") 41 | 42 | core_deps = [ 43 | f"{default_dep_path}/org/trustyai/explainability-arrow-999-SNAPSHOT.jar", 44 | ] 45 | 46 | return core_deps, default_dep_path 47 | 48 | 49 | def init(*args, path=None): 50 | """init(*args, path=JAVA_DEPENDENCIES) 51 | 52 | Manually initialize the JVM. If you would like to manually specify the Java libraries to be 53 | imported, for example if you want to use a different version of the Trusty Explainability 54 | library than is bundled by default, you can do so by calling :func:`init`. If this is not 55 | manually called, trustyai will use the default set of libraries and automatically initialize 56 | itself when necessary. 57 | 58 | Parameters 59 | ---------- 60 | args: list 61 | List of args to be passed to ``jpype.startJVM``. See the 62 | `JPype manual `_ 63 | for more details. 64 | path: list[str] 65 | List of jar files to add the Java class path. By default, this will add the necessary 66 | dependencies of the TrustyAI Java library. 67 | """ 68 | # Launch the JVM 69 | try: 70 | # get default dependencies 71 | if path is None: 72 | path, default_dep_path = _get_default_path() 73 | logging.debug("Checking for dependencies in %s", default_dep_path) 74 | 75 | # check the classpath 76 | for jar_path in path: 77 | if "*" not in jar_path: 78 | jar_path_exists = Path(jar_path).exists() 79 | else: 80 | jar_path_exists = any( 81 | Path(fp).exists() for fp in glob.glob(jar_path) if ".jar" in fp 82 | ) 83 | if jar_path_exists: 84 | logging.debug("JAR %s found.", jar_path) 85 | else: 86 | logging.error("JAR %s not found.", jar_path) 87 | 88 | _args = args + DEFAULT_ARGS 89 | jpype.startJVM(*_args, classpath=path) 90 | 91 | from java.lang import Thread 92 | 93 | if not Thread.isAttached: 94 | jpype.attachThreadToJVM() 95 | 96 | from java.util import UUID 97 | 98 | @_jcustomizer.JConversion("java.util.List", exact=List) 99 | def _JListConvert(_, py_list: List): 100 | return _jclass.JClass("java.util.Arrays").asList(py_list) 101 | 102 | @_jcustomizer.JConversion("java.util.UUID", instanceof=uuid.UUID) 103 | def _JUUIDConvert(_, obj): 104 | return UUID.fromString(str(obj)) 105 | 106 | except OSError: 107 | print("JVM already initialized") 108 | warnings.warn("JVM already initialized") 109 | 110 | return True 111 | -------------------------------------------------------------------------------- /tests/extras/test_tslime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.ensemble import RandomForestRegressor 7 | from trustyai.utils.extras.timeseries import tsFrame 8 | from aix360.datasets import SunspotDataset 9 | from trustyai.explainers.extras.tslime import TSLimeExplainer 10 | from trustyai.utils.extras.timeseries import BlockBootstrapPerturber 11 | 12 | 13 | # transform a time series dataset into a supervised learning dataset 14 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 15 | class RandomForestUniVariateForecaster: 16 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 17 | self.n_past = n_past 18 | self.n_future = n_future 19 | self.model = RandomForestRegressor(**RFparams) 20 | 21 | def fit(self, X): 22 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 23 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 24 | self.model = self.model.fit(trainX, trainy) 25 | return self 26 | 27 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 28 | n_vars = 1 if type(data) is list else data.shape[1] 29 | df = pd.DataFrame(data) 30 | cols = list() 31 | 32 | # input sequence (t-n, ... t-1) 33 | for i in range(n_in, 0, -1): 34 | cols.append(df.shift(i)) 35 | # forecast sequence (t, t+1, ... t+n) 36 | for i in range(0, n_out): 37 | cols.append(df.shift(-i)) 38 | # put it all together 39 | agg = pd.concat(cols, axis=1) 40 | # drop rows with NaN values 41 | if dropnan: 42 | agg.dropna(inplace=True) 43 | return agg.values 44 | 45 | def predict(self, X): 46 | row = X[-self.n_past:].flatten() 47 | y_pred = self.model.predict(np.asarray([row])) 48 | return y_pred 49 | 50 | 51 | class TestTSLimeExplainer(unittest.TestCase): 52 | def setUp(self): 53 | # load data 54 | df, schema = SunspotDataset().load_data() 55 | ts = tsFrame( 56 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 57 | ) 58 | 59 | (self.ts_train, self.ts_test) = train_test_split( 60 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 61 | ) 62 | 63 | def test_tslime(self): 64 | # load model 65 | input_length = 24 66 | forecast_horizon = 4 67 | forecaster = RandomForestUniVariateForecaster( 68 | n_past=input_length, n_future=forecast_horizon 69 | ) 70 | 71 | forecaster.fit(self.ts_train.iloc[-200:]) 72 | 73 | # initialize/fit explainer 74 | 75 | relevant_history = 12 76 | explainer = TSLimeExplainer( 77 | model=forecaster.predict, 78 | input_length=input_length, 79 | relevant_history=relevant_history, 80 | perturbers=[ 81 | BlockBootstrapPerturber( 82 | window_length=min(4, input_length - 1), block_length=2, block_swap=2 83 | ), 84 | ], 85 | n_perturbations=10, 86 | random_seed=22, 87 | ) 88 | 89 | # compute explanations 90 | test_window = self.ts_test.iloc[:input_length] 91 | explanation = explainer.explain(test_window) 92 | 93 | # validate explanation structure 94 | self.assertIn("input_data", explanation.explanation) 95 | self.assertIn("history_weights", explanation.explanation) 96 | self.assertIn("x_perturbations", explanation.explanation) 97 | self.assertIn("y_perturbations", explanation.explanation) 98 | self.assertIn("model_prediction", explanation.explanation) 99 | self.assertIn("surrogate_prediction", explanation.explanation) 100 | 101 | self.assertEqual(explanation.explanation["history_weights"].shape[0], relevant_history) 102 | -------------------------------------------------------------------------------- /src/trustyai/utils/_tyrus_info_text.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = consider-using-f-string 2 | """Info text used in Tyrus visualization explainer info""" 3 | from trustyai.utils._visualisation import bold_red_html, bold_green_html 4 | 5 | LIME_TEXT = """ 6 |
7 |

What is LIME?

8 | 9 |

10 | LIME (Local Interpretable Model-agnostic Explanations) explanations answer the following question: 11 |

"Which features were most important to the predicted {{0}}?"
12 | LIME does this by providing per-feature saliencies, numeric weights that describe how strongly each feature contributed to the model’s output. 13 |

14 | 15 |

16 | In this plot, each horizontal bar represents a feature's saliency: features with positive importance to the predicted {{0}} are marked in {}, while 17 | features with negative importance are marked in {}. The larger the bar, the more important the feature was to the output. 18 |

19 | 20 |

21 | To see how TrustyAI's LIME works, check out the documentation! 22 |

23 |
24 | """.format( 25 | bold_green_html("green"), bold_red_html("red") 26 | ) 27 | 28 | SHAP_TEXT = """ 29 |
30 |

What is SHAP?

31 | 32 | SHAP (SHapley Additive exPlanations) explanations answer the following question: 33 |
“By how much did each feature contribute to the predicted {{}}?”
34 | 35 |

36 | SHAP does this by providing SHAP values that provide an additive explanation of the model output; a receipt for the model’s output. 37 | SHAP will produce a list of per-feature contributions, the sum of which will equal the model's output. 38 | To operate, SHAP also needs access to a background dataset, a set of representative input datapoints that captures 39 | the model’s “normal behavior”. All SHAP values are comparisons against to this background data, i.e., 40 | "By how much did each feature of this input contribute to the output, as compared to the background inputs?" 41 |

42 | 43 |

44 | In this plot, the dotted horizontal line shows the average model output over the background, the starting 45 | "baseline comparison" mark for a SHAP explanation. Then, each vertical bar or candle describes how 46 | each feature {} or {} its contribution to the model's predicted output, marked by the solid horizontal line. The larger 47 | the feature's contribution, the larger the bar. 48 |

49 | 50 |

51 | To see how TrustyAI's SHAP works, check out the documentation! 52 |

53 |
54 | """.format( 55 | bold_green_html("adds"), bold_red_html("subtracts") 56 | ) 57 | 58 | CF_TEXT = """ 59 |
60 |

What is a Counterfactual?

61 | 62 |

63 | Counterfactuals represent alternate, "what-if" scenarios; what other possible values of {0} 64 | can be attained by modifying the input? 65 |

66 | 67 |

68 | This plot shows all of counterfactuals 69 | produced during the computation of the LIME and SHAP explanations. On the x-axis are 70 | novel counterfactual values for {0}, while the y-axis shows the number of features that were changed to produce 71 | that particular value. Hover over individual points to see the exact changes to the original input 72 | necessary to produce the displayed counterfactual value of {0}. 73 |

74 | 75 |

76 | To see how TrustyAI's Counterfactual Explainer works, check out the documentation! 77 |

78 |
79 | """ 80 | -------------------------------------------------------------------------------- /tests/general/test_datautils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Data utils test suite""" 3 | from common import * 4 | 5 | from pytest import approx 6 | import random 7 | 8 | from trustyai.utils import DataUtils 9 | from trustyai.model import FeatureFactory 10 | from java.util import Random 11 | 12 | jrandom = Random() 13 | 14 | 15 | def test_get_mean(): 16 | """Test GetMean""" 17 | data = [2, 4, 3, 5, 1] 18 | assert DataUtils.getMean(data) == approx(3, 1e-6) 19 | 20 | 21 | def test_get_std_dev(): 22 | """Test GetStdDev""" 23 | data = [2, 4, 3, 5, 1] 24 | assert DataUtils.getStdDev(data, 3) == approx(1.41, 1e-2) 25 | 26 | 27 | def test_gaussian_kernel(): 28 | """Test Gaussian Kernel""" 29 | x = 0.0 30 | k = DataUtils.gaussianKernel(x, 0, 1) 31 | assert k == approx(0.398, 1e-2) 32 | x = 0.218 33 | k = DataUtils.gaussianKernel(x, 0, 1) 34 | assert k == approx(0.389, 1e-2) 35 | 36 | 37 | def test_euclidean_distance(): 38 | """Test Euclidean distance""" 39 | x = [1, 1] 40 | y = [2, 3] 41 | distance = DataUtils.euclideanDistance(x, y) 42 | assert approx(distance, 1e-3) == 2.236 43 | 44 | 45 | def test_hamming_distance_double(): 46 | """Test Hamming distance for doubles""" 47 | x = [2, 1] 48 | y = [2, 3] 49 | distance = DataUtils.hammingDistance(x, y) 50 | assert distance == approx(1, 1e-1) 51 | 52 | 53 | def test_hamming_distance_string(): 54 | """Test Hamming distance for strings""" 55 | x = "test1" 56 | y = "test2" 57 | distance = DataUtils.hammingDistance(x, y) 58 | assert distance == approx(1, 1e-1) 59 | 60 | 61 | def test_doubles_to_features(): 62 | """Test doubles to features""" 63 | inputs = [1 if i % 2 == 0 else 0 for i in range(10)] 64 | features = DataUtils.doublesToFeatures(inputs) 65 | assert features is not None 66 | assert len(features) == 10 67 | for f in features: 68 | assert f is not None 69 | assert f.getName() is not None 70 | assert f.getValue() is not None 71 | 72 | 73 | def test_exponential_smoothing_kernel(): 74 | """Test exponential smoothing kernel""" 75 | x = 0.218 76 | k = DataUtils.exponentialSmoothingKernel(x, 2) 77 | assert k == approx(0.994, 1e-3) 78 | 79 | 80 | # def test_perturb_features_empty(): 81 | # """Test perturb empty features""" 82 | # features = [] 83 | # perturbationContext = PerturbationContext(jrandom, 0) 84 | # newFeatures = DataUtils.perturbFeatures(features, perturbationContext) 85 | # assert newFeatures is not None 86 | # assert len(features) == newFeatures.size() 87 | 88 | 89 | def test_random_distribution_generation(): 90 | """Test random distribution generation""" 91 | dataDistribution = DataUtils.generateRandomDataDistribution(10, 10, jrandom) 92 | assert dataDistribution is not None 93 | assert dataDistribution.asFeatureDistributions() is not None 94 | for featureDistribution in dataDistribution.asFeatureDistributions(): 95 | assert featureDistribution is not None 96 | 97 | 98 | def test_linearized_numeric_features(): 99 | """Test linearised numeric features""" 100 | f = FeatureFactory.newNumericalFeature("f-num", 1.0) 101 | features = [f] 102 | linearizedFeatures = DataUtils.getLinearizedFeatures(features) 103 | assert len(features) == linearizedFeatures.size() 104 | 105 | 106 | def test_sample_with_replacement(): 107 | """Test sample with replacement""" 108 | emptyValues = [] 109 | emptySamples = DataUtils.sampleWithReplacement(emptyValues, 1, jrandom) 110 | assert emptySamples is not None 111 | assert emptySamples.size() == 0 112 | 113 | values = DataUtils.generateData(0, 1, 100, jrandom) 114 | sampleSize = 10 115 | samples = DataUtils.sampleWithReplacement(values, sampleSize, jrandom) 116 | assert samples is not None 117 | assert samples.size() == sampleSize 118 | assert samples[random.randint(0, sampleSize - 1)] in values 119 | 120 | largerSampleSize = 300 121 | largerSamples = DataUtils.sampleWithReplacement(values, largerSampleSize, jrandom) 122 | assert largerSamples is not None 123 | assert largerSampleSize == largerSamples.size() 124 | assert largerSamples[random.randint(0, largerSampleSize - 1)] in largerSamples 125 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/lime.py: -------------------------------------------------------------------------------- 1 | """Visualizations.lime module""" 2 | 3 | # pylint: disable = import-error, too-few-public-methods, consider-using-f-string, missing-final-newline 4 | import matplotlib.pyplot as plt 5 | import matplotlib as mpl 6 | from bokeh.models import ColumnDataSource, HoverTool 7 | from bokeh.plotting import figure 8 | import pandas as pd 9 | 10 | from trustyai.utils._visualisation import ( 11 | DEFAULT_STYLE as ds, 12 | DEFAULT_RC_PARAMS as drcp, 13 | bold_red_html, 14 | bold_green_html, 15 | output_html, 16 | feature_html, 17 | ) 18 | from trustyai.visualizations.visualization_results import VisualizationResults 19 | 20 | 21 | class LimeViz(VisualizationResults): 22 | """Visualizes LIME results.""" 23 | 24 | def _matplotlib_plot( 25 | self, explanations, output_name: str, block=True, call_show=True 26 | ) -> None: 27 | """Plot the LIME saliencies.""" 28 | with mpl.rc_context(drcp): 29 | dictionary = {} 30 | for feature_importance in ( 31 | explanations.saliency_map().get(output_name).getPerFeatureImportance() 32 | ): 33 | dictionary[feature_importance.getFeature().name] = ( 34 | feature_importance.getScore() 35 | ) 36 | 37 | colours = [ 38 | ( 39 | ds["negative_primary_colour"] 40 | if i < 0 41 | else ds["positive_primary_colour"] 42 | ) 43 | for i in dictionary.values() 44 | ] 45 | plt.title(f"LIME: Feature Importances to {output_name}") 46 | plt.barh( 47 | range(len(dictionary)), 48 | dictionary.values(), 49 | align="center", 50 | color=colours, 51 | ) 52 | plt.yticks(range(len(dictionary)), list(dictionary.keys())) 53 | plt.tight_layout() 54 | 55 | if call_show: 56 | plt.show(block=block) 57 | 58 | def _get_bokeh_plot(self, explanations, output_name): 59 | lime_data_source = pd.DataFrame( 60 | [ 61 | { 62 | "feature": str(pfi.getFeature().getName()), 63 | "saliency": pfi.getScore(), 64 | } 65 | for pfi in explanations.saliency_map()[ 66 | output_name 67 | ].getPerFeatureImportance() 68 | ] 69 | ) 70 | lime_data_source["color"] = lime_data_source["saliency"].apply( 71 | lambda x: ( 72 | ds["positive_primary_colour"] 73 | if x >= 0 74 | else ds["negative_primary_colour"] 75 | ) 76 | ) 77 | lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply( 78 | lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x)) 79 | ) 80 | 81 | lime_data_source["color_faded"] = lime_data_source["saliency"].apply( 82 | lambda x: ( 83 | ds["positive_primary_colour_faded"] 84 | if x >= 0 85 | else ds["negative_primary_colour_faded"] 86 | ) 87 | ) 88 | source = ColumnDataSource(lime_data_source) 89 | htool = HoverTool( 90 | name="bars", 91 | tooltips="

LIME

{} saliency to {}: @saliency_colored".format( 92 | feature_html("@feature"), output_html(output_name) 93 | ), 94 | ) 95 | bokeh_plot = figure( 96 | sizing_mode="stretch_both", 97 | title="Lime Feature Importances", 98 | y_range=lime_data_source["feature"], 99 | tools=[htool], 100 | ) 101 | bokeh_plot.hbar( 102 | y="feature", 103 | left=0, 104 | right="saliency", 105 | fill_color="color_faded", 106 | line_color="color", 107 | hover_color="color", 108 | color="color", 109 | height=0.75, 110 | name="bars", 111 | source=source, 112 | ) 113 | bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000") 114 | bokeh_plot.xaxis.axis_label = "Saliency Value" 115 | bokeh_plot.yaxis.axis_label = "Feature" 116 | return bokeh_plot 117 | 118 | def _get_bokeh_plot_dict(self, explanations): 119 | return { 120 | output_name: self._get_bokeh_plot(explanations, output_name) 121 | for output_name in explanations.saliency_map().keys() 122 | } 123 | -------------------------------------------------------------------------------- /src/trustyai/explainers/pdp.py: -------------------------------------------------------------------------------- 1 | """Explainers.pdp module""" 2 | 3 | import math 4 | import pandas as pd 5 | from pandas.io.formats.style import Styler 6 | 7 | from jpype import ( 8 | JImplements, 9 | JOverride, 10 | ) 11 | 12 | # pylint: disable = import-error 13 | from org.kie.trustyai.explainability.global_ import pdp 14 | 15 | # pylint: disable = import-error 16 | from org.kie.trustyai.explainability.model import ( 17 | PredictionProvider, 18 | PredictionInputsDataDistribution, 19 | PredictionOutput, 20 | Output, 21 | Type, 22 | Value, 23 | ) 24 | 25 | from trustyai.utils.data_conversions import ManyInputsUnionType, many_inputs_convert 26 | 27 | from .explanation_results import ExplanationResults 28 | 29 | 30 | class PDPResults(ExplanationResults): 31 | """ 32 | Results class for Partial Dependence Plots 33 | """ 34 | 35 | def __init__(self, pdp_graphs): 36 | self.pdp_graphs = pdp_graphs 37 | 38 | def as_dataframe(self) -> pd.DataFrame: 39 | """ 40 | Returns 41 | ------- 42 | a pd.DataFrame with input values and feature name as 43 | columns and marginal feature outputs as rows 44 | """ 45 | pdp_series_list = [] 46 | for pdp_graph in self.pdp_graphs: 47 | inputs = [self._to_plottable(x) for x in pdp_graph.getX()] 48 | outputs = [self._to_plottable(y) for y in pdp_graph.getY()] 49 | pdp_dict = dict(zip(inputs, outputs)) 50 | pdp_dict["feature"] = "" + str(pdp_graph.getFeature().getName()) 51 | pdp_series = pd.Series(index=inputs + ["feature"], data=pdp_dict) 52 | pdp_series_list.append(pdp_series) 53 | pdp_df = pd.DataFrame(pdp_series_list) 54 | return pdp_df 55 | 56 | def as_html(self) -> Styler: 57 | """ 58 | Returns 59 | ------- 60 | Style object from the PDP pd.DataFrame (see as_dataframe) 61 | """ 62 | return self.as_dataframe().style 63 | 64 | @staticmethod 65 | def _to_plottable(datum: Value): 66 | plottable = datum.asNumber() 67 | if math.isnan(plottable): 68 | plottable = str(datum.asString()) 69 | return plottable 70 | 71 | 72 | # pylint: disable = too-few-public-methods 73 | class PDPExplainer: 74 | """ 75 | Partial Dependence Plot explainer. 76 | See https://christophm.github.io/interpretable-ml-book/pdp.html 77 | """ 78 | 79 | def __init__(self, config=None): 80 | if config is None: 81 | config = pdp.PartialDependencePlotConfig() 82 | self._explainer = pdp.PartialDependencePlotExplainer(config) 83 | 84 | def explain( 85 | self, model: PredictionProvider, data: ManyInputsUnionType, num_outputs: int = 1 86 | ) -> PDPResults: 87 | """ 88 | Parameters 89 | ---------- 90 | model: PredictionProvider 91 | the model to explain 92 | data: ManyInputsUnionType 93 | the data used to calculate the PDP 94 | num_outputs: int 95 | the number of outputs to calculate the PDP for 96 | 97 | Returns 98 | ------- 99 | pdp_results: PDPResults 100 | the partial dependence plots associated to the model outputs 101 | """ 102 | metadata = _PredictionProviderMetadata(many_inputs_convert(data), num_outputs) 103 | pdp_graphs = self._explainer.explainFromMetadata(model, metadata) 104 | return PDPResults(pdp_graphs) 105 | 106 | 107 | @JImplements( 108 | "org.kie.trustyai.explainability.model.PredictionProviderMetadata", deferred=True 109 | ) 110 | class _PredictionProviderMetadata: 111 | """ 112 | Implementation of org.kie.trustyai.explainability.model.PredictionProviderMetadata interface 113 | """ 114 | 115 | def __init__(self, data: list, size: int): 116 | """ 117 | Parameters 118 | ---------- 119 | data: ManyInputsUnionType 120 | the data 121 | size: int 122 | the size of the model output 123 | """ 124 | self.data = PredictionInputsDataDistribution(data) 125 | outputs = [] 126 | for _ in range(size): 127 | outputs.append(Output("", Type.UNDEFINED)) 128 | self.pred_out = PredictionOutput(outputs) 129 | 130 | # pylint: disable = invalid-name 131 | @JOverride 132 | def getDataDistribution(self): 133 | """ 134 | Returns 135 | -------- 136 | the underlying data distribution 137 | """ 138 | return self.data 139 | 140 | # pylint: disable = invalid-name 141 | @JOverride 142 | def getInputShape(self): 143 | """ 144 | Returns 145 | -------- 146 | a PredictionInput from the underlying distribution 147 | """ 148 | return self.data.sample() 149 | 150 | # pylint: disable = invalid-name, missing-final-newline 151 | @JOverride 152 | def getOutputShape(self): 153 | """ 154 | Returns 155 | -------- 156 | a PredictionOutput 157 | """ 158 | return self.pred_out 159 | -------------------------------------------------------------------------------- /src/trustyai/metrics/saliency.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error 2 | """Saliency evaluation metrics""" 3 | from typing import Union 4 | 5 | from org.apache.commons.lang3.tuple import ( 6 | Pair as _Pair, 7 | ) 8 | 9 | from org.kie.trustyai.explainability.model import ( 10 | PredictionInput, 11 | PredictionInputsDataDistribution, 12 | ) 13 | from org.kie.trustyai.explainability.local import LocalExplainer 14 | 15 | from jpype import JObject 16 | 17 | from trustyai.model import simple_prediction, PredictionProvider 18 | from trustyai.explainers import SHAPExplainer, LimeExplainer 19 | 20 | from . import ExplainabilityMetrics 21 | 22 | 23 | def impact_score( 24 | model: PredictionProvider, 25 | pred_input: PredictionInput, 26 | explainer: Union[LimeExplainer, SHAPExplainer], 27 | k: int, 28 | is_model_callable: bool = False, 29 | ): 30 | """ 31 | Parameters 32 | ---------- 33 | model: trustyai.PredictionProvider 34 | the model used to generate predictions 35 | pred_input: trustyai.PredictionInput 36 | the input to the model 37 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 38 | the explainer to evaluate 39 | k: int 40 | the number of top important features 41 | is_model_callable: bool 42 | whether to directly use model function call or use the predict method 43 | 44 | Returns 45 | ------- 46 | :float: 47 | impact score metric 48 | """ 49 | if is_model_callable: 50 | output = model(pred_input) 51 | else: 52 | output = model.predict([pred_input])[0].outputs 53 | pred = simple_prediction(pred_input, output) 54 | explanation = explainer.explain(inputs=pred_input, outputs=output, model=model) 55 | saliency = list(explanation.saliency_map().values())[0] 56 | top_k_features = saliency.getTopFeatures(k) 57 | return ExplainabilityMetrics.impactScore(model, pred, top_k_features) 58 | 59 | 60 | def mean_impact_score( 61 | explainer: Union[LimeExplainer, SHAPExplainer], 62 | model: PredictionProvider, 63 | data: list, 64 | is_model_callable=False, 65 | k=2, 66 | ): 67 | """ 68 | Parameters 69 | ---------- 70 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 71 | the explainer to evaluate 72 | model: trustyai.PredictionProvider 73 | the model used to generate predictions 74 | data: list[list[trustyai.model.Feature]] 75 | the inputs to calculate the metric for 76 | is_model_callable: bool 77 | whether to directly use model function call or use the predict method 78 | k: int 79 | the number of top important features 80 | 81 | Returns 82 | ------- 83 | :float: 84 | the mean impact score metric across all inputs 85 | """ 86 | m_is = 0 87 | for features in data: 88 | m_is += impact_score( 89 | model, features, explainer, k, is_model_callable=is_model_callable 90 | ) 91 | return m_is / len(data) 92 | 93 | 94 | def classification_fidelity( 95 | explainer: Union[LimeExplainer, SHAPExplainer], 96 | model: PredictionProvider, 97 | inputs: list, 98 | is_model_callable: bool = False, 99 | ): 100 | """ 101 | Parameters 102 | ---------- 103 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 104 | the explainer to evaluate 105 | model: trustyai.PredictionProvider 106 | the model used to generate predictions 107 | inputs: list[list[trustyai.model.Feature]] 108 | the inputs to calculate the metric for 109 | is_model_callable: bool 110 | whether to directly use model function call or use the predict method 111 | 112 | Returns 113 | ------- 114 | :float: 115 | the classification fidelity metric 116 | """ 117 | pairs = [] 118 | for c_input in inputs: 119 | if is_model_callable: 120 | output = model(c_input) 121 | else: 122 | output = model.predict([c_input])[0].outputs 123 | explanation = explainer.explain(inputs=c_input, outputs=output, model=model) 124 | saliency = list(explanation.saliency_map().values())[0] 125 | pairs.append(_Pair.of(saliency, simple_prediction(c_input, output))) 126 | return ExplainabilityMetrics.classificationFidelity(pairs) 127 | 128 | 129 | # pylint: disable = too-many-arguments 130 | def local_saliency_f1( 131 | output_name: str, 132 | model: PredictionProvider, 133 | explainer: Union[LimeExplainer, SHAPExplainer], 134 | distribution: PredictionInputsDataDistribution, 135 | k: int, 136 | chunk_size: int, 137 | ): 138 | """ 139 | Parameters 140 | ---------- 141 | output_name: str 142 | the name of the output to calculate the metric for 143 | model: trustyai.PredictionProvider 144 | the model used to generate predictions 145 | explainer: Union[trustyai.explainers.LIMEExplainer, trustyai.explainers.SHAPExplainer, 146 | trustyai.explainers.LocalExplainer] 147 | the explainer to evaluate 148 | distribution: org.kie.trustyai.explainability.model.PredictionInputsDataDistribution 149 | the data distribution to fetch the inputs from 150 | k: int 151 | the number of top important features 152 | chunk_size: int 153 | the chunk of inputs to fetch fro the distribution 154 | 155 | Returns 156 | ------- 157 | :float: 158 | the local saliency f1 metric 159 | """ 160 | if not isinstance(explainer, LocalExplainer): 161 | # pylint: disable = protected-access 162 | local_explainer = JObject(explainer._explainer, LocalExplainer) 163 | else: 164 | local_explainer = explainer 165 | return ExplainabilityMetrics.getLocalSaliencyF1( 166 | output_name, model, local_explainer, distribution, k, chunk_size 167 | ) 168 | -------------------------------------------------------------------------------- /tests/general/test_shap.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code, unused-import 2 | """SHAP explainer test suite""" 3 | 4 | from common import * 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | np.random.seed(0) 11 | 12 | import pytest 13 | from trustyai.explainers import SHAPExplainer 14 | from trustyai.model import feature, Model 15 | from trustyai.utils.data_conversions import numpy_to_prediction_object 16 | from trustyai.utils import TestModels 17 | from trustyai.visualizations import plot 18 | 19 | 20 | def test_no_variance_one_output(): 21 | """Check if the explanation returned is not null""" 22 | model = TestModels.getSumSkipModel(0) 23 | 24 | background = np.array([[1.0, 2.0, 3.0] for _ in range(2)]) 25 | prediction_outputs = model.predictAsync(numpy_to_prediction_object(background, feature)).get() 26 | shap_explainer = SHAPExplainer(background=background) 27 | for i in range(2): 28 | explanation = shap_explainer.explain(inputs=background[i], outputs=prediction_outputs[i].outputs, model=model) 29 | for _, saliency in explanation.saliency_map().items(): 30 | for feature_importance in saliency.getPerFeatureImportance()[:-1]: 31 | assert feature_importance.getScore() == 0.0 32 | 33 | 34 | def test_shap_arrow(): 35 | """Basic SHAP/Arrow test""" 36 | np.random.seed(0) 37 | data = pd.DataFrame(np.random.rand(101, 5)) 38 | background = data.iloc[:100] 39 | to_explain = data.iloc[100:101] 40 | 41 | model_weights = np.random.rand(5) 42 | predict_function = lambda x: np.dot(x.values, model_weights) 43 | 44 | model = Model(predict_function, dataframe_input=True) 45 | shap_explainer = SHAPExplainer(background=background) 46 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 47 | 48 | 49 | answers = [-.152, -.114, 0.00304, .0525, -.0725] 50 | for _, saliency in explanation.saliency_map().items(): 51 | for i, feature_importance in enumerate(saliency.getPerFeatureImportance()[:-1]): 52 | assert answers[i] - 1e-2 <= feature_importance.getScore() <= answers[i] + 1e-2 53 | 54 | 55 | def shap_plots(block): 56 | """Test SHAP plots""" 57 | np.random.seed(0) 58 | data = pd.DataFrame(np.random.rand(101, 5)) 59 | background = data.iloc[:100] 60 | to_explain = data.iloc[100:101] 61 | 62 | model_weights = np.random.rand(5) 63 | predict_function = lambda x: np.stack([np.dot(x.values, model_weights), 2 * np.dot(x.values, model_weights)], -1) 64 | model = Model(predict_function, dataframe_input=True) 65 | shap_explainer = SHAPExplainer(background=background) 66 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 67 | 68 | plot(explanation, block=block) 69 | plot(explanation, block=block, render_bokeh=True) 70 | plot(explanation, block=block, output_name='output-0') 71 | plot(explanation, block=block, output_name='output-0', render_bokeh=True) 72 | 73 | 74 | @pytest.mark.block_plots 75 | def test_shap_plots_blocking(): 76 | shap_plots(block=True) 77 | 78 | 79 | def test_shap_plots(): 80 | shap_plots(block=False) 81 | 82 | 83 | def test_shap_as_df(): 84 | np.random.seed(0) 85 | data = pd.DataFrame(np.random.rand(101, 5)) 86 | background = data.iloc[:100].values 87 | to_explain = data.iloc[100:101].values 88 | 89 | model_weights = np.random.rand(5) 90 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 91 | 92 | model = Model(predict_function, disable_arrow=True) 93 | 94 | shap_explainer = SHAPExplainer(background=background) 95 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 96 | 97 | for out_name, df in explanation.as_dataframe().items(): 98 | assert "Mean Background Value" in df 99 | assert "output" in out_name 100 | assert all([x in str(df) for x in "01234"]) 101 | 102 | 103 | def test_shap_as_html(): 104 | np.random.seed(0) 105 | data = pd.DataFrame(np.random.rand(101, 5)) 106 | background = data.iloc[:100].values 107 | to_explain = data.iloc[100:101].values 108 | 109 | model_weights = np.random.rand(5) 110 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 111 | 112 | model = Model(predict_function, disable_arrow=True) 113 | 114 | shap_explainer = SHAPExplainer(background=background) 115 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 116 | assert True 117 | 118 | 119 | def test_shap_numpy(): 120 | np.random.seed(0) 121 | data = np.random.rand(101, 5) 122 | model_weights = np.random.rand(5) 123 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 124 | fnames = ['f{}'.format(x) for x in "abcde"] 125 | onames = ['o{}'.format(x) for x in "12"] 126 | model = Model(predict_function, 127 | feature_names=fnames, 128 | output_names=onames 129 | ) 130 | 131 | shap_explainer = SHAPExplainer(background=data[1:]) 132 | explanation = shap_explainer.explain(inputs=data[0], outputs=model(data[0]), model=model) 133 | 134 | for oname in onames: 135 | assert oname in explanation.as_dataframe().keys() 136 | for fname in fnames: 137 | assert fname in explanation.as_dataframe()[oname]['Feature'].values 138 | 139 | 140 | # deliberately make strange plot to test pre and post-function plot editing 141 | def test_shap_edit_plot(): 142 | np.random.seed(0) 143 | data = pd.DataFrame(np.random.rand(101, 5)) 144 | background = data.iloc[:100].values 145 | to_explain = data.iloc[100:101].values 146 | 147 | model_weights = np.random.rand(5) 148 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 149 | 150 | model = Model(predict_function, disable_arrow=True) 151 | 152 | shap_explainer = SHAPExplainer(background=background) 153 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 154 | 155 | plt.figure(figsize=(32,2)) 156 | plot(explanation, call_show=False) 157 | plt.ylim(0, 123) 158 | plt.show() 159 | 160 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guide 2 | 3 | **Want to contribute? Great!** 4 | We try to make it easy, and all contributions, even the smaller ones, are more than welcome. 5 | This includes bug reports, fixes, documentation, examples... 6 | But first, read this page (including the small print at the end). 7 | 8 | ## Legal 9 | 10 | All original contributions to TrustyAI-explainability are licensed under the 11 | [ASL - Apache License](https://www.apache.org/licenses/LICENSE-2.0), 12 | version 2.0 or later, or, if another license is specified as governing the file or directory being 13 | modified, such other license. 14 | 15 | ## Issues 16 | 17 | Python TrustyAI uses [GitHub to manage and report issues](https://github.com/trustyai-explainability/trustyai-explainability-python/issues). 18 | 19 | If you believe you found a bug, please indicate a way to reproduce it, what you are seeing and what you would expect to see. 20 | Don't forget to indicate your Python TrustyAI, Java, and Maven version. 21 | 22 | ### Checking an issue is fixed in main 23 | 24 | Sometimes a bug has been fixed in the `main` branch of Python TrustyAI and you want to confirm it is fixed for your own application. 25 | Testing the `main` branch is easy and you can build Python TrustyAI all by yourself. 26 | 27 | ## Creating a Pull Request (PR) 28 | 29 | To contribute, use GitHub Pull Requests, from your **own** fork. 30 | 31 | - PRs should be always related to an open GitHub issue. If there is none, you should create one. 32 | - Try to fix only one issue per PR. 33 | - Make sure to create a new branch. Usually branches are named after the GitHub ticket they are addressing. E.g. for ticket "issue-XYZ An example issue" your branch should be at least prefixed with `FAI-XYZ`. E.g.: 34 | 35 | git checkout -b issue-XYZ 36 | # or 37 | git checkout -b issue-XYZ-my-fix 38 | 39 | - When you submit your PR, make sure to include the ticket ID, and its title; e.g., "issue-XYZ An example issue". 40 | - The description of your PR should describe the code you wrote. The issue that is solved should be at least described properly in the corresponding GitHub ticket. 41 | - If your contribution spans across multiple repositories, use the same branch name (e.g. `issue-XYZ`). 42 | - If your contribution spans across multiple repositories, make sure to list all the related PRs. 43 | 44 | ### Python Coding Guidelines 45 | 46 | PRs will be checked against `black` and `pylint` before passing the CI. 47 | 48 | You can perform these checks locally to guarantee the PR passes these checks. 49 | 50 | ### Requirements for Dependencies 51 | 52 | Any dependency used in the project must fulfill these hard requirements: 53 | 54 | - The dependency must have **an Apache 2.0 compatible license**. 55 | - Good: BSD, MIT, Apache 2.0 56 | - Avoid: EPL, LGPL 57 | - Especially LGPL is a last resort and should be abstracted away or contained behind an SPI. 58 | - Test scope dependencies pose no problem if they are EPL or LPGL. 59 | - Forbidden: no license, GPL, AGPL, proprietary license, field of use restrictions ("this software shall be used for good, not evil"), ... 60 | - Even test scope dependencies cannot use these licenses. 61 | - To check the ALS compatibility license please visit these links:[Similarity in terms to the Apache License 2.0](http://www.apache.org/legal/resolved.html#category-a)  62 | [How should so-called "Weak Copyleft" Licenses be handled](http://www.apache.org/legal/resolved.html#category-b) 63 | 64 | - The dependency shall be **available in PyPi**. 65 | - Why? 66 | - Build reproducibility. Any repository server we use, must still run in future from now. 67 | - Build speed. More repositories slow down the build. 68 | - Build reliability. A repository server that is temporarily down can break builds. 69 | 70 | - **Do not release the dependency yourself** (by building it from source). 71 | - Why? Because it's not an official release, by the official release guys. 72 | - A release must be 100% reproducible. 73 | - A release must be reliable (sometimes the release person does specific things you might not reproduce). 74 | 75 | - **The sources are publicly available** 76 | - We may need to rebuild the dependency from sources ourselves in future. This may be in the rare case when 77 | the dependency is no longer maintained, but we need to fix a specific CVE there. 78 | - Make sure the dependency's pom.xml contains link to the source repository (`scm` tag). 79 | 80 | - The dependency needs to use **reasonable build system** 81 | - Since we may need to rebuild the dependency from sources, we also need to make sure it is easily buildable. 82 | Maven or Gradle are acceptable as build systems. 83 | 84 | - Only use dependencies with **an active community**. 85 | - Check for activity in the last year through [Open Hub](https://www.openhub.net). 86 | 87 | - Less is more: **less dependencies is better**. Bloat is bad. 88 | - Try to use existing dependencies if the functionality is available in those dependencies 89 | - For example: use Apache Commons Math instead of Colt if Apache Commons Math is already a dependency 90 | 91 | There are currently a few dependencies which violate some of these rules. They should be properly commented with a 92 | warning and explaining why are needed 93 | If you want to add a dependency that violates any of the rules above, get approval from the project leads. 94 | 95 | ### Tests and Documentation 96 | 97 | Don't forget to include tests in your pull requests, and documentation (reference documentation, ...). 98 | Guides and reference documentation should be submitted to the [Python TrustyAI examples repository](https://github.com/trustyai-explainability/trustyai-explainability-python-examples). 99 | If you are contributing a new feature, we strongly advise submitting an example. 100 | 101 | ### Code Reviews and Continuous Integration 102 | 103 | All submissions, including those by project members, need to be reviewed by others before being merged. Our CI, GitHub Actions, should successfully execute your PR, marking the GitHub check as green. 104 | 105 | ## Feature Proposals 106 | 107 | If you would like to see some feature in Python TrustyAI, just open a feature request and tell us what you would like to see. 108 | Alternatively, you propose it during the [TrustyAI community meeting](https://github.com/trustyai-explainability/community). 109 | 110 | Great feature proposals should include a short **Description** of the feature, the **Motivation** that makes that feature necessary and the **Goals** that are achieved by realizing it. If the feature is deemed worthy, then an Epic will be created. 111 | 112 | ## The small print 113 | 114 | This project is an open source project, please act responsibly, be nice, polite and enjoy! 115 | 116 | -------------------------------------------------------------------------------- /tests/extras/test_tsice.py: -------------------------------------------------------------------------------- 1 | """ Tests for :py:mod:`aix360.algorithms.tsice.TSICEExplainer`. 2 | Original: https://github.com/Trusted-AI/AIX360/blob/master/tests/tsice/test_tsice.py 3 | """ 4 | import unittest 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.ensemble import RandomForestRegressor 9 | from aix360.algorithms.tsutils.tsframe import tsFrame 10 | from aix360.datasets import SunspotDataset 11 | from aix360.algorithms.tsutils.tsperturbers import BlockBootstrapPerturber 12 | from trustyai.explainers.extras.tsice import TSICEExplainer 13 | 14 | 15 | # transform a time series dataset into a supervised learning dataset 16 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 17 | class RandomForestUniVariateForecaster: 18 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 19 | self.n_past = n_past 20 | self.n_future = n_future 21 | self.model = RandomForestRegressor(**RFparams) 22 | 23 | def fit(self, X): 24 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 25 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 26 | self.model = self.model.fit(trainX, trainy) 27 | return self 28 | 29 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 30 | 1 if type(data) is list else data.shape[1] 31 | df = pd.DataFrame(data) 32 | cols = list() 33 | 34 | # input sequence (t-n, ... t-1) 35 | for i in range(n_in, 0, -1): 36 | cols.append(df.shift(i)) 37 | # forecast sequence (t, t+1, ... t+n) 38 | for i in range(0, n_out): 39 | cols.append(df.shift(-i)) 40 | # put it all together 41 | agg = pd.concat(cols, axis=1) 42 | # drop rows with NaN values 43 | if dropnan: 44 | agg.dropna(inplace=True) 45 | return agg.values 46 | 47 | def predict(self, X): 48 | row = X[-self.n_past:].flatten() 49 | y_pred = self.model.predict(np.asarray([row])) 50 | return y_pred 51 | 52 | 53 | class TestTSICEExplainer(unittest.TestCase): 54 | def setUp(self): 55 | # load data 56 | df, schema = SunspotDataset().load_data() 57 | ts = tsFrame( 58 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 59 | ) 60 | 61 | (self.ts_train, self.ts_test) = train_test_split( 62 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 63 | ) 64 | 65 | def test_tsice_with_range(self): 66 | # load model 67 | input_length = 24 68 | forecast_horizon = 4 69 | forecaster = RandomForestUniVariateForecaster( 70 | n_past=input_length, n_future=forecast_horizon 71 | ) 72 | 73 | forecaster.fit(self.ts_train.iloc[-200:]) 74 | 75 | # initialize/fit explainer 76 | observation_length = 12 77 | explainer = TSICEExplainer( 78 | model=forecaster.predict, 79 | explanation_window_start=10, 80 | explanation_window_length=observation_length, 81 | features_to_analyze=[ 82 | "mean", "std" # analyze mean metric from recent time series of lengh 83 | ], 84 | perturbers=[ 85 | BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2), 86 | ], 87 | input_length=input_length, 88 | forecast_lookahead=forecast_horizon, 89 | n_perturbations=30, 90 | ) 91 | 92 | # compute explanations 93 | explanation = explainer.explain( 94 | inputs=self.ts_test.iloc[:80], 95 | ) 96 | 97 | # validate explanation structure 98 | self.assertIn("data_x", explanation.explanation) 99 | self.assertIn("feature_names", explanation.explanation) 100 | self.assertIn("feature_values", explanation.explanation) 101 | self.assertIn("signed_impact", explanation.explanation) 102 | self.assertIn("total_impact", explanation.explanation) 103 | self.assertIn("current_forecast", explanation.explanation) 104 | self.assertIn("current_feature_values", explanation.explanation) 105 | self.assertIn("perturbations", explanation.explanation) 106 | self.assertIn("forecasts_on_perturbations", explanation.explanation) 107 | 108 | def test_tsice_with_latest(self): 109 | # load model 110 | input_length = 24 111 | forecast_horizon = 4 112 | forecaster = RandomForestUniVariateForecaster( 113 | n_past=input_length, n_future=forecast_horizon 114 | ) 115 | 116 | forecaster.fit(self.ts_train.iloc[-200:]) 117 | 118 | # initialize/fit explainer 119 | observation_length = 12 120 | explainer = TSICEExplainer( 121 | model=forecaster.predict, 122 | explanation_window_start=None, 123 | explanation_window_length=observation_length, 124 | features_to_analyze=[ 125 | "mean", # analyze mean metric from recent time series of lengh 126 | "median", # analyze median metric from recent time series of lengh 127 | "std", # analyze std metric from recent time series of lengh 128 | "max_variation", # analyze max_variation metric from recent time series of lengh 129 | "min", 130 | "max", 131 | "range", 132 | "intercept", 133 | "trend", 134 | "rsquared", 135 | ], 136 | perturbers=[ 137 | BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2), 138 | dict( 139 | type="frequency", 140 | window_length=5, 141 | truncate_frequencies=5, 142 | block_length=4, 143 | ), 144 | dict(type="moving-average", window_length=5, lag=5, block_length=4), 145 | dict(type="impute", block_length=4), 146 | dict(type="shift", block_length=4), 147 | ], 148 | input_length=input_length, 149 | forecast_lookahead=forecast_horizon, 150 | n_perturbations=20, 151 | ) 152 | 153 | # compute explanations 154 | explanation = explainer.explain( 155 | inputs=self.ts_test.iloc[:80], 156 | ) 157 | 158 | # validate explanation structure 159 | self.assertIn("data_x", explanation.explanation) 160 | self.assertIn("feature_names", explanation.explanation) 161 | self.assertIn("feature_values", explanation.explanation) 162 | self.assertIn("signed_impact", explanation.explanation) 163 | self.assertIn("total_impact", explanation.explanation) 164 | self.assertIn("current_forecast", explanation.explanation) 165 | self.assertIn("current_feature_values", explanation.explanation) 166 | self.assertIn("perturbations", explanation.explanation) 167 | self.assertIn("forecasts_on_perturbations", explanation.explanation) 168 | 169 | 170 | if __name__ == "__main__": 171 | unittest.main() 172 | -------------------------------------------------------------------------------- /tests/general/test_limeexplainer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code 2 | """LIME explainer test suite""" 3 | 4 | from common import * 5 | 6 | import pytest 7 | 8 | from trustyai.explainers import LimeExplainer 9 | from trustyai.utils import TestModels 10 | from trustyai.model import feature, Model, simple_prediction 11 | from trustyai.metrics import ExplainabilityMetrics 12 | from trustyai.visualizations import plot 13 | 14 | from org.kie.trustyai.explainability.local import ( 15 | LocalExplanationException, 16 | ) 17 | 18 | 19 | def mock_features(n_features: int): 20 | return [mock_feature(i, f"f-num{i}") for i in range(n_features)] 21 | 22 | 23 | def test_empty_prediction(): 24 | """Check if the explanation returned is not null""" 25 | lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1) 26 | inputs = [] 27 | model = TestModels.getSumSkipModel(0) 28 | outputs = model.predict([inputs])[0].outputs 29 | with pytest.raises(LocalExplanationException): 30 | lime_explainer.explain(inputs=inputs, outputs=outputs, model=model) 31 | 32 | 33 | def test_non_empty_input(): 34 | """Test for non-empty input""" 35 | lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1) 36 | features = [feature(name=f"f-num{i}", value=i, dtype="number") for i in range(4)] 37 | 38 | model = TestModels.getSumSkipModel(0) 39 | outputs = model.predict([features])[0].outputs 40 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model) 41 | assert saliency_map is not None 42 | 43 | 44 | def test_sparse_balance(): # pylint: disable=too-many-locals 45 | """Test sparse balance""" 46 | for n_features in range(1, 4): 47 | lime_explainer_no_penalty = LimeExplainer(samples=100, penalise_sparse_balance=False) 48 | 49 | features = mock_features(n_features) 50 | 51 | model = TestModels.getSumSkipModel(0) 52 | outputs = model.predict([features])[0].outputs 53 | 54 | saliency_map_no_penalty = lime_explainer_no_penalty.explain( 55 | inputs=features, outputs=outputs, model=model 56 | ).saliency_map() 57 | 58 | assert saliency_map_no_penalty is not None 59 | 60 | decision_name = "sum-but0" 61 | saliency_no_penalty = saliency_map_no_penalty.get(decision_name) 62 | 63 | lime_explainer = LimeExplainer(samples=100, penalise_sparse_balance=True) 64 | 65 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map() 66 | assert saliency_map is not None 67 | 68 | saliency = saliency_map.get(decision_name) 69 | 70 | for i in range(len(features)): 71 | score = saliency.getPerFeatureImportance().get(i).getScore() 72 | score_no_penalty = ( 73 | saliency_no_penalty.getPerFeatureImportance().get(i).getScore() 74 | ) 75 | assert abs(score) <= abs(score_no_penalty) 76 | 77 | 78 | def test_normalized_weights(): 79 | """Test normalized weights""" 80 | lime_explainer = LimeExplainer(normalise_weights=True, perturbations=2, samples=10) 81 | n_features = 4 82 | features = mock_features(n_features) 83 | model = TestModels.getSumSkipModel(0) 84 | outputs = model.predict([features])[0].outputs 85 | 86 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map() 87 | assert saliency_map is not None 88 | 89 | decision_name = "sum-but0" 90 | saliency = saliency_map.get(decision_name) 91 | per_feature_importance = saliency.getPerFeatureImportance() 92 | for feature_importance in per_feature_importance: 93 | assert -3.0 < feature_importance.getScore() < 3.0 94 | 95 | 96 | def lime_plots(block): 97 | """Test normalized weights""" 98 | lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10) 99 | n_features = 15 100 | features = mock_features(n_features) 101 | model = TestModels.getSumSkipModel(0) 102 | outputs = model.predict([features])[0].outputs 103 | 104 | explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model) 105 | plot(explanation, block=block) 106 | plot(explanation, block=block, render_bokeh=True) 107 | plot(explanation, block=block, output_name="sum-but0") 108 | plot(explanation, block=block, output_name="sum-but0", render_bokeh=True) 109 | 110 | 111 | @pytest.mark.block_plots 112 | def test_lime_plots_blocking(): 113 | lime_plots(True) 114 | 115 | 116 | def test_lime_plots(): 117 | lime_plots(False) 118 | 119 | 120 | def test_lime_v2(): 121 | np.random.seed(0) 122 | data = pd.DataFrame(np.random.rand(1, 5)).values 123 | 124 | model_weights = np.random.rand(5) 125 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 126 | model = Model(predict_function) 127 | 128 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 129 | explanation = explainer.explain(inputs=data, outputs=model(data), model=model) 130 | 131 | for score in explanation.as_dataframe()["output-0"]['Saliency']: 132 | assert score != 0 133 | 134 | for out_name, df in explanation.as_dataframe().items(): 135 | assert "Feature" in df 136 | assert "output" in out_name 137 | assert all([x in str(df) for x in "01234"]) 138 | 139 | 140 | def test_impact_score(): 141 | np.random.seed(0) 142 | data = pd.DataFrame(np.random.rand(1, 5)) 143 | model_weights = np.random.rand(5) 144 | predict_function = lambda x: np.dot(x.values, model_weights) 145 | model = Model(predict_function, dataframe_input=True) 146 | output = model(data) 147 | pred = simple_prediction(data, output) 148 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 149 | explanation = explainer.explain(inputs=data, outputs=output, model=model) 150 | saliency = list(explanation.saliency_map().values())[0] 151 | top_features_t = saliency.getTopFeatures(2) 152 | impact = ExplainabilityMetrics.impactScore(model, pred, top_features_t) 153 | assert impact > 0 154 | return impact 155 | 156 | 157 | def test_lime_as_html(): 158 | np.random.seed(0) 159 | data = np.random.rand(1, 5) 160 | 161 | model_weights = np.random.rand(5) 162 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 163 | 164 | model = Model(predict_function, disable_arrow=True) 165 | 166 | explainer = LimeExplainer() 167 | explainer.explain(inputs=data, outputs=model(data), model=model) 168 | assert True 169 | 170 | explanation = explainer.explain(inputs=data, outputs=model(data), model=model) 171 | for score in explanation.as_dataframe()["output-0"]['Saliency']: 172 | assert score != 0 173 | 174 | 175 | def test_lime_numpy(): 176 | np.random.seed(0) 177 | data = np.random.rand(101, 5) 178 | model_weights = np.random.rand(5) 179 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 180 | fnames = ['f{}'.format(x) for x in "abcde"] 181 | onames = ['o{}'.format(x) for x in "12"] 182 | model = Model(predict_function, 183 | feature_names=fnames, 184 | output_names=onames 185 | ) 186 | 187 | explainer = LimeExplainer() 188 | explanation = explainer.explain(inputs=data[0], outputs=model(data[0]), model=model) 189 | 190 | for oname in onames: 191 | assert oname in explanation.as_dataframe().keys() 192 | for fname in fnames: 193 | assert fname in explanation.as_dataframe()[oname]['Feature'].values 194 | 195 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """Common methods and models for tests""" 3 | import os 4 | import sys 5 | import pytest 6 | import time 7 | import numpy as np 8 | 9 | from trustyai.explainers import LimeExplainer, SHAPExplainer 10 | from trustyai.model import feature, PredictionInput 11 | from trustyai.utils import TestModels 12 | from trustyai.metrics.saliency import mean_impact_score, classification_fidelity, local_saliency_f1 13 | 14 | from org.kie.trustyai.explainability.model import ( 15 | PredictionInputsDataDistribution, 16 | ) 17 | 18 | myPath = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.insert(0, myPath + "/../general/") 20 | 21 | import test_counterfactualexplainer as tcf 22 | 23 | @pytest.mark.benchmark( 24 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 25 | ) 26 | def test_counterfactual_match(benchmark): 27 | """Counterfactual match""" 28 | benchmark(tcf.test_counterfactual_match) 29 | 30 | 31 | @pytest.mark.benchmark( 32 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 33 | ) 34 | def test_non_empty_input(benchmark): 35 | """Counterfactual non-empty input""" 36 | benchmark(tcf.test_non_empty_input) 37 | 38 | 39 | @pytest.mark.benchmark( 40 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 41 | ) 42 | def test_counterfactual_match_python_model(benchmark): 43 | """Counterfactual match (Python model)""" 44 | benchmark(tcf.test_counterfactual_match_python_model) 45 | 46 | 47 | @pytest.mark.benchmark( 48 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 49 | ) 50 | def test_sumskip_lime_impact_score_at_2(benchmark): 51 | no_of_features = 10 52 | np.random.seed(0) 53 | explainer = LimeExplainer() 54 | model = TestModels.getSumSkipModel(0) 55 | data = [] 56 | for i in range(100): 57 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)]) 58 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 59 | benchmark(mean_impact_score, explainer, model, data) 60 | 61 | 62 | @pytest.mark.benchmark( 63 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 64 | ) 65 | def test_sumskip_shap_impact_score_at_2(benchmark): 66 | no_of_features = 10 67 | np.random.seed(0) 68 | background = [] 69 | for i in range(10): 70 | background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)])) 71 | explainer = SHAPExplainer(background, samples=10000) 72 | model = TestModels.getSumSkipModel(0) 73 | data = [] 74 | for i in range(100): 75 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)]) 76 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 77 | benchmark(mean_impact_score, explainer, model, data) 78 | 79 | 80 | @pytest.mark.benchmark( 81 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 82 | ) 83 | def test_sumthreshold_lime_impact_score_at_2(benchmark): 84 | no_of_features = 10 85 | np.random.seed(0) 86 | explainer = LimeExplainer() 87 | center = 100.0 88 | epsilon = 10.0 89 | model = TestModels.getSumThresholdModel(center, epsilon) 90 | data = [] 91 | for i in range(100): 92 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 93 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 94 | benchmark(mean_impact_score, explainer, model, data) 95 | 96 | 97 | @pytest.mark.benchmark( 98 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 99 | ) 100 | def test_sumthreshold_shap_impact_score_at_2(benchmark): 101 | no_of_features = 10 102 | np.random.seed(0) 103 | background = [] 104 | for i in range(100): 105 | background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 106 | explainer = SHAPExplainer(background, samples=10000) 107 | center = 100.0 108 | epsilon = 10.0 109 | model = TestModels.getSumThresholdModel(center, epsilon) 110 | data = [] 111 | for i in range(100): 112 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 113 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 114 | benchmark(mean_impact_score, explainer, model, data) 115 | 116 | 117 | @pytest.mark.benchmark( 118 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 119 | ) 120 | def test_lime_fidelity(benchmark): 121 | no_of_features = 10 122 | np.random.seed(0) 123 | explainer = LimeExplainer() 124 | model = TestModels.getEvenSumModel(0) 125 | data = [] 126 | for i in range(100): 127 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 128 | benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data) 129 | benchmark(classification_fidelity, explainer, model, data) 130 | 131 | 132 | @pytest.mark.benchmark( 133 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 134 | ) 135 | def test_shap_fidelity(benchmark): 136 | no_of_features = 10 137 | np.random.seed(0) 138 | background = [] 139 | for i in range(10): 140 | background.append(PredictionInput( 141 | [feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in 142 | range(no_of_features)])) 143 | explainer = SHAPExplainer(background, samples=10000) 144 | model = TestModels.getEvenSumModel(0) 145 | data = [] 146 | for i in range(100): 147 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in 148 | range(no_of_features)]) 149 | benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data) 150 | benchmark(classification_fidelity, explainer, model, data) 151 | 152 | 153 | @pytest.mark.benchmark( 154 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 155 | ) 156 | def test_lime_local_saliency_f1(benchmark): 157 | no_of_features = 10 158 | np.random.seed(0) 159 | explainer = LimeExplainer() 160 | model = TestModels.getEvenSumModel(0) 161 | output_name = "sum-even-but0" 162 | data = [] 163 | for i in range(100): 164 | data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 165 | distribution = PredictionInputsDataDistribution(data) 166 | benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10) 167 | benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10) 168 | 169 | 170 | @pytest.mark.benchmark( 171 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 172 | ) 173 | def test_shap_local_saliency_f1(benchmark): 174 | no_of_features = 10 175 | np.random.seed(0) 176 | background = [] 177 | for i in range(10): 178 | background.append(PredictionInput( 179 | [feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in 180 | range(no_of_features)])) 181 | explainer = SHAPExplainer(background, samples=10000) 182 | model = TestModels.getEvenSumModel(0) 183 | output_name = "sum-even-but0" 184 | data = [] 185 | for i in range(100): 186 | data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 187 | distribution = PredictionInputsDataDistribution(data) 188 | benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10) 189 | benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10) -------------------------------------------------------------------------------- /src/trustyai/visualizations/shap.py: -------------------------------------------------------------------------------- 1 | """Visualizations.shap module""" 2 | 3 | # pylint: disable = import-error, consider-using-f-string, too-few-public-methods, missing-final-newline 4 | import matplotlib.pyplot as plt 5 | import matplotlib as mpl 6 | from bokeh.models import ColumnDataSource, HoverTool 7 | from bokeh.plotting import figure 8 | import pandas as pd 9 | import numpy as np 10 | 11 | from trustyai.utils._visualisation import ( 12 | DEFAULT_STYLE as ds, 13 | DEFAULT_RC_PARAMS as drcp, 14 | bold_red_html, 15 | bold_green_html, 16 | output_html, 17 | feature_html, 18 | ) 19 | from trustyai.visualizations.visualization_results import VisualizationResults 20 | 21 | 22 | class SHAPViz(VisualizationResults): 23 | """Visualizes SHAP results.""" 24 | 25 | def _matplotlib_plot( 26 | self, explanations, output_name=None, block=True, call_show=True 27 | ) -> None: 28 | """Visualize the SHAP explanation of each output as a set of candlestick plots, 29 | one per output.""" 30 | with mpl.rc_context(drcp): 31 | shap_values = [ 32 | pfi.getScore() 33 | for pfi in explanations.saliency_map()[ 34 | output_name 35 | ].getPerFeatureImportance()[:-1] 36 | ] 37 | feature_names = [ 38 | str(pfi.getFeature().getName()) 39 | for pfi in explanations.saliency_map()[ 40 | output_name 41 | ].getPerFeatureImportance()[:-1] 42 | ] 43 | fnull = explanations.get_fnull()[output_name] 44 | prediction = fnull + sum(shap_values) 45 | 46 | if call_show: 47 | plt.figure() 48 | pos = fnull 49 | for j, shap_value in enumerate(shap_values): 50 | color = ( 51 | ds["negative_primary_colour"] 52 | if shap_value < 0 53 | else ds["positive_primary_colour"] 54 | ) 55 | width = 0.9 56 | if j > 0: 57 | plt.plot([j - 0.5, j + width / 2 * 0.99], [pos, pos], color=color) 58 | plt.bar(j, height=shap_value, bottom=pos, color=color, width=width) 59 | pos += shap_values[j] 60 | 61 | if j != len(shap_values) - 1: 62 | plt.plot([j - width / 2 * 0.99, j + 0.5], [pos, pos], color=color) 63 | 64 | plt.axhline( 65 | fnull, 66 | color="#444444", 67 | linestyle="--", 68 | zorder=0, 69 | label="Background Value", 70 | ) 71 | plt.axhline(prediction, color="#444444", zorder=0, label="Prediction") 72 | plt.legend() 73 | 74 | ticksize = np.diff(plt.gca().get_yticks())[0] 75 | plt.ylim( 76 | plt.gca().get_ylim()[0] - ticksize / 2, 77 | plt.gca().get_ylim()[1] + ticksize / 2, 78 | ) 79 | plt.xticks(np.arange(len(feature_names)), feature_names) 80 | plt.ylabel(explanations.saliency_map()[output_name].getOutput().getName()) 81 | plt.xlabel("Feature SHAP Value") 82 | plt.title(f"SHAP: Feature Contributions to {output_name}") 83 | if call_show: 84 | plt.show(block=block) 85 | 86 | def _get_bokeh_plot(self, explanations, output_name): 87 | fnull = explanations.get_fnull()[output_name] 88 | 89 | # create dataframe of plot values 90 | data_source = pd.DataFrame( 91 | [ 92 | { 93 | "feature": str(pfi.getFeature().getName()), 94 | "saliency": pfi.getScore(), 95 | } 96 | for pfi in explanations.saliency_map()[ 97 | output_name 98 | ].getPerFeatureImportance()[:-1] 99 | ] 100 | ) 101 | prediction = fnull + data_source["saliency"].sum() 102 | 103 | data_source["color"] = data_source["saliency"].apply( 104 | lambda x: ( 105 | ds["positive_primary_colour"] 106 | if x >= 0 107 | else ds["negative_primary_colour"] 108 | ) 109 | ) 110 | data_source["color_faded"] = data_source["saliency"].apply( 111 | lambda x: ( 112 | ds["positive_primary_colour_faded"] 113 | if x >= 0 114 | else ds["negative_primary_colour_faded"] 115 | ) 116 | ) 117 | data_source["index"] = data_source.index 118 | data_source["saliency_text"] = data_source["saliency"].apply( 119 | lambda x: (bold_red_html if x <= 0 else bold_green_html)("{:.2f}".format(x)) 120 | ) 121 | data_source["bottom"] = pd.Series( 122 | [fnull] + data_source["saliency"].iloc[0:-1].tolist() 123 | ).cumsum() 124 | data_source["top"] = data_source["bottom"] + data_source["saliency"] 125 | 126 | # create hovertools 127 | htool_fnull = HoverTool( 128 | name="fnull", 129 | tooltips=("

SHAP

Baseline {}: {}").format( 130 | output_name, output_html("{:.2f}".format(fnull)) 131 | ), 132 | line_policy="interp", 133 | ) 134 | htool_pred = HoverTool( 135 | name="pred", 136 | tooltips=("

SHAP

Predicted {}: {}").format( 137 | output_name, output_html("{:.2f}".format(prediction)) 138 | ), 139 | line_policy="interp", 140 | ) 141 | htool_bars = HoverTool( 142 | name="bars", 143 | tooltips="

SHAP

{} contributions to {}: @saliency_text".format( 144 | feature_html("@feature"), output_html(output_name) 145 | ), 146 | ) 147 | 148 | # create plot 149 | bokeh_plot = figure( 150 | sizing_mode="stretch_both", 151 | title="SHAP Feature Contributions", 152 | x_range=data_source["feature"], 153 | tools=[htool_pred, htool_fnull, htool_bars], 154 | ) 155 | 156 | # add fnull and background lines 157 | line_data_source = ColumnDataSource( 158 | pd.DataFrame( 159 | [ 160 | {"x": 0, "pred": prediction}, 161 | {"x": len(data_source), "pred": prediction}, 162 | ] 163 | ) 164 | ) 165 | fnull_data_source = ColumnDataSource( 166 | pd.DataFrame( 167 | [{"x": 0, "fnull": fnull}, {"x": len(data_source), "fnull": fnull}] 168 | ) 169 | ) 170 | 171 | bokeh_plot.line( 172 | x="x", 173 | y="fnull", 174 | line_color="#999", 175 | hover_line_color="#333", 176 | line_width=2, 177 | hover_line_width=4, 178 | line_dash="dashed", 179 | name="fnull", 180 | source=fnull_data_source, 181 | ) 182 | bokeh_plot.line( 183 | x="x", 184 | y="pred", 185 | line_color="#999", 186 | hover_line_color="#333", 187 | line_width=2, 188 | hover_line_width=4, 189 | name="pred", 190 | source=line_data_source, 191 | ) 192 | 193 | # create candlestick plot lines 194 | bokeh_plot.line( 195 | x=[0.5, 1], 196 | y=data_source.iloc[0]["top"], 197 | color=data_source.iloc[0]["color"], 198 | ) 199 | for i in range(1, len(data_source)): 200 | # bar left line 201 | bokeh_plot.line( 202 | x=[i, i + 0.5], 203 | y=data_source.iloc[i]["bottom"], 204 | color=data_source.iloc[i]["color"], 205 | ) 206 | # bar right line 207 | if i != len(data_source) - 1: 208 | bokeh_plot.line( 209 | x=[i + 0.5, i + 1], 210 | y=data_source.iloc[i]["top"], 211 | color=data_source.iloc[i]["color"], 212 | ) 213 | 214 | # create candles 215 | bokeh_plot.vbar( 216 | x="feature", 217 | bottom="bottom", 218 | top="top", 219 | hover_color="color", 220 | color="color_faded", 221 | width=0.75, 222 | name="bars", 223 | source=data_source, 224 | ) 225 | bokeh_plot.yaxis.axis_label = str(output_name) 226 | return bokeh_plot 227 | 228 | def _get_bokeh_plot_dict(self, explanations): 229 | return { 230 | decision: self._get_bokeh_plot(explanations, decision) 231 | for decision in explanations.saliency_map().keys() 232 | } 233 | -------------------------------------------------------------------------------- /src/trustyai/metrics/fairness/group.py: -------------------------------------------------------------------------------- 1 | """Group fairness metrics""" 2 | 3 | # pylint: disable = import-error 4 | from typing import List, Optional, Any, Union 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from jpype import JInt 9 | from org.kie.trustyai.metrics.fairness.group import ( 10 | DisparateImpactRatio, 11 | GroupStatisticalParityDifference, 12 | GroupAverageOddsDifference, 13 | GroupAveragePredictiveValueDifference, 14 | ) 15 | 16 | from trustyai.model import Value, PredictionProvider, Model 17 | from trustyai.utils.data_conversions import ( 18 | OneOutputUnionType, 19 | one_output_convert, 20 | to_trusty_dataframe, 21 | ) 22 | 23 | ColumSelector = Union[List[int], List[str]] 24 | 25 | 26 | def _column_selector_to_index(columns: ColumSelector, dataframe: pd.DataFrame): 27 | """Returns a list of input and output indices, given an index size and output indices""" 28 | if len(columns) == 0: 29 | raise ValueError("Must specify at least one column") 30 | 31 | if isinstance(columns[0], str): # passing column 32 | columns = dataframe.columns.get_indexer(columns) 33 | indices = [JInt(c) for c in columns] # Java casting 34 | return indices 35 | 36 | 37 | def statistical_parity_difference( 38 | privileged: Union[pd.DataFrame, np.ndarray], 39 | unprivileged: Union[pd.DataFrame, np.ndarray], 40 | favorable: OneOutputUnionType, 41 | outputs: Optional[List[int]] = None, 42 | feature_names: Optional[List[str]] = None, 43 | ) -> float: 44 | """Calculate Statistical Parity Difference between privileged and unprivileged dataframes""" 45 | favorable_prediction_object = one_output_convert(favorable) 46 | return GroupStatisticalParityDifference.calculate( 47 | to_trusty_dataframe( 48 | data=privileged, outputs=outputs, feature_names=feature_names 49 | ), 50 | to_trusty_dataframe( 51 | data=unprivileged, outputs=outputs, feature_names=feature_names 52 | ), 53 | favorable_prediction_object.outputs, 54 | ) 55 | 56 | 57 | # pylint: disable = line-too-long, too-many-arguments 58 | def statistical_parity_difference_model( 59 | samples: Union[pd.DataFrame, np.ndarray], 60 | model: Union[PredictionProvider, Model], 61 | privilege_columns: ColumSelector, 62 | privilege_values: List[Any], 63 | favorable: OneOutputUnionType, 64 | feature_names: Optional[List[str]] = None, 65 | ) -> float: 66 | """Calculate Statistical Parity Difference using a samples dataframe and a model""" 67 | favorable_prediction_object = one_output_convert(favorable) 68 | _privilege_values = [Value(v) for v in privilege_values] 69 | _jsamples = to_trusty_dataframe( 70 | data=samples, no_outputs=True, feature_names=feature_names 71 | ) 72 | return GroupStatisticalParityDifference.calculate( 73 | _jsamples, 74 | model, 75 | _column_selector_to_index(privilege_columns, samples), 76 | _privilege_values, 77 | favorable_prediction_object.outputs, 78 | ) 79 | 80 | 81 | def disparate_impact_ratio( 82 | privileged: Union[pd.DataFrame, np.ndarray], 83 | unprivileged: Union[pd.DataFrame, np.ndarray], 84 | favorable: OneOutputUnionType, 85 | outputs: Optional[List[int]] = None, 86 | feature_names: Optional[List[str]] = None, 87 | ) -> float: 88 | """Calculate Disparate Impact Ration between privileged and unprivileged dataframes""" 89 | favorable_prediction_object = one_output_convert(favorable) 90 | return DisparateImpactRatio.calculate( 91 | to_trusty_dataframe( 92 | data=privileged, outputs=outputs, feature_names=feature_names 93 | ), 94 | to_trusty_dataframe( 95 | data=unprivileged, outputs=outputs, feature_names=feature_names 96 | ), 97 | favorable_prediction_object.outputs, 98 | ) 99 | 100 | 101 | # pylint: disable = line-too-long 102 | def disparate_impact_ratio_model( 103 | samples: Union[pd.DataFrame, np.ndarray], 104 | model: Union[PredictionProvider, Model], 105 | privilege_columns: ColumSelector, 106 | privilege_values: List[Any], 107 | favorable: OneOutputUnionType, 108 | feature_names: Optional[List[str]] = None, 109 | ) -> float: 110 | """Calculate Disparate Impact Ration using a samples dataframe and a model""" 111 | favorable_prediction_object = one_output_convert(favorable) 112 | _privilege_values = [Value(v) for v in privilege_values] 113 | _jsamples = to_trusty_dataframe( 114 | data=samples, no_outputs=True, feature_names=feature_names 115 | ) 116 | return DisparateImpactRatio.calculate( 117 | _jsamples, 118 | model, 119 | _column_selector_to_index(privilege_columns, samples), 120 | _privilege_values, 121 | favorable_prediction_object.outputs, 122 | ) 123 | 124 | 125 | # pylint: disable = too-many-arguments 126 | def average_odds_difference( 127 | test: Union[pd.DataFrame, np.ndarray], 128 | truth: Union[pd.DataFrame, np.ndarray], 129 | privilege_columns: ColumSelector, 130 | privilege_values: OneOutputUnionType, 131 | positive_class: List[Any], 132 | outputs: Optional[List[int]] = None, 133 | feature_names: Optional[List[str]] = None, 134 | ) -> float: 135 | """Calculate Average Odds between two dataframes""" 136 | if test.shape != truth.shape: 137 | raise ValueError( 138 | f"Dataframes have different shapes ({test.shape} and {truth.shape})" 139 | ) 140 | _privilege_values = [Value(v) for v in privilege_values] 141 | _positive_class = [Value(v) for v in positive_class] 142 | # determine privileged columns 143 | _privilege_columns = _column_selector_to_index(privilege_columns, test) 144 | return GroupAverageOddsDifference.calculate( 145 | to_trusty_dataframe(data=test, outputs=outputs, feature_names=feature_names), 146 | to_trusty_dataframe(data=truth, outputs=outputs, feature_names=feature_names), 147 | _privilege_columns, 148 | _privilege_values, 149 | _positive_class, 150 | ) 151 | 152 | 153 | def average_odds_difference_model( 154 | samples: Union[pd.DataFrame, np.ndarray], 155 | model: Union[PredictionProvider, Model], 156 | privilege_columns: ColumSelector, 157 | privilege_values: List[Any], 158 | positive_class: List[Any], 159 | feature_names: Optional[List[str]] = None, 160 | ) -> float: 161 | """Calculate Average Odds for a sample dataframe using the provided model""" 162 | _jsamples = to_trusty_dataframe( 163 | data=samples, no_outputs=True, feature_names=feature_names 164 | ) 165 | _privilege_values = [Value(v) for v in privilege_values] 166 | _positive_class = [Value(v) for v in positive_class] 167 | # determine privileged columns 168 | _privilege_columns = _column_selector_to_index(privilege_columns, samples) 169 | return GroupAverageOddsDifference.calculate( 170 | _jsamples, model, _privilege_columns, _privilege_values, _positive_class 171 | ) 172 | 173 | 174 | def average_predictive_value_difference( 175 | test: Union[pd.DataFrame, np.ndarray], 176 | truth: Union[pd.DataFrame, np.ndarray], 177 | privilege_columns: ColumSelector, 178 | privilege_values: List[Any], 179 | positive_class: List[Any], 180 | outputs: Optional[List[int]] = None, 181 | feature_names: Optional[List[str]] = None, 182 | ) -> float: 183 | """Calculate Average Predictive Value Difference between two dataframes""" 184 | if test.shape != truth.shape: 185 | raise ValueError( 186 | f"Dataframes have different shapes ({test.shape} and {truth.shape})" 187 | ) 188 | _privilege_values = [Value(v) for v in privilege_values] 189 | _positive_class = [Value(v) for v in positive_class] 190 | _privilege_columns = _column_selector_to_index(privilege_columns, test) 191 | return GroupAveragePredictiveValueDifference.calculate( 192 | to_trusty_dataframe(data=test, outputs=outputs, feature_names=feature_names), 193 | to_trusty_dataframe(data=truth, outputs=outputs, feature_names=feature_names), 194 | _privilege_columns, 195 | _privilege_values, 196 | _positive_class, 197 | ) 198 | 199 | 200 | # pylint: disable = line-too-long 201 | def average_predictive_value_difference_model( 202 | samples: Union[pd.DataFrame, np.ndarray], 203 | model: Union[PredictionProvider, Model], 204 | privilege_columns: ColumSelector, 205 | privilege_values: List[Any], 206 | positive_class: List[Any], 207 | ) -> float: 208 | """Calculate Average Predictive Value Difference for a sample dataframe using the provided model""" 209 | _jsamples = to_trusty_dataframe(samples, no_outputs=True) 210 | _privilege_values = [Value(v) for v in privilege_values] 211 | _positive_class = [Value(v) for v in positive_class] 212 | # determine privileged columns 213 | _privilege_columns = _column_selector_to_index(privilege_columns, samples) 214 | return GroupAveragePredictiveValueDifference.calculate( 215 | _jsamples, model, _privilege_columns, _privilege_values, _positive_class 216 | ) 217 | -------------------------------------------------------------------------------- /src/trustyai/utils/extras/metrics_service.py: -------------------------------------------------------------------------------- 1 | """Python client for TrustyAI metrics""" 2 | 3 | from typing import List 4 | import json 5 | import datetime as dt 6 | import pandas as pd 7 | import requests 8 | import matplotlib.pyplot as plt 9 | 10 | from trustyai.utils.api.api import TrustyAIApi 11 | 12 | 13 | def json_to_df(data_path: str, batch_list: List[int]) -> pd.DataFrame: 14 | """ 15 | Converts batched data in json files to a single pandas DataFrame 16 | """ 17 | final_df = pd.DataFrame() 18 | for batch in batch_list: 19 | file = data_path + f"{batch}.json" 20 | with open(file, encoding="utf8") as train_file: 21 | batch_data = json.load(train_file)["inputs"][0] 22 | batch_df = pd.DataFrame.from_dict(batch_data["data"]).T 23 | final_df = pd.concat([final_df, batch_df]) 24 | return final_df 25 | 26 | 27 | def df_to_json(final_df: pd.DataFrame, name: str, json_file: str) -> None: 28 | """ 29 | Converts pandas DataFrame to json file 30 | """ 31 | inputs = [ 32 | { 33 | "name": name, 34 | "shape": list(final_df.shape), 35 | "datatype": "FP64", 36 | "data": final_df.values.tolist(), 37 | } 38 | ] 39 | data_dict = {"inputs": inputs} 40 | with open(json_file, "w", encoding="utf8") as outfile: 41 | json.dump(data_dict, outfile) 42 | 43 | 44 | class TrustyAIMetricsService: 45 | """ 46 | Executes and returns queries from TrustyAI service on ODH 47 | """ 48 | 49 | def __init__(self, token: str, namespace: str, verify=True): 50 | """ 51 | :param token: OpenShift login token 52 | :param namespace: model namespace 53 | :param verify: enable SSL verification for requests 54 | """ 55 | self.token = token 56 | self.namespace = namespace 57 | self.trusty_url = TrustyAIApi().get_service_route( 58 | name="trustyai-service", namespace=self.namespace 59 | ) 60 | self.thanos_url = TrustyAIApi().get_service_route( 61 | name="thanos-querier", namespace="openshift-monitoring" 62 | ) 63 | self.headers = { 64 | "Authorization": "Bearer " + token, 65 | "Content-Type": "application/json", 66 | } 67 | self.verify = verify 68 | 69 | def upload_payload_data(self, json_file: str, timeout=5) -> None: 70 | """ 71 | Uploads data to TrustyAI service 72 | """ 73 | with open(json_file, "r", encoding="utf8") as file: 74 | response = requests.post( 75 | f"{self.trusty_url}/data/upload", 76 | data=file, 77 | headers=self.headers, 78 | verify=self.verify, 79 | timeout=timeout, 80 | ) 81 | if response.status_code == 200: 82 | print("Data sucessfully uploaded to TrustyAI service") 83 | else: 84 | print(f"Error {response.status_code}: {response.reason}") 85 | 86 | def get_model_metadata(self, timeout=5): 87 | """ 88 | Retrieves model data from TrustyAI 89 | """ 90 | response = requests.get( 91 | f"{self.trusty_url}/info", 92 | headers=self.headers, 93 | verify=self.verify, 94 | timeout=timeout, 95 | ) 96 | if response.status_code == 200: 97 | model_metadata = json.loads(response.text) 98 | return model_metadata 99 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 100 | 101 | def label_data_fields(self, payload: str, timeout=5): 102 | """ 103 | Assigns feature names to model input data 104 | """ 105 | 106 | def print_name_mapping(self): 107 | response = requests.get( 108 | f"{self.trusty_url}/info", 109 | headers=self.headers, 110 | verify=self.verify, 111 | timeout=timeout, 112 | ) 113 | name_mapping = json.loads(response.text)[0] 114 | for key, val in name_mapping["data"]["inputSchema"]["nameMapping"].items(): 115 | print(f"{key} -> {val}") 116 | 117 | response = requests.get( 118 | f"{self.trusty_url}/info", 119 | headers=self.headers, 120 | verify=self.verify, 121 | timeout=timeout, 122 | ) 123 | input_data_fields = list( 124 | json.loads(response.text)[0]["data"]["inputSchema"]["items"].keys() 125 | ) 126 | input_mapping_keys = list(payload["inputMapping"].keys()) 127 | if len(list(set(input_mapping_keys) - set(input_data_fields))) == 0: 128 | response = requests.post( 129 | f"{self.trusty_url}/info/names", 130 | json=payload, 131 | headers=self.headers, 132 | verify=self.verify, 133 | timeout=timeout, 134 | ) 135 | if response.status_code == 200: 136 | print_name_mapping(self) 137 | return response.text 138 | print(f"Error {response.status_code}: {response.reason}") 139 | raise ValueError("Field does not exist") 140 | 141 | def get_metric_request( 142 | self, payload: str, metric: str, reoccuring: bool, timeout=5 143 | ): 144 | """ 145 | Retrieve or schedule a metric request 146 | """ 147 | if reoccuring: 148 | response = requests.post( 149 | f"{self.trusty_url}/metrics/{metric}/request", 150 | json=payload, 151 | headers=self.headers, 152 | verify=self.verify, 153 | timeout=timeout, 154 | ) 155 | else: 156 | response = requests.post( 157 | f"{self.trusty_url}/metrics/{metric}", 158 | json=payload, 159 | headers=self.headers, 160 | verify=self.verify, 161 | timeout=timeout, 162 | ) 163 | if response.status_code == 200: 164 | return response.text 165 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 166 | 167 | def upload_data_to_model(self, model_name: str, json_file: str, timeout=5): 168 | """ 169 | Sends an inference request to the model 170 | """ 171 | model_route = TrustyAIApi().get_service_route( 172 | name=model_name, namespace=self.namespace 173 | ) 174 | with open(json_file, encoding="utf8") as batch_file: 175 | response = requests.post( 176 | url=f"https://{model_route}/infer", 177 | data=batch_file, 178 | headers=self.headers, 179 | verify=self.verify, 180 | timeout=timeout, 181 | ) 182 | if response.status_code == 200: 183 | return response.text 184 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 185 | 186 | def get_metric_data(self, metric: str, time_interval: List[str], timeout=5): 187 | """ 188 | Retrives metric data for a specific range in time for each subcategory in data field 189 | """ 190 | metric_df = pd.DataFrame() 191 | for subcategory in list( 192 | self.get_model_metadata()[0]["data"]["inputSchema"]["nameMapping"].values() 193 | ): 194 | params = { 195 | "query": f"{metric}{{subcategory='{subcategory}'}}{time_interval}" 196 | } 197 | 198 | response = requests.get( 199 | f"{self.thanos_url}/api/v1/query?", 200 | params=params, 201 | headers=self.headers, 202 | verify=self.verify, 203 | timeout=timeout, 204 | ) 205 | if response.status_code == 200: 206 | if "timestamp" in metric_df.columns: 207 | pass 208 | else: 209 | metric_df["timestamp"] = [ 210 | item[0] 211 | for item in json.loads(response.text)["data"]["result"][0][ 212 | "values" 213 | ] 214 | ] 215 | metric_df[subcategory] = [ 216 | item[1] 217 | for item in json.loads(response.text)["data"]["result"][0]["values"] 218 | ] 219 | else: 220 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 221 | 222 | metric_df["timestamp"] = metric_df["timestamp"].apply( 223 | lambda epoch: dt.datetime.fromtimestamp(epoch).strftime("%Y-%m-%d %H:%M:%S") 224 | ) 225 | return metric_df 226 | 227 | @staticmethod 228 | def plot_metric(metric_df: pd.DataFrame, metric: str): 229 | """ 230 | Plots a line for each subcategory in the pandas DataFrame returned by get_metric_request 231 | with the timestamp on x-axis and specified metric on the y-axis 232 | """ 233 | plt.figure(figsize=(12, 5)) 234 | for col in metric_df.columns[1:]: 235 | plt.plot(metric_df["timestamp"], metric_df[col]) 236 | plt.xlabel("timestamp") 237 | plt.ylabel(metric) 238 | plt.xticks(rotation=45) 239 | plt.legend(metric_df.columns[1:]) 240 | plt.tight_layout() 241 | plt.show() 242 | -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tsice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSICEExplainer from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | 6 | # pylint: disable=too-many-arguments,import-error 7 | from typing import Callable, List, Optional, Union 8 | 9 | from aix360.algorithms.tsice import TSICEExplainer as TSICEExplainerAIX 10 | from aix360.algorithms.tsutils.tsperturbers import TSPerturber 11 | import pandas as pd 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from sklearn.linear_model import LinearRegression 15 | 16 | from trustyai.explainers.explanation_results import ExplanationResults 17 | 18 | 19 | class TSICEResults(ExplanationResults): 20 | """Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`, 21 | and provides a variety of methods to visualize and interact with the explanation. 22 | """ 23 | 24 | def __init__(self, explanation): 25 | self.explanation = explanation 26 | 27 | def as_dataframe(self) -> pd.DataFrame: 28 | """Returns the explanation as a pandas dataframe.""" 29 | # Initialize an empty DataFrame 30 | dataframe = pd.DataFrame() 31 | 32 | # Loop through each feature_name and each key in data_x 33 | for key in self.explanation["data_x"]: 34 | for i, feature in enumerate(self.explanation["feature_names"]): 35 | dataframe[f"{key}-{feature}"] = [ 36 | val[0] for val in self.explanation["feature_values"][i] 37 | ] 38 | 39 | # Add "total_impact" as a column 40 | dataframe["total_impact"] = self.explanation["total_impact"] 41 | return dataframe 42 | 43 | def as_html(self) -> pd.io.formats.style.Styler: 44 | """Returns the explanation as an HTML table.""" 45 | dataframe = self.as_dataframe() 46 | return dataframe.style 47 | 48 | def plot_forecast(self, variable): # pylint: disable=too-many-locals 49 | """Plots the explanation. 50 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py 51 | """ 52 | forecast_horizon = self.explanation["current_forecast"].shape[0] 53 | original_ts = pd.DataFrame( 54 | data={variable: self.explanation["data_x"][variable]} 55 | ) 56 | perturbations = [d for d in self.explanation["perturbations"] if variable in d] 57 | 58 | # Generate a list of keys 59 | keys = list(self.explanation["data_x"].keys()) 60 | # Find the index of the given key 61 | key = keys.index(variable) 62 | forecasts_on_perturbations = [ 63 | arr[:, key : key + 1] 64 | for arr in self.explanation["forecasts_on_perturbations"] 65 | ] 66 | 67 | new_perturbations = [] 68 | new_timestamps = [] 69 | pred_ts = [] 70 | 71 | original_ts.index.freq = pd.infer_freq(original_ts.index) 72 | for i in range(1, forecast_horizon + 1): 73 | new_timestamps.append(original_ts.index[-1] + (i * original_ts.index.freq)) 74 | 75 | for perturbation in perturbations: 76 | new_perturbations.append(pd.DataFrame(perturbation)) 77 | 78 | for forecast in forecasts_on_perturbations: 79 | pred_ts.append(pd.DataFrame(forecast, index=new_timestamps)) 80 | 81 | current_forecast = self.explanation["current_forecast"][:, key : key + 1] 82 | pred_original_ts = pd.DataFrame(current_forecast, index=new_timestamps) 83 | 84 | _, axis = plt.subplots() 85 | 86 | # Plot perturbed time series 87 | axis = self._plot_timeseries( 88 | new_perturbations, 89 | color="lightgreen", 90 | axis=axis, 91 | name="perturbed timeseries samples", 92 | ) 93 | 94 | # Plot original time series 95 | axis = self._plot_timeseries( 96 | original_ts, color="green", axis=axis, name="input/original timeseries" 97 | ) 98 | 99 | # Plot varying forecast range 100 | axis = self._plot_timeseries( 101 | pred_ts, color="lightblue", axis=axis, name="forecast on perturbed samples" 102 | ) 103 | 104 | # Plot original forecast 105 | axis = self._plot_timeseries( 106 | pred_original_ts, color="blue", axis=axis, name="original forecast" 107 | ) 108 | 109 | # Set labels and title 110 | axis.set_xlabel("Timestamp") 111 | axis.set_ylabel(variable) 112 | axis.set_title("Time-Series Individual Conditional Expectation (TSICE)") 113 | 114 | axis.legend() 115 | 116 | # Display the plot 117 | plt.show() 118 | 119 | def _plot_timeseries( 120 | self, timeseries, color="green", axis=None, name="time series" 121 | ): 122 | showlegend = True 123 | if isinstance(timeseries, dict): 124 | data = timeseries 125 | if isinstance(color, str): 126 | color = {k: color for k in data} 127 | elif isinstance(timeseries, list): 128 | data = {} 129 | for k, ts_data in enumerate(timeseries): 130 | data[k] = ts_data 131 | if isinstance(color, str): 132 | color = {k: color for k in data} 133 | else: 134 | data = {} 135 | data["default"] = timeseries 136 | color = {"default": color} 137 | 138 | if axis is None: 139 | _, axis = plt.subplots() 140 | 141 | first = True 142 | for key, _timeseries in data.items(): 143 | if not first: 144 | showlegend = False 145 | 146 | self._add_timeseries( 147 | axis, _timeseries, color=color[key], showlegend=showlegend, name=name 148 | ) 149 | first = False 150 | 151 | return axis 152 | 153 | def _add_timeseries( 154 | self, axis, timeseries, color="green", name="time series", showlegend=False 155 | ): 156 | timestamps = timeseries.index 157 | axis.plot( 158 | timestamps, 159 | timeseries[timeseries.columns[0]], 160 | color=color, 161 | label=(name if showlegend else "_nolegend_"), 162 | ) 163 | 164 | def plot_impact(self, feature_per_row=2): 165 | """Plot the impace. 166 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py 167 | """ 168 | 169 | n_row = int(np.ceil(len(self.explanation["feature_names"]) / feature_per_row)) 170 | feat_values = np.array(self.explanation["feature_values"]) 171 | 172 | fig, axs = plt.subplots(n_row, feature_per_row, figsize=(15, 15)) 173 | axs = axs.ravel() # Flatten the axs to iterate over it 174 | 175 | for i, feat in enumerate(self.explanation["feature_names"]): 176 | x_feat = feat_values[i, :, 0] 177 | trend_fit = LinearRegression() 178 | trend_line = trend_fit.fit( 179 | x_feat.reshape(-1, 1), self.explanation["signed_impact"] 180 | ) 181 | x_trend = np.linspace(min(x_feat), max(x_feat), 101) 182 | y_trend = trend_line.predict(x_trend[..., np.newaxis]) 183 | 184 | # Scatter plot 185 | axs[i].scatter(x=x_feat, y=self.explanation["signed_impact"], color="blue") 186 | # Line plot 187 | axs[i].plot( 188 | x_trend, 189 | y_trend, 190 | color="green", 191 | label="correlation between forecast and observed feature", 192 | ) 193 | # Reference line 194 | current_value = self.explanation["current_feature_values"][i][0] 195 | axs[i].axvline( 196 | x=current_value, 197 | color="firebrick", 198 | linestyle="--", 199 | label="current value", 200 | ) 201 | 202 | axs[i].set_xlabel(feat) 203 | axs[i].set_ylabel("Δ forecast") 204 | 205 | # Display the legend on the first subplot 206 | axs[0].legend() 207 | 208 | fig.suptitle("Impact of Derived Variable On The Forecast", fontsize=16) 209 | plt.tight_layout() 210 | plt.subplots_adjust(top=0.95) 211 | plt.show() 212 | 213 | 214 | class TSICEExplainer(TSICEExplainerAIX): 215 | """ 216 | Wrapper for TSICEExplainer from aix360. 217 | """ 218 | 219 | def __init__( 220 | self, 221 | model: Callable, 222 | input_length: int, 223 | forecast_lookahead: int, 224 | n_variables: int = 1, 225 | n_exogs: int = 0, 226 | n_perturbations: int = 25, 227 | features_to_analyze: Optional[List[str]] = None, 228 | perturbers: Optional[List[Union[TSPerturber, dict]]] = None, 229 | explanation_window_start: Optional[int] = None, 230 | explanation_window_length: int = 10, 231 | ): 232 | super().__init__( 233 | forecaster=model, 234 | input_length=input_length, 235 | forecast_lookahead=forecast_lookahead, 236 | n_variables=n_variables, 237 | n_exogs=n_exogs, 238 | n_perturbations=n_perturbations, 239 | features_to_analyze=features_to_analyze, 240 | perturbers=perturbers, 241 | explanation_window_start=explanation_window_start, 242 | explanation_window_length=explanation_window_length, 243 | ) 244 | 245 | def explain(self, inputs, outputs=None, **kwargs) -> TSICEResults: 246 | """ 247 | Explain the model's prediction on X. 248 | """ 249 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 250 | return TSICEResults(_explanation) 251 | --------------------------------------------------------------------------------