├── .dockerignore ├── .github ├── actions │ └── build-core │ │ └── action.yml └── workflows │ ├── benchmarks-merge.yml │ ├── benchmarks.yml │ ├── publish.yml │ ├── security.yaml │ └── workflow.yml ├── .gitignore ├── .gitmodules ├── .readthedocs.yaml ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── cliff.toml ├── deps.sh ├── docs ├── Makefile ├── _static │ ├── artwork │ │ ├── favicon.png │ │ └── logo.png │ └── css │ │ ├── custom.css │ │ └── fonts │ │ ├── RedHatDisplay-Italic-VariableFont_wght.woff │ │ ├── RedHatDisplay-VariableFont_wght.woff │ │ ├── RedHatMono-Italic-VariableFont_wght.woff │ │ ├── RedHatMono-VariableFont_wght.woff │ │ ├── RedHatText-Italic-VariableFont_wght.woff │ │ └── RedHatText-VariableFont_wght.woff ├── api.rst ├── clean.sh ├── conf.py ├── generated │ ├── trustyai.explainers.CounterfactualExplainer.rst │ ├── trustyai.explainers.CounterfactualResult.rst │ ├── trustyai.explainers.LimeExplainer.rst │ ├── trustyai.explainers.LimeResults.rst │ ├── trustyai.explainers.SHAPExplainer.rst │ ├── trustyai.explainers.SHAPResults.rst │ ├── trustyai.initializer.init.rst │ ├── trustyai.model.Dataset.rst │ ├── trustyai.model.Model.rst │ ├── trustyai.model.counterfactual_prediction.rst │ ├── trustyai.model.feature.rst │ ├── trustyai.model.feature_domain.rst │ ├── trustyai.model.output.rst │ └── trustyai.model.simple_prediction.rst ├── index.rst ├── make.bat ├── python_models.html ├── python_models.ipynb ├── requirements.txt └── tutorial.rst ├── info └── detoxify.md ├── pyproject.toml ├── requirements.txt ├── scripts ├── build.sh └── local.sh ├── src └── trustyai │ ├── __init__.py │ ├── _default_initializer.py │ ├── dep │ ├── .gitkeep │ └── org │ │ └── trustyai │ │ └── .gitkeep │ ├── explainers │ ├── __init__.py │ ├── counterfactuals.py │ ├── explanation_results.py │ ├── extras │ │ ├── tsice.py │ │ ├── tslime.py │ │ └── tssaliency.py │ ├── lime.py │ ├── pdp.py │ └── shap.py │ ├── initializer.py │ ├── language │ ├── __init__.py │ └── detoxify │ │ ├── __init__.py │ │ └── tmarco.py │ ├── local │ └── __init__.py │ ├── metrics │ ├── __init__.py │ ├── distance.py │ ├── fairness │ │ ├── __init__.py │ │ └── group.py │ ├── language.py │ └── saliency.py │ ├── model │ ├── __init__.py │ └── domain.py │ ├── utils │ ├── DataUtils.py │ ├── __init__.py │ ├── _tyrus_info_text.py │ ├── _visualisation.py │ ├── api │ │ └── api.py │ ├── data_conversions.py │ ├── extras │ │ ├── metrics_service.py │ │ ├── models.py │ │ └── timeseries.py │ ├── text.py │ ├── tokenizers.py │ └── tyrus.py │ ├── version.py │ └── visualizations │ ├── __init__.py │ ├── distance.py │ ├── lime.py │ ├── pdp.py │ ├── shap.py │ └── visualization_results.py └── tests ├── benchmarks ├── benchmark.py ├── benchmark_common.py └── xai_benchmark.py ├── extras ├── test_metrics_service.py ├── test_tsice.py ├── test_tslime.py └── test_tssaliency.py ├── general ├── common.py ├── data │ ├── data.csv │ ├── income-biased.zip │ └── income-unbiased.zip ├── models │ └── income-xgd-biased.joblib ├── test_conversions.py ├── test_counterfactualexplainer.py ├── test_dataset.py ├── test_datautils.py ├── test_group_fairness.py ├── test_limeexplainer.py ├── test_metrics_language.py ├── test_model.py ├── test_pdp.py ├── test_prediction.py ├── test_shap.py ├── test_shap_background_generation.py ├── test_tyrus.py └── universal.py └── initialization └── test_initialization.py /.dockerignore: -------------------------------------------------------------------------------- 1 | *.log 2 | __pycache__ 3 | .pytest_cache 4 | .pynb_checkpoints 5 | .Rproj.user 6 | .vscode 7 | .idea 8 | .mypy_cache 9 | build 10 | dist 11 | trustyai.egg-info 12 | *.pyc -------------------------------------------------------------------------------- /.github/actions/build-core/action.yml: -------------------------------------------------------------------------------- 1 | name: Build exp-core JAR 2 | description: Clone and build TrustyAI-Explainability library (shaded in a single JAR) 3 | runs: 4 | using: "composite" 5 | steps: 6 | - name: Set up JDK 17 7 | uses: actions/setup-java@v2 8 | with: 9 | distribution: 'adopt' 10 | java-version: '17' 11 | - name: Build explainability-core 12 | shell: bash 13 | run: | 14 | git clone https://github.com/trustyai-explainability/trustyai-explainability.git 15 | mvn clean install -DskipTests -f trustyai-explainability/pom.xml -Pshaded -fae -e -nsu 16 | mv trustyai-explainability/explainability-arrow/target/explainability-arrow-*-SNAPSHOT.jar src/trustyai/dep/org/trustyai/ -------------------------------------------------------------------------------- /.github/workflows/benchmarks-merge.yml: -------------------------------------------------------------------------------- 1 | name: TrustyAI Python benchmarks (merge) 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | 9 | permissions: 10 | contents: write 11 | deployments: write 12 | pages: write 13 | pull-requests: write 14 | 15 | jobs: 16 | benchmark: 17 | if: github.event.pull_request.merged == 'true' 18 | name: Run pytest-benchmark benchmark 19 | runs-on: ubuntu-latest 20 | steps: 21 | - uses: actions/checkout@v2 22 | - uses: actions/setup-python@v2 23 | with: 24 | python-version: 3.8 25 | - uses: actions/setup-java@v2 26 | with: 27 | distribution: "adopt" 28 | java-version: "11" 29 | check-latest: true 30 | - uses: stCarolas/setup-maven@v4 31 | with: 32 | maven-version: 3.8.1 33 | - name: Build explainability-core 34 | uses: ./.github/actions/build-core 35 | - name: Install TrustyAI Python package 36 | run: | 37 | pip install -r requirements-dev.txt 38 | pip install . 39 | - name: Run benchmark 40 | run: | 41 | pytest tests/benchmarks/benchmark.py --benchmark-json tests/benchmarks/results.json 42 | - name: Store benchmark result 43 | uses: benchmark-action/github-action-benchmark@v1 44 | with: 45 | name: TrustyAI continuous benchmarks 46 | tool: 'pytest' 47 | output-file-path: tests/benchmarks/results.json 48 | github-token: ${{ secrets.GITHUB_TOKEN }} 49 | auto-push: true 50 | gh-pages-branch: gh-pages 51 | alert-threshold: '200%' 52 | comment-on-alert: true 53 | comment-always: true 54 | fail-on-alert: false 55 | alert-comment-cc-users: '@ruivieira' -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- 1 | name: TrustyAI Python benchmarks (PR) 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | deployments: write 11 | pages: write 12 | pull-requests: write 13 | 14 | jobs: 15 | benchmark: 16 | name: Run pytest-benchmark benchmark 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v2 20 | - uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.8 23 | - uses: actions/setup-java@v2 24 | with: 25 | distribution: "adopt" 26 | java-version: "11" 27 | check-latest: true 28 | - uses: stCarolas/setup-maven@v4 29 | with: 30 | maven-version: 3.8.1 31 | - name: Build explainability-core 32 | uses: ./.github/actions/build-core 33 | - name: Install TrustyAI Python package 34 | run: | 35 | pip install -r requirements-dev.txt 36 | pip install . 37 | - name: Run benchmark 38 | run: | 39 | pytest tests/benchmarks/benchmark.py --benchmark-json tests/benchmarks/results.json 40 | - name: Benchmark result comment 41 | uses: benchmark-action/github-action-benchmark@v1 42 | with: 43 | name: TrustyAI continuous benchmarks 44 | tool: 'pytest' 45 | output-file-path: tests/benchmarks/results.json 46 | github-token: ${{ secrets.GITHUB_TOKEN }} 47 | auto-push: false 48 | alert-threshold: '200%' 49 | comment-on-alert: true 50 | save-data-file: false 51 | comment-always: true 52 | fail-on-alert: false 53 | alert-comment-cc-users: '@ruivieira' 54 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | on: 3 | release: 4 | types: [ published ] 5 | jobs: 6 | pypi-publish: 7 | name: upload release to PyPI 8 | runs-on: ubuntu-latest 9 | environment: pypi 10 | permissions: 11 | id-token: write 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 0 17 | - name: Build explainability-core 18 | uses: ./.github/actions/build-core 19 | - run: python3 -m pip install --upgrade build && python3 -m build 20 | - name: Publish package 21 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.github/workflows/security.yaml: -------------------------------------------------------------------------------- 1 | name: Security 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | jobs: 8 | build: 9 | name: Build 10 | runs-on: ubuntu-20.04 11 | permissions: 12 | contents: read 13 | security-events: write 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | 18 | - name: Trivy scan 19 | uses: aquasecurity/trivy-action@0.28.0 20 | with: 21 | scan-type: 'fs' 22 | format: 'sarif' 23 | output: 'trivy-results.sarif' 24 | severity: 'MEDIUM,HIGH,CRITICAL' 25 | exit-code: '0' 26 | ignore-unfixed: false 27 | 28 | - name: Update Security tab 29 | uses: github/codeql-action/upload-sarif@v3 30 | with: 31 | sarif_file: 'trivy-results.sarif' -------------------------------------------------------------------------------- /.github/workflows/workflow.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ 3.8, 3.9, 3.11 ] 11 | java-version: [ 17 ] 12 | maven-version: [ '3.8.6' ] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up JDK + Maven version 16 | uses: s4u/setup-maven-action@v1.4.0 17 | with: 18 | java-version: ${{ matrix.java-version }} 19 | maven-version: ${{ matrix.maven-version }} 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Build explainability-core 25 | uses: ./.github/actions/build-core 26 | - name: Install TrustyAI Python package 27 | run: | 28 | pip install . 29 | pip install ".[dev]" 30 | pip install ".[extras]" 31 | pip install ".[api]" 32 | - name: Lint 33 | run: | 34 | pylint --ignore-imports=yes $(find src/trustyai -type f -name "*.py") 35 | - name: Test with pytest 36 | run: | 37 | pytest -v -s tests/general 38 | pytest -v -s tests/extras 39 | pytest -v -s tests/initialization --forked 40 | - name: Style 41 | run: | 42 | black --check $(find src/trustyai -type f -name "*.py") 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Sphinx stuff 69 | docs/_build 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .DS_Store 131 | .idea 132 | .Rproj.user 133 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/benchmarks/trustyai_xai_bench"] 2 | path = tests/benchmarks/trustyai_xai_bench 3 | url = https://github.com/trustyai-explainability/trustyai_xai_bench 4 | branch = main 5 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | apt_packages : 12 | - maven 13 | tools: 14 | python: "3.9" 15 | jobs: 16 | pre_create_environment: 17 | - rm -f src/trustyai/dep/org/trustyai/* 18 | - git clone https://github.com/trustyai-explainability/trustyai-explainability.git 19 | - mvn clean install -DskipTests -f trustyai-explainability/pom.xml -Pquickly -fae -e -nsu 20 | - mvn clean install -DskipTests -f trustyai-explainability/explainability-arrow/pom.xml -Pshaded -fae -e -nsu 21 | - mv trustyai-explainability/explainability-arrow/target/explainability-arrow-*-SNAPSHOT.jar src/trustyai/dep/org/trustyai/ 22 | 23 | post_build: 24 | - rm -Rf trustyai-explainability 25 | 26 | # install the package 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | - method: pip 31 | path: . 32 | extra_requirements: 33 | - dev 34 | 35 | # Build documentation in the docs/ directory with Sphinx 36 | sphinx: 37 | configuration: docs/conf.py 38 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guide 2 | 3 | **Want to contribute? Great!** 4 | We try to make it easy, and all contributions, even the smaller ones, are more than welcome. 5 | This includes bug reports, fixes, documentation, examples... 6 | But first, read this page (including the small print at the end). 7 | 8 | ## Legal 9 | 10 | All original contributions to TrustyAI-explainability are licensed under the 11 | [ASL - Apache License](https://www.apache.org/licenses/LICENSE-2.0), 12 | version 2.0 or later, or, if another license is specified as governing the file or directory being 13 | modified, such other license. 14 | 15 | ## Issues 16 | 17 | Python TrustyAI uses [GitHub to manage and report issues](https://github.com/trustyai-explainability/trustyai-explainability-python/issues). 18 | 19 | If you believe you found a bug, please indicate a way to reproduce it, what you are seeing and what you would expect to see. 20 | Don't forget to indicate your Python TrustyAI, Java, and Maven version. 21 | 22 | ### Checking an issue is fixed in main 23 | 24 | Sometimes a bug has been fixed in the `main` branch of Python TrustyAI and you want to confirm it is fixed for your own application. 25 | Testing the `main` branch is easy and you can build Python TrustyAI all by yourself. 26 | 27 | ## Creating a Pull Request (PR) 28 | 29 | To contribute, use GitHub Pull Requests, from your **own** fork. 30 | 31 | - PRs should be always related to an open GitHub issue. If there is none, you should create one. 32 | - Try to fix only one issue per PR. 33 | - Make sure to create a new branch. Usually branches are named after the GitHub ticket they are addressing. E.g. for ticket "issue-XYZ An example issue" your branch should be at least prefixed with `FAI-XYZ`. E.g.: 34 | 35 | git checkout -b issue-XYZ 36 | # or 37 | git checkout -b issue-XYZ-my-fix 38 | 39 | - When you submit your PR, make sure to include the ticket ID, and its title; e.g., "issue-XYZ An example issue". 40 | - The description of your PR should describe the code you wrote. The issue that is solved should be at least described properly in the corresponding GitHub ticket. 41 | - If your contribution spans across multiple repositories, use the same branch name (e.g. `issue-XYZ`). 42 | - If your contribution spans across multiple repositories, make sure to list all the related PRs. 43 | 44 | ### Python Coding Guidelines 45 | 46 | PRs will be checked against `black` and `pylint` before passing the CI. 47 | 48 | You can perform these checks locally to guarantee the PR passes these checks. 49 | 50 | ### Requirements for Dependencies 51 | 52 | Any dependency used in the project must fulfill these hard requirements: 53 | 54 | - The dependency must have **an Apache 2.0 compatible license**. 55 | - Good: BSD, MIT, Apache 2.0 56 | - Avoid: EPL, LGPL 57 | - Especially LGPL is a last resort and should be abstracted away or contained behind an SPI. 58 | - Test scope dependencies pose no problem if they are EPL or LPGL. 59 | - Forbidden: no license, GPL, AGPL, proprietary license, field of use restrictions ("this software shall be used for good, not evil"), ... 60 | - Even test scope dependencies cannot use these licenses. 61 | - To check the ALS compatibility license please visit these links:[Similarity in terms to the Apache License 2.0](http://www.apache.org/legal/resolved.html#category-a)  62 | [How should so-called "Weak Copyleft" Licenses be handled](http://www.apache.org/legal/resolved.html#category-b) 63 | 64 | - The dependency shall be **available in PyPi**. 65 | - Why? 66 | - Build reproducibility. Any repository server we use, must still run in future from now. 67 | - Build speed. More repositories slow down the build. 68 | - Build reliability. A repository server that is temporarily down can break builds. 69 | 70 | - **Do not release the dependency yourself** (by building it from source). 71 | - Why? Because it's not an official release, by the official release guys. 72 | - A release must be 100% reproducible. 73 | - A release must be reliable (sometimes the release person does specific things you might not reproduce). 74 | 75 | - **The sources are publicly available** 76 | - We may need to rebuild the dependency from sources ourselves in future. This may be in the rare case when 77 | the dependency is no longer maintained, but we need to fix a specific CVE there. 78 | - Make sure the dependency's pom.xml contains link to the source repository (`scm` tag). 79 | 80 | - The dependency needs to use **reasonable build system** 81 | - Since we may need to rebuild the dependency from sources, we also need to make sure it is easily buildable. 82 | Maven or Gradle are acceptable as build systems. 83 | 84 | - Only use dependencies with **an active community**. 85 | - Check for activity in the last year through [Open Hub](https://www.openhub.net). 86 | 87 | - Less is more: **less dependencies is better**. Bloat is bad. 88 | - Try to use existing dependencies if the functionality is available in those dependencies 89 | - For example: use Apache Commons Math instead of Colt if Apache Commons Math is already a dependency 90 | 91 | There are currently a few dependencies which violate some of these rules. They should be properly commented with a 92 | warning and explaining why are needed 93 | If you want to add a dependency that violates any of the rules above, get approval from the project leads. 94 | 95 | ### Tests and Documentation 96 | 97 | Don't forget to include tests in your pull requests, and documentation (reference documentation, ...). 98 | Guides and reference documentation should be submitted to the [Python TrustyAI examples repository](https://github.com/trustyai-explainability/trustyai-explainability-python-examples). 99 | If you are contributing a new feature, we strongly advise submitting an example. 100 | 101 | ### Code Reviews and Continuous Integration 102 | 103 | All submissions, including those by project members, need to be reviewed by others before being merged. Our CI, GitHub Actions, should successfully execute your PR, marking the GitHub check as green. 104 | 105 | ## Feature Proposals 106 | 107 | If you would like to see some feature in Python TrustyAI, just open a feature request and tell us what you would like to see. 108 | Alternatively, you propose it during the [TrustyAI community meeting](https://github.com/trustyai-explainability/community). 109 | 110 | Great feature proposals should include a short **Description** of the feature, the **Motivation** that makes that feature necessary and the **Goals** that are achieved by realizing it. If the feature is deemed worthy, then an Epic will be created. 111 | 112 | ## The small print 113 | 114 | This project is an open source project, please act responsibly, be nice, polite and enjoy! 115 | 116 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | prune tests 3 | prune docs 4 | prune .github 5 | 6 | global-exclude *~ *.py[cod] *.so *.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![version](https://img.shields.io/badge/version-0.6.0-green) [![Tests](https://github.com/trustyai-python/module/actions/workflows/workflow.yml/badge.svg)](https://github.com/trustyai-python/examples/actions/workflows/workflow.yml) 2 | 3 | # python-trustyai 4 | 5 | Python bindings to [TrustyAI](https://kogito.kie.org/trustyai/)'s explainability library. 6 | 7 | ## Setup 8 | 9 | ### PyPi 10 | 11 | Install from PyPi with 12 | 13 | ```shell 14 | pip install trustyai 15 | ``` 16 | 17 | To install additional experimental features, also use 18 | 19 | ```shell 20 | pip install trustyai[extras] 21 | ``` 22 | 23 | ### Local 24 | 25 | The minimum dependencies can be installed (from the root directory) with 26 | 27 | ```shell 28 | pip install . 29 | ``` 30 | 31 | If running the examples or developing, also install the development dependencies: 32 | 33 | ```shell 34 | pip install '.[dev]' 35 | ``` 36 | 37 | ### Docker 38 | 39 | Alternatively create a container image and run it using 40 | 41 | ```shell 42 | $ docker build -f Dockerfile -t python-trustyai:latest . 43 | $ docker run --rm -it -p 8888:8888 python-trustyai:latest 44 | ``` 45 | 46 | The Jupyter server will be available at `localhost:8888`. 47 | 48 | ### Binder 49 | 50 | You can also run the example Jupyter notebooks 51 | using `mybinder.org`: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/trustyai-python/trustyai-explainability-python-examples/main?labpath=examples) 52 | 53 | ## Documentation 54 | 55 | Check out the [ReadTheDocs page](https://trustyai-explainability-python.readthedocs.io/en/latest/) for API references 56 | and examples. 57 | 58 | ## Getting started 59 | 60 | ### Examples 61 | 62 | There are several working examples available in the [examples](https://github.com/trustyai-explainability/trustyai-explainability-python-examples/tree/main/examples) repository. 63 | 64 | ## Contributing 65 | 66 | Please see the [CONTRIBUTING.md](CONTRIBUTING.md) file for instructions on how to contribute to this project. -------------------------------------------------------------------------------- /cliff.toml: -------------------------------------------------------------------------------- 1 | [changelog] 2 | # changelog header 3 | header = """ 4 | # Changelog\n 5 | All notable changes to this project will be documented in this file.\n 6 | """ 7 | # template for the changelog body 8 | body = """ 9 | {% if version %}\ 10 | ## [{{ version }}] - {{ timestamp | date(format="%Y-%m-%d") }} 11 | {% else %}\ 12 | ## [unreleased] 13 | {% endif %}\ 14 | {% for group, commits in commits | group_by(attribute="group") %} 15 | ### {{ group | upper_first }} 16 | {% for commit in commits %} 17 | - {{ commit.message | upper_first }}\ 18 | {% endfor %} 19 | {% endfor %}\n 20 | """ 21 | # remove the leading and trailing whitespace from the template 22 | trim = true 23 | # changelog footer 24 | footer = "" 25 | 26 | [git] 27 | # parse the commits based on https://www.conventionalcommits.org 28 | conventional_commits = true 29 | # filter out the commits that are not conventional 30 | filter_unconventional = false 31 | # process each line of a commit as an individual commit 32 | split_commits = false 33 | # regex for preprocessing the commit messages 34 | commit_preprocessors = [ 35 | # { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/orhun/git-cliff/issues/${2}))"}, # replace issue numbers 36 | ] 37 | # regex for parsing and grouping commits 38 | commit_parsers = [ 39 | { message = "^feat", group = "Features" }, 40 | { message = "^fix", group = "Bug Fixes" }, 41 | { message = "^doc", group = "Documentation" }, 42 | { message = "^perf", group = "Performance" }, 43 | { message = "^refactor", group = "Refactor" }, 44 | { message = "^style", group = "Styling" }, 45 | { message = "^test", group = "Testing" }, 46 | { message = "^chore\\(release\\): prepare for", skip = true }, 47 | { message = "^chore", group = "Miscellaneous Tasks" }, 48 | { body = ".*security", group = "Security" }, 49 | ] 50 | # protect breaking changes from being skipped due to matching a skipping commit_parser 51 | protect_breaking_commits = false 52 | # filter out the commits that are not matched by commit parsers 53 | filter_commits = false 54 | # glob pattern for matching git tags 55 | tag_pattern = "[0-9]*" 56 | # regex for skipping tags 57 | skip_tags = "v0.1.0-beta.1" 58 | # regex for ignoring tags 59 | ignore_tags = "" 60 | # sort the tags topologically 61 | topo_order = false 62 | # sort the commits inside sections by oldest/newest order 63 | sort_commits = "oldest" 64 | # limit the number of commits included in the changelog. 65 | # limit_commits = 42 66 | -------------------------------------------------------------------------------- /deps.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | TRUSTY_VERSION="1.12.0.Final" 4 | 5 | mvn org.apache.maven.plugins:maven-dependency-plugin:2.10:get \ 6 | -DremoteRepositories=https://repository.sonatype.org/content/repositories/central \ 7 | -Dartifact=org.kie.kogito:explainability-core:$TRUSTY_VERSION \ 8 | -Dmaven.repo.local=dep -q 9 | 10 | # We also need the test JARs in order to get the test models 11 | wget -O ./dep/org/kie/kogito/explainability-core/$TRUSTY_VERSION/explainability-core-$TRUSTY_VERSION-tests.jar \ 12 | https://repo1.maven.org/maven2/org/kie/kogito/explainability-core/$TRUSTY_VERSION/explainability-core-$TRUSTY_VERSION-tests.jar -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/artwork/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/artwork/favicon.png -------------------------------------------------------------------------------- /docs/_static/artwork/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/artwork/logo.png -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | @import url("theme.css"); 2 | 3 | 4 | @font-face { 5 | font-family: "Red Hat Text", sans-serif; !important; 6 | src: url("fonts/RedHatText-VariableFont_wght.woff"); 7 | } 8 | 9 | @font-face { 10 | font-family: "Red Hat Display", sans-serif; !important; 11 | src: url("fonts/RedHatDisplay-VariableFont_wght.woff"); 12 | } 13 | 14 | @font-face { 15 | font-family: "Red Hat Mono", sans-serif; !important; 16 | src: url("fonts/RedHatMono-VariableFont_wght.woff"); 17 | } 18 | 19 | body { 20 | font-family: "Red Hat Text", sans-serif; !important; 21 | } 22 | 23 | h1, h2, h3, h4, h5, h6 { 24 | font-family: "Red Hat Display", sans-serif; !important; 25 | } 26 | 27 | .rst-content code, .rst-content tt { 28 | font-family: "Red Hat Mono", sans-serif; !important; 29 | } 30 | 31 | .wy-side-nav-search { 32 | background-color: #343131 !important; 33 | } 34 | 35 | html.writer-html4 .rst-content dl:not(.docutils)>dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt{ 36 | border-top: 3px solid #e06666; 37 | background: #e0666633; 38 | color: #a64d79; 39 | } 40 | -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatDisplay-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatDisplay-Italic-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatDisplay-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatDisplay-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatMono-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatMono-Italic-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatMono-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatMono-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatText-Italic-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatText-Italic-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/RedHatText-VariableFont_wght.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/docs/_static/css/fonts/RedHatText-VariableFont_wght.woff -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: trustyai 2 | 3 | API Reference 4 | ============= 5 | This page contains the API reference for public objects and function within TrustyAI. See the 6 | (example notebooks) for usage guides and tutorials. 7 | 8 | trustyai.initializer 9 | -------------------- 10 | Initializing The JVM 11 | ########################## 12 | .. currentmodule:: trustyai.initializer 13 | .. model_api: 14 | .. autosummary:: 15 | :toctree: generated/ 16 | 17 | init 18 | 19 | 20 | trustyai.model 21 | -------------- 22 | Feature and Output Objects 23 | ########################## 24 | .. currentmodule:: trustyai.model 25 | .. model_api: 26 | .. autosummary:: 27 | :toctree: generated/ 28 | 29 | feature 30 | feature_domain 31 | output 32 | 33 | Data Objects 34 | ############ 35 | .. autosummary:: 36 | :toctree: generated/ 37 | 38 | Dataset 39 | 40 | Model Classes 41 | ############# 42 | .. autosummary:: 43 | :toctree: generated/ 44 | 45 | Model 46 | 47 | trustyai.explainers 48 | ------------------- 49 | LIME 50 | #### 51 | .. currentmodule:: trustyai.explainers 52 | .. explainers_api: 53 | .. autosummary:: 54 | :toctree: generated/ 55 | 56 | LimeExplainer 57 | LimeResults 58 | 59 | SHAP 60 | #### 61 | .. autosummary:: 62 | :toctree: generated/ 63 | 64 | SHAPExplainer 65 | BackgroundGenerator 66 | SHAPResults 67 | 68 | Counterfactuals 69 | ############### 70 | .. autosummary:: 71 | :toctree: generated/ 72 | 73 | CounterfactualExplainer 74 | CounterfactualResult 75 | 76 | trustyai.utils 77 | -------------- 78 | .. currentmodule:: trustyai.utils.tyrus 79 | .. utils_api: 80 | .. autosummary:: 81 | :toctree: generated/ 82 | 83 | Tyrus 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /docs/clean.sh: -------------------------------------------------------------------------------- 1 | rm generated/* 2 | rm -r _build/ -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath('../src/trustyai/')) 17 | 18 | import sphinx_rtd_theme 19 | 20 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'TrustyAI' 25 | copyright = '2023, Rob Geada, Tommaso Teofili, Rui Vieira, Rebecca Whitworth, Daniele Zonca' 26 | author = 'Rob Geada, Tommaso Teofili, Rui Vieira, Rebecca Whitworth, Daniele Zonca' 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.autosummary', 36 | 'sphinx.ext.autosectionlabel', 37 | 'sphinx_rtd_theme', 38 | 'sphinx.ext.mathjax', 39 | 'numpydoc' 40 | ] 41 | 42 | autodoc_default_options = { 43 | 'members': True, 44 | 'inherited-members': True 45 | } 46 | autosummary_generate = True 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # List of patterns, relative to source directory, that match files and 52 | # directories to ignore when looking for source files. 53 | # This pattern also affects html_static_path and html_extra_path. 54 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 55 | 56 | # -- Options for HTML output ------------------------------------------------- 57 | 58 | # The theme to use for HTML and HTML Help pages. See the documentation for 59 | # a list of builtin themes. 60 | # 61 | html_theme = 'sphinx_rtd_theme' 62 | html_static_path = ['_static'] 63 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 64 | html_theme_options = { 65 | 'logo_only': True, 66 | 'style_nav_header_background': '#343131', 67 | } 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | html_css_files = ['css/custom.css'] 73 | 74 | html_favicon = '_static/artwork/favicon.png' 75 | html_logo = '_static/artwork/logo.png' 76 | # numpydoc settings 77 | numpydoc_show_class_members = False 78 | 79 | 80 | def setup(app): 81 | import trustyai 82 | from trustyai.model import Model 83 | from trustyai.utils.tyrus import Tyrus 84 | Model.__name__ = "Model" 85 | Tyrus.__name__ = "Tyrus" -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.CounterfactualExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.CounterfactualExplainer 2 | =========================================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: CounterfactualExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~CounterfactualExplainer.__init__ 17 | ~CounterfactualExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.CounterfactualResult.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.CounterfactualResult 2 | ======================================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: CounterfactualResult 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~CounterfactualResult.__init__ 17 | ~CounterfactualResult.as_dataframe 18 | ~CounterfactualResult.as_html 19 | ~CounterfactualResult.plot 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~CounterfactualResult.proposed_features_array 30 | ~CounterfactualResult.proposed_features_dataframe 31 | 32 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.LimeExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.LimeExplainer 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: LimeExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~LimeExplainer.__init__ 17 | ~LimeExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.LimeResults.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.LimeResults 2 | =============================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: LimeResults 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~LimeResults.__init__ 17 | ~LimeResults.as_dataframe 18 | ~LimeResults.as_html 19 | ~LimeResults.map 20 | ~LimeResults.plot 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.SHAPExplainer.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.SHAPExplainer 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: SHAPExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SHAPExplainer.__init__ 17 | ~SHAPExplainer.explain 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /docs/generated/trustyai.explainers.SHAPResults.rst: -------------------------------------------------------------------------------- 1 | trustyai.explainers.SHAPResults 2 | =============================== 3 | 4 | .. currentmodule:: trustyai.explainers 5 | 6 | .. autoclass:: SHAPResults 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SHAPResults.__init__ 17 | ~SHAPResults.as_dataframe 18 | ~SHAPResults.as_html 19 | ~SHAPResults.candlestick_plot 20 | ~SHAPResults.get_fnull 21 | ~SHAPResults.get_saliencies 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /docs/generated/trustyai.initializer.init.rst: -------------------------------------------------------------------------------- 1 | trustyai.initializer.init 2 | ========================= 3 | 4 | .. currentmodule:: trustyai.initializer 5 | 6 | .. autofunction:: init -------------------------------------------------------------------------------- /docs/generated/trustyai.model.Dataset.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.Dataset 2 | ====================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autoclass:: Dataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Dataset.__init__ 17 | ~Dataset.df_to_prediction_object 18 | ~Dataset.from_df 19 | ~Dataset.from_numpy 20 | ~Dataset.numpy_to_prediction_object 21 | ~Dataset.prediction_object_to_numpy 22 | ~Dataset.prediction_object_to_pandas 23 | 24 | 25 | 26 | 27 | 28 | .. rubric:: Attributes 29 | 30 | .. autosummary:: 31 | 32 | ~Dataset.data 33 | ~Dataset.inputs 34 | ~Dataset.outputs 35 | 36 | -------------------------------------------------------------------------------- /docs/generated/trustyai.model.Model.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.Model 2 | ==================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autoclass:: Model 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Model.__init__ 17 | ~Model.equals 18 | ~Model.hashCode 19 | ~Model.predictAsync 20 | ~Model.toString 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/generated/trustyai.model.counterfactual_prediction.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.counterfactual\_prediction 2 | ========================================= 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: counterfactual_prediction -------------------------------------------------------------------------------- /docs/generated/trustyai.model.feature.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.feature 2 | ====================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: feature -------------------------------------------------------------------------------- /docs/generated/trustyai.model.feature_domain.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.feature\_domain 2 | ============================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: feature_domain -------------------------------------------------------------------------------- /docs/generated/trustyai.model.output.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.output 2 | ===================== 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: output -------------------------------------------------------------------------------- /docs/generated/trustyai.model.simple_prediction.rst: -------------------------------------------------------------------------------- 1 | trustyai.model.simple\_prediction 2 | ================================= 3 | 4 | .. currentmodule:: trustyai.model 5 | 6 | .. autofunction:: simple_prediction -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. TrustyAI documentation master file, created by 2 | sphinx-quickstart on Tue Jul 12 11:47:01 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to TrustyAI's documentation! 7 | ==================================== 8 | Red Hat's TrustyAI-Python library provides XAI explanations of decision services and 9 | predictive models for both enterprise and data science use-cases. 10 | 11 | This library is designed to provide a set of Python bindings to the main 12 | `TrustyAI Java toolkit `_, to allow 13 | for easier access to the toolkit in data science and prototyping use cases. This means the library 14 | benefits from both the speed of Java as well as the ease-of-use of Python; our whitepaper shows that 15 | the TrustyAI-Python LIME and SHAP explainers can run faster than the the official implementations. 16 | 17 | Installation 18 | ============ 19 | ``pip install trustyai`` 20 | 21 | Tutorial and Examples 22 | ===================== 23 | To get started, check out the :ref:`tutorial`. For more usage examples, see the example notebooks: 24 | 25 | * `LIME `_ 26 | * `SHAP `_ 27 | * `Counterfactuals `_ 28 | 29 | GitHub Repos 30 | ============ 31 | * `TrustyAI Python `_ 32 | * `TrustyAI Python Examples `_ 33 | * `TrustyAI Java `_ 34 | 35 | Paper 36 | ===== 37 | `TrustyAI Explainability Toolkit `_, 2022 38 | 39 | Contents 40 | ======== 41 | .. toctree:: 42 | tutorial 43 | api 44 | 45 | Indices and tables 46 | ================== 47 | 48 | * :ref:`genindex` 49 | * :ref:`modindex` 50 | * :ref:`search` 51 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme -------------------------------------------------------------------------------- /info/detoxify.md: -------------------------------------------------------------------------------- 1 | # TrustyAI-Detoxify 2 | Algorithms and tools for detecting and fixing hate speech, abuse and profanity in content generated by Large Language Models (_LLMs_). The source code is located in the [trustyai.language.detoxify](https://github.com/trustyai-explainability/trustyai-explainability-python/tree/main/src/trustyai/language/detoxify) module. 3 | 4 | ## T-MaRCo 5 | 6 | T-MaRCo is an extension of the work [Detoxifying Text with MaRCo: Controllable Revision with Experts and Anti-Experts 7 | ](https://arxiv.org/abs/2212.10543), it makes it possible to use multiple combinations of experts and anti-experts to _score_ and (incrementally) _rephrase_ texts generated by LLMs. 8 | 9 | In addition to that, it can integrate _rephrasing_ with the base model _self-reflection_ capabilities (see papers [Towards Mitigating Hallucination in Large Language Models via Self-Reflection 10 | ](https://arxiv.org/abs/2310.06271) and [N-Critics: Self-Refinement of Large Language Models with Ensemble of Critics 11 | ](https://arxiv.org/abs/2310.18679)). 12 | 13 | T-MaRCo hence provides the following features: 14 | * content *scoring*: providing a _disagreement score_ for each input token; high disagreement is often attached to toxic content. 15 | * content *masking*: providing a masked version of the input content, where all tokens that are consired toxic are replaced with the `` token. 16 | * content *redirection*: providing a non-toxic "regenerated" version of the original content. 17 | 18 | How to use T-MaRCo: 19 | ```python 20 | from trustyai.detoxify import TMaRCo 21 | 22 | # instantiate T-MaRCo 23 | tmarco = TMaRCo(expert_weights=[-1, 2]) 24 | 25 | # load pretrained anti-expert and expert models 26 | tmarco.load_models(["trustyai/gminus", "trustyai/gplus"]) 27 | 28 | # pick up some text generated by a LLM 29 | text = "Stand by me, just as long as you fucking stand by me" 30 | 31 | # generate T-MaRCo disagreement scores 32 | scores = tmarco.score([text]) # '[0.78664607 0.06622718 0.02403926 5.331921 0.49842355 0.46609956 0.22441313 0.43487906 0.51990145 1.9062967 0.64200985 0.30269763 1.7964466 ]' 33 | 34 | # mask tokens having high disagreement scores 35 | masked_text = tmarco.mask([text], scores=scores) # 'Stand by me just as long as you stand by' 36 | 37 | # rephrase masked tokens 38 | rephrased = tmarco.rephrase([text], [masked_text]) # 'Stand by me and just as long as you want stand by me'' 39 | 40 | # combine rephrasing and a base model self-reflection capabilities 41 | reflected = tmarco.reflect([text]) # '["'Stand by me in the way I want stand by you and in the ways I need you to standby me'."]' 42 | 43 | ``` 44 | 45 | T-MaRCo Pretrained models are available under [TrustyAI HuggingFace space](https://huggingface.co/trustyai) at https://huggingface.co/trustyai/gminus and https://huggingface.co/trustyai/gplus. 46 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "trustyai" 3 | version = "0.6.1" 4 | description = "Python bindings to the TrustyAI explainability library." 5 | authors = [{ name = "Rui Vieira", email = "rui@redhat.com" }] 6 | license = { text = "Apache License Version 2.0" } 7 | readme = "README.md" 8 | requires-python = ">=3.8" 9 | 10 | keywords = ["trustyai", "xai", "explainability", "ml"] 11 | 12 | classifiers = [ 13 | "License :: OSI Approved :: Apache Software License", 14 | "Development Status :: 4 - Beta", 15 | "Intended Audience :: Developers", 16 | "Intended Audience :: Science/Research", 17 | "Programming Language :: Java", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 21 | "Topic :: Software Development :: Libraries :: Java Libraries", 22 | ] 23 | 24 | dependencies = [ 25 | "Jpype1==1.5.0", 26 | "pyarrow==17.0.0", 27 | "matplotlib~=3.6.3", 28 | "pandas~=1.5.3", 29 | "numpy~=1.24.1", 30 | "jupyter-bokeh~=3.0.5", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | dev = [ 35 | "JPype1==1.5.0", 36 | "black~=22.12.0", 37 | "click==8.0.4", 38 | "joblib~=1.2.0", 39 | "jupyterlab~=3.5.3", 40 | "numpydoc==1.5.0", 41 | "pylint==2.15.6", 42 | "pytest~=7.2.1", 43 | "pytest-benchmark==4.0.0", 44 | "pytest-forked~=1.6.0", 45 | "scikit-learn~=1.2.1", 46 | "setuptools", 47 | "twine==3.4.2", 48 | "wheel~=0.38.4", 49 | "xgboost==1.4.2", 50 | ] 51 | extras = ["aix360[default,tsice,tslime,tssaliency]==0.3.0"] 52 | 53 | detoxify = [ 54 | "transformers~=4.36.2", 55 | "datasets", 56 | "scipy~=1.12.0", 57 | "torch~=2.2.1", 58 | "iter-tools", 59 | "evaluate", 60 | "trl", 61 | ] 62 | 63 | api = ["kubernetes"] 64 | 65 | [project.urls] 66 | homepage = "https://github.com/trustyai-explainability/trustyai-explainability-python" 67 | documentation = "https://trustyai-explainability-python.readthedocs.io/en/latest/" 68 | repository = "https://github.com/trustyai-explainability/trustyai-explainability-python" 69 | 70 | [build-system] 71 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 72 | build-backend = "setuptools.build_meta" 73 | 74 | [tool.setuptools] 75 | package-dir = { "" = "src" } 76 | 77 | [tool.pytest.ini_options] 78 | log_cli = true 79 | addopts = '-m="not block_plots"' 80 | markers = [ 81 | "block_plots: Test plots will block execution of subsequent tests until closed", 82 | ] 83 | 84 | [tool.setuptools.packages.find] 85 | where = ["src"] 86 | 87 | [tool.setuptools_scm] 88 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | JPype1==1.4.1 2 | matplotlib==3.6.3 3 | pandas==1.2.5 4 | pyarrow==14.0.1 5 | jupyter-bokeh~=3.0.5 -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Red Hat, Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | ROOT_DIR=$(git rev-parse --show-toplevel) 19 | TMP_DIR=$(mktemp -d) 20 | 21 | EXP_CORE="trustyai-explainability" 22 | 23 | EXP_CORE_DEST="${TMP_DIR}/${EXP_CORE}" 24 | if [ ! -d "${EXP_CORE_DEST}" ] 25 | then 26 | echo "Cloning trustyai-explainability into ${EXP_CORE_DEST}" 27 | git clone --branch main https://github.com/${EXP_CORE}/${EXP_CORE}.git "${EXP_CORE_DEST}" 28 | echo "Copying JARs from ${EXP_CORE_DEST} into ${ROOT_DIR}/dep/org/trustyai/" 29 | mvn install package -DskipTests -f "${EXP_CORE_DEST}"/pom.xml -Pshaded 30 | mv "${EXP_CORE_DEST}"/explainability-arrow/target/explainability-arrow-*.jar "${ROOT_DIR}"/src/trustyai/dep/org/trustyai/ 31 | else 32 | echo "Directory ${EXP_CORE_DEST} already exists. Please delete it or move it." 33 | exit 1 34 | fi 35 | 36 | if [[ "$VIRTUAL_ENV" != "" ]] 37 | then 38 | pip install "${ROOT_DIR}" --force 39 | else 40 | echo "Not in a virtualenv. Installation not recommended." 41 | exit 1 42 | fi 43 | -------------------------------------------------------------------------------- /scripts/local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2022 Red Hat, Inc. and/or its affiliates 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | ROOT_DIR=$(git rev-parse --show-toplevel) 19 | 20 | EXP_CORE_DEST=$1 21 | 22 | if [[ "$EXP_CORE_DEST" == "" ]] 23 | then 24 | EXP_CORE_DEST="../trustyai-explainability" 25 | echo "No argument provided, building trustyai-explainability from ${EXP_CORE_DEST}" 26 | else 27 | echo "Building trustyai-explainability from ${EXP_CORE_DEST}" 28 | fi 29 | 30 | echo "Copying JARs from ${EXP_CORE_DEST} into ${ROOT_DIR}/dep/org/trustyai/" 31 | mvn install package -DskipTests -f "${EXP_CORE_DEST}"/pom.xml -Pshaded 32 | mv "${EXP_CORE_DEST}"/explainability-arrow/target/explainability-arrow-*.jar "${ROOT_DIR}"/src/trustyai/dep/org/trustyai/ 33 | 34 | 35 | if [[ "$VIRTUAL_ENV" != "" ]] 36 | then 37 | pip install "${ROOT_DIR}" --force 38 | else 39 | echo "Not in a virtualenv. Installation not recommended." 40 | exit 1 41 | fi 42 | -------------------------------------------------------------------------------- /src/trustyai/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value 2 | # pylint: disable = invalid-name, R0801, duplicate-code 3 | """Main TrustyAI Python bindings""" 4 | import os 5 | import logging 6 | 7 | # set initialized env variable to 0 8 | import warnings 9 | from .version import __version__ 10 | 11 | TRUSTYAI_IS_INITIALIZED = False 12 | 13 | if os.getenv("PYTHON_TRUSTY_DEBUG") == "1": 14 | _LOGGING_LEVEL = logging.DEBUG 15 | else: 16 | _LOGGING_LEVEL = logging.WARN 17 | 18 | logging.basicConfig(level=_LOGGING_LEVEL) 19 | 20 | 21 | def init(): 22 | """Deprecated manual initializer for the JVM. This function has been replaced by 23 | automatic initialization when importing the components of the module that require 24 | JVM access, or by manual user initialization via :func:`trustyai`initializer.init`.""" 25 | warnings.warn( 26 | "trustyai.init() is now deprecated; the trustyai library will now " 27 | + "automatically initialize. For manual initialization options, see " 28 | + "trustyai.initializer.init()" 29 | ) 30 | -------------------------------------------------------------------------------- /src/trustyai/_default_initializer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value, invalid-name, R0801 2 | """The default initializer""" 3 | import trustyai 4 | from trustyai import initializer # pylint: disable=no-name-in-module 5 | 6 | if not trustyai.TRUSTYAI_IS_INITIALIZED: 7 | trustyai.TRUSTYAI_IS_INITIALIZED = initializer.init() 8 | -------------------------------------------------------------------------------- /src/trustyai/dep/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/src/trustyai/dep/.gitkeep -------------------------------------------------------------------------------- /src/trustyai/dep/org/trustyai/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/src/trustyai/dep/org/trustyai/.gitkeep -------------------------------------------------------------------------------- /src/trustyai/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Explainers module""" 2 | # pylint: disable=duplicate-code 3 | from .counterfactuals import CounterfactualResult, CounterfactualExplainer 4 | from .lime import LimeExplainer, LimeResults 5 | from .shap import SHAPExplainer, SHAPResults, BackgroundGenerator 6 | from .pdp import PDPExplainer 7 | -------------------------------------------------------------------------------- /src/trustyai/explainers/explanation_results.py: -------------------------------------------------------------------------------- 1 | """Generic class for Explanation and Saliency results""" 2 | from abc import ABC, abstractmethod 3 | 4 | import pandas as pd 5 | from pandas.io.formats.style import Styler 6 | 7 | 8 | class ExplanationResults(ABC): 9 | """Abstract class for explanation visualisers""" 10 | 11 | @abstractmethod 12 | def as_dataframe(self) -> pd.DataFrame: 13 | """Display explanation result as a dataframe""" 14 | 15 | @abstractmethod 16 | def as_html(self) -> Styler: 17 | """Visualise the styled dataframe""" 18 | 19 | 20 | # pylint: disable=too-few-public-methods 21 | class SaliencyResults(ExplanationResults): 22 | """Abstract class for saliency visualisers""" 23 | 24 | @abstractmethod 25 | def saliency_map(self): 26 | """Return the Saliencies as a dictionary, keyed by output name""" 27 | -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tsice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSICEExplainer from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | # pylint: disable=too-many-arguments,import-error 6 | from typing import Callable, List, Optional, Union 7 | 8 | from aix360.algorithms.tsice import TSICEExplainer as TSICEExplainerAIX 9 | from aix360.algorithms.tsutils.tsperturbers import TSPerturber 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | from sklearn.linear_model import LinearRegression 14 | 15 | from trustyai.explainers.explanation_results import ExplanationResults 16 | 17 | 18 | class TSICEResults(ExplanationResults): 19 | """Wraps TSICE results. This object is returned by the :class:`~TSICEExplainer`, 20 | and provides a variety of methods to visualize and interact with the explanation. 21 | """ 22 | 23 | def __init__(self, explanation): 24 | self.explanation = explanation 25 | 26 | def as_dataframe(self) -> pd.DataFrame: 27 | """Returns the explanation as a pandas dataframe.""" 28 | # Initialize an empty DataFrame 29 | dataframe = pd.DataFrame() 30 | 31 | # Loop through each feature_name and each key in data_x 32 | for key in self.explanation["data_x"]: 33 | for i, feature in enumerate(self.explanation["feature_names"]): 34 | dataframe[f"{key}-{feature}"] = [ 35 | val[0] for val in self.explanation["feature_values"][i] 36 | ] 37 | 38 | # Add "total_impact" as a column 39 | dataframe["total_impact"] = self.explanation["total_impact"] 40 | return dataframe 41 | 42 | def as_html(self) -> pd.io.formats.style.Styler: 43 | """Returns the explanation as an HTML table.""" 44 | dataframe = self.as_dataframe() 45 | return dataframe.style 46 | 47 | def plot_forecast(self, variable): # pylint: disable=too-many-locals 48 | """Plots the explanation. 49 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py""" 50 | forecast_horizon = self.explanation["current_forecast"].shape[0] 51 | original_ts = pd.DataFrame( 52 | data={variable: self.explanation["data_x"][variable]} 53 | ) 54 | perturbations = [d for d in self.explanation["perturbations"] if variable in d] 55 | 56 | # Generate a list of keys 57 | keys = list(self.explanation["data_x"].keys()) 58 | # Find the index of the given key 59 | key = keys.index(variable) 60 | forecasts_on_perturbations = [ 61 | arr[:, key : key + 1] 62 | for arr in self.explanation["forecasts_on_perturbations"] 63 | ] 64 | 65 | new_perturbations = [] 66 | new_timestamps = [] 67 | pred_ts = [] 68 | 69 | original_ts.index.freq = pd.infer_freq(original_ts.index) 70 | for i in range(1, forecast_horizon + 1): 71 | new_timestamps.append(original_ts.index[-1] + (i * original_ts.index.freq)) 72 | 73 | for perturbation in perturbations: 74 | new_perturbations.append(pd.DataFrame(perturbation)) 75 | 76 | for forecast in forecasts_on_perturbations: 77 | pred_ts.append(pd.DataFrame(forecast, index=new_timestamps)) 78 | 79 | current_forecast = self.explanation["current_forecast"][:, key : key + 1] 80 | pred_original_ts = pd.DataFrame(current_forecast, index=new_timestamps) 81 | 82 | _, axis = plt.subplots() 83 | 84 | # Plot perturbed time series 85 | axis = self._plot_timeseries( 86 | new_perturbations, 87 | color="lightgreen", 88 | axis=axis, 89 | name="perturbed timeseries samples", 90 | ) 91 | 92 | # Plot original time series 93 | axis = self._plot_timeseries( 94 | original_ts, color="green", axis=axis, name="input/original timeseries" 95 | ) 96 | 97 | # Plot varying forecast range 98 | axis = self._plot_timeseries( 99 | pred_ts, color="lightblue", axis=axis, name="forecast on perturbed samples" 100 | ) 101 | 102 | # Plot original forecast 103 | axis = self._plot_timeseries( 104 | pred_original_ts, color="blue", axis=axis, name="original forecast" 105 | ) 106 | 107 | # Set labels and title 108 | axis.set_xlabel("Timestamp") 109 | axis.set_ylabel(variable) 110 | axis.set_title("Time-Series Individual Conditional Expectation (TSICE)") 111 | 112 | axis.legend() 113 | 114 | # Display the plot 115 | plt.show() 116 | 117 | def _plot_timeseries( 118 | self, timeseries, color="green", axis=None, name="time series" 119 | ): 120 | showlegend = True 121 | if isinstance(timeseries, dict): 122 | data = timeseries 123 | if isinstance(color, str): 124 | color = {k: color for k in data} 125 | elif isinstance(timeseries, list): 126 | data = {} 127 | for k, ts_data in enumerate(timeseries): 128 | data[k] = ts_data 129 | if isinstance(color, str): 130 | color = {k: color for k in data} 131 | else: 132 | data = {} 133 | data["default"] = timeseries 134 | color = {"default": color} 135 | 136 | if axis is None: 137 | _, axis = plt.subplots() 138 | 139 | first = True 140 | for key, _timeseries in data.items(): 141 | if not first: 142 | showlegend = False 143 | 144 | self._add_timeseries( 145 | axis, _timeseries, color=color[key], showlegend=showlegend, name=name 146 | ) 147 | first = False 148 | 149 | return axis 150 | 151 | def _add_timeseries( 152 | self, axis, timeseries, color="green", name="time series", showlegend=False 153 | ): 154 | timestamps = timeseries.index 155 | axis.plot( 156 | timestamps, 157 | timeseries[timeseries.columns[0]], 158 | color=color, 159 | label=(name if showlegend else "_nolegend_"), 160 | ) 161 | 162 | def plot_impact(self, feature_per_row=2): 163 | """Plot the impace. 164 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tsice/plots.py""" 165 | 166 | n_row = int(np.ceil(len(self.explanation["feature_names"]) / feature_per_row)) 167 | feat_values = np.array(self.explanation["feature_values"]) 168 | 169 | fig, axs = plt.subplots(n_row, feature_per_row, figsize=(15, 15)) 170 | axs = axs.ravel() # Flatten the axs to iterate over it 171 | 172 | for i, feat in enumerate(self.explanation["feature_names"]): 173 | x_feat = feat_values[i, :, 0] 174 | trend_fit = LinearRegression() 175 | trend_line = trend_fit.fit( 176 | x_feat.reshape(-1, 1), self.explanation["signed_impact"] 177 | ) 178 | x_trend = np.linspace(min(x_feat), max(x_feat), 101) 179 | y_trend = trend_line.predict(x_trend[..., np.newaxis]) 180 | 181 | # Scatter plot 182 | axs[i].scatter(x=x_feat, y=self.explanation["signed_impact"], color="blue") 183 | # Line plot 184 | axs[i].plot( 185 | x_trend, 186 | y_trend, 187 | color="green", 188 | label="correlation between forecast and observed feature", 189 | ) 190 | # Reference line 191 | current_value = self.explanation["current_feature_values"][i][0] 192 | axs[i].axvline( 193 | x=current_value, 194 | color="firebrick", 195 | linestyle="--", 196 | label="current value", 197 | ) 198 | 199 | axs[i].set_xlabel(feat) 200 | axs[i].set_ylabel("Δ forecast") 201 | 202 | # Display the legend on the first subplot 203 | axs[0].legend() 204 | 205 | fig.suptitle("Impact of Derived Variable On The Forecast", fontsize=16) 206 | plt.tight_layout() 207 | plt.subplots_adjust(top=0.95) 208 | plt.show() 209 | 210 | 211 | class TSICEExplainer(TSICEExplainerAIX): 212 | """ 213 | Wrapper for TSICEExplainer from aix360. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | model: Callable, 219 | input_length: int, 220 | forecast_lookahead: int, 221 | n_variables: int = 1, 222 | n_exogs: int = 0, 223 | n_perturbations: int = 25, 224 | features_to_analyze: Optional[List[str]] = None, 225 | perturbers: Optional[List[Union[TSPerturber, dict]]] = None, 226 | explanation_window_start: Optional[int] = None, 227 | explanation_window_length: int = 10, 228 | ): 229 | super().__init__( 230 | forecaster=model, 231 | input_length=input_length, 232 | forecast_lookahead=forecast_lookahead, 233 | n_variables=n_variables, 234 | n_exogs=n_exogs, 235 | n_perturbations=n_perturbations, 236 | features_to_analyze=features_to_analyze, 237 | perturbers=perturbers, 238 | explanation_window_start=explanation_window_start, 239 | explanation_window_length=explanation_window_length, 240 | ) 241 | 242 | def explain(self, inputs, outputs=None, **kwargs) -> TSICEResults: 243 | """ 244 | Explain the model's prediction on X. 245 | """ 246 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 247 | return TSICEResults(_explanation) 248 | -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tslime.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSLIME from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | 6 | from typing import Callable, List, Union 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from aix360.algorithms.tslime import TSLimeExplainer as TSLimeExplainerAIX 11 | from aix360.algorithms.tslime.surrogate import LinearSurrogateModel 12 | from pandas.io.formats.style import Styler 13 | import matplotlib.pyplot as plt 14 | 15 | from trustyai.explainers.explanation_results import ExplanationResults 16 | from trustyai.utils.extras.timeseries import TSPerturber 17 | 18 | 19 | class TSSLIMEResults(ExplanationResults): 20 | """Wraps TSLimeExplainer results. This object is returned by the :class:`~TSLimeExplainer`, 21 | and provides a variety of methods to visualize and interact with the explanation. 22 | """ 23 | 24 | def __init__(self, explanation): 25 | self.explanation = explanation 26 | 27 | def as_dataframe(self) -> pd.DataFrame: 28 | """Returns the weights as a pandas dataframe.""" 29 | return pd.DataFrame(self.explanation["history_weights"]) 30 | 31 | def as_html(self) -> Styler: 32 | """Returns the explanation as an HTML table.""" 33 | dataframe = self.as_dataframe() 34 | return dataframe.style 35 | 36 | def plot(self): 37 | """Plot TSLime explanation for the time-series instance. Based on 38 | https://github.com/Trusted-AI/AIX360/blob/master/examples/tslime/tslime_univariate_demo.ipynb""" 39 | relevant_history = self.explanation["history_weights"].shape[0] 40 | input_data = self.explanation["input_data"] 41 | relevant_df = input_data[-relevant_history:] 42 | 43 | plt.figure(layout="constrained") 44 | plt.plot(relevant_df, label="Input Time Series", marker="o") 45 | plt.gca().invert_yaxis() 46 | 47 | normalized_weights = ( 48 | self.explanation["history_weights"] 49 | / np.mean(np.abs(self.explanation["history_weights"])) 50 | ).flatten() 51 | 52 | plt.bar( 53 | input_data.index[-relevant_history:], 54 | normalized_weights, 55 | 0.4, 56 | label="TSLime Weights (Normalized)", 57 | color="red", 58 | ) 59 | plt.axhline(y=0, color="r", linestyle="-", alpha=0.4) 60 | plt.title("Time Series Lime Explanation Plot") 61 | plt.legend(bbox_to_anchor=(1.25, 1.0), loc="upper right") 62 | plt.show() 63 | 64 | 65 | class TSLimeExplainer(TSLimeExplainerAIX): 66 | """ 67 | Wrapper for TSLimeExplainer from aix360. 68 | """ 69 | 70 | def __init__( # pylint: disable=too-many-arguments 71 | self, 72 | model: Callable, 73 | input_length: int, 74 | n_perturbations: int = 2000, 75 | relevant_history: int = None, 76 | perturbers: List[Union[TSPerturber, dict]] = None, 77 | local_interpretable_model: LinearSurrogateModel = None, 78 | random_seed: int = None, 79 | ): 80 | super().__init__( 81 | model=model, 82 | input_length=input_length, 83 | n_perturbations=n_perturbations, 84 | relevant_history=relevant_history, 85 | perturbers=perturbers, 86 | local_interpretable_model=local_interpretable_model, 87 | random_seed=random_seed, 88 | ) 89 | 90 | def explain(self, inputs, outputs=None, **kwargs) -> TSSLIMEResults: 91 | """ 92 | Explain the model's prediction on X. 93 | """ 94 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 95 | return TSSLIMEResults(_explanation) 96 | -------------------------------------------------------------------------------- /src/trustyai/explainers/extras/tssaliency.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrapper module for TSSaliencyExplainer from aix360. 3 | Original at https://github.com/Trusted-AI/AIX360/ 4 | """ 5 | 6 | from typing import Callable, List 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from aix360.algorithms.tssaliency import TSSaliencyExplainer as TSSaliencyExplainerAIX 11 | from pandas.io.formats.style import Styler 12 | import matplotlib.pyplot as plt 13 | 14 | from trustyai.explainers.explanation_results import ExplanationResults 15 | 16 | 17 | class TSSaliencyResults(ExplanationResults): 18 | """Wraps TSSaliency results. This object is returned by the :class:`~TSSaliencyExplainer`, 19 | and provides a variety of methods to visualize and interact with the explanation. 20 | """ 21 | 22 | def __init__(self, explanation): 23 | self.explanation = explanation 24 | 25 | def as_dataframe(self) -> pd.DataFrame: 26 | saliencies = self.explanation["saliency"].reshape(-1) 27 | return pd.DataFrame(saliencies, columns=self.explanation["feature_names"]) 28 | 29 | def as_html(self) -> Styler: 30 | """Returns the explanation as an HTML table.""" 31 | dataframe = self.as_dataframe() 32 | return dataframe.style 33 | 34 | def plot(self, index: int, cpos, window: int = None): 35 | """Plot tssaliency explanation for the test point 36 | Based on https://github.com/Trusted-AI/AIX360/blob/master/examples/tssaliency""" 37 | if window: 38 | scores = ( 39 | np.convolve( 40 | self.explanation["saliency"].flatten(), np.ones(window), mode="same" 41 | ) 42 | / window 43 | ) 44 | else: 45 | scores = self.explanation["saliency"] 46 | 47 | vmax = np.max(np.abs(self.explanation["saliency"])) 48 | 49 | plt.figure(layout="constrained") 50 | plt.imshow( 51 | scores[np.newaxis, :], aspect="auto", cmap="seismic", vmin=-vmax, vmax=vmax 52 | ) 53 | plt.colorbar() 54 | plt.plot(self.explanation["input_data"]) 55 | instance = self.explanation["instance_prediction"] 56 | plt.title( 57 | "Time Series Saliency Explanation Plot for test point" 58 | f" i={index} with P(Y={cpos})= {instance}" 59 | ) 60 | plt.show() 61 | 62 | 63 | class TSSaliencyExplainer(TSSaliencyExplainerAIX): 64 | """ 65 | Wrapper for TSSaliencyExplainer from aix360. 66 | """ 67 | 68 | def __init__( # pylint: disable=too-many-arguments 69 | self, 70 | model: Callable, 71 | input_length: int, 72 | feature_names: List[str], 73 | base_value: List[float] = None, 74 | n_samples: int = 50, 75 | gradient_samples: int = 25, 76 | gradient_function: Callable = None, 77 | random_seed: int = 22, 78 | ): 79 | super().__init__( 80 | model=model, 81 | input_length=input_length, 82 | feature_names=feature_names, 83 | base_value=base_value, 84 | n_samples=n_samples, 85 | gradient_samples=gradient_samples, 86 | gradient_function=gradient_function, 87 | random_seed=random_seed, 88 | ) 89 | 90 | def explain(self, inputs, outputs=None, **kwargs) -> TSSaliencyResults: 91 | """ 92 | Explain the model's prediction on X. 93 | """ 94 | _explanation = super().explain_instance(inputs, y=outputs, **kwargs) 95 | return TSSaliencyResults(_explanation) 96 | -------------------------------------------------------------------------------- /src/trustyai/explainers/pdp.py: -------------------------------------------------------------------------------- 1 | """Explainers.pdp module""" 2 | import math 3 | import pandas as pd 4 | from pandas.io.formats.style import Styler 5 | 6 | from jpype import ( 7 | JImplements, 8 | JOverride, 9 | ) 10 | 11 | # pylint: disable = import-error 12 | from org.kie.trustyai.explainability.global_ import pdp 13 | 14 | # pylint: disable = import-error 15 | from org.kie.trustyai.explainability.model import ( 16 | PredictionProvider, 17 | PredictionInputsDataDistribution, 18 | PredictionOutput, 19 | Output, 20 | Type, 21 | Value, 22 | ) 23 | 24 | from trustyai.utils.data_conversions import ManyInputsUnionType, many_inputs_convert 25 | 26 | from .explanation_results import ExplanationResults 27 | 28 | 29 | class PDPResults(ExplanationResults): 30 | """ 31 | Results class for Partial Dependence Plots 32 | """ 33 | 34 | def __init__(self, pdp_graphs): 35 | self.pdp_graphs = pdp_graphs 36 | 37 | def as_dataframe(self) -> pd.DataFrame: 38 | """ 39 | Returns 40 | ------- 41 | a pd.DataFrame with input values and feature name as 42 | columns and marginal feature outputs as rows 43 | """ 44 | pdp_series_list = [] 45 | for pdp_graph in self.pdp_graphs: 46 | inputs = [self._to_plottable(x) for x in pdp_graph.getX()] 47 | outputs = [self._to_plottable(y) for y in pdp_graph.getY()] 48 | pdp_dict = dict(zip(inputs, outputs)) 49 | pdp_dict["feature"] = "" + str(pdp_graph.getFeature().getName()) 50 | pdp_series = pd.Series(index=inputs + ["feature"], data=pdp_dict) 51 | pdp_series_list.append(pdp_series) 52 | pdp_df = pd.DataFrame(pdp_series_list) 53 | return pdp_df 54 | 55 | def as_html(self) -> Styler: 56 | """ 57 | Returns 58 | ------- 59 | Style object from the PDP pd.DataFrame (see as_dataframe) 60 | """ 61 | return self.as_dataframe().style 62 | 63 | @staticmethod 64 | def _to_plottable(datum: Value): 65 | plottable = datum.asNumber() 66 | if math.isnan(plottable): 67 | plottable = str(datum.asString()) 68 | return plottable 69 | 70 | 71 | # pylint: disable = too-few-public-methods 72 | class PDPExplainer: 73 | """ 74 | Partial Dependence Plot explainer. 75 | See https://christophm.github.io/interpretable-ml-book/pdp.html 76 | """ 77 | 78 | def __init__(self, config=None): 79 | if config is None: 80 | config = pdp.PartialDependencePlotConfig() 81 | self._explainer = pdp.PartialDependencePlotExplainer(config) 82 | 83 | def explain( 84 | self, model: PredictionProvider, data: ManyInputsUnionType, num_outputs: int = 1 85 | ) -> PDPResults: 86 | """ 87 | Parameters 88 | ---------- 89 | model: PredictionProvider 90 | the model to explain 91 | data: ManyInputsUnionType 92 | the data used to calculate the PDP 93 | num_outputs: int 94 | the number of outputs to calculate the PDP for 95 | 96 | Returns 97 | ------- 98 | pdp_results: PDPResults 99 | the partial dependence plots associated to the model outputs 100 | """ 101 | metadata = _PredictionProviderMetadata(many_inputs_convert(data), num_outputs) 102 | pdp_graphs = self._explainer.explainFromMetadata(model, metadata) 103 | return PDPResults(pdp_graphs) 104 | 105 | 106 | @JImplements( 107 | "org.kie.trustyai.explainability.model.PredictionProviderMetadata", deferred=True 108 | ) 109 | class _PredictionProviderMetadata: 110 | """ 111 | Implementation of org.kie.trustyai.explainability.model.PredictionProviderMetadata interface 112 | """ 113 | 114 | def __init__(self, data: list, size: int): 115 | """ 116 | Parameters 117 | ---------- 118 | data: ManyInputsUnionType 119 | the data 120 | size: int 121 | the size of the model output 122 | """ 123 | self.data = PredictionInputsDataDistribution(data) 124 | outputs = [] 125 | for _ in range(size): 126 | outputs.append(Output("", Type.UNDEFINED)) 127 | self.pred_out = PredictionOutput(outputs) 128 | 129 | # pylint: disable = invalid-name 130 | @JOverride 131 | def getDataDistribution(self): 132 | """ 133 | Returns 134 | -------- 135 | the underlying data distribution 136 | """ 137 | return self.data 138 | 139 | # pylint: disable = invalid-name 140 | @JOverride 141 | def getInputShape(self): 142 | """ 143 | Returns 144 | -------- 145 | a PredictionInput from the underlying distribution 146 | """ 147 | return self.data.sample() 148 | 149 | # pylint: disable = invalid-name, missing-final-newline 150 | @JOverride 151 | def getOutputShape(self): 152 | """ 153 | Returns 154 | -------- 155 | a PredictionOutput 156 | """ 157 | return self.pred_out 158 | -------------------------------------------------------------------------------- /src/trustyai/initializer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, import-outside-toplevel, dangerous-default-value, invalid-name, R0801 2 | # pylint: disable = deprecated-module 3 | """Main TrustyAI Python bindings""" 4 | from distutils.sysconfig import get_python_lib 5 | import glob 6 | import logging 7 | import os 8 | from pathlib import Path 9 | import site 10 | from typing import List 11 | import uuid 12 | import warnings 13 | 14 | import jpype 15 | import jpype.imports 16 | from jpype import _jcustomizer, _jclass 17 | 18 | DEFAULT_ARGS = ( 19 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 20 | # see https://arrow.apache.org/docs/java/install.html#java-compatibility 21 | "-Dorg.slf4j.simpleLogger.defaultLogLevel=error", 22 | ) 23 | 24 | 25 | def _get_default_path(): 26 | try: 27 | default_dep_path = os.path.join(site.getsitepackages()[0], "trustyai", "dep") 28 | except AttributeError: 29 | default_dep_path = os.path.join(get_python_lib(), "trustyai", "dep") 30 | 31 | core_deps = [ 32 | f"{default_dep_path}/org/trustyai/explainability-arrow-999-SNAPSHOT.jar", 33 | ] 34 | 35 | return core_deps, default_dep_path 36 | 37 | 38 | def init(*args, path=None): 39 | """init(*args, path=JAVA_DEPENDENCIES) 40 | 41 | Manually initialize the JVM. If you would like to manually specify the Java libraries to be 42 | imported, for example if you want to use a different version of the Trusty Explainability 43 | library than is bundled by default, you can do so by calling :func:`init`. If this is not 44 | manually called, trustyai will use the default set of libraries and automatically initialize 45 | itself when necessary. 46 | 47 | Parameters 48 | ---------- 49 | args: list 50 | List of args to be passed to ``jpype.startJVM``. See the 51 | `JPype manual `_ 52 | for more details. 53 | path: list[str] 54 | List of jar files to add the Java class path. By default, this will add the necessary 55 | dependencies of the TrustyAI Java library. 56 | """ 57 | # Launch the JVM 58 | try: 59 | # get default dependencies 60 | if path is None: 61 | path, default_dep_path = _get_default_path() 62 | logging.debug("Checking for dependencies in %s", default_dep_path) 63 | 64 | # check the classpath 65 | for jar_path in path: 66 | if "*" not in jar_path: 67 | jar_path_exists = Path(jar_path).exists() 68 | else: 69 | jar_path_exists = any( 70 | Path(fp).exists() for fp in glob.glob(jar_path) if ".jar" in fp 71 | ) 72 | if jar_path_exists: 73 | logging.debug("JAR %s found.", jar_path) 74 | else: 75 | logging.error("JAR %s not found.", jar_path) 76 | 77 | _args = args + DEFAULT_ARGS 78 | jpype.startJVM(*_args, classpath=path) 79 | 80 | from java.lang import Thread 81 | 82 | if not Thread.isAttached: 83 | jpype.attachThreadToJVM() 84 | 85 | from java.util import UUID 86 | 87 | @_jcustomizer.JConversion("java.util.List", exact=List) 88 | def _JListConvert(_, py_list: List): 89 | return _jclass.JClass("java.util.Arrays").asList(py_list) 90 | 91 | @_jcustomizer.JConversion("java.util.UUID", instanceof=uuid.UUID) 92 | def _JUUIDConvert(_, obj): 93 | return UUID.fromString(str(obj)) 94 | 95 | except OSError: 96 | print("JVM already initialized") 97 | warnings.warn("JVM already initialized") 98 | 99 | return True 100 | -------------------------------------------------------------------------------- /src/trustyai/language/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/src/trustyai/language/__init__.py -------------------------------------------------------------------------------- /src/trustyai/language/detoxify/__init__.py: -------------------------------------------------------------------------------- 1 | """Language detoxification module.""" 2 | from trustyai.language.detoxify.tmarco import TMaRCo 3 | -------------------------------------------------------------------------------- /src/trustyai/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/src/trustyai/local/__init__.py -------------------------------------------------------------------------------- /src/trustyai/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, invalid-name, wrong-import-order, no-name-in-module 2 | """General model classes""" 3 | from trustyai import _default_initializer # pylint: disable=unused-import 4 | from org.kie.trustyai.metrics.explainability import ( 5 | ExplainabilityMetrics as _ExplainabilityMetrics, 6 | ) 7 | 8 | ExplainabilityMetrics = _ExplainabilityMetrics 9 | -------------------------------------------------------------------------------- /src/trustyai/metrics/distance.py: -------------------------------------------------------------------------------- 1 | """"Distance metrics""" 2 | # pylint: disable = import-error 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Union, Callable 5 | 6 | from org.kie.trustyai.metrics.language.distance import ( 7 | Levenshtein as _Levenshtein, 8 | LevenshteinResult as _LevenshteinResult, 9 | LevenshteinCounters as _LevenshteinCounters, 10 | ) 11 | from opennlp.tools.tokenize import Tokenizer 12 | import numpy as np 13 | from trustyai import _default_initializer # pylint: disable=unused-import 14 | 15 | 16 | @dataclass 17 | class LevenshteinCounters: 18 | """LevenshteinCounters Counters""" 19 | 20 | substitutions: int 21 | insertions: int 22 | deletions: int 23 | correct: int 24 | 25 | @staticmethod 26 | def convert(result: _LevenshteinCounters): 27 | """Converts a Java LevenshteinCounters to a Python LevenshteinCounters""" 28 | return LevenshteinCounters( 29 | substitutions=result.getSubstitutions(), 30 | insertions=result.getInsertions(), 31 | deletions=result.getDeletions(), 32 | correct=result.getCorrect(), 33 | ) 34 | 35 | 36 | @dataclass 37 | class LevenshteinResult: 38 | """Levenshtein Result""" 39 | 40 | distance: float 41 | counters: LevenshteinCounters 42 | matrix: np.ndarray 43 | reference: List[str] 44 | hypothesis: List[str] 45 | 46 | @staticmethod 47 | def convert(result: _LevenshteinResult): 48 | """Converts a Java LevenshteinResult to a Python LevenshteinResult""" 49 | distance = result.getDistance() 50 | counters = LevenshteinCounters.convert(result.getCounters()) 51 | data = result.getDistanceMatrix().getData() 52 | numpy_array = np.array(data)[1:, 1:] 53 | reference = result.getReferenceTokens() 54 | hypothesis = result.getHypothesisTokens() 55 | 56 | return LevenshteinResult( 57 | distance=distance, 58 | counters=counters, 59 | matrix=numpy_array, 60 | reference=reference, 61 | hypothesis=hypothesis, 62 | ) 63 | 64 | 65 | def levenshtein( 66 | reference: str, 67 | hypothesis: str, 68 | tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None, 69 | ) -> LevenshteinResult: 70 | """Calculate Levenshtein distance between two strings""" 71 | if not tokenizer: 72 | return LevenshteinResult.convert( 73 | _Levenshtein.calculateToken(reference, hypothesis) 74 | ) 75 | if isinstance(tokenizer, Tokenizer): 76 | return LevenshteinResult.convert( 77 | _Levenshtein.calculateToken(reference, hypothesis, tokenizer) 78 | ) 79 | if callable(tokenizer): 80 | tokenized_reference = tokenizer(reference) 81 | tokenized_hypothesis = tokenizer(hypothesis) 82 | return LevenshteinResult.convert( 83 | _Levenshtein.calculateToken(tokenized_reference, tokenized_hypothesis) 84 | ) 85 | 86 | raise ValueError("Unsupported tokenizer") 87 | -------------------------------------------------------------------------------- /src/trustyai/metrics/fairness/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/src/trustyai/metrics/fairness/__init__.py -------------------------------------------------------------------------------- /src/trustyai/metrics/fairness/group.py: -------------------------------------------------------------------------------- 1 | """Group fairness metrics""" 2 | # pylint: disable = import-error 3 | from typing import List, Optional, Any, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from jpype import JInt 8 | from org.kie.trustyai.metrics.fairness.group import ( 9 | DisparateImpactRatio, 10 | GroupStatisticalParityDifference, 11 | GroupAverageOddsDifference, 12 | GroupAveragePredictiveValueDifference, 13 | ) 14 | 15 | from trustyai.model import Value, PredictionProvider, Model 16 | from trustyai.utils.data_conversions import ( 17 | OneOutputUnionType, 18 | one_output_convert, 19 | to_trusty_dataframe, 20 | ) 21 | 22 | ColumSelector = Union[List[int], List[str]] 23 | 24 | 25 | def _column_selector_to_index(columns: ColumSelector, dataframe: pd.DataFrame): 26 | """Returns a list of input and output indices, given an index size and output indices""" 27 | if len(columns) == 0: 28 | raise ValueError("Must specify at least one column") 29 | 30 | if isinstance(columns[0], str): # passing column 31 | columns = dataframe.columns.get_indexer(columns) 32 | indices = [JInt(c) for c in columns] # Java casting 33 | return indices 34 | 35 | 36 | def statistical_parity_difference( 37 | privileged: Union[pd.DataFrame, np.ndarray], 38 | unprivileged: Union[pd.DataFrame, np.ndarray], 39 | favorable: OneOutputUnionType, 40 | outputs: Optional[List[int]] = None, 41 | feature_names: Optional[List[str]] = None, 42 | ) -> float: 43 | """Calculate Statistical Parity Difference between privileged and unprivileged dataframes""" 44 | favorable_prediction_object = one_output_convert(favorable) 45 | return GroupStatisticalParityDifference.calculate( 46 | to_trusty_dataframe( 47 | data=privileged, outputs=outputs, feature_names=feature_names 48 | ), 49 | to_trusty_dataframe( 50 | data=unprivileged, outputs=outputs, feature_names=feature_names 51 | ), 52 | favorable_prediction_object.outputs, 53 | ) 54 | 55 | 56 | # pylint: disable = line-too-long, too-many-arguments 57 | def statistical_parity_difference_model( 58 | samples: Union[pd.DataFrame, np.ndarray], 59 | model: Union[PredictionProvider, Model], 60 | privilege_columns: ColumSelector, 61 | privilege_values: List[Any], 62 | favorable: OneOutputUnionType, 63 | feature_names: Optional[List[str]] = None, 64 | ) -> float: 65 | """Calculate Statistical Parity Difference using a samples dataframe and a model""" 66 | favorable_prediction_object = one_output_convert(favorable) 67 | _privilege_values = [Value(v) for v in privilege_values] 68 | _jsamples = to_trusty_dataframe( 69 | data=samples, no_outputs=True, feature_names=feature_names 70 | ) 71 | return GroupStatisticalParityDifference.calculate( 72 | _jsamples, 73 | model, 74 | _column_selector_to_index(privilege_columns, samples), 75 | _privilege_values, 76 | favorable_prediction_object.outputs, 77 | ) 78 | 79 | 80 | def disparate_impact_ratio( 81 | privileged: Union[pd.DataFrame, np.ndarray], 82 | unprivileged: Union[pd.DataFrame, np.ndarray], 83 | favorable: OneOutputUnionType, 84 | outputs: Optional[List[int]] = None, 85 | feature_names: Optional[List[str]] = None, 86 | ) -> float: 87 | """Calculate Disparate Impact Ration between privileged and unprivileged dataframes""" 88 | favorable_prediction_object = one_output_convert(favorable) 89 | return DisparateImpactRatio.calculate( 90 | to_trusty_dataframe( 91 | data=privileged, outputs=outputs, feature_names=feature_names 92 | ), 93 | to_trusty_dataframe( 94 | data=unprivileged, outputs=outputs, feature_names=feature_names 95 | ), 96 | favorable_prediction_object.outputs, 97 | ) 98 | 99 | 100 | # pylint: disable = line-too-long 101 | def disparate_impact_ratio_model( 102 | samples: Union[pd.DataFrame, np.ndarray], 103 | model: Union[PredictionProvider, Model], 104 | privilege_columns: ColumSelector, 105 | privilege_values: List[Any], 106 | favorable: OneOutputUnionType, 107 | feature_names: Optional[List[str]] = None, 108 | ) -> float: 109 | """Calculate Disparate Impact Ration using a samples dataframe and a model""" 110 | favorable_prediction_object = one_output_convert(favorable) 111 | _privilege_values = [Value(v) for v in privilege_values] 112 | _jsamples = to_trusty_dataframe( 113 | data=samples, no_outputs=True, feature_names=feature_names 114 | ) 115 | return DisparateImpactRatio.calculate( 116 | _jsamples, 117 | model, 118 | _column_selector_to_index(privilege_columns, samples), 119 | _privilege_values, 120 | favorable_prediction_object.outputs, 121 | ) 122 | 123 | 124 | # pylint: disable = too-many-arguments 125 | def average_odds_difference( 126 | test: Union[pd.DataFrame, np.ndarray], 127 | truth: Union[pd.DataFrame, np.ndarray], 128 | privilege_columns: ColumSelector, 129 | privilege_values: OneOutputUnionType, 130 | positive_class: List[Any], 131 | outputs: Optional[List[int]] = None, 132 | feature_names: Optional[List[str]] = None, 133 | ) -> float: 134 | """Calculate Average Odds between two dataframes""" 135 | if test.shape != truth.shape: 136 | raise ValueError( 137 | f"Dataframes have different shapes ({test.shape} and {truth.shape})" 138 | ) 139 | _privilege_values = [Value(v) for v in privilege_values] 140 | _positive_class = [Value(v) for v in positive_class] 141 | # determine privileged columns 142 | _privilege_columns = _column_selector_to_index(privilege_columns, test) 143 | return GroupAverageOddsDifference.calculate( 144 | to_trusty_dataframe(data=test, outputs=outputs, feature_names=feature_names), 145 | to_trusty_dataframe(data=truth, outputs=outputs, feature_names=feature_names), 146 | _privilege_columns, 147 | _privilege_values, 148 | _positive_class, 149 | ) 150 | 151 | 152 | def average_odds_difference_model( 153 | samples: Union[pd.DataFrame, np.ndarray], 154 | model: Union[PredictionProvider, Model], 155 | privilege_columns: ColumSelector, 156 | privilege_values: List[Any], 157 | positive_class: List[Any], 158 | feature_names: Optional[List[str]] = None, 159 | ) -> float: 160 | """Calculate Average Odds for a sample dataframe using the provided model""" 161 | _jsamples = to_trusty_dataframe( 162 | data=samples, no_outputs=True, feature_names=feature_names 163 | ) 164 | _privilege_values = [Value(v) for v in privilege_values] 165 | _positive_class = [Value(v) for v in positive_class] 166 | # determine privileged columns 167 | _privilege_columns = _column_selector_to_index(privilege_columns, samples) 168 | return GroupAverageOddsDifference.calculate( 169 | _jsamples, model, _privilege_columns, _privilege_values, _positive_class 170 | ) 171 | 172 | 173 | def average_predictive_value_difference( 174 | test: Union[pd.DataFrame, np.ndarray], 175 | truth: Union[pd.DataFrame, np.ndarray], 176 | privilege_columns: ColumSelector, 177 | privilege_values: List[Any], 178 | positive_class: List[Any], 179 | outputs: Optional[List[int]] = None, 180 | feature_names: Optional[List[str]] = None, 181 | ) -> float: 182 | """Calculate Average Predictive Value Difference between two dataframes""" 183 | if test.shape != truth.shape: 184 | raise ValueError( 185 | f"Dataframes have different shapes ({test.shape} and {truth.shape})" 186 | ) 187 | _privilege_values = [Value(v) for v in privilege_values] 188 | _positive_class = [Value(v) for v in positive_class] 189 | _privilege_columns = _column_selector_to_index(privilege_columns, test) 190 | return GroupAveragePredictiveValueDifference.calculate( 191 | to_trusty_dataframe(data=test, outputs=outputs, feature_names=feature_names), 192 | to_trusty_dataframe(data=truth, outputs=outputs, feature_names=feature_names), 193 | _privilege_columns, 194 | _privilege_values, 195 | _positive_class, 196 | ) 197 | 198 | 199 | # pylint: disable = line-too-long 200 | def average_predictive_value_difference_model( 201 | samples: Union[pd.DataFrame, np.ndarray], 202 | model: Union[PredictionProvider, Model], 203 | privilege_columns: ColumSelector, 204 | privilege_values: List[Any], 205 | positive_class: List[Any], 206 | ) -> float: 207 | """Calculate Average Predictive Value Difference for a sample dataframe using the provided model""" 208 | _jsamples = to_trusty_dataframe(samples, no_outputs=True) 209 | _privilege_values = [Value(v) for v in privilege_values] 210 | _positive_class = [Value(v) for v in positive_class] 211 | # determine privileged columns 212 | _privilege_columns = _column_selector_to_index(privilege_columns, samples) 213 | return GroupAveragePredictiveValueDifference.calculate( 214 | _jsamples, model, _privilege_columns, _privilege_values, _positive_class 215 | ) 216 | -------------------------------------------------------------------------------- /src/trustyai/metrics/language.py: -------------------------------------------------------------------------------- 1 | """"Language metrics""" 2 | # pylint: disable = import-error 3 | from dataclasses import dataclass 4 | 5 | from typing import List, Optional, Union, Callable 6 | 7 | from org.kie.trustyai.metrics.language.levenshtein import ( 8 | WordErrorRate as _WordErrorRate, 9 | ErrorRateResult as _ErrorRateResult, 10 | ) 11 | from opennlp.tools.tokenize import Tokenizer 12 | from trustyai import _default_initializer # pylint: disable=unused-import 13 | 14 | from .distance import LevenshteinCounters 15 | 16 | 17 | @dataclass 18 | class ErrorRateResult: 19 | """Word Error Rate Result""" 20 | 21 | value: float 22 | alignment_counters: LevenshteinCounters 23 | 24 | @staticmethod 25 | def convert(result: _ErrorRateResult): 26 | """Converts a Java ErrorRateResult to a Python ErrorRateResult""" 27 | value = result.getValue() 28 | alignment_counters = result.getAlignmentCounters() 29 | return ErrorRateResult( 30 | value=value, 31 | alignment_counters=alignment_counters, 32 | ) 33 | 34 | 35 | def word_error_rate( 36 | reference: str, 37 | hypothesis: str, 38 | tokenizer: Optional[Union[Tokenizer, Callable[[str], List[str]]]] = None, 39 | ) -> ErrorRateResult: 40 | """Calculate Word Error Rate between reference and hypothesis strings""" 41 | if not tokenizer: 42 | _wer = _WordErrorRate() 43 | elif isinstance(tokenizer, Tokenizer): 44 | _wer = _WordErrorRate(tokenizer) 45 | elif callable(tokenizer): 46 | tokenized_reference = tokenizer(reference) 47 | tokenized_hypothesis = tokenizer(hypothesis) 48 | _wer = _WordErrorRate() 49 | return ErrorRateResult.convert( 50 | _wer.calculate(tokenized_reference, tokenized_hypothesis) 51 | ) 52 | else: 53 | raise ValueError("Unsupported tokenizer") 54 | return ErrorRateResult.convert(_wer.calculate(reference, hypothesis)) 55 | -------------------------------------------------------------------------------- /src/trustyai/metrics/saliency.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error 2 | """Saliency evaluation metrics""" 3 | from typing import Union 4 | 5 | from org.apache.commons.lang3.tuple import ( 6 | Pair as _Pair, 7 | ) 8 | 9 | from org.kie.trustyai.explainability.model import ( 10 | PredictionInput, 11 | PredictionInputsDataDistribution, 12 | ) 13 | from org.kie.trustyai.explainability.local import LocalExplainer 14 | 15 | from jpype import JObject 16 | 17 | from trustyai.model import simple_prediction, PredictionProvider 18 | from trustyai.explainers import SHAPExplainer, LimeExplainer 19 | 20 | from . import ExplainabilityMetrics 21 | 22 | 23 | def impact_score( 24 | model: PredictionProvider, 25 | pred_input: PredictionInput, 26 | explainer: Union[LimeExplainer, SHAPExplainer], 27 | k: int, 28 | is_model_callable: bool = False, 29 | ): 30 | """ 31 | Parameters 32 | ---------- 33 | model: trustyai.PredictionProvider 34 | the model used to generate predictions 35 | pred_input: trustyai.PredictionInput 36 | the input to the model 37 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 38 | the explainer to evaluate 39 | k: int 40 | the number of top important features 41 | is_model_callable: bool 42 | whether to directly use model function call or use the predict method 43 | 44 | Returns 45 | ------- 46 | :float: 47 | impact score metric 48 | """ 49 | if is_model_callable: 50 | output = model(pred_input) 51 | else: 52 | output = model.predict([pred_input])[0].outputs 53 | pred = simple_prediction(pred_input, output) 54 | explanation = explainer.explain(inputs=pred_input, outputs=output, model=model) 55 | saliency = list(explanation.saliency_map().values())[0] 56 | top_k_features = saliency.getTopFeatures(k) 57 | return ExplainabilityMetrics.impactScore(model, pred, top_k_features) 58 | 59 | 60 | def mean_impact_score( 61 | explainer: Union[LimeExplainer, SHAPExplainer], 62 | model: PredictionProvider, 63 | data: list, 64 | is_model_callable=False, 65 | k=2, 66 | ): 67 | """ 68 | Parameters 69 | ---------- 70 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 71 | the explainer to evaluate 72 | model: trustyai.PredictionProvider 73 | the model used to generate predictions 74 | data: list[list[trustyai.model.Feature]] 75 | the inputs to calculate the metric for 76 | is_model_callable: bool 77 | whether to directly use model function call or use the predict method 78 | k: int 79 | the number of top important features 80 | 81 | Returns 82 | ------- 83 | :float: 84 | the mean impact score metric across all inputs 85 | """ 86 | m_is = 0 87 | for features in data: 88 | m_is += impact_score( 89 | model, features, explainer, k, is_model_callable=is_model_callable 90 | ) 91 | return m_is / len(data) 92 | 93 | 94 | def classification_fidelity( 95 | explainer: Union[LimeExplainer, SHAPExplainer], 96 | model: PredictionProvider, 97 | inputs: list, 98 | is_model_callable: bool = False, 99 | ): 100 | """ 101 | Parameters 102 | ---------- 103 | explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer] 104 | the explainer to evaluate 105 | model: trustyai.PredictionProvider 106 | the model used to generate predictions 107 | inputs: list[list[trustyai.model.Feature]] 108 | the inputs to calculate the metric for 109 | is_model_callable: bool 110 | whether to directly use model function call or use the predict method 111 | 112 | Returns 113 | ------- 114 | :float: 115 | the classification fidelity metric 116 | """ 117 | pairs = [] 118 | for c_input in inputs: 119 | if is_model_callable: 120 | output = model(c_input) 121 | else: 122 | output = model.predict([c_input])[0].outputs 123 | explanation = explainer.explain(inputs=c_input, outputs=output, model=model) 124 | saliency = list(explanation.saliency_map().values())[0] 125 | pairs.append(_Pair.of(saliency, simple_prediction(c_input, output))) 126 | return ExplainabilityMetrics.classificationFidelity(pairs) 127 | 128 | 129 | # pylint: disable = too-many-arguments 130 | def local_saliency_f1( 131 | output_name: str, 132 | model: PredictionProvider, 133 | explainer: Union[LimeExplainer, SHAPExplainer], 134 | distribution: PredictionInputsDataDistribution, 135 | k: int, 136 | chunk_size: int, 137 | ): 138 | """ 139 | Parameters 140 | ---------- 141 | output_name: str 142 | the name of the output to calculate the metric for 143 | model: trustyai.PredictionProvider 144 | the model used to generate predictions 145 | explainer: Union[trustyai.explainers.LIMEExplainer, trustyai.explainers.SHAPExplainer, 146 | trustyai.explainers.LocalExplainer] 147 | the explainer to evaluate 148 | distribution: org.kie.trustyai.explainability.model.PredictionInputsDataDistribution 149 | the data distribution to fetch the inputs from 150 | k: int 151 | the number of top important features 152 | chunk_size: int 153 | the chunk of inputs to fetch fro the distribution 154 | 155 | Returns 156 | ------- 157 | :float: 158 | the local saliency f1 metric 159 | """ 160 | if not isinstance(explainer, LocalExplainer): 161 | # pylint: disable = protected-access 162 | local_explainer = JObject(explainer._explainer, LocalExplainer) 163 | else: 164 | local_explainer = explainer 165 | return ExplainabilityMetrics.getLocalSaliencyF1( 166 | output_name, model, local_explainer, distribution, k, chunk_size 167 | ) 168 | -------------------------------------------------------------------------------- /src/trustyai/model/domain.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error 2 | """Conversion method between Python and TrustyAI Java types""" 3 | from typing import Optional, Tuple, List, Union 4 | 5 | from jpype import _jclass 6 | 7 | from org.kie.trustyai.explainability.model.domain import ( 8 | FeatureDomain, 9 | NumericalFeatureDomain, 10 | CategoricalFeatureDomain, 11 | CategoricalNumericalFeatureDomain, 12 | ObjectFeatureDomain, 13 | EmptyFeatureDomain, 14 | ) 15 | 16 | 17 | def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDomain]: 18 | r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a 19 | particular feature, which is useful when constraining a counterfactual explanation to ensure it 20 | only recovers valid inputs. For example, if we had a feature that described a person's age, we 21 | might want to constrain it to the range [0, 125] to ensure the counterfactual explanation 22 | doesn't return unlikely ages such as -5 or 715. 23 | 24 | Parameters 25 | ---------- 26 | values : Optional[Union[Tuple, List]] 27 | The valid values of the feature. If ``values`` takes the form of: 28 | 29 | * **A tuple of floats or integers**: The feature domain will be a continuous range from 30 | ``values[0]`` to ``values[1]``. 31 | * **A list of floats or integers**: The feature domain will be a *numeric* categorical, 32 | where `values` contains all possible valid feature values. 33 | * **A list of strings**: The feature domain will be a *string* categorical, where ``values`` 34 | contains all possible valid feature values. 35 | * **A list of objects**: The feature domain will be an *object* categorical, where 36 | ``values`` contains all possible valid feature values. These may present an issue if the 37 | objects are not natively Java serializable. 38 | 39 | Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held 40 | fixed during the counterfactual explanation. 41 | 42 | Returns 43 | ------- 44 | :class:`FeatureDomain` 45 | A Java :class:`FeatureDomain` object, to be used in the :func:`~trustyai.model.feature` 46 | function. 47 | 48 | """ 49 | if not values: 50 | domain = EmptyFeatureDomain.create() 51 | else: 52 | if isinstance(values, tuple): 53 | assert isinstance(values[0], (float, int)) and isinstance( 54 | values[1], (float, int) 55 | ) 56 | assert len(values) == 2, ( 57 | "Tuples passed as domain values must only contain" 58 | " two values that define the (minimum, maximum) of the domain" 59 | ) 60 | domain = NumericalFeatureDomain.create(values[0], values[1]) 61 | 62 | elif isinstance(values, list): 63 | java_array = _jclass.JClass("java.util.Arrays").asList(values) 64 | if isinstance(values[0], bool) and isinstance(values[1], bool): 65 | domain = ObjectFeatureDomain.create(java_array) 66 | elif isinstance(values[0], (float, int)) and isinstance( 67 | values[1], (float, int) 68 | ): 69 | domain = CategoricalNumericalFeatureDomain.create(java_array) 70 | elif isinstance(values[0], str): 71 | domain = CategoricalFeatureDomain.create(java_array) 72 | else: 73 | domain = ObjectFeatureDomain.create(java_array) 74 | 75 | else: 76 | domain = EmptyFeatureDomain.create() 77 | return domain 78 | -------------------------------------------------------------------------------- /src/trustyai/utils/DataUtils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = invalid-name, import-error 2 | """DataUtils module""" 3 | from org.kie.trustyai.explainability.utils import DataUtils as du 4 | 5 | getMean = du.getMean 6 | getStdDev = du.getStdDev 7 | gaussianKernel = du.gaussianKernel 8 | euclideanDistance = du.euclideanDistance 9 | hammingDistance = du.hammingDistance 10 | doublesToFeatures = du.doublesToFeatures 11 | exponentialSmoothingKernel = du.exponentialSmoothingKernel 12 | generateRandomDataDistribution = du.generateRandomDataDistribution 13 | 14 | 15 | def generateData(mean, stdDeviation, size, jrandom): 16 | """Generate data""" 17 | return list(du.generateData(mean, stdDeviation, size, jrandom)) 18 | 19 | 20 | def perturbFeatures(originalFeatures, perturbationContext): 21 | """Perform perturbations on a fixed number of features in the given input.""" 22 | return du.perturbFeatures(originalFeatures, perturbationContext) 23 | 24 | 25 | def getLinearizedFeatures(originalFeatures): 26 | """Transform a list of eventually composite / nested features into a 27 | flat list of non composite / non nested features.""" 28 | return du.getLinearizedFeatures(originalFeatures) 29 | 30 | 31 | def sampleWithReplacement(values, sampleSize, jrandom): 32 | """Sample (with replacement) from a list of values.""" 33 | return du.sampleWithReplacement(values, sampleSize, jrandom) 34 | -------------------------------------------------------------------------------- /src/trustyai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = import-error, invalid-name, wrong-import-order 2 | """General model classes""" 3 | 4 | from jpype._jproxy import _createJProxy, _createJProxyDeferred 5 | from trustyai import _default_initializer 6 | 7 | from org.kie.trustyai.explainability import Config as _Config 8 | from org.kie.trustyai.explainability.utils.models import TestModels as _TestModels 9 | 10 | TestModels = _TestModels 11 | Config = _Config 12 | 13 | 14 | def JImplementsWithDocstring(*interfaces, deferred=False, **kwargs): 15 | """JPype's JImplements decorator overwrites the docstring of any annotated functions. This 16 | is a quick hack to preserve docstrings across the jproxy process.""" 17 | if deferred: 18 | 19 | def JProxyCreator(cls): 20 | proxy_class = _createJProxyDeferred(cls, *interfaces, **kwargs) 21 | proxy_class.__doc__ = cls.__doc__ 22 | proxy_class.__name__ = cls.__name__ 23 | return proxy_class 24 | 25 | else: 26 | 27 | def JProxyCreator(cls): 28 | proxy_class = _createJProxy(cls, *interfaces, **kwargs) 29 | proxy_class.__doc__ = cls.__doc__ 30 | proxy_class.__name__ = cls.__name__ 31 | return proxy_class 32 | 33 | return JProxyCreator 34 | -------------------------------------------------------------------------------- /src/trustyai/utils/_tyrus_info_text.py: -------------------------------------------------------------------------------- 1 | # pylint: disable = consider-using-f-string 2 | """Info text used in Tyrus visualization explainer info""" 3 | from trustyai.utils._visualisation import bold_red_html, bold_green_html 4 | 5 | LIME_TEXT = """ 6 |
7 |

What is LIME?

8 | 9 |

10 | LIME (Local Interpretable Model-agnostic Explanations) explanations answer the following question: 11 |

"Which features were most important to the predicted {{0}}?"
12 | LIME does this by providing per-feature saliencies, numeric weights that describe how strongly each feature contributed to the model’s output. 13 |

14 | 15 |

16 | In this plot, each horizontal bar represents a feature's saliency: features with positive importance to the predicted {{0}} are marked in {}, while 17 | features with negative importance are marked in {}. The larger the bar, the more important the feature was to the output. 18 |

19 | 20 |

21 | To see how TrustyAI's LIME works, check out the documentation! 22 |

23 |
24 | """.format( 25 | bold_green_html("green"), bold_red_html("red") 26 | ) 27 | 28 | SHAP_TEXT = """ 29 |
30 |

What is SHAP?

31 | 32 | SHAP (SHapley Additive exPlanations) explanations answer the following question: 33 |
“By how much did each feature contribute to the predicted {{}}?”
34 | 35 |

36 | SHAP does this by providing SHAP values that provide an additive explanation of the model output; a receipt for the model’s output. 37 | SHAP will produce a list of per-feature contributions, the sum of which will equal the model's output. 38 | To operate, SHAP also needs access to a background dataset, a set of representative input datapoints that captures 39 | the model’s “normal behavior”. All SHAP values are comparisons against to this background data, i.e., 40 | "By how much did each feature of this input contribute to the output, as compared to the background inputs?" 41 |

42 | 43 |

44 | In this plot, the dotted horizontal line shows the average model output over the background, the starting 45 | "baseline comparison" mark for a SHAP explanation. Then, each vertical bar or candle describes how 46 | each feature {} or {} its contribution to the model's predicted output, marked by the solid horizontal line. The larger 47 | the feature's contribution, the larger the bar. 48 |

49 | 50 |

51 | To see how TrustyAI's SHAP works, check out the documentation! 52 |

53 |
54 | """.format( 55 | bold_green_html("adds"), bold_red_html("subtracts") 56 | ) 57 | 58 | CF_TEXT = """ 59 |
60 |

What is a Counterfactual?

61 | 62 |

63 | Counterfactuals represent alternate, "what-if" scenarios; what other possible values of {0} 64 | can be attained by modifying the input? 65 |

66 | 67 |

68 | This plot shows all of counterfactuals 69 | produced during the computation of the LIME and SHAP explanations. On the x-axis are 70 | novel counterfactual values for {0}, while the y-axis shows the number of features that were changed to produce 71 | that particular value. Hover over individual points to see the exact changes to the original input 72 | necessary to produce the displayed counterfactual value of {0}. 73 |

74 | 75 |

76 | To see how TrustyAI's Counterfactual Explainer works, check out the documentation! 77 |

78 |
79 | """ 80 | -------------------------------------------------------------------------------- /src/trustyai/utils/_visualisation.py: -------------------------------------------------------------------------------- 1 | """Visualiser utilies for explainer results""" 2 | # pylint: disable = consider-using-f-string 3 | 4 | 5 | # HTML FORMAT FUNCTIONS ============================================================================ 6 | def bold_green_html(content): 7 | """Format the content string as a bold, green html object""" 8 | return '{}'.format( 9 | DEFAULT_STYLE["positive_primary_colour"], content 10 | ) 11 | 12 | 13 | def bold_red_html(content): 14 | """Format the content string as a bold, red html object""" 15 | return '{}'.format( 16 | DEFAULT_STYLE["negative_primary_colour"], content 17 | ) 18 | 19 | 20 | def output_html(content): 21 | """Format the content string as a bold object in TrustyAI purple, used for 22 | Tyrus output displays""" 23 | return '{}'.format(content) 24 | 25 | 26 | def feature_html(content): 27 | """Format the content string as a bold object in black, used for 28 | Tyrus feature displays""" 29 | return '{}'.format(content) 30 | 31 | 32 | DEFAULT_STYLE = { 33 | "positive_primary_colour": "#13ba3c", 34 | "positive_primary_colour_faded": "#88dc9d", 35 | "negative_primary_colour": "#ee0000", 36 | "negative_primary_colour_faded": "#f67f7f", 37 | "neutral_primary_colour": "#ffffff", 38 | } 39 | 40 | DEFAULT_RC_PARAMS = { 41 | "patch.linewidth": 0.5, 42 | "patch.facecolor": "348ABD", 43 | "patch.edgecolor": "EEEEEE", 44 | "patch.antialiased": True, 45 | "font.size": 10.0, 46 | "axes.facecolor": "DDDDDD", 47 | "axes.edgecolor": "white", 48 | "axes.linewidth": 1, 49 | "axes.grid": True, 50 | "axes.titlesize": "x-large", 51 | "axes.labelsize": "large", 52 | "axes.labelcolor": "black", 53 | "axes.axisbelow": True, 54 | "text.color": "black", 55 | "xtick.color": "black", 56 | "xtick.direction": "out", 57 | "ytick.color": "black", 58 | "ytick.direction": "out", 59 | "legend.facecolor": "ffffff", 60 | "grid.color": "white", 61 | "grid.linestyle": "-", # solid line 62 | "figure.figsize": (16, 9), 63 | "figure.dpi": 100, 64 | "figure.facecolor": "ffffff", 65 | "figure.edgecolor": "777777", 66 | "savefig.bbox": "tight", 67 | } 68 | -------------------------------------------------------------------------------- /src/trustyai/utils/api/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Server module 3 | """ 4 | 5 | # pylint: disable = import-error, too-few-public-methods, assignment-from-no-return 6 | __SUCCESSFUL_IMPORT = True 7 | 8 | try: 9 | from kubernetes import config, dynamic 10 | from kubernetes.dynamic.exceptions import ResourceNotFoundError 11 | from kubernetes.client import api_client 12 | 13 | except ImportError as e: 14 | print( 15 | "Warning: api dependencies not found. " 16 | "Dependencies can be installed with 'pip install trustyai[api]" 17 | ) 18 | __SUCCESSFUL_IMPORT = False 19 | 20 | if __SUCCESSFUL_IMPORT: 21 | 22 | class TrustyAIApi: 23 | """ 24 | Gets TrustyAI service information 25 | """ 26 | 27 | def __init__(self): 28 | try: 29 | k8s_client = config.load_incluster_config() 30 | except config.ConfigException: 31 | k8s_client = config.load_kube_config() 32 | self.dyn_client = dynamic.DynamicClient( 33 | api_client.ApiClient(configuration=k8s_client) 34 | ) 35 | 36 | def get_service_route(self, name: str, namespace: str): 37 | """ 38 | Gets routes for services under a specified namespace 39 | """ 40 | route_api = self.dyn_client.resources.get(api_version="v1", kind="Route") 41 | try: 42 | service = route_api.get(name=name, namespace=namespace) 43 | return f"https://{service.spec.host}" 44 | except ResourceNotFoundError: 45 | return f"Error accessing service {name} in namespace {namespace}." 46 | -------------------------------------------------------------------------------- /src/trustyai/utils/extras/metrics_service.py: -------------------------------------------------------------------------------- 1 | """Python client for TrustyAI metrics""" 2 | 3 | from typing import List 4 | import json 5 | import datetime as dt 6 | import pandas as pd 7 | import requests 8 | import matplotlib.pyplot as plt 9 | 10 | from trustyai.utils.api.api import TrustyAIApi 11 | 12 | 13 | def json_to_df(data_path: str, batch_list: List[int]) -> pd.DataFrame: 14 | """ 15 | Converts batched data in json files to a single pandas DataFrame 16 | """ 17 | final_df = pd.DataFrame() 18 | for batch in batch_list: 19 | file = data_path + f"{batch}.json" 20 | with open(file, encoding="utf8") as train_file: 21 | batch_data = json.load(train_file)["inputs"][0] 22 | batch_df = pd.DataFrame.from_dict(batch_data["data"]).T 23 | final_df = pd.concat([final_df, batch_df]) 24 | return final_df 25 | 26 | 27 | def df_to_json(final_df: pd.DataFrame, name: str, json_file: str) -> None: 28 | """ 29 | Converts pandas DataFrame to json file 30 | """ 31 | inputs = [ 32 | { 33 | "name": name, 34 | "shape": list(final_df.shape), 35 | "datatype": "FP64", 36 | "data": final_df.values.tolist(), 37 | } 38 | ] 39 | data_dict = {"inputs": inputs} 40 | with open(json_file, "w", encoding="utf8") as outfile: 41 | json.dump(data_dict, outfile) 42 | 43 | 44 | class TrustyAIMetricsService: 45 | """ 46 | Executes and returns queries from TrustyAI service on ODH 47 | """ 48 | 49 | def __init__(self, token: str, namespace: str, verify=True): 50 | """ 51 | :param token: OpenShift login token 52 | :param namespace: model namespace 53 | :param verify: enable SSL verification for requests 54 | """ 55 | self.token = token 56 | self.namespace = namespace 57 | self.trusty_url = TrustyAIApi().get_service_route( 58 | name="trustyai-service", namespace=self.namespace 59 | ) 60 | self.thanos_url = TrustyAIApi().get_service_route( 61 | name="thanos-querier", namespace="openshift-monitoring" 62 | ) 63 | self.headers = { 64 | "Authorization": "Bearer " + token, 65 | "Content-Type": "application/json", 66 | } 67 | self.verify = verify 68 | 69 | def upload_payload_data(self, json_file: str, timeout=5) -> None: 70 | """ 71 | Uploads data to TrustyAI service 72 | """ 73 | with open(json_file, "r", encoding="utf8") as file: 74 | response = requests.post( 75 | f"{self.trusty_url}/data/upload", 76 | data=file, 77 | headers=self.headers, 78 | verify=self.verify, 79 | timeout=timeout, 80 | ) 81 | if response.status_code == 200: 82 | print("Data sucessfully uploaded to TrustyAI service") 83 | else: 84 | print(f"Error {response.status_code}: {response.reason}") 85 | 86 | def get_model_metadata(self, timeout=5): 87 | """ 88 | Retrieves model data from TrustyAI 89 | """ 90 | response = requests.get( 91 | f"{self.trusty_url}/info", 92 | headers=self.headers, 93 | verify=self.verify, 94 | timeout=timeout, 95 | ) 96 | if response.status_code == 200: 97 | model_metadata = json.loads(response.text) 98 | return model_metadata 99 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 100 | 101 | def label_data_fields(self, payload: str, timeout=5): 102 | """ 103 | Assigns feature names to model input data 104 | """ 105 | 106 | def print_name_mapping(self): 107 | response = requests.get( 108 | f"{self.trusty_url}/info", 109 | headers=self.headers, 110 | verify=self.verify, 111 | timeout=timeout, 112 | ) 113 | name_mapping = json.loads(response.text)[0] 114 | for key, val in name_mapping["data"]["inputSchema"]["nameMapping"].items(): 115 | print(f"{key} -> {val}") 116 | 117 | response = requests.get( 118 | f"{self.trusty_url}/info", 119 | headers=self.headers, 120 | verify=self.verify, 121 | timeout=timeout, 122 | ) 123 | input_data_fields = list( 124 | json.loads(response.text)[0]["data"]["inputSchema"]["items"].keys() 125 | ) 126 | input_mapping_keys = list(payload["inputMapping"].keys()) 127 | if len(list(set(input_mapping_keys) - set(input_data_fields))) == 0: 128 | response = requests.post( 129 | f"{self.trusty_url}/info/names", 130 | json=payload, 131 | headers=self.headers, 132 | verify=self.verify, 133 | timeout=timeout, 134 | ) 135 | if response.status_code == 200: 136 | print_name_mapping(self) 137 | return response.text 138 | print(f"Error {response.status_code}: {response.reason}") 139 | raise ValueError("Field does not exist") 140 | 141 | def get_metric_request( 142 | self, payload: str, metric: str, reoccuring: bool, timeout=5 143 | ): 144 | """ 145 | Retrieve or schedule a metric request 146 | """ 147 | if reoccuring: 148 | response = requests.post( 149 | f"{self.trusty_url}/metrics/{metric}/request", 150 | json=payload, 151 | headers=self.headers, 152 | verify=self.verify, 153 | timeout=timeout, 154 | ) 155 | else: 156 | response = requests.post( 157 | f"{self.trusty_url}/metrics/{metric}", 158 | json=payload, 159 | headers=self.headers, 160 | verify=self.verify, 161 | timeout=timeout, 162 | ) 163 | if response.status_code == 200: 164 | return response.text 165 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 166 | 167 | def upload_data_to_model(self, model_name: str, json_file: str, timeout=5): 168 | """ 169 | Sends an inference request to the model 170 | """ 171 | model_route = TrustyAIApi().get_service_route( 172 | name=model_name, namespace=self.namespace 173 | ) 174 | with open(json_file, encoding="utf8") as batch_file: 175 | response = requests.post( 176 | url=f"https://{model_route}/infer", 177 | data=batch_file, 178 | headers=self.headers, 179 | verify=self.verify, 180 | timeout=timeout, 181 | ) 182 | if response.status_code == 200: 183 | return response.text 184 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 185 | 186 | def get_metric_data(self, metric: str, time_interval: List[str], timeout=5): 187 | """ 188 | Retrives metric data for a specific range in time for each subcategory in data field 189 | """ 190 | metric_df = pd.DataFrame() 191 | for subcategory in list( 192 | self.get_model_metadata()[0]["data"]["inputSchema"]["nameMapping"].values() 193 | ): 194 | params = { 195 | "query": f"{metric}{{subcategory='{subcategory}'}}{time_interval}" 196 | } 197 | 198 | response = requests.get( 199 | f"{self.thanos_url}/api/v1/query?", 200 | params=params, 201 | headers=self.headers, 202 | verify=self.verify, 203 | timeout=timeout, 204 | ) 205 | if response.status_code == 200: 206 | if "timestamp" in metric_df.columns: 207 | pass 208 | else: 209 | metric_df["timestamp"] = [ 210 | item[0] 211 | for item in json.loads(response.text)["data"]["result"][0][ 212 | "values" 213 | ] 214 | ] 215 | metric_df[subcategory] = [ 216 | item[1] 217 | for item in json.loads(response.text)["data"]["result"][0]["values"] 218 | ] 219 | else: 220 | raise RuntimeError(f"Error {response.status_code}: {response.reason}") 221 | 222 | metric_df["timestamp"] = metric_df["timestamp"].apply( 223 | lambda epoch: dt.datetime.fromtimestamp(epoch).strftime("%Y-%m-%d %H:%M:%S") 224 | ) 225 | return metric_df 226 | 227 | @staticmethod 228 | def plot_metric(metric_df: pd.DataFrame, metric: str): 229 | """ 230 | Plots a line for each subcategory in the pandas DataFrame returned by get_metric_request 231 | with the timestamp on x-axis and specified metric on the y-axis 232 | """ 233 | plt.figure(figsize=(12, 5)) 234 | for col in metric_df.columns[1:]: 235 | plt.plot(metric_df["timestamp"], metric_df[col]) 236 | plt.xlabel("timestamp") 237 | plt.ylabel(metric) 238 | plt.xticks(rotation=45) 239 | plt.legend(metric_df.columns[1:]) 240 | plt.tight_layout() 241 | plt.show() 242 | -------------------------------------------------------------------------------- /src/trustyai/utils/extras/models.py: -------------------------------------------------------------------------------- 1 | """AIX360 model wrappers""" 2 | from aix360.algorithms.tsutils.model_wrappers import * # pylint: disable=wildcard-import,unused-wildcard-import 3 | -------------------------------------------------------------------------------- /src/trustyai/utils/extras/timeseries.py: -------------------------------------------------------------------------------- 1 | """Extra time series utilities.""" 2 | from aix360.algorithms.tsutils.tsframe import tsFrame # pylint: disable=unused-import 3 | from aix360.algorithms.tsutils.tsperturbers import * # pylint: disable=wildcard-import,unused-wildcard-import 4 | -------------------------------------------------------------------------------- /src/trustyai/utils/text.py: -------------------------------------------------------------------------------- 1 | """Utility methods for text data handling""" 2 | from typing import List, Callable 3 | 4 | from jpype import _jclass 5 | 6 | 7 | def tokenizer(function: Callable[[str], List[str]]): 8 | """Post-process outputs of a Python tokenizer function""" 9 | 10 | def wrapper(_input: str): 11 | return _jclass.JClass("java.util.Arrays").asList(function(_input)) 12 | 13 | return wrapper 14 | -------------------------------------------------------------------------------- /src/trustyai/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | """"Default tokenizers for TrustyAI.""" 2 | # pylint: disable = import-error 3 | 4 | from org.apache.commons.text import StringTokenizer as _StringTokenizer 5 | from opennlp.tools.tokenize import SimpleTokenizer as _SimpleTokenizer 6 | 7 | CommonsStringTokenizer = _StringTokenizer 8 | OpenNLPTokenizer = _SimpleTokenizer 9 | -------------------------------------------------------------------------------- /src/trustyai/version.py: -------------------------------------------------------------------------------- 1 | """TrustyAI version""" 2 | __version__ = "0.6.1" 3 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/__init__.py: -------------------------------------------------------------------------------- 1 | """Generates visualization according to explanation type""" 2 | # pylint: disable=import-error, wrong-import-order, protected-access, missing-final-newline 3 | from typing import Union, Optional 4 | 5 | from bokeh.io import show 6 | 7 | from trustyai.explainers import SHAPResults, LimeResults, pdp 8 | from trustyai.metrics.distance import LevenshteinResult 9 | from trustyai.visualizations.visualization_results import VisualizationResults 10 | from trustyai.visualizations.shap import SHAPViz 11 | from trustyai.visualizations.lime import LimeViz 12 | from trustyai.visualizations.pdp import PDPViz 13 | from trustyai.visualizations.distance import DistanceViz 14 | 15 | 16 | def get_viz(explanations) -> VisualizationResults: 17 | """ 18 | Get visualization according to the explanation method 19 | """ 20 | if isinstance(explanations, SHAPResults): 21 | return SHAPViz() 22 | if isinstance(explanations, LimeResults): 23 | return LimeViz() 24 | if isinstance(explanations, pdp.PDPResults): 25 | return PDPViz() 26 | if isinstance(explanations, LevenshteinResult): 27 | return DistanceViz() 28 | raise ValueError("Explanation method unknown") 29 | 30 | 31 | def plot( 32 | explanations: Union[SHAPResults, LimeResults, pdp.PDPResults, LevenshteinResult], 33 | output_name: Optional[str] = None, 34 | render_bokeh: bool = False, 35 | block: bool = True, 36 | call_show: bool = True, 37 | ) -> None: 38 | """ 39 | Plot the found feature saliencies. 40 | 41 | Parameters 42 | ---------- 43 | explanations: Union[LimeResults, SHAPResults, PDPResults, LevenshteinResult] 44 | the explanation result to plot 45 | output_name : str 46 | (default= `None`) The name of the output to be explainer. If `None`, all outputs will 47 | be displayed 48 | render_bokeh : bool 49 | (default= `False`) If true, render plot in bokeh, otherwise use matplotlib. 50 | block: bool 51 | (default= `True`) Whether displaying the plot blocks subsequent code execution 52 | call_show: bool 53 | (default= 'True') Whether plt.show() will be called by default at the end of the 54 | plotting function. If `False`, the plot will be returned to the user for further 55 | editing. 56 | """ 57 | viz = get_viz(explanations) 58 | 59 | if isinstance(explanations, pdp.PDPResults): 60 | viz.plot(explanations, output_name) 61 | elif isinstance(explanations, LevenshteinResult): 62 | viz.plot(explanations) 63 | elif output_name is None: 64 | for output_name_iterator in explanations.saliency_map().keys(): 65 | if render_bokeh: 66 | show(viz._get_bokeh_plot(explanations, output_name_iterator)) 67 | else: 68 | viz._matplotlib_plot( 69 | explanations, output_name_iterator, block, call_show 70 | ) 71 | else: 72 | if render_bokeh: 73 | show(viz._get_bokeh_plot(explanations, output_name)) 74 | else: 75 | viz._matplotlib_plot(explanations, output_name, block, call_show) 76 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/distance.py: -------------------------------------------------------------------------------- 1 | """Visualizations.distance module""" 2 | # pylint: disable = import-error, too-few-public-methods, line-too-long, missing-final-newline 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class DistanceViz: 8 | """Visualizes Levenshtein distance""" 9 | 10 | def plot(self, explanations): 11 | """Plot the Levenshtein distance matrix""" 12 | cmap = plt.cm.viridis 13 | 14 | _, axes = plt.subplots() 15 | cax = axes.imshow(explanations.matrix, cmap=cmap, interpolation="nearest") 16 | 17 | plt.colorbar(cax) 18 | 19 | axes.set_xticks(np.arange(len(explanations.reference))) 20 | axes.set_yticks(np.arange(len(explanations.hypothesis))) 21 | axes.set_xticklabels(explanations.reference) 22 | axes.set_yticklabels(explanations.hypothesis) 23 | 24 | plt.setp( 25 | axes.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" 26 | ) 27 | 28 | nrows, ncols = explanations.matrix.shape 29 | for i in range(nrows): 30 | for j in range(ncols): 31 | color = ( 32 | "white" 33 | if explanations.matrix[i, j] < explanations.matrix.max() / 2 34 | else "black" 35 | ) 36 | axes.text( 37 | j, 38 | i, 39 | int(explanations.matrix[i, j]), 40 | ha="center", 41 | va="center", 42 | color=color, 43 | ) 44 | 45 | plt.show() 46 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/lime.py: -------------------------------------------------------------------------------- 1 | """Visualizations.lime module""" 2 | # pylint: disable = import-error, too-few-public-methods, consider-using-f-string, missing-final-newline 3 | import matplotlib.pyplot as plt 4 | import matplotlib as mpl 5 | from bokeh.models import ColumnDataSource, HoverTool 6 | from bokeh.plotting import figure 7 | import pandas as pd 8 | 9 | from trustyai.utils._visualisation import ( 10 | DEFAULT_STYLE as ds, 11 | DEFAULT_RC_PARAMS as drcp, 12 | bold_red_html, 13 | bold_green_html, 14 | output_html, 15 | feature_html, 16 | ) 17 | from trustyai.visualizations.visualization_results import VisualizationResults 18 | 19 | 20 | class LimeViz(VisualizationResults): 21 | """Visualizes LIME results.""" 22 | 23 | def _matplotlib_plot( 24 | self, explanations, output_name: str, block=True, call_show=True 25 | ) -> None: 26 | """Plot the LIME saliencies.""" 27 | with mpl.rc_context(drcp): 28 | dictionary = {} 29 | for feature_importance in ( 30 | explanations.saliency_map().get(output_name).getPerFeatureImportance() 31 | ): 32 | dictionary[ 33 | feature_importance.getFeature().name 34 | ] = feature_importance.getScore() 35 | 36 | colours = [ 37 | ds["negative_primary_colour"] 38 | if i < 0 39 | else ds["positive_primary_colour"] 40 | for i in dictionary.values() 41 | ] 42 | plt.title(f"LIME: Feature Importances to {output_name}") 43 | plt.barh( 44 | range(len(dictionary)), 45 | dictionary.values(), 46 | align="center", 47 | color=colours, 48 | ) 49 | plt.yticks(range(len(dictionary)), list(dictionary.keys())) 50 | plt.tight_layout() 51 | 52 | if call_show: 53 | plt.show(block=block) 54 | 55 | def _get_bokeh_plot(self, explanations, output_name): 56 | lime_data_source = pd.DataFrame( 57 | [ 58 | { 59 | "feature": str(pfi.getFeature().getName()), 60 | "saliency": pfi.getScore(), 61 | } 62 | for pfi in explanations.saliency_map()[ 63 | output_name 64 | ].getPerFeatureImportance() 65 | ] 66 | ) 67 | lime_data_source["color"] = lime_data_source["saliency"].apply( 68 | lambda x: ds["positive_primary_colour"] 69 | if x >= 0 70 | else ds["negative_primary_colour"] 71 | ) 72 | lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply( 73 | lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x)) 74 | ) 75 | 76 | lime_data_source["color_faded"] = lime_data_source["saliency"].apply( 77 | lambda x: ds["positive_primary_colour_faded"] 78 | if x >= 0 79 | else ds["negative_primary_colour_faded"] 80 | ) 81 | source = ColumnDataSource(lime_data_source) 82 | htool = HoverTool( 83 | name="bars", 84 | tooltips="

LIME

{} saliency to {}: @saliency_colored".format( 85 | feature_html("@feature"), output_html(output_name) 86 | ), 87 | ) 88 | bokeh_plot = figure( 89 | sizing_mode="stretch_both", 90 | title="Lime Feature Importances", 91 | y_range=lime_data_source["feature"], 92 | tools=[htool], 93 | ) 94 | bokeh_plot.hbar( 95 | y="feature", 96 | left=0, 97 | right="saliency", 98 | fill_color="color_faded", 99 | line_color="color", 100 | hover_color="color", 101 | color="color", 102 | height=0.75, 103 | name="bars", 104 | source=source, 105 | ) 106 | bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000") 107 | bokeh_plot.xaxis.axis_label = "Saliency Value" 108 | bokeh_plot.yaxis.axis_label = "Feature" 109 | return bokeh_plot 110 | 111 | def _get_bokeh_plot_dict(self, explanations): 112 | return { 113 | output_name: self._get_bokeh_plot(explanations, output_name) 114 | for output_name in explanations.saliency_map().keys() 115 | } 116 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/pdp.py: -------------------------------------------------------------------------------- 1 | """Visualizations.pdp module""" 2 | # pylint: disable = import-error, wrong-import-order, too-few-public-methods, missing-final-newline 3 | # pylint: disable = protected-access 4 | import matplotlib.pyplot as plt 5 | 6 | from trustyai.explainers.pdp import PDPResults 7 | 8 | 9 | class PDPViz: 10 | """Visualizes PDP graphs""" 11 | 12 | def plot(self, explanations, output_name=None, block=True, call_show=True) -> None: 13 | """ 14 | Parameters 15 | ---------- 16 | explanations: pdp.PDPResults 17 | the partial dependence plots associated to the model outputs 18 | output_name: str 19 | name of the output to be plotted 20 | Default to None 21 | block: bool 22 | whether the plotting operation 23 | should be blocking or not 24 | call_show: bool 25 | (default= 'True') Whether plt.show() will be called by default at the end of 26 | the plotting function. If `False`, the plot will be returned to the user for 27 | further editing. 28 | """ 29 | pdp_graphs = explanations.pdp_graphs 30 | fig, axs = plt.subplots(len(pdp_graphs), constrained_layout=True) 31 | p_idx = 0 32 | for pdp_graph in pdp_graphs: 33 | if output_name is not None and output_name != str( 34 | pdp_graph.getOutput().getName() 35 | ): 36 | continue 37 | fig.suptitle(str(pdp_graph.getOutput().getName())) 38 | pdp_x = [] 39 | for i in range(len(pdp_graph.getX())): 40 | pdp_x.append(PDPResults._to_plottable(pdp_graph.getX()[i])) 41 | pdp_y = [] 42 | for i in range(len(pdp_graph.getY())): 43 | pdp_y.append(PDPResults._to_plottable(pdp_graph.getY()[i])) 44 | axs[p_idx].plot(pdp_x, pdp_y) 45 | axs[p_idx].set_title( 46 | str(pdp_graph.getFeature().getName()), loc="left", fontsize="small" 47 | ) 48 | axs[p_idx].grid() 49 | p_idx += 1 50 | fig.supylabel("Partial Dependence Plot") 51 | if call_show: 52 | plt.show(block=block) 53 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/shap.py: -------------------------------------------------------------------------------- 1 | """Visualizations.shap module""" 2 | # pylint: disable = import-error, consider-using-f-string, too-few-public-methods, missing-final-newline 3 | import matplotlib.pyplot as plt 4 | import matplotlib as mpl 5 | from bokeh.models import ColumnDataSource, HoverTool 6 | from bokeh.plotting import figure 7 | import pandas as pd 8 | import numpy as np 9 | 10 | from trustyai.utils._visualisation import ( 11 | DEFAULT_STYLE as ds, 12 | DEFAULT_RC_PARAMS as drcp, 13 | bold_red_html, 14 | bold_green_html, 15 | output_html, 16 | feature_html, 17 | ) 18 | from trustyai.visualizations.visualization_results import VisualizationResults 19 | 20 | 21 | class SHAPViz(VisualizationResults): 22 | """Visualizes SHAP results.""" 23 | 24 | def _matplotlib_plot( 25 | self, explanations, output_name=None, block=True, call_show=True 26 | ) -> None: 27 | """Visualize the SHAP explanation of each output as a set of candlestick plots, 28 | one per output.""" 29 | with mpl.rc_context(drcp): 30 | shap_values = [ 31 | pfi.getScore() 32 | for pfi in explanations.saliency_map()[ 33 | output_name 34 | ].getPerFeatureImportance()[:-1] 35 | ] 36 | feature_names = [ 37 | str(pfi.getFeature().getName()) 38 | for pfi in explanations.saliency_map()[ 39 | output_name 40 | ].getPerFeatureImportance()[:-1] 41 | ] 42 | fnull = explanations.get_fnull()[output_name] 43 | prediction = fnull + sum(shap_values) 44 | 45 | if call_show: 46 | plt.figure() 47 | pos = fnull 48 | for j, shap_value in enumerate(shap_values): 49 | color = ( 50 | ds["negative_primary_colour"] 51 | if shap_value < 0 52 | else ds["positive_primary_colour"] 53 | ) 54 | width = 0.9 55 | if j > 0: 56 | plt.plot([j - 0.5, j + width / 2 * 0.99], [pos, pos], color=color) 57 | plt.bar(j, height=shap_value, bottom=pos, color=color, width=width) 58 | pos += shap_values[j] 59 | 60 | if j != len(shap_values) - 1: 61 | plt.plot([j - width / 2 * 0.99, j + 0.5], [pos, pos], color=color) 62 | 63 | plt.axhline( 64 | fnull, 65 | color="#444444", 66 | linestyle="--", 67 | zorder=0, 68 | label="Background Value", 69 | ) 70 | plt.axhline(prediction, color="#444444", zorder=0, label="Prediction") 71 | plt.legend() 72 | 73 | ticksize = np.diff(plt.gca().get_yticks())[0] 74 | plt.ylim( 75 | plt.gca().get_ylim()[0] - ticksize / 2, 76 | plt.gca().get_ylim()[1] + ticksize / 2, 77 | ) 78 | plt.xticks(np.arange(len(feature_names)), feature_names) 79 | plt.ylabel(explanations.saliency_map()[output_name].getOutput().getName()) 80 | plt.xlabel("Feature SHAP Value") 81 | plt.title(f"SHAP: Feature Contributions to {output_name}") 82 | if call_show: 83 | plt.show(block=block) 84 | 85 | def _get_bokeh_plot(self, explanations, output_name): 86 | fnull = explanations.get_fnull()[output_name] 87 | 88 | # create dataframe of plot values 89 | data_source = pd.DataFrame( 90 | [ 91 | { 92 | "feature": str(pfi.getFeature().getName()), 93 | "saliency": pfi.getScore(), 94 | } 95 | for pfi in explanations.saliency_map()[ 96 | output_name 97 | ].getPerFeatureImportance()[:-1] 98 | ] 99 | ) 100 | prediction = fnull + data_source["saliency"].sum() 101 | 102 | data_source["color"] = data_source["saliency"].apply( 103 | lambda x: ds["positive_primary_colour"] 104 | if x >= 0 105 | else ds["negative_primary_colour"] 106 | ) 107 | data_source["color_faded"] = data_source["saliency"].apply( 108 | lambda x: ds["positive_primary_colour_faded"] 109 | if x >= 0 110 | else ds["negative_primary_colour_faded"] 111 | ) 112 | data_source["index"] = data_source.index 113 | data_source["saliency_text"] = data_source["saliency"].apply( 114 | lambda x: (bold_red_html if x <= 0 else bold_green_html)("{:.2f}".format(x)) 115 | ) 116 | data_source["bottom"] = pd.Series( 117 | [fnull] + data_source["saliency"].iloc[0:-1].tolist() 118 | ).cumsum() 119 | data_source["top"] = data_source["bottom"] + data_source["saliency"] 120 | 121 | # create hovertools 122 | htool_fnull = HoverTool( 123 | name="fnull", 124 | tooltips=("

SHAP

Baseline {}: {}").format( 125 | output_name, output_html("{:.2f}".format(fnull)) 126 | ), 127 | line_policy="interp", 128 | ) 129 | htool_pred = HoverTool( 130 | name="pred", 131 | tooltips=("

SHAP

Predicted {}: {}").format( 132 | output_name, output_html("{:.2f}".format(prediction)) 133 | ), 134 | line_policy="interp", 135 | ) 136 | htool_bars = HoverTool( 137 | name="bars", 138 | tooltips="

SHAP

{} contributions to {}: @saliency_text".format( 139 | feature_html("@feature"), output_html(output_name) 140 | ), 141 | ) 142 | 143 | # create plot 144 | bokeh_plot = figure( 145 | sizing_mode="stretch_both", 146 | title="SHAP Feature Contributions", 147 | x_range=data_source["feature"], 148 | tools=[htool_pred, htool_fnull, htool_bars], 149 | ) 150 | 151 | # add fnull and background lines 152 | line_data_source = ColumnDataSource( 153 | pd.DataFrame( 154 | [ 155 | {"x": 0, "pred": prediction}, 156 | {"x": len(data_source), "pred": prediction}, 157 | ] 158 | ) 159 | ) 160 | fnull_data_source = ColumnDataSource( 161 | pd.DataFrame( 162 | [{"x": 0, "fnull": fnull}, {"x": len(data_source), "fnull": fnull}] 163 | ) 164 | ) 165 | 166 | bokeh_plot.line( 167 | x="x", 168 | y="fnull", 169 | line_color="#999", 170 | hover_line_color="#333", 171 | line_width=2, 172 | hover_line_width=4, 173 | line_dash="dashed", 174 | name="fnull", 175 | source=fnull_data_source, 176 | ) 177 | bokeh_plot.line( 178 | x="x", 179 | y="pred", 180 | line_color="#999", 181 | hover_line_color="#333", 182 | line_width=2, 183 | hover_line_width=4, 184 | name="pred", 185 | source=line_data_source, 186 | ) 187 | 188 | # create candlestick plot lines 189 | bokeh_plot.line( 190 | x=[0.5, 1], 191 | y=data_source.iloc[0]["top"], 192 | color=data_source.iloc[0]["color"], 193 | ) 194 | for i in range(1, len(data_source)): 195 | # bar left line 196 | bokeh_plot.line( 197 | x=[i, i + 0.5], 198 | y=data_source.iloc[i]["bottom"], 199 | color=data_source.iloc[i]["color"], 200 | ) 201 | # bar right line 202 | if i != len(data_source) - 1: 203 | bokeh_plot.line( 204 | x=[i + 0.5, i + 1], 205 | y=data_source.iloc[i]["top"], 206 | color=data_source.iloc[i]["color"], 207 | ) 208 | 209 | # create candles 210 | bokeh_plot.vbar( 211 | x="feature", 212 | bottom="bottom", 213 | top="top", 214 | hover_color="color", 215 | color="color_faded", 216 | width=0.75, 217 | name="bars", 218 | source=data_source, 219 | ) 220 | bokeh_plot.yaxis.axis_label = str(output_name) 221 | return bokeh_plot 222 | 223 | def _get_bokeh_plot_dict(self, explanations): 224 | return { 225 | decision: self._get_bokeh_plot(explanations, decision) 226 | for decision in explanations.saliency_map().keys() 227 | } 228 | -------------------------------------------------------------------------------- /src/trustyai/visualizations/visualization_results.py: -------------------------------------------------------------------------------- 1 | """Generic class for Visualization results""" 2 | # pylint: disable = import-error, too-few-public-methods, line-too-long, missing-final-newline 3 | from abc import ABC, abstractmethod 4 | from typing import Dict 5 | 6 | import bokeh.models 7 | 8 | 9 | class VisualizationResults(ABC): 10 | """Abstract class for visualization results""" 11 | 12 | @abstractmethod 13 | def _matplotlib_plot( 14 | self, explanations, output_name: str, block: bool, call_show: bool 15 | ) -> None: 16 | """Plot the saliencies of a particular output in matplotlib""" 17 | 18 | @abstractmethod 19 | def _get_bokeh_plot(self, explanations, output_name: str) -> bokeh.models.Plot: 20 | """Get a bokeh plot visualizing the saliencies of a particular output""" 21 | 22 | @abstractmethod 23 | def _get_bokeh_plot_dict(self, explanations) -> Dict[str, bokeh.models.Plot]: 24 | """Get a dictionary containing visualizations of the saliencies of all outputs, 25 | keyed by output name""" 26 | return { 27 | output_name: self._get_bokeh_plot(explanations, output_name) 28 | for output_name in explanations.saliency_map().keys() 29 | } 30 | -------------------------------------------------------------------------------- /tests/benchmarks/benchmark.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """Common methods and models for tests""" 3 | import os 4 | import sys 5 | import pytest 6 | import time 7 | import numpy as np 8 | 9 | from trustyai.explainers import LimeExplainer, SHAPExplainer 10 | from trustyai.model import feature, PredictionInput 11 | from trustyai.utils import TestModels 12 | from trustyai.metrics.saliency import mean_impact_score, classification_fidelity, local_saliency_f1 13 | 14 | from org.kie.trustyai.explainability.model import ( 15 | PredictionInputsDataDistribution, 16 | ) 17 | 18 | myPath = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.insert(0, myPath + "/../general/") 20 | 21 | import test_counterfactualexplainer as tcf 22 | 23 | @pytest.mark.benchmark( 24 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 25 | ) 26 | def test_counterfactual_match(benchmark): 27 | """Counterfactual match""" 28 | benchmark(tcf.test_counterfactual_match) 29 | 30 | 31 | @pytest.mark.benchmark( 32 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 33 | ) 34 | def test_non_empty_input(benchmark): 35 | """Counterfactual non-empty input""" 36 | benchmark(tcf.test_non_empty_input) 37 | 38 | 39 | @pytest.mark.benchmark( 40 | group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 41 | ) 42 | def test_counterfactual_match_python_model(benchmark): 43 | """Counterfactual match (Python model)""" 44 | benchmark(tcf.test_counterfactual_match_python_model) 45 | 46 | 47 | @pytest.mark.benchmark( 48 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 49 | ) 50 | def test_sumskip_lime_impact_score_at_2(benchmark): 51 | no_of_features = 10 52 | np.random.seed(0) 53 | explainer = LimeExplainer() 54 | model = TestModels.getSumSkipModel(0) 55 | data = [] 56 | for i in range(100): 57 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)]) 58 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 59 | benchmark(mean_impact_score, explainer, model, data) 60 | 61 | 62 | @pytest.mark.benchmark( 63 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 64 | ) 65 | def test_sumskip_shap_impact_score_at_2(benchmark): 66 | no_of_features = 10 67 | np.random.seed(0) 68 | background = [] 69 | for i in range(10): 70 | background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)])) 71 | explainer = SHAPExplainer(background, samples=10000) 72 | model = TestModels.getSumSkipModel(0) 73 | data = [] 74 | for i in range(100): 75 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)]) 76 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 77 | benchmark(mean_impact_score, explainer, model, data) 78 | 79 | 80 | @pytest.mark.benchmark( 81 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 82 | ) 83 | def test_sumthreshold_lime_impact_score_at_2(benchmark): 84 | no_of_features = 10 85 | np.random.seed(0) 86 | explainer = LimeExplainer() 87 | center = 100.0 88 | epsilon = 10.0 89 | model = TestModels.getSumThresholdModel(center, epsilon) 90 | data = [] 91 | for i in range(100): 92 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 93 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 94 | benchmark(mean_impact_score, explainer, model, data) 95 | 96 | 97 | @pytest.mark.benchmark( 98 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 99 | ) 100 | def test_sumthreshold_shap_impact_score_at_2(benchmark): 101 | no_of_features = 10 102 | np.random.seed(0) 103 | background = [] 104 | for i in range(100): 105 | background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 106 | explainer = SHAPExplainer(background, samples=10000) 107 | center = 100.0 108 | epsilon = 10.0 109 | model = TestModels.getSumThresholdModel(center, epsilon) 110 | data = [] 111 | for i in range(100): 112 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 113 | benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data) 114 | benchmark(mean_impact_score, explainer, model, data) 115 | 116 | 117 | @pytest.mark.benchmark( 118 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 119 | ) 120 | def test_lime_fidelity(benchmark): 121 | no_of_features = 10 122 | np.random.seed(0) 123 | explainer = LimeExplainer() 124 | model = TestModels.getEvenSumModel(0) 125 | data = [] 126 | for i in range(100): 127 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]) 128 | benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data) 129 | benchmark(classification_fidelity, explainer, model, data) 130 | 131 | 132 | @pytest.mark.benchmark( 133 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 134 | ) 135 | def test_shap_fidelity(benchmark): 136 | no_of_features = 10 137 | np.random.seed(0) 138 | background = [] 139 | for i in range(10): 140 | background.append(PredictionInput( 141 | [feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in 142 | range(no_of_features)])) 143 | explainer = SHAPExplainer(background, samples=10000) 144 | model = TestModels.getEvenSumModel(0) 145 | data = [] 146 | for i in range(100): 147 | data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in 148 | range(no_of_features)]) 149 | benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data) 150 | benchmark(classification_fidelity, explainer, model, data) 151 | 152 | 153 | @pytest.mark.benchmark( 154 | group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 155 | ) 156 | def test_lime_local_saliency_f1(benchmark): 157 | no_of_features = 10 158 | np.random.seed(0) 159 | explainer = LimeExplainer() 160 | model = TestModels.getEvenSumModel(0) 161 | output_name = "sum-even-but0" 162 | data = [] 163 | for i in range(100): 164 | data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 165 | distribution = PredictionInputsDataDistribution(data) 166 | benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10) 167 | benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10) 168 | 169 | 170 | @pytest.mark.benchmark( 171 | group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True 172 | ) 173 | def test_shap_local_saliency_f1(benchmark): 174 | no_of_features = 10 175 | np.random.seed(0) 176 | background = [] 177 | for i in range(10): 178 | background.append(PredictionInput( 179 | [feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in 180 | range(no_of_features)])) 181 | explainer = SHAPExplainer(background, samples=10000) 182 | model = TestModels.getEvenSumModel(0) 183 | output_name = "sum-even-but0" 184 | data = [] 185 | for i in range(100): 186 | data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])) 187 | distribution = PredictionInputsDataDistribution(data) 188 | benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10) 189 | benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10) -------------------------------------------------------------------------------- /tests/benchmarks/benchmark_common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/tests/benchmarks/benchmark_common.py -------------------------------------------------------------------------------- /tests/benchmarks/xai_benchmark.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from trustyai_xai_bench import run_benchmark_config 3 | 4 | 5 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 6 | def test_level_0(benchmark): 7 | # ~4.5 min 8 | result = benchmark(run_benchmark_config, 0) 9 | benchmark.extra_info['runs'] = result.to_dict('records') 10 | 11 | 12 | @pytest.mark.skip(reason="full diagnostic benchmark, ~2 hour runtime") 13 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 14 | def test_level_1(benchmark): 15 | result = benchmark(run_benchmark_config, 1) 16 | benchmark.extra_info['runs'] = result.to_dict('records') 17 | 18 | 19 | @pytest.mark.skip(reason="very thorough benchmark, >>2 hour runtime") 20 | @pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False) 21 | def test_level_2(benchmark): 22 | result = benchmark(run_benchmark_config, 2) 23 | benchmark.extra_info['runs'] = result.to_dict('records') -------------------------------------------------------------------------------- /tests/extras/test_metrics_service.py: -------------------------------------------------------------------------------- 1 | """Test suite for TrustyAI metrics service data conversions""" 2 | import json 3 | import os 4 | import random 5 | import unittest 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from trustyai.utils.extras.metrics_service import ( 10 | json_to_df, 11 | df_to_json 12 | ) 13 | 14 | def generate_json_data(batch_list, data_path): 15 | for batch in batch_list: 16 | data = { 17 | "inputs": [ 18 | {"name": "test_data_input", 19 | "shape": [1, 100], 20 | "datatype": "FP64", 21 | "data": [random.uniform(a=100, b=200) for i in range(100)] 22 | } 23 | ] 24 | } 25 | for batch in batch_list: 26 | with open(data_path + f"{batch}.json", 'w', encoding="utf-8") as f: 27 | json.dump(data, f, ensure_ascii=False) 28 | 29 | 30 | def generate_test_df(): 31 | data = { 32 | '0': np.random.uniform(low=100, high=200, size=100), 33 | '1': np.random.uniform(low=5000, high=10000, size=100), 34 | '2': np.random.uniform(low=100, high=200, size=100), 35 | '3': np.random.uniform(low=5000, high=10000, size=100), 36 | '4': np.random.uniform(low=5000, high=10000, size=100) 37 | } 38 | return pd.DataFrame(data=data) 39 | 40 | 41 | class TestMetricsService(unittest.TestCase): 42 | def setUp(self): 43 | self.df = generate_test_df() 44 | self.data_path = "data/" 45 | if not os.path.exists(self.data_path): 46 | os.mkdir("data/") 47 | self.batch_list = list(range(0, 5)) 48 | 49 | def test_json_to_df(self): 50 | """Test json data to pandas dataframe conversion""" 51 | generate_json_data(batch_list=self.batch_list, data_path=self.data_path) 52 | df = json_to_df(self.data_path, self.batch_list) 53 | n_rows, n_cols = 0, 0 54 | for batch in self.batch_list: 55 | file = self.data_path + f"{batch}.json" 56 | with open(file, encoding="utf8") as f: 57 | data = json.load(f)["inputs"][0] 58 | n_rows += data["shape"][0] 59 | n_cols = data["shape"][1] 60 | self.assertEqual(df.shape, (n_rows, n_cols)) 61 | 62 | 63 | def test_df_to_json(self): 64 | """Test pandas dataframe to json data conversion""" 65 | df = generate_test_df() 66 | name = 'test_data_input' 67 | json_file = 'data/test.json' 68 | df_to_json(df, name, json_file) 69 | with open(json_file, encoding="utf8") as f: 70 | data = json.load(f)["inputs"][0] 71 | n_rows = data["shape"][0] 72 | n_cols = data["shape"][1] 73 | self.assertEqual(df.shape, (n_rows, n_cols)) 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/extras/test_tsice.py: -------------------------------------------------------------------------------- 1 | """ Tests for :py:mod:`aix360.algorithms.tsice.TSICEExplainer`. 2 | Original: https://github.com/Trusted-AI/AIX360/blob/master/tests/tsice/test_tsice.py 3 | """ 4 | import unittest 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.model_selection import train_test_split 8 | from sklearn.ensemble import RandomForestRegressor 9 | from aix360.algorithms.tsutils.tsframe import tsFrame 10 | from aix360.datasets import SunspotDataset 11 | from aix360.algorithms.tsutils.tsperturbers import BlockBootstrapPerturber 12 | from trustyai.explainers.extras.tsice import TSICEExplainer 13 | 14 | 15 | # transform a time series dataset into a supervised learning dataset 16 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 17 | class RandomForestUniVariateForecaster: 18 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 19 | self.n_past = n_past 20 | self.n_future = n_future 21 | self.model = RandomForestRegressor(**RFparams) 22 | 23 | def fit(self, X): 24 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 25 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 26 | self.model = self.model.fit(trainX, trainy) 27 | return self 28 | 29 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 30 | 1 if type(data) is list else data.shape[1] 31 | df = pd.DataFrame(data) 32 | cols = list() 33 | 34 | # input sequence (t-n, ... t-1) 35 | for i in range(n_in, 0, -1): 36 | cols.append(df.shift(i)) 37 | # forecast sequence (t, t+1, ... t+n) 38 | for i in range(0, n_out): 39 | cols.append(df.shift(-i)) 40 | # put it all together 41 | agg = pd.concat(cols, axis=1) 42 | # drop rows with NaN values 43 | if dropnan: 44 | agg.dropna(inplace=True) 45 | return agg.values 46 | 47 | def predict(self, X): 48 | row = X[-self.n_past:].flatten() 49 | y_pred = self.model.predict(np.asarray([row])) 50 | return y_pred 51 | 52 | 53 | class TestTSICEExplainer(unittest.TestCase): 54 | def setUp(self): 55 | # load data 56 | df, schema = SunspotDataset().load_data() 57 | ts = tsFrame( 58 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 59 | ) 60 | 61 | (self.ts_train, self.ts_test) = train_test_split( 62 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 63 | ) 64 | 65 | def test_tsice_with_range(self): 66 | # load model 67 | input_length = 24 68 | forecast_horizon = 4 69 | forecaster = RandomForestUniVariateForecaster( 70 | n_past=input_length, n_future=forecast_horizon 71 | ) 72 | 73 | forecaster.fit(self.ts_train.iloc[-200:]) 74 | 75 | # initialize/fit explainer 76 | observation_length = 12 77 | explainer = TSICEExplainer( 78 | model=forecaster.predict, 79 | explanation_window_start=10, 80 | explanation_window_length=observation_length, 81 | features_to_analyze=[ 82 | "mean", "std" # analyze mean metric from recent time series of lengh 83 | ], 84 | perturbers=[ 85 | BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2), 86 | ], 87 | input_length=input_length, 88 | forecast_lookahead=forecast_horizon, 89 | n_perturbations=30, 90 | ) 91 | 92 | # compute explanations 93 | explanation = explainer.explain( 94 | inputs=self.ts_test.iloc[:80], 95 | ) 96 | 97 | # validate explanation structure 98 | self.assertIn("data_x", explanation.explanation) 99 | self.assertIn("feature_names", explanation.explanation) 100 | self.assertIn("feature_values", explanation.explanation) 101 | self.assertIn("signed_impact", explanation.explanation) 102 | self.assertIn("total_impact", explanation.explanation) 103 | self.assertIn("current_forecast", explanation.explanation) 104 | self.assertIn("current_feature_values", explanation.explanation) 105 | self.assertIn("perturbations", explanation.explanation) 106 | self.assertIn("forecasts_on_perturbations", explanation.explanation) 107 | 108 | def test_tsice_with_latest(self): 109 | # load model 110 | input_length = 24 111 | forecast_horizon = 4 112 | forecaster = RandomForestUniVariateForecaster( 113 | n_past=input_length, n_future=forecast_horizon 114 | ) 115 | 116 | forecaster.fit(self.ts_train.iloc[-200:]) 117 | 118 | # initialize/fit explainer 119 | observation_length = 12 120 | explainer = TSICEExplainer( 121 | model=forecaster.predict, 122 | explanation_window_start=None, 123 | explanation_window_length=observation_length, 124 | features_to_analyze=[ 125 | "mean", # analyze mean metric from recent time series of lengh 126 | "median", # analyze median metric from recent time series of lengh 127 | "std", # analyze std metric from recent time series of lengh 128 | "max_variation", # analyze max_variation metric from recent time series of lengh 129 | "min", 130 | "max", 131 | "range", 132 | "intercept", 133 | "trend", 134 | "rsquared", 135 | ], 136 | perturbers=[ 137 | BlockBootstrapPerturber(window_length=5, block_length=5, block_swap=2), 138 | dict( 139 | type="frequency", 140 | window_length=5, 141 | truncate_frequencies=5, 142 | block_length=4, 143 | ), 144 | dict(type="moving-average", window_length=5, lag=5, block_length=4), 145 | dict(type="impute", block_length=4), 146 | dict(type="shift", block_length=4), 147 | ], 148 | input_length=input_length, 149 | forecast_lookahead=forecast_horizon, 150 | n_perturbations=20, 151 | ) 152 | 153 | # compute explanations 154 | explanation = explainer.explain( 155 | inputs=self.ts_test.iloc[:80], 156 | ) 157 | 158 | # validate explanation structure 159 | self.assertIn("data_x", explanation.explanation) 160 | self.assertIn("feature_names", explanation.explanation) 161 | self.assertIn("feature_values", explanation.explanation) 162 | self.assertIn("signed_impact", explanation.explanation) 163 | self.assertIn("total_impact", explanation.explanation) 164 | self.assertIn("current_forecast", explanation.explanation) 165 | self.assertIn("current_feature_values", explanation.explanation) 166 | self.assertIn("perturbations", explanation.explanation) 167 | self.assertIn("forecasts_on_perturbations", explanation.explanation) 168 | 169 | 170 | if __name__ == "__main__": 171 | unittest.main() 172 | -------------------------------------------------------------------------------- /tests/extras/test_tslime.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import numpy as np 4 | import pandas as pd 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.ensemble import RandomForestRegressor 7 | from trustyai.utils.extras.timeseries import tsFrame 8 | from aix360.datasets import SunspotDataset 9 | from trustyai.explainers.extras.tslime import TSLimeExplainer 10 | from trustyai.utils.extras.timeseries import BlockBootstrapPerturber 11 | 12 | 13 | # transform a time series dataset into a supervised learning dataset 14 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 15 | class RandomForestUniVariateForecaster: 16 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 17 | self.n_past = n_past 18 | self.n_future = n_future 19 | self.model = RandomForestRegressor(**RFparams) 20 | 21 | def fit(self, X): 22 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 23 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 24 | self.model = self.model.fit(trainX, trainy) 25 | return self 26 | 27 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 28 | n_vars = 1 if type(data) is list else data.shape[1] 29 | df = pd.DataFrame(data) 30 | cols = list() 31 | 32 | # input sequence (t-n, ... t-1) 33 | for i in range(n_in, 0, -1): 34 | cols.append(df.shift(i)) 35 | # forecast sequence (t, t+1, ... t+n) 36 | for i in range(0, n_out): 37 | cols.append(df.shift(-i)) 38 | # put it all together 39 | agg = pd.concat(cols, axis=1) 40 | # drop rows with NaN values 41 | if dropnan: 42 | agg.dropna(inplace=True) 43 | return agg.values 44 | 45 | def predict(self, X): 46 | row = X[-self.n_past:].flatten() 47 | y_pred = self.model.predict(np.asarray([row])) 48 | return y_pred 49 | 50 | 51 | class TestTSLimeExplainer(unittest.TestCase): 52 | def setUp(self): 53 | # load data 54 | df, schema = SunspotDataset().load_data() 55 | ts = tsFrame( 56 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 57 | ) 58 | 59 | (self.ts_train, self.ts_test) = train_test_split( 60 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 61 | ) 62 | 63 | def test_tslime(self): 64 | # load model 65 | input_length = 24 66 | forecast_horizon = 4 67 | forecaster = RandomForestUniVariateForecaster( 68 | n_past=input_length, n_future=forecast_horizon 69 | ) 70 | 71 | forecaster.fit(self.ts_train.iloc[-200:]) 72 | 73 | # initialize/fit explainer 74 | 75 | relevant_history = 12 76 | explainer = TSLimeExplainer( 77 | model=forecaster.predict, 78 | input_length=input_length, 79 | relevant_history=relevant_history, 80 | perturbers=[ 81 | BlockBootstrapPerturber( 82 | window_length=min(4, input_length - 1), block_length=2, block_swap=2 83 | ), 84 | ], 85 | n_perturbations=10, 86 | random_seed=22, 87 | ) 88 | 89 | # compute explanations 90 | test_window = self.ts_test.iloc[:input_length] 91 | explanation = explainer.explain(test_window) 92 | 93 | # validate explanation structure 94 | self.assertIn("input_data", explanation.explanation) 95 | self.assertIn("history_weights", explanation.explanation) 96 | self.assertIn("x_perturbations", explanation.explanation) 97 | self.assertIn("y_perturbations", explanation.explanation) 98 | self.assertIn("model_prediction", explanation.explanation) 99 | self.assertIn("surrogate_prediction", explanation.explanation) 100 | 101 | self.assertEqual(explanation.explanation["history_weights"].shape[0], relevant_history) 102 | -------------------------------------------------------------------------------- /tests/extras/test_tssaliency.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.ensemble import RandomForestRegressor 6 | 7 | from aix360.datasets import SunspotDataset 8 | from trustyai.explainers.extras.tssaliency import TSSaliencyExplainer 9 | from trustyai.utils.extras.timeseries import tsFrame 10 | 11 | 12 | # transform a time series dataset into a supervised learning dataset 13 | # below sample forecaster is from: https://machinelearningmastery.com/random-forest-for-time-series-forecasting/ 14 | class RandomForestUniVariateForecaster: 15 | def __init__(self, n_past=4, n_future=1, RFparams={"n_estimators": 250}): 16 | self.n_past = n_past 17 | self.n_future = n_future 18 | self.model = RandomForestRegressor(**RFparams) 19 | 20 | def fit(self, X): 21 | train = self._series_to_supervised(X, n_in=self.n_past, n_out=self.n_future) 22 | trainX, trainy = train[:, : -self.n_future], train[:, -self.n_future:] 23 | self.model = self.model.fit(trainX, trainy) 24 | return self 25 | 26 | def _series_to_supervised(self, data, n_in=1, n_out=1, dropnan=True): 27 | n_vars = 1 if type(data) is list else data.shape[1] 28 | df = pd.DataFrame(data) 29 | cols = list() 30 | 31 | # input sequence (t-n, ... t-1) 32 | for i in range(n_in, 0, -1): 33 | cols.append(df.shift(i)) 34 | # forecast sequence (t, t+1, ... t+n) 35 | for i in range(0, n_out): 36 | cols.append(df.shift(-i)) 37 | # put it all together 38 | agg = pd.concat(cols, axis=1) 39 | # drop rows with NaN values 40 | if dropnan: 41 | agg.dropna(inplace=True) 42 | return agg.values 43 | 44 | def predict(self, X): 45 | row = X[-self.n_past:].flatten() 46 | y_pred = self.model.predict(np.asarray([row])) 47 | return y_pred 48 | 49 | 50 | class TestTSSaliencyExplainer(unittest.TestCase): 51 | def setUp(self): 52 | # load data 53 | df, schema = SunspotDataset().load_data() 54 | ts = tsFrame( 55 | df, timestamp_column=schema["timestamp"], columns=schema["targets"] 56 | ) 57 | 58 | (self.ts_train, self.ts_test) = train_test_split( 59 | ts, shuffle=False, stratify=None, test_size=0.15, train_size=None 60 | ) 61 | 62 | def test_tssaliency(self): 63 | # load model 64 | input_length = 48 65 | forecast_horizon = 10 66 | forecaster = RandomForestUniVariateForecaster( 67 | n_past=input_length, n_future=forecast_horizon 68 | ) 69 | 70 | forecaster.fit(self.ts_train.iloc[-200:]) 71 | 72 | # initialize/fit explainer 73 | 74 | explainer = TSSaliencyExplainer( 75 | model=forecaster.predict, 76 | input_length=input_length, 77 | feature_names=self.ts_train.columns.tolist(), 78 | n_samples=2, 79 | gradient_samples=50, 80 | ) 81 | 82 | # compute explanations 83 | test_window = self.ts_test.iloc[:input_length] 84 | explanation = explainer.explain(test_window) 85 | 86 | # validate explanation structure 87 | self.assertIn("input_data", explanation.explanation) 88 | self.assertIn("feature_names", explanation.explanation) 89 | self.assertIn("saliency", explanation.explanation) 90 | self.assertIn("timestamps", explanation.explanation) 91 | self.assertIn("base_value", explanation.explanation) 92 | self.assertIn("instance_prediction", explanation.explanation) 93 | self.assertIn("base_value_prediction", explanation.explanation) 94 | 95 | self.assertEqual(explanation.explanation["saliency"].shape, test_window.shape) 96 | -------------------------------------------------------------------------------- /tests/general/common.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=R0801 2 | """Common methods and models for tests""" 3 | import os 4 | import sys 5 | from typing import Optional, List 6 | 7 | import numpy as np 8 | import pandas as pd # pylint: disable=unused-import 9 | 10 | myPath = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.insert(0, myPath + "/../../src") 12 | 13 | from trustyai.model import ( 14 | FeatureFactory, 15 | ) 16 | 17 | 18 | def mock_feature(value, name='f-num'): 19 | """Create a mock numerical feature""" 20 | return FeatureFactory.newNumericalFeature(name, value) 21 | 22 | 23 | def sum_skip_model(inputs: np.ndarray) -> np.ndarray: 24 | """SumSkip test model""" 25 | return np.sum(inputs[:, [i for i in range(inputs.shape[1]) if i != 5]], 1) 26 | 27 | 28 | def create_random_dataframe(weights: Optional[List[float]] = None): 29 | """Create a simple random Pandas dataframe""" 30 | from sklearn.datasets import make_classification 31 | if not weights: 32 | weights = [0.9, 0.1] 33 | 34 | X, y = make_classification(n_samples=5000, n_features=2, n_informative=2, n_redundant=0, n_repeated=0, n_classes=2, 35 | n_clusters_per_class=2, class_sep=2, flip_y=0, weights=weights, 36 | random_state=23) 37 | 38 | return pd.DataFrame({ 39 | 'x1': X[:, 0], 40 | 'x2': X[:, 1], 41 | 'y': y 42 | }) 43 | -------------------------------------------------------------------------------- /tests/general/data/income-biased.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/tests/general/data/income-biased.zip -------------------------------------------------------------------------------- /tests/general/data/income-unbiased.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/tests/general/data/income-unbiased.zip -------------------------------------------------------------------------------- /tests/general/models/income-xgd-biased.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trustyai-explainability/trustyai-explainability-python/d16ff9e8c3b37b781414a0fc59887faf7875bf20/tests/general/models/income-xgd-biased.joblib -------------------------------------------------------------------------------- /tests/general/test_dataset.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801 2 | """Test suite for the Dataset structure""" 3 | 4 | from common import * 5 | 6 | from java.util import Random 7 | from pytest import approx 8 | import pandas as pd 9 | import numpy as np 10 | import uuid 11 | 12 | from trustyai.model import Dataset, Type 13 | 14 | 15 | jrandom = Random() 16 | jrandom.setSeed(0) 17 | 18 | def generate_test_df(): 19 | data = { 20 | 'x1': np.random.uniform(low=100, high=200, size=100), 21 | 'x2': np.random.uniform(low=5000, high=10000, size=100), 22 | 'x3': [str(uuid.uuid4()) for _ in range(100)], 23 | 'x4': np.random.randint(low=0, high=42, size=100), 24 | 'select': np.random.choice(a=[False, True], size=100) 25 | } 26 | return pd.DataFrame(data=data) 27 | 28 | def generate_test_array(): 29 | return np.random.rand(100, 5) 30 | 31 | 32 | def test_no_output(): 33 | """Checks whether we have an output when specifying none""" 34 | df = generate_test_df() 35 | dataset = Dataset.from_df(df) 36 | outputs = dataset.outputs[0].outputs 37 | assert len(outputs) == 1 38 | assert outputs[0].name == 'select' 39 | 40 | def test_outputs(): 41 | """Checks whether we have the correct specified outputs""" 42 | df = generate_test_df() 43 | dataset = Dataset.from_df(df, outputs=["x2", "x3"]) 44 | outputs = dataset.outputs[0].outputs 45 | assert len(outputs) == 2 46 | assert outputs[0].name == 'x2' and outputs[1].name == 'x3' 47 | 48 | def test_shape(): 49 | """Checks whether we have the correct shape""" 50 | df = generate_test_df() 51 | dataset = Dataset.from_df(df, outputs=["x4"]) 52 | assert len(dataset.outputs) == 100 53 | assert len(dataset.inputs) == 100 54 | assert len(dataset.data) == 100 55 | 56 | assert len(dataset.inputs[0].features) == 4 57 | assert len(dataset.outputs[0].outputs) == 1 58 | 59 | def test_types(): 60 | """Checks whether we have the correct shape""" 61 | df = generate_test_df() 62 | dataset = Dataset.from_df(df, outputs=["x4"]) 63 | features = dataset.inputs[0].features 64 | assert features[0].type == Type.NUMBER and features[0].name == 'x1' 65 | assert features[1].type == Type.NUMBER and features[1].name == 'x2' 66 | assert features[2].type == Type.CATEGORICAL and features[2].name == 'x3' 67 | assert features[3].type == Type.BOOLEAN and features[3].name == 'select' 68 | outputs = dataset.outputs[0].outputs 69 | assert outputs[0].type == Type.NUMBER and outputs[0].name == 'x4' 70 | 71 | def test_array_no_output(): 72 | """Checks whether we have an output when specifying none""" 73 | array = generate_test_array() 74 | dataset = Dataset.from_numpy(array) 75 | outputs = dataset.outputs[0].outputs 76 | assert len(outputs) == 1 77 | assert outputs[0].name == 'output-0' 78 | 79 | def test_array_outputs(): 80 | """Checks whether we have the correct specified outputs""" 81 | array = generate_test_array() 82 | dataset = Dataset.from_numpy(array, outputs=[1, 2]) 83 | outputs = dataset.outputs[0].outputs 84 | assert len(outputs) == 2 85 | assert outputs[0].name == 'output-0' and outputs[1].name == 'output-1' 86 | 87 | def test_array_shape(): 88 | """Checks whether we have the correct shape""" 89 | array = generate_test_array() 90 | dataset = Dataset.from_numpy(array, outputs=[4]) 91 | assert len(dataset.outputs) == 100 92 | assert len(dataset.inputs) == 100 93 | assert len(dataset.data) == 100 94 | 95 | assert len(dataset.inputs[0].features) == 4 96 | assert len(dataset.outputs[0].outputs) == 1 -------------------------------------------------------------------------------- /tests/general/test_datautils.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Data utils test suite""" 3 | from common import * 4 | 5 | from pytest import approx 6 | import random 7 | 8 | from trustyai.utils import DataUtils 9 | from trustyai.model import FeatureFactory 10 | from java.util import Random 11 | 12 | jrandom = Random() 13 | 14 | 15 | def test_get_mean(): 16 | """Test GetMean""" 17 | data = [2, 4, 3, 5, 1] 18 | assert DataUtils.getMean(data) == approx(3, 1e-6) 19 | 20 | 21 | def test_get_std_dev(): 22 | """Test GetStdDev""" 23 | data = [2, 4, 3, 5, 1] 24 | assert DataUtils.getStdDev(data, 3) == approx(1.41, 1e-2) 25 | 26 | 27 | def test_gaussian_kernel(): 28 | """Test Gaussian Kernel""" 29 | x = 0.0 30 | k = DataUtils.gaussianKernel(x, 0, 1) 31 | assert k == approx(0.398, 1e-2) 32 | x = 0.218 33 | k = DataUtils.gaussianKernel(x, 0, 1) 34 | assert k == approx(0.389, 1e-2) 35 | 36 | 37 | def test_euclidean_distance(): 38 | """Test Euclidean distance""" 39 | x = [1, 1] 40 | y = [2, 3] 41 | distance = DataUtils.euclideanDistance(x, y) 42 | assert approx(distance, 1e-3) == 2.236 43 | 44 | 45 | def test_hamming_distance_double(): 46 | """Test Hamming distance for doubles""" 47 | x = [2, 1] 48 | y = [2, 3] 49 | distance = DataUtils.hammingDistance(x, y) 50 | assert distance == approx(1, 1e-1) 51 | 52 | 53 | def test_hamming_distance_string(): 54 | """Test Hamming distance for strings""" 55 | x = "test1" 56 | y = "test2" 57 | distance = DataUtils.hammingDistance(x, y) 58 | assert distance == approx(1, 1e-1) 59 | 60 | 61 | def test_doubles_to_features(): 62 | """Test doubles to features""" 63 | inputs = [1 if i % 2 == 0 else 0 for i in range(10)] 64 | features = DataUtils.doublesToFeatures(inputs) 65 | assert features is not None 66 | assert len(features) == 10 67 | for f in features: 68 | assert f is not None 69 | assert f.getName() is not None 70 | assert f.getValue() is not None 71 | 72 | 73 | def test_exponential_smoothing_kernel(): 74 | """Test exponential smoothing kernel""" 75 | x = 0.218 76 | k = DataUtils.exponentialSmoothingKernel(x, 2) 77 | assert k == approx(0.994, 1e-3) 78 | 79 | 80 | # def test_perturb_features_empty(): 81 | # """Test perturb empty features""" 82 | # features = [] 83 | # perturbationContext = PerturbationContext(jrandom, 0) 84 | # newFeatures = DataUtils.perturbFeatures(features, perturbationContext) 85 | # assert newFeatures is not None 86 | # assert len(features) == newFeatures.size() 87 | 88 | 89 | def test_random_distribution_generation(): 90 | """Test random distribution generation""" 91 | dataDistribution = DataUtils.generateRandomDataDistribution(10, 10, jrandom) 92 | assert dataDistribution is not None 93 | assert dataDistribution.asFeatureDistributions() is not None 94 | for featureDistribution in dataDistribution.asFeatureDistributions(): 95 | assert featureDistribution is not None 96 | 97 | 98 | def test_linearized_numeric_features(): 99 | """Test linearised numeric features""" 100 | f = FeatureFactory.newNumericalFeature("f-num", 1.0) 101 | features = [f] 102 | linearizedFeatures = DataUtils.getLinearizedFeatures(features) 103 | assert len(features) == linearizedFeatures.size() 104 | 105 | 106 | def test_sample_with_replacement(): 107 | """Test sample with replacement""" 108 | emptyValues = [] 109 | emptySamples = DataUtils.sampleWithReplacement(emptyValues, 1, jrandom) 110 | assert emptySamples is not None 111 | assert emptySamples.size() == 0 112 | 113 | values = DataUtils.generateData(0, 1, 100, jrandom) 114 | sampleSize = 10 115 | samples = DataUtils.sampleWithReplacement(values, sampleSize, jrandom) 116 | assert samples is not None 117 | assert samples.size() == sampleSize 118 | assert samples[random.randint(0, sampleSize - 1)] in values 119 | 120 | largerSampleSize = 300 121 | largerSamples = DataUtils.sampleWithReplacement(values, largerSampleSize, jrandom) 122 | assert largerSamples is not None 123 | assert largerSampleSize == largerSamples.size() 124 | assert largerSamples[random.randint(0, largerSampleSize - 1)] in largerSamples 125 | -------------------------------------------------------------------------------- /tests/general/test_limeexplainer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code 2 | """LIME explainer test suite""" 3 | 4 | from common import * 5 | 6 | import pytest 7 | 8 | from trustyai.explainers import LimeExplainer 9 | from trustyai.utils import TestModels 10 | from trustyai.model import feature, Model, simple_prediction 11 | from trustyai.metrics import ExplainabilityMetrics 12 | from trustyai.visualizations import plot 13 | 14 | from org.kie.trustyai.explainability.local import ( 15 | LocalExplanationException, 16 | ) 17 | 18 | 19 | def mock_features(n_features: int): 20 | return [mock_feature(i, f"f-num{i}") for i in range(n_features)] 21 | 22 | 23 | def test_empty_prediction(): 24 | """Check if the explanation returned is not null""" 25 | lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1) 26 | inputs = [] 27 | model = TestModels.getSumSkipModel(0) 28 | outputs = model.predict([inputs])[0].outputs 29 | with pytest.raises(LocalExplanationException): 30 | lime_explainer.explain(inputs=inputs, outputs=outputs, model=model) 31 | 32 | 33 | def test_non_empty_input(): 34 | """Test for non-empty input""" 35 | lime_explainer = LimeExplainer(seed=0, samples=10, perturbations=1) 36 | features = [feature(name=f"f-num{i}", value=i, dtype="number") for i in range(4)] 37 | 38 | model = TestModels.getSumSkipModel(0) 39 | outputs = model.predict([features])[0].outputs 40 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model) 41 | assert saliency_map is not None 42 | 43 | 44 | def test_sparse_balance(): # pylint: disable=too-many-locals 45 | """Test sparse balance""" 46 | for n_features in range(1, 4): 47 | lime_explainer_no_penalty = LimeExplainer(samples=100, penalise_sparse_balance=False) 48 | 49 | features = mock_features(n_features) 50 | 51 | model = TestModels.getSumSkipModel(0) 52 | outputs = model.predict([features])[0].outputs 53 | 54 | saliency_map_no_penalty = lime_explainer_no_penalty.explain( 55 | inputs=features, outputs=outputs, model=model 56 | ).saliency_map() 57 | 58 | assert saliency_map_no_penalty is not None 59 | 60 | decision_name = "sum-but0" 61 | saliency_no_penalty = saliency_map_no_penalty.get(decision_name) 62 | 63 | lime_explainer = LimeExplainer(samples=100, penalise_sparse_balance=True) 64 | 65 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map() 66 | assert saliency_map is not None 67 | 68 | saliency = saliency_map.get(decision_name) 69 | 70 | for i in range(len(features)): 71 | score = saliency.getPerFeatureImportance().get(i).getScore() 72 | score_no_penalty = ( 73 | saliency_no_penalty.getPerFeatureImportance().get(i).getScore() 74 | ) 75 | assert abs(score) <= abs(score_no_penalty) 76 | 77 | 78 | def test_normalized_weights(): 79 | """Test normalized weights""" 80 | lime_explainer = LimeExplainer(normalise_weights=True, perturbations=2, samples=10) 81 | n_features = 4 82 | features = mock_features(n_features) 83 | model = TestModels.getSumSkipModel(0) 84 | outputs = model.predict([features])[0].outputs 85 | 86 | saliency_map = lime_explainer.explain(inputs=features, outputs=outputs, model=model).saliency_map() 87 | assert saliency_map is not None 88 | 89 | decision_name = "sum-but0" 90 | saliency = saliency_map.get(decision_name) 91 | per_feature_importance = saliency.getPerFeatureImportance() 92 | for feature_importance in per_feature_importance: 93 | assert -3.0 < feature_importance.getScore() < 3.0 94 | 95 | 96 | def lime_plots(block): 97 | """Test normalized weights""" 98 | lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10) 99 | n_features = 15 100 | features = mock_features(n_features) 101 | model = TestModels.getSumSkipModel(0) 102 | outputs = model.predict([features])[0].outputs 103 | 104 | explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model) 105 | plot(explanation, block=block) 106 | plot(explanation, block=block, render_bokeh=True) 107 | plot(explanation, block=block, output_name="sum-but0") 108 | plot(explanation, block=block, output_name="sum-but0", render_bokeh=True) 109 | 110 | 111 | @pytest.mark.block_plots 112 | def test_lime_plots_blocking(): 113 | lime_plots(True) 114 | 115 | 116 | def test_lime_plots(): 117 | lime_plots(False) 118 | 119 | 120 | def test_lime_v2(): 121 | np.random.seed(0) 122 | data = pd.DataFrame(np.random.rand(1, 5)).values 123 | 124 | model_weights = np.random.rand(5) 125 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 126 | model = Model(predict_function) 127 | 128 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 129 | explanation = explainer.explain(inputs=data, outputs=model(data), model=model) 130 | 131 | for score in explanation.as_dataframe()["output-0"]['Saliency']: 132 | assert score != 0 133 | 134 | for out_name, df in explanation.as_dataframe().items(): 135 | assert "Feature" in df 136 | assert "output" in out_name 137 | assert all([x in str(df) for x in "01234"]) 138 | 139 | 140 | def test_impact_score(): 141 | np.random.seed(0) 142 | data = pd.DataFrame(np.random.rand(1, 5)) 143 | model_weights = np.random.rand(5) 144 | predict_function = lambda x: np.dot(x.values, model_weights) 145 | model = Model(predict_function, dataframe_input=True) 146 | output = model(data) 147 | pred = simple_prediction(data, output) 148 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 149 | explanation = explainer.explain(inputs=data, outputs=output, model=model) 150 | saliency = list(explanation.saliency_map().values())[0] 151 | top_features_t = saliency.getTopFeatures(2) 152 | impact = ExplainabilityMetrics.impactScore(model, pred, top_features_t) 153 | assert impact > 0 154 | return impact 155 | 156 | 157 | def test_lime_as_html(): 158 | np.random.seed(0) 159 | data = np.random.rand(1, 5) 160 | 161 | model_weights = np.random.rand(5) 162 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 163 | 164 | model = Model(predict_function, disable_arrow=True) 165 | 166 | explainer = LimeExplainer() 167 | explainer.explain(inputs=data, outputs=model(data), model=model) 168 | assert True 169 | 170 | explanation = explainer.explain(inputs=data, outputs=model(data), model=model) 171 | for score in explanation.as_dataframe()["output-0"]['Saliency']: 172 | assert score != 0 173 | 174 | 175 | def test_lime_numpy(): 176 | np.random.seed(0) 177 | data = np.random.rand(101, 5) 178 | model_weights = np.random.rand(5) 179 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 180 | fnames = ['f{}'.format(x) for x in "abcde"] 181 | onames = ['o{}'.format(x) for x in "12"] 182 | model = Model(predict_function, 183 | feature_names=fnames, 184 | output_names=onames 185 | ) 186 | 187 | explainer = LimeExplainer() 188 | explanation = explainer.explain(inputs=data[0], outputs=model(data[0]), model=model) 189 | 190 | for oname in onames: 191 | assert oname in explanation.as_dataframe().keys() 192 | for fname in fnames: 193 | assert fname in explanation.as_dataframe()[oname]['Feature'].values 194 | 195 | -------------------------------------------------------------------------------- /tests/general/test_metrics_language.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code, unused-import 2 | """Language metrics test suite""" 3 | 4 | from common import * 5 | from trustyai.metrics.language import word_error_rate 6 | import math 7 | 8 | tolerance = 1e-4 9 | 10 | REFERENCES = [ 11 | "This is the test reference, to which I will compare alignment against.", 12 | "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce euismod tortor massa, nec euismod sapien laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.", 13 | "The quick red fox jumped over the lazy brown dog"] 14 | 15 | INPUTS = [ 16 | "I'm a hypothesis reference, from which the aligner will compare against.", 17 | "Lorem ipsum sit amet, consectetur adipiscing elit. Curabitur condimentum velit id velit posuere dictum. Fusce blandit euismod tortor massa, nec euismod sapien blandit laoreet non. Donec vulputate mi velit, eu ultricies nibh iaculis vel. Aenean posuere urna nec sapien consectetur, vitae porttitor sapien finibus. Duis nec libero convallis lectus pharetra blandit ut ac odio. Vivamus nec dui quis sem convallis pulvinar. Maecenas sodales sollicitudin leo a faucibus.", 18 | "dog brown lazy the over jumped fox red quick The"] 19 | 20 | 21 | def test_default_tokenizer(): 22 | """Test default tokenizer""" 23 | results = [4 / 7, 1 / 26, 1] 24 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 25 | wer = word_error_rate(reference, hypothesis).value 26 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 27 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 28 | 29 | 30 | def test_commons_stringtokenizer(): 31 | """Test Apache Commons StringTokenizer""" 32 | from trustyai.utils.tokenizers import CommonsStringTokenizer 33 | results = [8 / 12., 3 / 66., 1.0] 34 | 35 | def tokenizer(text: str) -> List[str]: 36 | return CommonsStringTokenizer(text).getTokenList() 37 | 38 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 39 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 40 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 41 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 42 | 43 | 44 | def test_opennlp_tokenizer(): 45 | """Test Apache Commons StringTokenizer""" 46 | from trustyai.utils.tokenizers import OpenNLPTokenizer 47 | results = [9 / 14., 3 / 78., 1.0] 48 | tokenizer = OpenNLPTokenizer() 49 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 50 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 51 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 52 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 53 | 54 | 55 | def test_python_tokenizer(): 56 | """Test pure Python whitespace tokenizer""" 57 | 58 | results = [3 / 4., 3 / 66., 1.0] 59 | 60 | def tokenizer(text: str) -> List[str]: 61 | return text.split(" ") 62 | 63 | for i, (reference, hypothesis) in enumerate(zip(REFERENCES, INPUTS)): 64 | wer = word_error_rate(reference, hypothesis, tokenizer=tokenizer).value 65 | assert math.isclose(wer, results[i], rel_tol=tolerance), \ 66 | f"WER for {reference}, {hypothesis} was {wer}, expected ~{results[i]}." 67 | -------------------------------------------------------------------------------- /tests/general/test_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Test model provider interface""" 3 | 4 | from common import * 5 | from trustyai.model import Model, Dataset, feature 6 | 7 | import pytest 8 | 9 | from trustyai.utils.data_conversions import numpy_to_prediction_object 10 | 11 | 12 | def test_basic_model(): 13 | """Test basic model""" 14 | 15 | model = Model(lambda x: x, output_names=['a', 'b', 'c', 'd', 'e']) 16 | features = numpy_to_prediction_object(np.arange(0, 100).reshape(20, 5), feature) 17 | result = model.predictAsync(features).get() 18 | assert len(result[0].outputs) == 5 19 | 20 | 21 | def test_cast_output(): 22 | np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'], disable_arrow=True) 23 | np2df = Model(lambda x: pd.DataFrame(x), disable_arrow=True) 24 | df2np = Model(lambda x: x.sum(1).values, 25 | dataframe_input=True, 26 | output_names=['sum'], 27 | disable_arrow=True) 28 | df2df = Model(lambda x: x, dataframe_input=True, disable_arrow=True) 29 | 30 | pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature) 31 | 32 | for m in [np2np, np2df, df2df, df2np]: 33 | output_val = m.predictAsync(pis).get() 34 | assert len(output_val) == 25 35 | 36 | 37 | def test_cast_output_arrow(): 38 | np2np = Model(lambda x: np.sum(x, 1), output_names=['sum']) 39 | np2df = Model(lambda x: pd.DataFrame(x)) 40 | df2np = Model(lambda x: x.sum(1).values, dataframe_input=True, output_names=['sum']) 41 | df2df = Model(lambda x: x, dataframe_input=True) 42 | pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature) 43 | 44 | for m in [np2np, np2df, df2df, df2np]: 45 | m._set_arrow(pis[0]) 46 | output_val = m.predictAsync(pis).get() 47 | assert len(output_val) == 25 48 | 49 | -------------------------------------------------------------------------------- /tests/general/test_pdp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """PDP test suite""" 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from sklearn.datasets import make_classification 8 | from trustyai.explainers import PDPExplainer 9 | from trustyai.model import Model 10 | from trustyai.utils import TestModels 11 | from trustyai.visualizations import plot 12 | 13 | 14 | def create_random_df(): 15 | X, _ = make_classification(n_samples=5000, n_features=5, n_classes=2, 16 | n_clusters_per_class=2, class_sep=2, flip_y=0, random_state=23) 17 | 18 | return pd.DataFrame({ 19 | 'x1': X[:, 0], 20 | 'x2': X[:, 1], 21 | 'x3': X[:, 2], 22 | 'x4': X[:, 3], 23 | 'x5': X[:, 4], 24 | }) 25 | 26 | 27 | def test_pdp_sumskip(): 28 | """Test PDP with sum skip model on random generated data""" 29 | 30 | df = create_random_df() 31 | model = TestModels.getSumSkipModel(0) 32 | pdp_explainer = PDPExplainer() 33 | pdp_results = pdp_explainer.explain(model, df) 34 | assert pdp_results is not None 35 | assert pdp_results.as_dataframe() is not None 36 | 37 | 38 | def test_pdp_sumthreshold(): 39 | """Test PDP with sum threshold model on random generated data""" 40 | 41 | df = create_random_df() 42 | model = TestModels.getLinearThresholdModel([0.1, 0.2, 0.3, 0.4, 0.5], 0) 43 | pdp_explainer = PDPExplainer() 44 | pdp_results = pdp_explainer.explain(model, df) 45 | assert pdp_results is not None 46 | assert pdp_results.as_dataframe() is not None 47 | 48 | 49 | def pdp_plots(block): 50 | """Test PDP plots""" 51 | np.random.seed(0) 52 | data = pd.DataFrame(np.random.rand(101, 5)) 53 | 54 | model_weights = np.random.rand(5) 55 | predict_function = lambda x: np.stack([np.dot(x.values, model_weights), 2 * np.dot(x.values, model_weights)], -1) 56 | model = Model(predict_function, dataframe_input=True) 57 | pdp_explainer = PDPExplainer() 58 | explanation = pdp_explainer.explain(model, data) 59 | 60 | plot(explanation, block=block) 61 | plot(explanation, block=block, output_name='output-0') 62 | 63 | 64 | @pytest.mark.block_plots 65 | def test_lime_plots_blocking(): 66 | pdp_plots(True) 67 | 68 | 69 | def test_lime_plots(): 70 | pdp_plots(False) 71 | -------------------------------------------------------------------------------- /tests/general/test_prediction.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name 2 | """Test model provider interface""" 3 | 4 | from common import * 5 | from trustyai.model import simple_prediction, counterfactual_prediction,feature, output 6 | from trustyai.utils.data_conversions import numpy_to_prediction_object 7 | import pytest 8 | 9 | # test that predictions are created correctly from numpy arrays 10 | def test_predictions_numpy(): 11 | input_values = np.arange(5) 12 | output_values = np.arange(2) 13 | 14 | pred = simple_prediction(input_values, output_values) 15 | assert len(pred.getInput().getFeatures()) == 5 16 | 17 | pred = counterfactual_prediction(input_values, output_values) 18 | assert len(pred.getInput().getFeatures()) == 5 19 | 20 | 21 | # test that predictions are created correctly from dataframe 22 | def test_predictions_pandas(): 23 | input_values = pd.DataFrame(np.arange(5).reshape(1, -1), columns=list("abcde")) 24 | output_values = pd.DataFrame(np.arange(2).reshape(1, -1), columns=list("xy")) 25 | 26 | pred = simple_prediction(input_values, output_values) 27 | assert len(pred.getInput().getFeatures()) == 5 28 | assert pred.getInput().getFeatures()[0].getName() == "a" 29 | 30 | pred = counterfactual_prediction(input_values, output_values) 31 | assert pred.getInput().getFeatures()[0].getName() == "a" 32 | assert len(pred.getInput().getFeatures()) == 5 33 | 34 | 35 | # test that predictions are created correctly from prediction input + outputs 36 | def test_prediction_pi(): 37 | input_values = numpy_to_prediction_object(np.arange(5).reshape(1, -1), feature)[0] 38 | output_values = numpy_to_prediction_object(np.arange(2).reshape(1, -1), output)[0] 39 | 40 | pred = simple_prediction(input_values, output_values) 41 | assert len(pred.getInput().getFeatures()) == 5 42 | 43 | pred = counterfactual_prediction(input_values, output_values) 44 | assert len(pred.getInput().getFeatures()) == 5 45 | 46 | 47 | # test that predictions are created correctly from feature+output lists 48 | def test_prediction_featurelist(): 49 | input_values = numpy_to_prediction_object( 50 | np.arange(5).reshape(1, -1), feature 51 | )[0].getFeatures() 52 | output_values = numpy_to_prediction_object( 53 | np.arange(2).reshape(1, -1), output 54 | )[0].getOutputs() 55 | 56 | pred = simple_prediction(input_values, output_values) 57 | assert len(pred.getInput().getFeatures()) == 5 58 | 59 | pred = counterfactual_prediction(input_values, output_values) 60 | assert len(pred.getInput().getFeatures()) == 5 61 | -------------------------------------------------------------------------------- /tests/general/test_shap.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=import-error, wrong-import-position, wrong-import-order, duplicate-code, unused-import 2 | """SHAP explainer test suite""" 3 | 4 | from common import * 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | np.random.seed(0) 11 | 12 | import pytest 13 | from trustyai.explainers import SHAPExplainer 14 | from trustyai.model import feature, Model 15 | from trustyai.utils.data_conversions import numpy_to_prediction_object 16 | from trustyai.utils import TestModels 17 | from trustyai.visualizations import plot 18 | 19 | 20 | def test_no_variance_one_output(): 21 | """Check if the explanation returned is not null""" 22 | model = TestModels.getSumSkipModel(0) 23 | 24 | background = np.array([[1.0, 2.0, 3.0] for _ in range(2)]) 25 | prediction_outputs = model.predictAsync(numpy_to_prediction_object(background, feature)).get() 26 | shap_explainer = SHAPExplainer(background=background) 27 | for i in range(2): 28 | explanation = shap_explainer.explain(inputs=background[i], outputs=prediction_outputs[i].outputs, model=model) 29 | for _, saliency in explanation.saliency_map().items(): 30 | for feature_importance in saliency.getPerFeatureImportance()[:-1]: 31 | assert feature_importance.getScore() == 0.0 32 | 33 | 34 | def test_shap_arrow(): 35 | """Basic SHAP/Arrow test""" 36 | np.random.seed(0) 37 | data = pd.DataFrame(np.random.rand(101, 5)) 38 | background = data.iloc[:100] 39 | to_explain = data.iloc[100:101] 40 | 41 | model_weights = np.random.rand(5) 42 | predict_function = lambda x: np.dot(x.values, model_weights) 43 | 44 | model = Model(predict_function, dataframe_input=True) 45 | shap_explainer = SHAPExplainer(background=background) 46 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 47 | 48 | 49 | answers = [-.152, -.114, 0.00304, .0525, -.0725] 50 | for _, saliency in explanation.saliency_map().items(): 51 | for i, feature_importance in enumerate(saliency.getPerFeatureImportance()[:-1]): 52 | assert answers[i] - 1e-2 <= feature_importance.getScore() <= answers[i] + 1e-2 53 | 54 | 55 | def shap_plots(block): 56 | """Test SHAP plots""" 57 | np.random.seed(0) 58 | data = pd.DataFrame(np.random.rand(101, 5)) 59 | background = data.iloc[:100] 60 | to_explain = data.iloc[100:101] 61 | 62 | model_weights = np.random.rand(5) 63 | predict_function = lambda x: np.stack([np.dot(x.values, model_weights), 2 * np.dot(x.values, model_weights)], -1) 64 | model = Model(predict_function, dataframe_input=True) 65 | shap_explainer = SHAPExplainer(background=background) 66 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 67 | 68 | plot(explanation, block=block) 69 | plot(explanation, block=block, render_bokeh=True) 70 | plot(explanation, block=block, output_name='output-0') 71 | plot(explanation, block=block, output_name='output-0', render_bokeh=True) 72 | 73 | 74 | @pytest.mark.block_plots 75 | def test_shap_plots_blocking(): 76 | shap_plots(block=True) 77 | 78 | 79 | def test_shap_plots(): 80 | shap_plots(block=False) 81 | 82 | 83 | def test_shap_as_df(): 84 | np.random.seed(0) 85 | data = pd.DataFrame(np.random.rand(101, 5)) 86 | background = data.iloc[:100].values 87 | to_explain = data.iloc[100:101].values 88 | 89 | model_weights = np.random.rand(5) 90 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 91 | 92 | model = Model(predict_function, disable_arrow=True) 93 | 94 | shap_explainer = SHAPExplainer(background=background) 95 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 96 | 97 | for out_name, df in explanation.as_dataframe().items(): 98 | assert "Mean Background Value" in df 99 | assert "output" in out_name 100 | assert all([x in str(df) for x in "01234"]) 101 | 102 | 103 | def test_shap_as_html(): 104 | np.random.seed(0) 105 | data = pd.DataFrame(np.random.rand(101, 5)) 106 | background = data.iloc[:100].values 107 | to_explain = data.iloc[100:101].values 108 | 109 | model_weights = np.random.rand(5) 110 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 111 | 112 | model = Model(predict_function, disable_arrow=True) 113 | 114 | shap_explainer = SHAPExplainer(background=background) 115 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 116 | assert True 117 | 118 | 119 | def test_shap_numpy(): 120 | np.random.seed(0) 121 | data = np.random.rand(101, 5) 122 | model_weights = np.random.rand(5) 123 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 124 | fnames = ['f{}'.format(x) for x in "abcde"] 125 | onames = ['o{}'.format(x) for x in "12"] 126 | model = Model(predict_function, 127 | feature_names=fnames, 128 | output_names=onames 129 | ) 130 | 131 | shap_explainer = SHAPExplainer(background=data[1:]) 132 | explanation = shap_explainer.explain(inputs=data[0], outputs=model(data[0]), model=model) 133 | 134 | for oname in onames: 135 | assert oname in explanation.as_dataframe().keys() 136 | for fname in fnames: 137 | assert fname in explanation.as_dataframe()[oname]['Feature'].values 138 | 139 | 140 | # deliberately make strange plot to test pre and post-function plot editing 141 | def test_shap_edit_plot(): 142 | np.random.seed(0) 143 | data = pd.DataFrame(np.random.rand(101, 5)) 144 | background = data.iloc[:100].values 145 | to_explain = data.iloc[100:101].values 146 | 147 | model_weights = np.random.rand(5) 148 | predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1) 149 | 150 | model = Model(predict_function, disable_arrow=True) 151 | 152 | shap_explainer = SHAPExplainer(background=background) 153 | explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model) 154 | 155 | plt.figure(figsize=(32,2)) 156 | plot(explanation, call_show=False) 157 | plt.ylim(0, 123) 158 | plt.show() 159 | 160 | -------------------------------------------------------------------------------- /tests/general/test_shap_background_generation.py: -------------------------------------------------------------------------------- 1 | """SHAP background generation test suite""" 2 | 3 | import pytest 4 | import numpy as np 5 | import math 6 | 7 | from trustyai.explainers.shap import BackgroundGenerator 8 | from trustyai.model import Model, feature_domain 9 | from trustyai.utils.data_conversions import prediction_object_to_numpy 10 | 11 | 12 | def test_random_generation(): 13 | """Test that random sampling recovers samples from distribution""" 14 | seed = 0 15 | np.random.seed(seed) 16 | data = np.random.rand(100, 5) 17 | background_ta = BackgroundGenerator(data).sample(5) 18 | background = prediction_object_to_numpy(background_ta) 19 | 20 | assert len(background) == 5 21 | for row in background: 22 | assert row in data 23 | 24 | 25 | def test_kmeans_generation(): 26 | """Test that k-means recovers centroids of well-clustered data""" 27 | 28 | seed = 0 29 | clusters = 5 30 | np.random.seed(seed) 31 | 32 | data = [] 33 | ground_truth = [] 34 | for cluster in range(clusters): 35 | data.append(np.random.rand(100 // clusters, 5) + cluster * 10) 36 | ground_truth.append(np.array([cluster * 10] * 5)) 37 | data = np.vstack(data) 38 | ground_truth = np.vstack(ground_truth) 39 | background_ta = BackgroundGenerator(data).kmeans(clusters) 40 | background = prediction_object_to_numpy(background_ta) 41 | 42 | assert len(background) == 5 43 | for row in background: 44 | ground_truth_idx = math.floor(row[0] / 10) 45 | assert np.linalg.norm(row - ground_truth[ground_truth_idx]) < 2.5 46 | 47 | 48 | def test_counterfactual_generation_single_goal(): 49 | """Test that cf background meets requirements""" 50 | seed = 0 51 | np.random.seed(seed) 52 | data = np.random.rand(100, 5) 53 | model = Model(lambda x: x.sum(1)) 54 | goal = np.array([1.0]) 55 | 56 | # check that undomained backgrounds are caught 57 | attribute_error_thrown = False 58 | try: 59 | BackgroundGenerator(data).counterfactual(goal, model, 10,) 60 | except AttributeError: 61 | attribute_error_thrown = True 62 | assert attribute_error_thrown 63 | 64 | domains = [feature_domain((-10, 10)) for _ in range(5)] 65 | background_ta = BackgroundGenerator(data, domains, seed)\ 66 | .counterfactual(goal, model, 5, step_count=5000, timeout_seconds=2) 67 | background = prediction_object_to_numpy(background_ta) 68 | 69 | for row in background: 70 | assert np.linalg.norm(goal - model(row.reshape(1, -1))) < .01 71 | 72 | 73 | def test_counterfactual_generation_multi_goal(): 74 | """Test that cf background meets requirements for multiple goals""" 75 | 76 | seed = 0 77 | np.random.seed(seed) 78 | data = np.random.rand(100, 5) 79 | model = Model(lambda x: x.sum(1)) 80 | goals = np.arange(1, 10).reshape(-1, 1) 81 | domains = [feature_domain((-10, 10)) for _ in range(5)] 82 | background_ta = BackgroundGenerator(data, domains, seed)\ 83 | .counterfactual(goals, model, 1, step_count=5000, timeout_seconds=2, chain=True) 84 | background = prediction_object_to_numpy(background_ta) 85 | 86 | for i, goal in enumerate(goals): 87 | assert np.linalg.norm(goal - model(background[i:i+1])) < goal[0]/100 88 | -------------------------------------------------------------------------------- /tests/general/test_tyrus.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | from trustyai.model import Model 5 | from trustyai.utils.tyrus import Tyrus 6 | import numpy as np 7 | import pandas as pd 8 | 9 | import os 10 | 11 | 12 | def test_tyrus_series(): 13 | # define data 14 | data = pd.DataFrame(np.random.rand(101, 5), columns=list('ABCDE')) 15 | 16 | # define model 17 | def predict_function(x): 18 | return pd.DataFrame( 19 | np.stack( 20 | [x.sum(1), x.std(1), np.linalg.norm(x, axis=1)]).T, 21 | columns= ['Sum', 'StdDev', 'L2 Norm']) 22 | 23 | predictions = predict_function(data) 24 | 25 | model = Model(predict_function, dataframe_input=True) 26 | 27 | # create Tyrus instance 28 | tyrus = Tyrus( 29 | model, 30 | data.iloc[100], 31 | predictions.iloc[100], 32 | background=data.iloc[:100], 33 | filepath=os.getcwd() 34 | ) 35 | 36 | # launch dashboard 37 | tyrus.run() 38 | 39 | # see if dashboard html exists 40 | assert "trustyai_dashboard.html" in os.listdir() 41 | 42 | # cleanup 43 | os.remove("trustyai_dashboard.html") 44 | 45 | 46 | def test_tyrus_numpy(): 47 | # define data 48 | data = np.random.rand(101, 5) 49 | 50 | # define model 51 | def predict_function(x): 52 | return np.stack([x.sum(1), x.std(1), np.linalg.norm(x, axis=1)]).T 53 | 54 | predictions = predict_function(data) 55 | 56 | model = Model(predict_function, dataframe_input=False) 57 | 58 | # create Tyrus instance 59 | tyrus = Tyrus( 60 | model, 61 | data[100], 62 | predictions[100], 63 | background=data[:100] 64 | ) 65 | 66 | # launch dashboard 67 | tyrus.run() 68 | 69 | # see if dashboard html exists 70 | assert "trustyai_dashboard.html" in os.listdir(tyrus.filepath) 71 | -------------------------------------------------------------------------------- /tests/general/universal.py: -------------------------------------------------------------------------------- 1 | # General Setup 2 | from trustyai.model import Model, simple_prediction, counterfactual_prediction 3 | from trustyai.explainers import * 4 | 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | 10 | np.random.seed(0) 11 | 12 | @pytest.mark.skip("redundant") 13 | def test_all_explainers(): 14 | # universal setup ============================================================================== 15 | data = pd.DataFrame(np.random.rand(1, 5)) 16 | model_weights = np.random.rand(5) 17 | predict_function = lambda x: np.dot(x.values, model_weights) 18 | model = Model(predict_function, dataframe_input=True, arrow=True) 19 | prediction = simple_prediction(input_features=data, outputs=model(data)) 20 | 21 | # SHAP ========================================================================================= 22 | background = pd.DataFrame(np.zeros([100, 5])) 23 | shap_explainer = SHAPExplainer(background=background) 24 | explanation = shap_explainer.explain(prediction, model) 25 | 26 | for score in explanation.as_dataframe()['SHAP Value'].iloc[1:-1]: 27 | assert score > 0 28 | 29 | # LIME ========================================================================================= 30 | explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False) 31 | explanation = explainer.explain(prediction, model) 32 | for score in explanation.as_dataframe()["output-0_score"]: 33 | assert score > 0 34 | 35 | # Counterfactual =============================================================================== 36 | features = [feature(str(k), "number", v, domain=(-10., 10.)) for k, v in data.iloc[0].items()] 37 | goal = np.array([[0]]) 38 | cf_prediction = counterfactual_prediction(input_features=features, outputs=goal) 39 | explainer = CounterfactualExplainer(steps=10_000) 40 | explanation = explainer.explain(cf_prediction, model) 41 | result_output = model(explanation.get_proposed_features_as_pandas()) 42 | assert result_output < .01 43 | assert result_output > -.01 44 | -------------------------------------------------------------------------------- /tests/initialization/test_initialization.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | 4 | import pytest 5 | from multiprocessing import Process, Value 6 | import sys 7 | 8 | 9 | # slightly hacky functions to make sure the test process does not see the trustyai initialization 10 | # from commons.py 11 | def test_manual_initializer_process(): 12 | import trustyai 13 | from trustyai import initializer 14 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 15 | initializer.init(path=initializer._get_default_path()[0]) 16 | 17 | # test imports work 18 | from trustyai.explainers import SHAPExplainer 19 | 20 | # test initialization is set 21 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 22 | assert initial_state == False 23 | assert final_state == True 24 | 25 | 26 | def test_default_initializer_process_mod(): 27 | import trustyai 28 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 29 | import trustyai.model 30 | 31 | # test initialization is set 32 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 33 | assert initial_state == False 34 | assert final_state == True 35 | 36 | 37 | def test_default_initializer_process_exp(): 38 | import trustyai 39 | initial_state = trustyai.TRUSTYAI_IS_INITIALIZED 40 | import trustyai.explainers 41 | 42 | # test initialization is set 43 | final_state = trustyai.TRUSTYAI_IS_INITIALIZED 44 | assert initial_state == False 45 | assert final_state == True 46 | --------------------------------------------------------------------------------