├── .github └── workflows │ ├── linting-and-type-check.yml │ ├── publish.yml │ ├── test-package-and-comment.yml │ └── test-python-package.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── bus_757.png ├── id_curve.png ├── ladybird_301.png ├── ladybird_301_raw.png ├── ladybird_rm.png ├── lizard_subs.png ├── occluded_bus_rm.png ├── peacock.jpg ├── peacock_05.png ├── peacock_84_00.png ├── peacock_84_01.png ├── peacock_84_02.png ├── peacock_comp.png ├── peacock_exp.png ├── rex-assumptions-768x259.png ├── rex-structure-600x129.png ├── rex_logo.png ├── rex_logo.svg ├── spectrum.png ├── tandem.jpg ├── tandem_contrastive.png ├── tandem_exp.png ├── tandem_multi_1.png ├── tandem_multi_2.png └── tandem_multi_3.png ├── docs ├── Makefile ├── _static │ └── rex_logo.png ├── background.md ├── command_line.md ├── conf.py ├── config.md ├── contrastive.md ├── index.md ├── index.rst ├── make.bat ├── multiple.md ├── notebooks │ └── intro.md └── script.md ├── example.rex.toml ├── poetry.lock ├── pyproject.toml ├── rex_xai ├── __init__.py ├── explanation │ ├── evaluation.py │ ├── explanation.py │ ├── multi_explanation.py │ └── rex.py ├── input │ ├── config.py │ ├── input_data.py │ └── onnx.py ├── mutants │ ├── box.py │ ├── distributions.py │ ├── mutant.py │ └── occlusions.py ├── output │ ├── database.py │ └── visualisation.py ├── responsibility │ ├── prediction.py │ ├── resp_maps.py │ └── responsibility.py ├── rex_wrapper.py └── utils │ ├── _utils.py │ └── logger.py ├── scripts ├── pytorch.py ├── spectral_pytorch.py └── timm_script.py ├── shell └── _rex └── tests ├── conftest.py ├── scripts ├── pytorch_3d.py ├── pytorch_resnet50.py └── pytorch_swin_v2_t.py ├── snapshot_tests ├── __snapshots__ │ ├── _explanation_onnx_test.ambr │ ├── _explanation_test.ambr │ ├── explanation_test.ambr │ ├── load_preprocess_test.ambr │ └── spectral_test.ambr ├── _explanation_onnx_test.py ├── _explanation_test.py ├── explanation_test.py ├── load_preprocess_test.py └── spectral_test.py ├── test_data ├── 004_0002.jpg ├── 2008_000033.jpg ├── DoublePeakClass 0 Mean 1.npy ├── DoublePeakClass 0 Mean.npy ├── DoublePeakClass 1 Mean.npy ├── DoublePeakClass 2 Mean.npy ├── ILSVRC2012_val_00047302.JPEG ├── TCGA_DU_7018_19911220_14.tif ├── bike.jpg ├── dog.jpg ├── dog_hide.jpg ├── ladybird.jpg ├── lizard.jpg ├── peacock.jpg ├── positive193.npy ├── rex-test-all-config.toml ├── spectrum_class_DNA.npy ├── spectrum_class_noDNA.npy ├── starfish.jpg ├── tennis.jpg └── testimage.png └── unit_tests ├── box_test.py ├── cmd_args_test.py ├── config_test.py ├── data_test.py ├── database_test.py ├── multi-explanation_test.py ├── preprocessing_test.py ├── validate_args_test.py └── visualisation_test.py /.github/workflows/linting-and-type-check.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install ReX and carry out linting and type checking 2 | 3 | name: Linting and type checking 4 | 5 | on: pull_request 6 | 7 | jobs: 8 | lint: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ["3.13"] 14 | 15 | permissions: 16 | pull-requests: write 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Install poetry 21 | run: pipx install poetry 22 | 23 | - name: Set up Python, install and cache dependencies 24 | uses: actions/setup-python@v5 25 | with: 26 | python-version: ${{matrix.python-version}} 27 | cache: poetry 28 | - run: poetry install --with dev 29 | - run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH 30 | 31 | - name: Lint with Ruff 32 | run: | 33 | ruff check --output-format=github . 34 | 35 | - name: Run pyright with reviewdog 36 | uses: jordemort/action-pyright@e85f3910971e8bd8cec27d8c7235d1f99825e570 37 | with: 38 | github_token: ${{ secrets.GITHUB_TOKEN }} 39 | reporter: github-pr-review 40 | 41 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | test: 13 | name: Install and test package 14 | uses: ./.github/workflows/test-python-package.yml 15 | with: 16 | python-version: "3.13" 17 | 18 | build: 19 | name: Build package 20 | needs: 21 | - test 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout code 25 | uses: actions/checkout@v4 26 | 27 | - name: Install poetry 28 | run: pipx install poetry 29 | 30 | - name: Set up Python 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: "3.13" 34 | cache: poetry 35 | 36 | - name: Package project 37 | run: poetry build 38 | 39 | - name: Store the distribution packages 40 | uses: actions/upload-artifact@v4 41 | with: 42 | name: python-package-distributions 43 | path: dist/ 44 | 45 | install_and_check_version: 46 | name: Install and check Version Number 47 | needs: 48 | - build 49 | runs-on: ubuntu-latest 50 | steps: 51 | - name: Download all the dists 52 | uses: actions/download-artifact@v4 53 | with: 54 | name: python-package-distributions 55 | path: dist/ 56 | 57 | - name: Install 58 | run: pip install dist/*.tar.gz 59 | 60 | - name: Check version number 61 | run: | 62 | PYTHON_VERSION=`ReX --version` 63 | echo "PYTHON_VERSION=${PYTHON_VERSION}" 64 | GIT_VERSION=$GITHUB_REF_NAME 65 | echo "GIT_VERSION=${GIT_VERSION}" # NB that Github version should have a 'v' prefix 66 | if [ "v$PYTHON_VERSION" != "$GIT_VERSION" ]; then exit 1; fi 67 | echo "VERSION=${GIT_VERSION}" >> $GITHUB_OUTPUT 68 | 69 | pypi-publish: 70 | name: Upload release to PyPI 71 | needs: 72 | - install_and_check_version 73 | runs-on: ubuntu-latest 74 | environment: 75 | name: pypi 76 | url: https://pypi.org/project/rex_xai 77 | permissions: 78 | id-token: write 79 | steps: 80 | - name: Download all the dists 81 | uses: actions/download-artifact@v4 82 | with: 83 | name: python-package-distributions 84 | path: dist/ 85 | 86 | - name: Publish package distributions to PyPI 87 | uses: pypa/gh-action-pypi-publish@release/v1 88 | -------------------------------------------------------------------------------- /.github/workflows/test-package-and-comment.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install ReX and run tests, using the highest and lowest Python versions we support 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Install, run tests, and report test coverage 5 | 6 | on: 7 | pull_request: 8 | 9 | 10 | jobs: 11 | run-tests: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.10", "3.13"] 16 | uses: ./.github/workflows/test-python-package.yml 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | 20 | coverage_report: 21 | needs: 22 | - run-tests 23 | runs-on: ubuntu-latest 24 | permissions: 25 | pull-requests: write 26 | 27 | steps: 28 | - name: Download coverage report 29 | uses: actions/download-artifact@v4 30 | with: 31 | name: coverage-report-py3.13 32 | 33 | - name: Pytest coverage comment 34 | uses: MishaKav/pytest-coverage-comment@v1.1.52 35 | if: ${{ !github.event.pull_request.head.repo.fork }} 36 | with: 37 | pytest-xml-coverage-path: ./coverage.xml 38 | -------------------------------------------------------------------------------- /.github/workflows/test-python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install the project and run tests 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Install and run tests 5 | 6 | on: 7 | workflow_call: 8 | inputs: 9 | python-version: 10 | required: true 11 | type: string 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v4 20 | 21 | - name: Install poetry 22 | run: pipx install poetry 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{inputs.python-version}} 28 | cache: poetry 29 | 30 | - name: Install project and dependencies 31 | run: poetry install --with dev 32 | 33 | - name: Update PATH 34 | run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH 35 | 36 | - name: Cache model files for testing 37 | uses: actions/cache@v3 38 | env: 39 | cache-name: cache-model-files 40 | with: 41 | path: ~/.cache/cached_path 42 | key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('tests/conftest.py') }} 43 | 44 | - name: Test with pytest 45 | run: | 46 | pytest --junitxml=pytest.xml --cov-report=xml:coverage.xml --cov=rex_xai tests/unit_tests tests/snapshot_tests 47 | 48 | - name: Upload coverage report 49 | uses: actions/upload-artifact@v4 50 | with: 51 | name: coverage-report-py${{ inputs.python-version }} 52 | path: ./coverage.xml 53 | 54 | 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.log 3 | tmp/ 4 | yolo/ 5 | typings/ 6 | results/ 7 | scripts/ 8 | dist/ 9 | ReX.egg-info/ 10 | local/ 11 | runs/ 12 | local/ 13 | runs/ 14 | benchmarks/ 15 | rex.toml 16 | rgb_rex.toml 17 | l_rex.toml 18 | *.org 19 | *.onnx 20 | 21 | test_*.jpg 22 | ReX_*.jpg 23 | causal_*.jpg 24 | driver.py 25 | data_dump 26 | deepcover.toml 27 | *.egg 28 | build 29 | htmlcov 30 | __pycache__/ 31 | one_image/ 32 | data 33 | *.pickle 34 | *.swp 35 | *.db 36 | *.onnx 37 | *.h5 38 | *.hdf5 39 | *.pt 40 | *.json.gz 41 | .coverage 42 | rex_ai.egg-info/* 43 | *.pt 44 | *.json 45 | .coverage 46 | 47 | docs/_build/ 48 | .jupyter 49 | .ipynb_checkpoints/ 50 | *.ipynb 51 | 52 | .DS_Store 53 | .AppleDouble 54 | .LSOverride 55 | 56 | # Icon must end with two \r 57 | Icon 58 | 59 | # Thumbnails 60 | ._* 61 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | jobs: 13 | post_create_environment: 14 | # Install poetry 15 | # https://python-poetry.org/docs/#installing-manually 16 | - python -m pip install poetry 17 | post_install: 18 | # Install package and all dependencies 19 | # https://python-poetry.org/docs/managing-dependencies/#dependency-groups 20 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --with dev 21 | 22 | # Build documentation in the "docs/" directory with Sphinx 23 | sphinx: 24 | configuration: docs/conf.py 25 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 26 | # builder: "dirhtml" 27 | # Fail on all warnings to avoid broken references 28 | # fail_on_warning: true 29 | 30 | # Optionally build your docs in additional formats such as PDF and ePub 31 | # formats: 32 | # - pdf 33 | # - epub 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## Setting up a local development environment 4 | 5 | While you can install ReX through various methods, for development purposes we use [poetry](https://python-poetry.org/) for installation, development dependency management, and building the package. Install poetry following [its installation instructions](https://python-poetry.org/docs/) - it's recommended to install using `pipx`. 6 | 7 | Clone this repo and `cd` into it. 8 | 9 | Install ReX and its dependencies by running `poetry install`. 10 | This will install the versions of the dependencies given in `poetry.lock`. 11 | This ensures that the development environment is consistent for different people. 12 | 13 | The development dependencies (for generating documentation, linting, and running tests) are marked as optional, so to install these you will need to run instead `poetry install --with dev`. 14 | 15 | There are also some additional optional dependencies that are only required for working with 3D data. 16 | You can install these using `poetry install --extras 3D`. 17 | 18 | N.B. that poetry by default creates its own virtual environment for the project. 19 | However if you run `poetry install` in an activated virtual environment, it will detect and respect this. 20 | See the [poetry docs](https://python-poetry.org/docs/basic-usage/#using-your-virtual-environment) for more information. 21 | 22 | ## Testing 23 | 24 | We use [pytest](https://docs.pytest.org/en/stable/index.html) with the [pytest-cov](https://github.com/pytest-dev/pytest-cov) plugin. 25 | 26 | Run the tests by running `pytest`, which will automatically run all files of the form `test_*.py` or `*_test.py` in the current directory and its subdirectories. 27 | Run `pytest --cov=rex_xai tests/` to get a coverage report printed to the terminal. 28 | See [pytest-cov's documentation](https://pytest-cov.readthedocs.io/en/latest/) for additional reporting options. 29 | 30 | As the end-to-end tests which run the whole ReX pipeline can take a while to run, we have split the tests into two sub-directories: `tests/unit_tests/` and `tests/long_tests/`. 31 | During development you may wish to only run the faster unit tests. 32 | You can do this by specifying the directory: `pytest tests/unit_tests/` 33 | Both sets of tests are run by GitHub Actions upon a pull request. 34 | 35 | ### Updating snapshots 36 | 37 | Most of the end-to-end tests which run the whole ReX pipeline are 'snapshot tests' using the [syrupy](https://github.com/syrupy-project/syrupy) package. 38 | These tests involve comparing an object returned by the function under test to a previously saved 'snapshot' of that object. 39 | This can help identify unintentional changes in results that are introduced by new development. 40 | Note that snapshots are based on the text representation of an object, so don't necessarily capture *all* results you may care about. 41 | 42 | If a snapshot test fails, follow the steps below to confirm if the changes are expected or not and update the snapshots: 43 | 44 | * Run `pytest -vv` to see a detailed comparison of changes compared to the snapshot 45 | * Check whether these are expected or not. For example, if you have added an additional parameter in the `CausalArgs` class, you expect that parameter value to be missing from the snapshot. 46 | * If you only see expected differences in the snapshot, you can update the snapshots to match the new results by running `pytest --snapshot-update` and commit the updated files. 47 | 48 | ## Generating documentation with Sphinx 49 | 50 | Docs are automatically built on PRs and on updates to the repo's default branch, and are available at . 51 | 52 | To build documentation locally using Sphinx and sphinx-autoapi: 53 | 54 | ```sh 55 | cd docs/ 56 | make html 57 | ``` 58 | 59 | This will automatically generate documentation based on the code and docstrings and produce html files in `docs/_build/html`. 60 | 61 | ### Docstring style 62 | 63 | We prefer [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) docstrings, and use `sphinx.ext.napoleon` to parse them. 64 | 65 | ### Working with notebooks 66 | 67 | The [introduction to working with ReX interactively](https://rex-xai.readthedocs.io/en/latest/notebooks/intro.html) is written as a Jupyter notebook in markdown format. 68 | We use [MyST-NB](https://myst-nb.readthedocs.io/en/latest/index.html) to compile the notebook into html as part of the documentation. 69 | To more easily work with the notebook locally, you can use [Jupytext](https://jupytext.readthedocs.io/en/latest/) to generate an .ipynb notebook from the markdown file, edit the tutorial, and then convert the edited notebook back into md. 70 | 71 | ```sh 72 | # convert to .ipynb 73 | jupytext docs/notebooks/intro.md --to ipynb 74 | # convert back to markdown 75 | jupytext docs/notebooks/intro.ipynb --to myst 76 | ``` 77 | 78 | Markdown format allows much clearer diffs when tracking the notebook with version control, so please don't add the .ipynb files to version control. 79 | 80 | ## Code linting and formatting 81 | 82 | This project uses [ruff](https://docs.astral.sh/ruff/) for code linting and formatting, to ensure a consistent code style and identify issues like unused imports. 83 | Install by running `poetry install --with dev`. 84 | 85 | Run the linter on all files in the current working directory with `ruff check`. 86 | Ruff can automatically fix some issues if you run `ruff check --fix`. 87 | 88 | Run `ruff format` to automatically format all files in the current working directory. 89 | Run `ruff format --diff` to get a preview of any changes that would be made. 90 | 91 | ## Type checking 92 | 93 | We use [Pyright](https://microsoft.github.io/pyright/#/) for type checking. 94 | You can [install](https://microsoft.github.io/pyright/#/installation) the command line tool and/or an extension for your favourite editor. 95 | Upon a pull request, a check is run that identifies Pyright errors/warnings in the lines that have been added in the PR. 96 | A review comment will be left for each change. 97 | Ideally, no new errors/warnings will be introduced in a PR, but this is not an enforced requirement to merge. 98 | 99 | ## GitHub Actions 100 | 101 | We use GitHub Actions to automatically run certain checks upon pull requests, and to automate releasing a new ReX version to PyPI. 102 | 103 | On a pull request, the following workflows run: 104 | 105 | * linting and type checking 106 | * installing the package and running tests (using Python 3.10 and 3.13) 107 | * test coverage is also measured 108 | * the docs are also built by ReadTheDocs (separate from GitHub Actions) 109 | 110 | When a new [release](https://docs.github.com/en/repositories/releasing-projects-on-github/about-releases) is created, the following workflows run: 111 | 112 | * installing the package and running tests (using Python 3.13) 113 | * building the package 114 | * checking that the installed package version matches the release tag 115 | * uploading the release to PyPI 116 | 117 | ## Publishing the package on PyPI 118 | 119 | To publish a new ReX version to PyPI, create a [release](https://docs.github.com/en/repositories/releasing-projects-on-github/about-releases). 120 | This will trigger a set of GitHub Actions workflows, which will run tests, check for version number consistency, and then publish the package to PyPI. 121 | 122 | When creating the release, typically the target branch should be `main`. 123 | The target branch should contain all the commits you want to be included in the new release. 124 | 125 | The release should be associated with a tag that has the form "vX.Y.Z" - note the "v" prefix! 126 | This can be a new tag that is created for the most recent commit at the time of the release, or can be a pre-eexisting tag. 127 | 128 | Give the release a title - this can just be the version number. 129 | 130 | Write some release notes explaining the changes incorporated in this release. 131 | Github offers the option to [automatically generate release notes](https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes) based on PRs merged in since the last release, which can be a good starting point. 132 | Here is [one possible example](https://gist.github.com/andreasonny83/24c733ae50cadf00fcf83bc8beaa8e6a) of how release notes can be structured, to give some ideas of what to include. 133 | 134 | The release can be saved as a draft. 135 | When you are ready, use the "Publish release" button to publish the release and trigger the Github Actions workflow that will publish it to PyPI. 136 | 137 | If there are any issues with the workflow, it can also be re-run manually. 138 | Navigate to the workflow in the Actions tab of the repo and use the "Run workflow" button to run it manually (after fixing any known issues). 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ReX-XAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReX: Causal *R*esponsibility *EX*planations for image classifiers 2 | 3 | ![ReX logo with dinosaur](https://raw.githubusercontent.com/ReX-XAI/ReX/main/assets/rex_logo.png "ReX Logo with dinosaur") 4 | 5 | 6 | 7 | [![Docs](https://readthedocs.org/projects/rex-xai/badge/?version=latest)](https://rex-xai.readthedocs.io/en/latest/) 8 | [![Tests](https://github.com/ReX-XAI/ReX/actions/workflows/test-package-and-comment.yml/badge.svg)](https://github.com/ReX-XAI/ReX/actions/workflows/test-package-and-comment.yml) 9 | [![License](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/ReX-XAI/ReX.jl/blob/main/LICENSE) 10 | 11 | 12 | 13 | --- 14 | 15 | ReX is a causal explainability tool for image classifiers. It also works on tabular and 3D data. 16 | 17 | Given an input image and a classifier, ReX calculates a causal responsibility map across the data and identifies a minimal, sufficient, explanation. 18 | 19 | ![ladybird](https://raw.githubusercontent.com/ReX-XAI/ReX/main/tests/test_data/ladybird.jpg "Original Image") ![responsibility map](https://raw.githubusercontent.com/ReX-XAI/ReX/main/assets/ladybird_rm.png "Responsibility Map") ![minimal explanation](https://raw.githubusercontent.com/ReX-XAI/ReX/main/assets/ladybird_301.png "Explanation") 20 | 21 | ReX is black-box, that is, agnostic to the internal structure of the classifier. 22 | ReX finds single explanations, non-contiguous explanations (for partially obscured images), multiple independent explanations, contrastive explanations and lots of other things! 23 | It has a host of options and parameters, allowing you to fine tune it to your data. 24 | 25 | For background information and detailed usage instructions, see our [documentation](https://rex-xai.readthedocs.io/en/latest/). 26 | 27 | 28 | 29 | ## Installation 30 | 31 | ReX can be installed using `pip`. 32 | We recommend creating a virtual environment to install ReX. 33 | ReX has been tested using versions of Python >= 3.10. 34 | The following instructions assume `conda`: 35 | 36 | ```bash 37 | conda create -n rex python=3.13 38 | conda activate rex 39 | pip install rex_xai 40 | ``` 41 | 42 | This should install an executable `rex` in your path. 43 | 44 | > **Note:** 45 | > 46 | > By default, `onnxruntime` will be installed. 47 | > If you wish to use a GPU, you should uninstall `onnxruntime` and install `onnxruntime-gpu` instead. 48 | > You can alternatively clone the project and edit the `pyproject.toml` to read "onnxruntime-gpu >= 1.17.0" rather than "onnxruntime >= 1.17.0". 49 | 50 | If you want to use ReX with 3D data, you will need to install some optional extra dependencies: 51 | 52 | ```bash 53 | pip install 'rex_xai[3D]' 54 | ``` 55 | 56 | 57 | 58 | ## Feedback 59 | 60 | Bug reports, questions, and suggestions for enhancements are welcome - please [check the GitHub Issues](https://github.com/ReX-XAI/ReX/issues) to see if there is already a relevant issue, or [open a new one](https://github.com/ReX-XAI/ReX/issues/new)! 61 | 62 | ## How to Contribute 63 | 64 | Your contributions are highly valued and welcomed. To get started, please review the guidelines outlined in the [CONTRIBUTING.md](/CONTRIBUTING.md) file. We look forward to your participation! 65 | -------------------------------------------------------------------------------- /assets/bus_757.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/bus_757.png -------------------------------------------------------------------------------- /assets/id_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/id_curve.png -------------------------------------------------------------------------------- /assets/ladybird_301.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_301.png -------------------------------------------------------------------------------- /assets/ladybird_301_raw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_301_raw.png -------------------------------------------------------------------------------- /assets/ladybird_rm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_rm.png -------------------------------------------------------------------------------- /assets/lizard_subs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/lizard_subs.png -------------------------------------------------------------------------------- /assets/occluded_bus_rm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/occluded_bus_rm.png -------------------------------------------------------------------------------- /assets/peacock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock.jpg -------------------------------------------------------------------------------- /assets/peacock_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_05.png -------------------------------------------------------------------------------- /assets/peacock_84_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_00.png -------------------------------------------------------------------------------- /assets/peacock_84_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_01.png -------------------------------------------------------------------------------- /assets/peacock_84_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_02.png -------------------------------------------------------------------------------- /assets/peacock_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_comp.png -------------------------------------------------------------------------------- /assets/peacock_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_exp.png -------------------------------------------------------------------------------- /assets/rex-assumptions-768x259.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex-assumptions-768x259.png -------------------------------------------------------------------------------- /assets/rex-structure-600x129.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex-structure-600x129.png -------------------------------------------------------------------------------- /assets/rex_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex_logo.png -------------------------------------------------------------------------------- /assets/rex_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 73 | -------------------------------------------------------------------------------- /assets/spectrum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/spectrum.png -------------------------------------------------------------------------------- /assets/tandem.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem.jpg -------------------------------------------------------------------------------- /assets/tandem_contrastive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_contrastive.png -------------------------------------------------------------------------------- /assets/tandem_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_exp.png -------------------------------------------------------------------------------- /assets/tandem_multi_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_1.png -------------------------------------------------------------------------------- /assets/tandem_multi_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_2.png -------------------------------------------------------------------------------- /assets/tandem_multi_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_3.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/rex_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/docs/_static/rex_logo.png -------------------------------------------------------------------------------- /docs/background.md: -------------------------------------------------------------------------------- 1 | # Background information 2 | 3 | ReX is a causal explainability tool for image classifiers. 4 | ReX is black-box, that is, agnostic to the internal structure of the classifier. 5 | We assume that we can modify the inputs and send them to the classifier, observing the output. 6 | ReX outperforms other tools on [single explanations](https://www.hanachockler.com/eccv/), [non-contiguous explanations](https://www.hanachockler.com/iccv2021/) (for partially obscured images), and [multiple explanations](http://www.hanachockler.com/multirex/). 7 | 8 | ![ReX organisation](../assets/rex-structure-600x129.png) 9 | 10 | ## Assumptions 11 | 12 | ReX works on the assumption that if we can intervene on the inputs to a model and observe changes in its outputs, we can use this information to reason about the way the DNN makes its decisions. 13 | 14 | ![ReX assumptions](../assets/rex-assumptions-768x259.png) 15 | 16 | ## Presentations about ReX 17 | 18 | * [Attacking your black box classifier with ReX](https://www.hanachockler.com/rex-2/) 19 | * [Causal Explanations For Image Classifiers](https://www.hanachockler.com/hana-chockler-causal-xai-workshop-102023/) 20 | 21 | ## Papers 22 | 23 | 1. [Causal Explanations for Image Classifiers](https://arxiv.org/pdf/2411.08875). Under review. This paper introduces the tool ReX. 24 | 2. [Multiple Different Black Box Explanations for Image Classifiers](http://www.hanachockler.com/multirex/). Under review. This paper introduces MULTI-ReX for multiple explanations. 25 | 3. [3D ReX: Causal Explanations in 3D Neuroimaging Classification](https://arxiv.org/pdf/2502.12181). Presented at [Imageomics-AAAI-25](https://sites.google.com/view/imageomics-aaai-25/home?authuser=0). 3D explanations for neuroimaging. 26 | 4. [Explanations for Occluded Images](http://www.hanachockler.com/iccv2021/). In ICCV’21. This paper introduces causality for image classifier explanations. Note: the tool is called DC-Causal in this paper. 27 | 5. [Explaining Image Classifiers using Statistical Fault Localization](http://www.hanachockler.com/eccv/). In ECCV’20. The first paper on ReX. Note: the tool is called DeepCover in this paper. 28 | -------------------------------------------------------------------------------- /docs/command_line.md: -------------------------------------------------------------------------------- 1 | # Command line usage 2 | 3 | 4 | 5 | ```{argparse} 6 | :module: rex_xai.config 7 | :func: cmdargs_parser 8 | :prog: ReX 9 | ``` 10 | 11 | 12 | 13 | ## Model formats 14 | 15 | ### Onnx 16 | 17 | ReX natively understands onnx files. Train or download a model (e.g. [Resnet50](https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v1-7.onnx)) and, from this directory, run: 18 | 19 | ```bash 20 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --output dog_exp.jpg 21 | ``` 22 | 23 | ### Pytorch 24 | 25 | ReX also works with PyTorch, but you will need to write some custom code to provide ReX with the prediction function and model shape, as well as preprocess the input data. 26 | See the sample scripts in `scripts/`. 27 | 28 | ```bash 29 | rex tests/test_data/dog.jpg --script scripts/pytorch.py -vv --output dog_exp.jpg 30 | ``` 31 | 32 | ## Saving output in a database 33 | 34 | To store all output in a sqlite database, use: 35 | 36 | ```bash 37 | rex --model -db 38 | ``` 39 | 40 | ReX will create the db if it does not already exist. 41 | It will append to any db with the given name, so be careful not to use the same database if you are restarting an experiment. 42 | ReX comes with a number of database functions which allow you to load it as a pandas dataframe for analysis. 43 | 44 | ## Config 45 | 46 | ReX looks for the config file `rex.toml` in the current working directory and then `$HOME/.config/rex.toml` on unix-like systems. 47 | 48 | If you want to use a custom location, use: 49 | 50 | ```bash 51 | rex --model --config 52 | ``` 53 | 54 | An example config file is included in the repo as `example.rex.toml`. 55 | Rename this to `rex.toml` if you wish to use it. ReX will ignore the file `example.rex.toml` is you do not rename it. 56 | 57 | ### Overriding the config 58 | 59 | Some options from the config file can be overridden at the command line when calling ReX. 60 | In particular, you can change the number of iterations of the algorithm: 61 | 62 | ```bash 63 | rex --model --iters 5 64 | ``` 65 | 66 | ## Preprocessing 67 | 68 | Input data should be transformed in the same way the model's training data was transformed before training. 69 | For PyTorch models, you should specify the preprocessing steps in the custom script. 70 | See the sample scripts in `scripts/` for examples of using models provided by the [torchvision](https://pytorch.org/vision/stable/index.html) and [timm](https://huggingface.co/docs/timm/index) packages. 71 | You can also write a script to load an `onnx` model as well if you have custom preprocessing to do. 72 | 73 | Otherwise, ReX tries to make reasonable guesses for image preprocessing. 74 | This includes resizing the image to match that needed for the model, converting it to a PyTorch tensor, and normalising the data. This can be controlled in `rex.toml`. 75 | In the event the the model input is multi-channel and the image is greyscale, then ReX will convert the image to pseudo-RGB. 76 | If you want more control over the conversion, you can do the conversion yourself and pass in the converted image. 77 | 78 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "ReX" 10 | copyright = "2024, David Kelly" 11 | author = "David Kelly, Liz Ing-Simmons and other contributors" 12 | 13 | # -- General configuration --------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 15 | 16 | extensions = [ 17 | "autoapi.extension", 18 | "sphinx.ext.autodoc.typehints", 19 | "sphinx.ext.napoleon", 20 | "sphinx.ext.intersphinx", 21 | 'sphinxarg.ext', 22 | "myst_nb" 23 | ] 24 | 25 | templates_path = ["_templates"] 26 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 27 | 28 | # -- Options for HTML output ------------------------------------------------- 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 30 | 31 | html_theme = "alabaster" 32 | html_sidebars = { 33 | "**": [ 34 | "about.html", 35 | "searchfield.html", 36 | "navigation.html", 37 | "relations.html", 38 | "donate.html", 39 | ] 40 | } 41 | html_static_path = ["_static"] 42 | html_theme_options = { 43 | "logo": "rex_logo.png" 44 | } 45 | 46 | # -- AutoAPI ----------------------------------------------------------------- 47 | # https://sphinx-autoapi.readthedocs.io/en/latest/ 48 | autoapi_dirs = ["../rex_xai/"] 49 | autodoc_typehints = "description" 50 | 51 | # -- Intersphinx -------------------------------------------------------------- 52 | # https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html 53 | intersphinx_mapping = { 54 | "python": ("https://docs.python.org/3", None), 55 | "torch": ("https://pytorch.org/docs/stable/", None), 56 | "sqla": ("https://docs.sqlalchemy.org/en/latest/", None), 57 | } 58 | 59 | # -- MyST -------------------------------------------------------------- 60 | # https://myst-parser.readthedocs.io/en/latest/ 61 | myst_enable_extensions = [ 62 | "attrs_inline" 63 | ] 64 | 65 | nb_execution_timeout = 300 66 | -------------------------------------------------------------------------------- /docs/contrastive.md: -------------------------------------------------------------------------------- 1 | # Contrastive Explanations 2 | 3 | This page describes **contrastive** explanations. These are *necessary* and *sufficient* explanations. 4 | A normal explanation from ReX is only sufficient. 5 | 6 | From the command line, the basic call is 7 | 8 | ```bash 9 | rex --script pytorch.py --contrastive 10 | ``` 11 | 12 | ## Example 13 | 14 | This image is a tandem bike 15 | 16 | ```{image} ../assets/tandem.jpg 17 | :alt: Tandem Bike 18 | :align: center 19 | ``` 20 | 21 | 22 | If we run 23 | 24 | ```bash 25 | rex tandem.jpg --script ../tests/scripts/pytorch_resnet50.py --vv --output tandem_exp.png 26 | ``` 27 | 28 | we get 29 | 30 | ```{image} ../assets/tandem_exp.png 31 | :alt: Tandem Bike Explanation 32 | :scale: 200% 33 | :align: center 34 | ``` 35 | 36 | 37 | 38 | Passing the highlighted pixels to the model, against the baseline defined in `rex.toml`, is enough to get the classification `tandem`. 39 | However, if we remove these highlighted pixels and leave the rest alone, we still get `tandem`. Why is that? Because these pixels are *sufficient* 40 | to get `tandem`, but they aren't *necessary*. There must be at least one more, independent, explanation for `tandem` in the image. 41 | 42 | We can try to find other sufficient explanations using `--multi`. 43 | 44 | ```bash 45 | rex tandem.jpg --script ../tests/scripts/pytorch_resnet50.py --multi --vv --output tandem_exp.png 46 | ``` 47 | ![tandem1](../assets/tandem_multi_1.png) 48 | ![tandem2](../assets/tandem_multi_2.png) 49 | ![tandem3](../assets/tandem_multi_3.png) 50 | 51 | Each one of these is, by itself, sufficient for `tandem`. By default, `contrastive` uses combinations from the set of found sufficient explanations to 52 | find a combinations which is also *necessary*: if we remove these pixels then we no longer have a `tandem`, if we have only these pixels, we have `tandem`. 53 | This explanations is both sufficient to get `tandem` and necessary to get `tandem`. 54 | 55 | ```{image} ../assets/tandem_contrastive.png 56 | :alt: Tandem Bike Contrastive Explanation 57 | :scale: 200% 58 | :align: center 59 | ``` 60 | 61 | ## Notes 62 | 63 | `contrastive` uses multiple explanations under the hood and then tests all combinations of discovered explanations. As a result, it can fail to find a contrastive explanation, 64 | especially if there are more sufficient explanations than discovered by `--multi`. In this case, ReX quits with an error message. 65 | 66 | `constrastive` takes an optional `int` at the command line, indicating how many spotlights to launch. This defaults to `10`, just like with multiple explanations. 67 | 68 | The constrastive algorithm is an instance of the [set packing](https://en.wikipedia.org/wiki/Set_packing) problem. As such, in the worst case scenario, it can be quite expensive to compute. 69 | 70 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # ReX: Causal Responsibility Explanations for image classifiers 2 | 3 | **ReX** is a causal explainability tool for image classifiers. 4 | ReX is black-box, that is, agnostic to the internal structure of the classifier. 5 | We assume that we can modify the inputs and send them to the classifier, observing the output. 6 | ReX provides sufficient, minimal single explanations, non-contiguous explanations (for partially obscured images), multiple explanations 7 | and contrastive explanations (sufficient, necessary and minimal). 8 | 9 | ```{image} ../assets/rex-structure-600x129.png 10 | :alt: ReX organisation 11 | :width: 600px 12 | :align: center 13 | ``` 14 | 15 | For more information and links to the papers, see the [background](background) page. 16 | 17 | ```{include} ../README.md 18 | :start-after: 19 | :end-before: 20 | ``` 21 | 22 | ## Quickstart 23 | 24 | ReX requires as input an image and a model. 25 | ReX natively understands onnx files. Train or download a model (e.g. [Resnet50](https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v1-7.onnx)) and, from this directory, run: 26 | 27 | ```bash 28 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --output dog_exp.jpg 29 | ``` 30 | 31 | To view an interactive plot for the responsibility map, run:: 32 | 33 | ```bash 34 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --surface 35 | ``` 36 | 37 | To save the extracted explanation to a file: 38 | 39 | ```bash 40 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx --output dog_exp.jpg 41 | ``` 42 | 43 | ReX also works with PyTorch, but you will need to write some custom code to provide ReX with the prediction function and model shape, as well as preprocess the input data. 44 | See the sample scripts in `tests/scripts/`. 45 | 46 | ```bash 47 | rex tests/test_data/dog.jpg --script tests/scripts/pytorch_resnet50.py -vv --output dog_exp.jpg 48 | ``` 49 | 50 | Other options: 51 | 52 | ```bash 53 | # with spatial search rather than the default global search 54 | rex --model --strategy spatial 55 | 56 | # to run multiple explanations 57 | rex --model --multi 58 | 59 | # to view a responsibility landscape heatmap 60 | rex --model --heatmap 61 | 62 | # to save a responsibility landscape surface plot 63 | rex --model --surface 64 | ``` 65 | 66 | ReX configuration is mainly handled via a config file; some options can also be set on the command line. 67 | ReX looks for the config file `rex.toml` in the current working directory and then `$HOME/.config/rex.toml` on unix-like systems. 68 | 69 | If you want to use a custom location, use: 70 | 71 | ```bash 72 | rex --model --config 73 | ``` 74 | 75 | An example config file is included in the repo as `example.rex.toml`. 76 | Rename this to `rex.toml` if you wish to use it. 77 | 78 | ## Command line usage 79 | 80 | ```{include} command_line.md 81 | :start-after: 82 | :end-before: 83 | ``` 84 | 85 | ## Examples 86 | 87 | ### Explanation 88 | 89 | An explanation for a ladybird. This explanation was produced with 20 iterations, using the default masking colour (0). The minimal, sufficient explanation itself 90 | is pretty printed using the settings in `[rex.visual]` in `rex.toml`. 91 | 92 | ![ladybird](../tests/test_data/ladybird.jpg "Original Image") ![responsibility map](../assets/ladybird_rm.png "Responsibility Map") ![minimal explanation](../assets/ladybird_301.png "Explanation") 93 | 94 | Setting `raw = true` in `rex.toml` produces the image which was actually classified by the model. 95 | 96 | ![ladybird raw](../assets/ladybird_301_raw.png) 97 | 98 | ### Multiple Explanations 99 | 100 | ```bash 101 | rex tests/test_data/peacock.jpg --model resnet50-v1-7.onnx --strategy multi --output peacock.png 102 | ``` 103 | 104 | The number of explanations found depends on the model and some of the settings in `rex.toml`. 105 | 106 | ![peacock](../tests/test_data/peacock.jpg){w=200px} ![peacock 1](../assets/peacock_84_00.png) ![peacock 2](../assets/peacock_84_01.png) ![peacock 3](../assets/peacock_84_02.png) 107 | 108 | ### Occluded Images 109 | 110 | ![occluded bus](../tests/test_data/occluded_bus.jpg) 111 | 112 | ![occluded_bus_rm](../assets/occluded_bus_rm.png) 113 | 114 | ![bus_explanation](../assets/bus_757.png) 115 | 116 | ### Explanation Quality 117 | 118 | ```bash 119 | rex tests/test_data/ladybird.jpg --script tests/scripts/pytorch_resnet50.py --analyse 120 | 121 | INFO:ReX:area 0.000399, entropy difference 6.751189, insertion curve 0.964960, deletion curve 0.046096 122 | ``` 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | ```{toctree} 136 | :maxdepth: 2 137 | :caption: Contents: 138 | background.md 139 | command_line.md 140 | notebooks/intro 141 | config.md 142 | multiple.md 143 | contrastive.md 144 | ``` 145 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/docs/index.rst -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/multiple.md: -------------------------------------------------------------------------------- 1 | # Multiple Explanations 2 | 3 | This page describes **multiple** explanations, see [Multiple Different Black Box Explanations for Image Classifiers](http://www.hanachockler.com/multirex/). 4 | 5 | ## Example 6 | 7 | An image classification may have more than one sufficient explanation. Take this image of a peacock 8 | 9 | ```{image} ../assets/peacock.jpg 10 | :alt: Peacock 11 | :align: center 12 | ``` 13 | 14 | The global explanation is: 15 | 16 | 17 | ```{image} ../assets/peacock_exp.png 18 | :alt: Peacock Explanation 19 | :scale: 120% 20 | :align: center 21 | ``` 22 | 23 | But it's very likely that there's more than one. This small part of the tail is enough to get the classification `peacock`, but there are many 24 | other possible sources of information that match that classification. ReX can try to find them. 25 | 26 | ReX searches the responsibility map for sufficient explanations. It does this by launched `spotlights` which explore the space, using the responsibility 27 | as a guide. How many spotlights are launched is a parameter (by default: 10) and is set with the `--multi` flag; `--multi` takes an optional 28 | integer argument. 29 | 30 | ```bash 31 | rex peacock.jpg --script ../tests/scripts/pytorch_resnet50.py --multi 5 --vv --output peacock_exp.png 32 | ``` 33 | we get 34 | 35 | ```{image} ../assets/peacock_comp.png 36 | :alt: Peacock Explanation 37 | :align: center 38 | :scale: 120% 39 | ``` 40 | 41 | ReX has found 4 distinct, non-overlapping explanations. The original global explanation is still there, but we also have 3 other explanations. 42 | Two of these explanations (highlighted in white and red respectively) are [*disjoint* explanations](https://arxiv.org/pdf/2411.08875). 43 | 44 | ## Overlap 45 | 46 | The peacock shows 4 non-overlapping explanations, but this is a parameter. We can set the allowed degree of overlap by changing 47 | `permitted_overlap` in the [config](explanation_multi). This sets the [dice coefficient](https://en.wikipedia.org/wiki/Dice-S%C3%B8rensen_coefficient) 48 | of the explanations. 49 | 50 | If we set `permitted_overlap = 0.5` 51 | 52 | ```bash 53 | rex peacock.jpg --script ../tests/scripts/pytorch_resnet50.py --multi 10 --vv --output peacock_exp.png 54 | ``` 55 | 56 | ```{image} ../assets/peacock_05.png 57 | :alt: Peacock Explanation Overlap 58 | :align: center 59 | :scale: 120% 60 | ``` 61 | 62 | ## Notes 63 | Multi-ReX has many options and parameters, see [config](explanation_multi) for the complete list. 64 | 65 | The `spotlight` requires an objective function to guide its search of the responsibility landscape. By default this is `none`: if 66 | the spotlight fails to find an explanation in one location, it takes a random jump to another. Alternatively, `mean` moves the spotlight 67 | in the direction of the greater mean responsibility. 68 | 69 | -------------------------------------------------------------------------------- /docs/script.md: -------------------------------------------------------------------------------- 1 | ## Script usage 2 | ReX can take in scripts that define the model behaviour, the preprocessing for the model and how the model's output can be interpreted by ReX. This is to allow the users to provide custom preprocessing/models to ReX. 3 | 4 | As outlined in the [command line section](command_line.md), the user can pass in the script using the `--script` argument. 5 | 6 | ```bash 7 | rex imgs/dog.jpg --script scripts/pytorch.py -vv --output dog_exp.jpg 8 | ``` 9 | 10 | ### Contents of the python script 11 | There are three main components to the script: 12 | - A preprocess function which takes in the following parameters and returns a Data object: 13 | - path: The path to the image 14 | - shape: The shape of the model input 15 | - device: The device the data is on e.g. "cuda" 16 | - mode: The mode of the data e.g. "RGB", "L", "voxel" 17 | - A function that calls the model called prediction_function that takes in the following parameters and returns a list of Prediction objects: 18 | - mutants: Mutants created by ReX to run inference on 19 | - target: The target class , default None 20 | - raw: Whether to return the raw output (e.g. the probability of the classification) or not, default False 21 | - binary_threshold: The threshold for binary classification e.g. 0.5 , default None 22 | - Model shape function that returns the shape of the model input 23 | - Any other helper functions that are needed for the above functions 24 | 25 | #### Preprocessing function 26 | 27 | The preprocessing function is responsible for loading the image and transforming it to the correct shape for the model. 28 | 29 | ```python 30 | def preprocess(path, shape, device, mode) -> Data: 31 | ``` 32 | The **key steps** in the preprocess function are: 33 | - Load the data from the path 34 | - Transform the data to requirements of the model 35 | - Return a Data object 36 | 37 | 38 | The function should return a Data object. The Data object contains the following fields: 39 | - `input` -> The raw input 40 | - `data` -> The transformed input for the model 41 | - `model_shape` -> The shape of the model input 42 | - `device` -> The device the data is on e.g. "cuda" 43 | - `mode` -> The mode of the data e.g. "RGB", "L", "voxel" 44 | - `process` -> A boolean that indicates whether the data mode should be accessed or not 45 | - `model_height` -> The height of the model input 46 | - `model_width` -> The width of the model input 47 | - `model_height` -> The height of the model input 48 | - `model_channels` -> The number of channels in the model input 49 | - `transposed` -> Whether the data is transposed or not 50 | - `model_order` -> The order of the model input e.g. "first" or "last" 51 | - `background` -> The value of the background of the image e.g. 0 or 255 ... etc. For a range of values, use a tuple e.g. (0, 255) 52 | - `context` -> The context of the image e.g. the specific background like a beach or a road that can be used as an occlusion if specified as mask value 53 | 54 | The Data object can be initialised with the `input`, `model_shape`, `device` and optionally the `mode` and `process`.: 55 | Example: 56 | ```python 57 | data = Data(input, model_shape, device, mode="voxel", process=False) 58 | # Set the other attributes of the Data object separately like so 59 | data.model_height = 224 60 | ``` 61 | 62 | 63 | #### Prediction function 64 | 65 | The prediction function is responsible for running inference on the model, processing and returning the output. 66 | 67 | ```python 68 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 69 | ``` 70 | 71 | **Parameters**: 72 | - `mutants` -> A list of mutants to run inference on? 73 | - `target` -> The target class 74 | - `raw` -> Whether to return the raw output (e.g. the probability of the classification) or not 75 | - `binary_threshold` -> The threshold for binary classification e.g. 0.5 76 | 77 | **Returns**: 78 | - A list of Prediction objects or a float if raw is True 79 | 80 | The Prediction object contains the following fields: 81 | - `classification` -> The classification of the mutant: Optional[int] 82 | - `confidence` -> The confidence of the classification: Optional[float] 83 | - `bounding_box` -> The bounding box for the classification: Optional[NDArray] 84 | - `target` -> The target class: Optional[int] 85 | - `target_confidence` -> The confidence of the target class: Optional[float] 86 | 87 | #### Model shape function 88 | 89 | The model shape function is responsible for returning the shape of the model input. 90 | 91 | ```python 92 | def model_shape() -> []: 93 | ``` 94 | **Example:** 95 | ```python 96 | def model_shape(): 97 | return ["N", 3, 224, 224] 98 | ``` 99 | 100 | --- 101 | Example scripts can be found in the `tests/scripts` and `scripts` directory. -------------------------------------------------------------------------------- /example.rex.toml: -------------------------------------------------------------------------------- 1 | [rex] 2 | # masking value for mutations, can be either an integer, float or 3 | # one of the following built-in occlusions 'spectral', 'min', 'mean' 4 | # mask_value = 0 5 | 6 | # random seed, only set for reproducibility 7 | # seed = 42 8 | 9 | # whether to use gpu or not, defaults to true 10 | # gpu = true 11 | 12 | # batch size for the model 13 | # batch_size = 64 14 | 15 | [rex.onnx] 16 | # means for min-max normalization 17 | # means = [0.485, 0.456, 0.406] 18 | 19 | # stds = [0.229, 0.224, 0.225] 20 | 21 | # binary model confidence threshold. Anything >= threshold will be classified as 1, otherwise 0 22 | # binary_threshold = 0.5 23 | 24 | # norm = 255.0 25 | 26 | # intra_op_num_threads = 8 27 | 28 | # inter_op_num_threads = 8 29 | 30 | # ort_logger = 3 31 | 32 | [rex.visual] 33 | # whether to show progress bar in the terminal, defaults to true 34 | # progress_bar = false 35 | 36 | # resize the explanation to the size of the original image. This uses cubic interpolation and will not be as visually accurate as not resizing, defaults to false 37 | # resize = true 38 | 39 | # include classification and confidence information in title of plot, defaults to true 40 | # info = false 41 | 42 | # produce unvarnished image with actual masking value, defaults to false 43 | # raw = false 44 | 45 | # pretty printing colour for explanations, defaults to 200 46 | # colour = 100 47 | 48 | # matplotlib colourscheme for responsibility map plotting, defaults to 'magma' 49 | # heatmap_colours = 'coolwarm' 50 | 51 | # alpha blend for main image, defaults to 0.2 (PIL Image.blend parameter) 52 | # alpha = 0.2 53 | 54 | # overlay a 10*10 grid on an explanation, defaults to false 55 | # grid = false 56 | 57 | # mark quickshift segmentation on image 58 | # mark_segments = false 59 | 60 | # multi_style explanations, either or 61 | # multi_style = "composite" 62 | 63 | [causal] 64 | # maximum depth of tree, defaults to 10, note that search can actually go beyond this number on occasion, as the 65 | # check only occurs at the end of an iteration 66 | # tree_depth = 30 67 | 68 | # limit on number of combinations to consider , defaults to none. 69 | # It is **not** the total work done by ReX over all iterations. Leaving the search limit at none 70 | # can potentially be very expensive. 71 | # search_limit = 1000 72 | 73 | # number of times to run the algorithm, defaults to 20 74 | # iters = 30 75 | 76 | # minimum child size, in pixels 77 | # min_box_size = 10 78 | 79 | # remove passing mutants which have a confidence less thatn . Defaults to 0.0 (meaning all mutants are considered) 80 | # confidence_filter = 0.5 81 | 82 | # whether to weight responsibility by prediction confidence, default to false 83 | # weighted = false 84 | 85 | # queue_style = "intersection" | "area" | "all" | "dc", defaults to "area" 86 | # queue_style = "area" 87 | 88 | # maximum number of things to hold in search queue, either an integer or 'all' 89 | # queue_len = 1 90 | 91 | # concentrate: weight responsibility by tree depth of passing partition. Defaults to false 92 | # concentrate = true 93 | 94 | # subtract responsibility for non-target classifications. Defaults to false 95 | # negative_responsibility = true 96 | 97 | [causal.distribution] 98 | # distribution for splitting the box, defaults to uniform. Possible choices are 'uniform' | 'binom' | 'betabinom' | 'adaptive' 99 | # distribution = 'uniform' 100 | 101 | # blend = 0.5 102 | 103 | # supplimental arguments for distribution creation, these are ignored if does not take any parameters 104 | # dist_args = [1.1, 1.1] 105 | 106 | [explanation] 107 | # iterate through pixel ranking in chunks, defaults to causal.min_box_size 108 | # chunk_size = 10 109 | 110 | # causal explanations are minimal by definition, but very small explanations might have very low confidence. ReX will keep 111 | # looking for an explanation of confidence greater than or equal to the model confidence on . 112 | # This is especially useful to reduce errors due to floating point imprecision when batching calls to the model. 113 | # Defaults to 0.0, with maximum value 1.0. 114 | # minimum_confidence_threshold = 0.1 115 | 116 | [explanation.spatial] 117 | # initial search radius 118 | # initial_radius = 25 119 | 120 | # increment to change radius 121 | # radius_eta = 0.2 122 | 123 | # number of times to expand before quitting, defaults to 4 124 | # no_expansions = 4 125 | 126 | [explanation.multi] 127 | # multi method (just spotlight so far) 128 | # method = 'spotlight' 129 | 130 | # no of spotlights to launch 131 | # spotlights = 10 132 | 133 | # default size of spotlight 134 | # spotlight_size = 24 135 | 136 | # decrease spotlight by this amount 137 | # spotlight_eta = 0.2 138 | 139 | # maximum number of random steps that a spotlight can make before quitting 140 | # max_spotlight_budget = 40 141 | 142 | # objective function for spotlight search. Possible options 'mean' | 'max' | "none" 143 | # objective_function = 'none' 144 | 145 | # permitted_overlap = 0.5 146 | 147 | [explanation.evaluation] 148 | 149 | # normalise insertion/deletion curves by confidence of original data, defaults to true 150 | # normalise_curves = true 151 | 152 | # insertion/deletion curve step size 153 | # insertion_step = 100 154 | 155 | 156 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "rex-xai" 3 | version = "0.3.1" 4 | description = "causal Responsibility-based eXplanations of black-box-classifiers" 5 | authors = [ 6 | { name = "David Kelly", email = "dkellino@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | 11 | dependencies = [ 12 | "numpy==1.26.4", 13 | "scipy>=1.10", 14 | "imutils>=0.5.4", 15 | "toml>=0.10", 16 | "anytree>=2.8.0", 17 | "fastcache>=1.1.0", 18 | "tqdm>=4.65.0", 19 | "sqlalchemy>=2.0.16", 20 | "matplotlib>=3.7.1", 21 | "onnxruntime>=1.18.0", 22 | "scikit-image>=0.21.0", 23 | "pandas>=2.2.0", 24 | "pillow>=10.3.0", 25 | "torch>=2.6.0" 26 | ] 27 | 28 | [project.optional-dependencies] 29 | 30 | 3D = [ 31 | "nibabel>=5.2.1", 32 | "kaleido==0.2.1", 33 | "plotly>=5.4.0", 34 | "dash>=2.1.0" 35 | ] 36 | 37 | [project.urls] 38 | homepage = "https://rex-xai.readthedocs.io//" 39 | repository = "https://github.com/ReX-XAI/ReX" 40 | documentation = "https://rex-xai.readthedocs.io/" 41 | "Bug Tracker" = "https://github.com/ReX-XAI/ReX/issues" 42 | 43 | [tool.poetry.scripts] 44 | ReX = "rex_xai.rex_wrapper:main" 45 | 46 | [build-system] 47 | requires = ["poetry-core"] 48 | build-backend = "poetry.core.masonry.api" 49 | 50 | [tool.poetry.group.dev] 51 | optional = true 52 | 53 | [tool.poetry.group.dev.dependencies] 54 | ruff = "^0.6.8" 55 | pytest = "^8.3.3" 56 | sphinx = "^8.0.2" 57 | myst-parser = "^4.0.0" 58 | sphinx-autoapi = "^3.3.2" 59 | pyright = "^1.1.383" 60 | pytest-cov = "^5.0.0" 61 | syrupy = "^4.7.2" 62 | torchvision = "^0.21.0" 63 | pytest-sugar = "^1.0.0" 64 | cached-path = "^1.6.3" 65 | sphinx-argparse = "^0.5.2" 66 | myst-nb = "^1.1.2" 67 | jupytext = "^1.16.7" 68 | plotly = "^5.4.0" 69 | dash = "^2.1.0" 70 | kaleido = "0.2.1" 71 | 72 | [tool.pyright] 73 | include = ["rex_xai"] 74 | exclude = ["scripts"] 75 | reportMissingTypeStubs = false 76 | -------------------------------------------------------------------------------- /rex_xai/__init__.py: -------------------------------------------------------------------------------- 1 | # pylintt: disable=invalid-name 2 | -------------------------------------------------------------------------------- /rex_xai/explanation/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from typing import Tuple 3 | import numpy as np 4 | import torch as tt 5 | from scipy.integrate import simpson 6 | from scipy.signal import periodogram 7 | from skimage.measure import shannon_entropy 8 | 9 | from rex_xai.explanation.explanation import Explanation 10 | from rex_xai.utils._utils import get_map_locations 11 | from rex_xai.mutants.mutant import _apply_to_data 12 | from rex_xai.utils._utils import set_boolean_mask_value, xlogx 13 | 14 | 15 | class Evaluation: 16 | # TODO does this need to be an object? Probably not... 17 | # TODO consider inheritance from Explanation object 18 | def __init__(self, explanation: Explanation) -> None: 19 | self.explanation = explanation 20 | 21 | def ratio(self) -> float: 22 | """Returns percentage of data required for sufficient explanation""" 23 | if ( 24 | self.explanation.explanation is None 25 | or self.explanation.data.model_channels is None 26 | ): 27 | raise ValueError("Invalid Explanation object") 28 | 29 | try: 30 | final_mask = self.explanation.final_mask.squeeze().item() # type: ignore 31 | except Exception: 32 | final_mask = self.explanation.final_mask 33 | 34 | try: 35 | return ( 36 | tt.count_nonzero(final_mask) # type: ignore 37 | / final_mask.size # type: ignore 38 | ).item() 39 | except TypeError: 40 | return ( 41 | np.count_nonzero(final_mask) # type: ignore 42 | / final_mask.size # type: ignore 43 | ) 44 | 45 | def spectral_entropy(self) -> Tuple[float, float]: 46 | """ 47 | This code is a simplified version of 48 | https://github.com/raphaelvallat/antropy/blob/master/src/antropy/entropy.py 49 | """ 50 | _, psd = periodogram(self.explanation.target_map) 51 | psd_norm = psd / psd.sum() 52 | ent = -np.sum(xlogx(psd_norm)) 53 | if len(psd_norm.shape) == 2: 54 | max_ent = np.log2(len(psd_norm[0])) 55 | else: 56 | max_ent = np.log2(len(psd_norm)) 57 | return ent, max_ent 58 | 59 | def entropy_loss(self): 60 | img = np.array(self.explanation.data.input) 61 | assert self.explanation.explanation is not None 62 | exp = shannon_entropy(self.explanation.explanation.detach().cpu().numpy()) 63 | 64 | return shannon_entropy(img), exp 65 | 66 | def insertion_deletion_curve(self, prediction_func, normalise=False): 67 | insertion_curve = [] 68 | deletion_curve = [] 69 | 70 | assert self.explanation.data.target is not None 71 | assert self.explanation.data.target.confidence is not None 72 | 73 | assert self.explanation.data.data is not None 74 | insertion_mask = tt.zeros( 75 | self.explanation.data.data.squeeze(0).shape, dtype=tt.bool 76 | ).to(self.explanation.data.device) 77 | deletion_mask = tt.ones( 78 | self.explanation.data.data.squeeze(0).shape, dtype=tt.bool 79 | ).to(self.explanation.data.device) 80 | im = [] 81 | dm = [] 82 | 83 | step = self.explanation.args.insertion_step 84 | ranking = get_map_locations(map=self.explanation.target_map) 85 | iters = len(ranking) // step 86 | 87 | for i in range(0, len(ranking), step): 88 | chunk = ranking[i : i + step] 89 | for _, loc in chunk: 90 | set_boolean_mask_value( 91 | insertion_mask, 92 | self.explanation.data.mode, 93 | self.explanation.data.model_order, 94 | loc, 95 | ) 96 | set_boolean_mask_value( 97 | deletion_mask, 98 | self.explanation.data.mode, 99 | self.explanation.data.model_order, 100 | loc, 101 | val=False, 102 | ) 103 | im.append( 104 | _apply_to_data(insertion_mask, self.explanation.data, 0).squeeze(0) 105 | ) 106 | dm.append( 107 | _apply_to_data(deletion_mask, self.explanation.data, 0).squeeze(0) 108 | ) 109 | 110 | if len(im) == self.explanation.args.batch_size: 111 | self.__batch(im, dm, prediction_func, insertion_curve, deletion_curve) 112 | im = [] 113 | dm = [] 114 | 115 | if im != [] and dm != []: 116 | self.__batch(im, dm, prediction_func, insertion_curve, deletion_curve) 117 | 118 | i_auc = simpson(insertion_curve, dx=step) 119 | d_auc = simpson(deletion_curve, dx=step) 120 | 121 | if normalise: 122 | const = self.explanation.data.target.confidence * iters * step 123 | i_auc /= const 124 | d_auc /= const 125 | 126 | return i_auc, d_auc 127 | 128 | # # def sensitivity(self): 129 | # # pass 130 | 131 | # # def infidelity(self): 132 | # # pass 133 | 134 | def __batch(self, im, dm, prediction_func, insertion_curve, deletion_curve): 135 | assert self.explanation.data.target is not None 136 | ip = prediction_func(tt.stack(im).to(self.explanation.data.device), raw=True) 137 | dp = prediction_func(tt.stack(dm).to(self.explanation.data.device), raw=True) 138 | for p in range(0, ip.shape[0]): 139 | insertion_curve.append( 140 | ip[p, self.explanation.data.target.classification].item() 141 | ) # type: ignore 142 | deletion_curve.append( 143 | dp[p, self.explanation.data.target.classification].item() 144 | ) # type: ignore 145 | -------------------------------------------------------------------------------- /rex_xai/explanation/multi_explanation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """generate multiple explanations from a responsibility landscape """ 4 | 5 | import os 6 | import re 7 | import numpy as np 8 | 9 | import torch as tt 10 | from itertools import combinations 11 | 12 | from rex_xai.explanation.explanation import Explanation 13 | from rex_xai.mutants.distributions import random_coords, Distribution 14 | from rex_xai.utils.logger import logger 15 | from rex_xai.utils._utils import powerset, clause_area, SpatialSearch 16 | from rex_xai.output.visualisation import ( 17 | save_multi_explanation, 18 | save_image, 19 | plot_image_grid, 20 | ) 21 | 22 | 23 | class MultiExplanation(Explanation): 24 | def __init__(self, maps, prediction_func, data, args, run_stats): 25 | super().__init__(maps, prediction_func, data, args, run_stats) 26 | self.explanations = [] 27 | self.explanation_confidences = [] 28 | 29 | def __repr__(self) -> str: 30 | pred_func = repr(self.prediction_func) 31 | match_func_name = re.search(r"(" 34 | 35 | run_stats = {k: round(v, 5) for k, v in self.run_stats.items()} 36 | 37 | exp_text = ( 38 | "MultiExplanation:" 39 | + f"\n\tCausalArgs: {type(self.args)}" 40 | + f"\n\tData: {self.data}" 41 | + f"\n\tprediction function: {pred_func}" 42 | + f"\n\tResponsibilityMaps: {self.maps}" 43 | + f"\n\trun statistics: {run_stats} (5 dp)" 44 | ) 45 | 46 | if len(self.explanations) == 0: 47 | return ( 48 | exp_text 49 | + f"\n\texplanations: {self.explanations}" 50 | + f"\n\texplanation confidences: {self.explanation_confidences}" 51 | ) 52 | else: 53 | return ( 54 | exp_text 55 | + f"\n\texplanations: {len(self.explanations)} explanations of {type(self.explanations[0])} and shape {self.explanations[0].shape}" 56 | + f"\n\texplanation confidences: {[round(x, ndigits=5) for x in self.explanation_confidences]} (5 dp)" 57 | ) 58 | 59 | def save(self, path, mask=None, multi=None, multi_style=None, clauses=None): 60 | if multi_style is None: 61 | multi_style = self.args.multi_style 62 | if multi_style == "contrastive": 63 | super().save(path, mask=self.final_mask) 64 | if multi_style == "separate": 65 | logger.info("saving explanations in multiple different files") 66 | for i, mask in enumerate(self.explanations): 67 | name, ext = os.path.splitext(path) 68 | exp_path = f"{name}_{i}{ext}" 69 | super().save(exp_path, mask=mask) 70 | elif multi_style == "composite": 71 | logger.info("using composite style to save explanations") 72 | if clauses is None: 73 | clause = range(0, len(self.explanations)) 74 | save_multi_explanation( 75 | self.explanations, self.data, self.args, clause=clause, path=path 76 | ) 77 | else: 78 | name, ext = os.path.splitext(path) 79 | new_name = f"{name}_{clauses}{ext}" 80 | save_multi_explanation( 81 | self.explanations, 82 | self.data, 83 | self.args, 84 | clause=clauses, 85 | path=new_name, 86 | ) 87 | 88 | def show(self, path=None, multi_style=None, clauses=None): 89 | if multi_style is None: 90 | multi_style = self.args.multi_style 91 | outs = [] 92 | 93 | for mask in self.explanations: 94 | out = save_image(mask, self.data, self.args, path=None) 95 | outs.append(out) 96 | 97 | if multi_style == "separate": 98 | for mask in self.explanations: 99 | out = save_image(mask, self.data, self.args, path=None) 100 | outs.append(out) 101 | 102 | elif multi_style == "composite": 103 | if clauses is None: 104 | clause = tuple([i for i in range(len(self.explanations))]) 105 | out = save_multi_explanation( 106 | self.explanations, self.data, self.args, clause=clause, path=None 107 | ) 108 | outs.append(out) 109 | else: 110 | for clause in clauses: 111 | out = save_multi_explanation( 112 | self.explanations, 113 | self.data, 114 | self.args, 115 | clause=clause, 116 | path=None, 117 | ) 118 | outs.append(out) 119 | 120 | if len(outs) > 1: 121 | plot_image_grid(outs) 122 | else: 123 | return outs[0] 124 | 125 | def extract(self, method=None): 126 | self.blank() 127 | # we start with the global max explanation 128 | logger.info("spotlight number 1 (global max)") 129 | conf = self._Explanation__global() # type: ignore 130 | if self.final_mask is not None: 131 | self.explanations.append(self.final_mask) 132 | self.explanation_confidences.append(conf) 133 | self.blank() 134 | 135 | for i in range(0, self.args.spotlights - 1): 136 | logger.info("spotlight number %d", i + 2) 137 | conf = self.spotlight_search() 138 | if self.final_mask is not None: 139 | self.explanations.append(self.final_mask) 140 | self.explanation_confidences.append(conf) 141 | self.blank() 142 | logger.info( 143 | "ReX has found a total of %d explanations via spotlight search", 144 | len(self.explanations), 145 | ) 146 | 147 | def __dice(self, d1, d2): 148 | """calculates dice coefficient between two numpy arrays of the same dimensions""" 149 | d_sum = d1.sum() + d2.sum() 150 | if d_sum == 0: 151 | return 0 152 | intersection = tt.logical_and(d1, d2) 153 | return np.abs((2.0 * intersection.sum() / d_sum).item()) 154 | 155 | def separate_by(self, dice_coefficient: float, reverse=True): 156 | exps = [] 157 | sizes = dict() 158 | 159 | for i, exp in enumerate(self.explanations): 160 | size = tt.count_nonzero(exp) 161 | if size > 0: 162 | exps.append(i) 163 | sizes[i] = size 164 | 165 | clause_len = 0 166 | clauses = [] 167 | 168 | perms = combinations(exps, 2) 169 | bad_pairs = set() 170 | for perm in perms: 171 | left, right = perm 172 | if ( 173 | self.__dice(self.explanations[left], self.explanations[right]) 174 | > dice_coefficient 175 | ): 176 | bad_pairs.add(perm) 177 | 178 | for s in powerset(exps, reverse=reverse): 179 | found = True 180 | for bp in bad_pairs: 181 | if bp[0] in s and bp[1] in s: 182 | found = False 183 | break 184 | if found: 185 | if len(s) >= clause_len: 186 | clause_len = len(s) 187 | clauses.append(s) 188 | else: 189 | break 190 | 191 | clauses = sorted(clauses, key=lambda x: clause_area(x, sizes)) 192 | return clauses 193 | 194 | def contrastive(self, clauses): 195 | for clause in clauses: 196 | for subset in powerset(clause, reverse=False): 197 | mask = sum([self.explanations[x] for x in subset]) 198 | mask = mask.to(tt.bool) # type: ignore 199 | sufficient = tt.where(mask, self.data.data, self.data.mask_value) # type: ignore 200 | counterfactual = tt.where(mask, self.data.mask_value, self.data.data) # type: ignore 201 | ps = self.prediction_func(sufficient)[0] 202 | pn = self.prediction_func(counterfactual)[0] 203 | 204 | if ( 205 | ps.classification == self.data.target.classification # type: ignore 206 | and pn.classification != self.data.target.classification # type: ignore 207 | ): 208 | logger.info( 209 | "found sufficient and necessary explanation of class %d, %d with confidence %f", 210 | ps.classification, 211 | pn.classification, 212 | pn.confidence, 213 | ) 214 | self.final_mask = mask 215 | return subset 216 | logger.warning( 217 | "ReX is unable to find a counterfactual, so not producing an output. Exiting here..." 218 | ) 219 | exit() 220 | 221 | def __random_step_from(self, origin, width, height, step=5): 222 | c, r = origin 223 | # flip a coin to move left (0) or right (1) 224 | c_dir = np.random.randint(0, 2) 225 | c = c - step if c_dir == 0 else c + step 226 | if c < 0: 227 | c = 0 228 | if c > width: 229 | c = width 230 | 231 | # flip a coin to move down (0) or up (1) 232 | r_dir = np.random.randint(0, 2) 233 | r = r - step if r_dir == 0 else r + step 234 | if r < 0: 235 | r = 0 236 | if r > height: 237 | r = height 238 | logger.debug(f"trying new location: moving from {origin} to {(c, r)}") 239 | return (c, r) 240 | 241 | def __random_location(self): 242 | assert self.data.model_width is not None 243 | assert self.data.model_height is not None 244 | origin = random_coords( 245 | Distribution.Uniform, 246 | self.data.model_width * self.data.model_height, 247 | ) 248 | 249 | return np.unravel_index(origin, (self.data.model_height, self.data.model_width)) # type: ignore 250 | 251 | def spotlight_search(self, origin=None): 252 | if origin is None: 253 | centre = self.__random_location() 254 | else: 255 | centre = origin 256 | 257 | ret, resp, conf = self._Explanation__spatial( # type: ignore 258 | centre=centre, expansion_limit=self.args.no_expansions 259 | ) 260 | 261 | steps = 0 262 | while ret == SpatialSearch.NotFound and steps < self.args.max_spotlight_budget: 263 | if self.args.spotlight_objective_function == "none": 264 | centre = self.__random_location() 265 | ret, resp, conf = self._Explanation__spatial( # type: ignore 266 | centre=centre, expansion_limit=self.args.no_expansions 267 | ) 268 | else: 269 | new_resp = 0.0 270 | while new_resp < resp: 271 | centre = self.__random_step_from( 272 | centre, 273 | self.data.model_height, 274 | self.data.model_width, 275 | step=self.args.spotlight_step, 276 | ) 277 | ret, new_resp, conf = self._Explanation__spatial( # type: ignore 278 | centre=centre, expansion_limit=self.args.no_expansions 279 | ) 280 | if ret == SpatialSearch.Found: 281 | return conf 282 | ret, resp, conf = self._Explanation__spatial( # type: ignore 283 | centre=centre, expansion_limit=self.args.no_expansions 284 | ) 285 | steps += 1 286 | return conf 287 | -------------------------------------------------------------------------------- /rex_xai/input/input_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from typing import Optional 3 | import numpy as np 4 | import torch as tt 5 | 6 | from enum import Enum 7 | 8 | from rex_xai.mutants.occlusions import spectral_occlusion, context_occlusion 9 | from rex_xai.responsibility.prediction import Prediction 10 | from rex_xai.utils.logger import logger 11 | from rex_xai.utils._utils import ReXDataError 12 | 13 | Setup = Enum("Setup", ["ONNXMPS", "ONNX", "PYTORCH"]) 14 | 15 | 16 | def _guess_mode(input): 17 | if hasattr(input, "mode"): 18 | return input.mode 19 | if hasattr(input, "shape"): 20 | if len(input.shape) == 4: 21 | return "voxel" 22 | else: 23 | return "spectral" 24 | 25 | 26 | class Data: 27 | def __init__( 28 | self, input, model_shape, device="cpu", mode=None, process=True 29 | ) -> None: 30 | self.input = input 31 | self.mode = None 32 | self.target: Optional[Prediction] = None 33 | self.device = device 34 | self.setup: Optional[Setup] = None 35 | self.transposed = False 36 | 37 | self.mode = mode 38 | if mode is None: 39 | self.mode = _guess_mode(input) 40 | 41 | self.model_shape = model_shape 42 | height, width, channels, order, depth = self.__get_shape() 43 | self.model_height: Optional[int] = height 44 | self.model_width: Optional[int] = width 45 | self.model_depth: Optional[int] = depth 46 | self.model_channels: Optional[int] = channels if channels is not None else 1 47 | self.model_order = order 48 | self.mask_value = None 49 | self.background = None 50 | self.context = None 51 | 52 | if process: 53 | if self.mode == "RGB": 54 | if self.model_order == "first": 55 | self.transposed = True 56 | elif self.mode in ("tabular", "spectral"): 57 | self.data = self.input 58 | self.match_data_to_model_shape() 59 | elif self.mode == "voxel": 60 | self.data = self.input 61 | else: 62 | raise NotImplementedError 63 | 64 | def set_height(self, h: int): 65 | self.model_height = h 66 | 67 | def set_width(self, w: int): 68 | self.model_width = w 69 | 70 | def set_channels(self, c=None): 71 | self.model_channels = c 72 | 73 | def __repr__(self) -> str: 74 | data_info = f"Data: {self.mode}, {self.model_shape}, {self.model_height}, {self.model_width}, {self.model_channels}, {self.model_order}" 75 | if self.target is not None: 76 | target_info = repr(self.target) 77 | data_info = data_info + "\n\t Target:" + target_info 78 | return data_info 79 | 80 | def set_classification(self, cl): 81 | self.classification = cl 82 | 83 | def match_data_to_model_shape(self): 84 | """ 85 | a PIL image has the from H * W * C, so 86 | if the model takes C * H * W we need to transpose self.data to 87 | get it into the correct form for the model to consume 88 | This function does *not* add in the batch channel at the beginning 89 | """ 90 | assert self.data is not None 91 | if self.mode == "RGB" and self.model_order == "first": 92 | self.data = self.data.transpose(2, 0, 1) # type: ignore 93 | self.transposed = True 94 | if self.mode in ("tabular", "spectral"): 95 | self.data = self.generic_tab_preprocess() 96 | if self.mode == "voxel": 97 | pass 98 | self.data = self.try_unsqueeze() 99 | 100 | def generic_tab_preprocess(self): 101 | if isinstance(self.input, np.ndarray): 102 | self.data = self.input.astype("float32") 103 | arr = tt.from_numpy(self.data).to(self.device) 104 | else: 105 | arr = self.input 106 | for _ in range(len(self.model_shape) - len(arr.shape)): 107 | arr = arr.unsqueeze(0) 108 | return arr 109 | 110 | def load_data(self, astype="float32"): 111 | img = self.input.resize((self.model_height, self.model_width)) 112 | img = np.array(img).astype(astype) 113 | self.data = img 114 | self.match_data_to_model_shape() 115 | self.data = tt.from_numpy(self.data).to(self.device) 116 | 117 | def _normalise_rgb_data(self, means, stds, norm): 118 | assert self.data is not None 119 | if self.model_channels != 3: 120 | raise ReXDataError( 121 | f"expected RGB data, but got data with the shape {self.model_shape}" 122 | ) 123 | 124 | normed_data = self.data 125 | if norm is not None: 126 | normed_data /= norm 127 | 128 | if self.model_order == "first": 129 | if means is not None: 130 | for i, m in enumerate(means): 131 | normed_data[:, i, :, :] = normed_data[:, i, :, :] - m 132 | if stds is not None: 133 | for i, s in enumerate(stds): 134 | normed_data[:, i, :, :] = normed_data[:, i, :, :] / s 135 | 136 | if self.model_order == "last": 137 | if means is not None: 138 | for i, m in enumerate(means): 139 | normed_data[:, :, i] = normed_data[:, :, i] - m 140 | if stds is not None: 141 | for i, s in enumerate(stds): 142 | normed_data[:, :, i] = normed_data[:, :, i] / s 143 | 144 | return normed_data 145 | 146 | def try_unsqueeze(self): 147 | out = self.data 148 | if self.model_order == "first": 149 | dim = 0 150 | else: 151 | dim = -1 152 | if isinstance(self.data, tt.Tensor): 153 | for _ in range(len(self.model_shape) - len(self.data.shape) - 1): 154 | out = tt.unsqueeze(out, dim=dim) # type: ignore 155 | out = tt.unsqueeze(out, dim=0) # type: ignore 156 | else: 157 | for _ in range(len(self.model_shape) - len(self.data.shape) - 1): # type: ignore 158 | out = np.expand_dims(out, axis=dim) # type: ignore 159 | out = np.expand_dims(out, axis=0) # type: ignore 160 | return out 161 | 162 | def generic_image_preprocess( 163 | self, 164 | means=None, 165 | stds=None, 166 | astype="float32", 167 | norm: Optional[float] = 255.0, 168 | ): 169 | self.load_data(astype=astype) 170 | 171 | if self.mode == "RGB" and self.data is not None: 172 | self.data = self._normalise_rgb_data(means, stds, norm) 173 | self.try_unsqueeze() 174 | if self.mode == "L": 175 | self.data = self._normalise_rgb_data(means, stds, norm) 176 | 177 | def __get_shape(self): 178 | """returns height, width, channels, order, depth for the model""" 179 | if self.mode == "spectral": 180 | # an array of the form (h, w), so no channel info or order or depth 181 | if len(self.model_shape) == 2: 182 | return self.model_shape[0], self.model_shape[1], 1, None, None 183 | # an array of the form (batch, h, w), so no channel info or order or depth 184 | if len(self.model_shape) == 3: 185 | return self.model_shape[1], self.model_shape[2], 1, None, None 186 | if self.mode == "RGB": 187 | if len(self.model_shape) == 4: 188 | _, a, b, c = self.model_shape 189 | if a in (1, 3, 4): 190 | return b, c, a, "first", None 191 | else: 192 | return a, b, c, "last", None 193 | if self.mode == "voxel": 194 | if len(self.model_shape) == 4: 195 | _, w, h, d = self.model_shape # If batch is present 196 | return w, h, None, None, d 197 | else: 198 | w, h, d = self.model_shape 199 | return w, h, None, None, d 200 | 201 | raise ReXDataError( 202 | f"Incompatible 'mode' {self.mode} and 'model_shape' ({self.model_shape}), cannot get valid shape of Data object so exiting here" 203 | ) 204 | 205 | def set_mask_value(self, m): 206 | assert self.data is not None 207 | # if m is a number, then if might still need to be normalised 208 | 209 | if m == "spectral" and self.mode != "spectral": 210 | logger.warning( 211 | "Mask value 'spectral' can only be used if mode is also 'spectral', using default mask value 0 instead" 212 | ) 213 | m = 0 214 | 215 | match m: 216 | case int() | float() as m: 217 | self.mask_value = m 218 | case "min": 219 | self.mask_value = tt.min(self.data).item() # type: ignore 220 | case "mean": 221 | self.mask_value = tt.mean(self.data).item() # type: ignore 222 | case "spectral": 223 | self.mask_value = lambda m, d: spectral_occlusion( 224 | m, d, device=self.device 225 | ) 226 | case "context": 227 | self.mask_value = lambda m, d: context_occlusion(m, d, self.context) 228 | # TODO: Add args for noise and setting the context as currently only available through custom script 229 | case _: 230 | raise ValueError( 231 | f"Invalid mask value {m}. Should be an integer, float, or one of 'min', 'mean', 'spectral'" 232 | ) 233 | -------------------------------------------------------------------------------- /rex_xai/input/onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """onnx model management""" 4 | 5 | from typing import Optional, Union, List 6 | import sys 7 | import os 8 | import torch as tt 9 | import platform 10 | from scipy.special import softmax 11 | import numpy as np 12 | 13 | import onnxruntime as ort 14 | from onnxruntime import InferenceSession 15 | from rex_xai.responsibility.prediction import Prediction, from_pytorch_tensor 16 | from rex_xai.input.input_data import Setup 17 | 18 | from rex_xai.utils.logger import logger 19 | 20 | 21 | class OnnxRunner: 22 | def __init__(self, session: InferenceSession, setup: Setup, device) -> None: 23 | self.session = session 24 | self.input_shape = session.get_inputs()[0].shape 25 | self.output_shape = session.get_outputs()[0].shape 26 | self.input_name = session.get_inputs()[0].name 27 | self.output_name = session.get_outputs()[0].name 28 | self.setup: Setup = setup 29 | self.device: str = device 30 | 31 | def run_on_cpu( 32 | self, 33 | tensors: Union[tt.Tensor, List[tt.Tensor]], 34 | target: Optional[Prediction], 35 | raw: bool, 36 | binary_threshold: Optional[float] = None, 37 | ): 38 | """Convert a pytorch tensor, or list of tensors, to numpy arrays on the cpu for onnx inference.""" 39 | # check if it's a single tensor or a list of tensors 40 | if isinstance(tensors, list): 41 | tensor_size = tensors[0].shape[0] 42 | else: 43 | tensor_size = tensors.shape[0] 44 | 45 | if tensor_size == 1: 46 | tensors = tensors.detach().cpu().numpy() # type: ignore 47 | else: 48 | tensors = np.stack([t.detach().cpu().numpy() for t in tensors]) # type: ignore 49 | 50 | preds = [] 51 | 52 | try: 53 | prediction = self.session.run(None, {self.input_name: tensors})[0] 54 | for i in range(0, prediction.shape[0]): 55 | confidences = softmax(prediction[i]) 56 | if raw: 57 | for i in range(len(self.output_shape) - len(confidences.shape)): 58 | confidences = np.expand_dims(confidences, axis=0) 59 | return confidences 60 | if binary_threshold is not None: 61 | if confidences[0] >= binary_threshold: 62 | classification = 1 63 | else: 64 | classification = 0 65 | tc = confidences[0] 66 | else: 67 | classification = np.argmax(confidences) 68 | if target is not None: 69 | tc = confidences[target.classification] 70 | else: 71 | tc = None 72 | preds.append( 73 | Prediction( 74 | classification, 75 | confidences[classification], 76 | None, 77 | target=target, 78 | target_confidence=tc, 79 | ) 80 | ) 81 | 82 | return preds 83 | except Exception as e: 84 | logger.fatal(e) 85 | sys.exit(-1) 86 | 87 | def run_with_data_on_device( 88 | self, 89 | tensors, 90 | device, 91 | tsize, 92 | binary_threshold, 93 | raw=False, 94 | device_id=0, 95 | target=None, 96 | ): 97 | # input_shape = self.session.get_inputs()[0].shape # Gets the shape of the input (e.g [batch_size, 3, 224, 224]) 98 | batch_size = len(tensors) if isinstance(tensors, list) else tensors.shape[0] 99 | 100 | if isinstance(tensors, list): 101 | tensors = [m.contiguous() for m in tensors] 102 | shape = tuple( 103 | [batch_size] + list(self.input_shape)[1:] 104 | ) # batch_size + remaining input shape 105 | ptr = tensors[0].data_ptr() 106 | 107 | else: 108 | tensors = tensors.contiguous() 109 | shape = tuple(tensors.shape) 110 | ptr = tensors.data_ptr() 111 | 112 | binding = self.session.io_binding() 113 | binding.bind_input( 114 | name=self.input_name, 115 | device_type=device, 116 | device_id=device_id, 117 | element_type=np.float32, 118 | shape=shape, 119 | buffer_ptr=ptr, 120 | ) 121 | 122 | output_shape = [batch_size] + list(self.output_shape[1:]) 123 | 124 | z_tensor = tt.empty(output_shape, dtype=tt.float32, device=device).contiguous() 125 | 126 | binding.bind_output( 127 | name=self.output_name, 128 | device_type=device, 129 | device_id=device_id, 130 | element_type=np.float32, 131 | shape=tuple(z_tensor.shape), 132 | buffer_ptr=z_tensor.data_ptr(), 133 | ) 134 | 135 | self.session.run_with_iobinding(binding) 136 | if raw: 137 | return z_tensor 138 | return from_pytorch_tensor(z_tensor, target=target) 139 | 140 | def gen_prediction_function(self): 141 | if self.device == "cpu" or self.setup == Setup.ONNXMPS: 142 | return ( 143 | lambda tensor, 144 | target=None, 145 | raw=False, 146 | binary_threshold=None: self.run_on_cpu( 147 | tensor, target, raw, binary_threshold 148 | ), 149 | self.input_shape, 150 | ) 151 | if self.device == "cuda": 152 | return ( 153 | lambda tensor, 154 | target=None, 155 | device=self.device, 156 | raw=False, 157 | binary_threshold=None: self.run_with_data_on_device( 158 | tensor, 159 | device, 160 | len(tensor), 161 | binary_threshold, 162 | raw=raw, 163 | target=target, 164 | ), 165 | self.input_shape, 166 | ) 167 | 168 | 169 | def get_prediction_function(args): 170 | # def get_prediction_function(model_path, gpu: bool, logger_level=3): 171 | sess_options = ort.SessionOptions() 172 | 173 | ort.set_default_logger_severity(args.ort_logger) 174 | 175 | # are we (trying to) run on the gpu? 176 | if args.gpu: 177 | logger.info("using gpu for onnx inference session") 178 | if platform.uname().system == "Darwin": 179 | # note this is only true for M+ chips 180 | providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] 181 | sess_options.intra_op_num_threads = args.intra_op_num_threads 182 | sess_options.inter_op_num_threads = args.inter_op_num_threads 183 | device = "mps" 184 | _, ext = os.path.splitext(os.path.basename(args.model)) 185 | # for the moment, onnx does not seem to support data copying on mps, so we fall back to 186 | # copying data to the cpu for inference 187 | setup = Setup.ONNXMPS 188 | else: 189 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 190 | device = "cuda" 191 | setup = Setup.PYTORCH 192 | 193 | # set up sesson with gpu providers 194 | sess = ort.InferenceSession( 195 | args.model, sess_options=sess_options, providers=providers 196 | ) # type: ignore 197 | # are we running on cpu? 198 | else: 199 | logger.info("using cpu for onnx inference session") 200 | providers = ["CPUExecutionProvider"] 201 | sess = ort.InferenceSession( 202 | args.model, sess_options=sess_options, providers=providers 203 | ) 204 | device = "cpu" 205 | setup = Setup.PYTORCH 206 | 207 | onnx_session = OnnxRunner(sess, setup, device) 208 | 209 | return onnx_session.gen_prediction_function() 210 | -------------------------------------------------------------------------------- /rex_xai/mutants/distributions.py: -------------------------------------------------------------------------------- 1 | """distributions module""" 2 | 3 | from typing import Optional, Tuple 4 | from enum import Enum 5 | import numpy as np 6 | from scipy.stats import binom, betabinom 7 | from rex_xai.utils.logger import logger 8 | 9 | Distribution = Enum("Distribution", ["Binomial", "Uniform", "BetaBinomial", "Adaptive"]) 10 | 11 | 12 | def _betabinom2d(height, width, alpha, beta): 13 | bx = betabinom(width, alpha, beta) 14 | by = betabinom(height, alpha, beta) 15 | 16 | w = np.array([bx.pmf(i) for i in range(0, width)]) # type: ignore 17 | h = np.array([by.pmf(i) for i in range(0, height)]) # type: ignore 18 | 19 | w = np.expand_dims(w, axis=0) 20 | h = np.expand_dims(h, axis=0) 21 | 22 | u = (h.T * w / np.sum(h.T * w)).ravel() 23 | p = np.random.choice(np.arange(0, len(u)), p=u) 24 | return p 25 | 26 | 27 | def _blend(dist, alpha, base): 28 | pmf = np.array([base.pmf(x) for x in range(0, len(dist))]) 29 | blend = ((1.0 - alpha) * pmf) + (alpha * dist) 30 | blend /= np.sum(blend) 31 | return blend 32 | 33 | 34 | def _2d_adaptive(map, args: Tuple[int, int, int, int], alpha=0.0, base=None) -> int: 35 | # if the map exists and is not 0.0 everywhere... 36 | if map is not None and np.max(map) > 0.0: 37 | s = map[args[0] : args[1], args[2] : args[3]] 38 | sf = np.ndarray.flatten(s) 39 | # sf = np.max(sf) - sf 40 | sf /= np.sum(sf) 41 | 42 | # base = betabinom(0, len(sf), 1.1, 1.1) 43 | # if base is not None: 44 | # sf = _blend(alpha, base) 45 | pos = np.random.choice(np.arange(0, len(sf)), p=sf) 46 | return pos 47 | 48 | # if the map is empty or doesn't exist, return uniform 49 | return np.random.randint(1, (args[1] - args[0]) * (args[3] - args[2])) 50 | 51 | 52 | def str2distribution(d: str) -> Distribution: 53 | """converts string into Distribution enum""" 54 | if d == "binom": 55 | return Distribution.Binomial 56 | elif d == "uniform": 57 | return Distribution.Uniform 58 | elif d == "betabinom": 59 | return Distribution.BetaBinomial 60 | elif d == "adaptive": 61 | return Distribution.Adaptive 62 | else: 63 | logger.warning( 64 | "Invalid distribution '%s', reverting to default value Distribution.Uniform", 65 | d, 66 | ) 67 | return Distribution.Uniform 68 | 69 | 70 | def random_coords(d: Optional[Distribution], *args, map=None) -> Optional[int]: 71 | """generates random coordinates given a distribution and args""" 72 | 73 | try: 74 | if d == Distribution.Adaptive: 75 | return _2d_adaptive(map, args[0]) 76 | 77 | if d == Distribution.Uniform: 78 | return np.random.randint(1, args[0]) # type: ignore 79 | 80 | if d == Distribution.Binomial: 81 | start, stop, *dist_args = args[0] 82 | return binom(stop - start - 1, dist_args).rvs() + start 83 | 84 | if d == Distribution.BetaBinomial: 85 | return _betabinom2d(args[1], args[2], args[3][0], args[3][1]) 86 | 87 | return None 88 | except ValueError: 89 | return None 90 | -------------------------------------------------------------------------------- /rex_xai/mutants/mutant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from typing import List, Optional 3 | import numpy as np 4 | import torch as tt 5 | from PIL import Image # type: ignore 6 | 7 | try: 8 | from anytree.cachedsearch import find 9 | except ImportError: 10 | from anytree.search import find 11 | 12 | import matplotlib.pyplot as plt 13 | from rex_xai.mutants.box import Box 14 | from rex_xai.responsibility.prediction import Prediction 15 | from rex_xai.utils.logger import logger 16 | from rex_xai.input.input_data import Data 17 | from rex_xai.utils._utils import add_boundaries, set_boolean_mask_value 18 | 19 | __combinations = [ 20 | [ 21 | 0, 22 | ], 23 | [ 24 | 1, 25 | ], 26 | [ 27 | 2, 28 | ], 29 | [ 30 | 3, 31 | ], 32 | [0, 1], 33 | [0, 2], 34 | [0, 3], 35 | [1, 2], 36 | [1, 3], 37 | [2, 3], 38 | [0, 1, 2], 39 | [0, 1, 3], 40 | [0, 2, 3], 41 | [1, 2, 3], 42 | ] 43 | 44 | 45 | def _apply_to_data(mask, data: Data, masking_func): 46 | if isinstance(masking_func, (float, int)): 47 | res = tt.where(mask, data.data, masking_func) # type: ignore 48 | return res 49 | if callable(masking_func): 50 | return masking_func(mask, data.data) 51 | 52 | logger.warning("applying default masking value of 0") 53 | return tt.where(mask, data.data, 0) # type: ignore 54 | 55 | 56 | def get_combinations(): 57 | return __combinations 58 | 59 | 60 | class Mutant: 61 | def __init__(self, data: Data, static, active, masking_func) -> None: 62 | self.shape = tuple( 63 | data.model_shape[1:] 64 | ) # the first element of shape is the batch information, so we drop that 65 | self.mode = data.mode 66 | self.channels: int = ( 67 | data.model_channels if data.model_channels is not None else 1 68 | ) 69 | self.order = data.model_order 70 | self.mask = tt.zeros(self.shape, dtype=tt.bool, device=data.device) 71 | self.static = static 72 | self.active = active 73 | self.prediction: Optional[Prediction] = None 74 | self.passing = False 75 | self.masking_func = masking_func 76 | self.depth = 0 77 | 78 | def __repr__(self) -> str: 79 | return f"ACTIVE: {self.active}, PREDICTION: {self.prediction}, PASSING: {self.passing}" 80 | 81 | def get_name(self): 82 | return self.active 83 | 84 | def update_status(self, target): 85 | if self.prediction is not None: 86 | if target.classification == self.prediction.classification: 87 | self.passing = True 88 | 89 | def get_length(self): 90 | return len(self.active.split("_")) 91 | 92 | def get_active_boxes(self): 93 | return self.active.split("_") 94 | 95 | def area(self) -> int: 96 | """Return the total area *not* concealed by the mutant.""" 97 | tensor = tt.count_nonzero(self.mask) 98 | if tensor.numel() == 0 or tensor is None: 99 | return 0 100 | else: 101 | return int(tensor.item()) // self.channels 102 | 103 | def set_static_mask_regions(self, names, search_tree): 104 | for name in names: 105 | box = find(search_tree, lambda node: node.name == name) 106 | if box is not None: 107 | self.depth = max(self.depth, box.depth) 108 | self.set_mask_region_to_true(box) 109 | 110 | def set_active_mask_regions(self, boxes: List[Box]): 111 | for box in boxes: 112 | self.depth = max(self.depth, box.depth) 113 | self.set_mask_region_to_true(box) 114 | 115 | def set_mask_region_to_true(self, box: Box): 116 | set_boolean_mask_value(self.mask, self.mode, self.order, box) 117 | 118 | def apply_to_data(self, data: Data): 119 | return _apply_to_data(self.mask, data, self.masking_func) 120 | 121 | def save_mutant(self, data: Data, name=None, segs=None): 122 | if data.mode == "RGB": 123 | m = np.array(data.input.resize((data.model_height, data.model_width))) 124 | mask = self.mask.squeeze().detach().cpu().numpy() 125 | if data.transposed: 126 | # if transposed, we have C * H * W, so change that to H * W * C 127 | m = np.where(mask, m.transpose((2, 0, 1)), 0) 128 | m = m.transpose((1, 2, 0)) 129 | else: 130 | # TODO m = m.transpose((0, 2, 1)) 131 | m = np.where(mask, m, 255) 132 | # draw on the segment_mask, if available 133 | if segs is not None: 134 | m = add_boundaries(m, segs) 135 | img = Image.fromarray(m, data.mode) 136 | if name is not None: 137 | img.save(name) 138 | else: 139 | img.save(f"{self.get_name()}.png") 140 | # spectral or time series data 141 | if data.mode == "spectral": 142 | m = self.apply_to_data(data) 143 | fig = plt.figure() 144 | ax = fig.add_subplot(111) 145 | ax.plot(m[0][0].detach().cpu().numpy()) 146 | plt.savefig(f"{self.get_name()}.png") 147 | # 3d image 148 | if data.mode == "voxel": 149 | # TODO 150 | logger.info("saving 3d mutants is not yet implemented") 151 | pass 152 | -------------------------------------------------------------------------------- /rex_xai/mutants/occlusions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch as tt 3 | import numpy as np 4 | from scipy.ndimage import gaussian_filter 5 | 6 | 7 | def __split_groups(neg_mask): 8 | # for some reason, it's much faster to do this on the cpu with numpy 9 | # than it is to use tensor_split 10 | return np.split(neg_mask, np.where(np.diff(neg_mask) > 1)[0] + 1) 11 | 12 | 13 | def spectral_occlusion(mask: tt.Tensor, data: tt.Tensor, noise=0.02, device="cpu"): 14 | """Linear interpolated occlusion for spectral data, with optional added noise. 15 | 16 | @param mask: boolean valued NDArray 17 | @param data: data to be occluded 18 | @param noise: parameter for optional gaussian noise. 19 | Set to 0.0 if you want simple linear interpolation 20 | 21 | @return torch.Tensor 22 | """ 23 | neg_mask = tt.where(mask == 0)[0] 24 | split = __split_groups(neg_mask.detach().cpu().numpy()) 25 | 26 | # strangely, this all seems to run faster if we do it on the cpu. 27 | # TODO Needs further investigation 28 | local_data = np.copy(data.detach().cpu().numpy()) 29 | 30 | for s in split: 31 | if len(s) <= 1: 32 | return tt.from_numpy(local_data).to(device) 33 | start = s[0] 34 | stop = s[-1] 35 | dstart = data[:, :, start][0][0].item() 36 | dstop = data[:, :, stop][0][0].item() 37 | interp = tt.linspace(dstart, dstop, stop - start) 38 | if noise > 0.0: 39 | interp += np.random.normal(0, noise, len(interp)) 40 | 41 | local_data[0, 0, start:stop] = interp 42 | 43 | return tt.from_numpy(local_data).to(device) 44 | 45 | 46 | # Occlusions such as beach, sky, and other context-based occlusions 47 | 48 | 49 | # Medical-based occlusions could be a CT scan of a healthy patient 50 | def context_occlusion(mask: tt.Tensor, data: tt.Tensor, context: tt.Tensor, noise=0.5): 51 | """Context based occlusion with optional added noise. 52 | 53 | @param mask: boolean valued NDArray 54 | @param data: data to be occluded 55 | @param context: data to be used as occlusion e.g. CT scan of a healthy patient or a road 56 | @param noise: parameter for optional gaussian noise. 57 | Set to 0.0 for no noise 58 | 59 | @return torch.Tensor 60 | """ 61 | if noise > 0.0: 62 | context = tt.tensor(gaussian_filter(context, sigma=noise), dtype=tt.float32) 63 | return tt.where(mask == 0, context, data) 64 | -------------------------------------------------------------------------------- /rex_xai/output/database.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from datetime import datetime 3 | import zlib 4 | import torch as tt 5 | import sqlalchemy as sa 6 | from sqlalchemy import Boolean, Float, String, create_engine 7 | from sqlalchemy import Column, Integer, Unicode 8 | from sqlalchemy.orm import sessionmaker, DeclarativeBase 9 | from ast import literal_eval 10 | 11 | import pandas as pd 12 | import numpy as np 13 | 14 | from rex_xai.utils.logger import logger 15 | from rex_xai.input.config import CausalArgs, Strategy 16 | from rex_xai.explanation.explanation import Explanation 17 | from rex_xai.explanation.multi_explanation import MultiExplanation 18 | 19 | 20 | def _dataframe(db, table): 21 | return pd.read_sql_table(table, f"sqlite:///{db}") 22 | 23 | 24 | def _to_numpy(buffer, shape, dtype): 25 | return np.frombuffer(zlib.decompress(buffer), dtype=dtype).reshape(shape) 26 | 27 | 28 | def db_to_pandas(db, dtype=np.float32, table="rex", process=True): 29 | """for interactive use""" 30 | df = _dataframe(db, table=table) 31 | 32 | if process: 33 | df["responsibility"] = df.apply( 34 | lambda row: _to_numpy( 35 | row["responsibility"], literal_eval(row["responsibility_shape"]), dtype 36 | ), 37 | axis=1, 38 | ) 39 | # 40 | df["explanation"] = df.apply( 41 | lambda row: _to_numpy( 42 | row["explanation"], literal_eval(row["explanation_shape"]), np.bool_ 43 | ), 44 | axis=1, 45 | ) 46 | 47 | return df 48 | 49 | 50 | def __multi_update( 51 | db, 52 | explanation, 53 | classification, 54 | target, 55 | target_map, 56 | final_mask, 57 | time_taken, 58 | multi_no, 59 | ): 60 | if isinstance(final_mask, tt.Tensor): 61 | final_mask = final_mask.detach().cpu().numpy() 62 | add_to_database( 63 | db, 64 | explanation.args, 65 | classification, 66 | target.confidence, 67 | target_map, 68 | final_mask, 69 | explanation.explanation_confidences[multi_no], 70 | time_taken, 71 | explanation.run_stats["total_passing"], 72 | explanation.run_stats["total_failing"], 73 | explanation.run_stats["max_depth_reached"], 74 | explanation.run_stats["avg_box_size"], 75 | multi=True, 76 | multi_no=multi_no, 77 | ) 78 | 79 | 80 | def update_database( 81 | db, 82 | explanation: Explanation | MultiExplanation, # type: ignore 83 | time_taken=None, 84 | multi=False, 85 | clauses=None, 86 | ): 87 | target_map = explanation.target_map 88 | 89 | if isinstance(target_map, tt.Tensor): 90 | target_map = target_map.detach().cpu().numpy() 91 | 92 | target = explanation.data.target 93 | if target is None: 94 | logger.warning("unable to update database as target is None") 95 | return 96 | classification = int(target.classification) # type: ignore 97 | 98 | if not multi: 99 | final_mask = explanation.final_mask 100 | if explanation.final_mask is None: 101 | logger.warning("unable to update database as explanation is empty") 102 | return 103 | if isinstance(final_mask, tt.Tensor): 104 | final_mask = final_mask.detach().cpu().numpy() 105 | 106 | explanation_confidence = explanation.explanation_confidence 107 | 108 | add_to_database( 109 | db, 110 | explanation.args, 111 | classification, 112 | target.confidence, 113 | target_map, 114 | final_mask, 115 | explanation_confidence, 116 | time_taken, 117 | explanation.run_stats["total_passing"], 118 | explanation.run_stats["total_failing"], 119 | explanation.run_stats["max_depth_reached"], 120 | explanation.run_stats["avg_box_size"], 121 | ) 122 | 123 | else: 124 | if type(explanation) is not MultiExplanation: 125 | logger.warning( 126 | "unable to update database, multi=True is only valid for MultiExplanation objects" 127 | ) 128 | return 129 | else: 130 | for c, final_mask in enumerate(explanation.explanations): 131 | if clauses is not None: 132 | if c not in clauses: 133 | logger.warning("ignoring %s", c) 134 | else: 135 | __multi_update( 136 | db, 137 | explanation, 138 | classification, 139 | target, 140 | target_map, 141 | final_mask, 142 | time_taken, 143 | c, 144 | ) 145 | else: 146 | __multi_update( 147 | db, 148 | explanation, 149 | classification, 150 | target, 151 | target_map, 152 | final_mask, 153 | time_taken, 154 | c, 155 | ) 156 | 157 | 158 | def add_to_database( 159 | db, 160 | args: CausalArgs, 161 | target, 162 | confidence, 163 | responsibility, 164 | explanation, 165 | explanation_confidence, 166 | time_taken, 167 | passing, 168 | failing, 169 | depth_reached, 170 | avg_box_size, 171 | multi=False, 172 | multi_no=None, 173 | ): 174 | if multi: 175 | id = hash(str(datetime.now().time()) + str(multi_no)) 176 | else: 177 | id = hash(str(datetime.now().time())) 178 | 179 | responsibility_shape = responsibility.shape 180 | explanation_shape = explanation.shape 181 | 182 | object = DataBaseEntry( 183 | id, 184 | args.path, 185 | target, 186 | confidence, 187 | responsibility, 188 | responsibility_shape, 189 | explanation, 190 | explanation_shape, 191 | explanation_confidence, 192 | time_taken, 193 | depth_reached=depth_reached, 194 | avg_box_size=avg_box_size, 195 | tree_depth=args.tree_depth, 196 | search_limit=args.search_limit, 197 | iters=args.iters, 198 | min_size=args.min_box_size, 199 | distribution=str(args.distribution), 200 | distribution_args=str(args.distribution_args), 201 | ) 202 | # if object is not None: 203 | object.multi = multi 204 | object.multi_no = multi_no 205 | object.passing = passing 206 | object.failing = failing 207 | object.total_work = passing + failing 208 | object.method = str(args.strategy) 209 | if args.strategy == Strategy.Spatial: 210 | object.spatial_radius = args.spatial_initial_radius 211 | object.spatial_eta = args.spatial_radius_eta 212 | if args.strategy == Strategy.MultiSpotlight: 213 | object.spotlights = args.spotlights 214 | object.spotlight_size = args.spotlight_size 215 | object.spotlight_eta = args.spotlight_eta 216 | object.obj_function = args.spotlight_objective_function 217 | 218 | db.add(object) 219 | db.commit() 220 | 221 | 222 | class Base(DeclarativeBase): 223 | pass 224 | 225 | 226 | class NumpyType(sa.types.TypeDecorator): 227 | impl = sa.types.LargeBinary 228 | 229 | cache_ok = True 230 | 231 | def process_bind_param(self, value, dialect): 232 | if value is not None: 233 | return zlib.compress(value, 9) 234 | 235 | def process_result_value(self, value, dialect): 236 | return value 237 | 238 | 239 | class DataBaseEntry(Base): 240 | __tablename__ = "rex" 241 | id = Column(Integer, primary_key=True) 242 | path = Column(Unicode(100)) 243 | target = Column(Integer) 244 | confidence = Column(Float) 245 | time = Column(Float) 246 | responsibility = Column(NumpyType) 247 | responsibility_shape = Column(Unicode) 248 | total_work = Column(Integer) 249 | passing = Column(Integer) 250 | failing = Column(Integer) 251 | explanation = Column(NumpyType) 252 | explanation_shape = Column(Unicode) 253 | explanation_confidence = Column(Float) 254 | multi = Column(Boolean) 255 | multi_no = Column(Integer) 256 | 257 | # causal specific columns 258 | depth_reached = Column(Integer) 259 | avg_box_size = Column(Float) 260 | tree_depth = Column(Integer) 261 | search_limit_per_iter = Column(Integer) 262 | iters = Column(Integer) 263 | min_size = Column(Integer) 264 | distribution = Column(String) 265 | distribution_args = Column(String) 266 | 267 | # explanation specific columns 268 | spatial_radius = Column(Integer) 269 | spatial_eta = Column(Float) 270 | 271 | # spotlight columns 272 | method = Column(String) 273 | spotlights = Column(Integer) 274 | spotlight_size = Column(Integer) 275 | spotlight_eta = Column(Float) 276 | obj_function = Column(String) 277 | 278 | def __init__( 279 | self, 280 | id, 281 | path, 282 | target, 283 | confidence, 284 | responsibility, 285 | responsibility_shape, 286 | explanation, 287 | explanation_shape, 288 | explanation_confidence, 289 | time_taken, 290 | passing=None, 291 | failing=None, 292 | total_work=None, 293 | multi=False, 294 | multi_no=None, 295 | depth_reached=None, 296 | avg_box_size=None, 297 | tree_depth=None, 298 | search_limit=None, 299 | iters=None, 300 | min_size=None, 301 | distribution=None, 302 | distribution_args=None, 303 | initial_radius=None, 304 | radius_eta=None, 305 | method=None, 306 | spotlights=0, 307 | spotlight_size=0, 308 | spotlight_eta=0.0, 309 | obj_function=None, 310 | ): 311 | self.id = id 312 | self.path = path 313 | self.target = target 314 | self.confidence = confidence 315 | self.responsibility = responsibility 316 | self.responsibility_shape = str(responsibility_shape) 317 | self.explanation = explanation 318 | self.explanation_shape = str(explanation_shape) 319 | self.explanation_confidence = explanation_confidence 320 | self.time = time_taken 321 | self.total_work = total_work 322 | self.passing = passing 323 | self.failing = failing 324 | # multi status 325 | self.multi = multi 326 | self.multi_no = multi_no 327 | # causal 328 | self.depth_reached = depth_reached 329 | self.avg_box_size = avg_box_size 330 | self.tree_depth = tree_depth 331 | self.search_limit = search_limit 332 | self.iters = iters 333 | self.min_size = min_size 334 | self.distribution = distribution 335 | self.distribution_args = distribution_args 336 | # spatial 337 | self.spatial_radius = initial_radius 338 | self.spatial_eta = radius_eta 339 | self.method = method 340 | # spotlights 341 | self.spotlights = spotlights 342 | self.spotlight_size = spotlight_size 343 | self.spotlight_eta = spotlight_eta 344 | self.obj_function = obj_function 345 | 346 | 347 | def initialise_rex_db(name, echo=False): 348 | engine = create_engine(f"sqlite:///{name}", echo=echo) 349 | Base.metadata.create_all(engine, tables=[DataBaseEntry.__table__], checkfirst=True) # type: ignore 350 | Session = sessionmaker(bind=engine) 351 | s = Session() 352 | return s 353 | -------------------------------------------------------------------------------- /rex_xai/responsibility/prediction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional 4 | import torch as tt 5 | import torch.nn.functional as F 6 | from numpy.typing import NDArray 7 | from typing import List 8 | from rex_xai.utils._utils import ff 9 | 10 | 11 | class Prediction: 12 | def __init__( 13 | self, 14 | pred=None, 15 | conf=None, 16 | box=None, 17 | target=None, 18 | target_confidence=None, 19 | ) -> None: 20 | self.classification: Optional[int] = pred 21 | self.confidence: Optional[float] = conf 22 | self.bounding_box: Optional[NDArray] = box 23 | self.target: Optional[int] = None if target is None else target.classification 24 | self.target_confidence: Optional[float] = target_confidence 25 | 26 | def __repr__(self) -> str: 27 | if self.bounding_box is None: 28 | if self.is_passing(): 29 | return ( 30 | f"FOUND_CLASS: {self.classification}, CONF: {self.confidence:.5f}" 31 | ) 32 | else: 33 | if self.target is None: 34 | return f"FOUND_CLASS: {self.classification}, FOUND_CONF: {self.confidence:.5f}, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a" 35 | else: 36 | return f"FOUND_CLASS: {self.classification}, FOUND_CONF: {self.confidence:.5f}, TARGET_CLASS: {self.target}, TARGET_CONFIDENCE: {ff(self.target_confidence, '.5f')}" 37 | 38 | return f"CLASS: {self.classification}, CONF: {self.confidence:.5f}, TARGET_CLASS: {self.target}, TARGET_CONFIDENCE: {ff(self.target_confidence, '.5f')}, BOUNDING_BOX: {self.bounding_box}" 39 | 40 | def get_class(self): 41 | return self.classification 42 | 43 | def is_empty(self): 44 | return self.classification is None or self.confidence is None 45 | 46 | def is_passing(self): 47 | return self.target == self.classification 48 | 49 | 50 | def from_pytorch_tensor(tensor, target=None, binary_threshold=None) -> List[Prediction]: 51 | # TODO get this to handle binary models 52 | softmax_tensor = F.softmax(tensor, dim=1) 53 | prediction_scores, pred_labels = tt.topk(softmax_tensor, 1) 54 | predictions = [] 55 | for i, (ps, pl) in enumerate(zip(prediction_scores, pred_labels)): 56 | p = Prediction(pl.item(), ps.item()) 57 | if target is not None: 58 | p.target = target 59 | p.target_confidence = softmax_tensor[i, target.classification].item() 60 | predictions.append(p) 61 | 62 | return predictions 63 | -------------------------------------------------------------------------------- /rex_xai/responsibility/resp_maps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | from typing import List 4 | 5 | from typing import Optional 6 | 7 | try: 8 | from anytree.cachedsearch import find 9 | except ImportError: 10 | from anytree import find 11 | 12 | from rex_xai.mutants.box import Box 13 | from rex_xai.input.config import CausalArgs 14 | from rex_xai.mutants.mutant import Mutant 15 | from rex_xai.input.input_data import Data 16 | from rex_xai.utils.logger import logger 17 | from rex_xai.utils._utils import ReXMapError 18 | 19 | 20 | class ResponsibilityMaps: 21 | def __init__(self) -> None: 22 | self.maps = {} 23 | self.counts = {} 24 | 25 | def __repr__(self) -> str: 26 | return str(self.counts) 27 | 28 | def get(self, k, increment=False): 29 | try: 30 | if increment: 31 | self.counts[k] += 1 # type: ignore 32 | return self.maps[k] 33 | except KeyError: 34 | return 35 | 36 | def new_map(self, k: int, height, width, depth=None): 37 | if depth is not None: 38 | self.maps[k] = np.zeros((height, width, depth), dtype="float32") 39 | self.counts[k] = 1 40 | else: 41 | self.maps[k] = np.zeros((height, width), dtype="float32") 42 | self.counts[k] = 1 43 | 44 | def items(self): 45 | return self.maps.items() 46 | 47 | def keys(self): 48 | return self.maps.keys() 49 | 50 | def len(self): 51 | return len(self.maps) 52 | 53 | def merge(self, maps): 54 | for k, v in maps.items(): 55 | if np.max(v) == 0: 56 | break 57 | if k in self.maps: 58 | self.maps[k] += v 59 | else: 60 | self.maps[k] = v 61 | 62 | def negative_responsibility(self, target): 63 | for k, v in self.maps.items(): 64 | if k != target: 65 | logger.debug( 66 | f"subtracting responsibility for class {k} from class {target}" 67 | ) 68 | self.maps[target] = self.maps[target] - v # type: ignore 69 | 70 | def responsibility(self, mutant: Mutant, args: CausalArgs): 71 | responsibility = np.zeros(4, dtype=np.float32) 72 | parts = mutant.get_active_boxes() 73 | r = 1 / len(parts) 74 | for p in parts: 75 | i = np.uint(p[-1]) 76 | if ( 77 | args.weighted 78 | and mutant.prediction is not None 79 | and mutant.prediction.confidence is not None 80 | ): 81 | responsibility[i] += r * mutant.prediction.confidence 82 | else: 83 | responsibility[i] += r 84 | return responsibility 85 | 86 | def update_maps( 87 | self, mutants: List[Mutant], args: CausalArgs, data: Data, search_tree 88 | ): 89 | """Update the different responsibility maps with all passing mutants 90 | @params mutants: list of mutants 91 | @params args: causal args 92 | @params data: data 93 | @params search_tree: tree of boxes 94 | 95 | Mutates in place, does not return a value 96 | """ 97 | 98 | for mutant in mutants: 99 | r = self.responsibility(mutant, args) 100 | 101 | k = None 102 | # check that there is a prediction value 103 | if mutant.prediction is not None: 104 | k = mutant.prediction.classification 105 | # if there's no prediction value, raise an exception 106 | if k is None: 107 | raise ReXMapError("the provided mutant has no known classification") 108 | # check if k has been seen before and has a map. If k is new, make a new map 109 | if k not in self.maps: 110 | self.new_map(k, data.model_height, data.model_width, data.model_depth) 111 | 112 | # get the responsibility map for k 113 | resp_map = self.get(k, increment=True) 114 | if resp_map is None: 115 | raise ValueError( 116 | f"unable to open or generate a responsibility map for classification {k}" 117 | ) 118 | 119 | # we only increment responsibility for active boxes, not static boxes 120 | for box_name in mutant.get_active_boxes(): 121 | box: Optional[Box] = find(search_tree, lambda n: n.name == box_name) 122 | if box is not None and box.area() > 0: 123 | index = np.uint(box_name[-1]) 124 | local_r = r[index] 125 | # print(box.depth) 126 | if args.concentrate: 127 | local_r *= box.depth 128 | # Don't delete this code just yet as this is an alternative (less brutal) 129 | # local_r *= 1.0 / box.area() 130 | # scaling strategy that needs further investigation 131 | # scale = depth - 1 132 | # local_r = 2**(local_r * scale) 133 | 134 | if data.mode == "spectral": 135 | section = resp_map[0, box.col_start : box.col_stop] 136 | elif data.mode == "RGB": 137 | section = resp_map[ 138 | box.row_start : box.row_stop, 139 | box.col_start : box.col_stop, 140 | ] 141 | elif data.mode == "voxel": 142 | section = resp_map[ 143 | box.row_start : box.row_stop, 144 | box.col_start : box.col_stop, 145 | box.depth_start : box.depth_stop, 146 | ] 147 | else: 148 | logger.warning("not yet implemented") 149 | raise NotImplementedError 150 | 151 | section += local_r 152 | self.maps[k] = resp_map 153 | 154 | def subset(self, id): 155 | m = self.maps.get(id) 156 | c = self.counts.get(id) 157 | self.maps = {id: m} 158 | self.counts = {id: c} 159 | -------------------------------------------------------------------------------- /rex_xai/rex_wrapper.py: -------------------------------------------------------------------------------- 1 | """main entry point to ReX""" 2 | 3 | from rex_xai.utils._utils import get_device 4 | from rex_xai.input.config import get_all_args 5 | from rex_xai.output.database import initialise_rex_db 6 | 7 | from rex_xai.explanation.rex import explanation 8 | from rex_xai.utils.logger import logger, set_log_level 9 | from rex_xai.input.config import validate_args 10 | 11 | 12 | def main(): 13 | """main entry point to ReX cmdline tool""" 14 | args = get_all_args() 15 | validate_args(args) 16 | set_log_level(args.verbosity, logger) 17 | 18 | device = get_device(args.gpu) 19 | 20 | logger.debug("running ReX with the following args:\n %s", args) 21 | 22 | db = None 23 | if args.db is not None: 24 | db = initialise_rex_db(args.db) 25 | 26 | explanation(args, device, db) 27 | -------------------------------------------------------------------------------- /rex_xai/utils/_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import importlib.metadata 4 | from itertools import chain, combinations 5 | from enum import Enum 6 | from typing import Tuple, Union, Dict 7 | from numpy.typing import NDArray 8 | import torch as tt 9 | import numpy as np 10 | from skimage.segmentation import mark_boundaries 11 | from rex_xai.utils.logger import logger 12 | from rex_xai.mutants.box import Box 13 | 14 | Strategy = Enum("Strategy", ["Global", "Spatial", "MultiSpotlight", "Contrastive"]) 15 | 16 | Queue = Enum("Queue", ["Area", "All", "Intersection", "DC"]) 17 | 18 | SpatialSearch = Enum("SpatialSearch", ["NotFound", "Found"]) 19 | 20 | 21 | def one_d_permute(tensor): 22 | perm = tt.randperm(len(tensor)) 23 | return tensor[perm], perm 24 | 25 | 26 | def powerset(r, reverse=True): 27 | ps = list(chain.from_iterable(combinations(r, lim) for lim in range(1, len(r) + 1))) 28 | if reverse: 29 | return reversed(ps) 30 | else: 31 | return ps 32 | 33 | 34 | def clause_area(clause, areas: Dict) -> int: 35 | tot = 0 36 | for c in clause: 37 | tot += areas[c] 38 | return tot 39 | 40 | 41 | class ReXError(Exception): 42 | pass 43 | 44 | 45 | class ReXTomlError(ReXError): 46 | def __init__(self, message) -> None: 47 | self.message = message 48 | super().__init__(self.message) 49 | 50 | def __str__(self) -> str: 51 | return f"ReXTomlError: {self.message}" 52 | 53 | 54 | class ReXPathError(ReXError): 55 | def __init__(self, message) -> None: 56 | self.message = message 57 | super().__init__(self.message) 58 | 59 | def __str__(self) -> str: 60 | return f"ReXPathError: no such file exists at {self.message}" 61 | 62 | 63 | class ReXScriptError(ReXError): 64 | def __init__(self, message) -> None: 65 | self.message = message 66 | super().__init__(self.message) 67 | 68 | def __str__(self) -> str: 69 | return f"ReXScriptError: {self.message}" 70 | 71 | 72 | class ReXDataError(ReXError): 73 | def __init__(self, message) -> None: 74 | self.message = message 75 | super().__init__(self.message) 76 | 77 | def __str__(self) -> str: 78 | return f"ReXDataError: {self.message}" 79 | 80 | 81 | class ReXMapError(ReXError): 82 | def __init__(self, message) -> None: 83 | self.message = message 84 | super().__init__(self.message) 85 | 86 | def __str__(self) -> str: 87 | return f"ReXMapError: {self.message}" 88 | 89 | 90 | def xlogx(ps): 91 | f = np.vectorize(_xlogx) 92 | return f(ps) 93 | 94 | 95 | def _xlogx(p): 96 | if p == 0.0: 97 | return 0.0 98 | else: 99 | return p * np.log2(p) 100 | 101 | 102 | def add_boundaries( 103 | img: Union[NDArray, tt.Tensor], segs: NDArray, colour=None 104 | ) -> NDArray: 105 | if colour is None: 106 | m = mark_boundaries(img, segs, mode="thick") 107 | else: 108 | m = mark_boundaries(img, segs, colour, mode="thick") 109 | m *= 255 # type: ignore 110 | m = m.astype(np.uint8) 111 | return m 112 | 113 | 114 | def get_device(gpu: bool): 115 | if tt.backends.mps.is_available() and gpu: 116 | return tt.device("mps") 117 | if tt.device("cuda") and gpu: 118 | return tt.device("cuda") 119 | if gpu: 120 | logger.warning("gpu not available") 121 | return tt.device("cpu") 122 | 123 | 124 | def get_map_locations(map, reverse=True): 125 | if isinstance(map, tt.Tensor): 126 | map = map.detach().cpu().numpy() 127 | coords = [] 128 | for i, r in enumerate(np.nditer(map)): 129 | coords.append((r, np.unravel_index(i, map.shape))) 130 | coords = sorted(coords, reverse=reverse) 131 | return coords 132 | 133 | 134 | def set_boolean_mask_value( 135 | tensor, 136 | mode, 137 | order, 138 | coords: Union[Box, Tuple[NDArray, NDArray]], 139 | val: bool = True, 140 | ): 141 | if isinstance(coords, Box): 142 | if mode in ("spectral", "tabular"): 143 | h = coords.col_start 144 | w = coords.col_stop 145 | elif mode == "voxel": 146 | h = slice(coords.row_start, coords.row_stop) 147 | w = slice(coords.col_start, coords.col_stop) 148 | d = slice(coords.depth_start, coords.depth_stop) 149 | else: 150 | h = slice(coords.row_start, coords.row_stop) 151 | w = slice(coords.col_start, coords.col_stop) 152 | else: 153 | if mode == "voxel": 154 | h = coords[0] 155 | w = coords[1] 156 | d = coords[2] 157 | else: 158 | h = coords[0] 159 | w = coords[1] 160 | 161 | # three channels 162 | if mode == "RGB": 163 | # (C, H, W) 164 | if order == "first": 165 | tensor[:, h, w] = val 166 | # (H, W, C) 167 | else: 168 | tensor[h, w, :] = val 169 | elif mode == "L": 170 | if order == "first": 171 | # (1, H, W) 172 | tensor[0, h, w] = val 173 | else: 174 | tensor[h, w, :] = val 175 | elif mode in ("spectral", "tabular"): 176 | if len(tensor.shape) == 1: 177 | tensor[h:w] = val 178 | else: 179 | tensor[0, h:w] = val 180 | # elif mode == "tabular": 181 | 182 | elif mode == "voxel": 183 | tensor[h, w, d] = val 184 | else: 185 | raise ReXError("mode not recognised") 186 | 187 | 188 | def ff(obj, fmt): 189 | """ 190 | Like format(obj, fmt), but returns the string 'None' if obj is None. 191 | See the help for format() to see acceptable values for fmt. 192 | """ 193 | return "None" if obj is None else format(obj, fmt) 194 | 195 | 196 | def version(): 197 | return importlib.metadata.version("rex-xai") 198 | -------------------------------------------------------------------------------- /rex_xai/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """simple wrapper around python logging module""" 3 | 4 | import logging 5 | import sys 6 | import os 7 | 8 | 9 | def set_log_level(i: int, rex_logger): 10 | """sets the logging level for rex""" 11 | if i == 0: 12 | rex_logger.setLevel(logging.CRITICAL) 13 | elif i == 1: 14 | rex_logger.setLevel(logging.WARNING) 15 | elif i == 2: 16 | rex_logger.setLevel(logging.INFO) 17 | else: 18 | rex_logger.setLevel(logging.DEBUG) 19 | 20 | 21 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 22 | logger = logging.getLogger("ReX") 23 | logging.basicConfig(stream=sys.stdout, level=logging.WARNING) 24 | -------------------------------------------------------------------------------- /scripts/pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import platform 4 | from torchvision.models import get_model 5 | from torchvision import transforms as T 6 | import torch as tt 7 | import torch.nn.functional as F 8 | from PIL import Image # type: ignore 9 | from rex_xai.input.input_data import Data 10 | from rex_xai.responsibility.prediction import from_pytorch_tensor 11 | 12 | 13 | model = get_model('resnet50', weights="DEFAULT") 14 | model.eval() 15 | 16 | if platform.uname().system == "Darwin": 17 | model.to("mps") 18 | else: 19 | model.to("cuda") 20 | 21 | 22 | def preprocess(path, shape, device, mode) -> Data: 23 | transform = T.Compose( 24 | [ 25 | T.Resize((224, 224)), 26 | T.ToTensor(), 27 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 28 | ] 29 | ) 30 | # open the image with mode "RGB" 31 | img = Image.open(path).convert("RGB") 32 | # create a Data object 33 | data = Data(img, shape, device, mode='RGB') 34 | # manually set the data to the transformed image for model consumption 35 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore 36 | 37 | return data 38 | 39 | 40 | def prediction_function(mutants, target=None, raw=False, binary_threshold=False): 41 | with tt.no_grad(): # we don't use the grad and inference is faster without it 42 | tensor = model(mutants) 43 | if raw: # used when computing insertion/deletion curves 44 | return F.softmax(tensor, dim=1) 45 | # from_pytorch_tensor consumes a tensor and converts it to a Prediction object 46 | # you can alternatively use your own function here 47 | return from_pytorch_tensor(tensor, target=target) 48 | 49 | 50 | def model_shape(): 51 | return ["N", 3, 224, 224] 52 | -------------------------------------------------------------------------------- /scripts/spectral_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import platform 4 | import numpy as np 5 | import torch as tt 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from rex_xai.input.input_data import Data 10 | from rex_xai.responsibility.prediction import from_pytorch_tensor 11 | 12 | class ConvNet(nn.Module): 13 | def __init__(self): 14 | super(ConvNet, self).__init__() 15 | self.layer1 = nn.Sequential( 16 | nn.Conv1d(1, 20, kernel_size=7, stride=1, padding=0), 17 | nn.BatchNorm1d(20), 18 | nn.ReLU(), 19 | nn.MaxPool1d(kernel_size=7, stride=3)) 20 | self.layer2 = nn.Sequential( 21 | nn.Conv1d(20, 40, kernel_size=5, stride=1, padding=0), 22 | nn.BatchNorm1d(40), 23 | nn.ReLU(), 24 | nn.MaxPool1d(kernel_size=5, stride=2)) 25 | self.layer3 = nn.Sequential( 26 | nn.Conv1d(40, 40, kernel_size=3, stride=1, padding=0), 27 | nn.BatchNorm1d(40), 28 | nn.ReLU(), 29 | nn.MaxPool1d(kernel_size=3, stride=1)) 30 | self.fc1 = nn.Linear(8640, 60) 31 | self.fc2 = nn.Linear(60, 2) 32 | def forward(self, x): 33 | out = self.layer1(x) 34 | out = self.layer2(out) 35 | out = self.layer3(out) 36 | out = out.reshape(out.size(0), -1) 37 | out = F.relu(self.fc1(out)) 38 | out = F.dropout(out, 0.2) 39 | out = self.fc2(out) 40 | return out 41 | 42 | # model = ConvNet().to('mps') 43 | if platform.uname().system == "Darwin": 44 | model = ConvNet().to("mps") 45 | else: 46 | model = ConvNet().to("cuda") 47 | 48 | model.load_state_dict(tt.load("simple_DNA_model.pt", map_location='mps', weights_only=True)) 49 | model.eval() 50 | 51 | 52 | if platform.uname().system == "Darwin": 53 | model.to("mps") 54 | else: 55 | model.to("cuda") 56 | 57 | def preprocess(path, shape, device, mode) -> Data: 58 | path = tt.from_numpy(np.load(path)).float().unsqueeze(0).unsqueeze(0).to('mps') 59 | data = Data(path , (1, 1356), 'mps', mode="spectral", process=True) 60 | data.set_width(1356) 61 | data.set_height(1) 62 | return data 63 | 64 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 65 | with tt.no_grad(): 66 | tensor = model(mutants) 67 | if raw: 68 | return F.softmax(tensor, dim=1) 69 | return from_pytorch_tensor(tensor, target=target) 70 | 71 | def model_shape(): 72 | return ["N", 1, 1356] 73 | 74 | -------------------------------------------------------------------------------- /scripts/timm_script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import platform 4 | import timm 5 | from PIL import Image 6 | import torch as tt 7 | import torch.nn.functional as F 8 | from rex_xai.input.input_data import Data 9 | from rex_xai.responsibility.prediction import from_pytorch_tensor 10 | 11 | model = timm.create_model("resnet50.a1_in1k", pretrained=True) 12 | model.eval() 13 | 14 | if platform.uname().system == "Darwin": 15 | if tt.mps.is_available(): 16 | model.to("mps") 17 | else: 18 | if tt.cuda.is_available(): 19 | model.to("cuda") 20 | 21 | 22 | def preprocess(path, shape, device, mode) -> Data: 23 | data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) # type: ignore 24 | 25 | transform = timm.data.create_transform(**data_cfg) # type: ignore 26 | 27 | img = Image.open(path).convert("RGB") 28 | data = Data(img, shape, device, mode='RGB') 29 | data.data = transform(img).unsqueeze(0).to(device) 30 | 31 | return data 32 | 33 | 34 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 35 | with tt.no_grad(): 36 | tensor = model(mutants) 37 | if raw: 38 | return F.softmax(tensor, dim=1) 39 | return from_pytorch_tensor(tensor) 40 | 41 | 42 | def model_shape(): 43 | batch_size = ["N"] # put your batch size here 44 | input_size = timm.data.resolve_data_config(model.pretrained_cfg)['input_size'] 45 | 46 | return batch_size + list(input_size) 47 | -------------------------------------------------------------------------------- /shell/_rex: -------------------------------------------------------------------------------- 1 | #compdef ReX 2 | 3 | _rex() { 4 | local state 5 | 6 | _arguments \ 7 | '1: :_files'\ 8 | '--help[show help]' \ 9 | '--surface[Plot a 3D responsibility map, optionally save to ]' \ 10 | '--heatmap[Plot a 2D responsibility map, optionally save to ]' \ 11 | '--script[PyTorch compatible python script]:::_files' \ 12 | '--model[onnx model to use]:::_files' \ 13 | '--no_extract[Do not extract an explanation from the responsibility map]' \ 14 | '--config[optional config file to use with ReX]'\ 15 | '--output[save explanation to ]:::_files'\ 16 | {-v,--verbosity}'[set verbosity level, one of -v, -vv, or -vvv]'\ 17 | {-q,--quiet}'[set verbosity level to 0 (errors only)]'\ 18 | '--strategy[explanation extraction strategy, by default ]:::(global spatial)'\ 19 | {-db,--database}'[store output in sqlite database , creating it if necessary]:::_files'\ 20 | '--multi[multiple explanations, with optional spotlights]'\ 21 | '--contrastive[sufficent and necessary explanations, with optional spotlights]'\ 22 | '--iters[number of iterations of the main algorithm to perform]'\ 23 | '--analyse[perform some basic analysis of the explanation, optionally saved to ]'\ 24 | '--mode[assist ReX with input type, one of tabular | spectral | RGB | voxel]:::(tabular spectral RGB voxel)'\ 25 | '--spectral[shortcut for --mode spectral]'\ 26 | '--version[print version number and quit]'\ 27 | } 28 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as tt 3 | import pytest 4 | from cached_path import cached_path 5 | from rex_xai.utils._utils import get_device 6 | from rex_xai.mutants.box import initialise_tree 7 | from rex_xai.input.config import CausalArgs, process_custom_script, Strategy 8 | from rex_xai.mutants.distributions import Distribution 9 | from rex_xai.explanation.rex import ( 10 | calculate_responsibility, 11 | get_prediction_func_from_args, 12 | load_and_preprocess_data, 13 | predict_target, 14 | try_preprocess, 15 | ) 16 | from rex_xai.explanation.explanation import Explanation 17 | from rex_xai.explanation.multi_explanation import MultiExplanation 18 | from syrupy.extensions.amber.serializer import AmberDataSerializer 19 | from syrupy.filters import props 20 | from syrupy.matchers import path_type 21 | 22 | from rex_xai.input.input_data import Data 23 | 24 | 25 | @pytest.fixture 26 | def snapshot_explanation(snapshot): 27 | return snapshot.with_defaults( 28 | exclude=props( 29 | "obj_function", # pointer to function that will differ between runs 30 | "spotlight_objective_function", # pointer to function that will differ between runs 31 | "script", # path that differs between systems 32 | "script_location", # path that differs between systems 33 | "model", 34 | "target_map", # large array 35 | "final_mask", # large array 36 | "explanation" # large array 37 | ), 38 | matcher=path_type( 39 | types=(CausalArgs,), 40 | replacer=lambda data, _: AmberDataSerializer.object_as_named_tuple( #type: ignore 41 | data 42 | ), # needed to allow exclude to work for custom classes 43 | ) 44 | ) 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def resnet50(): 49 | resnet50_path = cached_path( 50 | "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/resnet/model/resnet50-v1-7.onnx" 51 | ) 52 | return resnet50_path 53 | 54 | 55 | @pytest.fixture(scope="session") 56 | def DNA_model(): 57 | DNA_model_path = cached_path( 58 | "https://github.com/ReX-XAI/models/raw/6f66a5c0e1480411436be828ee8312e72f0035e1/spectral/simple_DNA_model.onnx" 59 | ) 60 | return DNA_model_path 61 | 62 | 63 | @pytest.fixture 64 | def args(): 65 | args = CausalArgs() 66 | args.path = "tests/test_data/dog.jpg" 67 | args.iters = 2 68 | args.search_limit = 1000 69 | args.gpu = False 70 | 71 | return args 72 | 73 | 74 | @pytest.fixture 75 | def args_custom(args): 76 | process_custom_script("tests/scripts/pytorch_resnet50.py", args) 77 | args.seed = 42 78 | 79 | return args 80 | 81 | 82 | @pytest.fixture 83 | def args_torch_swin_v2_t(args): 84 | process_custom_script("tests/scripts/pytorch_swin_v2_t.py", args) 85 | args.seed = 42 86 | 87 | return args 88 | 89 | 90 | @pytest.fixture 91 | def args_onnx(args, resnet50): 92 | args.model = resnet50 93 | args.seed = 100 94 | 95 | return args 96 | 97 | 98 | @pytest.fixture 99 | def args_multi(args_custom): 100 | args = args_custom 101 | args.path = "tests/test_data/peacock.jpg" 102 | args.iters = 5 103 | args.strategy = Strategy.MultiSpotlight 104 | args.spotlights = 5 105 | 106 | return args 107 | 108 | 109 | @pytest.fixture 110 | def model_shape(args_custom): 111 | prediction_func, model_shape = get_prediction_func_from_args(args_custom) 112 | 113 | return model_shape 114 | 115 | 116 | @pytest.fixture 117 | def prediction_func(args_custom): 118 | prediction_func, model_shape = get_prediction_func_from_args(args_custom) 119 | 120 | return prediction_func 121 | 122 | 123 | @pytest.fixture 124 | def model_shape_swin_v2_t(args_torch_swin_v2_t): 125 | prediction_func, model_shape = get_prediction_func_from_args(args_torch_swin_v2_t) 126 | 127 | return model_shape 128 | 129 | 130 | @pytest.fixture 131 | def prediction_func_swin_v2_t(args_torch_swin_v2_t): 132 | prediction_func, model_shape = get_prediction_func_from_args(args_torch_swin_v2_t) 133 | 134 | return prediction_func 135 | 136 | 137 | @pytest.fixture 138 | def data(args_custom, model_shape, cpu_device): 139 | data = try_preprocess(args_custom, model_shape, device=cpu_device) 140 | return data 141 | 142 | 143 | @pytest.fixture 144 | def data_custom(args_custom, model_shape, cpu_device): 145 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom) 146 | data.set_mask_value(args_custom.mask_value) 147 | return data 148 | 149 | 150 | @pytest.fixture 151 | def data_multi(args_multi, model_shape, prediction_func, cpu_device): 152 | data = load_and_preprocess_data(model_shape, cpu_device, args_multi) 153 | data.set_mask_value(args_multi.mask_value) 154 | data.target = predict_target(data, prediction_func) 155 | return data 156 | 157 | 158 | @pytest.fixture(scope="session") 159 | def cpu_device(): 160 | device = get_device(gpu=False) 161 | 162 | return device 163 | 164 | 165 | @pytest.fixture 166 | def exp_custom(data_custom, args_custom, prediction_func): 167 | data_custom.target = predict_target(data_custom, prediction_func) 168 | maps, run_stats = calculate_responsibility( 169 | data_custom, args_custom, prediction_func 170 | ) 171 | exp = Explanation(maps, prediction_func, data_custom, args_custom, run_stats) 172 | 173 | return exp 174 | 175 | 176 | @pytest.fixture 177 | def exp_onnx(args_onnx, cpu_device): 178 | prediction_func, model_shape = get_prediction_func_from_args(args_onnx) 179 | data = load_and_preprocess_data(model_shape, cpu_device, args_onnx) 180 | data.set_mask_value(args_onnx.mask_value) 181 | data.target = predict_target(data, prediction_func) 182 | maps, run_stats = calculate_responsibility(data, args_onnx, prediction_func) 183 | exp = Explanation(maps, prediction_func, data, args_onnx, run_stats) 184 | 185 | return exp 186 | 187 | 188 | @pytest.fixture 189 | def exp_extracted(exp_custom): 190 | exp_custom.extract(Strategy.Global) 191 | 192 | return exp_custom 193 | 194 | 195 | @pytest.fixture 196 | def exp_multi(args_multi, data_multi, prediction_func): 197 | maps, run_stats = calculate_responsibility(data_multi, args_multi, prediction_func) 198 | multi_exp = MultiExplanation( 199 | maps, prediction_func, data_multi, args_multi, run_stats 200 | ) 201 | multi_exp.extract(args_multi.strategy) 202 | return multi_exp 203 | 204 | 205 | @pytest.fixture 206 | def data_3d(): 207 | voxel = np.zeros((1, 64, 64, 64), dtype=np.float32) 208 | voxel[0:30, 20:30, 20:35] = 1 209 | return Data( 210 | input=voxel, 211 | model_shape=[1, 64, 64, 64], 212 | device="cpu", 213 | mode="voxel" 214 | ) 215 | 216 | @pytest.fixture 217 | def data_2d(): 218 | return Data( 219 | input=np.arange(1, 64, 64), 220 | model_shape=[1, 64, 64], 221 | device="cpu" 222 | ) 223 | 224 | @pytest.fixture 225 | def box_3d(): 226 | return initialise_tree( 227 | r_lim=64, 228 | c_lim=64, 229 | d_lim=64, 230 | r_start=0, 231 | c_start=0, 232 | d_start=0, 233 | distribution=Distribution.Uniform, 234 | distribution_args=None, 235 | ) 236 | 237 | @pytest.fixture 238 | def box_2d(): 239 | return initialise_tree( 240 | r_lim=64, 241 | c_lim=64, 242 | r_start=0, 243 | c_start=0, 244 | distribution=Distribution.Uniform, 245 | distribution_args=None, 246 | ) 247 | 248 | @pytest.fixture 249 | def resp_map_2d(): 250 | return np.zeros((64, 64), dtype="float32") 251 | 252 | @pytest.fixture 253 | def resp_map_3d(): 254 | resp_map = tt.zeros((64, 64, 64), dtype=tt.float32) 255 | resp_map[0:10, 20:25, 20:35] = 1 256 | return resp_map -------------------------------------------------------------------------------- /tests/scripts/pytorch_3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs.Tensor import Tensor 3 | import torch as tt 4 | import torch.nn.functional as F 5 | from rex_xai.input.input_data import Data 6 | from rex_xai.responsibility.prediction import from_pytorch_tensor 7 | from monai.networks.nets import DenseNet121 8 | from monai.transforms import Resize, LoadImage 9 | 10 | # Load the sample 3D model 11 | device = tt.device("cuda" if tt.cuda.is_available() else "cpu") 12 | model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device) 13 | model.to(device) 14 | model.eval() 15 | 16 | 17 | def preprocess(path, shape, device, mode) -> Data: 18 | """ 19 | The preprocessing function is executed before the 20 | model is called. It is used to prepare the input. 21 | 22 | Args: 23 | path: str 24 | The path to the input data 25 | shape: tuple 26 | The shape of the input data 27 | device: str 28 | The device to use indicate where the data should be loaded: "cpu" or "cuda" 29 | mode: str 30 | The mode of the input data: "voxel" for 3D data 31 | 32 | Returns: 33 | Data: The input data object 34 | The input data object that contains the processed data 35 | and the metadata of the input data including the mode, model_height, 36 | model_width, model_depth, model_channels, model_order, and transposed. 37 | """ 38 | transform = Resize(spatial_size=(64, 64, 64)) 39 | if path is str: 40 | volume = LoadImage()(path) 41 | transformed_volume = transform(volume) 42 | data = Data(transformed_volume, shape, device, mode=mode, process=False) 43 | elif path is Tensor: 44 | transformed_volume = transform(path) 45 | data = Data(transformed_volume, shape, device, mode=mode, process=False) 46 | else: 47 | raise ValueError("Invalid input type") 48 | data.mode = "voxel" 49 | data.model_shape = shape 50 | data.model_height = 64 51 | data.model_width = 64 52 | data.model_depth = 64 53 | data.transposed = True 54 | 55 | return data 56 | 57 | 58 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 59 | """ 60 | The prediction function calls the model itself and returns the output. 61 | 62 | Args: 63 | mutants: Data 64 | The input data object 65 | target: None 66 | The target label 67 | raw: bool 68 | A flag to indicate if the output should be raw or not 69 | binary_threshold: None 70 | The binary threshold value 71 | 72 | Returns: 73 | list[Prediction] 74 | A list of prediction objects which each contain the output tensor, 75 | the target label, the confidence of the label, the classification confidence, 76 | and the classification label. 77 | 78 | """ 79 | with tt.no_grad(): 80 | tensor = model(mutants) 81 | if raw: 82 | return F.softmax(tensor, dim=1) 83 | return from_pytorch_tensor(tensor) 84 | 85 | 86 | def model_shape(): 87 | return ["N", 64, 64, 64] 88 | -------------------------------------------------------------------------------- /tests/scripts/pytorch_resnet50.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from torchvision.models import resnet50 4 | from torchvision import transforms as T 5 | import torch as tt 6 | import torch.nn.functional as F 7 | from PIL import Image # type: ignore 8 | from rex_xai.input.input_data import Data 9 | from rex_xai.responsibility.prediction import from_pytorch_tensor 10 | 11 | 12 | model = resnet50(weights="ResNet50_Weights.DEFAULT") 13 | model.eval() 14 | 15 | def preprocess(path, shape, device, mode) -> Data: 16 | transform = T.Compose( 17 | [ 18 | T.Resize((224, 224)), 19 | T.ToTensor(), 20 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 21 | ] 22 | ) 23 | # open the image with mode "RGB" 24 | img = Image.open(path).convert("RGB") 25 | # create a Data object 26 | data = Data(img, shape, device, mode='RGB') 27 | # manually set the data to the transformed image for model consumption 28 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore 29 | 30 | return data 31 | 32 | 33 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 34 | with tt.no_grad(): 35 | tensor = model(mutants) 36 | if raw: 37 | return F.softmax(tensor, dim=1) 38 | # from_pytorch_tensor consumes a tensor and converts it to a Prediction object 39 | # you can alternatively use your own function here 40 | return from_pytorch_tensor(tensor, target=target) 41 | 42 | 43 | def model_shape(): 44 | return ["N", 3, 224, 224] 45 | -------------------------------------------------------------------------------- /tests/scripts/pytorch_swin_v2_t.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from torchvision.models import swin_v2_t 4 | from torchvision import transforms as T 5 | import torch as tt 6 | import torch.nn.functional as F 7 | from PIL import Image # type: ignore 8 | from rex_xai.input.input_data import Data 9 | from rex_xai.responsibility.prediction import from_pytorch_tensor 10 | 11 | 12 | model = swin_v2_t(weights="DEFAULT") 13 | model.eval() 14 | model.to("cpu") 15 | 16 | 17 | def preprocess(path, shape, device, mode) -> Data: 18 | transform = T.Compose( 19 | [ 20 | T.Resize((260, 260), T.InterpolationMode.BICUBIC), 21 | T.CenterCrop(256), 22 | T.ToTensor(), 23 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 24 | ] 25 | ) 26 | img = Image.open(path).convert("RGB") 27 | data = Data(img, shape, device, mode=mode, process=False) 28 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore 29 | data.mode = "RGB" 30 | data.model_shape = shape 31 | data.model_height = 256 32 | data.model_width = 256 33 | data.model_channels = 3 34 | data.transposed = True 35 | data.model_order = "first" 36 | 37 | return data 38 | 39 | 40 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None): 41 | with tt.no_grad(): 42 | tensor = model(mutants) 43 | if raw: 44 | return F.softmax(tensor, dim=1) 45 | return from_pytorch_tensor(tensor) 46 | 47 | 48 | def model_shape(): 49 | return ["N", 3, 256, 256] 50 | -------------------------------------------------------------------------------- /tests/snapshot_tests/__snapshots__/_explanation_onnx_test.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test__explanation_snapshot 3 | Explanation: 4 | CausalArgs: 5 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 6 | Target:FOUND_CLASS: 209, FOUND_CONF: 0.39424, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 7 | prediction function: . > 8 | ResponsibilityMaps: {209: 1} 9 | run statistics: {'total_passing': 41, 'total_failing': 169, 'max_depth_reached': 4, 'avg_box_size': 2747.41667} (5 dp) 10 | explanation: of shape torch.Size([3, 224, 224]) 11 | final mask: of shape (3, 224, 224) 12 | explanation confidence: 0.12158 (5 dp) 13 | # --- 14 | # name: test__explanation_snapshot.1 15 | 7409862900317702363 16 | # --- 17 | # name: test_extract_analyze[Strategy.Global] 18 | 976237796511694519 19 | # --- 20 | # name: test_extract_analyze[Strategy.Global].1 21 | dict({ 22 | 'area': 0.1654, 23 | 'deletion_curve': 0.0259, 24 | 'entropy': 5.3764, 25 | 'insertion_curve': 0.503, 26 | }) 27 | # --- 28 | # name: test_extract_analyze[Strategy.Spatial] 29 | 976237796511694519 30 | # --- 31 | # name: test_extract_analyze[Strategy.Spatial].1 32 | dict({ 33 | 'area': 0.1654, 34 | 'deletion_curve': 0.0259, 35 | 'entropy': 5.3764, 36 | 'insertion_curve': 0.503, 37 | }) 38 | # --- 39 | -------------------------------------------------------------------------------- /tests/snapshot_tests/__snapshots__/_explanation_test.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test__explanation_snapshot[1] 3 | Explanation: 4 | CausalArgs: 5 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 6 | Target:FOUND_CLASS: 207, FOUND_CONF: 0.20462, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 7 | prediction function: 8 | ResponsibilityMaps: {207: 1} 9 | run statistics: {'total_passing': 84, 'total_failing': 126, 'max_depth_reached': 5, 'avg_box_size': 183.5625} (5 dp) 10 | explanation: of shape torch.Size([3, 224, 224]) 11 | final mask: of shape (3, 224, 224) 12 | explanation confidence: 0.01513 (5 dp) 13 | # --- 14 | # name: test__explanation_snapshot[1].1 15 | 2537636458340732753 16 | # --- 17 | # name: test__explanation_snapshot[64] 18 | Explanation: 19 | CausalArgs: 20 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 21 | Target:FOUND_CLASS: 207, FOUND_CONF: 0.20462, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 22 | prediction function: 23 | ResponsibilityMaps: {207: 1} 24 | run statistics: {'total_passing': 84, 'total_failing': 126, 'max_depth_reached': 5, 'avg_box_size': 183.5625} (5 dp) 25 | explanation: of shape torch.Size([3, 224, 224]) 26 | final mask: of shape (3, 224, 224) 27 | explanation confidence: 0.01513 (5 dp) 28 | # --- 29 | # name: test__explanation_snapshot[64].1 30 | 2537636458340732753 31 | # --- 32 | # name: test__explanation_snapshot_diff_model_shape[1] 33 | Explanation: 34 | CausalArgs: 35 | Data: Data: RGB, ['N', 3, 256, 256], 256, 256, 3, first 36 | Target:FOUND_CLASS: 215, FOUND_CONF: 0.48740, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 37 | prediction function: 38 | ResponsibilityMaps: {215: 1} 39 | run statistics: {'total_passing': 60, 'total_failing': 136, 'max_depth_reached': 5, 'avg_box_size': 262.0} (5 dp) 40 | explanation: of shape torch.Size([3, 256, 256]) 41 | final mask: of shape (3, 256, 256) 42 | explanation confidence: 0.03730 (5 dp) 43 | # --- 44 | # name: test__explanation_snapshot_diff_model_shape[1].1 45 | -5471936325618413956 46 | # --- 47 | # name: test__explanation_snapshot_diff_model_shape[64] 48 | Explanation: 49 | CausalArgs: 50 | Data: Data: RGB, ['N', 3, 256, 256], 256, 256, 3, first 51 | Target:FOUND_CLASS: 215, FOUND_CONF: 0.48740, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 52 | prediction function: 53 | ResponsibilityMaps: {215: 1} 54 | run statistics: {'total_passing': 60, 'total_failing': 136, 'max_depth_reached': 5, 'avg_box_size': 262.0} (5 dp) 55 | explanation: of shape torch.Size([3, 256, 256]) 56 | final mask: of shape (3, 256, 256) 57 | explanation confidence: 0.03730 (5 dp) 58 | # --- 59 | # name: test__explanation_snapshot_diff_model_shape[64].1 60 | -5471936325618413956 61 | # --- 62 | # name: test_extract_analyze[Strategy.Global] 63 | -2185390445883109608 64 | # --- 65 | # name: test_extract_analyze[Strategy.Global].1 66 | dict({ 67 | 'area': 0.0112, 68 | 'deletion_curve': 1.0199, 69 | 'entropy': 7.1353, 70 | 'insertion_curve': 1.2399, 71 | }) 72 | # --- 73 | # name: test_extract_analyze[Strategy.Spatial] 74 | -2185390445883109608 75 | # --- 76 | # name: test_extract_analyze[Strategy.Spatial].1 77 | dict({ 78 | 'area': 0.0112, 79 | 'deletion_curve': 1.0199, 80 | 'entropy': 7.1353, 81 | 'insertion_curve': 1.2399, 82 | }) 83 | # --- 84 | # name: test_multiexplanation_snapshot 85 | MultiExplanation: 86 | CausalArgs: 87 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 88 | Target:FOUND_CLASS: 84, FOUND_CONF: 0.51611, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 89 | prediction function: 90 | ResponsibilityMaps: {84: 1} 91 | run statistics: {'total_passing': 217, 'total_failing': 329, 'max_depth_reached': 6, 'avg_box_size': 87.1} (5 dp) 92 | explanations: 5 explanations of and shape torch.Size([3, 224, 224]) 93 | explanation confidences: [0.00628, 0.04587, 0.00387, 0.00446, 0.07882] (5 dp) 94 | # --- 95 | # name: test_multiexplanation_snapshot.1 96 | list([ 97 | tuple( 98 | 1, 99 | 2, 100 | 3, 101 | 4, 102 | ), 103 | tuple( 104 | 0, 105 | 1, 106 | 3, 107 | 4, 108 | ), 109 | ]) 110 | # --- 111 | # name: test_multiexplanation_snapshot.2 112 | 7802689268791639071 113 | # --- 114 | # name: test_multiexplanation_snapshot.3 115 | 6035056368165721418 116 | # --- 117 | # name: test_multiexplanation_snapshot.4 118 | -5651945023692276586 119 | # --- 120 | # name: test_multiexplanation_snapshot.5 121 | 2488684738058746373 122 | # --- 123 | # name: test_multiexplanation_snapshot.6 124 | 3892690372276336173 125 | # --- 126 | -------------------------------------------------------------------------------- /tests/snapshot_tests/__snapshots__/explanation_test.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args1] 3 | 3477467290869834009 4 | # --- 5 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args2] 6 | -8754720476463405946 7 | # --- 8 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args3] 9 | -200837265459686415 10 | # --- 11 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args4] 12 | -6449526703691359130 13 | # --- 14 | # name: test_calculate_responsibility[Distribution.Uniform-dist_args0] 15 | -1720903338186101365 16 | # --- 17 | -------------------------------------------------------------------------------- /tests/snapshot_tests/__snapshots__/load_preprocess_test.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test_preprocess 3 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 4 | # --- 5 | # name: test_preprocess_custom 6 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 7 | # --- 8 | # name: test_preprocess_spectral_mask_on_image_returns_warning 9 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first 10 | # --- 11 | -------------------------------------------------------------------------------- /tests/snapshot_tests/__snapshots__/spectral_test.ambr: -------------------------------------------------------------------------------- 1 | # serializer version: 1 2 | # name: test__explanation_snapshot[1] 3 | Explanation: 4 | CausalArgs: 5 | Data: Data: spectral, ['batch_size', 1, 1356], 1, 1356, 1, None 6 | Target:FOUND_CLASS: 1, FOUND_CONF: 0.99817, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 7 | prediction function: . > 8 | ResponsibilityMaps: {1: 1} 9 | run statistics: {'total_passing': 56, 'total_failing': 0, 'max_depth_reached': 2, 'avg_box_size': 22.875} (5 dp) 10 | explanation: of shape torch.Size([1, 1356]) 11 | final mask: of shape (1, 1356) 12 | explanation confidence: 0.99817 (5 dp) 13 | # --- 14 | # name: test__explanation_snapshot[1].1 15 | 3711909154719731854 16 | # --- 17 | # name: test__explanation_snapshot[64] 18 | Explanation: 19 | CausalArgs: 20 | Data: Data: spectral, ['batch_size', 1, 1356], 1, 1356, 1, None 21 | Target:FOUND_CLASS: 1, FOUND_CONF: 0.99817, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a 22 | prediction function: . > 23 | ResponsibilityMaps: {1: 1} 24 | run statistics: {'total_passing': 56, 'total_failing': 0, 'max_depth_reached': 2, 'avg_box_size': 22.875} (5 dp) 25 | explanation: of shape torch.Size([1, 1356]) 26 | final mask: of shape (1, 1356) 27 | explanation confidence: 0.99817 (5 dp) 28 | # --- 29 | # name: test__explanation_snapshot[64].1 30 | 3711909154719731854 31 | # --- 32 | -------------------------------------------------------------------------------- /tests/snapshot_tests/_explanation_onnx_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.input.config import Strategy 3 | from rex_xai.explanation.rex import _explanation, analyze, get_prediction_func_from_args 4 | 5 | 6 | def test__explanation_snapshot(args_onnx, cpu_device, snapshot_explanation): 7 | prediction_func, model_shape = get_prediction_func_from_args(args_onnx) 8 | exp = _explanation(args_onnx, model_shape, prediction_func, cpu_device, db=None) 9 | 10 | assert exp == snapshot_explanation 11 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation 12 | 13 | 14 | @pytest.mark.parametrize("strategy", [Strategy.Global, Strategy.Spatial]) 15 | def test_extract_analyze(exp_onnx, strategy, snapshot): 16 | exp_onnx.extract(strategy) 17 | results = analyze(exp_onnx, "RGB") 18 | results_rounded = {k: round(v, 4) for k, v in results.items() if v is not None} 19 | 20 | assert hash(tuple(exp_onnx.final_mask.reshape(-1).tolist())) == snapshot 21 | assert results_rounded == snapshot 22 | -------------------------------------------------------------------------------- /tests/snapshot_tests/_explanation_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.explanation.rex import _explanation, analyze 3 | from rex_xai.utils._utils import Strategy 4 | 5 | @pytest.mark.parametrize("batch_size", [1, 64]) 6 | def test__explanation_snapshot( 7 | args_custom, model_shape, prediction_func, cpu_device, batch_size, snapshot_explanation 8 | ): 9 | args_custom.batch_size = batch_size 10 | exp = _explanation(args_custom, model_shape, prediction_func, cpu_device, db=None) 11 | 12 | assert exp == snapshot_explanation 13 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation 14 | 15 | 16 | @pytest.mark.parametrize("batch_size", [1, 64]) 17 | def test__explanation_snapshot_diff_model_shape( 18 | args_torch_swin_v2_t, 19 | model_shape_swin_v2_t, 20 | prediction_func_swin_v2_t, 21 | cpu_device, 22 | batch_size, 23 | snapshot_explanation 24 | ): 25 | args_torch_swin_v2_t.batch_size = batch_size 26 | 27 | exp = _explanation( 28 | args_torch_swin_v2_t, 29 | model_shape_swin_v2_t, 30 | prediction_func_swin_v2_t, 31 | cpu_device, 32 | db=None, 33 | ) 34 | 35 | assert exp == snapshot_explanation 36 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation 37 | 38 | 39 | @pytest.mark.parametrize("strategy", [Strategy.Global, Strategy.Spatial]) 40 | def test_extract_analyze(exp_custom, strategy, snapshot): 41 | exp_custom.extract(strategy) 42 | results = analyze(exp_custom, "RGB") 43 | results_rounded = {k: round(v, 4) for k, v in results.items() if v is not None} 44 | 45 | assert hash(tuple(exp_custom.final_mask.reshape(-1).tolist())) == snapshot 46 | assert results_rounded == snapshot 47 | 48 | 49 | def test_multiexplanation_snapshot( 50 | args_multi, model_shape, prediction_func, cpu_device, snapshot_explanation 51 | ): 52 | exp = _explanation(args_multi, model_shape, prediction_func, cpu_device, db=None) 53 | clauses = exp.separate_by(0.0) 54 | 55 | assert exp == snapshot_explanation 56 | assert clauses == snapshot_explanation 57 | for explanation in exp.explanations: 58 | assert hash(tuple(explanation.reshape(-1).tolist())) == snapshot_explanation 59 | -------------------------------------------------------------------------------- /tests/snapshot_tests/explanation_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.mutants.distributions import Distribution 3 | from rex_xai.explanation.rex import calculate_responsibility, predict_target 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "distribution,dist_args", 8 | [ 9 | (Distribution.Uniform, []), 10 | (Distribution.BetaBinomial, [1, 1]), 11 | (Distribution.BetaBinomial, [0.5, 0.5]), 12 | (Distribution.BetaBinomial, [1, 0.5]), 13 | (Distribution.BetaBinomial, [0.5, 1]), 14 | ], 15 | ) 16 | def test_calculate_responsibility( 17 | data_custom, args_custom, prediction_func, distribution, dist_args, snapshot 18 | ): 19 | args_custom.distribution = distribution 20 | if dist_args: 21 | args_custom.distribution_args = dist_args 22 | data_custom.target = predict_target(data_custom, prediction_func) 23 | maps, _ = calculate_responsibility(data_custom, args_custom, prediction_func) 24 | target_map = maps.get(data_custom.target.classification) 25 | 26 | assert hash(tuple(target_map.reshape(-1))) == snapshot 27 | -------------------------------------------------------------------------------- /tests/snapshot_tests/load_preprocess_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.utils._utils import ReXDataError 3 | from rex_xai.explanation.rex import ( 4 | load_and_preprocess_data, 5 | try_preprocess, 6 | get_prediction_func_from_args 7 | ) 8 | 9 | 10 | def test_preprocess(args, model_shape, cpu_device, snapshot): 11 | data = load_and_preprocess_data(model_shape, cpu_device, args) 12 | assert data == snapshot 13 | 14 | 15 | def test_preprocess_custom(args_custom, model_shape, cpu_device, snapshot): 16 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom) 17 | assert data == snapshot 18 | 19 | 20 | def test_preprocess_spectral_mask_on_image_returns_warning( 21 | args, model_shape, cpu_device, snapshot, caplog 22 | ): 23 | args.mask_value = "spectral" 24 | data = try_preprocess(args, model_shape, device=cpu_device) 25 | 26 | assert args.mask_value == 0 27 | assert ( 28 | caplog.records[0].msg 29 | == "spectral is not suitable for images. Changing mask_value to 0" 30 | ) 31 | assert data == snapshot 32 | 33 | 34 | def test_preprocess_npy(args, DNA_model, cpu_device, snapshot, caplog): 35 | args.path = "tests/test_data/DoublePeakClass 0 Mean.npy" 36 | args.model = DNA_model 37 | _, model_shape = get_prediction_func_from_args(args) 38 | 39 | try_preprocess(args, model_shape, device=cpu_device) 40 | 41 | assert caplog.records[0].msg == "we do not generically handle this datatype" 42 | 43 | args.mode = "tabular" 44 | with pytest.raises(ReXDataError): 45 | try_preprocess(args, model_shape, device=cpu_device) 46 | 47 | def test_preprocess_incompatible_shapes(args, model_shape, cpu_device, caplog): 48 | args.path = "tests/test_data/DoublePeakClass 0 Mean.npy" 49 | args.mode = "tabular" 50 | 51 | with pytest.raises(ReXDataError): 52 | try_preprocess(args, model_shape, device=cpu_device) 53 | 54 | -------------------------------------------------------------------------------- /tests/snapshot_tests/spectral_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from cached_path import cached_path 3 | from rex_xai.input.config import Strategy 4 | from rex_xai.explanation.rex import _explanation, get_prediction_func_from_args 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def DNA_model(): 9 | DNA_model_path = cached_path( 10 | "https://github.com/ReX-XAI/models/raw/6f66a5c0e1480411436be828ee8312e72f0035e1/spectral/simple_DNA_model.onnx" 11 | ) 12 | return DNA_model_path 13 | 14 | 15 | @pytest.fixture 16 | def args_spectral(args, DNA_model): 17 | args.model = DNA_model 18 | args.path = "tests/test_data/spectrum_class_DNA.npy" 19 | args.mode = "spectral" 20 | args.mask_value = "spectral" 21 | args.seed = 15 22 | args.strategy = Strategy.Global 23 | 24 | return args 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "batch_size", [1, 64] 29 | ) 30 | def test__explanation_snapshot(args_spectral, cpu_device, batch_size, snapshot_explanation): 31 | args_spectral.batch_size = batch_size 32 | prediction_func, model_shape = get_prediction_func_from_args(args_spectral) 33 | exp = _explanation(args_spectral, model_shape, prediction_func, cpu_device, db=None) 34 | 35 | assert exp == snapshot_explanation 36 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation 37 | -------------------------------------------------------------------------------- /tests/test_data/004_0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/004_0002.jpg -------------------------------------------------------------------------------- /tests/test_data/2008_000033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/2008_000033.jpg -------------------------------------------------------------------------------- /tests/test_data/DoublePeakClass 0 Mean 1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 0 Mean 1.npy -------------------------------------------------------------------------------- /tests/test_data/DoublePeakClass 0 Mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 0 Mean.npy -------------------------------------------------------------------------------- /tests/test_data/DoublePeakClass 1 Mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 1 Mean.npy -------------------------------------------------------------------------------- /tests/test_data/DoublePeakClass 2 Mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 2 Mean.npy -------------------------------------------------------------------------------- /tests/test_data/ILSVRC2012_val_00047302.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/ILSVRC2012_val_00047302.JPEG -------------------------------------------------------------------------------- /tests/test_data/TCGA_DU_7018_19911220_14.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/TCGA_DU_7018_19911220_14.tif -------------------------------------------------------------------------------- /tests/test_data/bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/bike.jpg -------------------------------------------------------------------------------- /tests/test_data/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/dog.jpg -------------------------------------------------------------------------------- /tests/test_data/dog_hide.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/dog_hide.jpg -------------------------------------------------------------------------------- /tests/test_data/ladybird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/ladybird.jpg -------------------------------------------------------------------------------- /tests/test_data/lizard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/lizard.jpg -------------------------------------------------------------------------------- /tests/test_data/peacock.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/peacock.jpg -------------------------------------------------------------------------------- /tests/test_data/positive193.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/positive193.npy -------------------------------------------------------------------------------- /tests/test_data/rex-test-all-config.toml: -------------------------------------------------------------------------------- 1 | [rex] 2 | # masking value for mutations, can be either an integer, float or 3 | # one of the following built-in occlusions 'spectral', 'min', 'mean' 4 | mask_value = "mean" 5 | 6 | # random seed, only set for reproducibility 7 | seed = 42 8 | 9 | # whether to use gpu or not, defaults to true 10 | gpu = false 11 | 12 | # batch size for the model 13 | batch_size = 32 14 | 15 | [rex.onnx] 16 | # means for min-max normalization 17 | means = [0.485, 0.456, 0.406] 18 | 19 | # standards devs for min-max normalization 20 | stds = [0.229, 0.224, 0.225] 21 | 22 | # binary model confidence threshold. Anything >= threshold will be classified as 1, otherwise 0 23 | binary_threshold = 0.5 24 | 25 | # norm 26 | norm = 1.0 27 | 28 | [rex.visual] 29 | 30 | # include classification and confidence information in title of plot, defaults to true 31 | info = false 32 | 33 | # pretty printing colour for explanations, defaults to 200 34 | colour = 150 35 | 36 | # alpha blend for main image, defaults to 0.2 (PIL Image.blend parameter) 37 | alpha = 0.1 38 | 39 | # produce unvarnished image with actual masking value, defaults to false 40 | raw = true 41 | 42 | # resize the explanation to the size of the original image. This uses cubic interpolation and will not be as visually accurate as not resizing, defaults to false 43 | resize = true 44 | 45 | # whether to show progress bar in the terminal, defalts to true 46 | progress_bar = false 47 | 48 | # overlay a 10*10 grid on an explanation, defaults to false 49 | grid = true 50 | 51 | # mark quickshift segmentation on image 52 | mark_segments = true 53 | 54 | # matplotlib colourscheme for responsibility map plotting, defaults to 'magma' 55 | heatmap_colours = 'viridis' 56 | 57 | # multi_style explanations, either or 58 | multi_style = "separate" 59 | 60 | [causal] 61 | # maximum depth of tree, defaults to 10, note that search can actually go beyond this number on occasion, as the 62 | # check only occurs at the end of an iteration 63 | tree_depth = 5 64 | 65 | # limit on number of combinations to consider , defaults to none. 66 | # It is **not** the total work done by ReX over all iterations. Leaving the search limit at none 67 | # can potentially be very expensive. 68 | search_limit = 1000 69 | 70 | # number of times to run the algorithm, defaults to 20 71 | iters = 30 72 | 73 | # minimum child size, in pixels 74 | min_box_size = 20 75 | 76 | # remove passing mutants which have a confidence less thatn . Defaults to 0.0 (meaning all mutants are considered) 77 | confidence_filter = 0.5 78 | 79 | # whether to weight responsibility by prediction confidence, default to false 80 | weighted = true 81 | 82 | # queue_style = "intersection" | "area" | "all" | "dc", defaults to "area" 83 | queue_style = "intersection" 84 | 85 | # maximum number of things to hold in search queue, either an integer or 'all' 86 | queue_len = 2 87 | 88 | # concentrate: weight responsibility by size and depth of passing partition. Defaults to false 89 | concentrate = true 90 | 91 | [causal.distribution] 92 | # distribution for splitting the box, defaults to uniform. Possible choices are 'uniform' | 'binom' | 'betabinom' | 'adaptive' 93 | distribution = 'betabinom' 94 | 95 | # blend one of the above distributions with the responsibility map treated as a distribution 96 | blend = 0.5 97 | 98 | # supplimental arguments for distribution creation, these are ignored if does not take any parameters 99 | distribution_args = [1.1, 1.1] 100 | 101 | [explanation] 102 | # iterate through pixel ranking in chunks, defaults to causal.min_box_size 103 | chunk_size = 16 104 | 105 | [explanation.spatial] 106 | # initial search radius 107 | spatial_initial_radius = 20 108 | 109 | # increment to change radius 110 | spatial_radius_eta = 0.1 111 | 112 | # number of times to expand before quitting, defaults to 4 113 | no_expansions = 1 114 | 115 | [explanation.multi] 116 | # multi method (just spotlight so far) 117 | strategy = 'spotlight' 118 | 119 | # no of spotlights to launch 120 | spotlights = 5 121 | 122 | # default size of spotlight 123 | spotlight_size = 10 124 | 125 | # decrease spotlight by this amount 126 | spotlight_eta = 0.5 127 | 128 | spotlight_step = 10 129 | 130 | # maximum number of random steps that a spotlight can make before quitting 131 | max_spotlight_budget = 30 132 | 133 | # objective function for spotlight search. Possible options 'mean' | 'max' | "none", defaults to "none" 134 | spotlight_objective_function = 'mean' 135 | 136 | # how much overlap to allow between different explanations. This is the dice coefficient, so 137 | # 0.0 means no permitted overlap, and 1 total overlap permitted 138 | permitted_overlap = 0.1 139 | 140 | [explanation.evaluation] 141 | 142 | # insertion/deletion curve step size 143 | insertion_step = 50 144 | 145 | # normalise insertion/deletion curves by confidence of original data 146 | normalise_curves = false 147 | -------------------------------------------------------------------------------- /tests/test_data/spectrum_class_DNA.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/spectrum_class_DNA.npy -------------------------------------------------------------------------------- /tests/test_data/spectrum_class_noDNA.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/spectrum_class_noDNA.npy -------------------------------------------------------------------------------- /tests/test_data/starfish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/starfish.jpg -------------------------------------------------------------------------------- /tests/test_data/tennis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/tennis.jpg -------------------------------------------------------------------------------- /tests/test_data/testimage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/testimage.png -------------------------------------------------------------------------------- /tests/unit_tests/box_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | 4 | from rex_xai.mutants.box import box_dimensions 5 | 6 | 7 | def test_data(data_3d, data_2d): 8 | assert data_3d.model_shape == [1, 64, 64, 64] 9 | assert data_3d.mode == "voxel" 10 | assert data_2d.model_shape == [1, 64, 64] 11 | assert data_2d.mode == "spectral" 12 | 13 | def test_initialise_tree_3d(box_3d): 14 | assert box_3d.depth_start == 0 15 | assert box_3d.depth_stop == 64 16 | 17 | assert box_dimensions(box_3d) == (0, 64, 0, 64, 0, 64) 18 | 19 | assert ( 20 | box_3d.__repr__() 21 | == "Box < name: R, row_start: 0, row_stop: 64, col_start: 0, col_stop: 64, depth_start: 0, depth_stop: 64, volume: 262144" 22 | ) 23 | 24 | assert box_3d.shape() == (64, 64, 64) 25 | assert box_3d.area() == 262144 26 | assert box_3d.corners() == (0, 64, 0, 64, 0, 64) 27 | 28 | 29 | def test_initialise_tree_2d(box_2d): 30 | # Depth does not exist as attribute for 2D boxes 31 | assert box_2d.depth_start is None 32 | assert box_2d.depth_stop is None 33 | 34 | assert box_dimensions(box_2d) == (0, 64, 0, 64) 35 | 36 | assert ( 37 | box_2d.__repr__() 38 | == "Box < name: R, row_start: 0, row_stop: 64, col_start: 0, col_stop: 64, area: 4096" 39 | ) 40 | 41 | assert box_2d.shape() == (64, 64) 42 | assert box_2d.area() == 4096 43 | 44 | assert ( 45 | box_2d.corners() == (0, 64, 0, 64) 46 | ) # TODO: Box_dimensions has the same functionality as corners, should we remove one of them. 47 | 48 | def test_spawn_children_3d(box_3d, resp_map_3d): 49 | # Set seed 50 | np.random.seed(24) 51 | children_3d = box_3d.spawn_children(min_size=20, mode="voxel", map=resp_map_3d) 52 | assert len(children_3d) == 4 53 | assert children_3d[0].area() < 262144 54 | 55 | total_area_3d = ( 56 | children_3d[0].area() 57 | + children_3d[1].area() 58 | + children_3d[2].area() 59 | + children_3d[3].area() 60 | ) 61 | assert total_area_3d == 262144 62 | 63 | volumes = [1536, 2560, 96768, 161280] 64 | 65 | # Check splitting of boxes 66 | row_split, col_split = 1, 24 67 | 68 | row_starts = [0, 0, row_split, row_split] 69 | row_stops = [row_split, row_split, 64, 64] 70 | col_starts = [0, col_split, 0, col_split] 71 | col_stops = [col_split, 64, col_split, 64] 72 | 73 | for i in range(4): 74 | assert children_3d[i].name == f"R:{i}" 75 | # Not split through depth 76 | assert children_3d[i].depth_start == 0 77 | assert children_3d[i].depth_stop == 64 78 | 79 | assert children_3d[i].area() == volumes[i] 80 | 81 | assert children_3d[i].row_start == row_starts[i] 82 | assert children_3d[i].row_stop == row_stops[i] 83 | assert children_3d[i].col_start == col_starts[i] 84 | assert children_3d[i].col_stop == col_stops[i] 85 | 86 | 87 | 88 | def test_spawn_children_2d(box_2d, resp_map_2d): 89 | # Set seed 90 | np.random.seed(24) 91 | children_2d = box_2d.spawn_children(min_size=2, mode="RGB", map=resp_map_2d) 92 | assert len(children_2d) == 4 93 | assert children_2d[0].area() < 4096 94 | 95 | total_area_2d = ( 96 | children_2d[0].area() 97 | + children_2d[1].area() 98 | + children_2d[2].area() 99 | + children_2d[3].area() 100 | ) 101 | assert total_area_2d == 4096 102 | print(children_2d) 103 | areas = [210, 174, 2030, 1682] 104 | 105 | # Check splitting of boxes 106 | row_split, col_split = 6, 35 107 | 108 | row_starts = [0, 0, row_split, row_split] 109 | row_stops = [row_split, row_split, 64, 64] 110 | col_starts = [0, col_split, 0, col_split] 111 | col_stops = [col_split, 64, col_split, 64] 112 | 113 | for i in range(4): 114 | assert children_2d[i].name == f"R:{i}" 115 | assert children_2d[i].area() == areas[i] 116 | 117 | assert children_2d[i].row_start == row_starts[i] 118 | assert children_2d[i].row_stop == row_stops[i] 119 | assert children_2d[i].col_start == col_starts[i] 120 | assert children_2d[i].col_stop == col_stops[i] 121 | 122 | 123 | -------------------------------------------------------------------------------- /tests/unit_tests/cmd_args_test.py: -------------------------------------------------------------------------------- 1 | from types import ModuleType 2 | 3 | import pytest 4 | from rex_xai.utils._utils import Strategy 5 | from rex_xai.input.config import CausalArgs, cmdargs_parser, process_cmd_args, shared_args 6 | 7 | 8 | @pytest.fixture 9 | def non_default_cmd_args(): 10 | args_list = [ 11 | "filename.jpg", 12 | "--output", 13 | "output_path.jpg", 14 | "--config", 15 | "path/to/rex.toml", 16 | "--processed", 17 | "--script", 18 | "tests/scripts/pytorch_resnet50.py", 19 | "-vv", 20 | "--surface", 21 | "surface_path.jpg", 22 | "--heatmap", 23 | "heatmap_path.jpg", 24 | "--model", 25 | "path/to/model.onnx", 26 | "--strategy", 27 | "multi", 28 | "--database", 29 | "path/to/database.db", 30 | "--multi", 31 | "5", 32 | "--iters", 33 | "10", 34 | "--analyse", 35 | "--mode", 36 | "RGB", 37 | ] 38 | parser = cmdargs_parser() 39 | cmd_args = parser.parse_args(args_list) 40 | 41 | return cmd_args 42 | 43 | 44 | def test_process_cmd_args(non_default_cmd_args): 45 | args = CausalArgs() 46 | process_cmd_args(non_default_cmd_args, args) 47 | 48 | assert isinstance(args.script, ModuleType) 49 | assert args.path == non_default_cmd_args.filename 50 | assert args.strategy == Strategy.MultiSpotlight 51 | assert args.iters == int(non_default_cmd_args.iters) 52 | assert args.analyse 53 | assert args.spotlights == int(non_default_cmd_args.multi) 54 | 55 | 56 | def test_process_shared_args(non_default_cmd_args): 57 | args = CausalArgs() 58 | shared_args(non_default_cmd_args, args) 59 | 60 | assert args.config_location == non_default_cmd_args.config 61 | assert args.model == non_default_cmd_args.model 62 | assert args.surface == non_default_cmd_args.surface 63 | assert args.heatmap == non_default_cmd_args.heatmap 64 | assert args.output == non_default_cmd_args.output 65 | assert args.verbosity == non_default_cmd_args.verbose 66 | assert args.db == non_default_cmd_args.database 67 | assert args.mode == non_default_cmd_args.mode 68 | assert args.processed == non_default_cmd_args.processed 69 | 70 | 71 | def test_quiet_overrides_verbose(): 72 | cmd_args_list = ["filename.jpg", "-vv", "--quiet"] 73 | parser = cmdargs_parser() 74 | cmd_args = parser.parse_args(cmd_args_list) 75 | args = CausalArgs() 76 | shared_args(cmd_args, args) 77 | 78 | assert args.verbosity == 0 79 | 80 | 81 | def test_contrastive(): 82 | cmd_args_list = ["filename.jpg", "--contrastive", "5"] 83 | parser = cmdargs_parser() 84 | cmd_args = parser.parse_args(cmd_args_list) 85 | args = CausalArgs() 86 | process_cmd_args(cmd_args, args) 87 | 88 | assert args.strategy == Strategy.Contrastive 89 | assert args.spotlights == int(cmd_args.contrastive) 90 | 91 | 92 | def test_spectral(): 93 | cmd_args_list = ["filename.jpg", "--spectral"] 94 | parser = cmdargs_parser() 95 | cmd_args = parser.parse_args(cmd_args_list) 96 | args = CausalArgs() 97 | shared_args(cmd_args, args) 98 | 99 | assert args.mode == "spectral" 100 | -------------------------------------------------------------------------------- /tests/unit_tests/config_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | from rex_xai.utils._utils import Queue, Strategy 5 | from rex_xai.input.config import CausalArgs, process_config_dict, read_config_file 6 | from rex_xai.mutants.distributions import Distribution 7 | 8 | 9 | @pytest.fixture 10 | def non_default_args(): 11 | non_default_args = CausalArgs() 12 | # rex 13 | non_default_args.mask_value = "mean" 14 | non_default_args.seed = 42 15 | non_default_args.gpu = False 16 | non_default_args.batch_size = 32 17 | # rex.onnx 18 | non_default_args.means = [0.485, 0.456, 0.406] 19 | non_default_args.stds = [0.229, 0.224, 0.225] 20 | non_default_args.binary_threshold = 0.5 21 | non_default_args.norm = 1.0 22 | # rex.visual 23 | non_default_args.info = False 24 | non_default_args.colour = 150 25 | non_default_args.alpha = 0.1 26 | non_default_args.raw = True 27 | non_default_args.resize = True 28 | non_default_args.progress_bar = False 29 | non_default_args.grid = True 30 | non_default_args.mark_segments = True 31 | non_default_args.heatmap_colours = "viridis" 32 | non_default_args.multi_style = "separate" 33 | # causal 34 | non_default_args.tree_depth = 5 35 | non_default_args.search_limit = 1000 36 | non_default_args.iters = 30 37 | non_default_args.min_box_size = 20 38 | non_default_args.confidence_filter = 0.5 39 | non_default_args.weighted = True 40 | non_default_args.queue_style = Queue.Intersection 41 | non_default_args.queue_len = 2 42 | non_default_args.concentrate = True 43 | # causal.distribution 44 | non_default_args.distribution = Distribution.BetaBinomial 45 | non_default_args.blend = 0.5 46 | non_default_args.distribution_args = [1.1, 1.1] 47 | # explanation 48 | non_default_args.chunk_size = 16 49 | non_default_args.spatial_initial_radius = 20 50 | non_default_args.spatial_radius_eta = 0.1 51 | non_default_args.no_expansions = 1 52 | # explanation.multi 53 | non_default_args.strategy = Strategy.MultiSpotlight 54 | non_default_args.spotlights = 5 55 | non_default_args.spotlight_size = 10 56 | non_default_args.spotlight_eta = 0.5 57 | non_default_args.spotlight_step = 10 58 | non_default_args.max_spotlight_budget = 30 59 | non_default_args.spotlight_objective_function = "mean" 60 | non_default_args.permitted_overlap = 0.1 61 | # explanation.evaluation 62 | non_default_args.insertion_step = 50 63 | non_default_args.normalise_curves = False 64 | 65 | return non_default_args 66 | 67 | 68 | def test_process_config_dict(non_default_args): 69 | args = CausalArgs() 70 | config_dict = read_config_file("tests/test_data/rex-test-all-config.toml") 71 | 72 | process_config_dict(config_dict, args) 73 | 74 | assert vars(args) == vars(non_default_args) 75 | 76 | 77 | def test_process_config_dict_empty(): 78 | args = CausalArgs() 79 | config_dict = {} 80 | orig_args = copy.deepcopy(args) 81 | 82 | process_config_dict(config_dict, args) 83 | 84 | assert vars(args) == vars(orig_args) 85 | 86 | 87 | def test_process_config_dict_invalid_arg(caplog): 88 | args = CausalArgs() 89 | config_dict = {"explanation": {"chunk": 10}} 90 | 91 | process_config_dict(config_dict, args) 92 | assert ( 93 | caplog.records[0].message == "Invalid or misplaced parameter 'chunk', skipping!" 94 | ) 95 | 96 | 97 | def test_process_config_dict_invalid_distribution(caplog): 98 | args = CausalArgs() 99 | config_dict = { 100 | "causal": {"distribution": {"distribution": "an-invalid-distribution"}} 101 | } 102 | 103 | process_config_dict(config_dict, args) 104 | 105 | assert args.distribution == Distribution.Uniform 106 | assert ( 107 | caplog.records[0].message 108 | == "Invalid distribution 'an-invalid-distribution', reverting to default value Distribution.Uniform" 109 | ) 110 | 111 | 112 | def test_process_config_dict_uniform_distribution(): 113 | args = CausalArgs() 114 | config_dict = { 115 | "causal": { 116 | "distribution": {"distribution": "uniform", "distribution_args": [0.0, 0.0]} 117 | } 118 | } 119 | 120 | process_config_dict(config_dict, args) 121 | 122 | assert args.distribution == Distribution.Uniform 123 | assert args.distribution_args is None 124 | 125 | 126 | def test_process_config_dict_distribution_args(): 127 | args = CausalArgs() 128 | config_dict = { 129 | "causal": { 130 | "distribution": { 131 | "distribution": "betabinom", 132 | "distribution_args": [0.0, 0.0], 133 | } 134 | } 135 | } 136 | 137 | process_config_dict(config_dict, args) 138 | 139 | assert args.distribution == Distribution.BetaBinomial 140 | assert args.distribution_args == [0.0, 0.0] 141 | 142 | 143 | def test_process_config_dict_queue_style(): 144 | args = CausalArgs() 145 | config_dict = {"causal": {"queue_style": "all"}} 146 | 147 | process_config_dict(config_dict, args) 148 | assert args.queue_style == Queue.All 149 | 150 | 151 | def test_process_config_dict_queue_style_upper(): 152 | args = CausalArgs() 153 | config_dict = {"causal": {"queue_style": "ALL"}} 154 | 155 | process_config_dict(config_dict, args) 156 | assert args.queue_style == Queue.All 157 | 158 | 159 | def test_process_config_dict_queue_style_invalid(caplog): 160 | args = CausalArgs() 161 | config_dict = {"causal": {"queue_style": "an-invalid-queue-style"}} 162 | 163 | process_config_dict(config_dict, args) 164 | assert args.queue_style == Queue.Area 165 | assert ( 166 | caplog.records[0].message 167 | == "Invalid queue style 'an-invalid-queue-style', reverting to default value Queue.Area" 168 | ) 169 | 170 | 171 | def test_process_config_dict_strategy(): 172 | args = CausalArgs() 173 | config_dict = {"explanation": {"multi": {"strategy": "spotlight"}}} 174 | 175 | process_config_dict(config_dict, args) 176 | assert args.strategy == Strategy.MultiSpotlight 177 | 178 | 179 | def test_process_config_dict_strategy_invalid(caplog): 180 | args = CausalArgs() 181 | config_dict = {"explanation": {"multi": {"strategy": "an-invalid-strategy"}}} 182 | 183 | process_config_dict(config_dict, args) 184 | assert ( 185 | caplog.records[0].message 186 | == "Invalid strategy 'an-invalid-strategy', reverting to default value Strategy.Global" 187 | ) 188 | assert args.strategy == Strategy.Global 189 | -------------------------------------------------------------------------------- /tests/unit_tests/data_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | from rex_xai.explanation.rex import load_and_preprocess_data 5 | from rex_xai.input.input_data import Data 6 | 7 | tab = np.arange(0, 1, 999) 8 | voxel = np.random.rand(1, 64, 64, 64) 9 | 10 | 11 | def test_data(): 12 | data = Data(input=tab, model_shape=[1, 999], device="cpu") 13 | assert data.model_shape == [1, 999] 14 | assert data.mode == "spectral" 15 | 16 | 17 | def test_set_mask_value(args_custom, model_shape, cpu_device, caplog): 18 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom) 19 | data.set_mask_value("spectral") 20 | 21 | assert ( 22 | caplog.records[0].msg 23 | == "Mask value 'spectral' can only be used if mode is also 'spectral', using default mask value 0 instead" 24 | ) 25 | assert data.mask_value == 0 26 | 27 | 28 | def test_3D_data(): 29 | data = Data(input=voxel, model_shape=[1, 64, 64, 64], device="cpu") 30 | assert data.model_shape == [1, 64, 64, 64] 31 | assert data.mode == "voxel" 32 | -------------------------------------------------------------------------------- /tests/unit_tests/database_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from rex_xai.output.database import db_to_pandas, initialise_rex_db, update_database 5 | 6 | 7 | @pytest.fixture 8 | def db(tmp_path): 9 | p = tmp_path / "rex.db" 10 | db = initialise_rex_db(p) 11 | return db 12 | 13 | 14 | def test_update_database(exp_extracted, tmp_path): 15 | p = tmp_path / "rex.db" 16 | db = initialise_rex_db(p) 17 | update_database(db, exp_extracted) 18 | assert os.path.exists(p) 19 | assert os.stat(p).st_size > 0 20 | 21 | 22 | def test_update_database_no_target(exp_extracted, db, caplog): 23 | exp_extracted.data.target = None 24 | update_database(db, exp_extracted) 25 | assert caplog.records[0].message == "unable to update database as target is None" 26 | 27 | 28 | def test_update_database_no_exp(exp_extracted, db, caplog): 29 | exp_extracted.final_mask = None 30 | update_database(db, exp_extracted) 31 | assert ( 32 | caplog.records[0].message == "unable to update database as explanation is empty" 33 | ) 34 | 35 | 36 | def test_read_db(exp_extracted, tmp_path): 37 | p = tmp_path / "rex.db" 38 | db = initialise_rex_db(p) 39 | update_database(db, exp_extracted) 40 | df = db_to_pandas(p) 41 | assert df.shape == (1, 30) 42 | 43 | 44 | def test_no_multi(exp_extracted, caplog): 45 | update_database(db, exp_extracted, multi=True) 46 | assert ( 47 | caplog.records[0].message 48 | == "unable to update database, multi=True is only valid for MultiExplanation objects" 49 | ) 50 | 51 | 52 | def test_update_database_multiexp(exp_multi, tmp_path): 53 | p = tmp_path / "rex.db" 54 | db = initialise_rex_db(p) 55 | update_database(db, exp_multi) 56 | assert os.path.exists(p) 57 | assert os.stat(p).st_size > 0 58 | 59 | 60 | def test_read_database_multiexp(exp_multi, tmp_path): 61 | p = tmp_path / "rex.db" 62 | db = initialise_rex_db(p) 63 | update_database(db, exp_multi, multi=True) 64 | 65 | df = db_to_pandas(p) 66 | assert df.shape == (len(exp_multi.explanations), 30) 67 | -------------------------------------------------------------------------------- /tests/unit_tests/multi-explanation_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import pytest 6 | from rex_xai.utils._utils import Strategy 7 | from rex_xai.explanation.rex import calculate_responsibility 8 | from rex_xai.explanation.explanation import Explanation 9 | from rex_xai.explanation.multi_explanation import MultiExplanation 10 | 11 | 12 | @pytest.mark.parametrize("spotlights", [5, 10]) 13 | def test_multiexplanation(data_multi, args_multi, prediction_func, spotlights, caplog): 14 | args_multi.spotlights = spotlights 15 | 16 | maps, run_stats = calculate_responsibility(data_multi, args_multi, prediction_func) 17 | 18 | exp = Explanation(maps, prediction_func, data_multi, args_multi, run_stats) 19 | exp.extract(method=Strategy.Global) 20 | 21 | multi_exp = MultiExplanation( 22 | maps, prediction_func, data_multi, args_multi, run_stats 23 | ) 24 | caplog.set_level(logging.INFO) 25 | multi_exp.extract(Strategy.MultiSpotlight) 26 | 27 | n_exp = 0 28 | for record in caplog.records: 29 | print(record) 30 | if "found an explanation" in record.message: 31 | n_exp += 1 32 | 33 | assert ( 34 | caplog.records[-1].message 35 | == f"ReX has found a total of {n_exp} explanations via spotlight search" 36 | ) 37 | assert n_exp == len(multi_exp.explanations) 38 | assert len(multi_exp.explanations) <= spotlights # always true 39 | assert ( 40 | len(multi_exp.explanations) == spotlights 41 | ) # not always true but is for this data/parameters 42 | assert np.array_equal( 43 | multi_exp.explanations[0].detach().cpu().numpy(), exp.final_mask 44 | ) # first explanation is global explanation 45 | 46 | 47 | def test_multiexplanation_save_composite(exp_multi, tmp_path): 48 | clauses = exp_multi.separate_by(0.0) 49 | 50 | p = tmp_path / "exp.png" 51 | exp_multi.save(path=p, multi_style="composite", clauses=None) 52 | 53 | assert os.path.exists(p) 54 | assert os.stat(p).st_size > 0 55 | 56 | for c in clauses: 57 | exp_multi.save(path=p, multi_style="composite", clauses=c) 58 | clause_path = tmp_path / f"exp_{c}.png" 59 | assert os.path.exists(clause_path) 60 | assert os.stat(clause_path).st_size > 0 61 | 62 | 63 | def test_multiexplanation_save_separate(exp_multi, tmp_path): 64 | p = tmp_path / "exp.png" 65 | exp_multi.save(path=p, multi_style="separate") 66 | 67 | for i in range(len(exp_multi.explanations)): 68 | exp_path = tmp_path / f"exp_{i}.png" 69 | assert os.path.exists(exp_path) 70 | assert os.stat(exp_path).st_size > 0 71 | -------------------------------------------------------------------------------- /tests/unit_tests/preprocessing_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.input.config import validate_args 3 | from rex_xai.explanation.rex import ( 4 | predict_target, 5 | try_preprocess, 6 | ) 7 | 8 | 9 | def test_preprocess_nii_notimplemented(args, model_shape, cpu_device, caplog): 10 | args.path = "tests/test_data/dog.nii" 11 | data = try_preprocess(args, model_shape, device=cpu_device) 12 | 13 | assert data == NotImplemented 14 | assert caplog.records[0].msg == "we do not (yet) handle nifti files generically" 15 | 16 | 17 | def test_predict_target(data, prediction_func): 18 | target = predict_target(data, prediction_func) 19 | 20 | assert target.classification == 207 21 | assert target.confidence == pytest.approx(0.253237, abs=2.5e-6) 22 | 23 | 24 | def test_validate_args(args): 25 | args.path = None # type: ignore 26 | with pytest.raises(FileNotFoundError): 27 | validate_args(args) 28 | 29 | 30 | def test_preprocess_rgba(args, model_shape, prediction_func, cpu_device, caplog): 31 | args.path = "assets/rex_logo.png" 32 | data = try_preprocess(args, model_shape, device=cpu_device) 33 | predict_target(data, prediction_func) 34 | 35 | assert caplog.records[0].msg == "RGBA input image provided, converting to RGB" 36 | assert data.mode == "RGB" 37 | assert data.input.mode == "RGB" 38 | assert data.data is not None 39 | assert data.data.shape[1] == 3 # batch, channels, height, width 40 | -------------------------------------------------------------------------------- /tests/unit_tests/validate_args_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rex_xai.utils._utils import ReXTomlError 3 | from rex_xai.input.config import CausalArgs, validate_args 4 | 5 | 6 | def test_no_path(args): 7 | args.path = None # type: ignore 8 | with pytest.raises(FileNotFoundError): 9 | validate_args(args) 10 | 11 | 12 | def test_blend_invalid(caplog): 13 | args = CausalArgs() 14 | args.blend = 20 15 | with pytest.raises(ReXTomlError): 16 | validate_args(args) 17 | assert ( 18 | caplog.records[0].message 19 | == "Invalid value '20': must be between 0.0 and 1.0" 20 | ) 21 | 22 | 23 | def test_permitted_overlap_invalid(caplog): 24 | args = CausalArgs() 25 | args.permitted_overlap = -5 26 | with pytest.raises(ReXTomlError): 27 | validate_args(args) 28 | assert ( 29 | caplog.records[0].message 30 | == "Invalid value '-5': must be between 0.0 and 1.0" 31 | ) 32 | 33 | 34 | def test_iters_invalid(caplog): 35 | args = CausalArgs() 36 | args.iters = 0 37 | with pytest.raises(ReXTomlError): 38 | validate_args(args) 39 | assert caplog.records[0].message == "Invalid value '0': must be more than 0.0" 40 | 41 | 42 | def test_raw_invalid(caplog): 43 | args = CausalArgs() 44 | args.raw = 100 # type: ignore 45 | with pytest.raises(ReXTomlError): 46 | validate_args(args) 47 | assert caplog.records[0].message == "Invalid value '100': must be boolean" 48 | 49 | 50 | def test_multi_style_invalid(caplog): 51 | args = CausalArgs() 52 | args.multi_style = "an-invalid-style" 53 | with pytest.raises(ReXTomlError): 54 | validate_args(args) 55 | assert ( 56 | caplog.records[0].message 57 | == "Invalid value 'an-invalid-style' for multi_style, must be 'composite' or 'separate'" 58 | ) 59 | 60 | 61 | def test_queue_len_invalid(caplog): 62 | args = CausalArgs() 63 | args.queue_len = 7.5 # type: ignore 64 | with pytest.raises(ReXTomlError): 65 | validate_args(args) 66 | assert ( 67 | caplog.records[0].message 68 | == "Invalid value '7.5' for queue_len, must be 'all' or an integer" 69 | ) 70 | 71 | 72 | def test_distribution_args_invalid(caplog): 73 | args = CausalArgs() 74 | args.distribution_args = 1 # type: ignore 75 | with pytest.raises(ReXTomlError): 76 | validate_args(args) 77 | assert caplog.records[0].message == "distribution args must be length 2, not 1" 78 | 79 | args.distribution_args = [0, -1] 80 | with pytest.raises(ReXTomlError): 81 | validate_args(args) 82 | assert ( 83 | caplog.records[0].message 84 | == "All values in distribution args must be more than zero" 85 | ) 86 | 87 | 88 | def test_colour_map_invalid(caplog): 89 | args = CausalArgs() 90 | args.heatmap_colours = "RedBlue" 91 | with pytest.raises(ReXTomlError): 92 | validate_args(args) 93 | assert ( 94 | caplog.records[0].message 95 | == "Invalid colourmap 'RedBlue', must be a valid matplotlib colourmap" 96 | ) 97 | -------------------------------------------------------------------------------- /tests/unit_tests/visualisation_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch as tt 4 | from rex_xai.output.visualisation import save_image, voxel_plot 5 | 6 | from rex_xai.input.config import CausalArgs 7 | 8 | 9 | def test_surface(exp_extracted, tmp_path): 10 | p = tmp_path / "surface.png" 11 | exp_extracted.surface_plot(path=p) 12 | 13 | assert os.path.exists(p) 14 | assert os.stat(p).st_size > 0 15 | 16 | 17 | def test_heatmap(exp_extracted, tmp_path): 18 | p = tmp_path / "heatmap.png" 19 | exp_extracted.heatmap_plot(path=p) 20 | 21 | assert os.path.exists(p) 22 | assert os.stat(p).st_size > 0 23 | 24 | 25 | def test_save_exp(exp_extracted, tmp_path): 26 | p = tmp_path / "exp.png" 27 | exp_extracted.save(path=p) 28 | 29 | assert os.path.exists(p) 30 | assert os.stat(p).st_size > 0 31 | 32 | def test_save_image_3d(data_3d): 33 | # Explanation mask for the voxel data - random values of 0s and 1s 34 | explanation = tt.zeros((1, 64, 64, 64), dtype=tt.bool, device="cpu") 35 | explanation[0, 32:64, 32:64, 32:64] = 1 36 | args = CausalArgs() 37 | args.mode = "voxel" 38 | args.output = "test.png" 39 | save_image(explanation, data_3d, args, path=args.output) 40 | assert os.path.exists(args.output) 41 | assert os.path.getsize(args.output) > 0 42 | 43 | os.remove(args.output) 44 | 45 | 46 | def test_voxel_plot(data_3d, resp_map_3d): 47 | args = CausalArgs() 48 | print(data_3d) 49 | # Create a cube in data 50 | voxel_plot(args, resp_map_3d, data_3d, path="test.png") 51 | for i in ["x", "y", "z"]: 52 | assert os.path.exists(f"test_{i}_slice.png") 53 | assert os.path.getsize(f"test_{i}_slice.png") > 0 54 | os.remove(f"test_{i}_slice.png") 55 | --------------------------------------------------------------------------------