├── .github
└── workflows
│ ├── linting-and-type-check.yml
│ ├── publish.yml
│ ├── test-package-and-comment.yml
│ └── test-python-package.yml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── bus_757.png
├── id_curve.png
├── ladybird_301.png
├── ladybird_301_raw.png
├── ladybird_rm.png
├── lizard_subs.png
├── occluded_bus_rm.png
├── peacock.jpg
├── peacock_05.png
├── peacock_84_00.png
├── peacock_84_01.png
├── peacock_84_02.png
├── peacock_comp.png
├── peacock_exp.png
├── rex-assumptions-768x259.png
├── rex-structure-600x129.png
├── rex_logo.png
├── rex_logo.svg
├── spectrum.png
├── tandem.jpg
├── tandem_contrastive.png
├── tandem_exp.png
├── tandem_multi_1.png
├── tandem_multi_2.png
└── tandem_multi_3.png
├── docs
├── Makefile
├── _static
│ └── rex_logo.png
├── background.md
├── command_line.md
├── conf.py
├── config.md
├── contrastive.md
├── index.md
├── index.rst
├── make.bat
├── multiple.md
├── notebooks
│ └── intro.md
└── script.md
├── example.rex.toml
├── poetry.lock
├── pyproject.toml
├── rex_xai
├── __init__.py
├── explanation
│ ├── evaluation.py
│ ├── explanation.py
│ ├── multi_explanation.py
│ └── rex.py
├── input
│ ├── config.py
│ ├── input_data.py
│ └── onnx.py
├── mutants
│ ├── box.py
│ ├── distributions.py
│ ├── mutant.py
│ └── occlusions.py
├── output
│ ├── database.py
│ └── visualisation.py
├── responsibility
│ ├── prediction.py
│ ├── resp_maps.py
│ └── responsibility.py
├── rex_wrapper.py
└── utils
│ ├── _utils.py
│ └── logger.py
├── scripts
├── pytorch.py
├── spectral_pytorch.py
└── timm_script.py
├── shell
└── _rex
└── tests
├── conftest.py
├── scripts
├── pytorch_3d.py
├── pytorch_resnet50.py
└── pytorch_swin_v2_t.py
├── snapshot_tests
├── __snapshots__
│ ├── _explanation_onnx_test.ambr
│ ├── _explanation_test.ambr
│ ├── explanation_test.ambr
│ ├── load_preprocess_test.ambr
│ └── spectral_test.ambr
├── _explanation_onnx_test.py
├── _explanation_test.py
├── explanation_test.py
├── load_preprocess_test.py
└── spectral_test.py
├── test_data
├── 004_0002.jpg
├── 2008_000033.jpg
├── DoublePeakClass 0 Mean 1.npy
├── DoublePeakClass 0 Mean.npy
├── DoublePeakClass 1 Mean.npy
├── DoublePeakClass 2 Mean.npy
├── ILSVRC2012_val_00047302.JPEG
├── TCGA_DU_7018_19911220_14.tif
├── bike.jpg
├── dog.jpg
├── dog_hide.jpg
├── ladybird.jpg
├── lizard.jpg
├── peacock.jpg
├── positive193.npy
├── rex-test-all-config.toml
├── spectrum_class_DNA.npy
├── spectrum_class_noDNA.npy
├── starfish.jpg
├── tennis.jpg
└── testimage.png
└── unit_tests
├── box_test.py
├── cmd_args_test.py
├── config_test.py
├── data_test.py
├── database_test.py
├── multi-explanation_test.py
├── preprocessing_test.py
├── validate_args_test.py
└── visualisation_test.py
/.github/workflows/linting-and-type-check.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install ReX and carry out linting and type checking
2 |
3 | name: Linting and type checking
4 |
5 | on: pull_request
6 |
7 | jobs:
8 | lint:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | fail-fast: false
12 | matrix:
13 | python-version: ["3.13"]
14 |
15 | permissions:
16 | pull-requests: write
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Install poetry
21 | run: pipx install poetry
22 |
23 | - name: Set up Python, install and cache dependencies
24 | uses: actions/setup-python@v5
25 | with:
26 | python-version: ${{matrix.python-version}}
27 | cache: poetry
28 | - run: poetry install --with dev
29 | - run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH
30 |
31 | - name: Lint with Ruff
32 | run: |
33 | ruff check --output-format=github .
34 |
35 | - name: Run pyright with reviewdog
36 | uses: jordemort/action-pyright@e85f3910971e8bd8cec27d8c7235d1f99825e570
37 | with:
38 | github_token: ${{ secrets.GITHUB_TOKEN }}
39 | reporter: github-pr-review
40 |
41 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | test:
13 | name: Install and test package
14 | uses: ./.github/workflows/test-python-package.yml
15 | with:
16 | python-version: "3.13"
17 |
18 | build:
19 | name: Build package
20 | needs:
21 | - test
22 | runs-on: ubuntu-latest
23 | steps:
24 | - name: Checkout code
25 | uses: actions/checkout@v4
26 |
27 | - name: Install poetry
28 | run: pipx install poetry
29 |
30 | - name: Set up Python
31 | uses: actions/setup-python@v5
32 | with:
33 | python-version: "3.13"
34 | cache: poetry
35 |
36 | - name: Package project
37 | run: poetry build
38 |
39 | - name: Store the distribution packages
40 | uses: actions/upload-artifact@v4
41 | with:
42 | name: python-package-distributions
43 | path: dist/
44 |
45 | install_and_check_version:
46 | name: Install and check Version Number
47 | needs:
48 | - build
49 | runs-on: ubuntu-latest
50 | steps:
51 | - name: Download all the dists
52 | uses: actions/download-artifact@v4
53 | with:
54 | name: python-package-distributions
55 | path: dist/
56 |
57 | - name: Install
58 | run: pip install dist/*.tar.gz
59 |
60 | - name: Check version number
61 | run: |
62 | PYTHON_VERSION=`ReX --version`
63 | echo "PYTHON_VERSION=${PYTHON_VERSION}"
64 | GIT_VERSION=$GITHUB_REF_NAME
65 | echo "GIT_VERSION=${GIT_VERSION}" # NB that Github version should have a 'v' prefix
66 | if [ "v$PYTHON_VERSION" != "$GIT_VERSION" ]; then exit 1; fi
67 | echo "VERSION=${GIT_VERSION}" >> $GITHUB_OUTPUT
68 |
69 | pypi-publish:
70 | name: Upload release to PyPI
71 | needs:
72 | - install_and_check_version
73 | runs-on: ubuntu-latest
74 | environment:
75 | name: pypi
76 | url: https://pypi.org/project/rex_xai
77 | permissions:
78 | id-token: write
79 | steps:
80 | - name: Download all the dists
81 | uses: actions/download-artifact@v4
82 | with:
83 | name: python-package-distributions
84 | path: dist/
85 |
86 | - name: Publish package distributions to PyPI
87 | uses: pypa/gh-action-pypi-publish@release/v1
88 |
--------------------------------------------------------------------------------
/.github/workflows/test-package-and-comment.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install ReX and run tests, using the highest and lowest Python versions we support
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Install, run tests, and report test coverage
5 |
6 | on:
7 | pull_request:
8 |
9 |
10 | jobs:
11 | run-tests:
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | python-version: ["3.10", "3.13"]
16 | uses: ./.github/workflows/test-python-package.yml
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 |
20 | coverage_report:
21 | needs:
22 | - run-tests
23 | runs-on: ubuntu-latest
24 | permissions:
25 | pull-requests: write
26 |
27 | steps:
28 | - name: Download coverage report
29 | uses: actions/download-artifact@v4
30 | with:
31 | name: coverage-report-py3.13
32 |
33 | - name: Pytest coverage comment
34 | uses: MishaKav/pytest-coverage-comment@v1.1.52
35 | if: ${{ !github.event.pull_request.head.repo.fork }}
36 | with:
37 | pytest-xml-coverage-path: ./coverage.xml
38 |
--------------------------------------------------------------------------------
/.github/workflows/test-python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install the project and run tests
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Install and run tests
5 |
6 | on:
7 | workflow_call:
8 | inputs:
9 | python-version:
10 | required: true
11 | type: string
12 |
13 | jobs:
14 | test:
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v4
20 |
21 | - name: Install poetry
22 | run: pipx install poetry
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: ${{inputs.python-version}}
28 | cache: poetry
29 |
30 | - name: Install project and dependencies
31 | run: poetry install --with dev
32 |
33 | - name: Update PATH
34 | run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH
35 |
36 | - name: Cache model files for testing
37 | uses: actions/cache@v3
38 | env:
39 | cache-name: cache-model-files
40 | with:
41 | path: ~/.cache/cached_path
42 | key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('tests/conftest.py') }}
43 |
44 | - name: Test with pytest
45 | run: |
46 | pytest --junitxml=pytest.xml --cov-report=xml:coverage.xml --cov=rex_xai tests/unit_tests tests/snapshot_tests
47 |
48 | - name: Upload coverage report
49 | uses: actions/upload-artifact@v4
50 | with:
51 | name: coverage-report-py${{ inputs.python-version }}
52 | path: ./coverage.xml
53 |
54 |
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.log
3 | tmp/
4 | yolo/
5 | typings/
6 | results/
7 | scripts/
8 | dist/
9 | ReX.egg-info/
10 | local/
11 | runs/
12 | local/
13 | runs/
14 | benchmarks/
15 | rex.toml
16 | rgb_rex.toml
17 | l_rex.toml
18 | *.org
19 | *.onnx
20 |
21 | test_*.jpg
22 | ReX_*.jpg
23 | causal_*.jpg
24 | driver.py
25 | data_dump
26 | deepcover.toml
27 | *.egg
28 | build
29 | htmlcov
30 | __pycache__/
31 | one_image/
32 | data
33 | *.pickle
34 | *.swp
35 | *.db
36 | *.onnx
37 | *.h5
38 | *.hdf5
39 | *.pt
40 | *.json.gz
41 | .coverage
42 | rex_ai.egg-info/*
43 | *.pt
44 | *.json
45 | .coverage
46 |
47 | docs/_build/
48 | .jupyter
49 | .ipynb_checkpoints/
50 | *.ipynb
51 |
52 | .DS_Store
53 | .AppleDouble
54 | .LSOverride
55 |
56 | # Icon must end with two \r
57 | Icon
58 |
59 | # Thumbnails
60 | ._*
61 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file for Sphinx projects
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | # Set the OS, Python version and other tools you might need
8 | build:
9 | os: ubuntu-22.04
10 | tools:
11 | python: "3.12"
12 | jobs:
13 | post_create_environment:
14 | # Install poetry
15 | # https://python-poetry.org/docs/#installing-manually
16 | - python -m pip install poetry
17 | post_install:
18 | # Install package and all dependencies
19 | # https://python-poetry.org/docs/managing-dependencies/#dependency-groups
20 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH poetry install --with dev
21 |
22 | # Build documentation in the "docs/" directory with Sphinx
23 | sphinx:
24 | configuration: docs/conf.py
25 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs
26 | # builder: "dirhtml"
27 | # Fail on all warnings to avoid broken references
28 | # fail_on_warning: true
29 |
30 | # Optionally build your docs in additional formats such as PDF and ePub
31 | # formats:
32 | # - pdf
33 | # - epub
34 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guidelines
2 |
3 | ## Setting up a local development environment
4 |
5 | While you can install ReX through various methods, for development purposes we use [poetry](https://python-poetry.org/) for installation, development dependency management, and building the package. Install poetry following [its installation instructions](https://python-poetry.org/docs/) - it's recommended to install using `pipx`.
6 |
7 | Clone this repo and `cd` into it.
8 |
9 | Install ReX and its dependencies by running `poetry install`.
10 | This will install the versions of the dependencies given in `poetry.lock`.
11 | This ensures that the development environment is consistent for different people.
12 |
13 | The development dependencies (for generating documentation, linting, and running tests) are marked as optional, so to install these you will need to run instead `poetry install --with dev`.
14 |
15 | There are also some additional optional dependencies that are only required for working with 3D data.
16 | You can install these using `poetry install --extras 3D`.
17 |
18 | N.B. that poetry by default creates its own virtual environment for the project.
19 | However if you run `poetry install` in an activated virtual environment, it will detect and respect this.
20 | See the [poetry docs](https://python-poetry.org/docs/basic-usage/#using-your-virtual-environment) for more information.
21 |
22 | ## Testing
23 |
24 | We use [pytest](https://docs.pytest.org/en/stable/index.html) with the [pytest-cov](https://github.com/pytest-dev/pytest-cov) plugin.
25 |
26 | Run the tests by running `pytest`, which will automatically run all files of the form `test_*.py` or `*_test.py` in the current directory and its subdirectories.
27 | Run `pytest --cov=rex_xai tests/` to get a coverage report printed to the terminal.
28 | See [pytest-cov's documentation](https://pytest-cov.readthedocs.io/en/latest/) for additional reporting options.
29 |
30 | As the end-to-end tests which run the whole ReX pipeline can take a while to run, we have split the tests into two sub-directories: `tests/unit_tests/` and `tests/long_tests/`.
31 | During development you may wish to only run the faster unit tests.
32 | You can do this by specifying the directory: `pytest tests/unit_tests/`
33 | Both sets of tests are run by GitHub Actions upon a pull request.
34 |
35 | ### Updating snapshots
36 |
37 | Most of the end-to-end tests which run the whole ReX pipeline are 'snapshot tests' using the [syrupy](https://github.com/syrupy-project/syrupy) package.
38 | These tests involve comparing an object returned by the function under test to a previously saved 'snapshot' of that object.
39 | This can help identify unintentional changes in results that are introduced by new development.
40 | Note that snapshots are based on the text representation of an object, so don't necessarily capture *all* results you may care about.
41 |
42 | If a snapshot test fails, follow the steps below to confirm if the changes are expected or not and update the snapshots:
43 |
44 | * Run `pytest -vv` to see a detailed comparison of changes compared to the snapshot
45 | * Check whether these are expected or not. For example, if you have added an additional parameter in the `CausalArgs` class, you expect that parameter value to be missing from the snapshot.
46 | * If you only see expected differences in the snapshot, you can update the snapshots to match the new results by running `pytest --snapshot-update` and commit the updated files.
47 |
48 | ## Generating documentation with Sphinx
49 |
50 | Docs are automatically built on PRs and on updates to the repo's default branch, and are available at .
51 |
52 | To build documentation locally using Sphinx and sphinx-autoapi:
53 |
54 | ```sh
55 | cd docs/
56 | make html
57 | ```
58 |
59 | This will automatically generate documentation based on the code and docstrings and produce html files in `docs/_build/html`.
60 |
61 | ### Docstring style
62 |
63 | We prefer [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) docstrings, and use `sphinx.ext.napoleon` to parse them.
64 |
65 | ### Working with notebooks
66 |
67 | The [introduction to working with ReX interactively](https://rex-xai.readthedocs.io/en/latest/notebooks/intro.html) is written as a Jupyter notebook in markdown format.
68 | We use [MyST-NB](https://myst-nb.readthedocs.io/en/latest/index.html) to compile the notebook into html as part of the documentation.
69 | To more easily work with the notebook locally, you can use [Jupytext](https://jupytext.readthedocs.io/en/latest/) to generate an .ipynb notebook from the markdown file, edit the tutorial, and then convert the edited notebook back into md.
70 |
71 | ```sh
72 | # convert to .ipynb
73 | jupytext docs/notebooks/intro.md --to ipynb
74 | # convert back to markdown
75 | jupytext docs/notebooks/intro.ipynb --to myst
76 | ```
77 |
78 | Markdown format allows much clearer diffs when tracking the notebook with version control, so please don't add the .ipynb files to version control.
79 |
80 | ## Code linting and formatting
81 |
82 | This project uses [ruff](https://docs.astral.sh/ruff/) for code linting and formatting, to ensure a consistent code style and identify issues like unused imports.
83 | Install by running `poetry install --with dev`.
84 |
85 | Run the linter on all files in the current working directory with `ruff check`.
86 | Ruff can automatically fix some issues if you run `ruff check --fix`.
87 |
88 | Run `ruff format` to automatically format all files in the current working directory.
89 | Run `ruff format --diff` to get a preview of any changes that would be made.
90 |
91 | ## Type checking
92 |
93 | We use [Pyright](https://microsoft.github.io/pyright/#/) for type checking.
94 | You can [install](https://microsoft.github.io/pyright/#/installation) the command line tool and/or an extension for your favourite editor.
95 | Upon a pull request, a check is run that identifies Pyright errors/warnings in the lines that have been added in the PR.
96 | A review comment will be left for each change.
97 | Ideally, no new errors/warnings will be introduced in a PR, but this is not an enforced requirement to merge.
98 |
99 | ## GitHub Actions
100 |
101 | We use GitHub Actions to automatically run certain checks upon pull requests, and to automate releasing a new ReX version to PyPI.
102 |
103 | On a pull request, the following workflows run:
104 |
105 | * linting and type checking
106 | * installing the package and running tests (using Python 3.10 and 3.13)
107 | * test coverage is also measured
108 | * the docs are also built by ReadTheDocs (separate from GitHub Actions)
109 |
110 | When a new [release](https://docs.github.com/en/repositories/releasing-projects-on-github/about-releases) is created, the following workflows run:
111 |
112 | * installing the package and running tests (using Python 3.13)
113 | * building the package
114 | * checking that the installed package version matches the release tag
115 | * uploading the release to PyPI
116 |
117 | ## Publishing the package on PyPI
118 |
119 | To publish a new ReX version to PyPI, create a [release](https://docs.github.com/en/repositories/releasing-projects-on-github/about-releases).
120 | This will trigger a set of GitHub Actions workflows, which will run tests, check for version number consistency, and then publish the package to PyPI.
121 |
122 | When creating the release, typically the target branch should be `main`.
123 | The target branch should contain all the commits you want to be included in the new release.
124 |
125 | The release should be associated with a tag that has the form "vX.Y.Z" - note the "v" prefix!
126 | This can be a new tag that is created for the most recent commit at the time of the release, or can be a pre-eexisting tag.
127 |
128 | Give the release a title - this can just be the version number.
129 |
130 | Write some release notes explaining the changes incorporated in this release.
131 | Github offers the option to [automatically generate release notes](https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes) based on PRs merged in since the last release, which can be a good starting point.
132 | Here is [one possible example](https://gist.github.com/andreasonny83/24c733ae50cadf00fcf83bc8beaa8e6a) of how release notes can be structured, to give some ideas of what to include.
133 |
134 | The release can be saved as a draft.
135 | When you are ready, use the "Publish release" button to publish the release and trigger the Github Actions workflow that will publish it to PyPI.
136 |
137 | If there are any issues with the workflow, it can also be re-run manually.
138 | Navigate to the workflow in the Actions tab of the repo and use the "Run workflow" button to run it manually (after fixing any known issues).
139 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ReX-XAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ReX: Causal *R*esponsibility *EX*planations for image classifiers
2 |
3 | 
4 |
5 |
6 |
7 | [](https://rex-xai.readthedocs.io/en/latest/)
8 | [](https://github.com/ReX-XAI/ReX/actions/workflows/test-package-and-comment.yml)
9 | [](https://github.com/ReX-XAI/ReX.jl/blob/main/LICENSE)
10 |
11 |
12 |
13 | ---
14 |
15 | ReX is a causal explainability tool for image classifiers. It also works on tabular and 3D data.
16 |
17 | Given an input image and a classifier, ReX calculates a causal responsibility map across the data and identifies a minimal, sufficient, explanation.
18 |
19 |   
20 |
21 | ReX is black-box, that is, agnostic to the internal structure of the classifier.
22 | ReX finds single explanations, non-contiguous explanations (for partially obscured images), multiple independent explanations, contrastive explanations and lots of other things!
23 | It has a host of options and parameters, allowing you to fine tune it to your data.
24 |
25 | For background information and detailed usage instructions, see our [documentation](https://rex-xai.readthedocs.io/en/latest/).
26 |
27 |
28 |
29 | ## Installation
30 |
31 | ReX can be installed using `pip`.
32 | We recommend creating a virtual environment to install ReX.
33 | ReX has been tested using versions of Python >= 3.10.
34 | The following instructions assume `conda`:
35 |
36 | ```bash
37 | conda create -n rex python=3.13
38 | conda activate rex
39 | pip install rex_xai
40 | ```
41 |
42 | This should install an executable `rex` in your path.
43 |
44 | > **Note:**
45 | >
46 | > By default, `onnxruntime` will be installed.
47 | > If you wish to use a GPU, you should uninstall `onnxruntime` and install `onnxruntime-gpu` instead.
48 | > You can alternatively clone the project and edit the `pyproject.toml` to read "onnxruntime-gpu >= 1.17.0" rather than "onnxruntime >= 1.17.0".
49 |
50 | If you want to use ReX with 3D data, you will need to install some optional extra dependencies:
51 |
52 | ```bash
53 | pip install 'rex_xai[3D]'
54 | ```
55 |
56 |
57 |
58 | ## Feedback
59 |
60 | Bug reports, questions, and suggestions for enhancements are welcome - please [check the GitHub Issues](https://github.com/ReX-XAI/ReX/issues) to see if there is already a relevant issue, or [open a new one](https://github.com/ReX-XAI/ReX/issues/new)!
61 |
62 | ## How to Contribute
63 |
64 | Your contributions are highly valued and welcomed. To get started, please review the guidelines outlined in the [CONTRIBUTING.md](/CONTRIBUTING.md) file. We look forward to your participation!
65 |
--------------------------------------------------------------------------------
/assets/bus_757.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/bus_757.png
--------------------------------------------------------------------------------
/assets/id_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/id_curve.png
--------------------------------------------------------------------------------
/assets/ladybird_301.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_301.png
--------------------------------------------------------------------------------
/assets/ladybird_301_raw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_301_raw.png
--------------------------------------------------------------------------------
/assets/ladybird_rm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/ladybird_rm.png
--------------------------------------------------------------------------------
/assets/lizard_subs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/lizard_subs.png
--------------------------------------------------------------------------------
/assets/occluded_bus_rm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/occluded_bus_rm.png
--------------------------------------------------------------------------------
/assets/peacock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock.jpg
--------------------------------------------------------------------------------
/assets/peacock_05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_05.png
--------------------------------------------------------------------------------
/assets/peacock_84_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_00.png
--------------------------------------------------------------------------------
/assets/peacock_84_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_01.png
--------------------------------------------------------------------------------
/assets/peacock_84_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_84_02.png
--------------------------------------------------------------------------------
/assets/peacock_comp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_comp.png
--------------------------------------------------------------------------------
/assets/peacock_exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/peacock_exp.png
--------------------------------------------------------------------------------
/assets/rex-assumptions-768x259.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex-assumptions-768x259.png
--------------------------------------------------------------------------------
/assets/rex-structure-600x129.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex-structure-600x129.png
--------------------------------------------------------------------------------
/assets/rex_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/rex_logo.png
--------------------------------------------------------------------------------
/assets/rex_logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
73 |
--------------------------------------------------------------------------------
/assets/spectrum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/spectrum.png
--------------------------------------------------------------------------------
/assets/tandem.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem.jpg
--------------------------------------------------------------------------------
/assets/tandem_contrastive.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_contrastive.png
--------------------------------------------------------------------------------
/assets/tandem_exp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_exp.png
--------------------------------------------------------------------------------
/assets/tandem_multi_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_1.png
--------------------------------------------------------------------------------
/assets/tandem_multi_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_2.png
--------------------------------------------------------------------------------
/assets/tandem_multi_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/assets/tandem_multi_3.png
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/rex_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/docs/_static/rex_logo.png
--------------------------------------------------------------------------------
/docs/background.md:
--------------------------------------------------------------------------------
1 | # Background information
2 |
3 | ReX is a causal explainability tool for image classifiers.
4 | ReX is black-box, that is, agnostic to the internal structure of the classifier.
5 | We assume that we can modify the inputs and send them to the classifier, observing the output.
6 | ReX outperforms other tools on [single explanations](https://www.hanachockler.com/eccv/), [non-contiguous explanations](https://www.hanachockler.com/iccv2021/) (for partially obscured images), and [multiple explanations](http://www.hanachockler.com/multirex/).
7 |
8 | 
9 |
10 | ## Assumptions
11 |
12 | ReX works on the assumption that if we can intervene on the inputs to a model and observe changes in its outputs, we can use this information to reason about the way the DNN makes its decisions.
13 |
14 | 
15 |
16 | ## Presentations about ReX
17 |
18 | * [Attacking your black box classifier with ReX](https://www.hanachockler.com/rex-2/)
19 | * [Causal Explanations For Image Classifiers](https://www.hanachockler.com/hana-chockler-causal-xai-workshop-102023/)
20 |
21 | ## Papers
22 |
23 | 1. [Causal Explanations for Image Classifiers](https://arxiv.org/pdf/2411.08875). Under review. This paper introduces the tool ReX.
24 | 2. [Multiple Different Black Box Explanations for Image Classifiers](http://www.hanachockler.com/multirex/). Under review. This paper introduces MULTI-ReX for multiple explanations.
25 | 3. [3D ReX: Causal Explanations in 3D Neuroimaging Classification](https://arxiv.org/pdf/2502.12181). Presented at [Imageomics-AAAI-25](https://sites.google.com/view/imageomics-aaai-25/home?authuser=0). 3D explanations for neuroimaging.
26 | 4. [Explanations for Occluded Images](http://www.hanachockler.com/iccv2021/). In ICCV’21. This paper introduces causality for image classifier explanations. Note: the tool is called DC-Causal in this paper.
27 | 5. [Explaining Image Classifiers using Statistical Fault Localization](http://www.hanachockler.com/eccv/). In ECCV’20. The first paper on ReX. Note: the tool is called DeepCover in this paper.
28 |
--------------------------------------------------------------------------------
/docs/command_line.md:
--------------------------------------------------------------------------------
1 | # Command line usage
2 |
3 |
4 |
5 | ```{argparse}
6 | :module: rex_xai.config
7 | :func: cmdargs_parser
8 | :prog: ReX
9 | ```
10 |
11 |
12 |
13 | ## Model formats
14 |
15 | ### Onnx
16 |
17 | ReX natively understands onnx files. Train or download a model (e.g. [Resnet50](https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v1-7.onnx)) and, from this directory, run:
18 |
19 | ```bash
20 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --output dog_exp.jpg
21 | ```
22 |
23 | ### Pytorch
24 |
25 | ReX also works with PyTorch, but you will need to write some custom code to provide ReX with the prediction function and model shape, as well as preprocess the input data.
26 | See the sample scripts in `scripts/`.
27 |
28 | ```bash
29 | rex tests/test_data/dog.jpg --script scripts/pytorch.py -vv --output dog_exp.jpg
30 | ```
31 |
32 | ## Saving output in a database
33 |
34 | To store all output in a sqlite database, use:
35 |
36 | ```bash
37 | rex --model -db
38 | ```
39 |
40 | ReX will create the db if it does not already exist.
41 | It will append to any db with the given name, so be careful not to use the same database if you are restarting an experiment.
42 | ReX comes with a number of database functions which allow you to load it as a pandas dataframe for analysis.
43 |
44 | ## Config
45 |
46 | ReX looks for the config file `rex.toml` in the current working directory and then `$HOME/.config/rex.toml` on unix-like systems.
47 |
48 | If you want to use a custom location, use:
49 |
50 | ```bash
51 | rex --model --config
52 | ```
53 |
54 | An example config file is included in the repo as `example.rex.toml`.
55 | Rename this to `rex.toml` if you wish to use it. ReX will ignore the file `example.rex.toml` is you do not rename it.
56 |
57 | ### Overriding the config
58 |
59 | Some options from the config file can be overridden at the command line when calling ReX.
60 | In particular, you can change the number of iterations of the algorithm:
61 |
62 | ```bash
63 | rex --model --iters 5
64 | ```
65 |
66 | ## Preprocessing
67 |
68 | Input data should be transformed in the same way the model's training data was transformed before training.
69 | For PyTorch models, you should specify the preprocessing steps in the custom script.
70 | See the sample scripts in `scripts/` for examples of using models provided by the [torchvision](https://pytorch.org/vision/stable/index.html) and [timm](https://huggingface.co/docs/timm/index) packages.
71 | You can also write a script to load an `onnx` model as well if you have custom preprocessing to do.
72 |
73 | Otherwise, ReX tries to make reasonable guesses for image preprocessing.
74 | This includes resizing the image to match that needed for the model, converting it to a PyTorch tensor, and normalising the data. This can be controlled in `rex.toml`.
75 | In the event the the model input is multi-channel and the image is greyscale, then ReX will convert the image to pseudo-RGB.
76 | If you want more control over the conversion, you can do the conversion yourself and pass in the converted image.
77 |
78 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = "ReX"
10 | copyright = "2024, David Kelly"
11 | author = "David Kelly, Liz Ing-Simmons and other contributors"
12 |
13 | # -- General configuration ---------------------------------------------------
14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
15 |
16 | extensions = [
17 | "autoapi.extension",
18 | "sphinx.ext.autodoc.typehints",
19 | "sphinx.ext.napoleon",
20 | "sphinx.ext.intersphinx",
21 | 'sphinxarg.ext',
22 | "myst_nb"
23 | ]
24 |
25 | templates_path = ["_templates"]
26 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
27 |
28 | # -- Options for HTML output -------------------------------------------------
29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
30 |
31 | html_theme = "alabaster"
32 | html_sidebars = {
33 | "**": [
34 | "about.html",
35 | "searchfield.html",
36 | "navigation.html",
37 | "relations.html",
38 | "donate.html",
39 | ]
40 | }
41 | html_static_path = ["_static"]
42 | html_theme_options = {
43 | "logo": "rex_logo.png"
44 | }
45 |
46 | # -- AutoAPI -----------------------------------------------------------------
47 | # https://sphinx-autoapi.readthedocs.io/en/latest/
48 | autoapi_dirs = ["../rex_xai/"]
49 | autodoc_typehints = "description"
50 |
51 | # -- Intersphinx --------------------------------------------------------------
52 | # https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html
53 | intersphinx_mapping = {
54 | "python": ("https://docs.python.org/3", None),
55 | "torch": ("https://pytorch.org/docs/stable/", None),
56 | "sqla": ("https://docs.sqlalchemy.org/en/latest/", None),
57 | }
58 |
59 | # -- MyST --------------------------------------------------------------
60 | # https://myst-parser.readthedocs.io/en/latest/
61 | myst_enable_extensions = [
62 | "attrs_inline"
63 | ]
64 |
65 | nb_execution_timeout = 300
66 |
--------------------------------------------------------------------------------
/docs/contrastive.md:
--------------------------------------------------------------------------------
1 | # Contrastive Explanations
2 |
3 | This page describes **contrastive** explanations. These are *necessary* and *sufficient* explanations.
4 | A normal explanation from ReX is only sufficient.
5 |
6 | From the command line, the basic call is
7 |
8 | ```bash
9 | rex --script pytorch.py --contrastive
10 | ```
11 |
12 | ## Example
13 |
14 | This image is a tandem bike
15 |
16 | ```{image} ../assets/tandem.jpg
17 | :alt: Tandem Bike
18 | :align: center
19 | ```
20 |
21 |
22 | If we run
23 |
24 | ```bash
25 | rex tandem.jpg --script ../tests/scripts/pytorch_resnet50.py --vv --output tandem_exp.png
26 | ```
27 |
28 | we get
29 |
30 | ```{image} ../assets/tandem_exp.png
31 | :alt: Tandem Bike Explanation
32 | :scale: 200%
33 | :align: center
34 | ```
35 |
36 |
37 |
38 | Passing the highlighted pixels to the model, against the baseline defined in `rex.toml`, is enough to get the classification `tandem`.
39 | However, if we remove these highlighted pixels and leave the rest alone, we still get `tandem`. Why is that? Because these pixels are *sufficient*
40 | to get `tandem`, but they aren't *necessary*. There must be at least one more, independent, explanation for `tandem` in the image.
41 |
42 | We can try to find other sufficient explanations using `--multi`.
43 |
44 | ```bash
45 | rex tandem.jpg --script ../tests/scripts/pytorch_resnet50.py --multi --vv --output tandem_exp.png
46 | ```
47 | 
48 | 
49 | 
50 |
51 | Each one of these is, by itself, sufficient for `tandem`. By default, `contrastive` uses combinations from the set of found sufficient explanations to
52 | find a combinations which is also *necessary*: if we remove these pixels then we no longer have a `tandem`, if we have only these pixels, we have `tandem`.
53 | This explanations is both sufficient to get `tandem` and necessary to get `tandem`.
54 |
55 | ```{image} ../assets/tandem_contrastive.png
56 | :alt: Tandem Bike Contrastive Explanation
57 | :scale: 200%
58 | :align: center
59 | ```
60 |
61 | ## Notes
62 |
63 | `contrastive` uses multiple explanations under the hood and then tests all combinations of discovered explanations. As a result, it can fail to find a contrastive explanation,
64 | especially if there are more sufficient explanations than discovered by `--multi`. In this case, ReX quits with an error message.
65 |
66 | `constrastive` takes an optional `int` at the command line, indicating how many spotlights to launch. This defaults to `10`, just like with multiple explanations.
67 |
68 | The constrastive algorithm is an instance of the [set packing](https://en.wikipedia.org/wiki/Set_packing) problem. As such, in the worst case scenario, it can be quite expensive to compute.
69 |
70 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # ReX: Causal Responsibility Explanations for image classifiers
2 |
3 | **ReX** is a causal explainability tool for image classifiers.
4 | ReX is black-box, that is, agnostic to the internal structure of the classifier.
5 | We assume that we can modify the inputs and send them to the classifier, observing the output.
6 | ReX provides sufficient, minimal single explanations, non-contiguous explanations (for partially obscured images), multiple explanations
7 | and contrastive explanations (sufficient, necessary and minimal).
8 |
9 | ```{image} ../assets/rex-structure-600x129.png
10 | :alt: ReX organisation
11 | :width: 600px
12 | :align: center
13 | ```
14 |
15 | For more information and links to the papers, see the [background](background) page.
16 |
17 | ```{include} ../README.md
18 | :start-after:
19 | :end-before:
20 | ```
21 |
22 | ## Quickstart
23 |
24 | ReX requires as input an image and a model.
25 | ReX natively understands onnx files. Train or download a model (e.g. [Resnet50](https://github.com/onnx/models/blob/main/validated/vision/classification/resnet/model/resnet50-v1-7.onnx)) and, from this directory, run:
26 |
27 | ```bash
28 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --output dog_exp.jpg
29 | ```
30 |
31 | To view an interactive plot for the responsibility map, run::
32 |
33 | ```bash
34 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx -vv --surface
35 | ```
36 |
37 | To save the extracted explanation to a file:
38 |
39 | ```bash
40 | rex tests/test_data/dog.jpg --model resnet50-v1-7.onnx --output dog_exp.jpg
41 | ```
42 |
43 | ReX also works with PyTorch, but you will need to write some custom code to provide ReX with the prediction function and model shape, as well as preprocess the input data.
44 | See the sample scripts in `tests/scripts/`.
45 |
46 | ```bash
47 | rex tests/test_data/dog.jpg --script tests/scripts/pytorch_resnet50.py -vv --output dog_exp.jpg
48 | ```
49 |
50 | Other options:
51 |
52 | ```bash
53 | # with spatial search rather than the default global search
54 | rex --model --strategy spatial
55 |
56 | # to run multiple explanations
57 | rex --model --multi
58 |
59 | # to view a responsibility landscape heatmap
60 | rex --model --heatmap
61 |
62 | # to save a responsibility landscape surface plot
63 | rex --model --surface
64 | ```
65 |
66 | ReX configuration is mainly handled via a config file; some options can also be set on the command line.
67 | ReX looks for the config file `rex.toml` in the current working directory and then `$HOME/.config/rex.toml` on unix-like systems.
68 |
69 | If you want to use a custom location, use:
70 |
71 | ```bash
72 | rex --model --config
73 | ```
74 |
75 | An example config file is included in the repo as `example.rex.toml`.
76 | Rename this to `rex.toml` if you wish to use it.
77 |
78 | ## Command line usage
79 |
80 | ```{include} command_line.md
81 | :start-after:
82 | :end-before:
83 | ```
84 |
85 | ## Examples
86 |
87 | ### Explanation
88 |
89 | An explanation for a ladybird. This explanation was produced with 20 iterations, using the default masking colour (0). The minimal, sufficient explanation itself
90 | is pretty printed using the settings in `[rex.visual]` in `rex.toml`.
91 |
92 |   
93 |
94 | Setting `raw = true` in `rex.toml` produces the image which was actually classified by the model.
95 |
96 | 
97 |
98 | ### Multiple Explanations
99 |
100 | ```bash
101 | rex tests/test_data/peacock.jpg --model resnet50-v1-7.onnx --strategy multi --output peacock.png
102 | ```
103 |
104 | The number of explanations found depends on the model and some of the settings in `rex.toml`.
105 |
106 | {w=200px}   
107 |
108 | ### Occluded Images
109 |
110 | 
111 |
112 | 
113 |
114 | 
115 |
116 | ### Explanation Quality
117 |
118 | ```bash
119 | rex tests/test_data/ladybird.jpg --script tests/scripts/pytorch_resnet50.py --analyse
120 |
121 | INFO:ReX:area 0.000399, entropy difference 6.751189, insertion curve 0.964960, deletion curve 0.046096
122 | ```
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 | ```{toctree}
136 | :maxdepth: 2
137 | :caption: Contents:
138 | background.md
139 | command_line.md
140 | notebooks/intro
141 | config.md
142 | multiple.md
143 | contrastive.md
144 | ```
145 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/docs/index.rst
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/multiple.md:
--------------------------------------------------------------------------------
1 | # Multiple Explanations
2 |
3 | This page describes **multiple** explanations, see [Multiple Different Black Box Explanations for Image Classifiers](http://www.hanachockler.com/multirex/).
4 |
5 | ## Example
6 |
7 | An image classification may have more than one sufficient explanation. Take this image of a peacock
8 |
9 | ```{image} ../assets/peacock.jpg
10 | :alt: Peacock
11 | :align: center
12 | ```
13 |
14 | The global explanation is:
15 |
16 |
17 | ```{image} ../assets/peacock_exp.png
18 | :alt: Peacock Explanation
19 | :scale: 120%
20 | :align: center
21 | ```
22 |
23 | But it's very likely that there's more than one. This small part of the tail is enough to get the classification `peacock`, but there are many
24 | other possible sources of information that match that classification. ReX can try to find them.
25 |
26 | ReX searches the responsibility map for sufficient explanations. It does this by launched `spotlights` which explore the space, using the responsibility
27 | as a guide. How many spotlights are launched is a parameter (by default: 10) and is set with the `--multi` flag; `--multi` takes an optional
28 | integer argument.
29 |
30 | ```bash
31 | rex peacock.jpg --script ../tests/scripts/pytorch_resnet50.py --multi 5 --vv --output peacock_exp.png
32 | ```
33 | we get
34 |
35 | ```{image} ../assets/peacock_comp.png
36 | :alt: Peacock Explanation
37 | :align: center
38 | :scale: 120%
39 | ```
40 |
41 | ReX has found 4 distinct, non-overlapping explanations. The original global explanation is still there, but we also have 3 other explanations.
42 | Two of these explanations (highlighted in white and red respectively) are [*disjoint* explanations](https://arxiv.org/pdf/2411.08875).
43 |
44 | ## Overlap
45 |
46 | The peacock shows 4 non-overlapping explanations, but this is a parameter. We can set the allowed degree of overlap by changing
47 | `permitted_overlap` in the [config](explanation_multi). This sets the [dice coefficient](https://en.wikipedia.org/wiki/Dice-S%C3%B8rensen_coefficient)
48 | of the explanations.
49 |
50 | If we set `permitted_overlap = 0.5`
51 |
52 | ```bash
53 | rex peacock.jpg --script ../tests/scripts/pytorch_resnet50.py --multi 10 --vv --output peacock_exp.png
54 | ```
55 |
56 | ```{image} ../assets/peacock_05.png
57 | :alt: Peacock Explanation Overlap
58 | :align: center
59 | :scale: 120%
60 | ```
61 |
62 | ## Notes
63 | Multi-ReX has many options and parameters, see [config](explanation_multi) for the complete list.
64 |
65 | The `spotlight` requires an objective function to guide its search of the responsibility landscape. By default this is `none`: if
66 | the spotlight fails to find an explanation in one location, it takes a random jump to another. Alternatively, `mean` moves the spotlight
67 | in the direction of the greater mean responsibility.
68 |
69 |
--------------------------------------------------------------------------------
/docs/script.md:
--------------------------------------------------------------------------------
1 | ## Script usage
2 | ReX can take in scripts that define the model behaviour, the preprocessing for the model and how the model's output can be interpreted by ReX. This is to allow the users to provide custom preprocessing/models to ReX.
3 |
4 | As outlined in the [command line section](command_line.md), the user can pass in the script using the `--script` argument.
5 |
6 | ```bash
7 | rex imgs/dog.jpg --script scripts/pytorch.py -vv --output dog_exp.jpg
8 | ```
9 |
10 | ### Contents of the python script
11 | There are three main components to the script:
12 | - A preprocess function which takes in the following parameters and returns a Data object:
13 | - path: The path to the image
14 | - shape: The shape of the model input
15 | - device: The device the data is on e.g. "cuda"
16 | - mode: The mode of the data e.g. "RGB", "L", "voxel"
17 | - A function that calls the model called prediction_function that takes in the following parameters and returns a list of Prediction objects:
18 | - mutants: Mutants created by ReX to run inference on
19 | - target: The target class , default None
20 | - raw: Whether to return the raw output (e.g. the probability of the classification) or not, default False
21 | - binary_threshold: The threshold for binary classification e.g. 0.5 , default None
22 | - Model shape function that returns the shape of the model input
23 | - Any other helper functions that are needed for the above functions
24 |
25 | #### Preprocessing function
26 |
27 | The preprocessing function is responsible for loading the image and transforming it to the correct shape for the model.
28 |
29 | ```python
30 | def preprocess(path, shape, device, mode) -> Data:
31 | ```
32 | The **key steps** in the preprocess function are:
33 | - Load the data from the path
34 | - Transform the data to requirements of the model
35 | - Return a Data object
36 |
37 |
38 | The function should return a Data object. The Data object contains the following fields:
39 | - `input` -> The raw input
40 | - `data` -> The transformed input for the model
41 | - `model_shape` -> The shape of the model input
42 | - `device` -> The device the data is on e.g. "cuda"
43 | - `mode` -> The mode of the data e.g. "RGB", "L", "voxel"
44 | - `process` -> A boolean that indicates whether the data mode should be accessed or not
45 | - `model_height` -> The height of the model input
46 | - `model_width` -> The width of the model input
47 | - `model_height` -> The height of the model input
48 | - `model_channels` -> The number of channels in the model input
49 | - `transposed` -> Whether the data is transposed or not
50 | - `model_order` -> The order of the model input e.g. "first" or "last"
51 | - `background` -> The value of the background of the image e.g. 0 or 255 ... etc. For a range of values, use a tuple e.g. (0, 255)
52 | - `context` -> The context of the image e.g. the specific background like a beach or a road that can be used as an occlusion if specified as mask value
53 |
54 | The Data object can be initialised with the `input`, `model_shape`, `device` and optionally the `mode` and `process`.:
55 | Example:
56 | ```python
57 | data = Data(input, model_shape, device, mode="voxel", process=False)
58 | # Set the other attributes of the Data object separately like so
59 | data.model_height = 224
60 | ```
61 |
62 |
63 | #### Prediction function
64 |
65 | The prediction function is responsible for running inference on the model, processing and returning the output.
66 |
67 | ```python
68 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
69 | ```
70 |
71 | **Parameters**:
72 | - `mutants` -> A list of mutants to run inference on?
73 | - `target` -> The target class
74 | - `raw` -> Whether to return the raw output (e.g. the probability of the classification) or not
75 | - `binary_threshold` -> The threshold for binary classification e.g. 0.5
76 |
77 | **Returns**:
78 | - A list of Prediction objects or a float if raw is True
79 |
80 | The Prediction object contains the following fields:
81 | - `classification` -> The classification of the mutant: Optional[int]
82 | - `confidence` -> The confidence of the classification: Optional[float]
83 | - `bounding_box` -> The bounding box for the classification: Optional[NDArray]
84 | - `target` -> The target class: Optional[int]
85 | - `target_confidence` -> The confidence of the target class: Optional[float]
86 |
87 | #### Model shape function
88 |
89 | The model shape function is responsible for returning the shape of the model input.
90 |
91 | ```python
92 | def model_shape() -> []:
93 | ```
94 | **Example:**
95 | ```python
96 | def model_shape():
97 | return ["N", 3, 224, 224]
98 | ```
99 |
100 | ---
101 | Example scripts can be found in the `tests/scripts` and `scripts` directory.
--------------------------------------------------------------------------------
/example.rex.toml:
--------------------------------------------------------------------------------
1 | [rex]
2 | # masking value for mutations, can be either an integer, float or
3 | # one of the following built-in occlusions 'spectral', 'min', 'mean'
4 | # mask_value = 0
5 |
6 | # random seed, only set for reproducibility
7 | # seed = 42
8 |
9 | # whether to use gpu or not, defaults to true
10 | # gpu = true
11 |
12 | # batch size for the model
13 | # batch_size = 64
14 |
15 | [rex.onnx]
16 | # means for min-max normalization
17 | # means = [0.485, 0.456, 0.406]
18 |
19 | # stds = [0.229, 0.224, 0.225]
20 |
21 | # binary model confidence threshold. Anything >= threshold will be classified as 1, otherwise 0
22 | # binary_threshold = 0.5
23 |
24 | # norm = 255.0
25 |
26 | # intra_op_num_threads = 8
27 |
28 | # inter_op_num_threads = 8
29 |
30 | # ort_logger = 3
31 |
32 | [rex.visual]
33 | # whether to show progress bar in the terminal, defaults to true
34 | # progress_bar = false
35 |
36 | # resize the explanation to the size of the original image. This uses cubic interpolation and will not be as visually accurate as not resizing, defaults to false
37 | # resize = true
38 |
39 | # include classification and confidence information in title of plot, defaults to true
40 | # info = false
41 |
42 | # produce unvarnished image with actual masking value, defaults to false
43 | # raw = false
44 |
45 | # pretty printing colour for explanations, defaults to 200
46 | # colour = 100
47 |
48 | # matplotlib colourscheme for responsibility map plotting, defaults to 'magma'
49 | # heatmap_colours = 'coolwarm'
50 |
51 | # alpha blend for main image, defaults to 0.2 (PIL Image.blend parameter)
52 | # alpha = 0.2
53 |
54 | # overlay a 10*10 grid on an explanation, defaults to false
55 | # grid = false
56 |
57 | # mark quickshift segmentation on image
58 | # mark_segments = false
59 |
60 | # multi_style explanations, either or
61 | # multi_style = "composite"
62 |
63 | [causal]
64 | # maximum depth of tree, defaults to 10, note that search can actually go beyond this number on occasion, as the
65 | # check only occurs at the end of an iteration
66 | # tree_depth = 30
67 |
68 | # limit on number of combinations to consider , defaults to none.
69 | # It is **not** the total work done by ReX over all iterations. Leaving the search limit at none
70 | # can potentially be very expensive.
71 | # search_limit = 1000
72 |
73 | # number of times to run the algorithm, defaults to 20
74 | # iters = 30
75 |
76 | # minimum child size, in pixels
77 | # min_box_size = 10
78 |
79 | # remove passing mutants which have a confidence less thatn . Defaults to 0.0 (meaning all mutants are considered)
80 | # confidence_filter = 0.5
81 |
82 | # whether to weight responsibility by prediction confidence, default to false
83 | # weighted = false
84 |
85 | # queue_style = "intersection" | "area" | "all" | "dc", defaults to "area"
86 | # queue_style = "area"
87 |
88 | # maximum number of things to hold in search queue, either an integer or 'all'
89 | # queue_len = 1
90 |
91 | # concentrate: weight responsibility by tree depth of passing partition. Defaults to false
92 | # concentrate = true
93 |
94 | # subtract responsibility for non-target classifications. Defaults to false
95 | # negative_responsibility = true
96 |
97 | [causal.distribution]
98 | # distribution for splitting the box, defaults to uniform. Possible choices are 'uniform' | 'binom' | 'betabinom' | 'adaptive'
99 | # distribution = 'uniform'
100 |
101 | # blend = 0.5
102 |
103 | # supplimental arguments for distribution creation, these are ignored if does not take any parameters
104 | # dist_args = [1.1, 1.1]
105 |
106 | [explanation]
107 | # iterate through pixel ranking in chunks, defaults to causal.min_box_size
108 | # chunk_size = 10
109 |
110 | # causal explanations are minimal by definition, but very small explanations might have very low confidence. ReX will keep
111 | # looking for an explanation of confidence greater than or equal to the model confidence on .
112 | # This is especially useful to reduce errors due to floating point imprecision when batching calls to the model.
113 | # Defaults to 0.0, with maximum value 1.0.
114 | # minimum_confidence_threshold = 0.1
115 |
116 | [explanation.spatial]
117 | # initial search radius
118 | # initial_radius = 25
119 |
120 | # increment to change radius
121 | # radius_eta = 0.2
122 |
123 | # number of times to expand before quitting, defaults to 4
124 | # no_expansions = 4
125 |
126 | [explanation.multi]
127 | # multi method (just spotlight so far)
128 | # method = 'spotlight'
129 |
130 | # no of spotlights to launch
131 | # spotlights = 10
132 |
133 | # default size of spotlight
134 | # spotlight_size = 24
135 |
136 | # decrease spotlight by this amount
137 | # spotlight_eta = 0.2
138 |
139 | # maximum number of random steps that a spotlight can make before quitting
140 | # max_spotlight_budget = 40
141 |
142 | # objective function for spotlight search. Possible options 'mean' | 'max' | "none"
143 | # objective_function = 'none'
144 |
145 | # permitted_overlap = 0.5
146 |
147 | [explanation.evaluation]
148 |
149 | # normalise insertion/deletion curves by confidence of original data, defaults to true
150 | # normalise_curves = true
151 |
152 | # insertion/deletion curve step size
153 | # insertion_step = 100
154 |
155 |
156 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "rex-xai"
3 | version = "0.3.1"
4 | description = "causal Responsibility-based eXplanations of black-box-classifiers"
5 | authors = [
6 | { name = "David Kelly", email = "dkellino@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 |
11 | dependencies = [
12 | "numpy==1.26.4",
13 | "scipy>=1.10",
14 | "imutils>=0.5.4",
15 | "toml>=0.10",
16 | "anytree>=2.8.0",
17 | "fastcache>=1.1.0",
18 | "tqdm>=4.65.0",
19 | "sqlalchemy>=2.0.16",
20 | "matplotlib>=3.7.1",
21 | "onnxruntime>=1.18.0",
22 | "scikit-image>=0.21.0",
23 | "pandas>=2.2.0",
24 | "pillow>=10.3.0",
25 | "torch>=2.6.0"
26 | ]
27 |
28 | [project.optional-dependencies]
29 |
30 | 3D = [
31 | "nibabel>=5.2.1",
32 | "kaleido==0.2.1",
33 | "plotly>=5.4.0",
34 | "dash>=2.1.0"
35 | ]
36 |
37 | [project.urls]
38 | homepage = "https://rex-xai.readthedocs.io//"
39 | repository = "https://github.com/ReX-XAI/ReX"
40 | documentation = "https://rex-xai.readthedocs.io/"
41 | "Bug Tracker" = "https://github.com/ReX-XAI/ReX/issues"
42 |
43 | [tool.poetry.scripts]
44 | ReX = "rex_xai.rex_wrapper:main"
45 |
46 | [build-system]
47 | requires = ["poetry-core"]
48 | build-backend = "poetry.core.masonry.api"
49 |
50 | [tool.poetry.group.dev]
51 | optional = true
52 |
53 | [tool.poetry.group.dev.dependencies]
54 | ruff = "^0.6.8"
55 | pytest = "^8.3.3"
56 | sphinx = "^8.0.2"
57 | myst-parser = "^4.0.0"
58 | sphinx-autoapi = "^3.3.2"
59 | pyright = "^1.1.383"
60 | pytest-cov = "^5.0.0"
61 | syrupy = "^4.7.2"
62 | torchvision = "^0.21.0"
63 | pytest-sugar = "^1.0.0"
64 | cached-path = "^1.6.3"
65 | sphinx-argparse = "^0.5.2"
66 | myst-nb = "^1.1.2"
67 | jupytext = "^1.16.7"
68 | plotly = "^5.4.0"
69 | dash = "^2.1.0"
70 | kaleido = "0.2.1"
71 |
72 | [tool.pyright]
73 | include = ["rex_xai"]
74 | exclude = ["scripts"]
75 | reportMissingTypeStubs = false
76 |
--------------------------------------------------------------------------------
/rex_xai/__init__.py:
--------------------------------------------------------------------------------
1 | # pylintt: disable=invalid-name
2 |
--------------------------------------------------------------------------------
/rex_xai/explanation/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from typing import Tuple
3 | import numpy as np
4 | import torch as tt
5 | from scipy.integrate import simpson
6 | from scipy.signal import periodogram
7 | from skimage.measure import shannon_entropy
8 |
9 | from rex_xai.explanation.explanation import Explanation
10 | from rex_xai.utils._utils import get_map_locations
11 | from rex_xai.mutants.mutant import _apply_to_data
12 | from rex_xai.utils._utils import set_boolean_mask_value, xlogx
13 |
14 |
15 | class Evaluation:
16 | # TODO does this need to be an object? Probably not...
17 | # TODO consider inheritance from Explanation object
18 | def __init__(self, explanation: Explanation) -> None:
19 | self.explanation = explanation
20 |
21 | def ratio(self) -> float:
22 | """Returns percentage of data required for sufficient explanation"""
23 | if (
24 | self.explanation.explanation is None
25 | or self.explanation.data.model_channels is None
26 | ):
27 | raise ValueError("Invalid Explanation object")
28 |
29 | try:
30 | final_mask = self.explanation.final_mask.squeeze().item() # type: ignore
31 | except Exception:
32 | final_mask = self.explanation.final_mask
33 |
34 | try:
35 | return (
36 | tt.count_nonzero(final_mask) # type: ignore
37 | / final_mask.size # type: ignore
38 | ).item()
39 | except TypeError:
40 | return (
41 | np.count_nonzero(final_mask) # type: ignore
42 | / final_mask.size # type: ignore
43 | )
44 |
45 | def spectral_entropy(self) -> Tuple[float, float]:
46 | """
47 | This code is a simplified version of
48 | https://github.com/raphaelvallat/antropy/blob/master/src/antropy/entropy.py
49 | """
50 | _, psd = periodogram(self.explanation.target_map)
51 | psd_norm = psd / psd.sum()
52 | ent = -np.sum(xlogx(psd_norm))
53 | if len(psd_norm.shape) == 2:
54 | max_ent = np.log2(len(psd_norm[0]))
55 | else:
56 | max_ent = np.log2(len(psd_norm))
57 | return ent, max_ent
58 |
59 | def entropy_loss(self):
60 | img = np.array(self.explanation.data.input)
61 | assert self.explanation.explanation is not None
62 | exp = shannon_entropy(self.explanation.explanation.detach().cpu().numpy())
63 |
64 | return shannon_entropy(img), exp
65 |
66 | def insertion_deletion_curve(self, prediction_func, normalise=False):
67 | insertion_curve = []
68 | deletion_curve = []
69 |
70 | assert self.explanation.data.target is not None
71 | assert self.explanation.data.target.confidence is not None
72 |
73 | assert self.explanation.data.data is not None
74 | insertion_mask = tt.zeros(
75 | self.explanation.data.data.squeeze(0).shape, dtype=tt.bool
76 | ).to(self.explanation.data.device)
77 | deletion_mask = tt.ones(
78 | self.explanation.data.data.squeeze(0).shape, dtype=tt.bool
79 | ).to(self.explanation.data.device)
80 | im = []
81 | dm = []
82 |
83 | step = self.explanation.args.insertion_step
84 | ranking = get_map_locations(map=self.explanation.target_map)
85 | iters = len(ranking) // step
86 |
87 | for i in range(0, len(ranking), step):
88 | chunk = ranking[i : i + step]
89 | for _, loc in chunk:
90 | set_boolean_mask_value(
91 | insertion_mask,
92 | self.explanation.data.mode,
93 | self.explanation.data.model_order,
94 | loc,
95 | )
96 | set_boolean_mask_value(
97 | deletion_mask,
98 | self.explanation.data.mode,
99 | self.explanation.data.model_order,
100 | loc,
101 | val=False,
102 | )
103 | im.append(
104 | _apply_to_data(insertion_mask, self.explanation.data, 0).squeeze(0)
105 | )
106 | dm.append(
107 | _apply_to_data(deletion_mask, self.explanation.data, 0).squeeze(0)
108 | )
109 |
110 | if len(im) == self.explanation.args.batch_size:
111 | self.__batch(im, dm, prediction_func, insertion_curve, deletion_curve)
112 | im = []
113 | dm = []
114 |
115 | if im != [] and dm != []:
116 | self.__batch(im, dm, prediction_func, insertion_curve, deletion_curve)
117 |
118 | i_auc = simpson(insertion_curve, dx=step)
119 | d_auc = simpson(deletion_curve, dx=step)
120 |
121 | if normalise:
122 | const = self.explanation.data.target.confidence * iters * step
123 | i_auc /= const
124 | d_auc /= const
125 |
126 | return i_auc, d_auc
127 |
128 | # # def sensitivity(self):
129 | # # pass
130 |
131 | # # def infidelity(self):
132 | # # pass
133 |
134 | def __batch(self, im, dm, prediction_func, insertion_curve, deletion_curve):
135 | assert self.explanation.data.target is not None
136 | ip = prediction_func(tt.stack(im).to(self.explanation.data.device), raw=True)
137 | dp = prediction_func(tt.stack(dm).to(self.explanation.data.device), raw=True)
138 | for p in range(0, ip.shape[0]):
139 | insertion_curve.append(
140 | ip[p, self.explanation.data.target.classification].item()
141 | ) # type: ignore
142 | deletion_curve.append(
143 | dp[p, self.explanation.data.target.classification].item()
144 | ) # type: ignore
145 |
--------------------------------------------------------------------------------
/rex_xai/explanation/multi_explanation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """generate multiple explanations from a responsibility landscape """
4 |
5 | import os
6 | import re
7 | import numpy as np
8 |
9 | import torch as tt
10 | from itertools import combinations
11 |
12 | from rex_xai.explanation.explanation import Explanation
13 | from rex_xai.mutants.distributions import random_coords, Distribution
14 | from rex_xai.utils.logger import logger
15 | from rex_xai.utils._utils import powerset, clause_area, SpatialSearch
16 | from rex_xai.output.visualisation import (
17 | save_multi_explanation,
18 | save_image,
19 | plot_image_grid,
20 | )
21 |
22 |
23 | class MultiExplanation(Explanation):
24 | def __init__(self, maps, prediction_func, data, args, run_stats):
25 | super().__init__(maps, prediction_func, data, args, run_stats)
26 | self.explanations = []
27 | self.explanation_confidences = []
28 |
29 | def __repr__(self) -> str:
30 | pred_func = repr(self.prediction_func)
31 | match_func_name = re.search(r"("
34 |
35 | run_stats = {k: round(v, 5) for k, v in self.run_stats.items()}
36 |
37 | exp_text = (
38 | "MultiExplanation:"
39 | + f"\n\tCausalArgs: {type(self.args)}"
40 | + f"\n\tData: {self.data}"
41 | + f"\n\tprediction function: {pred_func}"
42 | + f"\n\tResponsibilityMaps: {self.maps}"
43 | + f"\n\trun statistics: {run_stats} (5 dp)"
44 | )
45 |
46 | if len(self.explanations) == 0:
47 | return (
48 | exp_text
49 | + f"\n\texplanations: {self.explanations}"
50 | + f"\n\texplanation confidences: {self.explanation_confidences}"
51 | )
52 | else:
53 | return (
54 | exp_text
55 | + f"\n\texplanations: {len(self.explanations)} explanations of {type(self.explanations[0])} and shape {self.explanations[0].shape}"
56 | + f"\n\texplanation confidences: {[round(x, ndigits=5) for x in self.explanation_confidences]} (5 dp)"
57 | )
58 |
59 | def save(self, path, mask=None, multi=None, multi_style=None, clauses=None):
60 | if multi_style is None:
61 | multi_style = self.args.multi_style
62 | if multi_style == "contrastive":
63 | super().save(path, mask=self.final_mask)
64 | if multi_style == "separate":
65 | logger.info("saving explanations in multiple different files")
66 | for i, mask in enumerate(self.explanations):
67 | name, ext = os.path.splitext(path)
68 | exp_path = f"{name}_{i}{ext}"
69 | super().save(exp_path, mask=mask)
70 | elif multi_style == "composite":
71 | logger.info("using composite style to save explanations")
72 | if clauses is None:
73 | clause = range(0, len(self.explanations))
74 | save_multi_explanation(
75 | self.explanations, self.data, self.args, clause=clause, path=path
76 | )
77 | else:
78 | name, ext = os.path.splitext(path)
79 | new_name = f"{name}_{clauses}{ext}"
80 | save_multi_explanation(
81 | self.explanations,
82 | self.data,
83 | self.args,
84 | clause=clauses,
85 | path=new_name,
86 | )
87 |
88 | def show(self, path=None, multi_style=None, clauses=None):
89 | if multi_style is None:
90 | multi_style = self.args.multi_style
91 | outs = []
92 |
93 | for mask in self.explanations:
94 | out = save_image(mask, self.data, self.args, path=None)
95 | outs.append(out)
96 |
97 | if multi_style == "separate":
98 | for mask in self.explanations:
99 | out = save_image(mask, self.data, self.args, path=None)
100 | outs.append(out)
101 |
102 | elif multi_style == "composite":
103 | if clauses is None:
104 | clause = tuple([i for i in range(len(self.explanations))])
105 | out = save_multi_explanation(
106 | self.explanations, self.data, self.args, clause=clause, path=None
107 | )
108 | outs.append(out)
109 | else:
110 | for clause in clauses:
111 | out = save_multi_explanation(
112 | self.explanations,
113 | self.data,
114 | self.args,
115 | clause=clause,
116 | path=None,
117 | )
118 | outs.append(out)
119 |
120 | if len(outs) > 1:
121 | plot_image_grid(outs)
122 | else:
123 | return outs[0]
124 |
125 | def extract(self, method=None):
126 | self.blank()
127 | # we start with the global max explanation
128 | logger.info("spotlight number 1 (global max)")
129 | conf = self._Explanation__global() # type: ignore
130 | if self.final_mask is not None:
131 | self.explanations.append(self.final_mask)
132 | self.explanation_confidences.append(conf)
133 | self.blank()
134 |
135 | for i in range(0, self.args.spotlights - 1):
136 | logger.info("spotlight number %d", i + 2)
137 | conf = self.spotlight_search()
138 | if self.final_mask is not None:
139 | self.explanations.append(self.final_mask)
140 | self.explanation_confidences.append(conf)
141 | self.blank()
142 | logger.info(
143 | "ReX has found a total of %d explanations via spotlight search",
144 | len(self.explanations),
145 | )
146 |
147 | def __dice(self, d1, d2):
148 | """calculates dice coefficient between two numpy arrays of the same dimensions"""
149 | d_sum = d1.sum() + d2.sum()
150 | if d_sum == 0:
151 | return 0
152 | intersection = tt.logical_and(d1, d2)
153 | return np.abs((2.0 * intersection.sum() / d_sum).item())
154 |
155 | def separate_by(self, dice_coefficient: float, reverse=True):
156 | exps = []
157 | sizes = dict()
158 |
159 | for i, exp in enumerate(self.explanations):
160 | size = tt.count_nonzero(exp)
161 | if size > 0:
162 | exps.append(i)
163 | sizes[i] = size
164 |
165 | clause_len = 0
166 | clauses = []
167 |
168 | perms = combinations(exps, 2)
169 | bad_pairs = set()
170 | for perm in perms:
171 | left, right = perm
172 | if (
173 | self.__dice(self.explanations[left], self.explanations[right])
174 | > dice_coefficient
175 | ):
176 | bad_pairs.add(perm)
177 |
178 | for s in powerset(exps, reverse=reverse):
179 | found = True
180 | for bp in bad_pairs:
181 | if bp[0] in s and bp[1] in s:
182 | found = False
183 | break
184 | if found:
185 | if len(s) >= clause_len:
186 | clause_len = len(s)
187 | clauses.append(s)
188 | else:
189 | break
190 |
191 | clauses = sorted(clauses, key=lambda x: clause_area(x, sizes))
192 | return clauses
193 |
194 | def contrastive(self, clauses):
195 | for clause in clauses:
196 | for subset in powerset(clause, reverse=False):
197 | mask = sum([self.explanations[x] for x in subset])
198 | mask = mask.to(tt.bool) # type: ignore
199 | sufficient = tt.where(mask, self.data.data, self.data.mask_value) # type: ignore
200 | counterfactual = tt.where(mask, self.data.mask_value, self.data.data) # type: ignore
201 | ps = self.prediction_func(sufficient)[0]
202 | pn = self.prediction_func(counterfactual)[0]
203 |
204 | if (
205 | ps.classification == self.data.target.classification # type: ignore
206 | and pn.classification != self.data.target.classification # type: ignore
207 | ):
208 | logger.info(
209 | "found sufficient and necessary explanation of class %d, %d with confidence %f",
210 | ps.classification,
211 | pn.classification,
212 | pn.confidence,
213 | )
214 | self.final_mask = mask
215 | return subset
216 | logger.warning(
217 | "ReX is unable to find a counterfactual, so not producing an output. Exiting here..."
218 | )
219 | exit()
220 |
221 | def __random_step_from(self, origin, width, height, step=5):
222 | c, r = origin
223 | # flip a coin to move left (0) or right (1)
224 | c_dir = np.random.randint(0, 2)
225 | c = c - step if c_dir == 0 else c + step
226 | if c < 0:
227 | c = 0
228 | if c > width:
229 | c = width
230 |
231 | # flip a coin to move down (0) or up (1)
232 | r_dir = np.random.randint(0, 2)
233 | r = r - step if r_dir == 0 else r + step
234 | if r < 0:
235 | r = 0
236 | if r > height:
237 | r = height
238 | logger.debug(f"trying new location: moving from {origin} to {(c, r)}")
239 | return (c, r)
240 |
241 | def __random_location(self):
242 | assert self.data.model_width is not None
243 | assert self.data.model_height is not None
244 | origin = random_coords(
245 | Distribution.Uniform,
246 | self.data.model_width * self.data.model_height,
247 | )
248 |
249 | return np.unravel_index(origin, (self.data.model_height, self.data.model_width)) # type: ignore
250 |
251 | def spotlight_search(self, origin=None):
252 | if origin is None:
253 | centre = self.__random_location()
254 | else:
255 | centre = origin
256 |
257 | ret, resp, conf = self._Explanation__spatial( # type: ignore
258 | centre=centre, expansion_limit=self.args.no_expansions
259 | )
260 |
261 | steps = 0
262 | while ret == SpatialSearch.NotFound and steps < self.args.max_spotlight_budget:
263 | if self.args.spotlight_objective_function == "none":
264 | centre = self.__random_location()
265 | ret, resp, conf = self._Explanation__spatial( # type: ignore
266 | centre=centre, expansion_limit=self.args.no_expansions
267 | )
268 | else:
269 | new_resp = 0.0
270 | while new_resp < resp:
271 | centre = self.__random_step_from(
272 | centre,
273 | self.data.model_height,
274 | self.data.model_width,
275 | step=self.args.spotlight_step,
276 | )
277 | ret, new_resp, conf = self._Explanation__spatial( # type: ignore
278 | centre=centre, expansion_limit=self.args.no_expansions
279 | )
280 | if ret == SpatialSearch.Found:
281 | return conf
282 | ret, resp, conf = self._Explanation__spatial( # type: ignore
283 | centre=centre, expansion_limit=self.args.no_expansions
284 | )
285 | steps += 1
286 | return conf
287 |
--------------------------------------------------------------------------------
/rex_xai/input/input_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from typing import Optional
3 | import numpy as np
4 | import torch as tt
5 |
6 | from enum import Enum
7 |
8 | from rex_xai.mutants.occlusions import spectral_occlusion, context_occlusion
9 | from rex_xai.responsibility.prediction import Prediction
10 | from rex_xai.utils.logger import logger
11 | from rex_xai.utils._utils import ReXDataError
12 |
13 | Setup = Enum("Setup", ["ONNXMPS", "ONNX", "PYTORCH"])
14 |
15 |
16 | def _guess_mode(input):
17 | if hasattr(input, "mode"):
18 | return input.mode
19 | if hasattr(input, "shape"):
20 | if len(input.shape) == 4:
21 | return "voxel"
22 | else:
23 | return "spectral"
24 |
25 |
26 | class Data:
27 | def __init__(
28 | self, input, model_shape, device="cpu", mode=None, process=True
29 | ) -> None:
30 | self.input = input
31 | self.mode = None
32 | self.target: Optional[Prediction] = None
33 | self.device = device
34 | self.setup: Optional[Setup] = None
35 | self.transposed = False
36 |
37 | self.mode = mode
38 | if mode is None:
39 | self.mode = _guess_mode(input)
40 |
41 | self.model_shape = model_shape
42 | height, width, channels, order, depth = self.__get_shape()
43 | self.model_height: Optional[int] = height
44 | self.model_width: Optional[int] = width
45 | self.model_depth: Optional[int] = depth
46 | self.model_channels: Optional[int] = channels if channels is not None else 1
47 | self.model_order = order
48 | self.mask_value = None
49 | self.background = None
50 | self.context = None
51 |
52 | if process:
53 | if self.mode == "RGB":
54 | if self.model_order == "first":
55 | self.transposed = True
56 | elif self.mode in ("tabular", "spectral"):
57 | self.data = self.input
58 | self.match_data_to_model_shape()
59 | elif self.mode == "voxel":
60 | self.data = self.input
61 | else:
62 | raise NotImplementedError
63 |
64 | def set_height(self, h: int):
65 | self.model_height = h
66 |
67 | def set_width(self, w: int):
68 | self.model_width = w
69 |
70 | def set_channels(self, c=None):
71 | self.model_channels = c
72 |
73 | def __repr__(self) -> str:
74 | data_info = f"Data: {self.mode}, {self.model_shape}, {self.model_height}, {self.model_width}, {self.model_channels}, {self.model_order}"
75 | if self.target is not None:
76 | target_info = repr(self.target)
77 | data_info = data_info + "\n\t Target:" + target_info
78 | return data_info
79 |
80 | def set_classification(self, cl):
81 | self.classification = cl
82 |
83 | def match_data_to_model_shape(self):
84 | """
85 | a PIL image has the from H * W * C, so
86 | if the model takes C * H * W we need to transpose self.data to
87 | get it into the correct form for the model to consume
88 | This function does *not* add in the batch channel at the beginning
89 | """
90 | assert self.data is not None
91 | if self.mode == "RGB" and self.model_order == "first":
92 | self.data = self.data.transpose(2, 0, 1) # type: ignore
93 | self.transposed = True
94 | if self.mode in ("tabular", "spectral"):
95 | self.data = self.generic_tab_preprocess()
96 | if self.mode == "voxel":
97 | pass
98 | self.data = self.try_unsqueeze()
99 |
100 | def generic_tab_preprocess(self):
101 | if isinstance(self.input, np.ndarray):
102 | self.data = self.input.astype("float32")
103 | arr = tt.from_numpy(self.data).to(self.device)
104 | else:
105 | arr = self.input
106 | for _ in range(len(self.model_shape) - len(arr.shape)):
107 | arr = arr.unsqueeze(0)
108 | return arr
109 |
110 | def load_data(self, astype="float32"):
111 | img = self.input.resize((self.model_height, self.model_width))
112 | img = np.array(img).astype(astype)
113 | self.data = img
114 | self.match_data_to_model_shape()
115 | self.data = tt.from_numpy(self.data).to(self.device)
116 |
117 | def _normalise_rgb_data(self, means, stds, norm):
118 | assert self.data is not None
119 | if self.model_channels != 3:
120 | raise ReXDataError(
121 | f"expected RGB data, but got data with the shape {self.model_shape}"
122 | )
123 |
124 | normed_data = self.data
125 | if norm is not None:
126 | normed_data /= norm
127 |
128 | if self.model_order == "first":
129 | if means is not None:
130 | for i, m in enumerate(means):
131 | normed_data[:, i, :, :] = normed_data[:, i, :, :] - m
132 | if stds is not None:
133 | for i, s in enumerate(stds):
134 | normed_data[:, i, :, :] = normed_data[:, i, :, :] / s
135 |
136 | if self.model_order == "last":
137 | if means is not None:
138 | for i, m in enumerate(means):
139 | normed_data[:, :, i] = normed_data[:, :, i] - m
140 | if stds is not None:
141 | for i, s in enumerate(stds):
142 | normed_data[:, :, i] = normed_data[:, :, i] / s
143 |
144 | return normed_data
145 |
146 | def try_unsqueeze(self):
147 | out = self.data
148 | if self.model_order == "first":
149 | dim = 0
150 | else:
151 | dim = -1
152 | if isinstance(self.data, tt.Tensor):
153 | for _ in range(len(self.model_shape) - len(self.data.shape) - 1):
154 | out = tt.unsqueeze(out, dim=dim) # type: ignore
155 | out = tt.unsqueeze(out, dim=0) # type: ignore
156 | else:
157 | for _ in range(len(self.model_shape) - len(self.data.shape) - 1): # type: ignore
158 | out = np.expand_dims(out, axis=dim) # type: ignore
159 | out = np.expand_dims(out, axis=0) # type: ignore
160 | return out
161 |
162 | def generic_image_preprocess(
163 | self,
164 | means=None,
165 | stds=None,
166 | astype="float32",
167 | norm: Optional[float] = 255.0,
168 | ):
169 | self.load_data(astype=astype)
170 |
171 | if self.mode == "RGB" and self.data is not None:
172 | self.data = self._normalise_rgb_data(means, stds, norm)
173 | self.try_unsqueeze()
174 | if self.mode == "L":
175 | self.data = self._normalise_rgb_data(means, stds, norm)
176 |
177 | def __get_shape(self):
178 | """returns height, width, channels, order, depth for the model"""
179 | if self.mode == "spectral":
180 | # an array of the form (h, w), so no channel info or order or depth
181 | if len(self.model_shape) == 2:
182 | return self.model_shape[0], self.model_shape[1], 1, None, None
183 | # an array of the form (batch, h, w), so no channel info or order or depth
184 | if len(self.model_shape) == 3:
185 | return self.model_shape[1], self.model_shape[2], 1, None, None
186 | if self.mode == "RGB":
187 | if len(self.model_shape) == 4:
188 | _, a, b, c = self.model_shape
189 | if a in (1, 3, 4):
190 | return b, c, a, "first", None
191 | else:
192 | return a, b, c, "last", None
193 | if self.mode == "voxel":
194 | if len(self.model_shape) == 4:
195 | _, w, h, d = self.model_shape # If batch is present
196 | return w, h, None, None, d
197 | else:
198 | w, h, d = self.model_shape
199 | return w, h, None, None, d
200 |
201 | raise ReXDataError(
202 | f"Incompatible 'mode' {self.mode} and 'model_shape' ({self.model_shape}), cannot get valid shape of Data object so exiting here"
203 | )
204 |
205 | def set_mask_value(self, m):
206 | assert self.data is not None
207 | # if m is a number, then if might still need to be normalised
208 |
209 | if m == "spectral" and self.mode != "spectral":
210 | logger.warning(
211 | "Mask value 'spectral' can only be used if mode is also 'spectral', using default mask value 0 instead"
212 | )
213 | m = 0
214 |
215 | match m:
216 | case int() | float() as m:
217 | self.mask_value = m
218 | case "min":
219 | self.mask_value = tt.min(self.data).item() # type: ignore
220 | case "mean":
221 | self.mask_value = tt.mean(self.data).item() # type: ignore
222 | case "spectral":
223 | self.mask_value = lambda m, d: spectral_occlusion(
224 | m, d, device=self.device
225 | )
226 | case "context":
227 | self.mask_value = lambda m, d: context_occlusion(m, d, self.context)
228 | # TODO: Add args for noise and setting the context as currently only available through custom script
229 | case _:
230 | raise ValueError(
231 | f"Invalid mask value {m}. Should be an integer, float, or one of 'min', 'mean', 'spectral'"
232 | )
233 |
--------------------------------------------------------------------------------
/rex_xai/input/onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """onnx model management"""
4 |
5 | from typing import Optional, Union, List
6 | import sys
7 | import os
8 | import torch as tt
9 | import platform
10 | from scipy.special import softmax
11 | import numpy as np
12 |
13 | import onnxruntime as ort
14 | from onnxruntime import InferenceSession
15 | from rex_xai.responsibility.prediction import Prediction, from_pytorch_tensor
16 | from rex_xai.input.input_data import Setup
17 |
18 | from rex_xai.utils.logger import logger
19 |
20 |
21 | class OnnxRunner:
22 | def __init__(self, session: InferenceSession, setup: Setup, device) -> None:
23 | self.session = session
24 | self.input_shape = session.get_inputs()[0].shape
25 | self.output_shape = session.get_outputs()[0].shape
26 | self.input_name = session.get_inputs()[0].name
27 | self.output_name = session.get_outputs()[0].name
28 | self.setup: Setup = setup
29 | self.device: str = device
30 |
31 | def run_on_cpu(
32 | self,
33 | tensors: Union[tt.Tensor, List[tt.Tensor]],
34 | target: Optional[Prediction],
35 | raw: bool,
36 | binary_threshold: Optional[float] = None,
37 | ):
38 | """Convert a pytorch tensor, or list of tensors, to numpy arrays on the cpu for onnx inference."""
39 | # check if it's a single tensor or a list of tensors
40 | if isinstance(tensors, list):
41 | tensor_size = tensors[0].shape[0]
42 | else:
43 | tensor_size = tensors.shape[0]
44 |
45 | if tensor_size == 1:
46 | tensors = tensors.detach().cpu().numpy() # type: ignore
47 | else:
48 | tensors = np.stack([t.detach().cpu().numpy() for t in tensors]) # type: ignore
49 |
50 | preds = []
51 |
52 | try:
53 | prediction = self.session.run(None, {self.input_name: tensors})[0]
54 | for i in range(0, prediction.shape[0]):
55 | confidences = softmax(prediction[i])
56 | if raw:
57 | for i in range(len(self.output_shape) - len(confidences.shape)):
58 | confidences = np.expand_dims(confidences, axis=0)
59 | return confidences
60 | if binary_threshold is not None:
61 | if confidences[0] >= binary_threshold:
62 | classification = 1
63 | else:
64 | classification = 0
65 | tc = confidences[0]
66 | else:
67 | classification = np.argmax(confidences)
68 | if target is not None:
69 | tc = confidences[target.classification]
70 | else:
71 | tc = None
72 | preds.append(
73 | Prediction(
74 | classification,
75 | confidences[classification],
76 | None,
77 | target=target,
78 | target_confidence=tc,
79 | )
80 | )
81 |
82 | return preds
83 | except Exception as e:
84 | logger.fatal(e)
85 | sys.exit(-1)
86 |
87 | def run_with_data_on_device(
88 | self,
89 | tensors,
90 | device,
91 | tsize,
92 | binary_threshold,
93 | raw=False,
94 | device_id=0,
95 | target=None,
96 | ):
97 | # input_shape = self.session.get_inputs()[0].shape # Gets the shape of the input (e.g [batch_size, 3, 224, 224])
98 | batch_size = len(tensors) if isinstance(tensors, list) else tensors.shape[0]
99 |
100 | if isinstance(tensors, list):
101 | tensors = [m.contiguous() for m in tensors]
102 | shape = tuple(
103 | [batch_size] + list(self.input_shape)[1:]
104 | ) # batch_size + remaining input shape
105 | ptr = tensors[0].data_ptr()
106 |
107 | else:
108 | tensors = tensors.contiguous()
109 | shape = tuple(tensors.shape)
110 | ptr = tensors.data_ptr()
111 |
112 | binding = self.session.io_binding()
113 | binding.bind_input(
114 | name=self.input_name,
115 | device_type=device,
116 | device_id=device_id,
117 | element_type=np.float32,
118 | shape=shape,
119 | buffer_ptr=ptr,
120 | )
121 |
122 | output_shape = [batch_size] + list(self.output_shape[1:])
123 |
124 | z_tensor = tt.empty(output_shape, dtype=tt.float32, device=device).contiguous()
125 |
126 | binding.bind_output(
127 | name=self.output_name,
128 | device_type=device,
129 | device_id=device_id,
130 | element_type=np.float32,
131 | shape=tuple(z_tensor.shape),
132 | buffer_ptr=z_tensor.data_ptr(),
133 | )
134 |
135 | self.session.run_with_iobinding(binding)
136 | if raw:
137 | return z_tensor
138 | return from_pytorch_tensor(z_tensor, target=target)
139 |
140 | def gen_prediction_function(self):
141 | if self.device == "cpu" or self.setup == Setup.ONNXMPS:
142 | return (
143 | lambda tensor,
144 | target=None,
145 | raw=False,
146 | binary_threshold=None: self.run_on_cpu(
147 | tensor, target, raw, binary_threshold
148 | ),
149 | self.input_shape,
150 | )
151 | if self.device == "cuda":
152 | return (
153 | lambda tensor,
154 | target=None,
155 | device=self.device,
156 | raw=False,
157 | binary_threshold=None: self.run_with_data_on_device(
158 | tensor,
159 | device,
160 | len(tensor),
161 | binary_threshold,
162 | raw=raw,
163 | target=target,
164 | ),
165 | self.input_shape,
166 | )
167 |
168 |
169 | def get_prediction_function(args):
170 | # def get_prediction_function(model_path, gpu: bool, logger_level=3):
171 | sess_options = ort.SessionOptions()
172 |
173 | ort.set_default_logger_severity(args.ort_logger)
174 |
175 | # are we (trying to) run on the gpu?
176 | if args.gpu:
177 | logger.info("using gpu for onnx inference session")
178 | if platform.uname().system == "Darwin":
179 | # note this is only true for M+ chips
180 | providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
181 | sess_options.intra_op_num_threads = args.intra_op_num_threads
182 | sess_options.inter_op_num_threads = args.inter_op_num_threads
183 | device = "mps"
184 | _, ext = os.path.splitext(os.path.basename(args.model))
185 | # for the moment, onnx does not seem to support data copying on mps, so we fall back to
186 | # copying data to the cpu for inference
187 | setup = Setup.ONNXMPS
188 | else:
189 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
190 | device = "cuda"
191 | setup = Setup.PYTORCH
192 |
193 | # set up sesson with gpu providers
194 | sess = ort.InferenceSession(
195 | args.model, sess_options=sess_options, providers=providers
196 | ) # type: ignore
197 | # are we running on cpu?
198 | else:
199 | logger.info("using cpu for onnx inference session")
200 | providers = ["CPUExecutionProvider"]
201 | sess = ort.InferenceSession(
202 | args.model, sess_options=sess_options, providers=providers
203 | )
204 | device = "cpu"
205 | setup = Setup.PYTORCH
206 |
207 | onnx_session = OnnxRunner(sess, setup, device)
208 |
209 | return onnx_session.gen_prediction_function()
210 |
--------------------------------------------------------------------------------
/rex_xai/mutants/distributions.py:
--------------------------------------------------------------------------------
1 | """distributions module"""
2 |
3 | from typing import Optional, Tuple
4 | from enum import Enum
5 | import numpy as np
6 | from scipy.stats import binom, betabinom
7 | from rex_xai.utils.logger import logger
8 |
9 | Distribution = Enum("Distribution", ["Binomial", "Uniform", "BetaBinomial", "Adaptive"])
10 |
11 |
12 | def _betabinom2d(height, width, alpha, beta):
13 | bx = betabinom(width, alpha, beta)
14 | by = betabinom(height, alpha, beta)
15 |
16 | w = np.array([bx.pmf(i) for i in range(0, width)]) # type: ignore
17 | h = np.array([by.pmf(i) for i in range(0, height)]) # type: ignore
18 |
19 | w = np.expand_dims(w, axis=0)
20 | h = np.expand_dims(h, axis=0)
21 |
22 | u = (h.T * w / np.sum(h.T * w)).ravel()
23 | p = np.random.choice(np.arange(0, len(u)), p=u)
24 | return p
25 |
26 |
27 | def _blend(dist, alpha, base):
28 | pmf = np.array([base.pmf(x) for x in range(0, len(dist))])
29 | blend = ((1.0 - alpha) * pmf) + (alpha * dist)
30 | blend /= np.sum(blend)
31 | return blend
32 |
33 |
34 | def _2d_adaptive(map, args: Tuple[int, int, int, int], alpha=0.0, base=None) -> int:
35 | # if the map exists and is not 0.0 everywhere...
36 | if map is not None and np.max(map) > 0.0:
37 | s = map[args[0] : args[1], args[2] : args[3]]
38 | sf = np.ndarray.flatten(s)
39 | # sf = np.max(sf) - sf
40 | sf /= np.sum(sf)
41 |
42 | # base = betabinom(0, len(sf), 1.1, 1.1)
43 | # if base is not None:
44 | # sf = _blend(alpha, base)
45 | pos = np.random.choice(np.arange(0, len(sf)), p=sf)
46 | return pos
47 |
48 | # if the map is empty or doesn't exist, return uniform
49 | return np.random.randint(1, (args[1] - args[0]) * (args[3] - args[2]))
50 |
51 |
52 | def str2distribution(d: str) -> Distribution:
53 | """converts string into Distribution enum"""
54 | if d == "binom":
55 | return Distribution.Binomial
56 | elif d == "uniform":
57 | return Distribution.Uniform
58 | elif d == "betabinom":
59 | return Distribution.BetaBinomial
60 | elif d == "adaptive":
61 | return Distribution.Adaptive
62 | else:
63 | logger.warning(
64 | "Invalid distribution '%s', reverting to default value Distribution.Uniform",
65 | d,
66 | )
67 | return Distribution.Uniform
68 |
69 |
70 | def random_coords(d: Optional[Distribution], *args, map=None) -> Optional[int]:
71 | """generates random coordinates given a distribution and args"""
72 |
73 | try:
74 | if d == Distribution.Adaptive:
75 | return _2d_adaptive(map, args[0])
76 |
77 | if d == Distribution.Uniform:
78 | return np.random.randint(1, args[0]) # type: ignore
79 |
80 | if d == Distribution.Binomial:
81 | start, stop, *dist_args = args[0]
82 | return binom(stop - start - 1, dist_args).rvs() + start
83 |
84 | if d == Distribution.BetaBinomial:
85 | return _betabinom2d(args[1], args[2], args[3][0], args[3][1])
86 |
87 | return None
88 | except ValueError:
89 | return None
90 |
--------------------------------------------------------------------------------
/rex_xai/mutants/mutant.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from typing import List, Optional
3 | import numpy as np
4 | import torch as tt
5 | from PIL import Image # type: ignore
6 |
7 | try:
8 | from anytree.cachedsearch import find
9 | except ImportError:
10 | from anytree.search import find
11 |
12 | import matplotlib.pyplot as plt
13 | from rex_xai.mutants.box import Box
14 | from rex_xai.responsibility.prediction import Prediction
15 | from rex_xai.utils.logger import logger
16 | from rex_xai.input.input_data import Data
17 | from rex_xai.utils._utils import add_boundaries, set_boolean_mask_value
18 |
19 | __combinations = [
20 | [
21 | 0,
22 | ],
23 | [
24 | 1,
25 | ],
26 | [
27 | 2,
28 | ],
29 | [
30 | 3,
31 | ],
32 | [0, 1],
33 | [0, 2],
34 | [0, 3],
35 | [1, 2],
36 | [1, 3],
37 | [2, 3],
38 | [0, 1, 2],
39 | [0, 1, 3],
40 | [0, 2, 3],
41 | [1, 2, 3],
42 | ]
43 |
44 |
45 | def _apply_to_data(mask, data: Data, masking_func):
46 | if isinstance(masking_func, (float, int)):
47 | res = tt.where(mask, data.data, masking_func) # type: ignore
48 | return res
49 | if callable(masking_func):
50 | return masking_func(mask, data.data)
51 |
52 | logger.warning("applying default masking value of 0")
53 | return tt.where(mask, data.data, 0) # type: ignore
54 |
55 |
56 | def get_combinations():
57 | return __combinations
58 |
59 |
60 | class Mutant:
61 | def __init__(self, data: Data, static, active, masking_func) -> None:
62 | self.shape = tuple(
63 | data.model_shape[1:]
64 | ) # the first element of shape is the batch information, so we drop that
65 | self.mode = data.mode
66 | self.channels: int = (
67 | data.model_channels if data.model_channels is not None else 1
68 | )
69 | self.order = data.model_order
70 | self.mask = tt.zeros(self.shape, dtype=tt.bool, device=data.device)
71 | self.static = static
72 | self.active = active
73 | self.prediction: Optional[Prediction] = None
74 | self.passing = False
75 | self.masking_func = masking_func
76 | self.depth = 0
77 |
78 | def __repr__(self) -> str:
79 | return f"ACTIVE: {self.active}, PREDICTION: {self.prediction}, PASSING: {self.passing}"
80 |
81 | def get_name(self):
82 | return self.active
83 |
84 | def update_status(self, target):
85 | if self.prediction is not None:
86 | if target.classification == self.prediction.classification:
87 | self.passing = True
88 |
89 | def get_length(self):
90 | return len(self.active.split("_"))
91 |
92 | def get_active_boxes(self):
93 | return self.active.split("_")
94 |
95 | def area(self) -> int:
96 | """Return the total area *not* concealed by the mutant."""
97 | tensor = tt.count_nonzero(self.mask)
98 | if tensor.numel() == 0 or tensor is None:
99 | return 0
100 | else:
101 | return int(tensor.item()) // self.channels
102 |
103 | def set_static_mask_regions(self, names, search_tree):
104 | for name in names:
105 | box = find(search_tree, lambda node: node.name == name)
106 | if box is not None:
107 | self.depth = max(self.depth, box.depth)
108 | self.set_mask_region_to_true(box)
109 |
110 | def set_active_mask_regions(self, boxes: List[Box]):
111 | for box in boxes:
112 | self.depth = max(self.depth, box.depth)
113 | self.set_mask_region_to_true(box)
114 |
115 | def set_mask_region_to_true(self, box: Box):
116 | set_boolean_mask_value(self.mask, self.mode, self.order, box)
117 |
118 | def apply_to_data(self, data: Data):
119 | return _apply_to_data(self.mask, data, self.masking_func)
120 |
121 | def save_mutant(self, data: Data, name=None, segs=None):
122 | if data.mode == "RGB":
123 | m = np.array(data.input.resize((data.model_height, data.model_width)))
124 | mask = self.mask.squeeze().detach().cpu().numpy()
125 | if data.transposed:
126 | # if transposed, we have C * H * W, so change that to H * W * C
127 | m = np.where(mask, m.transpose((2, 0, 1)), 0)
128 | m = m.transpose((1, 2, 0))
129 | else:
130 | # TODO m = m.transpose((0, 2, 1))
131 | m = np.where(mask, m, 255)
132 | # draw on the segment_mask, if available
133 | if segs is not None:
134 | m = add_boundaries(m, segs)
135 | img = Image.fromarray(m, data.mode)
136 | if name is not None:
137 | img.save(name)
138 | else:
139 | img.save(f"{self.get_name()}.png")
140 | # spectral or time series data
141 | if data.mode == "spectral":
142 | m = self.apply_to_data(data)
143 | fig = plt.figure()
144 | ax = fig.add_subplot(111)
145 | ax.plot(m[0][0].detach().cpu().numpy())
146 | plt.savefig(f"{self.get_name()}.png")
147 | # 3d image
148 | if data.mode == "voxel":
149 | # TODO
150 | logger.info("saving 3d mutants is not yet implemented")
151 | pass
152 |
--------------------------------------------------------------------------------
/rex_xai/mutants/occlusions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import torch as tt
3 | import numpy as np
4 | from scipy.ndimage import gaussian_filter
5 |
6 |
7 | def __split_groups(neg_mask):
8 | # for some reason, it's much faster to do this on the cpu with numpy
9 | # than it is to use tensor_split
10 | return np.split(neg_mask, np.where(np.diff(neg_mask) > 1)[0] + 1)
11 |
12 |
13 | def spectral_occlusion(mask: tt.Tensor, data: tt.Tensor, noise=0.02, device="cpu"):
14 | """Linear interpolated occlusion for spectral data, with optional added noise.
15 |
16 | @param mask: boolean valued NDArray
17 | @param data: data to be occluded
18 | @param noise: parameter for optional gaussian noise.
19 | Set to 0.0 if you want simple linear interpolation
20 |
21 | @return torch.Tensor
22 | """
23 | neg_mask = tt.where(mask == 0)[0]
24 | split = __split_groups(neg_mask.detach().cpu().numpy())
25 |
26 | # strangely, this all seems to run faster if we do it on the cpu.
27 | # TODO Needs further investigation
28 | local_data = np.copy(data.detach().cpu().numpy())
29 |
30 | for s in split:
31 | if len(s) <= 1:
32 | return tt.from_numpy(local_data).to(device)
33 | start = s[0]
34 | stop = s[-1]
35 | dstart = data[:, :, start][0][0].item()
36 | dstop = data[:, :, stop][0][0].item()
37 | interp = tt.linspace(dstart, dstop, stop - start)
38 | if noise > 0.0:
39 | interp += np.random.normal(0, noise, len(interp))
40 |
41 | local_data[0, 0, start:stop] = interp
42 |
43 | return tt.from_numpy(local_data).to(device)
44 |
45 |
46 | # Occlusions such as beach, sky, and other context-based occlusions
47 |
48 |
49 | # Medical-based occlusions could be a CT scan of a healthy patient
50 | def context_occlusion(mask: tt.Tensor, data: tt.Tensor, context: tt.Tensor, noise=0.5):
51 | """Context based occlusion with optional added noise.
52 |
53 | @param mask: boolean valued NDArray
54 | @param data: data to be occluded
55 | @param context: data to be used as occlusion e.g. CT scan of a healthy patient or a road
56 | @param noise: parameter for optional gaussian noise.
57 | Set to 0.0 for no noise
58 |
59 | @return torch.Tensor
60 | """
61 | if noise > 0.0:
62 | context = tt.tensor(gaussian_filter(context, sigma=noise), dtype=tt.float32)
63 | return tt.where(mask == 0, context, data)
64 |
--------------------------------------------------------------------------------
/rex_xai/output/database.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from datetime import datetime
3 | import zlib
4 | import torch as tt
5 | import sqlalchemy as sa
6 | from sqlalchemy import Boolean, Float, String, create_engine
7 | from sqlalchemy import Column, Integer, Unicode
8 | from sqlalchemy.orm import sessionmaker, DeclarativeBase
9 | from ast import literal_eval
10 |
11 | import pandas as pd
12 | import numpy as np
13 |
14 | from rex_xai.utils.logger import logger
15 | from rex_xai.input.config import CausalArgs, Strategy
16 | from rex_xai.explanation.explanation import Explanation
17 | from rex_xai.explanation.multi_explanation import MultiExplanation
18 |
19 |
20 | def _dataframe(db, table):
21 | return pd.read_sql_table(table, f"sqlite:///{db}")
22 |
23 |
24 | def _to_numpy(buffer, shape, dtype):
25 | return np.frombuffer(zlib.decompress(buffer), dtype=dtype).reshape(shape)
26 |
27 |
28 | def db_to_pandas(db, dtype=np.float32, table="rex", process=True):
29 | """for interactive use"""
30 | df = _dataframe(db, table=table)
31 |
32 | if process:
33 | df["responsibility"] = df.apply(
34 | lambda row: _to_numpy(
35 | row["responsibility"], literal_eval(row["responsibility_shape"]), dtype
36 | ),
37 | axis=1,
38 | )
39 | #
40 | df["explanation"] = df.apply(
41 | lambda row: _to_numpy(
42 | row["explanation"], literal_eval(row["explanation_shape"]), np.bool_
43 | ),
44 | axis=1,
45 | )
46 |
47 | return df
48 |
49 |
50 | def __multi_update(
51 | db,
52 | explanation,
53 | classification,
54 | target,
55 | target_map,
56 | final_mask,
57 | time_taken,
58 | multi_no,
59 | ):
60 | if isinstance(final_mask, tt.Tensor):
61 | final_mask = final_mask.detach().cpu().numpy()
62 | add_to_database(
63 | db,
64 | explanation.args,
65 | classification,
66 | target.confidence,
67 | target_map,
68 | final_mask,
69 | explanation.explanation_confidences[multi_no],
70 | time_taken,
71 | explanation.run_stats["total_passing"],
72 | explanation.run_stats["total_failing"],
73 | explanation.run_stats["max_depth_reached"],
74 | explanation.run_stats["avg_box_size"],
75 | multi=True,
76 | multi_no=multi_no,
77 | )
78 |
79 |
80 | def update_database(
81 | db,
82 | explanation: Explanation | MultiExplanation, # type: ignore
83 | time_taken=None,
84 | multi=False,
85 | clauses=None,
86 | ):
87 | target_map = explanation.target_map
88 |
89 | if isinstance(target_map, tt.Tensor):
90 | target_map = target_map.detach().cpu().numpy()
91 |
92 | target = explanation.data.target
93 | if target is None:
94 | logger.warning("unable to update database as target is None")
95 | return
96 | classification = int(target.classification) # type: ignore
97 |
98 | if not multi:
99 | final_mask = explanation.final_mask
100 | if explanation.final_mask is None:
101 | logger.warning("unable to update database as explanation is empty")
102 | return
103 | if isinstance(final_mask, tt.Tensor):
104 | final_mask = final_mask.detach().cpu().numpy()
105 |
106 | explanation_confidence = explanation.explanation_confidence
107 |
108 | add_to_database(
109 | db,
110 | explanation.args,
111 | classification,
112 | target.confidence,
113 | target_map,
114 | final_mask,
115 | explanation_confidence,
116 | time_taken,
117 | explanation.run_stats["total_passing"],
118 | explanation.run_stats["total_failing"],
119 | explanation.run_stats["max_depth_reached"],
120 | explanation.run_stats["avg_box_size"],
121 | )
122 |
123 | else:
124 | if type(explanation) is not MultiExplanation:
125 | logger.warning(
126 | "unable to update database, multi=True is only valid for MultiExplanation objects"
127 | )
128 | return
129 | else:
130 | for c, final_mask in enumerate(explanation.explanations):
131 | if clauses is not None:
132 | if c not in clauses:
133 | logger.warning("ignoring %s", c)
134 | else:
135 | __multi_update(
136 | db,
137 | explanation,
138 | classification,
139 | target,
140 | target_map,
141 | final_mask,
142 | time_taken,
143 | c,
144 | )
145 | else:
146 | __multi_update(
147 | db,
148 | explanation,
149 | classification,
150 | target,
151 | target_map,
152 | final_mask,
153 | time_taken,
154 | c,
155 | )
156 |
157 |
158 | def add_to_database(
159 | db,
160 | args: CausalArgs,
161 | target,
162 | confidence,
163 | responsibility,
164 | explanation,
165 | explanation_confidence,
166 | time_taken,
167 | passing,
168 | failing,
169 | depth_reached,
170 | avg_box_size,
171 | multi=False,
172 | multi_no=None,
173 | ):
174 | if multi:
175 | id = hash(str(datetime.now().time()) + str(multi_no))
176 | else:
177 | id = hash(str(datetime.now().time()))
178 |
179 | responsibility_shape = responsibility.shape
180 | explanation_shape = explanation.shape
181 |
182 | object = DataBaseEntry(
183 | id,
184 | args.path,
185 | target,
186 | confidence,
187 | responsibility,
188 | responsibility_shape,
189 | explanation,
190 | explanation_shape,
191 | explanation_confidence,
192 | time_taken,
193 | depth_reached=depth_reached,
194 | avg_box_size=avg_box_size,
195 | tree_depth=args.tree_depth,
196 | search_limit=args.search_limit,
197 | iters=args.iters,
198 | min_size=args.min_box_size,
199 | distribution=str(args.distribution),
200 | distribution_args=str(args.distribution_args),
201 | )
202 | # if object is not None:
203 | object.multi = multi
204 | object.multi_no = multi_no
205 | object.passing = passing
206 | object.failing = failing
207 | object.total_work = passing + failing
208 | object.method = str(args.strategy)
209 | if args.strategy == Strategy.Spatial:
210 | object.spatial_radius = args.spatial_initial_radius
211 | object.spatial_eta = args.spatial_radius_eta
212 | if args.strategy == Strategy.MultiSpotlight:
213 | object.spotlights = args.spotlights
214 | object.spotlight_size = args.spotlight_size
215 | object.spotlight_eta = args.spotlight_eta
216 | object.obj_function = args.spotlight_objective_function
217 |
218 | db.add(object)
219 | db.commit()
220 |
221 |
222 | class Base(DeclarativeBase):
223 | pass
224 |
225 |
226 | class NumpyType(sa.types.TypeDecorator):
227 | impl = sa.types.LargeBinary
228 |
229 | cache_ok = True
230 |
231 | def process_bind_param(self, value, dialect):
232 | if value is not None:
233 | return zlib.compress(value, 9)
234 |
235 | def process_result_value(self, value, dialect):
236 | return value
237 |
238 |
239 | class DataBaseEntry(Base):
240 | __tablename__ = "rex"
241 | id = Column(Integer, primary_key=True)
242 | path = Column(Unicode(100))
243 | target = Column(Integer)
244 | confidence = Column(Float)
245 | time = Column(Float)
246 | responsibility = Column(NumpyType)
247 | responsibility_shape = Column(Unicode)
248 | total_work = Column(Integer)
249 | passing = Column(Integer)
250 | failing = Column(Integer)
251 | explanation = Column(NumpyType)
252 | explanation_shape = Column(Unicode)
253 | explanation_confidence = Column(Float)
254 | multi = Column(Boolean)
255 | multi_no = Column(Integer)
256 |
257 | # causal specific columns
258 | depth_reached = Column(Integer)
259 | avg_box_size = Column(Float)
260 | tree_depth = Column(Integer)
261 | search_limit_per_iter = Column(Integer)
262 | iters = Column(Integer)
263 | min_size = Column(Integer)
264 | distribution = Column(String)
265 | distribution_args = Column(String)
266 |
267 | # explanation specific columns
268 | spatial_radius = Column(Integer)
269 | spatial_eta = Column(Float)
270 |
271 | # spotlight columns
272 | method = Column(String)
273 | spotlights = Column(Integer)
274 | spotlight_size = Column(Integer)
275 | spotlight_eta = Column(Float)
276 | obj_function = Column(String)
277 |
278 | def __init__(
279 | self,
280 | id,
281 | path,
282 | target,
283 | confidence,
284 | responsibility,
285 | responsibility_shape,
286 | explanation,
287 | explanation_shape,
288 | explanation_confidence,
289 | time_taken,
290 | passing=None,
291 | failing=None,
292 | total_work=None,
293 | multi=False,
294 | multi_no=None,
295 | depth_reached=None,
296 | avg_box_size=None,
297 | tree_depth=None,
298 | search_limit=None,
299 | iters=None,
300 | min_size=None,
301 | distribution=None,
302 | distribution_args=None,
303 | initial_radius=None,
304 | radius_eta=None,
305 | method=None,
306 | spotlights=0,
307 | spotlight_size=0,
308 | spotlight_eta=0.0,
309 | obj_function=None,
310 | ):
311 | self.id = id
312 | self.path = path
313 | self.target = target
314 | self.confidence = confidence
315 | self.responsibility = responsibility
316 | self.responsibility_shape = str(responsibility_shape)
317 | self.explanation = explanation
318 | self.explanation_shape = str(explanation_shape)
319 | self.explanation_confidence = explanation_confidence
320 | self.time = time_taken
321 | self.total_work = total_work
322 | self.passing = passing
323 | self.failing = failing
324 | # multi status
325 | self.multi = multi
326 | self.multi_no = multi_no
327 | # causal
328 | self.depth_reached = depth_reached
329 | self.avg_box_size = avg_box_size
330 | self.tree_depth = tree_depth
331 | self.search_limit = search_limit
332 | self.iters = iters
333 | self.min_size = min_size
334 | self.distribution = distribution
335 | self.distribution_args = distribution_args
336 | # spatial
337 | self.spatial_radius = initial_radius
338 | self.spatial_eta = radius_eta
339 | self.method = method
340 | # spotlights
341 | self.spotlights = spotlights
342 | self.spotlight_size = spotlight_size
343 | self.spotlight_eta = spotlight_eta
344 | self.obj_function = obj_function
345 |
346 |
347 | def initialise_rex_db(name, echo=False):
348 | engine = create_engine(f"sqlite:///{name}", echo=echo)
349 | Base.metadata.create_all(engine, tables=[DataBaseEntry.__table__], checkfirst=True) # type: ignore
350 | Session = sessionmaker(bind=engine)
351 | s = Session()
352 | return s
353 |
--------------------------------------------------------------------------------
/rex_xai/responsibility/prediction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from typing import Optional
4 | import torch as tt
5 | import torch.nn.functional as F
6 | from numpy.typing import NDArray
7 | from typing import List
8 | from rex_xai.utils._utils import ff
9 |
10 |
11 | class Prediction:
12 | def __init__(
13 | self,
14 | pred=None,
15 | conf=None,
16 | box=None,
17 | target=None,
18 | target_confidence=None,
19 | ) -> None:
20 | self.classification: Optional[int] = pred
21 | self.confidence: Optional[float] = conf
22 | self.bounding_box: Optional[NDArray] = box
23 | self.target: Optional[int] = None if target is None else target.classification
24 | self.target_confidence: Optional[float] = target_confidence
25 |
26 | def __repr__(self) -> str:
27 | if self.bounding_box is None:
28 | if self.is_passing():
29 | return (
30 | f"FOUND_CLASS: {self.classification}, CONF: {self.confidence:.5f}"
31 | )
32 | else:
33 | if self.target is None:
34 | return f"FOUND_CLASS: {self.classification}, FOUND_CONF: {self.confidence:.5f}, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a"
35 | else:
36 | return f"FOUND_CLASS: {self.classification}, FOUND_CONF: {self.confidence:.5f}, TARGET_CLASS: {self.target}, TARGET_CONFIDENCE: {ff(self.target_confidence, '.5f')}"
37 |
38 | return f"CLASS: {self.classification}, CONF: {self.confidence:.5f}, TARGET_CLASS: {self.target}, TARGET_CONFIDENCE: {ff(self.target_confidence, '.5f')}, BOUNDING_BOX: {self.bounding_box}"
39 |
40 | def get_class(self):
41 | return self.classification
42 |
43 | def is_empty(self):
44 | return self.classification is None or self.confidence is None
45 |
46 | def is_passing(self):
47 | return self.target == self.classification
48 |
49 |
50 | def from_pytorch_tensor(tensor, target=None, binary_threshold=None) -> List[Prediction]:
51 | # TODO get this to handle binary models
52 | softmax_tensor = F.softmax(tensor, dim=1)
53 | prediction_scores, pred_labels = tt.topk(softmax_tensor, 1)
54 | predictions = []
55 | for i, (ps, pl) in enumerate(zip(prediction_scores, pred_labels)):
56 | p = Prediction(pl.item(), ps.item())
57 | if target is not None:
58 | p.target = target
59 | p.target_confidence = softmax_tensor[i, target.classification].item()
60 | predictions.append(p)
61 |
62 | return predictions
63 |
--------------------------------------------------------------------------------
/rex_xai/responsibility/resp_maps.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import numpy as np
3 | from typing import List
4 |
5 | from typing import Optional
6 |
7 | try:
8 | from anytree.cachedsearch import find
9 | except ImportError:
10 | from anytree import find
11 |
12 | from rex_xai.mutants.box import Box
13 | from rex_xai.input.config import CausalArgs
14 | from rex_xai.mutants.mutant import Mutant
15 | from rex_xai.input.input_data import Data
16 | from rex_xai.utils.logger import logger
17 | from rex_xai.utils._utils import ReXMapError
18 |
19 |
20 | class ResponsibilityMaps:
21 | def __init__(self) -> None:
22 | self.maps = {}
23 | self.counts = {}
24 |
25 | def __repr__(self) -> str:
26 | return str(self.counts)
27 |
28 | def get(self, k, increment=False):
29 | try:
30 | if increment:
31 | self.counts[k] += 1 # type: ignore
32 | return self.maps[k]
33 | except KeyError:
34 | return
35 |
36 | def new_map(self, k: int, height, width, depth=None):
37 | if depth is not None:
38 | self.maps[k] = np.zeros((height, width, depth), dtype="float32")
39 | self.counts[k] = 1
40 | else:
41 | self.maps[k] = np.zeros((height, width), dtype="float32")
42 | self.counts[k] = 1
43 |
44 | def items(self):
45 | return self.maps.items()
46 |
47 | def keys(self):
48 | return self.maps.keys()
49 |
50 | def len(self):
51 | return len(self.maps)
52 |
53 | def merge(self, maps):
54 | for k, v in maps.items():
55 | if np.max(v) == 0:
56 | break
57 | if k in self.maps:
58 | self.maps[k] += v
59 | else:
60 | self.maps[k] = v
61 |
62 | def negative_responsibility(self, target):
63 | for k, v in self.maps.items():
64 | if k != target:
65 | logger.debug(
66 | f"subtracting responsibility for class {k} from class {target}"
67 | )
68 | self.maps[target] = self.maps[target] - v # type: ignore
69 |
70 | def responsibility(self, mutant: Mutant, args: CausalArgs):
71 | responsibility = np.zeros(4, dtype=np.float32)
72 | parts = mutant.get_active_boxes()
73 | r = 1 / len(parts)
74 | for p in parts:
75 | i = np.uint(p[-1])
76 | if (
77 | args.weighted
78 | and mutant.prediction is not None
79 | and mutant.prediction.confidence is not None
80 | ):
81 | responsibility[i] += r * mutant.prediction.confidence
82 | else:
83 | responsibility[i] += r
84 | return responsibility
85 |
86 | def update_maps(
87 | self, mutants: List[Mutant], args: CausalArgs, data: Data, search_tree
88 | ):
89 | """Update the different responsibility maps with all passing mutants
90 | @params mutants: list of mutants
91 | @params args: causal args
92 | @params data: data
93 | @params search_tree: tree of boxes
94 |
95 | Mutates in place, does not return a value
96 | """
97 |
98 | for mutant in mutants:
99 | r = self.responsibility(mutant, args)
100 |
101 | k = None
102 | # check that there is a prediction value
103 | if mutant.prediction is not None:
104 | k = mutant.prediction.classification
105 | # if there's no prediction value, raise an exception
106 | if k is None:
107 | raise ReXMapError("the provided mutant has no known classification")
108 | # check if k has been seen before and has a map. If k is new, make a new map
109 | if k not in self.maps:
110 | self.new_map(k, data.model_height, data.model_width, data.model_depth)
111 |
112 | # get the responsibility map for k
113 | resp_map = self.get(k, increment=True)
114 | if resp_map is None:
115 | raise ValueError(
116 | f"unable to open or generate a responsibility map for classification {k}"
117 | )
118 |
119 | # we only increment responsibility for active boxes, not static boxes
120 | for box_name in mutant.get_active_boxes():
121 | box: Optional[Box] = find(search_tree, lambda n: n.name == box_name)
122 | if box is not None and box.area() > 0:
123 | index = np.uint(box_name[-1])
124 | local_r = r[index]
125 | # print(box.depth)
126 | if args.concentrate:
127 | local_r *= box.depth
128 | # Don't delete this code just yet as this is an alternative (less brutal)
129 | # local_r *= 1.0 / box.area()
130 | # scaling strategy that needs further investigation
131 | # scale = depth - 1
132 | # local_r = 2**(local_r * scale)
133 |
134 | if data.mode == "spectral":
135 | section = resp_map[0, box.col_start : box.col_stop]
136 | elif data.mode == "RGB":
137 | section = resp_map[
138 | box.row_start : box.row_stop,
139 | box.col_start : box.col_stop,
140 | ]
141 | elif data.mode == "voxel":
142 | section = resp_map[
143 | box.row_start : box.row_stop,
144 | box.col_start : box.col_stop,
145 | box.depth_start : box.depth_stop,
146 | ]
147 | else:
148 | logger.warning("not yet implemented")
149 | raise NotImplementedError
150 |
151 | section += local_r
152 | self.maps[k] = resp_map
153 |
154 | def subset(self, id):
155 | m = self.maps.get(id)
156 | c = self.counts.get(id)
157 | self.maps = {id: m}
158 | self.counts = {id: c}
159 |
--------------------------------------------------------------------------------
/rex_xai/rex_wrapper.py:
--------------------------------------------------------------------------------
1 | """main entry point to ReX"""
2 |
3 | from rex_xai.utils._utils import get_device
4 | from rex_xai.input.config import get_all_args
5 | from rex_xai.output.database import initialise_rex_db
6 |
7 | from rex_xai.explanation.rex import explanation
8 | from rex_xai.utils.logger import logger, set_log_level
9 | from rex_xai.input.config import validate_args
10 |
11 |
12 | def main():
13 | """main entry point to ReX cmdline tool"""
14 | args = get_all_args()
15 | validate_args(args)
16 | set_log_level(args.verbosity, logger)
17 |
18 | device = get_device(args.gpu)
19 |
20 | logger.debug("running ReX with the following args:\n %s", args)
21 |
22 | db = None
23 | if args.db is not None:
24 | db = initialise_rex_db(args.db)
25 |
26 | explanation(args, device, db)
27 |
--------------------------------------------------------------------------------
/rex_xai/utils/_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import importlib.metadata
4 | from itertools import chain, combinations
5 | from enum import Enum
6 | from typing import Tuple, Union, Dict
7 | from numpy.typing import NDArray
8 | import torch as tt
9 | import numpy as np
10 | from skimage.segmentation import mark_boundaries
11 | from rex_xai.utils.logger import logger
12 | from rex_xai.mutants.box import Box
13 |
14 | Strategy = Enum("Strategy", ["Global", "Spatial", "MultiSpotlight", "Contrastive"])
15 |
16 | Queue = Enum("Queue", ["Area", "All", "Intersection", "DC"])
17 |
18 | SpatialSearch = Enum("SpatialSearch", ["NotFound", "Found"])
19 |
20 |
21 | def one_d_permute(tensor):
22 | perm = tt.randperm(len(tensor))
23 | return tensor[perm], perm
24 |
25 |
26 | def powerset(r, reverse=True):
27 | ps = list(chain.from_iterable(combinations(r, lim) for lim in range(1, len(r) + 1)))
28 | if reverse:
29 | return reversed(ps)
30 | else:
31 | return ps
32 |
33 |
34 | def clause_area(clause, areas: Dict) -> int:
35 | tot = 0
36 | for c in clause:
37 | tot += areas[c]
38 | return tot
39 |
40 |
41 | class ReXError(Exception):
42 | pass
43 |
44 |
45 | class ReXTomlError(ReXError):
46 | def __init__(self, message) -> None:
47 | self.message = message
48 | super().__init__(self.message)
49 |
50 | def __str__(self) -> str:
51 | return f"ReXTomlError: {self.message}"
52 |
53 |
54 | class ReXPathError(ReXError):
55 | def __init__(self, message) -> None:
56 | self.message = message
57 | super().__init__(self.message)
58 |
59 | def __str__(self) -> str:
60 | return f"ReXPathError: no such file exists at {self.message}"
61 |
62 |
63 | class ReXScriptError(ReXError):
64 | def __init__(self, message) -> None:
65 | self.message = message
66 | super().__init__(self.message)
67 |
68 | def __str__(self) -> str:
69 | return f"ReXScriptError: {self.message}"
70 |
71 |
72 | class ReXDataError(ReXError):
73 | def __init__(self, message) -> None:
74 | self.message = message
75 | super().__init__(self.message)
76 |
77 | def __str__(self) -> str:
78 | return f"ReXDataError: {self.message}"
79 |
80 |
81 | class ReXMapError(ReXError):
82 | def __init__(self, message) -> None:
83 | self.message = message
84 | super().__init__(self.message)
85 |
86 | def __str__(self) -> str:
87 | return f"ReXMapError: {self.message}"
88 |
89 |
90 | def xlogx(ps):
91 | f = np.vectorize(_xlogx)
92 | return f(ps)
93 |
94 |
95 | def _xlogx(p):
96 | if p == 0.0:
97 | return 0.0
98 | else:
99 | return p * np.log2(p)
100 |
101 |
102 | def add_boundaries(
103 | img: Union[NDArray, tt.Tensor], segs: NDArray, colour=None
104 | ) -> NDArray:
105 | if colour is None:
106 | m = mark_boundaries(img, segs, mode="thick")
107 | else:
108 | m = mark_boundaries(img, segs, colour, mode="thick")
109 | m *= 255 # type: ignore
110 | m = m.astype(np.uint8)
111 | return m
112 |
113 |
114 | def get_device(gpu: bool):
115 | if tt.backends.mps.is_available() and gpu:
116 | return tt.device("mps")
117 | if tt.device("cuda") and gpu:
118 | return tt.device("cuda")
119 | if gpu:
120 | logger.warning("gpu not available")
121 | return tt.device("cpu")
122 |
123 |
124 | def get_map_locations(map, reverse=True):
125 | if isinstance(map, tt.Tensor):
126 | map = map.detach().cpu().numpy()
127 | coords = []
128 | for i, r in enumerate(np.nditer(map)):
129 | coords.append((r, np.unravel_index(i, map.shape)))
130 | coords = sorted(coords, reverse=reverse)
131 | return coords
132 |
133 |
134 | def set_boolean_mask_value(
135 | tensor,
136 | mode,
137 | order,
138 | coords: Union[Box, Tuple[NDArray, NDArray]],
139 | val: bool = True,
140 | ):
141 | if isinstance(coords, Box):
142 | if mode in ("spectral", "tabular"):
143 | h = coords.col_start
144 | w = coords.col_stop
145 | elif mode == "voxel":
146 | h = slice(coords.row_start, coords.row_stop)
147 | w = slice(coords.col_start, coords.col_stop)
148 | d = slice(coords.depth_start, coords.depth_stop)
149 | else:
150 | h = slice(coords.row_start, coords.row_stop)
151 | w = slice(coords.col_start, coords.col_stop)
152 | else:
153 | if mode == "voxel":
154 | h = coords[0]
155 | w = coords[1]
156 | d = coords[2]
157 | else:
158 | h = coords[0]
159 | w = coords[1]
160 |
161 | # three channels
162 | if mode == "RGB":
163 | # (C, H, W)
164 | if order == "first":
165 | tensor[:, h, w] = val
166 | # (H, W, C)
167 | else:
168 | tensor[h, w, :] = val
169 | elif mode == "L":
170 | if order == "first":
171 | # (1, H, W)
172 | tensor[0, h, w] = val
173 | else:
174 | tensor[h, w, :] = val
175 | elif mode in ("spectral", "tabular"):
176 | if len(tensor.shape) == 1:
177 | tensor[h:w] = val
178 | else:
179 | tensor[0, h:w] = val
180 | # elif mode == "tabular":
181 |
182 | elif mode == "voxel":
183 | tensor[h, w, d] = val
184 | else:
185 | raise ReXError("mode not recognised")
186 |
187 |
188 | def ff(obj, fmt):
189 | """
190 | Like format(obj, fmt), but returns the string 'None' if obj is None.
191 | See the help for format() to see acceptable values for fmt.
192 | """
193 | return "None" if obj is None else format(obj, fmt)
194 |
195 |
196 | def version():
197 | return importlib.metadata.version("rex-xai")
198 |
--------------------------------------------------------------------------------
/rex_xai/utils/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """simple wrapper around python logging module"""
3 |
4 | import logging
5 | import sys
6 | import os
7 |
8 |
9 | def set_log_level(i: int, rex_logger):
10 | """sets the logging level for rex"""
11 | if i == 0:
12 | rex_logger.setLevel(logging.CRITICAL)
13 | elif i == 1:
14 | rex_logger.setLevel(logging.WARNING)
15 | elif i == 2:
16 | rex_logger.setLevel(logging.INFO)
17 | else:
18 | rex_logger.setLevel(logging.DEBUG)
19 |
20 |
21 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
22 | logger = logging.getLogger("ReX")
23 | logging.basicConfig(stream=sys.stdout, level=logging.WARNING)
24 |
--------------------------------------------------------------------------------
/scripts/pytorch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import platform
4 | from torchvision.models import get_model
5 | from torchvision import transforms as T
6 | import torch as tt
7 | import torch.nn.functional as F
8 | from PIL import Image # type: ignore
9 | from rex_xai.input.input_data import Data
10 | from rex_xai.responsibility.prediction import from_pytorch_tensor
11 |
12 |
13 | model = get_model('resnet50', weights="DEFAULT")
14 | model.eval()
15 |
16 | if platform.uname().system == "Darwin":
17 | model.to("mps")
18 | else:
19 | model.to("cuda")
20 |
21 |
22 | def preprocess(path, shape, device, mode) -> Data:
23 | transform = T.Compose(
24 | [
25 | T.Resize((224, 224)),
26 | T.ToTensor(),
27 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28 | ]
29 | )
30 | # open the image with mode "RGB"
31 | img = Image.open(path).convert("RGB")
32 | # create a Data object
33 | data = Data(img, shape, device, mode='RGB')
34 | # manually set the data to the transformed image for model consumption
35 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore
36 |
37 | return data
38 |
39 |
40 | def prediction_function(mutants, target=None, raw=False, binary_threshold=False):
41 | with tt.no_grad(): # we don't use the grad and inference is faster without it
42 | tensor = model(mutants)
43 | if raw: # used when computing insertion/deletion curves
44 | return F.softmax(tensor, dim=1)
45 | # from_pytorch_tensor consumes a tensor and converts it to a Prediction object
46 | # you can alternatively use your own function here
47 | return from_pytorch_tensor(tensor, target=target)
48 |
49 |
50 | def model_shape():
51 | return ["N", 3, 224, 224]
52 |
--------------------------------------------------------------------------------
/scripts/spectral_pytorch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import platform
4 | import numpy as np
5 | import torch as tt
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from rex_xai.input.input_data import Data
10 | from rex_xai.responsibility.prediction import from_pytorch_tensor
11 |
12 | class ConvNet(nn.Module):
13 | def __init__(self):
14 | super(ConvNet, self).__init__()
15 | self.layer1 = nn.Sequential(
16 | nn.Conv1d(1, 20, kernel_size=7, stride=1, padding=0),
17 | nn.BatchNorm1d(20),
18 | nn.ReLU(),
19 | nn.MaxPool1d(kernel_size=7, stride=3))
20 | self.layer2 = nn.Sequential(
21 | nn.Conv1d(20, 40, kernel_size=5, stride=1, padding=0),
22 | nn.BatchNorm1d(40),
23 | nn.ReLU(),
24 | nn.MaxPool1d(kernel_size=5, stride=2))
25 | self.layer3 = nn.Sequential(
26 | nn.Conv1d(40, 40, kernel_size=3, stride=1, padding=0),
27 | nn.BatchNorm1d(40),
28 | nn.ReLU(),
29 | nn.MaxPool1d(kernel_size=3, stride=1))
30 | self.fc1 = nn.Linear(8640, 60)
31 | self.fc2 = nn.Linear(60, 2)
32 | def forward(self, x):
33 | out = self.layer1(x)
34 | out = self.layer2(out)
35 | out = self.layer3(out)
36 | out = out.reshape(out.size(0), -1)
37 | out = F.relu(self.fc1(out))
38 | out = F.dropout(out, 0.2)
39 | out = self.fc2(out)
40 | return out
41 |
42 | # model = ConvNet().to('mps')
43 | if platform.uname().system == "Darwin":
44 | model = ConvNet().to("mps")
45 | else:
46 | model = ConvNet().to("cuda")
47 |
48 | model.load_state_dict(tt.load("simple_DNA_model.pt", map_location='mps', weights_only=True))
49 | model.eval()
50 |
51 |
52 | if platform.uname().system == "Darwin":
53 | model.to("mps")
54 | else:
55 | model.to("cuda")
56 |
57 | def preprocess(path, shape, device, mode) -> Data:
58 | path = tt.from_numpy(np.load(path)).float().unsqueeze(0).unsqueeze(0).to('mps')
59 | data = Data(path , (1, 1356), 'mps', mode="spectral", process=True)
60 | data.set_width(1356)
61 | data.set_height(1)
62 | return data
63 |
64 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
65 | with tt.no_grad():
66 | tensor = model(mutants)
67 | if raw:
68 | return F.softmax(tensor, dim=1)
69 | return from_pytorch_tensor(tensor, target=target)
70 |
71 | def model_shape():
72 | return ["N", 1, 1356]
73 |
74 |
--------------------------------------------------------------------------------
/scripts/timm_script.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import platform
4 | import timm
5 | from PIL import Image
6 | import torch as tt
7 | import torch.nn.functional as F
8 | from rex_xai.input.input_data import Data
9 | from rex_xai.responsibility.prediction import from_pytorch_tensor
10 |
11 | model = timm.create_model("resnet50.a1_in1k", pretrained=True)
12 | model.eval()
13 |
14 | if platform.uname().system == "Darwin":
15 | if tt.mps.is_available():
16 | model.to("mps")
17 | else:
18 | if tt.cuda.is_available():
19 | model.to("cuda")
20 |
21 |
22 | def preprocess(path, shape, device, mode) -> Data:
23 | data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) # type: ignore
24 |
25 | transform = timm.data.create_transform(**data_cfg) # type: ignore
26 |
27 | img = Image.open(path).convert("RGB")
28 | data = Data(img, shape, device, mode='RGB')
29 | data.data = transform(img).unsqueeze(0).to(device)
30 |
31 | return data
32 |
33 |
34 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
35 | with tt.no_grad():
36 | tensor = model(mutants)
37 | if raw:
38 | return F.softmax(tensor, dim=1)
39 | return from_pytorch_tensor(tensor)
40 |
41 |
42 | def model_shape():
43 | batch_size = ["N"] # put your batch size here
44 | input_size = timm.data.resolve_data_config(model.pretrained_cfg)['input_size']
45 |
46 | return batch_size + list(input_size)
47 |
--------------------------------------------------------------------------------
/shell/_rex:
--------------------------------------------------------------------------------
1 | #compdef ReX
2 |
3 | _rex() {
4 | local state
5 |
6 | _arguments \
7 | '1: :_files'\
8 | '--help[show help]' \
9 | '--surface[Plot a 3D responsibility map, optionally save to ]' \
10 | '--heatmap[Plot a 2D responsibility map, optionally save to ]' \
11 | '--script[PyTorch compatible python script]:::_files' \
12 | '--model[onnx model to use]:::_files' \
13 | '--no_extract[Do not extract an explanation from the responsibility map]' \
14 | '--config[optional config file to use with ReX]'\
15 | '--output[save explanation to ]:::_files'\
16 | {-v,--verbosity}'[set verbosity level, one of -v, -vv, or -vvv]'\
17 | {-q,--quiet}'[set verbosity level to 0 (errors only)]'\
18 | '--strategy[explanation extraction strategy, by default ]:::(global spatial)'\
19 | {-db,--database}'[store output in sqlite database , creating it if necessary]:::_files'\
20 | '--multi[multiple explanations, with optional spotlights]'\
21 | '--contrastive[sufficent and necessary explanations, with optional spotlights]'\
22 | '--iters[number of iterations of the main algorithm to perform]'\
23 | '--analyse[perform some basic analysis of the explanation, optionally saved to ]'\
24 | '--mode[assist ReX with input type, one of tabular | spectral | RGB | voxel]:::(tabular spectral RGB voxel)'\
25 | '--spectral[shortcut for --mode spectral]'\
26 | '--version[print version number and quit]'\
27 | }
28 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as tt
3 | import pytest
4 | from cached_path import cached_path
5 | from rex_xai.utils._utils import get_device
6 | from rex_xai.mutants.box import initialise_tree
7 | from rex_xai.input.config import CausalArgs, process_custom_script, Strategy
8 | from rex_xai.mutants.distributions import Distribution
9 | from rex_xai.explanation.rex import (
10 | calculate_responsibility,
11 | get_prediction_func_from_args,
12 | load_and_preprocess_data,
13 | predict_target,
14 | try_preprocess,
15 | )
16 | from rex_xai.explanation.explanation import Explanation
17 | from rex_xai.explanation.multi_explanation import MultiExplanation
18 | from syrupy.extensions.amber.serializer import AmberDataSerializer
19 | from syrupy.filters import props
20 | from syrupy.matchers import path_type
21 |
22 | from rex_xai.input.input_data import Data
23 |
24 |
25 | @pytest.fixture
26 | def snapshot_explanation(snapshot):
27 | return snapshot.with_defaults(
28 | exclude=props(
29 | "obj_function", # pointer to function that will differ between runs
30 | "spotlight_objective_function", # pointer to function that will differ between runs
31 | "script", # path that differs between systems
32 | "script_location", # path that differs between systems
33 | "model",
34 | "target_map", # large array
35 | "final_mask", # large array
36 | "explanation" # large array
37 | ),
38 | matcher=path_type(
39 | types=(CausalArgs,),
40 | replacer=lambda data, _: AmberDataSerializer.object_as_named_tuple( #type: ignore
41 | data
42 | ), # needed to allow exclude to work for custom classes
43 | )
44 | )
45 |
46 |
47 | @pytest.fixture(scope="session")
48 | def resnet50():
49 | resnet50_path = cached_path(
50 | "https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/resnet/model/resnet50-v1-7.onnx"
51 | )
52 | return resnet50_path
53 |
54 |
55 | @pytest.fixture(scope="session")
56 | def DNA_model():
57 | DNA_model_path = cached_path(
58 | "https://github.com/ReX-XAI/models/raw/6f66a5c0e1480411436be828ee8312e72f0035e1/spectral/simple_DNA_model.onnx"
59 | )
60 | return DNA_model_path
61 |
62 |
63 | @pytest.fixture
64 | def args():
65 | args = CausalArgs()
66 | args.path = "tests/test_data/dog.jpg"
67 | args.iters = 2
68 | args.search_limit = 1000
69 | args.gpu = False
70 |
71 | return args
72 |
73 |
74 | @pytest.fixture
75 | def args_custom(args):
76 | process_custom_script("tests/scripts/pytorch_resnet50.py", args)
77 | args.seed = 42
78 |
79 | return args
80 |
81 |
82 | @pytest.fixture
83 | def args_torch_swin_v2_t(args):
84 | process_custom_script("tests/scripts/pytorch_swin_v2_t.py", args)
85 | args.seed = 42
86 |
87 | return args
88 |
89 |
90 | @pytest.fixture
91 | def args_onnx(args, resnet50):
92 | args.model = resnet50
93 | args.seed = 100
94 |
95 | return args
96 |
97 |
98 | @pytest.fixture
99 | def args_multi(args_custom):
100 | args = args_custom
101 | args.path = "tests/test_data/peacock.jpg"
102 | args.iters = 5
103 | args.strategy = Strategy.MultiSpotlight
104 | args.spotlights = 5
105 |
106 | return args
107 |
108 |
109 | @pytest.fixture
110 | def model_shape(args_custom):
111 | prediction_func, model_shape = get_prediction_func_from_args(args_custom)
112 |
113 | return model_shape
114 |
115 |
116 | @pytest.fixture
117 | def prediction_func(args_custom):
118 | prediction_func, model_shape = get_prediction_func_from_args(args_custom)
119 |
120 | return prediction_func
121 |
122 |
123 | @pytest.fixture
124 | def model_shape_swin_v2_t(args_torch_swin_v2_t):
125 | prediction_func, model_shape = get_prediction_func_from_args(args_torch_swin_v2_t)
126 |
127 | return model_shape
128 |
129 |
130 | @pytest.fixture
131 | def prediction_func_swin_v2_t(args_torch_swin_v2_t):
132 | prediction_func, model_shape = get_prediction_func_from_args(args_torch_swin_v2_t)
133 |
134 | return prediction_func
135 |
136 |
137 | @pytest.fixture
138 | def data(args_custom, model_shape, cpu_device):
139 | data = try_preprocess(args_custom, model_shape, device=cpu_device)
140 | return data
141 |
142 |
143 | @pytest.fixture
144 | def data_custom(args_custom, model_shape, cpu_device):
145 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom)
146 | data.set_mask_value(args_custom.mask_value)
147 | return data
148 |
149 |
150 | @pytest.fixture
151 | def data_multi(args_multi, model_shape, prediction_func, cpu_device):
152 | data = load_and_preprocess_data(model_shape, cpu_device, args_multi)
153 | data.set_mask_value(args_multi.mask_value)
154 | data.target = predict_target(data, prediction_func)
155 | return data
156 |
157 |
158 | @pytest.fixture(scope="session")
159 | def cpu_device():
160 | device = get_device(gpu=False)
161 |
162 | return device
163 |
164 |
165 | @pytest.fixture
166 | def exp_custom(data_custom, args_custom, prediction_func):
167 | data_custom.target = predict_target(data_custom, prediction_func)
168 | maps, run_stats = calculate_responsibility(
169 | data_custom, args_custom, prediction_func
170 | )
171 | exp = Explanation(maps, prediction_func, data_custom, args_custom, run_stats)
172 |
173 | return exp
174 |
175 |
176 | @pytest.fixture
177 | def exp_onnx(args_onnx, cpu_device):
178 | prediction_func, model_shape = get_prediction_func_from_args(args_onnx)
179 | data = load_and_preprocess_data(model_shape, cpu_device, args_onnx)
180 | data.set_mask_value(args_onnx.mask_value)
181 | data.target = predict_target(data, prediction_func)
182 | maps, run_stats = calculate_responsibility(data, args_onnx, prediction_func)
183 | exp = Explanation(maps, prediction_func, data, args_onnx, run_stats)
184 |
185 | return exp
186 |
187 |
188 | @pytest.fixture
189 | def exp_extracted(exp_custom):
190 | exp_custom.extract(Strategy.Global)
191 |
192 | return exp_custom
193 |
194 |
195 | @pytest.fixture
196 | def exp_multi(args_multi, data_multi, prediction_func):
197 | maps, run_stats = calculate_responsibility(data_multi, args_multi, prediction_func)
198 | multi_exp = MultiExplanation(
199 | maps, prediction_func, data_multi, args_multi, run_stats
200 | )
201 | multi_exp.extract(args_multi.strategy)
202 | return multi_exp
203 |
204 |
205 | @pytest.fixture
206 | def data_3d():
207 | voxel = np.zeros((1, 64, 64, 64), dtype=np.float32)
208 | voxel[0:30, 20:30, 20:35] = 1
209 | return Data(
210 | input=voxel,
211 | model_shape=[1, 64, 64, 64],
212 | device="cpu",
213 | mode="voxel"
214 | )
215 |
216 | @pytest.fixture
217 | def data_2d():
218 | return Data(
219 | input=np.arange(1, 64, 64),
220 | model_shape=[1, 64, 64],
221 | device="cpu"
222 | )
223 |
224 | @pytest.fixture
225 | def box_3d():
226 | return initialise_tree(
227 | r_lim=64,
228 | c_lim=64,
229 | d_lim=64,
230 | r_start=0,
231 | c_start=0,
232 | d_start=0,
233 | distribution=Distribution.Uniform,
234 | distribution_args=None,
235 | )
236 |
237 | @pytest.fixture
238 | def box_2d():
239 | return initialise_tree(
240 | r_lim=64,
241 | c_lim=64,
242 | r_start=0,
243 | c_start=0,
244 | distribution=Distribution.Uniform,
245 | distribution_args=None,
246 | )
247 |
248 | @pytest.fixture
249 | def resp_map_2d():
250 | return np.zeros((64, 64), dtype="float32")
251 |
252 | @pytest.fixture
253 | def resp_map_3d():
254 | resp_map = tt.zeros((64, 64, 64), dtype=tt.float32)
255 | resp_map[0:10, 20:25, 20:35] = 1
256 | return resp_map
--------------------------------------------------------------------------------
/tests/scripts/pytorch_3d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs.Tensor import Tensor
3 | import torch as tt
4 | import torch.nn.functional as F
5 | from rex_xai.input.input_data import Data
6 | from rex_xai.responsibility.prediction import from_pytorch_tensor
7 | from monai.networks.nets import DenseNet121
8 | from monai.transforms import Resize, LoadImage
9 |
10 | # Load the sample 3D model
11 | device = tt.device("cuda" if tt.cuda.is_available() else "cpu")
12 | model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1).to(device)
13 | model.to(device)
14 | model.eval()
15 |
16 |
17 | def preprocess(path, shape, device, mode) -> Data:
18 | """
19 | The preprocessing function is executed before the
20 | model is called. It is used to prepare the input.
21 |
22 | Args:
23 | path: str
24 | The path to the input data
25 | shape: tuple
26 | The shape of the input data
27 | device: str
28 | The device to use indicate where the data should be loaded: "cpu" or "cuda"
29 | mode: str
30 | The mode of the input data: "voxel" for 3D data
31 |
32 | Returns:
33 | Data: The input data object
34 | The input data object that contains the processed data
35 | and the metadata of the input data including the mode, model_height,
36 | model_width, model_depth, model_channels, model_order, and transposed.
37 | """
38 | transform = Resize(spatial_size=(64, 64, 64))
39 | if path is str:
40 | volume = LoadImage()(path)
41 | transformed_volume = transform(volume)
42 | data = Data(transformed_volume, shape, device, mode=mode, process=False)
43 | elif path is Tensor:
44 | transformed_volume = transform(path)
45 | data = Data(transformed_volume, shape, device, mode=mode, process=False)
46 | else:
47 | raise ValueError("Invalid input type")
48 | data.mode = "voxel"
49 | data.model_shape = shape
50 | data.model_height = 64
51 | data.model_width = 64
52 | data.model_depth = 64
53 | data.transposed = True
54 |
55 | return data
56 |
57 |
58 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
59 | """
60 | The prediction function calls the model itself and returns the output.
61 |
62 | Args:
63 | mutants: Data
64 | The input data object
65 | target: None
66 | The target label
67 | raw: bool
68 | A flag to indicate if the output should be raw or not
69 | binary_threshold: None
70 | The binary threshold value
71 |
72 | Returns:
73 | list[Prediction]
74 | A list of prediction objects which each contain the output tensor,
75 | the target label, the confidence of the label, the classification confidence,
76 | and the classification label.
77 |
78 | """
79 | with tt.no_grad():
80 | tensor = model(mutants)
81 | if raw:
82 | return F.softmax(tensor, dim=1)
83 | return from_pytorch_tensor(tensor)
84 |
85 |
86 | def model_shape():
87 | return ["N", 64, 64, 64]
88 |
--------------------------------------------------------------------------------
/tests/scripts/pytorch_resnet50.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from torchvision.models import resnet50
4 | from torchvision import transforms as T
5 | import torch as tt
6 | import torch.nn.functional as F
7 | from PIL import Image # type: ignore
8 | from rex_xai.input.input_data import Data
9 | from rex_xai.responsibility.prediction import from_pytorch_tensor
10 |
11 |
12 | model = resnet50(weights="ResNet50_Weights.DEFAULT")
13 | model.eval()
14 |
15 | def preprocess(path, shape, device, mode) -> Data:
16 | transform = T.Compose(
17 | [
18 | T.Resize((224, 224)),
19 | T.ToTensor(),
20 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21 | ]
22 | )
23 | # open the image with mode "RGB"
24 | img = Image.open(path).convert("RGB")
25 | # create a Data object
26 | data = Data(img, shape, device, mode='RGB')
27 | # manually set the data to the transformed image for model consumption
28 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore
29 |
30 | return data
31 |
32 |
33 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
34 | with tt.no_grad():
35 | tensor = model(mutants)
36 | if raw:
37 | return F.softmax(tensor, dim=1)
38 | # from_pytorch_tensor consumes a tensor and converts it to a Prediction object
39 | # you can alternatively use your own function here
40 | return from_pytorch_tensor(tensor, target=target)
41 |
42 |
43 | def model_shape():
44 | return ["N", 3, 224, 224]
45 |
--------------------------------------------------------------------------------
/tests/scripts/pytorch_swin_v2_t.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from torchvision.models import swin_v2_t
4 | from torchvision import transforms as T
5 | import torch as tt
6 | import torch.nn.functional as F
7 | from PIL import Image # type: ignore
8 | from rex_xai.input.input_data import Data
9 | from rex_xai.responsibility.prediction import from_pytorch_tensor
10 |
11 |
12 | model = swin_v2_t(weights="DEFAULT")
13 | model.eval()
14 | model.to("cpu")
15 |
16 |
17 | def preprocess(path, shape, device, mode) -> Data:
18 | transform = T.Compose(
19 | [
20 | T.Resize((260, 260), T.InterpolationMode.BICUBIC),
21 | T.CenterCrop(256),
22 | T.ToTensor(),
23 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24 | ]
25 | )
26 | img = Image.open(path).convert("RGB")
27 | data = Data(img, shape, device, mode=mode, process=False)
28 | data.data = transform(img).unsqueeze(0).to(device) # type: ignore
29 | data.mode = "RGB"
30 | data.model_shape = shape
31 | data.model_height = 256
32 | data.model_width = 256
33 | data.model_channels = 3
34 | data.transposed = True
35 | data.model_order = "first"
36 |
37 | return data
38 |
39 |
40 | def prediction_function(mutants, target=None, raw=False, binary_threshold=None):
41 | with tt.no_grad():
42 | tensor = model(mutants)
43 | if raw:
44 | return F.softmax(tensor, dim=1)
45 | return from_pytorch_tensor(tensor)
46 |
47 |
48 | def model_shape():
49 | return ["N", 3, 256, 256]
50 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/__snapshots__/_explanation_onnx_test.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test__explanation_snapshot
3 | Explanation:
4 | CausalArgs:
5 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
6 | Target:FOUND_CLASS: 209, FOUND_CONF: 0.39424, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
7 | prediction function: . >
8 | ResponsibilityMaps: {209: 1}
9 | run statistics: {'total_passing': 41, 'total_failing': 169, 'max_depth_reached': 4, 'avg_box_size': 2747.41667} (5 dp)
10 | explanation: of shape torch.Size([3, 224, 224])
11 | final mask: of shape (3, 224, 224)
12 | explanation confidence: 0.12158 (5 dp)
13 | # ---
14 | # name: test__explanation_snapshot.1
15 | 7409862900317702363
16 | # ---
17 | # name: test_extract_analyze[Strategy.Global]
18 | 976237796511694519
19 | # ---
20 | # name: test_extract_analyze[Strategy.Global].1
21 | dict({
22 | 'area': 0.1654,
23 | 'deletion_curve': 0.0259,
24 | 'entropy': 5.3764,
25 | 'insertion_curve': 0.503,
26 | })
27 | # ---
28 | # name: test_extract_analyze[Strategy.Spatial]
29 | 976237796511694519
30 | # ---
31 | # name: test_extract_analyze[Strategy.Spatial].1
32 | dict({
33 | 'area': 0.1654,
34 | 'deletion_curve': 0.0259,
35 | 'entropy': 5.3764,
36 | 'insertion_curve': 0.503,
37 | })
38 | # ---
39 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/__snapshots__/_explanation_test.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test__explanation_snapshot[1]
3 | Explanation:
4 | CausalArgs:
5 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
6 | Target:FOUND_CLASS: 207, FOUND_CONF: 0.20462, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
7 | prediction function:
8 | ResponsibilityMaps: {207: 1}
9 | run statistics: {'total_passing': 84, 'total_failing': 126, 'max_depth_reached': 5, 'avg_box_size': 183.5625} (5 dp)
10 | explanation: of shape torch.Size([3, 224, 224])
11 | final mask: of shape (3, 224, 224)
12 | explanation confidence: 0.01513 (5 dp)
13 | # ---
14 | # name: test__explanation_snapshot[1].1
15 | 2537636458340732753
16 | # ---
17 | # name: test__explanation_snapshot[64]
18 | Explanation:
19 | CausalArgs:
20 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
21 | Target:FOUND_CLASS: 207, FOUND_CONF: 0.20462, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
22 | prediction function:
23 | ResponsibilityMaps: {207: 1}
24 | run statistics: {'total_passing': 84, 'total_failing': 126, 'max_depth_reached': 5, 'avg_box_size': 183.5625} (5 dp)
25 | explanation: of shape torch.Size([3, 224, 224])
26 | final mask: of shape (3, 224, 224)
27 | explanation confidence: 0.01513 (5 dp)
28 | # ---
29 | # name: test__explanation_snapshot[64].1
30 | 2537636458340732753
31 | # ---
32 | # name: test__explanation_snapshot_diff_model_shape[1]
33 | Explanation:
34 | CausalArgs:
35 | Data: Data: RGB, ['N', 3, 256, 256], 256, 256, 3, first
36 | Target:FOUND_CLASS: 215, FOUND_CONF: 0.48740, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
37 | prediction function:
38 | ResponsibilityMaps: {215: 1}
39 | run statistics: {'total_passing': 60, 'total_failing': 136, 'max_depth_reached': 5, 'avg_box_size': 262.0} (5 dp)
40 | explanation: of shape torch.Size([3, 256, 256])
41 | final mask: of shape (3, 256, 256)
42 | explanation confidence: 0.03730 (5 dp)
43 | # ---
44 | # name: test__explanation_snapshot_diff_model_shape[1].1
45 | -5471936325618413956
46 | # ---
47 | # name: test__explanation_snapshot_diff_model_shape[64]
48 | Explanation:
49 | CausalArgs:
50 | Data: Data: RGB, ['N', 3, 256, 256], 256, 256, 3, first
51 | Target:FOUND_CLASS: 215, FOUND_CONF: 0.48740, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
52 | prediction function:
53 | ResponsibilityMaps: {215: 1}
54 | run statistics: {'total_passing': 60, 'total_failing': 136, 'max_depth_reached': 5, 'avg_box_size': 262.0} (5 dp)
55 | explanation: of shape torch.Size([3, 256, 256])
56 | final mask: of shape (3, 256, 256)
57 | explanation confidence: 0.03730 (5 dp)
58 | # ---
59 | # name: test__explanation_snapshot_diff_model_shape[64].1
60 | -5471936325618413956
61 | # ---
62 | # name: test_extract_analyze[Strategy.Global]
63 | -2185390445883109608
64 | # ---
65 | # name: test_extract_analyze[Strategy.Global].1
66 | dict({
67 | 'area': 0.0112,
68 | 'deletion_curve': 1.0199,
69 | 'entropy': 7.1353,
70 | 'insertion_curve': 1.2399,
71 | })
72 | # ---
73 | # name: test_extract_analyze[Strategy.Spatial]
74 | -2185390445883109608
75 | # ---
76 | # name: test_extract_analyze[Strategy.Spatial].1
77 | dict({
78 | 'area': 0.0112,
79 | 'deletion_curve': 1.0199,
80 | 'entropy': 7.1353,
81 | 'insertion_curve': 1.2399,
82 | })
83 | # ---
84 | # name: test_multiexplanation_snapshot
85 | MultiExplanation:
86 | CausalArgs:
87 | Data: Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
88 | Target:FOUND_CLASS: 84, FOUND_CONF: 0.51611, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
89 | prediction function:
90 | ResponsibilityMaps: {84: 1}
91 | run statistics: {'total_passing': 217, 'total_failing': 329, 'max_depth_reached': 6, 'avg_box_size': 87.1} (5 dp)
92 | explanations: 5 explanations of and shape torch.Size([3, 224, 224])
93 | explanation confidences: [0.00628, 0.04587, 0.00387, 0.00446, 0.07882] (5 dp)
94 | # ---
95 | # name: test_multiexplanation_snapshot.1
96 | list([
97 | tuple(
98 | 1,
99 | 2,
100 | 3,
101 | 4,
102 | ),
103 | tuple(
104 | 0,
105 | 1,
106 | 3,
107 | 4,
108 | ),
109 | ])
110 | # ---
111 | # name: test_multiexplanation_snapshot.2
112 | 7802689268791639071
113 | # ---
114 | # name: test_multiexplanation_snapshot.3
115 | 6035056368165721418
116 | # ---
117 | # name: test_multiexplanation_snapshot.4
118 | -5651945023692276586
119 | # ---
120 | # name: test_multiexplanation_snapshot.5
121 | 2488684738058746373
122 | # ---
123 | # name: test_multiexplanation_snapshot.6
124 | 3892690372276336173
125 | # ---
126 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/__snapshots__/explanation_test.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args1]
3 | 3477467290869834009
4 | # ---
5 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args2]
6 | -8754720476463405946
7 | # ---
8 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args3]
9 | -200837265459686415
10 | # ---
11 | # name: test_calculate_responsibility[Distribution.BetaBinomial-dist_args4]
12 | -6449526703691359130
13 | # ---
14 | # name: test_calculate_responsibility[Distribution.Uniform-dist_args0]
15 | -1720903338186101365
16 | # ---
17 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/__snapshots__/load_preprocess_test.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test_preprocess
3 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
4 | # ---
5 | # name: test_preprocess_custom
6 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
7 | # ---
8 | # name: test_preprocess_spectral_mask_on_image_returns_warning
9 | Data: RGB, ['N', 3, 224, 224], 224, 224, 3, first
10 | # ---
11 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/__snapshots__/spectral_test.ambr:
--------------------------------------------------------------------------------
1 | # serializer version: 1
2 | # name: test__explanation_snapshot[1]
3 | Explanation:
4 | CausalArgs:
5 | Data: Data: spectral, ['batch_size', 1, 1356], 1, 1356, 1, None
6 | Target:FOUND_CLASS: 1, FOUND_CONF: 0.99817, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
7 | prediction function: . >
8 | ResponsibilityMaps: {1: 1}
9 | run statistics: {'total_passing': 56, 'total_failing': 0, 'max_depth_reached': 2, 'avg_box_size': 22.875} (5 dp)
10 | explanation: of shape torch.Size([1, 1356])
11 | final mask: of shape (1, 1356)
12 | explanation confidence: 0.99817 (5 dp)
13 | # ---
14 | # name: test__explanation_snapshot[1].1
15 | 3711909154719731854
16 | # ---
17 | # name: test__explanation_snapshot[64]
18 | Explanation:
19 | CausalArgs:
20 | Data: Data: spectral, ['batch_size', 1, 1356], 1, 1356, 1, None
21 | Target:FOUND_CLASS: 1, FOUND_CONF: 0.99817, TARGET_CLASS: n/a, TARGET_CONFIDENCE: n/a
22 | prediction function: . >
23 | ResponsibilityMaps: {1: 1}
24 | run statistics: {'total_passing': 56, 'total_failing': 0, 'max_depth_reached': 2, 'avg_box_size': 22.875} (5 dp)
25 | explanation: of shape torch.Size([1, 1356])
26 | final mask: of shape (1, 1356)
27 | explanation confidence: 0.99817 (5 dp)
28 | # ---
29 | # name: test__explanation_snapshot[64].1
30 | 3711909154719731854
31 | # ---
32 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/_explanation_onnx_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.input.config import Strategy
3 | from rex_xai.explanation.rex import _explanation, analyze, get_prediction_func_from_args
4 |
5 |
6 | def test__explanation_snapshot(args_onnx, cpu_device, snapshot_explanation):
7 | prediction_func, model_shape = get_prediction_func_from_args(args_onnx)
8 | exp = _explanation(args_onnx, model_shape, prediction_func, cpu_device, db=None)
9 |
10 | assert exp == snapshot_explanation
11 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation
12 |
13 |
14 | @pytest.mark.parametrize("strategy", [Strategy.Global, Strategy.Spatial])
15 | def test_extract_analyze(exp_onnx, strategy, snapshot):
16 | exp_onnx.extract(strategy)
17 | results = analyze(exp_onnx, "RGB")
18 | results_rounded = {k: round(v, 4) for k, v in results.items() if v is not None}
19 |
20 | assert hash(tuple(exp_onnx.final_mask.reshape(-1).tolist())) == snapshot
21 | assert results_rounded == snapshot
22 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/_explanation_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.explanation.rex import _explanation, analyze
3 | from rex_xai.utils._utils import Strategy
4 |
5 | @pytest.mark.parametrize("batch_size", [1, 64])
6 | def test__explanation_snapshot(
7 | args_custom, model_shape, prediction_func, cpu_device, batch_size, snapshot_explanation
8 | ):
9 | args_custom.batch_size = batch_size
10 | exp = _explanation(args_custom, model_shape, prediction_func, cpu_device, db=None)
11 |
12 | assert exp == snapshot_explanation
13 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation
14 |
15 |
16 | @pytest.mark.parametrize("batch_size", [1, 64])
17 | def test__explanation_snapshot_diff_model_shape(
18 | args_torch_swin_v2_t,
19 | model_shape_swin_v2_t,
20 | prediction_func_swin_v2_t,
21 | cpu_device,
22 | batch_size,
23 | snapshot_explanation
24 | ):
25 | args_torch_swin_v2_t.batch_size = batch_size
26 |
27 | exp = _explanation(
28 | args_torch_swin_v2_t,
29 | model_shape_swin_v2_t,
30 | prediction_func_swin_v2_t,
31 | cpu_device,
32 | db=None,
33 | )
34 |
35 | assert exp == snapshot_explanation
36 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation
37 |
38 |
39 | @pytest.mark.parametrize("strategy", [Strategy.Global, Strategy.Spatial])
40 | def test_extract_analyze(exp_custom, strategy, snapshot):
41 | exp_custom.extract(strategy)
42 | results = analyze(exp_custom, "RGB")
43 | results_rounded = {k: round(v, 4) for k, v in results.items() if v is not None}
44 |
45 | assert hash(tuple(exp_custom.final_mask.reshape(-1).tolist())) == snapshot
46 | assert results_rounded == snapshot
47 |
48 |
49 | def test_multiexplanation_snapshot(
50 | args_multi, model_shape, prediction_func, cpu_device, snapshot_explanation
51 | ):
52 | exp = _explanation(args_multi, model_shape, prediction_func, cpu_device, db=None)
53 | clauses = exp.separate_by(0.0)
54 |
55 | assert exp == snapshot_explanation
56 | assert clauses == snapshot_explanation
57 | for explanation in exp.explanations:
58 | assert hash(tuple(explanation.reshape(-1).tolist())) == snapshot_explanation
59 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/explanation_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.mutants.distributions import Distribution
3 | from rex_xai.explanation.rex import calculate_responsibility, predict_target
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "distribution,dist_args",
8 | [
9 | (Distribution.Uniform, []),
10 | (Distribution.BetaBinomial, [1, 1]),
11 | (Distribution.BetaBinomial, [0.5, 0.5]),
12 | (Distribution.BetaBinomial, [1, 0.5]),
13 | (Distribution.BetaBinomial, [0.5, 1]),
14 | ],
15 | )
16 | def test_calculate_responsibility(
17 | data_custom, args_custom, prediction_func, distribution, dist_args, snapshot
18 | ):
19 | args_custom.distribution = distribution
20 | if dist_args:
21 | args_custom.distribution_args = dist_args
22 | data_custom.target = predict_target(data_custom, prediction_func)
23 | maps, _ = calculate_responsibility(data_custom, args_custom, prediction_func)
24 | target_map = maps.get(data_custom.target.classification)
25 |
26 | assert hash(tuple(target_map.reshape(-1))) == snapshot
27 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/load_preprocess_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.utils._utils import ReXDataError
3 | from rex_xai.explanation.rex import (
4 | load_and_preprocess_data,
5 | try_preprocess,
6 | get_prediction_func_from_args
7 | )
8 |
9 |
10 | def test_preprocess(args, model_shape, cpu_device, snapshot):
11 | data = load_and_preprocess_data(model_shape, cpu_device, args)
12 | assert data == snapshot
13 |
14 |
15 | def test_preprocess_custom(args_custom, model_shape, cpu_device, snapshot):
16 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom)
17 | assert data == snapshot
18 |
19 |
20 | def test_preprocess_spectral_mask_on_image_returns_warning(
21 | args, model_shape, cpu_device, snapshot, caplog
22 | ):
23 | args.mask_value = "spectral"
24 | data = try_preprocess(args, model_shape, device=cpu_device)
25 |
26 | assert args.mask_value == 0
27 | assert (
28 | caplog.records[0].msg
29 | == "spectral is not suitable for images. Changing mask_value to 0"
30 | )
31 | assert data == snapshot
32 |
33 |
34 | def test_preprocess_npy(args, DNA_model, cpu_device, snapshot, caplog):
35 | args.path = "tests/test_data/DoublePeakClass 0 Mean.npy"
36 | args.model = DNA_model
37 | _, model_shape = get_prediction_func_from_args(args)
38 |
39 | try_preprocess(args, model_shape, device=cpu_device)
40 |
41 | assert caplog.records[0].msg == "we do not generically handle this datatype"
42 |
43 | args.mode = "tabular"
44 | with pytest.raises(ReXDataError):
45 | try_preprocess(args, model_shape, device=cpu_device)
46 |
47 | def test_preprocess_incompatible_shapes(args, model_shape, cpu_device, caplog):
48 | args.path = "tests/test_data/DoublePeakClass 0 Mean.npy"
49 | args.mode = "tabular"
50 |
51 | with pytest.raises(ReXDataError):
52 | try_preprocess(args, model_shape, device=cpu_device)
53 |
54 |
--------------------------------------------------------------------------------
/tests/snapshot_tests/spectral_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from cached_path import cached_path
3 | from rex_xai.input.config import Strategy
4 | from rex_xai.explanation.rex import _explanation, get_prediction_func_from_args
5 |
6 |
7 | @pytest.fixture(scope="session")
8 | def DNA_model():
9 | DNA_model_path = cached_path(
10 | "https://github.com/ReX-XAI/models/raw/6f66a5c0e1480411436be828ee8312e72f0035e1/spectral/simple_DNA_model.onnx"
11 | )
12 | return DNA_model_path
13 |
14 |
15 | @pytest.fixture
16 | def args_spectral(args, DNA_model):
17 | args.model = DNA_model
18 | args.path = "tests/test_data/spectrum_class_DNA.npy"
19 | args.mode = "spectral"
20 | args.mask_value = "spectral"
21 | args.seed = 15
22 | args.strategy = Strategy.Global
23 |
24 | return args
25 |
26 |
27 | @pytest.mark.parametrize(
28 | "batch_size", [1, 64]
29 | )
30 | def test__explanation_snapshot(args_spectral, cpu_device, batch_size, snapshot_explanation):
31 | args_spectral.batch_size = batch_size
32 | prediction_func, model_shape = get_prediction_func_from_args(args_spectral)
33 | exp = _explanation(args_spectral, model_shape, prediction_func, cpu_device, db=None)
34 |
35 | assert exp == snapshot_explanation
36 | assert hash(tuple(exp.explanation.reshape(-1).tolist())) == snapshot_explanation
37 |
--------------------------------------------------------------------------------
/tests/test_data/004_0002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/004_0002.jpg
--------------------------------------------------------------------------------
/tests/test_data/2008_000033.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/2008_000033.jpg
--------------------------------------------------------------------------------
/tests/test_data/DoublePeakClass 0 Mean 1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 0 Mean 1.npy
--------------------------------------------------------------------------------
/tests/test_data/DoublePeakClass 0 Mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 0 Mean.npy
--------------------------------------------------------------------------------
/tests/test_data/DoublePeakClass 1 Mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 1 Mean.npy
--------------------------------------------------------------------------------
/tests/test_data/DoublePeakClass 2 Mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/DoublePeakClass 2 Mean.npy
--------------------------------------------------------------------------------
/tests/test_data/ILSVRC2012_val_00047302.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/ILSVRC2012_val_00047302.JPEG
--------------------------------------------------------------------------------
/tests/test_data/TCGA_DU_7018_19911220_14.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/TCGA_DU_7018_19911220_14.tif
--------------------------------------------------------------------------------
/tests/test_data/bike.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/bike.jpg
--------------------------------------------------------------------------------
/tests/test_data/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/dog.jpg
--------------------------------------------------------------------------------
/tests/test_data/dog_hide.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/dog_hide.jpg
--------------------------------------------------------------------------------
/tests/test_data/ladybird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/ladybird.jpg
--------------------------------------------------------------------------------
/tests/test_data/lizard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/lizard.jpg
--------------------------------------------------------------------------------
/tests/test_data/peacock.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/peacock.jpg
--------------------------------------------------------------------------------
/tests/test_data/positive193.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/positive193.npy
--------------------------------------------------------------------------------
/tests/test_data/rex-test-all-config.toml:
--------------------------------------------------------------------------------
1 | [rex]
2 | # masking value for mutations, can be either an integer, float or
3 | # one of the following built-in occlusions 'spectral', 'min', 'mean'
4 | mask_value = "mean"
5 |
6 | # random seed, only set for reproducibility
7 | seed = 42
8 |
9 | # whether to use gpu or not, defaults to true
10 | gpu = false
11 |
12 | # batch size for the model
13 | batch_size = 32
14 |
15 | [rex.onnx]
16 | # means for min-max normalization
17 | means = [0.485, 0.456, 0.406]
18 |
19 | # standards devs for min-max normalization
20 | stds = [0.229, 0.224, 0.225]
21 |
22 | # binary model confidence threshold. Anything >= threshold will be classified as 1, otherwise 0
23 | binary_threshold = 0.5
24 |
25 | # norm
26 | norm = 1.0
27 |
28 | [rex.visual]
29 |
30 | # include classification and confidence information in title of plot, defaults to true
31 | info = false
32 |
33 | # pretty printing colour for explanations, defaults to 200
34 | colour = 150
35 |
36 | # alpha blend for main image, defaults to 0.2 (PIL Image.blend parameter)
37 | alpha = 0.1
38 |
39 | # produce unvarnished image with actual masking value, defaults to false
40 | raw = true
41 |
42 | # resize the explanation to the size of the original image. This uses cubic interpolation and will not be as visually accurate as not resizing, defaults to false
43 | resize = true
44 |
45 | # whether to show progress bar in the terminal, defalts to true
46 | progress_bar = false
47 |
48 | # overlay a 10*10 grid on an explanation, defaults to false
49 | grid = true
50 |
51 | # mark quickshift segmentation on image
52 | mark_segments = true
53 |
54 | # matplotlib colourscheme for responsibility map plotting, defaults to 'magma'
55 | heatmap_colours = 'viridis'
56 |
57 | # multi_style explanations, either or
58 | multi_style = "separate"
59 |
60 | [causal]
61 | # maximum depth of tree, defaults to 10, note that search can actually go beyond this number on occasion, as the
62 | # check only occurs at the end of an iteration
63 | tree_depth = 5
64 |
65 | # limit on number of combinations to consider , defaults to none.
66 | # It is **not** the total work done by ReX over all iterations. Leaving the search limit at none
67 | # can potentially be very expensive.
68 | search_limit = 1000
69 |
70 | # number of times to run the algorithm, defaults to 20
71 | iters = 30
72 |
73 | # minimum child size, in pixels
74 | min_box_size = 20
75 |
76 | # remove passing mutants which have a confidence less thatn . Defaults to 0.0 (meaning all mutants are considered)
77 | confidence_filter = 0.5
78 |
79 | # whether to weight responsibility by prediction confidence, default to false
80 | weighted = true
81 |
82 | # queue_style = "intersection" | "area" | "all" | "dc", defaults to "area"
83 | queue_style = "intersection"
84 |
85 | # maximum number of things to hold in search queue, either an integer or 'all'
86 | queue_len = 2
87 |
88 | # concentrate: weight responsibility by size and depth of passing partition. Defaults to false
89 | concentrate = true
90 |
91 | [causal.distribution]
92 | # distribution for splitting the box, defaults to uniform. Possible choices are 'uniform' | 'binom' | 'betabinom' | 'adaptive'
93 | distribution = 'betabinom'
94 |
95 | # blend one of the above distributions with the responsibility map treated as a distribution
96 | blend = 0.5
97 |
98 | # supplimental arguments for distribution creation, these are ignored if does not take any parameters
99 | distribution_args = [1.1, 1.1]
100 |
101 | [explanation]
102 | # iterate through pixel ranking in chunks, defaults to causal.min_box_size
103 | chunk_size = 16
104 |
105 | [explanation.spatial]
106 | # initial search radius
107 | spatial_initial_radius = 20
108 |
109 | # increment to change radius
110 | spatial_radius_eta = 0.1
111 |
112 | # number of times to expand before quitting, defaults to 4
113 | no_expansions = 1
114 |
115 | [explanation.multi]
116 | # multi method (just spotlight so far)
117 | strategy = 'spotlight'
118 |
119 | # no of spotlights to launch
120 | spotlights = 5
121 |
122 | # default size of spotlight
123 | spotlight_size = 10
124 |
125 | # decrease spotlight by this amount
126 | spotlight_eta = 0.5
127 |
128 | spotlight_step = 10
129 |
130 | # maximum number of random steps that a spotlight can make before quitting
131 | max_spotlight_budget = 30
132 |
133 | # objective function for spotlight search. Possible options 'mean' | 'max' | "none", defaults to "none"
134 | spotlight_objective_function = 'mean'
135 |
136 | # how much overlap to allow between different explanations. This is the dice coefficient, so
137 | # 0.0 means no permitted overlap, and 1 total overlap permitted
138 | permitted_overlap = 0.1
139 |
140 | [explanation.evaluation]
141 |
142 | # insertion/deletion curve step size
143 | insertion_step = 50
144 |
145 | # normalise insertion/deletion curves by confidence of original data
146 | normalise_curves = false
147 |
--------------------------------------------------------------------------------
/tests/test_data/spectrum_class_DNA.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/spectrum_class_DNA.npy
--------------------------------------------------------------------------------
/tests/test_data/spectrum_class_noDNA.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/spectrum_class_noDNA.npy
--------------------------------------------------------------------------------
/tests/test_data/starfish.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/starfish.jpg
--------------------------------------------------------------------------------
/tests/test_data/tennis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/tennis.jpg
--------------------------------------------------------------------------------
/tests/test_data/testimage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReX-XAI/ReX/d479870702dbfd4df73ad5d569e78087a0f1aeff/tests/test_data/testimage.png
--------------------------------------------------------------------------------
/tests/unit_tests/box_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import numpy as np
3 |
4 | from rex_xai.mutants.box import box_dimensions
5 |
6 |
7 | def test_data(data_3d, data_2d):
8 | assert data_3d.model_shape == [1, 64, 64, 64]
9 | assert data_3d.mode == "voxel"
10 | assert data_2d.model_shape == [1, 64, 64]
11 | assert data_2d.mode == "spectral"
12 |
13 | def test_initialise_tree_3d(box_3d):
14 | assert box_3d.depth_start == 0
15 | assert box_3d.depth_stop == 64
16 |
17 | assert box_dimensions(box_3d) == (0, 64, 0, 64, 0, 64)
18 |
19 | assert (
20 | box_3d.__repr__()
21 | == "Box < name: R, row_start: 0, row_stop: 64, col_start: 0, col_stop: 64, depth_start: 0, depth_stop: 64, volume: 262144"
22 | )
23 |
24 | assert box_3d.shape() == (64, 64, 64)
25 | assert box_3d.area() == 262144
26 | assert box_3d.corners() == (0, 64, 0, 64, 0, 64)
27 |
28 |
29 | def test_initialise_tree_2d(box_2d):
30 | # Depth does not exist as attribute for 2D boxes
31 | assert box_2d.depth_start is None
32 | assert box_2d.depth_stop is None
33 |
34 | assert box_dimensions(box_2d) == (0, 64, 0, 64)
35 |
36 | assert (
37 | box_2d.__repr__()
38 | == "Box < name: R, row_start: 0, row_stop: 64, col_start: 0, col_stop: 64, area: 4096"
39 | )
40 |
41 | assert box_2d.shape() == (64, 64)
42 | assert box_2d.area() == 4096
43 |
44 | assert (
45 | box_2d.corners() == (0, 64, 0, 64)
46 | ) # TODO: Box_dimensions has the same functionality as corners, should we remove one of them.
47 |
48 | def test_spawn_children_3d(box_3d, resp_map_3d):
49 | # Set seed
50 | np.random.seed(24)
51 | children_3d = box_3d.spawn_children(min_size=20, mode="voxel", map=resp_map_3d)
52 | assert len(children_3d) == 4
53 | assert children_3d[0].area() < 262144
54 |
55 | total_area_3d = (
56 | children_3d[0].area()
57 | + children_3d[1].area()
58 | + children_3d[2].area()
59 | + children_3d[3].area()
60 | )
61 | assert total_area_3d == 262144
62 |
63 | volumes = [1536, 2560, 96768, 161280]
64 |
65 | # Check splitting of boxes
66 | row_split, col_split = 1, 24
67 |
68 | row_starts = [0, 0, row_split, row_split]
69 | row_stops = [row_split, row_split, 64, 64]
70 | col_starts = [0, col_split, 0, col_split]
71 | col_stops = [col_split, 64, col_split, 64]
72 |
73 | for i in range(4):
74 | assert children_3d[i].name == f"R:{i}"
75 | # Not split through depth
76 | assert children_3d[i].depth_start == 0
77 | assert children_3d[i].depth_stop == 64
78 |
79 | assert children_3d[i].area() == volumes[i]
80 |
81 | assert children_3d[i].row_start == row_starts[i]
82 | assert children_3d[i].row_stop == row_stops[i]
83 | assert children_3d[i].col_start == col_starts[i]
84 | assert children_3d[i].col_stop == col_stops[i]
85 |
86 |
87 |
88 | def test_spawn_children_2d(box_2d, resp_map_2d):
89 | # Set seed
90 | np.random.seed(24)
91 | children_2d = box_2d.spawn_children(min_size=2, mode="RGB", map=resp_map_2d)
92 | assert len(children_2d) == 4
93 | assert children_2d[0].area() < 4096
94 |
95 | total_area_2d = (
96 | children_2d[0].area()
97 | + children_2d[1].area()
98 | + children_2d[2].area()
99 | + children_2d[3].area()
100 | )
101 | assert total_area_2d == 4096
102 | print(children_2d)
103 | areas = [210, 174, 2030, 1682]
104 |
105 | # Check splitting of boxes
106 | row_split, col_split = 6, 35
107 |
108 | row_starts = [0, 0, row_split, row_split]
109 | row_stops = [row_split, row_split, 64, 64]
110 | col_starts = [0, col_split, 0, col_split]
111 | col_stops = [col_split, 64, col_split, 64]
112 |
113 | for i in range(4):
114 | assert children_2d[i].name == f"R:{i}"
115 | assert children_2d[i].area() == areas[i]
116 |
117 | assert children_2d[i].row_start == row_starts[i]
118 | assert children_2d[i].row_stop == row_stops[i]
119 | assert children_2d[i].col_start == col_starts[i]
120 | assert children_2d[i].col_stop == col_stops[i]
121 |
122 |
123 |
--------------------------------------------------------------------------------
/tests/unit_tests/cmd_args_test.py:
--------------------------------------------------------------------------------
1 | from types import ModuleType
2 |
3 | import pytest
4 | from rex_xai.utils._utils import Strategy
5 | from rex_xai.input.config import CausalArgs, cmdargs_parser, process_cmd_args, shared_args
6 |
7 |
8 | @pytest.fixture
9 | def non_default_cmd_args():
10 | args_list = [
11 | "filename.jpg",
12 | "--output",
13 | "output_path.jpg",
14 | "--config",
15 | "path/to/rex.toml",
16 | "--processed",
17 | "--script",
18 | "tests/scripts/pytorch_resnet50.py",
19 | "-vv",
20 | "--surface",
21 | "surface_path.jpg",
22 | "--heatmap",
23 | "heatmap_path.jpg",
24 | "--model",
25 | "path/to/model.onnx",
26 | "--strategy",
27 | "multi",
28 | "--database",
29 | "path/to/database.db",
30 | "--multi",
31 | "5",
32 | "--iters",
33 | "10",
34 | "--analyse",
35 | "--mode",
36 | "RGB",
37 | ]
38 | parser = cmdargs_parser()
39 | cmd_args = parser.parse_args(args_list)
40 |
41 | return cmd_args
42 |
43 |
44 | def test_process_cmd_args(non_default_cmd_args):
45 | args = CausalArgs()
46 | process_cmd_args(non_default_cmd_args, args)
47 |
48 | assert isinstance(args.script, ModuleType)
49 | assert args.path == non_default_cmd_args.filename
50 | assert args.strategy == Strategy.MultiSpotlight
51 | assert args.iters == int(non_default_cmd_args.iters)
52 | assert args.analyse
53 | assert args.spotlights == int(non_default_cmd_args.multi)
54 |
55 |
56 | def test_process_shared_args(non_default_cmd_args):
57 | args = CausalArgs()
58 | shared_args(non_default_cmd_args, args)
59 |
60 | assert args.config_location == non_default_cmd_args.config
61 | assert args.model == non_default_cmd_args.model
62 | assert args.surface == non_default_cmd_args.surface
63 | assert args.heatmap == non_default_cmd_args.heatmap
64 | assert args.output == non_default_cmd_args.output
65 | assert args.verbosity == non_default_cmd_args.verbose
66 | assert args.db == non_default_cmd_args.database
67 | assert args.mode == non_default_cmd_args.mode
68 | assert args.processed == non_default_cmd_args.processed
69 |
70 |
71 | def test_quiet_overrides_verbose():
72 | cmd_args_list = ["filename.jpg", "-vv", "--quiet"]
73 | parser = cmdargs_parser()
74 | cmd_args = parser.parse_args(cmd_args_list)
75 | args = CausalArgs()
76 | shared_args(cmd_args, args)
77 |
78 | assert args.verbosity == 0
79 |
80 |
81 | def test_contrastive():
82 | cmd_args_list = ["filename.jpg", "--contrastive", "5"]
83 | parser = cmdargs_parser()
84 | cmd_args = parser.parse_args(cmd_args_list)
85 | args = CausalArgs()
86 | process_cmd_args(cmd_args, args)
87 |
88 | assert args.strategy == Strategy.Contrastive
89 | assert args.spotlights == int(cmd_args.contrastive)
90 |
91 |
92 | def test_spectral():
93 | cmd_args_list = ["filename.jpg", "--spectral"]
94 | parser = cmdargs_parser()
95 | cmd_args = parser.parse_args(cmd_args_list)
96 | args = CausalArgs()
97 | shared_args(cmd_args, args)
98 |
99 | assert args.mode == "spectral"
100 |
--------------------------------------------------------------------------------
/tests/unit_tests/config_test.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import pytest
4 | from rex_xai.utils._utils import Queue, Strategy
5 | from rex_xai.input.config import CausalArgs, process_config_dict, read_config_file
6 | from rex_xai.mutants.distributions import Distribution
7 |
8 |
9 | @pytest.fixture
10 | def non_default_args():
11 | non_default_args = CausalArgs()
12 | # rex
13 | non_default_args.mask_value = "mean"
14 | non_default_args.seed = 42
15 | non_default_args.gpu = False
16 | non_default_args.batch_size = 32
17 | # rex.onnx
18 | non_default_args.means = [0.485, 0.456, 0.406]
19 | non_default_args.stds = [0.229, 0.224, 0.225]
20 | non_default_args.binary_threshold = 0.5
21 | non_default_args.norm = 1.0
22 | # rex.visual
23 | non_default_args.info = False
24 | non_default_args.colour = 150
25 | non_default_args.alpha = 0.1
26 | non_default_args.raw = True
27 | non_default_args.resize = True
28 | non_default_args.progress_bar = False
29 | non_default_args.grid = True
30 | non_default_args.mark_segments = True
31 | non_default_args.heatmap_colours = "viridis"
32 | non_default_args.multi_style = "separate"
33 | # causal
34 | non_default_args.tree_depth = 5
35 | non_default_args.search_limit = 1000
36 | non_default_args.iters = 30
37 | non_default_args.min_box_size = 20
38 | non_default_args.confidence_filter = 0.5
39 | non_default_args.weighted = True
40 | non_default_args.queue_style = Queue.Intersection
41 | non_default_args.queue_len = 2
42 | non_default_args.concentrate = True
43 | # causal.distribution
44 | non_default_args.distribution = Distribution.BetaBinomial
45 | non_default_args.blend = 0.5
46 | non_default_args.distribution_args = [1.1, 1.1]
47 | # explanation
48 | non_default_args.chunk_size = 16
49 | non_default_args.spatial_initial_radius = 20
50 | non_default_args.spatial_radius_eta = 0.1
51 | non_default_args.no_expansions = 1
52 | # explanation.multi
53 | non_default_args.strategy = Strategy.MultiSpotlight
54 | non_default_args.spotlights = 5
55 | non_default_args.spotlight_size = 10
56 | non_default_args.spotlight_eta = 0.5
57 | non_default_args.spotlight_step = 10
58 | non_default_args.max_spotlight_budget = 30
59 | non_default_args.spotlight_objective_function = "mean"
60 | non_default_args.permitted_overlap = 0.1
61 | # explanation.evaluation
62 | non_default_args.insertion_step = 50
63 | non_default_args.normalise_curves = False
64 |
65 | return non_default_args
66 |
67 |
68 | def test_process_config_dict(non_default_args):
69 | args = CausalArgs()
70 | config_dict = read_config_file("tests/test_data/rex-test-all-config.toml")
71 |
72 | process_config_dict(config_dict, args)
73 |
74 | assert vars(args) == vars(non_default_args)
75 |
76 |
77 | def test_process_config_dict_empty():
78 | args = CausalArgs()
79 | config_dict = {}
80 | orig_args = copy.deepcopy(args)
81 |
82 | process_config_dict(config_dict, args)
83 |
84 | assert vars(args) == vars(orig_args)
85 |
86 |
87 | def test_process_config_dict_invalid_arg(caplog):
88 | args = CausalArgs()
89 | config_dict = {"explanation": {"chunk": 10}}
90 |
91 | process_config_dict(config_dict, args)
92 | assert (
93 | caplog.records[0].message == "Invalid or misplaced parameter 'chunk', skipping!"
94 | )
95 |
96 |
97 | def test_process_config_dict_invalid_distribution(caplog):
98 | args = CausalArgs()
99 | config_dict = {
100 | "causal": {"distribution": {"distribution": "an-invalid-distribution"}}
101 | }
102 |
103 | process_config_dict(config_dict, args)
104 |
105 | assert args.distribution == Distribution.Uniform
106 | assert (
107 | caplog.records[0].message
108 | == "Invalid distribution 'an-invalid-distribution', reverting to default value Distribution.Uniform"
109 | )
110 |
111 |
112 | def test_process_config_dict_uniform_distribution():
113 | args = CausalArgs()
114 | config_dict = {
115 | "causal": {
116 | "distribution": {"distribution": "uniform", "distribution_args": [0.0, 0.0]}
117 | }
118 | }
119 |
120 | process_config_dict(config_dict, args)
121 |
122 | assert args.distribution == Distribution.Uniform
123 | assert args.distribution_args is None
124 |
125 |
126 | def test_process_config_dict_distribution_args():
127 | args = CausalArgs()
128 | config_dict = {
129 | "causal": {
130 | "distribution": {
131 | "distribution": "betabinom",
132 | "distribution_args": [0.0, 0.0],
133 | }
134 | }
135 | }
136 |
137 | process_config_dict(config_dict, args)
138 |
139 | assert args.distribution == Distribution.BetaBinomial
140 | assert args.distribution_args == [0.0, 0.0]
141 |
142 |
143 | def test_process_config_dict_queue_style():
144 | args = CausalArgs()
145 | config_dict = {"causal": {"queue_style": "all"}}
146 |
147 | process_config_dict(config_dict, args)
148 | assert args.queue_style == Queue.All
149 |
150 |
151 | def test_process_config_dict_queue_style_upper():
152 | args = CausalArgs()
153 | config_dict = {"causal": {"queue_style": "ALL"}}
154 |
155 | process_config_dict(config_dict, args)
156 | assert args.queue_style == Queue.All
157 |
158 |
159 | def test_process_config_dict_queue_style_invalid(caplog):
160 | args = CausalArgs()
161 | config_dict = {"causal": {"queue_style": "an-invalid-queue-style"}}
162 |
163 | process_config_dict(config_dict, args)
164 | assert args.queue_style == Queue.Area
165 | assert (
166 | caplog.records[0].message
167 | == "Invalid queue style 'an-invalid-queue-style', reverting to default value Queue.Area"
168 | )
169 |
170 |
171 | def test_process_config_dict_strategy():
172 | args = CausalArgs()
173 | config_dict = {"explanation": {"multi": {"strategy": "spotlight"}}}
174 |
175 | process_config_dict(config_dict, args)
176 | assert args.strategy == Strategy.MultiSpotlight
177 |
178 |
179 | def test_process_config_dict_strategy_invalid(caplog):
180 | args = CausalArgs()
181 | config_dict = {"explanation": {"multi": {"strategy": "an-invalid-strategy"}}}
182 |
183 | process_config_dict(config_dict, args)
184 | assert (
185 | caplog.records[0].message
186 | == "Invalid strategy 'an-invalid-strategy', reverting to default value Strategy.Global"
187 | )
188 | assert args.strategy == Strategy.Global
189 |
--------------------------------------------------------------------------------
/tests/unit_tests/data_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import numpy as np
4 | from rex_xai.explanation.rex import load_and_preprocess_data
5 | from rex_xai.input.input_data import Data
6 |
7 | tab = np.arange(0, 1, 999)
8 | voxel = np.random.rand(1, 64, 64, 64)
9 |
10 |
11 | def test_data():
12 | data = Data(input=tab, model_shape=[1, 999], device="cpu")
13 | assert data.model_shape == [1, 999]
14 | assert data.mode == "spectral"
15 |
16 |
17 | def test_set_mask_value(args_custom, model_shape, cpu_device, caplog):
18 | data = load_and_preprocess_data(model_shape, cpu_device, args_custom)
19 | data.set_mask_value("spectral")
20 |
21 | assert (
22 | caplog.records[0].msg
23 | == "Mask value 'spectral' can only be used if mode is also 'spectral', using default mask value 0 instead"
24 | )
25 | assert data.mask_value == 0
26 |
27 |
28 | def test_3D_data():
29 | data = Data(input=voxel, model_shape=[1, 64, 64, 64], device="cpu")
30 | assert data.model_shape == [1, 64, 64, 64]
31 | assert data.mode == "voxel"
32 |
--------------------------------------------------------------------------------
/tests/unit_tests/database_test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from rex_xai.output.database import db_to_pandas, initialise_rex_db, update_database
5 |
6 |
7 | @pytest.fixture
8 | def db(tmp_path):
9 | p = tmp_path / "rex.db"
10 | db = initialise_rex_db(p)
11 | return db
12 |
13 |
14 | def test_update_database(exp_extracted, tmp_path):
15 | p = tmp_path / "rex.db"
16 | db = initialise_rex_db(p)
17 | update_database(db, exp_extracted)
18 | assert os.path.exists(p)
19 | assert os.stat(p).st_size > 0
20 |
21 |
22 | def test_update_database_no_target(exp_extracted, db, caplog):
23 | exp_extracted.data.target = None
24 | update_database(db, exp_extracted)
25 | assert caplog.records[0].message == "unable to update database as target is None"
26 |
27 |
28 | def test_update_database_no_exp(exp_extracted, db, caplog):
29 | exp_extracted.final_mask = None
30 | update_database(db, exp_extracted)
31 | assert (
32 | caplog.records[0].message == "unable to update database as explanation is empty"
33 | )
34 |
35 |
36 | def test_read_db(exp_extracted, tmp_path):
37 | p = tmp_path / "rex.db"
38 | db = initialise_rex_db(p)
39 | update_database(db, exp_extracted)
40 | df = db_to_pandas(p)
41 | assert df.shape == (1, 30)
42 |
43 |
44 | def test_no_multi(exp_extracted, caplog):
45 | update_database(db, exp_extracted, multi=True)
46 | assert (
47 | caplog.records[0].message
48 | == "unable to update database, multi=True is only valid for MultiExplanation objects"
49 | )
50 |
51 |
52 | def test_update_database_multiexp(exp_multi, tmp_path):
53 | p = tmp_path / "rex.db"
54 | db = initialise_rex_db(p)
55 | update_database(db, exp_multi)
56 | assert os.path.exists(p)
57 | assert os.stat(p).st_size > 0
58 |
59 |
60 | def test_read_database_multiexp(exp_multi, tmp_path):
61 | p = tmp_path / "rex.db"
62 | db = initialise_rex_db(p)
63 | update_database(db, exp_multi, multi=True)
64 |
65 | df = db_to_pandas(p)
66 | assert df.shape == (len(exp_multi.explanations), 30)
67 |
--------------------------------------------------------------------------------
/tests/unit_tests/multi-explanation_test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 | import pytest
6 | from rex_xai.utils._utils import Strategy
7 | from rex_xai.explanation.rex import calculate_responsibility
8 | from rex_xai.explanation.explanation import Explanation
9 | from rex_xai.explanation.multi_explanation import MultiExplanation
10 |
11 |
12 | @pytest.mark.parametrize("spotlights", [5, 10])
13 | def test_multiexplanation(data_multi, args_multi, prediction_func, spotlights, caplog):
14 | args_multi.spotlights = spotlights
15 |
16 | maps, run_stats = calculate_responsibility(data_multi, args_multi, prediction_func)
17 |
18 | exp = Explanation(maps, prediction_func, data_multi, args_multi, run_stats)
19 | exp.extract(method=Strategy.Global)
20 |
21 | multi_exp = MultiExplanation(
22 | maps, prediction_func, data_multi, args_multi, run_stats
23 | )
24 | caplog.set_level(logging.INFO)
25 | multi_exp.extract(Strategy.MultiSpotlight)
26 |
27 | n_exp = 0
28 | for record in caplog.records:
29 | print(record)
30 | if "found an explanation" in record.message:
31 | n_exp += 1
32 |
33 | assert (
34 | caplog.records[-1].message
35 | == f"ReX has found a total of {n_exp} explanations via spotlight search"
36 | )
37 | assert n_exp == len(multi_exp.explanations)
38 | assert len(multi_exp.explanations) <= spotlights # always true
39 | assert (
40 | len(multi_exp.explanations) == spotlights
41 | ) # not always true but is for this data/parameters
42 | assert np.array_equal(
43 | multi_exp.explanations[0].detach().cpu().numpy(), exp.final_mask
44 | ) # first explanation is global explanation
45 |
46 |
47 | def test_multiexplanation_save_composite(exp_multi, tmp_path):
48 | clauses = exp_multi.separate_by(0.0)
49 |
50 | p = tmp_path / "exp.png"
51 | exp_multi.save(path=p, multi_style="composite", clauses=None)
52 |
53 | assert os.path.exists(p)
54 | assert os.stat(p).st_size > 0
55 |
56 | for c in clauses:
57 | exp_multi.save(path=p, multi_style="composite", clauses=c)
58 | clause_path = tmp_path / f"exp_{c}.png"
59 | assert os.path.exists(clause_path)
60 | assert os.stat(clause_path).st_size > 0
61 |
62 |
63 | def test_multiexplanation_save_separate(exp_multi, tmp_path):
64 | p = tmp_path / "exp.png"
65 | exp_multi.save(path=p, multi_style="separate")
66 |
67 | for i in range(len(exp_multi.explanations)):
68 | exp_path = tmp_path / f"exp_{i}.png"
69 | assert os.path.exists(exp_path)
70 | assert os.stat(exp_path).st_size > 0
71 |
--------------------------------------------------------------------------------
/tests/unit_tests/preprocessing_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.input.config import validate_args
3 | from rex_xai.explanation.rex import (
4 | predict_target,
5 | try_preprocess,
6 | )
7 |
8 |
9 | def test_preprocess_nii_notimplemented(args, model_shape, cpu_device, caplog):
10 | args.path = "tests/test_data/dog.nii"
11 | data = try_preprocess(args, model_shape, device=cpu_device)
12 |
13 | assert data == NotImplemented
14 | assert caplog.records[0].msg == "we do not (yet) handle nifti files generically"
15 |
16 |
17 | def test_predict_target(data, prediction_func):
18 | target = predict_target(data, prediction_func)
19 |
20 | assert target.classification == 207
21 | assert target.confidence == pytest.approx(0.253237, abs=2.5e-6)
22 |
23 |
24 | def test_validate_args(args):
25 | args.path = None # type: ignore
26 | with pytest.raises(FileNotFoundError):
27 | validate_args(args)
28 |
29 |
30 | def test_preprocess_rgba(args, model_shape, prediction_func, cpu_device, caplog):
31 | args.path = "assets/rex_logo.png"
32 | data = try_preprocess(args, model_shape, device=cpu_device)
33 | predict_target(data, prediction_func)
34 |
35 | assert caplog.records[0].msg == "RGBA input image provided, converting to RGB"
36 | assert data.mode == "RGB"
37 | assert data.input.mode == "RGB"
38 | assert data.data is not None
39 | assert data.data.shape[1] == 3 # batch, channels, height, width
40 |
--------------------------------------------------------------------------------
/tests/unit_tests/validate_args_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rex_xai.utils._utils import ReXTomlError
3 | from rex_xai.input.config import CausalArgs, validate_args
4 |
5 |
6 | def test_no_path(args):
7 | args.path = None # type: ignore
8 | with pytest.raises(FileNotFoundError):
9 | validate_args(args)
10 |
11 |
12 | def test_blend_invalid(caplog):
13 | args = CausalArgs()
14 | args.blend = 20
15 | with pytest.raises(ReXTomlError):
16 | validate_args(args)
17 | assert (
18 | caplog.records[0].message
19 | == "Invalid value '20': must be between 0.0 and 1.0"
20 | )
21 |
22 |
23 | def test_permitted_overlap_invalid(caplog):
24 | args = CausalArgs()
25 | args.permitted_overlap = -5
26 | with pytest.raises(ReXTomlError):
27 | validate_args(args)
28 | assert (
29 | caplog.records[0].message
30 | == "Invalid value '-5': must be between 0.0 and 1.0"
31 | )
32 |
33 |
34 | def test_iters_invalid(caplog):
35 | args = CausalArgs()
36 | args.iters = 0
37 | with pytest.raises(ReXTomlError):
38 | validate_args(args)
39 | assert caplog.records[0].message == "Invalid value '0': must be more than 0.0"
40 |
41 |
42 | def test_raw_invalid(caplog):
43 | args = CausalArgs()
44 | args.raw = 100 # type: ignore
45 | with pytest.raises(ReXTomlError):
46 | validate_args(args)
47 | assert caplog.records[0].message == "Invalid value '100': must be boolean"
48 |
49 |
50 | def test_multi_style_invalid(caplog):
51 | args = CausalArgs()
52 | args.multi_style = "an-invalid-style"
53 | with pytest.raises(ReXTomlError):
54 | validate_args(args)
55 | assert (
56 | caplog.records[0].message
57 | == "Invalid value 'an-invalid-style' for multi_style, must be 'composite' or 'separate'"
58 | )
59 |
60 |
61 | def test_queue_len_invalid(caplog):
62 | args = CausalArgs()
63 | args.queue_len = 7.5 # type: ignore
64 | with pytest.raises(ReXTomlError):
65 | validate_args(args)
66 | assert (
67 | caplog.records[0].message
68 | == "Invalid value '7.5' for queue_len, must be 'all' or an integer"
69 | )
70 |
71 |
72 | def test_distribution_args_invalid(caplog):
73 | args = CausalArgs()
74 | args.distribution_args = 1 # type: ignore
75 | with pytest.raises(ReXTomlError):
76 | validate_args(args)
77 | assert caplog.records[0].message == "distribution args must be length 2, not 1"
78 |
79 | args.distribution_args = [0, -1]
80 | with pytest.raises(ReXTomlError):
81 | validate_args(args)
82 | assert (
83 | caplog.records[0].message
84 | == "All values in distribution args must be more than zero"
85 | )
86 |
87 |
88 | def test_colour_map_invalid(caplog):
89 | args = CausalArgs()
90 | args.heatmap_colours = "RedBlue"
91 | with pytest.raises(ReXTomlError):
92 | validate_args(args)
93 | assert (
94 | caplog.records[0].message
95 | == "Invalid colourmap 'RedBlue', must be a valid matplotlib colourmap"
96 | )
97 |
--------------------------------------------------------------------------------
/tests/unit_tests/visualisation_test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch as tt
4 | from rex_xai.output.visualisation import save_image, voxel_plot
5 |
6 | from rex_xai.input.config import CausalArgs
7 |
8 |
9 | def test_surface(exp_extracted, tmp_path):
10 | p = tmp_path / "surface.png"
11 | exp_extracted.surface_plot(path=p)
12 |
13 | assert os.path.exists(p)
14 | assert os.stat(p).st_size > 0
15 |
16 |
17 | def test_heatmap(exp_extracted, tmp_path):
18 | p = tmp_path / "heatmap.png"
19 | exp_extracted.heatmap_plot(path=p)
20 |
21 | assert os.path.exists(p)
22 | assert os.stat(p).st_size > 0
23 |
24 |
25 | def test_save_exp(exp_extracted, tmp_path):
26 | p = tmp_path / "exp.png"
27 | exp_extracted.save(path=p)
28 |
29 | assert os.path.exists(p)
30 | assert os.stat(p).st_size > 0
31 |
32 | def test_save_image_3d(data_3d):
33 | # Explanation mask for the voxel data - random values of 0s and 1s
34 | explanation = tt.zeros((1, 64, 64, 64), dtype=tt.bool, device="cpu")
35 | explanation[0, 32:64, 32:64, 32:64] = 1
36 | args = CausalArgs()
37 | args.mode = "voxel"
38 | args.output = "test.png"
39 | save_image(explanation, data_3d, args, path=args.output)
40 | assert os.path.exists(args.output)
41 | assert os.path.getsize(args.output) > 0
42 |
43 | os.remove(args.output)
44 |
45 |
46 | def test_voxel_plot(data_3d, resp_map_3d):
47 | args = CausalArgs()
48 | print(data_3d)
49 | # Create a cube in data
50 | voxel_plot(args, resp_map_3d, data_3d, path="test.png")
51 | for i in ["x", "y", "z"]:
52 | assert os.path.exists(f"test_{i}_slice.png")
53 | assert os.path.getsize(f"test_{i}_slice.png") > 0
54 | os.remove(f"test_{i}_slice.png")
55 |
--------------------------------------------------------------------------------