├── src └── keras_explainable │ ├── engine │ ├── __init__.py │ └── explaining.py │ ├── methods │ ├── __init__.py │ ├── meta.py │ ├── gradient.py │ └── cams.py │ ├── __init__.py │ ├── utils.py │ ├── filters.py │ └── inspection.py ├── docs ├── _static │ ├── .gitignore │ ├── images │ │ ├── cover.jpg │ │ ├── singleton │ │ │ ├── bear.jpg │ │ │ ├── dogcat.jpg │ │ │ ├── goldfish.jpg │ │ │ ├── image02.png │ │ │ ├── soldiers.jpg │ │ │ ├── Dalmatian-2.jpg │ │ │ ├── iStock-157312120.webp │ │ │ ├── multiple-cats-300x225.jpg │ │ │ ├── ILSVRC2012_val_00000073.JPEG │ │ │ ├── ILSVRC2012_val_00000091.JPEG │ │ │ ├── ILSVRC2012_val_00000198.JPEG │ │ │ ├── ILSVRC2012_val_00000476.JPEG │ │ │ ├── ILSVRC2012_val_00002193.JPEG │ │ │ ├── delta_Airbus_diecast_airplane.jpg │ │ │ ├── Images-of-San-Francisco-Garter-Snake.jpg │ │ │ └── _links.txt │ │ └── voc12 │ │ │ ├── 2007_000063.jpg │ │ │ ├── 2007_000068.jpg │ │ │ ├── 2007_000129.jpg │ │ │ ├── 2007_000170.jpg │ │ │ ├── 2007_000175.jpg │ │ │ ├── 2007_000363.jpg │ │ │ ├── 2007_000480.jpg │ │ │ └── 2007_000733.jpg │ └── css │ │ └── custom.css ├── readme.rst ├── contributing.rst ├── authors.rst ├── changelog.rst ├── license.rst ├── requirements.txt ├── index.rst ├── methods │ ├── index.rst │ ├── cams │ │ ├── ttacam.rst │ │ └── gradcam.rst │ └── saliency │ │ ├── smoothgrad.rst │ │ ├── fullgrad.rst │ │ └── gradients.rst ├── Makefile ├── explaining.rst ├── wsol.rst ├── exposure.rst └── conf.py ├── AUTHORS.rst ├── _static └── images │ └── cover.jpg ├── shell ├── format.sh └── lint.sh ├── tests ├── conftest.py └── unit │ ├── methods │ └── meta_test.py │ ├── engine │ └── explaining_test.py │ └── inspection_test.py ├── .style.yapf ├── pyproject.toml ├── .readthedocs.yml ├── CHANGELOG.rst ├── .coveragerc ├── setup.py ├── .gitignore ├── .github └── workflows │ ├── ci.yml │ └── pages.yml ├── tox.ini ├── setup.cfg ├── README.rst ├── LICENSE └── CONTRIBUTING.rst /src/keras_explainable/engine/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGELOG.rst 3 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. include:: ../LICENSE 8 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Lucas David 6 | -------------------------------------------------------------------------------- /_static/images/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/_static/images/cover.jpg -------------------------------------------------------------------------------- /docs/_static/images/cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/cover.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/bear.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/dogcat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/dogcat.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/goldfish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/goldfish.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/image02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/image02.png -------------------------------------------------------------------------------- /docs/_static/images/singleton/soldiers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/soldiers.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000063.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000068.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000129.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000170.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000170.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000175.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000175.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000363.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000363.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000480.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000480.jpg -------------------------------------------------------------------------------- /docs/_static/images/voc12/2007_000733.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/voc12/2007_000733.jpg -------------------------------------------------------------------------------- /shell/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # isort --sl src/keras_explainable 3 | black --line-length 90 src/keras_explainable 4 | flake8 src/keras_explainable 5 | -------------------------------------------------------------------------------- /docs/_static/images/singleton/Dalmatian-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/Dalmatian-2.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/iStock-157312120.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/iStock-157312120.webp -------------------------------------------------------------------------------- /docs/_static/images/singleton/multiple-cats-300x225.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/multiple-cats-300x225.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/ILSVRC2012_val_00000073.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/ILSVRC2012_val_00000073.JPEG -------------------------------------------------------------------------------- /docs/_static/images/singleton/ILSVRC2012_val_00000091.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/ILSVRC2012_val_00000091.JPEG -------------------------------------------------------------------------------- /docs/_static/images/singleton/ILSVRC2012_val_00000198.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/ILSVRC2012_val_00000198.JPEG -------------------------------------------------------------------------------- /docs/_static/images/singleton/ILSVRC2012_val_00000476.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/ILSVRC2012_val_00000476.JPEG -------------------------------------------------------------------------------- /docs/_static/images/singleton/ILSVRC2012_val_00002193.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/ILSVRC2012_val_00002193.JPEG -------------------------------------------------------------------------------- /docs/_static/images/singleton/delta_Airbus_diecast_airplane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/delta_Airbus_diecast_airplane.jpg -------------------------------------------------------------------------------- /docs/_static/images/singleton/Images-of-San-Francisco-Garter-Snake.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucasdavid/keras-explainable/HEAD/docs/_static/images/singleton/Images-of-San-Francisco-Garter-Snake.jpg -------------------------------------------------------------------------------- /src/keras_explainable/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from keras_explainable.methods import cams 2 | from keras_explainable.methods import gradient 3 | from keras_explainable.methods import meta 4 | 5 | __all__ = [ 6 | "cams", 7 | "gradient", 8 | "meta", 9 | ] 10 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for ReadTheDocs, check .readthedocs.yml. 2 | # To build the module reference correctly, make sure every external package 3 | # under `install_requires` in `setup.cfg` is also listed here! 4 | sphinx>=3.2.1 5 | sphinx_redactor_theme 6 | jupyter-sphinx 7 | keras_explainable 8 | sphinx-autodoc-typehints 9 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dummy conftest.py for keras_explainable. 3 | 4 | If you don't know what this is for, just leave it empty. 5 | Read more about conftest.py under: 6 | - https://docs.pytest.org/en/stable/fixture.html 7 | - https://docs.pytest.org/en/stable/writing_plugins.html 8 | """ 9 | 10 | # import pytest 11 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | spaces_before_comment = 2 4 | indent_width = 2 5 | split_before_logical_operator = true 6 | column_limit = 90 7 | split_before_named_assigns = true 8 | dedent_closing_brackets = true 9 | indent_dictionary_value = false 10 | continuation_indent_width = 2 11 | split_before_default_or_named_assigns = true 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # For smarter version schemes and other configuration options, 8 | # check out https://github.com/pypa/setuptools_scm 9 | version_scheme = "no-guess-dev" 10 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Build documentation with MkDocs 12 | #mkdocs: 13 | # configuration: mkdocs.yml 14 | 15 | # Optionally build your docs in additional formats such as PDF 16 | formats: 17 | - pdf 18 | 19 | python: 20 | version: 3.8 21 | install: 22 | - requirements: docs/requirements.txt 23 | - {path: ., method: pip} 24 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. include:: readme.rst 3 | 4 | Contents 5 | ======== 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | Overview 11 | Explaining 12 | Exposure 13 | Methods 14 | WSSL & WSSS 15 | Contributions & Help 16 | License 17 | Authors 18 | Changelog 19 | Module Reference 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | 28 | .. _Sphinx: https://www.sphinx-doc.org/ 29 | -------------------------------------------------------------------------------- /shell/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # isort --check --sl -c src/keras_explainable 4 | # if ! [ $? -eq 0 ] 5 | # then 6 | # echo "Please run \"sh shell/format.sh\" to format the code." 7 | # exit 1 8 | # fi 9 | # echo "no issues with isort" 10 | 11 | flake8 src/keras_explainable 12 | if ! [ $? -eq 0 ] 13 | then 14 | echo "Please fix the code style issue." 15 | exit 1 16 | fi 17 | echo "no issues with flake8" 18 | black --check --line-length 90 src/keras_explainable 19 | if ! [ $? -eq 0 ] 20 | then 21 | echo "Please run \"sh shell/format.sh\" to format the code." 22 | exit 1 23 | fi 24 | echo "no issues with black" 25 | echo "linting success!" 26 | -------------------------------------------------------------------------------- /docs/methods/index.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | AI Explaining Methods 3 | ===================== 4 | 5 | We list in this page the AI Explaining Methods implemented in 6 | ``keras-explainable``, as well as a few examples on how to utilize them. 7 | 8 | Saliency and Gradient-based 9 | """"""""""""""""""""""""""" 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | 14 | Gradient Back-propagation 15 | SmoothGrad 16 | Full-Gradient 17 | 18 | CAM-Based Techniques 19 | """""""""""""""""""" 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | Grad-CAM 25 | TTA CAM 26 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.0.2 6 | ============= 7 | 8 | - Improve Documentation 9 | - Add CI GitHub actions 10 | 11 | Version 0.0.1 12 | ============= 13 | 14 | - Start project with scaffolding and add schematics 15 | - Add CAM explaining method 16 | - Add Grad-CAM explaining method 17 | - Add Grad-CAM++ explaining method 18 | - Add Score-CAM explaining method 19 | - Add gradient backprop saliency explaining method 20 | - Add FullGrad explaining method 21 | - Add `engine`, `filters` and `inspection` modules 22 | - Add :py:mod:`keras_explainable.methods.meta` module, containing 23 | implementations for the ``Smooth`` and ``TTA`` procedures 24 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = keras_explainable 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for keras-explainable. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.3.1. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 14 | except: # noqa 15 | print( 16 | "\n\nAn error occurred while building the project, " 17 | "please ensure you have the most updated version of setuptools, " 18 | "setuptools_scm and wheel with:\n" 19 | " pip install -U setuptools setuptools_scm wheel\n\n" 20 | ) 21 | raise 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/api/* 46 | docs/_rst/* 47 | docs/_build/* 48 | cover/* 49 | MANIFEST 50 | 51 | # Per-project virtualenvs 52 | .venv*/ 53 | .conda*/ 54 | .python-version 55 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'release' 7 | jobs: 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ["3.10"] 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | echo "Installing dependencies and caching them." 24 | python -m pip install --upgrade pip 25 | pip install numpy pandas matplotlib 26 | pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow_cpu-2.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 27 | pip install setuptools pytest pytest-cov parameterized 28 | pip install . 29 | - name: 30 | run: | 31 | ./shell/lint.sh 32 | continue-on-error: true 33 | - name: Test with pytest 34 | run: | 35 | pytest 36 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | AUTODOCDIR = api 11 | 12 | # User-friendly check for sphinx-build 13 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 14 | $(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") 15 | endif 16 | 17 | .PHONY: help clean Makefile 18 | 19 | # Put it first so that "make" without argument is like "make help". 20 | help: 21 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | clean: 24 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 25 | 26 | # Catch-all target: route all unknown targets to Sphinx using the new 27 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 28 | %: Makefile 29 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 30 | -------------------------------------------------------------------------------- /.github/workflows/pages.yml: -------------------------------------------------------------------------------- 1 | name: Pages 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | env: 9 | TF_CPP_MIN_LOG_LEVEL: 2 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.10' 18 | - uses: actions/checkout@master 19 | with: 20 | ref: release 21 | fetch-depth: 0 22 | - name: Cache Dependencies 23 | uses: actions/cache@v2 24 | id: cache 25 | with: 26 | path: ~/.cache/pip 27 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 28 | restore-keys: | 29 | ${{ runner.os }}-pip- 30 | - name: Install Dependencies 31 | run: | 32 | echo "Installing dependencies and caching them." 33 | python -m pip install --upgrade pip 34 | pip install numpy pandas matplotlib 35 | pip install https://files.pythonhosted.org/packages/04/ea/49fd026ac36fdd79bf072294b139170aefc118e487ccb39af019946797e9/tensorflow-2.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 36 | pip install . 37 | - name: Build and Commit 38 | uses: sphinx-notes/pages@v2 39 | with: 40 | requirements_path: ./docs/requirements.txt 41 | - name: Push changes 42 | uses: ad-m/github-push-action@master 43 | with: 44 | github_token: ${{ secrets.GITHUB_TOKEN }} 45 | branch: gh-pages 46 | -------------------------------------------------------------------------------- /docs/_static/images/singleton/_links.txt: -------------------------------------------------------------------------------- 1 | https://raw.githubusercontent.com/haofanwang/Score-CAM/master/images/ILSVRC2012_val_00000073.JPEG 2 | https://raw.githubusercontent.com/haofanwang/Score-CAM/master/images/ILSVRC2012_val_00000091.JPEG 3 | https://raw.githubusercontent.com/haofanwang/Score-CAM/master/images/ILSVRC2012_val_00000198.JPEG 4 | https://raw.githubusercontent.com/haofanwang/Score-CAM/master/images/ILSVRC2012_val_00000476.JPEG 5 | https://raw.githubusercontent.com/haofanwang/Score-CAM/master/images/ILSVRC2012_val_00002193.JPEG 6 | https://raw.githubusercontent.com/keisen/tf-keras-vis/master/docs/examples/images/goldfish.jpg 7 | https://raw.githubusercontent.com/keisen/tf-keras-vis/master/docs/examples/images/bear.jpg 8 | https://raw.githubusercontent.com/keisen/tf-keras-vis/master/docs/examples/images/soldiers.jpg 9 | https://3.bp.blogspot.com/-W__wiaHUjwI/Vt3Grd8df0I/AAAAAAAAA78/7xqUNj8ujtY/s400/image02.png 10 | http://www.aviationexplorer.com/Diecast_Airplanes_Aircraft/delta_Airbus_diecast_airplane.jpg 11 | https://www.petcare.com.au/wp-content/uploads/2017/09/Dalmatian-2.jpg 12 | http://sites.psu.edu/siowfa15/wp-content/uploads/sites/29639/2015/10/dogcat.jpg 13 | https://consciouscat.net/wp-content/uploads/2009/08/multiple-cats-300x225.jpg 14 | https://images2.minutemediacdn.com/image/upload/c_crop,h_843,w_1500,x_0,y_78/f_auto,q_auto,w_1100/v1554995977/shape/mentalfloss/iStock-157312120.jpg 15 | http://www.reptilefact.com/wp-content/uploads/2016/08/Images-of-San-Francisco-Garter-Snake.jpg 16 | -------------------------------------------------------------------------------- /tests/unit/methods/meta_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import keras_explainable as ke 5 | 6 | class MetaTest(tf.test.TestCase): 7 | BATCH = 2 8 | SHAPE = [64, 64, 3] 9 | RUN_EAGERLY = False 10 | 11 | def _build_model(self, run_eagerly=RUN_EAGERLY): 12 | input_tensor = tf.keras.Input([None, None, 3], name='inputs') 13 | model = tf.keras.applications.ResNet50V2( 14 | weights=None, 15 | input_tensor=input_tensor, 16 | classifier_activation=None, 17 | ) 18 | model.run_eagerly = run_eagerly 19 | 20 | return model 21 | 22 | def _build_model_with_activations(self, run_eagerly=RUN_EAGERLY): 23 | model = self._build_model(run_eagerly) 24 | 25 | return tf.keras.Model( 26 | inputs=model.inputs, 27 | outputs=[model.output, model.get_layer('avg_pool').input] 28 | ) 29 | 30 | def test_sanity_tta_cam(self): 31 | model = self._build_model_with_activations() 32 | 33 | x, y = map(tf.convert_to_tensor, ( 34 | np.random.rand(self.BATCH, *self.SHAPE), 35 | np.random.randint(10, size=(self.BATCH, 1)) 36 | )) 37 | 38 | tta = ke.methods.meta.tta( 39 | ke.methods.cams.cam, 40 | scales=[0.5], 41 | hflip=True, 42 | ) 43 | logits, maps = tta(model, x, indices=y) 44 | 45 | self.assertIsNotNone(logits) 46 | self.assertEqual(logits.shape, (self.BATCH, 1)) 47 | 48 | self.assertIsNotNone(maps) 49 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 50 | 51 | def test_sanity_smooth_grad(self): 52 | model = self._build_model(run_eagerly=False) 53 | 54 | x, y = map(tf.convert_to_tensor, ( 55 | np.random.rand(self.BATCH, *self.SHAPE), 56 | np.random.randint(10, size=(self.BATCH, 1)) 57 | )) 58 | 59 | smoothgrad = ke.methods.meta.smooth( 60 | ke.methods.gradient.gradients, 61 | repetitions=5, 62 | noise=0.2, 63 | ) 64 | logits, maps = smoothgrad(model, x, y) 65 | 66 | self.assertIsNotNone(logits) 67 | self.assertEqual(logits.shape, (self.BATCH, 1)) 68 | 69 | self.assertIsNotNone(maps) 70 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 71 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /*** 2 | * Code Highlighting. 3 | **/ 4 | 5 | .highlight { 6 | background-color: #fafafa !important; 7 | padding: 0 !important; 8 | } 9 | 10 | .notranslate { 11 | margin-bottom: 1em; 12 | } 13 | 14 | .highlight > pre { 15 | background: none; 16 | border: none; 17 | /* padding: 1em !important; */ 18 | 19 | -webkit-box-shadow: none; 20 | -moz-box-shadow: none; 21 | box-shadow: none; 22 | } 23 | 24 | .jupyter_container { 25 | background-color: transparent !important; 26 | border: none !important; 27 | margin: .85rem 0 !important; 28 | 29 | -webkit-box-shadow: none !important; 30 | -moz-box-shadow: none !important; 31 | box-shadow: none !important; 32 | } 33 | 34 | .jupyter_container > .code_cell { 35 | border: none !important; 36 | } 37 | 38 | .jupyter_container > .cell_output > .output { 39 | margin: 0; 40 | } 41 | 42 | .jupyter_container > .cell_output > .output:first-child { 43 | margin-top: 0.25em; 44 | } 45 | 46 | .jupyter_container > .cell_output > .output > .highlight { 47 | margin: 0; 48 | } 49 | 50 | .jupyter_container > .cell_output > .output > .highlight > pre { 51 | padding: 0 20px !important; 52 | } 53 | 54 | .jupyter_container > .cell_output > .output:first-child > .highlight > pre { 55 | padding-top: 20px !important; 56 | } 57 | 58 | .jupyter_container > .cell_output > .output:last-child > .highlight > pre { 59 | padding-bottom: 20px !important; 60 | } 61 | 62 | /*** 63 | * Forms 64 | **/ 65 | 66 | .dataframe { 67 | border: none !important; 68 | } 69 | 70 | input[type="text"] { 71 | display: block; 72 | width: 100%; 73 | } 74 | 75 | /* 76 | 77 | input[type="text"]:focus { 78 | border-color: #007daf; 79 | box-shadow: 0 0 0 3px rgba(54, 198, 255, .25); 80 | } 81 | 82 | .function>dt, 83 | .method>dt { 84 | overflow-x: auto; 85 | } */ 86 | 87 | img { 88 | max-width: 100%; 89 | } 90 | 91 | /* Spacing */ 92 | 93 | h1,h2,h3 { 94 | margin-top: 0.5em; 95 | margin-bottom: 0.5em; 96 | } 97 | 98 | p { 99 | margin-bottom: 0.5em; 100 | } 101 | 102 | .admonition, .note { 103 | border: 0; 104 | margin-top: 1em; 105 | margin-bottom: 1em; 106 | } 107 | 108 | .viewcode-link, .viewcode-back { 109 | float: right; 110 | } 111 | 112 | 113 | /*** 114 | * Tables 115 | **/ 116 | 117 | /* table { 118 | color: #666; 119 | border: #eee 1px solid; 120 | width: 100%; 121 | } 122 | 123 | table th { 124 | border: 0; 125 | padding: 0.2em; 126 | } 127 | 128 | table td { 129 | text-align: right; 130 | border: #efefef 1px solid; 131 | padding: 0.2em; 132 | } */ 133 | -------------------------------------------------------------------------------- /src/keras_explainable/__init__.py: -------------------------------------------------------------------------------- 1 | """Keras Explainable Library. 2 | 3 | Efficient explaining AI algorithms for Keras models. 4 | """ 5 | 6 | import sys 7 | 8 | from keras_explainable import filters 9 | from keras_explainable import inspection 10 | from keras_explainable import methods 11 | from keras_explainable import utils 12 | from keras_explainable.engine import explaining 13 | from keras_explainable.engine.explaining import explain 14 | from keras_explainable.engine.explaining import partial_explain 15 | 16 | if sys.version_info[:2] >= (3, 8): 17 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 18 | from importlib.metadata import PackageNotFoundError # pragma: no cover 19 | from importlib.metadata import version 20 | else: 21 | from importlib_metadata import PackageNotFoundError # pragma: no cover 22 | from importlib_metadata import version 23 | 24 | try: 25 | # Change here if project is renamed and does not equal the package name 26 | dist_name = "keras-explainable" 27 | __version__ = version(dist_name) 28 | except PackageNotFoundError: # pragma: no cover 29 | __version__ = "unknown" 30 | finally: 31 | del version, PackageNotFoundError 32 | 33 | cam = partial_explain(methods.cams.cam, postprocessing=filters.positive_normalize) 34 | """Shortcut for :py:func:`methods.cams.cam`, 35 | filtering positively contributing regions. 36 | """ 37 | 38 | gradcam = partial_explain(methods.cams.gradcam, postprocessing=filters.positive_normalize) 39 | """Shortcut for :py:func:`methods.cams.gradcam`, 40 | filtering positively contributing regions. 41 | """ 42 | 43 | gradcampp = partial_explain( 44 | methods.cams.gradcampp, postprocessing=filters.positive_normalize 45 | ) 46 | """Shortcut for :py:func:`methods.cams.gradcampp`, 47 | filtering positively contributing regions. 48 | """ 49 | 50 | scorecam = partial_explain( 51 | methods.cams.scorecam, 52 | postprocessing=filters.positive_normalize, 53 | resizing=False, 54 | ) 55 | """Shortcut for :py:func:`methods.cams.scorecam`, 56 | filtering positively contributing regions. 57 | """ 58 | 59 | gradients = partial_explain( 60 | methods.gradient.gradients, 61 | postprocessing=filters.normalize, 62 | resizing=False, 63 | ) 64 | """Shortcut for :py:func:`methods.gradient.gradients`, 65 | filtering absolutely contributing regions. 66 | """ 67 | 68 | full_gradients = partial_explain( 69 | methods.gradient.full_gradients, 70 | postprocessing=filters.normalize, 71 | resizing=False, 72 | ) 73 | """Shortcut for :py:func:`methods.gradient.full_gradients`, 74 | filtering absolutely contributing regions. 75 | """ 76 | 77 | __all__ = [ 78 | "methods", 79 | "inspection", 80 | "filters", 81 | "utils", 82 | "explaining", 83 | "explain", 84 | "gradients", 85 | "full_gradients", 86 | "cam", 87 | "gradcam", 88 | "gradcampp", 89 | "scorecam", 90 | ] 91 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | [testenv] 11 | description = Invoke pytest to run automated tests 12 | setenv = 13 | TOXINIDIR = {toxinidir} 14 | passenv = 15 | HOME 16 | SETUPTOOLS_* 17 | extras = 18 | testing 19 | commands = 20 | pytest {posargs} 21 | 22 | # # To run `tox -e lint` you need to make sure you have a 23 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 24 | # [testenv:lint] 25 | # description = Perform static analysis and style checks 26 | # skip_install = True 27 | # deps = pre-commit 28 | # passenv = 29 | # HOMEPATH 30 | # PROGRAMDATA 31 | # SETUPTOOLS_* 32 | # commands = 33 | # pre-commit run --all-files {posargs:--show-diff-on-failure} 34 | 35 | [testenv:{build,clean}] 36 | description = 37 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 38 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 39 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 40 | skip_install = True 41 | changedir = {toxinidir} 42 | deps = 43 | build: build[virtualenv] 44 | passenv = 45 | SETUPTOOLS_* 46 | commands = 47 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 48 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 49 | build: python -m build {posargs} 50 | 51 | [testenv:{docs,doctests,linkcheck}] 52 | description = 53 | docs: Invoke sphinx-build to build the docs 54 | doctests: Invoke sphinx-build to run doctests 55 | linkcheck: Check for broken links in the documentation 56 | passenv = 57 | SETUPTOOLS_* 58 | setenv = 59 | DOCSDIR = {toxinidir}/docs 60 | BUILDDIR = {toxinidir}/docs/_build 61 | docs: BUILD = html 62 | doctests: BUILD = doctest 63 | linkcheck: BUILD = linkcheck 64 | deps = 65 | -r {toxinidir}/docs/requirements.txt 66 | # ^ requirements.txt shared with Read The Docs 67 | commands = 68 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 69 | 70 | [testenv:publish] 71 | description = 72 | Publish the package you have been developing to a package index server. 73 | By default, it uses testpypi. If you really want to publish your package 74 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 75 | skip_install = True 76 | changedir = {toxinidir} 77 | passenv = 78 | # See: https://twine.readthedocs.io/en/latest/ 79 | TWINE_USERNAME 80 | TWINE_PASSWORD 81 | TWINE_REPOSITORY 82 | TWINE_REPOSITORY_URL 83 | deps = twine 84 | commands = 85 | python -m twine check dist/* 86 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 87 | -------------------------------------------------------------------------------- /src/keras_explainable/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | from math import ceil 3 | from typing import List 4 | from typing import Optional 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from PIL import Image 11 | 12 | # region Generics 13 | 14 | 15 | def tolist(item): 16 | if isinstance(item, list): 17 | return item 18 | 19 | if isinstance(item, (tuple, set)): 20 | return list(item) 21 | 22 | return [item] 23 | 24 | 25 | # endregion 26 | 27 | # region Visualization 28 | 29 | 30 | def get_dims(image): 31 | if hasattr(image, "shape"): 32 | return image.shape 33 | return (len(image), *get_dims(image[0])) 34 | 35 | 36 | def visualize( 37 | images: List[Union[tf.Tensor, np.ndarray]], 38 | titles: Optional[List[str]] = None, 39 | overlays: Optional[List[np.ndarray]] = None, 40 | overlay_alpha: float = 0.75, 41 | rows: Optional[int] = None, 42 | cols: Optional[int] = None, 43 | figsize: Tuple[float, float] = None, 44 | cmap: str = None, 45 | overlay_cmap: str = None, 46 | to_file: str = None, 47 | to_buffer: io.BytesIO = None, 48 | subplots_ws: float = 0.0, 49 | subplots_hs: float = 0.0, 50 | ): 51 | import matplotlib.pyplot as plt 52 | 53 | dims = get_dims(images) 54 | rank = len(dims) 55 | 56 | if isinstance(images, tf.Tensor): 57 | images = images.numpy() 58 | 59 | if isinstance(images, (list, tuple)) or rank > 3: 60 | images = images 61 | else: 62 | images = [images] 63 | 64 | if rows is None and cols is None: 65 | cols = min(8, len(images)) 66 | rows = ceil(len(images) / cols) 67 | elif rows is None: 68 | rows = ceil(len(images) / cols) 69 | else: 70 | cols = ceil(len(images) / rows) 71 | 72 | plt.figure(figsize=figsize or (4 * cols, 4 * rows)) 73 | 74 | for ix, image in enumerate(images): 75 | plt.subplot(rows, cols, ix + 1) 76 | 77 | if image is not None: 78 | if isinstance(image, tf.Tensor): 79 | image = image.numpy() 80 | 81 | if len(image.shape) > 2 and image.shape[-1] == 1: 82 | image = image[..., 0] 83 | 84 | plt.imshow(image, cmap=cmap) 85 | 86 | if overlays is not None and len(overlays) > ix and overlays[ix] is not None: 87 | oi = overlays[ix] 88 | if len(oi.shape) > 2 and oi.shape[-1] == 1: 89 | oi = oi[..., 0] 90 | plt.imshow(oi, overlay_cmap, alpha=overlay_alpha) 91 | if titles is not None and len(titles) > ix: 92 | plt.title(titles[ix]) 93 | plt.axis("off") 94 | 95 | plt.tight_layout() 96 | plt.subplots_adjust(wspace=subplots_ws, hspace=subplots_hs) 97 | 98 | if to_buffer: 99 | plt.savefig(to_buffer) 100 | return Image.open(to_buffer) 101 | 102 | if to_file is not None: 103 | plt.savefig(to_file) 104 | 105 | 106 | # endregion 107 | -------------------------------------------------------------------------------- /docs/methods/cams/ttacam.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | TTA CAM 3 | ======= 4 | 5 | Test-time augmentation (TTA) is a commonly employed strategy in Saliency 6 | detection and Weakly Supervised Segmentation tasks order to obtain smoother 7 | and more stable explaining maps. 8 | 9 | We illustrate in this example how to apply TTA to AI explaining methods using 10 | ``keras-explainable``. This can be easily achieved with the following code 11 | template snippet: 12 | 13 | .. code-block:: python 14 | 15 | import keras_explainable as ke 16 | 17 | model = tf.keras.applications.Xception(...) 18 | model = ke.inspection.expose(model) 19 | 20 | ttacam = ke.methods.meta.tta( 21 | ke.methods.cams.cam, 22 | scales=[0.5, 1.0, 1.5, 2.], 23 | hflip=True 24 | ) 25 | _, cams = ke.explain( 26 | ttacam, model, inputs, postprocessing=ke.filters.positive_normalize 27 | ) 28 | 29 | .. jupyter-execute:: 30 | :hide-code: 31 | :hide-output: 32 | 33 | import os 34 | import numpy as np 35 | import pandas as pd 36 | import tensorflow as tf 37 | from keras.utils import load_img, img_to_array 38 | 39 | import keras_explainable as ke 40 | 41 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 42 | SAMPLES = 8 43 | SIZES = (224, 224) 44 | 45 | file_names = os.listdir(SOURCE_DIRECTORY) 46 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 47 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 48 | images = images.astype("uint8")[:SAMPLES] 49 | 50 | We describe bellow these lines in detail. 51 | Firstly, we employ the :class:`Xception` network pre-trained over the 52 | ImageNet dataset: 53 | 54 | .. jupyter-execute:: 55 | 56 | input_tensor = tf.keras.Input(shape=(None, None, 3), name='inputs') 57 | 58 | model = tf.keras.applications.Xception( 59 | input_tensor=input_tensor, 60 | classifier_activation=None, 61 | weights='imagenet', 62 | ) 63 | 64 | print(f'Xception pretrained over ImageNet was loaded.') 65 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 66 | 67 | We can feed-forward the samples once and get the predicted classes for each 68 | sample. Besides making sure the model is outputting the expected classes, 69 | this step is required in order to determine the most activating units in the 70 | *logits* layer, which improves performance of the explaining methods. 71 | 72 | .. jupyter-execute:: 73 | 74 | from tensorflow.keras.applications.imagenet_utils import preprocess_input 75 | 76 | inputs = images / 127.5 - 1 77 | logits = model.predict(inputs, verbose=0) 78 | indices = np.argsort(logits, axis=-1)[:, ::-1] 79 | 80 | explaining_units = indices[:, :1] # First-most likely classes. 81 | 82 | .. jupyter-execute:: 83 | 84 | model = ke.inspection.expose(model) 85 | 86 | ttacam = ke.methods.meta.tta( 87 | ke.methods.cams.cam, 88 | scales=[0.5, 1.0, 1.5, 2.], 89 | hflip=True 90 | ) 91 | _, cams = ke.explain(ttacam, model, inputs, explaining_units, batch_size=1) 92 | 93 | ke.utils.visualize( 94 | images=[*images, *cams, *images], 95 | overlays=[None] * (2 * len(images)) + [*cams], 96 | ) 97 | -------------------------------------------------------------------------------- /docs/methods/saliency/smoothgrad.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | Smooth-Grad 3 | =========== 4 | 5 | In this page, we describe how to obtain *saliency maps* from a trained 6 | Convolutional Neural Network (CNN) with respect to an input signal (an image, 7 | in this case) using the Smooth-Grad AI explaining method. 8 | 9 | Smooth-Grad is the variant of the Gradient Backprop algorithm first described 10 | in the following paper: 11 | 12 | Smilkov, D., Thorat, N., Kim, B., Viégas, F., & Wattenberg, M. (2017). 13 | Smoothgrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825. 14 | Available at: `arxiv/1706.03825 `_. 15 | 16 | It consists of consecutive repetitions of the Gradient Backprop method, 17 | each of which is applied over the original sample tempered with 18 | some gaussian noise. 19 | Finally, averaging the resulting explaining maps results in cleaner 20 | visualization results, robust against marginal noise. 21 | 22 | Briefly, this can be achieved with the following template snippet: 23 | 24 | .. code-block:: python 25 | 26 | import keras_explainable as ke 27 | 28 | model = build_model(...) 29 | model.layers[-1].activation = 'linear' # Usually softmax or sigmoid. 30 | 31 | smoothgrad = ke.methods.meta.smooth( 32 | ke.methods.gradient.gradients, 33 | repetitions=10, 34 | noise=0.1 35 | ) 36 | 37 | logits, maps = ke.explain( 38 | smoothgrad, model, x, y, 39 | batch_size=32, 40 | postprocessing=ke.filters.normalize, 41 | ) 42 | 43 | We will describe each one of the steps above in detail. 44 | Firstly, we employ the :class:`Xception` network pre-trained over the 45 | ImageNet dataset: 46 | 47 | .. jupyter-execute:: 48 | :hide-code: 49 | :hide-output: 50 | 51 | import os 52 | import numpy as np 53 | import pandas as pd 54 | import tensorflow as tf 55 | from keras.utils import load_img, img_to_array 56 | 57 | import keras_explainable as ke 58 | 59 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 60 | SAMPLES = 8 61 | SIZES = (299, 299) 62 | 63 | file_names = os.listdir(SOURCE_DIRECTORY) 64 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 65 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 66 | images = images.astype("uint8")[:SAMPLES] 67 | 68 | .. jupyter-execute:: 69 | 70 | model = tf.keras.applications.Xception( 71 | classifier_activation=None, 72 | weights='imagenet', 73 | ) 74 | 75 | print(f'ResNet50 pretrained over ImageNet was loaded.') 76 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 77 | 78 | We can feed-forward the samples once and get the predicted classes for each sample. 79 | Besides making sure the model is outputting the expected classes, this step is 80 | required in order to determine the most activating units in the *logits* layer, 81 | which improves performance of the explaining methods. 82 | 83 | .. jupyter-execute:: 84 | 85 | inputs = images / 127.5 - 1 86 | logits = model.predict(inputs, verbose=0) 87 | indices = np.argsort(logits, axis=-1)[:, ::-1] 88 | explaining_units = indices[:, :1] # First most likely class. 89 | 90 | keras-explainable implements the Smooth-Grad with the meta explaining function 91 | :func:`keras_explainable.methods.meta.smooth`, which means it wraps any 92 | explaining method and smooths out its outputs. For example: 93 | 94 | .. jupyter-execute:: 95 | 96 | smoothgrad = ke.methods.meta.smooth( 97 | tf.function(ke.methods.gradient.gradients, reduce_retracing=True, jit_compile=True), 98 | repetitions=20, 99 | noise=0.1, 100 | ) 101 | _, smoothed_maps = smoothgrad( 102 | model, 103 | inputs, 104 | explaining_units, 105 | ) 106 | 107 | smoothed_maps = ke.filters.absolute_normalize(smoothed_maps).numpy() 108 | 109 | 110 | For comparative purposes, we also compute the vanilla gradients method: 111 | 112 | .. jupyter-execute:: 113 | 114 | _, maps = ke.gradients(model, inputs, explaining_units) 115 | 116 | ke.utils.visualize([*images, *maps, *smoothed_maps]) 117 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = keras-explainable 8 | description = Explainable algorithms for Keras models 9 | author = Lucas David 10 | author_email = lucasolivdavid@gmail.com 11 | license = Apache-2.0 12 | license_files = LICENSE 13 | long_description = file: README.rst 14 | long_description_content_type = text/x-rst; charset=UTF-8 15 | url = https://github.com/lucasdavid/keras-explainable 16 | 17 | project_urls = 18 | Documentation = https://pyscaffold.org/ 19 | # Source = https://github.com/pyscaffold/pyscaffold/ 20 | # Changelog = https://pyscaffold.org/en/latest/changelog.html 21 | # Tracker = https://github.com/pyscaffold/pyscaffold/issues 22 | # Conda-Forge = https://anaconda.org/conda-forge/pyscaffold 23 | # Download = https://pypi.org/project/PyScaffold/#files 24 | # Twitter = https://twitter.com/PyScaffold 25 | 26 | # Change if running only on Windows, Mac or Linux (comma-separated) 27 | platforms = any 28 | 29 | # Add here all kinds of additional classifiers as defined under 30 | # https://pypi.org/classifiers/ 31 | classifiers = 32 | Development Status :: 4 - Beta 33 | Programming Language :: Python 34 | 35 | [options] 36 | zip_safe = False 37 | packages = find_namespace: 38 | include_package_data = True 39 | package_dir = 40 | =src 41 | 42 | # Require a min/specific Python version (comma-separated conditions) 43 | # python_requires = >=3.8 44 | 45 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 46 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 47 | # new major versions. This works if the required packages follow Semantic Versioning. 48 | # For more information, check out https://semver.org/. 49 | install_requires = 50 | importlib-metadata; python_version<"3.8" 51 | tensorflow 52 | keras 53 | 54 | [options.packages.find] 55 | where = src 56 | exclude = 57 | tests 58 | 59 | [options.extras_require] 60 | # Add here additional requirements for extra features, to install with: 61 | # `pip install keras-explainable[PDF]` like: 62 | # PDF = ReportLab; RXP 63 | 64 | # Add here test requirements (semicolon/line-separated) 65 | testing = 66 | setuptools 67 | pytest 68 | pytest-cov 69 | parameterized 70 | 71 | [options.entry_points] 72 | # Add here console scripts like: 73 | # console_scripts = 74 | # script_name = keras_explainable.module:function 75 | # For example: 76 | # console_scripts = 77 | # fibonacci = keras_explainable.cli:run 78 | # And any other entry points, for example: 79 | # pyscaffold.cli = 80 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 81 | 82 | [tool:pytest] 83 | # Specify command line options as you would do when invoking pytest directly. 84 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 85 | # in order to write a coverage file that can be read by Jenkins. 86 | # CAUTION: --cov flags may prohibit setting breakpoints while debugging. 87 | # Comment those flags to avoid this pytest issue. 88 | addopts = 89 | --cov keras_explainable --cov-report term-missing 90 | --verbose 91 | norecursedirs = 92 | dist 93 | build 94 | .tox 95 | testpaths = tests 96 | # Use pytest markers to select/deselect specific tests 97 | # markers = 98 | # slow: mark tests as slow (deselect with '-m "not slow"') 99 | # system: mark end-to-end system tests 100 | 101 | [devpi:upload] 102 | # Options for the devpi: PyPI server and packaging tool 103 | # VCS export must be deactivated since we are using setuptools-scm 104 | no_vcs = 1 105 | formats = bdist_wheel 106 | 107 | [flake8] 108 | # Some sane defaults for the code style checker flake8 109 | max_line_length = 90 110 | extend_ignore = E203, W503 111 | # ^ Black-compatible 112 | # E203 and W503 have edge cases handled by black 113 | exclude = 114 | .tox 115 | build 116 | dist 117 | .eggs 118 | docs/conf.py 119 | 120 | [pyscaffold] 121 | # PyScaffold's parameters when the project was created. 122 | # This will be used when updating. Do not change! 123 | version = 4.3.1 124 | package = keras_explainable 125 | -------------------------------------------------------------------------------- /src/keras_explainable/filters.py: -------------------------------------------------------------------------------- 1 | """Shortcuts for commonly used signal filters used in literature. 2 | 3 | These filters can be used as post or mid processing for explaining 4 | methods and techniques. 5 | 6 | .. jupyter-execute:: 7 | :hide-code: 8 | :hide-output: 9 | 10 | import numpy as np 11 | import keras_explainable as ke 12 | 13 | """ 14 | 15 | from typing import Tuple 16 | 17 | import tensorflow as tf 18 | 19 | from keras_explainable.inspection import SPATIAL_AXIS 20 | 21 | 22 | def normalize(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 23 | """Normalize the signal into the interval [0, 1]. 24 | 25 | Usage: 26 | 27 | .. jupyter-execute:: 28 | 29 | x = 5 * np.random.normal(size=(4, 16, 16, 3)).round(1) 30 | y = ke.filters.absolute_normalize(x).numpy() 31 | print(f"[{x.min()}, {x.max()}] -> [{y.min()}, {y.max()}]") 32 | 33 | Args: 34 | x (tf.Tensor): the input signal to be normalized. 35 | axis (Tuple[int], optional): the dimensions containing positional 36 | information. Defaults to ``SPATIAL_AXIS``. 37 | 38 | Returns: 39 | tf.Tensor: the normalized signal. 40 | """ 41 | x = tf.convert_to_tensor(x) 42 | x -= tf.reduce_min(x, axis=axis, keepdims=True) 43 | 44 | return tf.math.divide_no_nan(x, tf.reduce_max(x, axis=axis, keepdims=True)) 45 | 46 | 47 | def positive(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 48 | """Retain only positive values of the input signal. 49 | 50 | Usage: 51 | 52 | .. jupyter-execute:: 53 | 54 | x = np.asarray([0, -1, 2, -3]) 55 | y = ke.filters.positive(x).numpy() 56 | print(f"{x} -> {y}") 57 | 58 | Args: 59 | x (tf.Tensor): the input signal. 60 | axis (Tuple[int], optional): the dimensions containing positional 61 | information. Defaults to ``SPATIAL_AXIS``. 62 | 63 | Returns: 64 | tf.Tensor: the filtered signal. 65 | """ 66 | return tf.nn.relu(x) 67 | 68 | 69 | def negative(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 70 | """Retain only negative values of the input signal. 71 | 72 | Usage: 73 | 74 | .. jupyter-execute:: 75 | 76 | x = np.asarray([0, -1, 2, -3]) 77 | y = ke.filters.negative(x).numpy() 78 | print(f"{x} -> {y}") 79 | 80 | Args: 81 | x (tf.Tensor): the input 82 | axis (Tuple[int], optional): the dimensions containing positional 83 | information. Defaults to ``SPATIAL_AXIS``. 84 | 85 | Returns: 86 | tf.Tensor: the filtered signal. 87 | """ 88 | return tf.maximum(x, 0) 89 | 90 | 91 | def positive_normalize(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 92 | """Retain only positive values of the input signal and normalize it between 0 and 1. 93 | 94 | Args: 95 | x (tf.Tensor): the input signal. 96 | axis (Tuple[int], optional): the dimensions containing positional 97 | information. Defaults to ``SPATIAL_AXIS``. 98 | 99 | Returns: 100 | tf.Tensor: the filtered signal. 101 | """ 102 | return normalize(positive(x, axis=axis), axis=axis) 103 | 104 | 105 | def absolute_normalize(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 106 | """Absolute values of the input signal and normalize it between 0 and 1. 107 | 108 | Args: 109 | x (tf.Tensor): the input signal. 110 | axis (Tuple[int], optional): the dimensions containing positional 111 | information. Defaults to ``SPATIAL_AXIS``. 112 | 113 | Returns: 114 | tf.Tensor: the filtered signal. 115 | """ 116 | return normalize(tf.abs(x), axis=axis) 117 | 118 | 119 | def negative_normalize(x: tf.Tensor, axis: Tuple[int] = SPATIAL_AXIS) -> tf.Tensor: 120 | """Retain only negative values of the input signal and normalize it between 0 and 1. 121 | 122 | Args: 123 | x (tf.Tensor): the input signal. 124 | axis (Tuple[int], optional): the dimensions containing positional 125 | information. Defaults to ``SPATIAL_AXIS``. 126 | 127 | Returns: 128 | tf.Tensor: the filtered signal. 129 | """ 130 | return normalize(negative(x), axis=axis) 131 | 132 | 133 | __all__ = [ 134 | "normalize", 135 | "positive", 136 | "negative", 137 | "positive_normalize", 138 | "absolute_normalize", 139 | "negative_normalize", 140 | ] 141 | -------------------------------------------------------------------------------- /docs/methods/saliency/fullgrad.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | Full Gradients 3 | ============== 4 | 5 | In this page, we describe how to obtain *saliency maps* from a trained 6 | Convolutional Neural Network (CNN) with respect to an input signal (an image, 7 | in this case) using the Full Gradients AI explaining method. 8 | Said maps can be used to explain the model's predictions, determining regions 9 | which most contributed to its effective output. 10 | 11 | FullGrad (short for Full Gradients) extends Gradient Back-propagation by adding the 12 | individual biases contributions to the gradient signal, forming the "full" explaining 13 | maps. This technique is fully described in the paper "Full-gradient representation for 14 | neural network visualization", published in Advances in neural information processing 15 | systems, 32 by Srinivas, S., & Fleuret, F. (2019), 16 | `arxiv.org/1905.00780v4 `_. 17 | 18 | Briefly, this can be achieved with the following template snippet: 19 | 20 | .. code-block:: python 21 | 22 | import keras_explainable as ke 23 | 24 | model = build_model(...) 25 | model.layers[-1].activation = 'linear' # Usually softmax or sigmoid. 26 | 27 | logits = ke.inspection.get_logits_layer(model) 28 | inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits]) 29 | model = ke.inspection.expose(model, inters, logits) 30 | 31 | x, y = ( 32 | np.random.rand(32, 512, 512, 3), 33 | np.random.randint(10, size=[32, 1]) 34 | ) 35 | 36 | logits, maps = ke.full_gradients( 37 | model, 38 | x, 39 | y, 40 | biases=biases, 41 | ) 42 | 43 | We describe bellow these lines in detail. 44 | Firstly, we employ the :class:`Xception` network pre-trained over the 45 | ImageNet dataset: 46 | 47 | .. jupyter-execute:: 48 | :hide-code: 49 | :hide-output: 50 | 51 | import os 52 | import numpy as np 53 | import pandas as pd 54 | import tensorflow as tf 55 | from keras.utils import load_img, img_to_array 56 | 57 | import keras_explainable as ke 58 | 59 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 60 | SAMPLES = 8 61 | SIZES = (299, 299) 62 | 63 | file_names = os.listdir(SOURCE_DIRECTORY) 64 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 65 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 66 | images = images.astype("uint8")[:SAMPLES] 67 | 68 | .. jupyter-execute:: 69 | 70 | model = tf.keras.applications.Xception( 71 | classifier_activation=None, 72 | weights="imagenet", 73 | ) 74 | 75 | print(f'Xception pretrained over ImageNet was loaded.') 76 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 77 | 78 | We can feed-forward the samples once and get the predicted classes for each sample. 79 | Besides making sure the model is outputting the expected classes, this step is 80 | required in order to determine the most activating units in the *logits* layer, 81 | which improves performance of the explaining methods. 82 | 83 | .. jupyter-execute:: 84 | 85 | from tensorflow.keras.applications.imagenet_utils import preprocess_input 86 | 87 | inputs = preprocess_input(images.astype("float").copy(), mode="tf") 88 | logits = model.predict(inputs, verbose=0) 89 | indices = np.argsort(logits, axis=-1)[:, ::-1] 90 | explaining_units = indices[:, :1] # First-most likely classes. 91 | 92 | The FullGrad algorithm, implemented through the 93 | :func:`keras_explainable.methods.gradient.full_gradients`, 94 | expects a model that exposes all layers containing biases (besides the output). 95 | Thus, we must first expose them. The most efficient way to do so is 96 | by collecting the layers directly: 97 | 98 | .. jupyter-execute:: 99 | 100 | logits = ke.inspection.get_logits_layer(model) 101 | inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits]) 102 | model = ke.inspection.expose(model, inters, logits) 103 | 104 | Now we can obtain FullGrad by simply calling to the :func:`explain` function: 105 | 106 | .. jupyter-execute:: 107 | 108 | _, maps = ke.full_gradients( 109 | model, 110 | inputs, 111 | explaining_units, 112 | biases=biases, 113 | ) 114 | 115 | ke.utils.visualize( 116 | images=[*images, *maps, *images], 117 | overlays=[None] * (2 * len(images)) + [*maps], 118 | ) 119 | 120 | .. note:: 121 | 122 | Passing the list of ``biases`` as a parameter to the 123 | :func:`~keras_explainable.full_gradients` function is not required, but it 124 | is generally a good idea, as it avoids unnecessary recollection of those. 125 | -------------------------------------------------------------------------------- /docs/methods/cams/gradcam.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Grad-CAM 3 | ======== 4 | 5 | This example illustrate how to explain predictions of a Convolutional Neural 6 | Network (CNN) using Grad-CAM. This can be easily achieved with the following 7 | code template snippet: 8 | 9 | .. code-block:: python 10 | 11 | import keras_explainable as ke 12 | 13 | model = tf.keras.applications.ResNet50V2(...) 14 | model = ke.inspection.expose(model) 15 | 16 | scores, cams = ke.gradcam(model, x, y, batch_size=32) 17 | 18 | In this page, we describe how to obtain *Class Activation Maps* (CAMs) from a 19 | trained Convolutional Neural Network (CNN) with respect to an input signal 20 | (an image, in this case) using the Grad-CAM visualization method. 21 | Said maps can be used to explain the model's predictions, determining regions 22 | which most contributed to its effective output. 23 | 24 | Grad-CAM is a form of visualizing regions that most contributed to the output 25 | of a given logit unit of a neural network, often times associated with the 26 | prediction of the occurrence of a class in the problem domain. This method 27 | is first described in the following article: 28 | 29 | Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. 30 | (2017). Grad-cam: Visual explanations from deep networks via gradient-based 31 | localization. In Proceedings of the IEEE international conference on computer 32 | vision (pp. 618-626). 33 | 34 | Briefly, this can be achieved with the following template snippet: 35 | 36 | .. code-block:: python 37 | 38 | import keras_explainable as ke 39 | 40 | model = build_model(...) 41 | logits, maps = ke.gradients(model, x, y, batch_size=32) 42 | 43 | We describe bellow these lines in detail. 44 | 45 | .. jupyter-execute:: 46 | :hide-code: 47 | :hide-output: 48 | 49 | import os 50 | import numpy as np 51 | import pandas as pd 52 | import tensorflow as tf 53 | from keras.utils import load_img, img_to_array 54 | 55 | import keras_explainable as ke 56 | 57 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 58 | SAMPLES = 8 59 | SIZES = (299, 299) 60 | 61 | file_names = os.listdir(SOURCE_DIRECTORY) 62 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 63 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 64 | images = images.astype("uint8")[:SAMPLES] 65 | 66 | Firstly, we employ the :class:`Xception` network pre-trained over the 67 | ImageNet dataset: 68 | 69 | .. jupyter-execute:: 70 | 71 | model = tf.keras.applications.Xception( 72 | classifier_activation=None, 73 | weights='imagenet', 74 | ) 75 | 76 | print(f'Xception pretrained over ImageNet was loaded.') 77 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 78 | 79 | We can feed-forward the samples once and get the predicted classes for each sample. 80 | Besides making sure the model is outputting the expected classes, this step is 81 | required in order to determine the most activating units in the *logits* layer, 82 | which improves performance of the explaining methods. 83 | 84 | .. jupyter-execute:: 85 | 86 | from tensorflow.keras.applications.imagenet_utils import preprocess_input 87 | 88 | inputs = images / 127.5 - 1 89 | logits = model.predict(inputs, verbose=0) 90 | indices = np.argsort(logits, axis=-1)[:, ::-1] 91 | 92 | explaining_units = indices[:, :1] # First most likely class of each sample. 93 | 94 | Grad-CAM works by computing the differential of an activation function, 95 | usually associated with the prediction of a given class, with respect to pixels 96 | contained in the activation map retrieved from an intermediate convolutional 97 | signal (oftentimes advent from the last convolutional layer). 98 | 99 | CAM-based methods implemented here expect the model to output both logits and 100 | activation signal, so their respective representative tensors are exposed and 101 | the jacobian can be computed from the former with respect to the latter. 102 | Hence, we modify the current `model` model --- which only output logits at this 103 | time --- to expose both activation maps and logits signals: 104 | 105 | .. jupyter-execute:: 106 | 107 | model = ke.inspection.expose(model) 108 | _, cams = ke.gradcam(model, inputs, explaining_units) 109 | 110 | ke.utils.visualize( 111 | images=[*images, *cams, *images], 112 | overlays=[None] * (2 * len(images)) + [*cams], 113 | ) 114 | 115 | .. note:: 116 | 117 | To increase efficiency, we sub-select only the top :math:`K` scoring 118 | classification units to explain. The jacobian will only be computed for 119 | these :math:`NK` outputs. 120 | 121 | Following the original Grad-CAM paper, we only consider the positive 122 | contributing regions in the creation of the CAMs, crunching negatively 123 | contributing and non-related regions together. 124 | This is done automatically by :py:func:`ke.gradcam`, which assigns 125 | the default value :py:func:`filters.positive_normalize` to the 126 | ``postprocessing`` parameter. 127 | -------------------------------------------------------------------------------- /tests/unit/engine/explaining_test.py: -------------------------------------------------------------------------------- 1 | from parameterized import parameterized 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | import keras_explainable as ke 7 | 8 | TEST_EXPLAIN_SANITY_GRADIENTS_EXCLUDE = ( 9 | ke.methods.gradient.full_gradients, 10 | ) 11 | 12 | class ExplainTest(tf.test.TestCase): 13 | BATCH = 2 14 | SHAPE = [64, 64, 3] 15 | RUN_EAGERLY = False 16 | 17 | def _build_model(self, run_eagerly=False, jit_compile=False): 18 | input_tensor = tf.keras.Input([None, None, 3], name='inputs') 19 | model = tf.keras.applications.ResNet50V2( 20 | weights=None, 21 | input_tensor=input_tensor, 22 | classifier_activation=None, 23 | ) 24 | model.compile( 25 | optimizer='sgd', 26 | loss='sparse_categorical_crossentropy', 27 | metrics=['accuracy'], 28 | run_eagerly=run_eagerly, 29 | jit_compile=jit_compile, 30 | ) 31 | 32 | return model 33 | 34 | def _build_model_with_activations(self, run_eagerly=False, jit_compile=False): 35 | model = self._build_model(run_eagerly, jit_compile) 36 | 37 | return tf.keras.Model( 38 | inputs=model.inputs, 39 | outputs=[model.output, model.get_layer('avg_pool').input] 40 | ) 41 | 42 | @parameterized.expand([(m,) for m in ke.methods.cams.METHODS]) 43 | def test_explain_sanity_cams(self, explaining_method): 44 | model = self._build_model_with_activations() 45 | 46 | x, y = ( 47 | np.random.rand(self.BATCH, *self.SHAPE), 48 | np.random.randint(10, size=(self.BATCH, 1)) 49 | ) 50 | 51 | logits, maps = ke.explain(explaining_method, model, x, y) 52 | 53 | self.assertIsNotNone(logits) 54 | self.assertEqual(logits.shape, (self.BATCH, 1)) 55 | 56 | self.assertIsNotNone(maps) 57 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 58 | 59 | @parameterized.expand([ 60 | (False, True), 61 | (True, False), 62 | ]) 63 | def test_explain_cams_jit_compile(self, run_eagerly, jit_compile): 64 | model = self._build_model_with_activations(run_eagerly, jit_compile) 65 | 66 | x, y = ( 67 | np.random.rand(self.BATCH, *self.SHAPE), 68 | np.random.randint(10, size=(self.BATCH, 1)) 69 | ) 70 | 71 | logits, maps = ke.explain(ke.methods.cams.gradcam, model, x, y) 72 | 73 | self.assertIsNotNone(logits) 74 | self.assertEqual(logits.shape, (self.BATCH, 1)) 75 | 76 | self.assertIsNotNone(maps) 77 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 78 | 79 | @parameterized.expand([ 80 | (m,) 81 | for m in ke.methods.gradient.METHODS 82 | if m not in TEST_EXPLAIN_SANITY_GRADIENTS_EXCLUDE 83 | ]) 84 | def test_explain_sanity_gradients(self, explaining_method): 85 | model = self._build_model() 86 | 87 | x, y = ( 88 | np.random.rand(self.BATCH, *self.SHAPE), 89 | np.random.randint(10, size=(self.BATCH, 1)) 90 | ) 91 | 92 | logits, maps = ke.explain(explaining_method, model, x, y) 93 | 94 | self.assertIsNotNone(logits) 95 | self.assertEqual(logits.shape, (self.BATCH, 1)) 96 | 97 | self.assertIsNotNone(maps) 98 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 99 | 100 | def test_explain_tta_cam(self): 101 | model = self._build_model_with_activations() 102 | 103 | x, y = ( 104 | np.random.rand(self.BATCH, *self.SHAPE), 105 | np.random.randint(10, size=(self.BATCH, 1)) 106 | ) 107 | 108 | explaining_method = ke.methods.meta.tta( 109 | ke.methods.cams.cam, 110 | scales=[0.5], 111 | hflip=True, 112 | ) 113 | logits, maps = ke.explain(explaining_method, model, x, y) 114 | 115 | self.assertIsNotNone(logits) 116 | self.assertEqual(logits.shape, (self.BATCH, 1)) 117 | 118 | self.assertIsNotNone(maps) 119 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 120 | 121 | def test_explain_smoothgrad(self): 122 | model = self._build_model(run_eagerly=True) 123 | 124 | x, y = ( 125 | np.random.rand(self.BATCH, *self.SHAPE), 126 | np.random.randint(10, size=(self.BATCH, 1)) 127 | ) 128 | 129 | explaining_method = ke.methods.meta.smooth( 130 | ke.methods.gradient.gradients, 131 | repetitions=3, 132 | noise=0.1, 133 | ) 134 | logits, maps = ke.explain(explaining_method, model, x, y) 135 | 136 | self.assertIsNotNone(logits) 137 | self.assertEqual(logits.shape, (self.BATCH, 1)) 138 | 139 | self.assertIsNotNone(maps) 140 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 141 | 142 | def test_explain_sanity_fullgradients(self): 143 | model = self._build_model() 144 | logits = ke.inspection.get_logits_layer(model) 145 | inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits]) 146 | 147 | model = ke.inspection.expose(model, inters, logits) 148 | 149 | x, y = ( 150 | np.random.rand(self.BATCH, *self.SHAPE), 151 | np.random.randint(10, size=(self.BATCH, 1)) 152 | ) 153 | 154 | logits, maps = ke.explain( 155 | ke.methods.gradient.full_gradients, 156 | model, 157 | x, 158 | y, 159 | biases=biases, 160 | ) 161 | 162 | self.assertIsNotNone(logits) 163 | self.assertEqual(logits.shape, (self.BATCH, 1)) 164 | 165 | self.assertIsNotNone(maps) 166 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 167 | -------------------------------------------------------------------------------- /tests/unit/inspection_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import keras_explainable as ke 5 | 6 | 7 | class InspectionTest(tf.test.TestCase): 8 | BATCH = 2 9 | SHAPE = [64, 64, 3] 10 | RUN_EAGERLY = False 11 | 12 | def _compile(self, model, run_eagerly=False, jit_compile=False): 13 | model.compile( 14 | optimizer='sgd', 15 | loss='sparse_categorical_crossentropy', 16 | metrics=['accuracy'], 17 | run_eagerly=run_eagerly, 18 | jit_compile=jit_compile, 19 | ) 20 | 21 | return model 22 | 23 | def _assert_valid_results(self, model): 24 | x, y = ( 25 | np.random.rand(self.BATCH, *self.SHAPE), 26 | np.random.randint(10, size=(self.BATCH, 1)) 27 | ) 28 | 29 | logits, maps = ke.gradcam(model, x, y) 30 | 31 | self.assertIsNotNone(logits) 32 | self.assertEqual(logits.shape, (self.BATCH, 1)) 33 | 34 | self.assertIsNotNone(maps) 35 | self.assertEqual(maps.shape, (self.BATCH, *self.SHAPE[:2], 1)) 36 | 37 | self.assertGreater(maps.max(), 0.) 38 | 39 | def test_functional_backbone(self): 40 | model = tf.keras.applications.ResNet50V2( 41 | input_shape=[64, 64, 3], 42 | classifier_activation=None, 43 | weights=None, 44 | classes=10, 45 | ) 46 | self._compile(model) 47 | 48 | exposed = ke.inspection.expose(model) 49 | self._assert_valid_results(exposed) 50 | 51 | def test_sequential_nested_backbone(self): 52 | rn50 = tf.keras.applications.ResNet50V2( 53 | input_shape=[None, None, 3], 54 | include_top=False, 55 | weights=None, 56 | ) 57 | model = tf.keras.Sequential([ 58 | tf.keras.Input([None, None, 3]), 59 | rn50, 60 | tf.keras.layers.GlobalAveragePooling2D(name="avg_pool"), 61 | tf.keras.layers.Dense(10, name="logits"), 62 | tf.keras.layers.Activation("softmax", name="predictions"), 63 | ]) 64 | self._compile(model) 65 | 66 | exposed = ke.inspection.expose(model) 67 | self._assert_valid_results(exposed) 68 | 69 | def test_sequential_nested_backbone_with_pooling(self): 70 | self.skipTest( 71 | "Nodes are not being properly appended when a model is nested. " 72 | "Skipping this test while this is not fixed. " 73 | "See #34977 and #16123 for more information." 74 | ) 75 | 76 | rn50 = tf.keras.applications.ResNet50V2( 77 | input_shape=[None, None, 3], 78 | include_top=False, 79 | weights=None, 80 | pooling="avg", 81 | ) 82 | model = tf.keras.Sequential([ 83 | tf.keras.Input([None, None, 3]), 84 | rn50, 85 | tf.keras.layers.Dense(10, name="logits"), 86 | tf.keras.layers.Activation("softmax", name="predictions"), 87 | ]) 88 | self._compile(model) 89 | 90 | exposed = ke.inspection.expose( 91 | model, 92 | {"name": ("resnet50v2", "avg_pool"), "link": "input", "node": 1} 93 | ) 94 | self._assert_valid_results(exposed) 95 | 96 | def test_functional_nested_backbone(self): 97 | rn50 = tf.keras.applications.ResNet50V2( 98 | input_shape=[None, None, 3], 99 | include_top=False, 100 | weights=None, 101 | ) 102 | x = tf.keras.Input([None, None, 3]) 103 | y = rn50(x) 104 | y = tf.keras.layers.GlobalAveragePooling2D(name="avg_pool")(y) 105 | y = tf.keras.layers.Dense(10, name="logits")(y) 106 | y = tf.keras.layers.Activation("softmax", name="predictions")(y) 107 | model = tf.keras.Model(x, y) 108 | self._compile(model) 109 | 110 | exposed = ke.inspection.expose(model) 111 | self._assert_valid_results(exposed) 112 | 113 | def test_functional_nested_backbone_with_pooling(self): 114 | rn50 = tf.keras.applications.ResNet50V2( 115 | input_shape=[None, None, 3], 116 | include_top=False, 117 | weights=None, 118 | pooling="avg", 119 | ) 120 | x = tf.keras.Input([None, None, 3]) 121 | y = rn50(x) 122 | y = tf.keras.layers.Dense(10, name="logits")(y) 123 | y = tf.keras.layers.Activation("softmax", name="predictions")(y) 124 | model = tf.keras.Model(x, y) 125 | self._compile(model) 126 | 127 | exposed = ke.inspection.expose(model) 128 | self._assert_valid_results(exposed) 129 | 130 | def test_functional_flatten_backbone(self): 131 | rn50 = tf.keras.applications.ResNet50V2( 132 | input_shape=[None, None, 3], 133 | include_top=False, 134 | weights=None, 135 | ) 136 | y = rn50.output 137 | y = tf.keras.layers.GlobalAveragePooling2D(name="avg_pool")(y) 138 | y = tf.keras.layers.Dense(10, name="logits")(y) 139 | y = tf.keras.layers.Activation("softmax", name="predictions")(y) 140 | model = tf.keras.Model(rn50.input, y) 141 | self._compile(model) 142 | 143 | exposed = ke.inspection.expose(model) 144 | self._assert_valid_results(exposed) 145 | 146 | def test_functional_nested_backbone_with_pooling(self): 147 | rn50 = tf.keras.applications.ResNet50V2( 148 | input_shape=[None, None, 3], 149 | include_top=False, 150 | weights=None, 151 | pooling="avg", 152 | ) 153 | y = rn50.output 154 | y = tf.keras.layers.Dense(10, name="logits")(y) 155 | y = tf.keras.layers.Activation("softmax", name="predictions")(y) 156 | model = tf.keras.Model(rn50.input, y) 157 | self._compile(model) 158 | 159 | exposed = ke.inspection.expose(model) 160 | self._assert_valid_results(exposed) 161 | -------------------------------------------------------------------------------- /docs/explaining.rst: -------------------------------------------------------------------------------- 1 | ============================== 2 | Explaining Model's Predictions 3 | ============================== 4 | 5 | This library has the function :func:`~keras_explainable.explain` as core 6 | component, which is used to execute any AI explaining method and technique. 7 | Think of it as the :meth:`keras.Model.fit` or :meth:`keras.Model.predict` 8 | loops of Keras' models, in which the execution graph of the operations 9 | contained in a model is compiled (conditioned to :attr:`Model.run_eagerly` 10 | and :attr:`Model.jit_compile`) and the explaining maps are computed 11 | according to the method's strategy. 12 | 13 | Just like in :meth:`keras.model.predict`, :func:`~keras_explainable.explain` 14 | allows various types of input data and retrieves the Model's associated 15 | distribute strategy in order to distribute the workload across multiple 16 | GPUs and/or workers. 17 | 18 | 19 | .. jupyter-execute:: 20 | :hide-code: 21 | :hide-output: 22 | 23 | import os 24 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 25 | 26 | import numpy as np 27 | import pandas as pd 28 | import tensorflow as tf 29 | from keras.utils import load_img, img_to_array 30 | 31 | import keras_explainable as ke 32 | 33 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 34 | SAMPLES = 8 35 | SIZES = (299, 299) 36 | 37 | file_names = os.listdir(SOURCE_DIRECTORY) 38 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 39 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 40 | images = images.astype("uint8")[:SAMPLES] 41 | 42 | We demonstrate bellow how predictions can be explained using the 43 | Xception network trained over ImageNet, using a few image samples. 44 | Firstly, we load the network: 45 | 46 | .. jupyter-execute:: 47 | 48 | model = tf.keras.applications.Xception( 49 | classifier_activation=None, 50 | weights='imagenet', 51 | ) 52 | 53 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 54 | 55 | We can feed-forward the samples once and get the predicted classes for each sample. 56 | Besides making sure the model is outputting the expected classes, this step is 57 | required in order to determine the most activating units in the *logits* layer, 58 | which improves performance of the explaining methods. 59 | 60 | .. jupyter-execute:: 61 | 62 | from tensorflow.keras.applications.imagenet_utils import preprocess_input, decode_predictions 63 | 64 | inputs = images / 127.5 - 1 65 | logits = model.predict(inputs, verbose=0) 66 | 67 | indices = np.argsort(logits, axis=-1)[:, ::-1] 68 | probs = tf.nn.softmax(logits).numpy() 69 | predictions = decode_predictions(probs, top=1) 70 | 71 | ke.utils.visualize( 72 | images=images, 73 | titles=[ 74 | ", ".join(f"{klass} {prob:.0%}" for code, klass, prob in p) 75 | for p in predictions 76 | ] 77 | ) 78 | 79 | Finally, we can simply run all available explaining methods: 80 | 81 | .. jupyter-execute:: 82 | :hide-output: 83 | 84 | explaining_units = indices[:, :1] # First most likely class. 85 | 86 | # Gradient Back-propagation 87 | _, g_maps = ke.gradients(model, inputs, explaining_units) 88 | 89 | # Full-Gradient 90 | logits = ke.inspection.get_logits_layer(model) 91 | inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits]) 92 | model_exp = ke.inspection.expose(model, inters, logits) 93 | _, fg_maps = ke.full_gradients(model_exp, inputs, explaining_units, biases=biases) 94 | 95 | # CAM-Based 96 | model_exp = ke.inspection.expose(model) 97 | _, c_maps = ke.cam(model_exp, inputs, explaining_units) 98 | _, gc_maps = ke.gradcam(model_exp, inputs, explaining_units) 99 | _, gcpp_maps = ke.gradcampp(model_exp, inputs, explaining_units) 100 | _, sc_maps = ke.scorecam(model_exp, inputs, explaining_units) 101 | 102 | .. jupyter-execute:: 103 | :hide-code: 104 | 105 | all_maps = (g_maps, fg_maps, c_maps, gc_maps, gcpp_maps, sc_maps) 106 | 107 | _images = images.repeat(1 + len(all_maps), axis=0) 108 | _titles = 'original Gradients Full-Grad CAM Grad-CAM Grad-CAM++ Score-CAM'.split() 109 | _overlays = sum(zip([None] * len(images), *all_maps), ()) 110 | 111 | ke.utils.visualize(_images, _titles, _overlays, cols=1 + len(all_maps)) 112 | 113 | The functions above are simply shortcuts for 114 | :func:`~keras_explainable.engine.explaining.explain`, using their conventional 115 | hyper-parameters and post processing functions. 116 | For more flexibility, you can use the regular form: 117 | 118 | .. code-block:: python 119 | 120 | logits, cams = ke.explain( 121 | ke.methods.cam.gradcam, 122 | model_exp, 123 | inputs, 124 | explaining_units, 125 | batch_size=32, 126 | postprocessing=ke.filters.positive_normalize, 127 | ) 128 | 129 | While the :func:`~keras_explainable.engine.explaining.explain` function is a convenient 130 | wrapper, transparently distributing the workload based on the distribution strategy 131 | associated with the model, it is not a necessary component in the overall functioning 132 | of the library. Alternatively, one can call any explaining method directly: 133 | 134 | .. code-block:: python 135 | 136 | gradcam = ke.methods.cams.gradcam 137 | # Uncomment the following to compile the explaining pass: 138 | # gradcam = tf.function(ke.methods.cams.gradcam, reduce_retracing=True, jit_compile=True) 139 | 140 | logits, cams = gradcam(model, inputs, explaining_units) 141 | 142 | cams = ke.filters.positive_normalize(cams) 143 | cams = tf.image.resize(cams, (299, 299)).numpy() 144 | -------------------------------------------------------------------------------- /docs/wsol.rst: -------------------------------------------------------------------------------- 1 | =============================================================== 2 | Weakly Supervised Object Localization and Semantic Segmentation 3 | =============================================================== 4 | 5 | Object localization and segmentation cues can be extracted from models 6 | trained over multi-label datasets in a weakly supervised setup. 7 | 8 | An example of this technique is OC-CSE, which was first described in 9 | the paper "Unlocking the potential of ordinary classifier: Class-specific 10 | adversarial erasing framework for weakly supervised semantic segmentation.", 11 | by Kweon et al. (2021) [`link `_]. 12 | Its original code (written in PyTorch) is available at 13 | `KAIST-vilab/OC-CSE `_, but 14 | we will actually load its TensorFlow alternative, available at 15 | `lucasdavid/resnet38d-tf `_: 16 | 17 | .. jupyter-execute:: 18 | :hide-code: 19 | :hide-output: 20 | 21 | import os 22 | import numpy as np 23 | import pandas as pd 24 | import tensorflow as tf 25 | from keras.utils import load_img, img_to_array 26 | 27 | import keras_explainable as ke 28 | 29 | SOURCE_DIRECTORY = 'docs/_static/images/voc12/' 30 | SAMPLES = 8 31 | SIZES = (384, 384) 32 | 33 | file_names = sorted(os.listdir(SOURCE_DIRECTORY)) 34 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 35 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 36 | images = images.astype("uint8")[:SAMPLES] 37 | label_indices = [[8, 11], [2], [1, 14], [4, 14], [16], [2], [0, 14], [13, 14]] 38 | labels = np.zeros((len(label_indices), 20)) 39 | for i, l in enumerate(label_indices): 40 | labels[i, l] = 1. 41 | 42 | def pascal_voc_classes(): 43 | return np.asarray(( 44 | "aeroplane bicycle bird boat bottle bus car cat chair cow diningtable " 45 | "dog horse motorbike person pottedplant sheep sofa train tvmonitor" 46 | ).split()) 47 | 48 | def pascal_voc_colors(): 49 | return np.asarray([ 50 | [0, 0, 0], # background 51 | [128, 0, 0], 52 | [0, 128, 0], 53 | [128, 128, 0], 54 | [0, 0, 128], 55 | [128, 0, 128], 56 | [0, 128, 128], 57 | [128, 128, 128], 58 | [64, 0, 0], 59 | [192, 0, 0], 60 | [64, 128, 0], 61 | [192, 128, 0], 62 | [64, 0, 128], 63 | [192, 0, 128], 64 | [64, 128, 128], 65 | [192, 128, 128], 66 | [0, 64, 0], 67 | [128, 64, 0], 68 | [0, 192, 0], 69 | [128, 192, 0], 70 | [0, 64, 128], 71 | [224, 224, 192] # void (contours, outline and padded regions) 72 | ]) / 255. 73 | 74 | 75 | .. jupyter-execute:: 76 | 77 | COLORS = pascal_voc_colors() 78 | CLASSES = pascal_voc_classes() 79 | WEIGHTS = 'docs/_build/data/resnet38d_voc2012_occse.h5' 80 | 81 | ! mkdir -p docs/_build/data 82 | ! wget -q -nc https://raw.githubusercontent.com/lucasdavid/resnet38d-tf/main/resnet38d.py 83 | ! wget -qnc https://github.com/lucasdavid/resnet38d-tf/releases/download/0.0.1/resnet38d_voc2012_occse.h5 -P docs/_build/data/ 84 | 85 | from resnet38d import ResNet38d 86 | 87 | input_tensor = tf.keras.Input(shape=(None, None, 3), name="inputs") 88 | rn38d = ResNet38d(input_tensor=input_tensor, weights=WEIGHTS) 89 | 90 | print(f"ResNet38-d with {WEIGHTS} pre-trained weights loaded.") 91 | print(f"Spatial map sizes: {rn38d.get_layer('s5/ac').input.shape}") 92 | 93 | ! rm resnet38d.py 94 | 95 | We can feed-forward the samples once and get the predicted classes for each sample. 96 | Besides making sure the model is outputting the expected classes, this step is 97 | required in order to determine the most activating units in the *logits* layer, 98 | which improves performance of the explaining methods. 99 | 100 | .. jupyter-execute:: 101 | 102 | prec = tf.keras.applications.imagenet_utils.preprocess_input 103 | 104 | inputs = prec(images.astype("float").copy(), mode='torch') 105 | probs = rn38d.predict(inputs, verbose=0) 106 | 107 | Finally, we can simply run all available explaining methods: 108 | 109 | .. jupyter-execute:: 110 | 111 | rn38d = ke.inspection.expose(rn38d, "s5/ac", 'avg_pool') 112 | 113 | # Vanilla CAM 114 | _, cams = ke.cam(rn38d, inputs, batch_size=4) 115 | 116 | # TTA-CAM 117 | tta_cam_method = ke.methods.meta.tta( 118 | ke.methods.cams.cam, 119 | scales=[0.5, 1.0, 1.5, 2.], 120 | hflip=True, 121 | ) 122 | _, tta_cams = ke.explain( 123 | tta_cam_method, 124 | rn38d, 125 | inputs, 126 | batch_size=4, 127 | postprocessing=ke.filters.positive_normalize, 128 | ) 129 | 130 | Explaining maps can be converted into color maps, 131 | respecting the conventional Pascal color mapping: 132 | 133 | .. jupyter-execute:: 134 | 135 | def cams_to_colors(labels, maps, colors): 136 | overlays = [] 137 | labels = labels.astype(bool) 138 | 139 | for i in range(8): 140 | l = labels[i] 141 | c = colors[l] 142 | m = maps[i][..., l] 143 | o = np.einsum('dc,hwd->hwc', c, m).clip(0, 1) 144 | overlays.append(o) 145 | 146 | return overlays 147 | 148 | cam_overlays = cams_to_colors(labels, cams, COLORS[1:21]) 149 | tta_overlays = cams_to_colors(labels, tta_cams, COLORS[1:21]) 150 | 151 | ke.utils.visualize([*images, *cam_overlays, *tta_overlays]) 152 | -------------------------------------------------------------------------------- /docs/methods/saliency/gradients.rst: -------------------------------------------------------------------------------- 1 | ================== 2 | Gradient Back-prop 3 | ================== 4 | 5 | In this page, we describe how to obtain *saliency maps* from a trained 6 | Convolutional Neural Network (CNN) with respect to an input signal (an image, 7 | in this case) using the Gradient backprop AI explaining method. 8 | Said maps can be used to explain the model's predictions, determining regions 9 | which most contributed to its effective output. 10 | 11 | Gradient Back-propagation (or Gradient Backprop, for short) is an early 12 | form of visualizing and explaining the salient and contributing features 13 | considered in the decision process of a neural network, being first 14 | described in the following article: 15 | 16 | Simonyan, K., Vedaldi, A., & Zisserman, A. (2013). 17 | Deep inside convolutional networks: Visualising image classification 18 | models and saliency maps. arXiv preprint arXiv:1312.6034. 19 | Available at: `arxiv/1312.6034 `_. 20 | 21 | Briefly, this can be achieved with the following template snippet: 22 | 23 | .. code-block:: python 24 | 25 | import keras_explainable as ke 26 | 27 | model = build_model(...) 28 | model.layers[-1].activation = 'linear' # Usually softmax or sigmoid. 29 | 30 | logits, maps = ke.gradients(model, x, y, batch_size=32) 31 | 32 | We detail each of the necessary steps bellow. Firstly, we employ the 33 | :class:`Xception` network pre-trained over the ImageNet dataset: 34 | 35 | .. jupyter-execute:: 36 | :hide-code: 37 | :hide-output: 38 | 39 | import os 40 | import numpy as np 41 | import pandas as pd 42 | import tensorflow as tf 43 | from keras.utils import load_img, img_to_array 44 | 45 | import keras_explainable as ke 46 | 47 | SOURCE_DIRECTORY = 'docs/_static/images/singleton/' 48 | SAMPLES = 8 49 | SIZES = (299, 299) 50 | 51 | file_names = os.listdir(SOURCE_DIRECTORY) 52 | image_paths = [os.path.join(SOURCE_DIRECTORY, f) for f in file_names if f != '_links.txt'] 53 | images = np.stack([img_to_array(load_img(ip).resize(SIZES)) for ip in image_paths]) 54 | images = images.astype("uint8")[:SAMPLES] 55 | 56 | .. jupyter-execute:: 57 | 58 | model = tf.keras.applications.Xception( 59 | classifier_activation=None, 60 | weights='imagenet', 61 | ) 62 | 63 | print(f'Xception pretrained over ImageNet was loaded.') 64 | print(f"Spatial map sizes: {model.get_layer('avg_pool').input.shape}") 65 | 66 | We can feed-forward the samples once and get the predicted classes for each sample. 67 | Besides making sure the model is outputting the expected classes, this step is 68 | required in order to determine the most activating units in the *logits* layer, 69 | which improves performance of the explaining methods. 70 | 71 | .. jupyter-execute:: 72 | 73 | from tensorflow.keras.applications.imagenet_utils import preprocess_input 74 | 75 | inputs = images / 127.5 - 1 76 | logits = model.predict(inputs, verbose=0) 77 | indices = np.argsort(logits, axis=-1)[:, ::-1] 78 | explaining_units = indices[:, :1] # First most likely class. 79 | 80 | Gradient Backprop can be obtained by computing the differential of a function 81 | (usually expressing the logit score for a given class) with respect to pixels 82 | contained in the input signal (usually expressing an image): 83 | 84 | .. jupyter-execute:: 85 | 86 | logits, maps = ke.gradients(model, inputs, explaining_units) 87 | 88 | ke.utils.visualize([*images, *maps]) 89 | 90 | .. note:: 91 | 92 | If the parameter ``indices`` in ``gradients`` is not set, an 93 | explanation for each unit in the explaining layer will be provided, 94 | possibly resuting in *OOM* errors for models containing many units. 95 | 96 | To increase efficiency, we sub-select only the top :math:`K` scoring 97 | classification units to explain. The jacobian will only be computed 98 | for these :math:`NK` outputs. 99 | 100 | Inside the hood, :func:`keras_explainable.gradients` is simply 101 | executing the following call to the 102 | :func:`explain` function: 103 | 104 | .. code-block:: python 105 | 106 | logits, maps = ke.explain( 107 | methods.gradient.gradients, 108 | model, 109 | inputs, 110 | explaining_units, 111 | postprocessing=filters.absolute_normalize, 112 | ) 113 | 114 | Following Gradient Backprop paper, we consider the positive and 115 | negative contributing regions in the creation of the saliency maps 116 | by computing their individual absolute contributions before 117 | normalizing them. Different strategies can be employed by 118 | changing the ``postprocessing`` parameter. 119 | 120 | .. note:: 121 | 122 | For more information on the :func:`~keras_explainable.explain` function, 123 | check its documentation or its own examples page. 124 | 125 | Of course, we can obtain the same result by directly calling the 126 | :func:`~keras_explainable.methods.gradient.gradients` function 127 | (though it will not leverage the model's inner distributed strategy 128 | and data optimizations implemented in :func:`~keras_explainable.explain`): 129 | 130 | .. jupyter-execute:: 131 | 132 | gradients = tf.function( 133 | ke.methods.gradient.gradients, jit_compile=True, reduce_retracing=True 134 | ) 135 | _, direct_maps = gradients(model, inputs, explaining_units) 136 | 137 | direct_maps = ke.filters.absolute_normalize(maps) 138 | direct_maps = tf.image.resize(direct_maps, (299, 299)) 139 | direct_maps = direct_maps.numpy() 140 | 141 | np.testing.assert_array_almost_equal(maps, direct_maps) 142 | print('Maps computed with `explain` and `methods.gradient.gradients` are the same!') 143 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | keras Explainable 3 | ================= 4 | 5 | .. image:: https://github.com/lucasdavid/keras-explainable/actions/workflows/ci.yml/badge.svg?branch=release 6 | :alt: Travis build status 7 | :target: https://github.com/lucasdavid/keras-explainable/actions/workflows/ci.yml 8 | 9 | .. image:: https://img.shields.io/badge/docs-0.0.2-blue 10 | :alt: Documentation status 11 | :target: https://lucasdavid.github.io/keras-explainable 12 | 13 | Efficient explaining AI algorithms for Keras models. 14 | 15 | .. image:: _static/images/cover.jpg 16 | :alt: Examples of explaining methods employed to explain outputs from various example images. 17 | 18 | Installation 19 | ------------ 20 | 21 | .. code-block:: shell 22 | 23 | pip install tensorflow 24 | pip install git+https://github.com/lucasdavid/keras-explainable.git 25 | 26 | Usage 27 | ----- 28 | 29 | This example illustrate how to explain predictions of a Convolutional Neural 30 | Network (CNN) using Grad-CAM. This can be easily achieved with the following 31 | example: 32 | 33 | .. code-block:: python 34 | 35 | import keras_explainable as ke 36 | 37 | model = tf.keras.applications.ResNet50V2(...) 38 | model = ke.inspection.expose(model) 39 | 40 | scores, cams = ke.gradcam(model, x, y, batch_size=32) 41 | 42 | Implemented Explaining Methods 43 | ------------------------------ 44 | 45 | .. table:: 46 | :widths: auto 47 | :align: left 48 | 49 | =========================== ========= ====================================================== ================== 50 | Method Kind Description Reference 51 | =========================== ========= ====================================================== ================== 52 | Gradient Back-propagation gradient Computes the gradient of the output activation unit `docs `_ 53 | being explained with respect to each unit in the input `paper `_ 54 | signal. 55 | Full-Gradient gradient Adds the individual contributions of each bias factor `docs `_ 56 | in the model to the extracted gradient, forming the `paper `_ 57 | "full gradient" representation. 58 | CAM CAM Creates class-specific maps by linearly combining the `docs `_ 59 | activation maps advent from the last convolutional `paper `_ 60 | layer, scaled by their contributions to the unit of 61 | interest. 62 | Grad-CAM CAM Linear combination of activation maps, weighted by `docs `_ 63 | the gradient of the output unit with respect to the `paper `_ 64 | maps themselves. 65 | Grad-CAM++ CAM Weights pixels in the activation maps in order to `docs `_ 66 | counterbalance, resulting in similar activation `paper `_ 67 | intensity over multiple instances of objects. 68 | Score-CAM CAM Combines activation maps considering their `docs `_ 69 | contribution towards activation, when used to mask `paper `_ 70 | Activation maps are used to mask the input signal, 71 | which is feed-forwarded and activation intensity is 72 | computed for the new . Maps are combined weighted by 73 | their relative activation retention. 74 | SmoothGrad Meta Consecutive applications of an AI explaining method, `docs `_ 75 | adding Gaussian noise to the input signal each time. `paper `_ 76 | TTA Meta Consecutive applications of an AI explaining method, `docs `_ 77 | applying augmentation to the input signal each time. `paper `_ 78 | =========================== ========= ====================================================== ================== 79 | -------------------------------------------------------------------------------- /src/keras_explainable/methods/meta.py: -------------------------------------------------------------------------------- 1 | """Implementation of various Meta techniques. 2 | 3 | These can be used conjointly with CAM and Gradient-based methods, 4 | providing cleaner and more robust results. 5 | """ 6 | 7 | from functools import partial 8 | from typing import Callable 9 | from typing import List 10 | from typing import Tuple 11 | 12 | import tensorflow as tf 13 | 14 | from keras_explainable.inspection import SPATIAL_AXIS 15 | 16 | 17 | def smooth( 18 | method: Callable, 19 | repetitions: int = 20, 20 | noise: int = 0.1, 21 | ) -> Tuple[tf.Tensor, tf.Tensor]: 22 | """Smooth Meta Explaining Method. 23 | 24 | This technique consists of repeatedly applying an AI explaining method, considering 25 | small variations of the input signal each time (tempered with gaussian noise). 26 | 27 | Usage: 28 | 29 | .. code-block:: python 30 | 31 | x = np.random.normal((1, 224, 224, 3)) 32 | y = np.asarray([[16, 32]]) 33 | 34 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 35 | 36 | smoothgrad = ke.methods.meta.smooth( 37 | ke.methods.gradient.gradients, 38 | repetitions=20, 39 | noise=0.1, 40 | ) 41 | 42 | scores, maps = smoothgrad(model, x, y) 43 | 44 | References: 45 | 46 | - Smilkov, D., Thorat, N., Kim, B., Viégas, F., & Wattenberg, M. (2017). 47 | SmoothGrad: removing noise by adding noise. arXiv preprint arXiv:1706.03825. 48 | Available at: [arxiv/1706.03825](https://arxiv.org/abs/1706.03825) 49 | 50 | Args: 51 | method (Callable): the explaining method to be smoothed 52 | repetitions (int, optional): number of repetitions. Defaults to 20. 53 | noise (int, optional): standard deviation of the gaussian noise 54 | added to the input signal. Defaults to 0.1. 55 | 56 | Returns: 57 | Tuple[tf.Tensor, tf.Tensor]: the logits and explaining maps. 58 | 59 | """ 60 | def apply( 61 | model: tf.keras.Model, 62 | inputs: tf.Tensor, 63 | *args, 64 | **params, 65 | ): 66 | logits, maps = method(model, inputs, *args, **params) 67 | shape = tf.concat(([repetitions - 1], tf.shape(inputs)), axis=0) 68 | 69 | noisy_inputs = inputs + tf.random.normal(shape, 0, noise, dtype=inputs.dtype) 70 | 71 | with tf.control_dependencies([logits, maps]): 72 | for step in tf.range(repetitions - 1): 73 | batch_inputs = noisy_inputs[step] 74 | batch_logits, batch_maps = method(model, batch_inputs, *args, **params) 75 | 76 | logits += batch_logits 77 | maps += batch_maps 78 | 79 | return ( 80 | logits / repetitions, 81 | maps / repetitions, 82 | ) 83 | 84 | apply.__name__ = f"{method.__name__}_smooth" 85 | return apply 86 | 87 | 88 | def tta( 89 | method: Callable, 90 | scales: List[float] = [0.5, 1.5, 2.0], 91 | hflip: bool = True, 92 | resize_method: str = "bilinear", 93 | ) -> Tuple[tf.Tensor, tf.Tensor]: 94 | """Computes the TTA version of a visualization method. 95 | 96 | Usage: 97 | 98 | .. code-block:: python 99 | 100 | x = np.random.normal((1, 224, 224, 3)) 101 | y = np.asarray([[16, 32]]) 102 | 103 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 104 | 105 | scores, maps = ke.explain( 106 | methods.gradient.gradients, 107 | rn50, 108 | inputs, 109 | explaining_units, 110 | postprocessing=filters.absolute_normalize, 111 | ) 112 | 113 | Args: 114 | method (Callable): the explaining method to be augmented 115 | scales (List[float], optional): a list of coefs to scale the inputs by. 116 | Defaults to [0.5, 1.5, 2.0]. 117 | hflip (bool, optional): wether or not to flip horizontally the inputs. 118 | Defaults to True. 119 | resize_method (str, optional): the resizing method used. Defaults to "bilinear". 120 | 121 | Returns: 122 | Tuple[tf.Tensor, tf.Tensor]: the logits and explaining maps. 123 | """ 124 | scales = tf.convert_to_tensor(scales, dtype=tf.float32) 125 | 126 | def apply( 127 | model: tf.keras.Model, 128 | inputs: tf.Tensor, 129 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 130 | **params, 131 | ): 132 | method_ = partial(method, spatial_axis=spatial_axis, **params) 133 | 134 | shapes = tf.shape(inputs) 135 | sizes = shapes[1:-1] 136 | 137 | logits, maps = _forward(method_, model, inputs, sizes, None, False, resize_method) 138 | 139 | if hflip: 140 | with tf.control_dependencies([logits, maps]): 141 | logits_r, maps_r = _forward( 142 | method_, model, inputs, sizes, None, True, resize_method 143 | ) 144 | logits += logits_r 145 | maps += maps_r 146 | 147 | for idx in tf.range(scales.shape[0]): 148 | scale = scales[idx] 149 | logits_r, maps_r = _forward( 150 | method_, model, inputs, sizes, scale, False, resize_method 151 | ) 152 | logits += logits_r 153 | maps += maps_r 154 | 155 | if hflip: 156 | logits_r, maps_r = _forward( 157 | method_, model, inputs, sizes, scale, True, resize_method 158 | ) 159 | logits += logits_r 160 | maps += maps_r 161 | 162 | repetitions = scales.shape[0] 163 | if hflip: 164 | repetitions *= 2 165 | 166 | logits /= repetitions 167 | maps /= repetitions 168 | 169 | return logits, maps 170 | 171 | def _forward(method, model, inputs, sizes, scale, hflip, resize_method): 172 | if hflip: 173 | inputs = tf.image.flip_left_right(inputs) 174 | 175 | if scale is not None: 176 | resizes = tf.cast(sizes, tf.float32) 177 | resizes = tf.cast(scale * resizes, tf.int32) 178 | inputs = tf.image.resize(inputs, resizes, method=resize_method) 179 | 180 | logits, maps = method(model, inputs) 181 | 182 | if hflip: 183 | maps = tf.image.flip_left_right(maps) 184 | 185 | maps = tf.image.resize(maps, sizes, method=resize_method) 186 | 187 | return logits, maps 188 | 189 | apply.__name__ = f"{method.__name__}_tta" 190 | return apply 191 | 192 | 193 | __all__ = [ 194 | "smooth", 195 | "tta", 196 | ] 197 | -------------------------------------------------------------------------------- /docs/exposure.rst: -------------------------------------------------------------------------------- 1 | ============================= 2 | Exposing Intermediate Signals 3 | ============================= 4 | 5 | This page details the exposure procedure, necessary for most AI explaining 6 | methods, and which can be easened with the help of the 7 | :func:`~keras_explainable.inspection.expose` function. 8 | 9 | Simple Exposition Examples 10 | -------------------------- 11 | 12 | Many explaining techniques require us to expose the intermediate tensors 13 | so their respective signals can be used, or so the gradient of the output 14 | can be computed with respect to their signals. 15 | For example, Grad-CAM computes the gradient of an output unit with respect 16 | to the activation signal advent from the last positional layer in the model: 17 | 18 | .. code-block:: python 19 | 20 | with tf.GradientTape() as tape: 21 | logits, activations = model(x) 22 | 23 | gradients = tape.batch_jacobian(logits, activations) 24 | 25 | Which evidently means the ``activations`` signal, a tensor of 26 | shape ``(batch, height, width, ..., kernels)`` must be available at runtime. 27 | For that to happen, we must redefine the model, setting its outputs 28 | to contain the :class:`KerasTensor`'s objects that reference both 29 | ``logits`` and ``activations`` tensors: 30 | 31 | .. jupyter-execute:: 32 | 33 | import numpy as np 34 | import tensorflow as tf 35 | from keras import Input, Model, Sequential 36 | from keras.applications import ResNet50V2 37 | from keras.layers import Activation, Dense, GlobalAveragePooling2D 38 | 39 | import keras_explainable as ke 40 | 41 | rn50 = ResNet50V2(weights=None, classifier_activation=None) 42 | # activations_tensor = rn50.get_layer("avg_pool").input # or... 43 | activations_tensor = rn50.get_layer("post_relu").output 44 | 45 | model = Model(rn50.input, [rn50.output, activations_tensor]) 46 | 47 | print(model.name) 48 | print(f" input: {model.input}") 49 | print(" outputs:") 50 | for o in model.outputs: 51 | print(f" {o}") 52 | 53 | Which can be simplified with: 54 | 55 | .. code-block:: python 56 | 57 | model = ke.inspection.expose(rn50) 58 | 59 | The :func:`~keras_explainable.inspection.expose` function inspects the model, 60 | seeking for the *logits* layer (the last containing a kernel property) and the 61 | *global pooling* layer, an instance of a :class:`GlobalPooling` or 62 | :class:`Flatten` layer classes. The output of the former and the input of the 63 | latter are collected and a new model is defined. 64 | 65 | You can also manually indicate the name of the argument and output layers. 66 | All options bellow are equivalent: 67 | 68 | .. code-block:: python 69 | 70 | model = ke.inspection.expose(rn50, "post_relu", "predictions") 71 | model = ke.inspection.expose( 72 | rn50, 73 | {"name": "post_relu", "link": "output"}, 74 | {"name": "predictions"}, 75 | ) 76 | model = ke.inspection.expose( 77 | rn50, 78 | {"name": "post_relu", "link": "output", "node": 0}, 79 | {"name": "predictions", "link": "output", "node": 0}, 80 | ) 81 | model = ke.inspection.expose( 82 | rn50, 83 | {"name": "avg_pool", "link": "input"}, 84 | "predictions", 85 | ) 86 | 87 | Grad-CAM (or Grad-CAM++) can be called immediately after that: 88 | 89 | .. jupyter-execute:: 90 | 91 | inputs = np.random.normal(size=(4, 224, 224, 3)) 92 | indices = np.asarray([[4], [9], [0], [2]]) 93 | 94 | scores, cams = ke.gradcam(model, inputs, indices) 95 | 96 | print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]") 97 | print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]") 98 | 99 | Exposing Nested Models 100 | ---------------------- 101 | 102 | Unfortunately, some model's topologies can make exposition a little tricky. 103 | An example of this is when nesting multiple models, producing more than one 104 | ``Input`` object and multiple conceptual graphs at once. 105 | Then, if one naively collects ``KerasTensor``'s from the model, disconnected 106 | nodes may be retrieved, resulting in the exception ``ValueError: Graph disconnected`` 107 | being raised: 108 | 109 | .. jupyter-execute:: 110 | :raises: ValueError 111 | 112 | rn50 = ResNet50V2(weights=None, include_top=False) 113 | 114 | x = Input([224, 224, 3], name="input_images") 115 | y = rn50(x) 116 | y = GlobalAveragePooling2D(name="avg_pool")(y) 117 | y = Dense(10, name="logits")(y) 118 | y = Activation("softmax", name="predictions", dtype="float32")(y) 119 | 120 | rn50_clf = Model(x, y, name="resnet50v2_clf") 121 | rn50_clf.summary() 122 | 123 | logits = rn50_clf.get_layer("logits").output 124 | activations = rn50_clf.get_layer("resnet50v2").output 125 | 126 | model = tf.keras.Model(rn50_clf.input, [logits, activations]) 127 | scores, cams = ke.gradcam(model, inputs, indices) 128 | 129 | print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]") 130 | print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]") 131 | 132 | The operations in ``rn50`` appear in two conceptual graphs. The first, defined 133 | when ``ResNet50V2(...)`` was invoked, contains all operations associated with the layers 134 | in the ResNet50 architecture. The second one, on the other hand, is defined when 135 | invoking :meth:`Layer.__call__` of each layer (``rn50``, ``GAP``, ``Dense`` and 136 | ``Activation``). 137 | 138 | When calling ``rn50_clf.get_layer("resnet50v2").output`` (which is equivalent 139 | to ``rn50_clf.get_layer("resnet50v2").get_output_at(0)``), the :class:`Node` 140 | from the first graph is retrieved. 141 | This ``Node`` is not associated with ``rn50_clf.input`` or ``logits``, and thus 142 | the error is raised. 143 | 144 | There are multiple ways to correctly access the Node from the second graph. One of them 145 | is to retrieve the input from the ``GAP`` layer, as it only appeared in one graph: 146 | 147 | .. jupyter-execute:: 148 | 149 | model = ke.inspection.expose( 150 | rn50_clf, {"name": "avg_pool", "link": "input"}, "predictions" 151 | ) 152 | scores, cams = ke.gradcam(model, inputs, indices) 153 | 154 | print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]") 155 | print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]") 156 | 157 | .. jupyter-execute:: 158 | :hide-code: 159 | :hide-output: 160 | 161 | del rn50, rn50_clf, model 162 | 163 | .. note:: 164 | 165 | The alternatives ``ke.inspection.expose(rn50_clf, "resnet50v2", "predictions")`` 166 | and ``ke.inspection.expose(rn50_clf)`` would work as well. 167 | In the former, the **last** output node is retrieved. 168 | In the latter, the **last** input node (there's only one) associated 169 | with the ``GAP`` layer is retrieved. 170 | 171 | Access Nested Layer Signals 172 | """"""""""""""""""""""""""" 173 | 174 | Another problem occurs when the global pooling layer is not part of layers set 175 | of the out-most model. While you can still collect its output using a name 176 | composition, we get a ``ValueError: Graph disconnected``. 177 | 178 | This problem occurs because Keras does not create ``Nodes`` for inner layers in a nested 179 | model, when that model is reused. Instead, the model is treated as a single operation 180 | in the conceptual graph, with a single new ``Node`` being created to represent it. 181 | Calling :func:`keras_explainable.inspection.expose` over the model will expand the 182 | parameter ``arguments`` into ``{"name": ("ResNet50V2", "avg_pool"), "link": "input", "node": "last"}``, 183 | but because no new nodes were created for the ``GAP`` layer, the :class:`KerasTensor` 184 | associated with the first conceptual graph is retrieved, and the error ensues. 185 | 186 | .. jupyter-execute:: 187 | :raises: ValueError 188 | 189 | rn50 = ResNet50V2(weights=None, include_top=False, pooling="avg") 190 | rn50_clf = Sequential([ 191 | Input([224, 224, 3], name="input_images"), 192 | rn50, 193 | Dense(10, name="logits"), 194 | Activation("softmax", name="predictions", dtype="float32"), 195 | ]) 196 | 197 | model = ke.inspection.expose(rn50_clf) 198 | scores, cams = ke.gradcam(model, inputs, indices) 199 | 200 | print(f"scores:{scores.shape} in [{scores.min()}, {scores.max()}]") 201 | print(f"cams:{cams.shape} in [{cams.min()}, {cams.max()}]") 202 | 203 | 204 | .. warning:: 205 | 206 | Since TensorFlow 2, nodes are no longer being stacked in ``_inbound_nodes`` 207 | for layers in nested models, which obstructs the access to intermediate 208 | signals contained in a nested model, and makes the remaining of this 209 | document obsolete. 210 | To avoid this problem, it is recommended to "flat out" the model before 211 | explaining it, or avoiding nesting models altogether. 212 | 213 | For more information, see the GitHub issue 214 | `#16123 `_. 215 | 216 | If you are using TensorFlow < 2.0, nodes are created for each operation 217 | in the inner model, and you may collect their internal signal by simply: 218 | 219 | .. code-block:: python 220 | 221 | model = ke.inspection.expose(rn50_clf) 222 | # ... or: ke.inspection.expose(rn50_clf, ("resnet50v2", "post_relu")) 223 | # ... or: ke.inspection.expose( 224 | # rn50_clf, {"name": ("resnet50v2", "avg_pool"), "link": "input"} 225 | # ) 226 | 227 | scores, cams = ke.gradcam(model, inputs, indices) 228 | 229 | .. note:: 230 | 231 | The above works because :func:`~keras_explainable.inspection.expose` 232 | will recursively seek for a ``GAP`` layer within the nested models. 233 | -------------------------------------------------------------------------------- /src/keras_explainable/methods/gradient.py: -------------------------------------------------------------------------------- 1 | """Implementation of various Gradient-based AI explaining methods and techniques. 2 | """ 3 | 4 | from functools import partial 5 | from typing import Callable 6 | from typing import List 7 | from typing import Optional 8 | from typing import Tuple 9 | 10 | import tensorflow as tf 11 | 12 | from keras_explainable import filters 13 | from keras_explainable import inspection 14 | from keras_explainable.inspection import KERNEL_AXIS 15 | from keras_explainable.inspection import SPATIAL_AXIS 16 | 17 | 18 | def transpose_jacobian( 19 | x: tf.Tensor, spatial_rank: Tuple[int] = len(SPATIAL_AXIS) 20 | ) -> tf.Tensor: 21 | """Transpose the Jacobian of shape (b,g,...) into (b,...,g). 22 | 23 | Args: 24 | x (tf.Tensor): the jacobian tensor. 25 | spatial_rank (Tuple[int], optional): the spatial rank of ``x``. 26 | Defaults to ``len(SPATIAL_AXIS)``. 27 | 28 | Returns: 29 | tf.Tensor: the transposed jacobian. 30 | """ 31 | dims = [2 + i for i in range(spatial_rank)] 32 | 33 | return tf.transpose(x, [0] + dims + [1]) 34 | 35 | 36 | def gradients( 37 | model: tf.keras.Model, 38 | inputs: tf.Tensor, 39 | indices: Optional[tf.Tensor] = None, 40 | indices_axis: int = KERNEL_AXIS, 41 | indices_batch_dims: int = -1, 42 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 43 | gradient_filter: Callable = tf.abs, 44 | ) -> Tuple[tf.Tensor, tf.Tensor]: 45 | """Computes the Gradient Back-propagation Visualization Method. 46 | 47 | This technique computes the gradient of the output activation unit being explained 48 | with respect to each unit in the input signal. 49 | Features (channels) in each pixel of the input sinal are absolutely averaged, 50 | following the original implementation: 51 | 52 | .. math:: 53 | 54 | f(x) = ψ(∇_xf(x)) 55 | 56 | This method expects `inputs` to be a batch of positional signals of 57 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 58 | where ``(H', W', ...)`` are the sizes of the visual receptive field 59 | in the explained activation layer and `L` is the number of labels 60 | represented within the model's output logits. 61 | 62 | If `indices` is passed, the specific logits indexed by elements in this 63 | tensor are selected before the gradients are computed, effectively 64 | reducing the columns in the jacobian, and the size of the output explaining map. 65 | 66 | Usage: 67 | 68 | .. code-block:: python 69 | 70 | x = np.random.normal((1, 224, 224, 3)) 71 | y = np.asarray([[16, 32]]) 72 | 73 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 74 | scores, cams = ke.methods.gradient.gradients(model, x, y) 75 | 76 | References: 77 | 78 | - Simonyan, K., Vedaldi, A., & Zisserman, A. (2013). 79 | Deep inside convolutional networks: Visualising image classification 80 | models and saliency maps. arXiv preprint 81 | `arXiv:1312.6034 `_. 82 | 83 | Args: 84 | model (tf.keras.Model): the model being explained 85 | inputs (tf.Tensor): the input data 86 | indices (Optional[tf.Tensor], optional): indices that should be gathered 87 | from ``outputs``. Defaults to None. 88 | indices_axis (int, optional): the axis containing the indices to gather. 89 | Defaults to ``KERNEL_AXIS``. 90 | indices_batch_dims (int, optional): the number of dimensions to broadcast 91 | in the ``tf.gather`` operation. Defaults to ``-1``. 92 | spatial_axis (Tuple[int], optional): the dimensions containing positional 93 | information. Defaults to ``SPATIAL_AXIS``. 94 | gradient_filter (Callable, optional): filter before channel combining. 95 | Defaults to ``tf.abs``. 96 | 97 | Returns: 98 | Tuple[tf.Tensor, tf.Tensor]: the logits and saliency maps. 99 | 100 | """ 101 | with tf.GradientTape(watch_accessed_variables=False) as tape: 102 | tape.watch(inputs) 103 | logits = model(inputs, training=False) 104 | logits = inspection.gather_units( 105 | logits, indices, indices_axis, indices_batch_dims 106 | ) 107 | 108 | maps = tape.batch_jacobian(logits, inputs) 109 | maps = gradient_filter(maps) 110 | maps = tf.reduce_mean(maps, axis=-1) 111 | maps = transpose_jacobian(maps, len(spatial_axis)) 112 | 113 | return logits, maps 114 | 115 | 116 | def _resized_psi_dfx( 117 | inputs: tf.Tensor, 118 | outputs: tf.Tensor, 119 | sizes: tf.Tensor, 120 | psi: Callable = filters.absolute_normalize, 121 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 122 | ) -> tf.Tensor: 123 | """Filter and resize the gradient tensor. 124 | 125 | Args: 126 | inputs (tf.Tensor): the input signal. 127 | outputs (tf.Tensor): the output signal. 128 | sizes (tf.Tensor): the expected sizes. 129 | psi (Callable, optional): the filtering function. Defaults to 130 | :func:`~keras_explainable.filters.absolute_normalize`. 131 | spatial_axis (Tuple[int], optional): the spatial axes in the signal. 132 | Defaults to ``SPATIAL_AXIS``. 133 | 134 | Returns: 135 | tf.Tensor: the resized and processed tensor. 136 | """ 137 | t = outputs * inputs 138 | t = psi(t, spatial_axis) 139 | t = tf.reduce_mean(t, axis=-1, keepdims=True) 140 | # t = transpose_jacobian(t, len(spatial_axis)) 141 | t = tf.image.resize(t, sizes) 142 | 143 | return t 144 | 145 | 146 | def full_gradients( 147 | model: tf.keras.Model, 148 | inputs: tf.Tensor, 149 | indices: Optional[tf.Tensor] = None, 150 | indices_axis: int = KERNEL_AXIS, 151 | indices_batch_dims: int = -1, 152 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 153 | psi: Callable = filters.absolute_normalize, 154 | biases: Optional[List[tf.Tensor]] = None, 155 | ): 156 | """Computes the Full-Grad Visualization Method. 157 | 158 | This technique adds the individual contributions of each bias factor 159 | in the model to the extracted gradient, forming the "full gradient" 160 | representation, and it can be summarized by the following equation: 161 | 162 | .. math:: 163 | 164 | f(x) = ψ(∇_xf(x)\\odot x) +∑_{l\\in L}∑_{c\\in c_l} ψ(f^b(x)_c) 165 | 166 | This method expects `inputs` to be a batch of positional signals of 167 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 168 | where ``(H', W', ...)`` are the sizes of the visual receptive field 169 | in the explained activation layer and `L` is the number of labels 170 | represented within the model's output logits. 171 | 172 | If `indices` is passed, the specific logits indexed by elements in this 173 | tensor are selected before the gradients are computed, effectively 174 | reducing the columns in the jacobian, and the size of the output explaining map. 175 | 176 | Furthermore, the cached list of ``biases`` can be passed as a parameter for this 177 | method. If none is passed, it will be inferred at runtime, implying on a marginal 178 | increase in execution overhead during tracing. 179 | 180 | Usage: 181 | 182 | .. code-block:: python 183 | 184 | x = np.random.normal((1, 224, 224, 3)) 185 | y = np.asarray([[16, 32]]) 186 | 187 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 188 | 189 | logits = ke.inspection.get_logits_layer(model) 190 | inters, biases = ke.inspection.layers_with_biases(model, exclude=[logits]) 191 | model = ke.inspection.expose(model, inters, logits) 192 | 193 | scores, cams = ke.methods.gradient.full_gradients(model, x, y, biases=biases) 194 | 195 | References: 196 | - Srinivas S, Fleuret F. Full-gradient representation for neural network 197 | visualization. `arxiv.org/1905.00780 `_, 198 | 2019. 199 | 200 | Args: 201 | model (tf.keras.Model): the model being explained 202 | inputs (tf.Tensor): the input data 203 | indices (Optional[tf.Tensor], optional): indices that should be gathered 204 | from ``outputs``. Defaults to None. 205 | indices_axis (int, optional): the axis containing the indices to gather. 206 | Defaults to ``KERNEL_AXIS``. 207 | indices_batch_dims (int, optional): the number of dimensions to broadcast 208 | in the ``tf.gather`` operation. Defaults to ``-1``. 209 | spatial_axis (Tuple[int], optional): the dimensions containing positional 210 | information. Defaults to ``SPATIAL_AXIS``. 211 | psi (Callable, optional): filter operation before combining the intermediate 212 | signals. Defaults to ``filters.absolute_normalize``. 213 | biases: (List[tf.Tensor], optional): list of biases associated with each 214 | intermediate signal exposed by the model. If none is passed, it will 215 | be inferred from the endpoints (nodes) outputed by the model. 216 | 217 | Returns: 218 | Tuple[tf.Tensor, tf.Tensor]: the logits and saliency maps. 219 | 220 | """ 221 | shape = tf.shape(inputs) 222 | sizes = [shape[a] for a in spatial_axis] 223 | 224 | resized_psi_dfx_ = partial( 225 | _resized_psi_dfx, 226 | sizes=sizes, 227 | psi=psi, 228 | spatial_axis=spatial_axis, 229 | ) 230 | 231 | if biases is None: 232 | _, *intermediates = (i._keras_history.layer for i in model.outputs) 233 | biases = inspection.biases(intermediates) 234 | 235 | with tf.GradientTape(watch_accessed_variables=False) as tape: 236 | tape.watch(inputs) 237 | logits, *intermediates = model(inputs, training=False) 238 | logits = inspection.gather_units( 239 | logits, indices, indices_axis, indices_batch_dims 240 | ) 241 | 242 | grad_input, *grad_inter = tape.gradient(logits, [inputs, *intermediates]) 243 | 244 | maps = resized_psi_dfx_(inputs, grad_input) 245 | for b, i in zip(biases, grad_inter): 246 | maps += resized_psi_dfx_(b, i) 247 | 248 | return logits, maps 249 | 250 | 251 | METHODS = [ 252 | gradients, 253 | full_gradients, 254 | ] 255 | """Available Gradient-based AI Explaining methods. 256 | 257 | This list contains all available methods implemented in this module, 258 | and it is kept and used for introspection and validation purposes. 259 | """ 260 | 261 | __all__ = [ 262 | "gradients", 263 | "full_gradients", 264 | ] 265 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is execfile()d with the current directory set to its containing dir. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import os 11 | import sys 12 | import shutil 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | __location__ = os.path.dirname(__file__) 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | sys.path.insert(0, os.path.join(__location__, "../src")) 22 | 23 | # -- Run sphinx-apidoc ------------------------------------------------------- 24 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 25 | # `sphinx-build -b html . _build/html`. See Issue: 26 | # https://github.com/readthedocs/readthedocs.org/issues/1139 27 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 28 | # setup.py install" in the RTD Advanced Settings. 29 | # Additionally it helps us to avoid running apidoc manually 30 | 31 | try: # for Sphinx >= 1.7 32 | from sphinx.ext import apidoc 33 | except ImportError: 34 | from sphinx import apidoc 35 | 36 | output_dir = os.path.join(__location__, "api") 37 | module_dir = os.path.join(__location__, "../src/keras_explainable") 38 | try: 39 | shutil.rmtree(output_dir) 40 | except FileNotFoundError: 41 | pass 42 | 43 | try: 44 | import sphinx 45 | 46 | cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" 47 | 48 | args = cmd_line.split(" ") 49 | if tuple(sphinx.__version__.split(".")) >= ("1", "7"): 50 | # This is a rudimentary parse_version to avoid external dependencies 51 | args = args[1:] 52 | 53 | apidoc.main(args) 54 | except Exception as e: 55 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 56 | 57 | import sphinx_redactor_theme 58 | 59 | # -- General configuration --------------------------------------------------- 60 | 61 | # If your documentation needs a minimal Sphinx version, state it here. 62 | # needs_sphinx = '1.0' 63 | 64 | # Add any Sphinx extension module names here, as strings. They can be extensions 65 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 66 | extensions = [ 67 | "sphinx.ext.autodoc", 68 | "sphinx.ext.intersphinx", 69 | "sphinx.ext.todo", 70 | "sphinx.ext.autosummary", 71 | "sphinx.ext.viewcode", 72 | "sphinx.ext.coverage", 73 | "sphinx.ext.doctest", 74 | "sphinx.ext.ifconfig", 75 | "sphinx.ext.mathjax", 76 | "sphinx.ext.napoleon", 77 | "sphinx_autodoc_typehints", 78 | "jupyter_sphinx", 79 | ] 80 | 81 | # from sphinx_execute_code import directives 82 | 83 | # Add any paths that contain templates here, relative to this directory. 84 | templates_path = ["_templates"] 85 | 86 | # The suffix of source filenames. 87 | source_suffix = ".rst" 88 | 89 | # The encoding of source files. 90 | # source_encoding = 'utf-8-sig' 91 | 92 | # The master toctree document. 93 | master_doc = "index" 94 | 95 | # General information about the project. 96 | project = "keras-explainable" 97 | copyright = "2022, Lucas David" 98 | 99 | # The version info for the project you're documenting, acts as replacement for 100 | # |version| and |release|, also used in various other places throughout the 101 | # built documents. 102 | # 103 | # version: The short X.Y version. 104 | # release: The full version, including alpha/beta/rc tags. 105 | # If you don’t need the separation provided between version and release, 106 | # just set them both to the same value. 107 | try: 108 | from keras_explainable import __version__ as version 109 | except ImportError: 110 | version = "" 111 | 112 | if not version or version.lower() == "unknown": 113 | version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD 114 | 115 | release = version 116 | 117 | # The language for content autogenerated by Sphinx. Refer to documentation 118 | # for a list of supported languages. 119 | # language = None 120 | 121 | # There are two options for replacing |today|: either, you set today to some 122 | # non-false value, then it is used: 123 | # today = '' 124 | # Else, today_fmt is used as the format for a strftime call. 125 | # today_fmt = '%B %d, %Y' 126 | 127 | # List of patterns, relative to source directory, that match files and 128 | # directories to ignore when looking for source files. 129 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] 130 | 131 | # The reST default role (used for this markup: `text`) to use for all documents. 132 | # default_role = None 133 | 134 | # If true, '()' will be appended to :func: etc. cross-reference text. 135 | # add_function_parentheses = True 136 | 137 | # If true, the current module name will be prepended to all description 138 | # unit titles (such as .. function::). 139 | add_module_names = False 140 | 141 | # If true, sectionauthor and moduleauthor directives will be shown in the 142 | # output. They are ignored by default. 143 | # show_authors = False 144 | 145 | # The name of the Pygments (syntax highlighting) style to use. 146 | pygments_style = "vs" 147 | 148 | # A list of ignored prefixes for module index sorting. 149 | # modindex_common_prefix = [] 150 | 151 | # If true, keep warnings as "system message" paragraphs in the built documents. 152 | # keep_warnings = False 153 | 154 | # If this is True, todo emits a warning for each TODO entries. The default is False. 155 | todo_emit_warnings = True 156 | 157 | # -- Options for HTML output ------------------------------------------------- 158 | 159 | # The theme to use for HTML and HTML Help pages. See the documentation for 160 | # a list of builtin themes. 161 | # html_theme = 'sphinx_book_theme' 162 | html_theme = 'sphinx_redactor_theme' 163 | 164 | # Theme options are theme-specific and customize the look and feel of a theme 165 | # further. For a list of options available for each theme, see the 166 | # documentation. 167 | html_theme_options = { 168 | # "sidebar_width": "300px", "page_width": "1200px" 169 | # "repository_url": "https://github.com/lucasdavid/keras-explainable", 170 | # "use_repository_button": True, 171 | } 172 | 173 | 174 | # Add any paths that contain custom themes here, relative to this directory. 175 | # html_theme_path = [] 176 | html_theme_path = [sphinx_redactor_theme.get_html_theme_path()] 177 | 178 | # The name for this set of Sphinx documents. If None, it defaults to 179 | # " v documentation". 180 | html_title = "Keras Explainable" 181 | 182 | # A shorter title for the navigation bar. Default is the same as html_title. 183 | # html_short_title = None 184 | 185 | # The name of an image file (relative to this directory) to place at the top 186 | # of the sidebar. 187 | # html_logo = "" 188 | 189 | # The name of an image file (within the static path) to use as favicon of the 190 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 191 | # pixels large. 192 | # html_favicon = None 193 | 194 | # Add any paths that contain custom static files (such as style sheets) here, 195 | # relative to this directory. They are copied after the builtin static files, 196 | # so a file named "default.css" will overwrite the builtin "default.css". 197 | html_static_path = ["_static"] 198 | 199 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 200 | # using the given strftime format. 201 | # html_last_updated_fmt = '%b %d, %Y' 202 | 203 | # If true, SmartyPants will be used to convert quotes and dashes to 204 | # typographically correct entities. 205 | # html_use_smartypants = True 206 | 207 | # Custom sidebar templates, maps document names to template names. 208 | # html_sidebars = {} 209 | 210 | # Additional templates that should be rendered to pages, maps page names to 211 | # template names. 212 | # html_additional_pages = {} 213 | 214 | # If false, no module index is generated. 215 | # html_domain_indices = True 216 | 217 | # If false, no index is generated. 218 | # html_use_index = True 219 | 220 | # If true, the index is split into individual pages for each letter. 221 | # html_split_index = False 222 | 223 | # If true, links to the reST sources are added to the pages. 224 | # html_show_sourcelink = True 225 | 226 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 227 | html_show_sphinx = False 228 | 229 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 230 | # html_show_copyright = True 231 | 232 | # If true, an OpenSearch description file will be output, and all pages will 233 | # contain a tag referring to it. The value of this option must be the 234 | # base URL from which the finished HTML is served. 235 | # html_use_opensearch = '' 236 | 237 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 238 | # html_file_suffix = None 239 | 240 | # Output file base name for HTML help builder. 241 | htmlhelp_basename = "keras-explainable-doc" 242 | 243 | # -- Options for LaTeX output ------------------------------------------------ 244 | 245 | latex_elements = { 246 | # The paper size ("letterpaper" or "a4paper"). 247 | # "papersize": "letterpaper", 248 | # The font size ("10pt", "11pt" or "12pt"). 249 | # "pointsize": "10pt", 250 | # Additional stuff for the LaTeX preamble. 251 | # "preamble": "", 252 | } 253 | 254 | # Grouping the document tree into LaTeX files. List of tuples 255 | # (source start file, target name, title, author, documentclass [howto/manual]). 256 | latex_documents = [ 257 | ("index", "user_guide.tex", "keras-explainable Documentation", "Lucas David", "manual") 258 | ] 259 | 260 | # The name of an image file (relative to this directory) to place at the top of 261 | # the title page. 262 | # latex_logo = "" 263 | 264 | # For "manual" documents, if this is true, then toplevel headings are parts, 265 | # not chapters. 266 | # latex_use_parts = False 267 | 268 | # If true, show page references after internal links. 269 | # latex_show_pagerefs = False 270 | 271 | # If true, show URL addresses after external links. 272 | # latex_show_urls = False 273 | 274 | # Documents to append as an appendix to all manuals. 275 | # latex_appendices = [] 276 | 277 | # If false, no module index is generated. 278 | # latex_domain_indices = True 279 | 280 | # -- External mapping -------------------------------------------------------- 281 | python_version = ".".join(map(str, sys.version_info[0:2])) 282 | intersphinx_mapping = { 283 | "sphinx": ("https://www.sphinx-doc.org/en/master", None), 284 | "python": ("https://docs.python.org/" + python_version, None), 285 | "matplotlib": ("https://matplotlib.org", None), 286 | "numpy": ("https://numpy.org/doc/stable", None), 287 | "sklearn": ("https://scikit-learn.org/stable", None), 288 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 289 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 290 | "setuptools": ("https://setuptools.pypa.io/en/stable/", None), 291 | "pyscaffold": ("https://pyscaffold.org/en/stable", None), 292 | } 293 | 294 | print(f"loading configurations for {project} {version} ...", file=sys.stderr) 295 | 296 | def builder_inited(app): 297 | app.add_css_file('css/custom.css') 298 | 299 | def setup(app): 300 | app.connect('builder-inited', builder_inited) 301 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Lucas David 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/keras_explainable/engine/explaining.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any 3 | from typing import Callable 4 | from typing import Dict 5 | from typing import List 6 | from typing import Optional 7 | from typing import Tuple 8 | from typing import Union 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | from keras import callbacks as callbacks_module 13 | from keras.callbacks import Callback 14 | from keras.engine import data_adapter 15 | from keras.engine.training import _is_tpu_multi_host 16 | from keras.engine.training import _minimum_control_deps 17 | from keras.engine.training import potentially_ragged_concat 18 | from keras.engine.training import reduce_per_replica 19 | from keras.utils import tf_utils 20 | from tensorflow.python.eager import context 21 | 22 | from keras_explainable.inspection import SPATIAL_AXIS 23 | 24 | 25 | def explain_step( 26 | model: tf.keras.Model, 27 | method: Callable, 28 | data: Tuple[tf.Tensor], 29 | spatial_axis: Tuple[int, int] = SPATIAL_AXIS, 30 | postprocessing: Callable = None, 31 | resizing: Optional[Union[bool, tf.Tensor]] = True, 32 | **params, 33 | ) -> Tuple[tf.Tensor, tf.Tensor]: 34 | inputs, indices, _ = data_adapter.unpack_x_y_sample_weight(data) 35 | logits, maps = method( 36 | model=model, 37 | inputs=inputs, 38 | indices=indices, 39 | spatial_axis=spatial_axis, 40 | **params, 41 | ) 42 | 43 | if postprocessing is not None: 44 | maps = postprocessing(maps, axis=spatial_axis) 45 | 46 | if resizing is not None and resizing is not False: 47 | if resizing is True: 48 | resizing = tf.shape(inputs)[1:-1] 49 | maps = tf.image.resize(maps, resizing) 50 | 51 | return logits, maps 52 | 53 | 54 | def make_explain_function( 55 | model: tf.keras.Model, 56 | method: Callable, 57 | params: Dict[str, Any], 58 | force: bool = False, 59 | ): 60 | explain_function = getattr(model, "explain_function", None) 61 | 62 | if explain_function is not None and not force: 63 | return explain_function 64 | 65 | def explain_function(iterator): 66 | """Runs a single explain step.""" 67 | 68 | def run_step(data): 69 | outputs = explain_step(model, method, data, **params) 70 | # Ensure counter is updated only if `test_step` succeeds. 71 | with tf.control_dependencies(_minimum_control_deps(outputs)): 72 | model._explain_counter.assign_add(1) 73 | return outputs 74 | 75 | if model._jit_compile: 76 | run_step = tf.function(run_step, jit_compile=True, reduce_retracing=True) 77 | 78 | data = next(iterator) 79 | outputs = model.distribute_strategy.run(run_step, args=(data,)) 80 | outputs = reduce_per_replica( 81 | outputs, model.distribute_strategy, reduction="concat" 82 | ) 83 | return outputs 84 | 85 | if not model.run_eagerly: 86 | explain_function = tf.function(explain_function, reduce_retracing=True) 87 | 88 | model.explain_function = explain_function 89 | 90 | return explain_function 91 | 92 | 93 | def make_data_handler( 94 | model, 95 | x, 96 | y, 97 | batch_size=None, 98 | steps=None, 99 | max_queue_size=10, 100 | workers=1, 101 | use_multiprocessing=False, 102 | ): 103 | dataset_types = (tf.compat.v1.data.Dataset, tf.data.Dataset) 104 | if ( 105 | model._in_multi_worker_mode() or _is_tpu_multi_host(model.distribute_strategy) 106 | ) and isinstance(x, dataset_types): 107 | try: 108 | opts = tf.data.Options() 109 | opts.experimental_distribute.auto_shard_policy = ( 110 | tf.data.experimental.AutoShardPolicy.DATA 111 | ) 112 | x = x.with_options(opts) 113 | except ValueError: 114 | warnings.warn( 115 | "Using evaluate with MultiWorkerMirroredStrategy " 116 | "or TPUStrategy and AutoShardPolicy.FILE might lead to " 117 | "out-of-order result. Consider setting it to " 118 | "AutoShardPolicy.DATA.", 119 | stacklevel=2, 120 | ) 121 | 122 | return data_adapter.get_data_handler( 123 | x=x, 124 | y=y, 125 | batch_size=batch_size, 126 | steps_per_epoch=steps, 127 | initial_epoch=0, 128 | epochs=1, 129 | max_queue_size=max_queue_size, 130 | workers=workers, 131 | use_multiprocessing=use_multiprocessing, 132 | model=model, 133 | steps_per_execution=model._steps_per_execution, 134 | ) 135 | 136 | 137 | def explain( 138 | method: Callable, 139 | model: tf.keras.Model, 140 | x: Union[np.ndarray, tf.Tensor, tf.data.Dataset], 141 | y: Optional[Union[np.ndarray, tf.Tensor]] = None, 142 | batch_size: Optional[int] = None, 143 | verbose: Union[str, int] = "auto", 144 | steps: Optional[int] = None, 145 | callbacks: List[Callback] = None, 146 | max_queue_size: int = 10, 147 | workers: int = 1, 148 | use_multiprocessing: bool = False, 149 | force: bool = True, 150 | **method_params, 151 | ) -> Tuple[np.ndarray, np.ndarray]: 152 | """Explain the outputs of ``model`` with respect to the inputs or an intermediate 153 | signal, using an AI explaining method. 154 | 155 | Usage: 156 | 157 | .. code-block:: python 158 | 159 | x = np.random.normal((1, 224, 224, 3)) 160 | y = np.asarray([[16, 32]]) 161 | 162 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 163 | 164 | scores, maps = ke.explain( 165 | ke.methods.gradient.gradients, 166 | model, x, y, 167 | postprocessing=filters.absolute_normalize, 168 | ) 169 | 170 | Args: 171 | method (Callable): An AI explaining function, as the ones contained in 172 | `methods` module. 173 | model (tf.keras.Model): The model whose predictions should be explained. 174 | x (Union[np.ndarray, tf.Tensor, tf.data.Dataset]): the input data for the model. 175 | y (Optional[Union[np.ndarray, tf.Tensor]], optional): the indices in the output 176 | tensor that should be explained. If none, an activation map is computed 177 | for each unit. Defaults to None. 178 | batch_size (Optional[int], optional): the batch size used by ``method``. 179 | Defaults to 32. 180 | verbose (Union[str, int], optional): wether to show a progress bar during 181 | the calculation of the explaining maps. Defaults to "auto". 182 | steps (Optional[int], optional): the number of steps, if ``x`` is a 183 | ``tf.data.Dataset`` of unknown cardinallity. Defaults to None. 184 | callbacks (List[Callback], optional): list of callbacks called during the 185 | explaining procedure. Defaults to None. 186 | max_queue_size (int, optional): the queue size when retrieving inputs. 187 | Used if ``x`` is a generator. Defaults to 10. 188 | workers (int, optional): the number of workers used when retrieving inputs. 189 | Defaults to 1. 190 | use_multiprocessing (bool, optional): wether to employ multi-process or 191 | multi-threading when retrieving inputs, when ``x`` is a generator. 192 | Defaults to False. 193 | force (bool, optional): to force the creation of the explaining function. 194 | Can be set to False if the same function is always applied to a model, 195 | avoiding retracing. Defaults to True. 196 | 197 | Besides the parameters described above, any named parameters passed to this function 198 | will be collected into ``methods_params`` and passed onto the :func:`explain_step` 199 | and ``method`` functions. The most common ones are: 200 | 201 | - **indices_batch_dims** (int): The dimensions marked as ``batch`` when gathering 202 | units described by ``y``. Ignore if ``y`` is None. 203 | - **indices_axis** (int): The axes from which to gather units described by ``y``. 204 | Ignore if ``y`` is None. 205 | - **spatial_axis** (Tuple[int]): The axes containing the positional visual info. 206 | We assume `inputs` to contain 2D images or videos in the shape 207 | `(B1, B2, ..., BN, H, W, 3)`. 208 | For 3D image data, set `spatial_axis` to `(1, 2, 3)` or `(-4, -3, -2)`. 209 | - **postprocessing** (Callable): A function to process the activation maps before 210 | normalization (most commonly adopted being `maximum(x, 0)` and `abs`). 211 | 212 | Raises: 213 | ValueError: the explaining method produced in an unexpected. 214 | 215 | Returns: 216 | Tuple[np.ndarray, np.ndarray]: logits and explaining maps tensors. 217 | """ 218 | 219 | if not hasattr(model, "_explain_counter"): 220 | agg = tf.VariableAggregation.ONLY_FIRST_REPLICA 221 | model._explain_counter = tf.Variable(0, dtype="int64", aggregation=agg) 222 | 223 | outputs = None 224 | with model.distribute_strategy.scope(): 225 | # Creates a `tf.data.Dataset` and handles batch and epoch iteration. 226 | data_handler = make_data_handler( 227 | model, 228 | x, 229 | y, 230 | batch_size=batch_size, 231 | steps=steps, 232 | max_queue_size=max_queue_size, 233 | workers=workers, 234 | use_multiprocessing=use_multiprocessing, 235 | ) 236 | 237 | # Container that configures and calls `tf.keras.Callback`s. 238 | if not isinstance(callbacks, callbacks_module.CallbackList): 239 | callbacks = callbacks_module.CallbackList( 240 | callbacks, 241 | add_history=True, 242 | add_progbar=verbose != 0, 243 | model=model, 244 | verbose=verbose, 245 | epochs=1, 246 | steps=data_handler.inferred_steps, 247 | ) 248 | 249 | explain_function = make_explain_function(model, method, method_params, force) 250 | model._explain_counter.assign(0) 251 | callbacks.on_predict_begin() 252 | batch_outputs = None 253 | for _, iterator in data_handler.enumerate_epochs(): # Single epoch. 254 | with data_handler.catch_stop_iteration(): 255 | for step in data_handler.steps(): 256 | callbacks.on_predict_batch_begin(step) 257 | tmp_batch_outputs = explain_function(iterator) 258 | if data_handler.should_sync: 259 | context.async_wait() 260 | batch_outputs = tmp_batch_outputs # No error, now safe to assign. 261 | if outputs is None: 262 | outputs = tf.nest.map_structure( 263 | lambda batch_output: [batch_output], 264 | batch_outputs, 265 | ) 266 | else: 267 | tf.__internal__.nest.map_structure_up_to( 268 | batch_outputs, 269 | lambda output, batch_output: output.append(batch_output), 270 | outputs, 271 | batch_outputs, 272 | ) 273 | end_step = step + data_handler.step_increment 274 | callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs}) 275 | if batch_outputs is None: 276 | raise ValueError( 277 | "Unexpected result of `explain_function` " 278 | "(Empty batch_outputs). Please use " 279 | "`Model.compile(..., run_eagerly=True)`, or " 280 | "`tf.config.run_functions_eagerly(True)` for more " 281 | "information of where went wrong, or file a " 282 | "issue/bug to `keras-explainable`." 283 | ) 284 | callbacks.on_predict_end() 285 | all_outputs = tf.__internal__.nest.map_structure_up_to( 286 | batch_outputs, potentially_ragged_concat, outputs 287 | ) 288 | return tf_utils.sync_to_numpy_or_python_type(all_outputs) 289 | 290 | 291 | def partial_explain(method: Callable, **default_params): 292 | """Wrapper for explaining methods. 293 | 294 | Args: 295 | method (Callable): the explaining method being wrapped by ``explain``. 296 | """ 297 | 298 | def _partial_method_explain(*args, **params): 299 | params = {**default_params, **params} 300 | return explain(method, *args, **params) 301 | 302 | _partial_method_explain.__name__ = f"{method.__name__}_explain" 303 | 304 | return _partial_method_explain 305 | -------------------------------------------------------------------------------- /src/keras_explainable/inspection.py: -------------------------------------------------------------------------------- 1 | """Inspection utils for models and layers. 2 | """ 3 | 4 | from typing import Dict 5 | from typing import List 6 | from typing import Optional 7 | from typing import Tuple 8 | from typing import Type 9 | from typing import Union 10 | 11 | import tensorflow as tf 12 | from keras.engine.base_layer import Layer 13 | from keras.engine.keras_tensor import KerasTensor 14 | from keras.engine.training import Model 15 | from keras.layers.normalization.batch_normalization import BatchNormalizationBase 16 | from keras.layers.normalization.layer_normalization import LayerNormalization 17 | from keras.layers.pooling.base_global_pooling1d import GlobalPooling1D 18 | from keras.layers.pooling.base_global_pooling2d import GlobalPooling2D 19 | from keras.layers.pooling.base_global_pooling3d import GlobalPooling3D 20 | from keras.layers.reshaping.flatten import Flatten 21 | 22 | from keras_explainable.utils import tolist 23 | 24 | E = Union[str, int, tf.Tensor, KerasTensor, Dict[str, Union[str, int]]] 25 | 26 | KERNEL_AXIS = -1 27 | SPATIAL_AXIS = (-3, -2) 28 | 29 | NORMALIZATION_LAYERS = ( 30 | BatchNormalizationBase, 31 | LayerNormalization, 32 | ) 33 | 34 | POOLING_LAYERS = ( 35 | Flatten, 36 | GlobalPooling1D, 37 | GlobalPooling2D, 38 | GlobalPooling3D, 39 | ) 40 | 41 | 42 | def get_nested_layer( 43 | model: Model, 44 | name: Union[str, List[str]], 45 | ) -> Layer: 46 | """Retrieve a nested layer in the model. 47 | 48 | Args: 49 | model (Model): the model containing the nested layer. 50 | name (Union[str, List[str]]): a string (or list of string) containing 51 | the name of the layer (or a list of names, each of which references 52 | a recursively nested module up to the layer of interest). 53 | 54 | Example: 55 | 56 | .. code-block:: python 57 | 58 | model = tf.keras.Sequential([ 59 | tf.keras.applications.ResNet101V2(include_top=False, pooling='avg'), 60 | tf.keras.layers.Dense(10, activation='softmax', name='predictions') 61 | ]) 62 | 63 | pooling_layer = get_nested_layer(model, ('resnet101v2', 'avg_pool')) 64 | 65 | Raises: 66 | ValueError: if ``name`` is not a nested member of ``model``. 67 | 68 | Returns: 69 | tf.keras.layer.Layer: the retrieved layer. 70 | """ 71 | for n in tolist(name): 72 | model = model.get_layer(n) 73 | 74 | return model 75 | 76 | 77 | def get_logits_layer( 78 | model: Model, 79 | name: str = None, 80 | ) -> Layer: 81 | """Retrieve the "logits" layer. 82 | 83 | Args: 84 | model (Model): the model containing the logits layer. 85 | name (str, optional): the name of the layer, if known. Defaults to None. 86 | 87 | Raises: 88 | ValueError: if a logits layer cannot be found 89 | 90 | Returns: 91 | Layer: the retrieved logits layer 92 | """ 93 | return find_layer_with(model, name, properties=["kernel"]) 94 | 95 | 96 | def get_global_pooling_layer( 97 | model: Model, 98 | name: str = None, 99 | ) -> Layer: 100 | """Retrieve the last global pooling layer. 101 | 102 | Args: 103 | model (Model): the model containing the pooling layer. 104 | name (str, optional): the name of the layer, if known. Defaults to None. 105 | 106 | Raises: 107 | ValueError: if a pooling layer cannot be found 108 | 109 | Returns: 110 | Layer: the retrieved pooling layer 111 | """ 112 | return find_layer_with(model, name, klass=POOLING_LAYERS) 113 | 114 | 115 | def find_layer_with( 116 | model: Model, 117 | name: Optional[str] = None, 118 | properties: Optional[Tuple[str]] = None, 119 | klass: Optional[Tuple[Type[Layer]]] = None, 120 | search_reversed: bool = True, 121 | ) -> Layer: 122 | """Find a layer within a model that satisfies all required properties. 123 | 124 | Args: 125 | model (Model): the container model. 126 | name (Optional[str], optional): the name of the layer, if known. 127 | Defaults to None. 128 | properties (Optional[Tuple[str]], optional): a list of properties that 129 | should be visible from the searched layer. Defaults to None. 130 | klass (Optional[Tuple[Type[Layer]]], optional): a collection of classes 131 | allowed for the searched layer. Defaults to None. 132 | search_reversed (bool, optional): wether to search from last-to-first. 133 | Defaults to True. 134 | 135 | Raises: 136 | ValueError: if no search parameters are passed. 137 | ValueError: if no valid layer can be found with the specified search 138 | parameters. 139 | 140 | Returns: 141 | Layer: the layer satisfying all search parameters. 142 | """ 143 | search_params = (name, properties, klass) 144 | if all(p is None for p in search_params): 145 | raise ValueError( 146 | "At least one of the search search parameters must " 147 | "be set when calling `get_layer`, indicating the " 148 | "necessary properties for the layer being retrieved." 149 | ) 150 | 151 | if name is not None: 152 | return get_nested_layer(model, name) 153 | 154 | layers = model._flatten_layers(include_self=False) 155 | if search_reversed: 156 | layers = reversed(list(layers)) 157 | for layer in layers: 158 | if klass and not isinstance(layer, klass): 159 | continue 160 | if properties and not all(hasattr(layer, p) for p in properties): 161 | continue 162 | 163 | return layer # `layer` matches all conditions. 164 | 165 | raise ValueError( 166 | f"A valid layer couldn't be inferred from the name=`{name}`, " 167 | f"klass=`{klass}` and properties=`{properties}`. Make sure these " 168 | "attributes correctly reflect a layer in the model." 169 | ) 170 | 171 | 172 | def endpoints(model: Model, endpoints: List[E]) -> List[KerasTensor]: 173 | """Collect intermediate endpoints in a model based on structured descriptors. 174 | 175 | Args: 176 | model (Model): the model containing the endpoints to be collected. 177 | endpoints (List[E]): descriptors of endpoints that should be collected. 178 | 179 | Raises: 180 | ValueError: raised whenever one of the endpoint descriptors is invalid 181 | or it does not describe a nested layer in the `model`. 182 | 183 | Returns: 184 | List[KerasTensor]: a list containing the endpoints of interest. 185 | """ 186 | endpoints_ = [] 187 | 188 | for ep in endpoints: 189 | if isinstance(ep, int): 190 | ep = {"layer": model.layers[ep]} 191 | elif isinstance(ep, Layer): 192 | ep = {"layer": ep} 193 | elif isinstance(ep, str): 194 | ep = {"name": ep} 195 | 196 | if not isinstance(ep, dict): 197 | raise ValueError( 198 | f"Illegal type {type(ep)} for endpoint {ep}. Expected a " 199 | "layer index (`int`), layer name (`str`), a layer " 200 | "(`keras.layers.Layer`) or a dictionary with " 201 | "`name`/`layer`, `link` and `node` keys." 202 | ) 203 | 204 | if "layer" in ep: 205 | layer = ep["layer"] 206 | else: 207 | layer = get_nested_layer(model, ep["name"]) 208 | 209 | link = ep.get("link", "output") 210 | node = ep.get("node", "last") 211 | 212 | if node == "last": 213 | node = len(layer._inbound_nodes) - 1 214 | 215 | endpoint = ( 216 | layer.get_input_at(node) if link == "input" else layer.get_output_at(node) 217 | ) 218 | 219 | endpoints_.append(endpoint) 220 | 221 | return endpoints_ 222 | 223 | 224 | def expose( 225 | model: Model, 226 | arguments: Optional[E] = None, 227 | outputs: Optional[E] = None, 228 | ) -> Model: 229 | """Creates a new model that exposes all endpoints described by 230 | ``arguments`` and ``outputs``. 231 | 232 | Args: 233 | model (Model): The model being explained. 234 | arguments (Optional[E], optional): Name of the argument layer/tensor in 235 | the model. The jacobian of the output explaining units will be computed 236 | with respect to the input signal of this layer. This argument can also 237 | be an integer, a dictionary representing the intermediate signal or 238 | the pooling layer itself. If None is passed, the penultimate layer 239 | is assumed to be a GAP layer. Defaults to None. 240 | outputs (Optional[E], optional): Name of the output layer in the model. 241 | The jacobian will be computed for the activation signal of units in this 242 | layer. This argument can also be an integer, a dictionary representing 243 | the output signal and the logits layer itself. If None is passed, 244 | the last layer is assumed to be the logits layer. Defaults to None. 245 | 246 | Returns: 247 | Model: the exposed model, whose outputs contain the intermediate and 248 | output tensors. 249 | """ 250 | if outputs is None: 251 | outputs = get_logits_layer(model) 252 | if isinstance(arguments, (str, tuple)): 253 | arguments = {"name": arguments} 254 | if arguments is None: 255 | gpl = get_global_pooling_layer(model) 256 | arguments = {"layer": gpl, "link": "input"} 257 | 258 | outputs = tolist(outputs) 259 | arguments = tolist(arguments) 260 | 261 | tensors = endpoints(model, outputs + arguments) 262 | 263 | return Model( 264 | inputs=model.inputs, 265 | outputs=tensors, 266 | ) 267 | 268 | 269 | def gather_units( 270 | tensor: tf.Tensor, 271 | indices: Optional[tf.Tensor], 272 | axis: int = -1, 273 | batch_dims: int = -1, 274 | ) -> tf.Tensor: 275 | """Gather units (in the last axis) from a tensor. 276 | 277 | Args: 278 | tensor (tf.Tensor): the input tensor. 279 | indices (tf.Tensor, optional): the indices that should be gathered. 280 | axis (int, optional): the axis from which indices should be taken, 281 | used to fine control gathering. Defaults to -1. 282 | batch_dims (int, optional): the number of batch dimensions, used to 283 | fine control gathering. Defaults to -1. 284 | 285 | Returns: 286 | tf.Tensor: the gathered units 287 | """ 288 | if indices is None: 289 | return tensor 290 | 291 | return tf.gather(tensor, indices, axis=axis, batch_dims=batch_dims) 292 | 293 | 294 | def layers_with_biases( 295 | model: Model, 296 | exclude: Tuple[Layer] = (), 297 | return_biases: bool = True, 298 | ) -> List[Layer]: 299 | """Extract layers containing biases from a model. 300 | 301 | Args: 302 | model (Model): the model inspected. 303 | exclude (Tuple[Layer], optional): a list of layers to ignore. Defaults to (). 304 | return_biases (bool, optional): wether or not to return the biases as well. 305 | Defaults to True. 306 | 307 | Returns: 308 | List[Layer]: a list of layers. 309 | List[Layer], List[tf.Tensor]: a list of layers and biases. 310 | """ 311 | layers = [ 312 | layer 313 | for layer in model._flatten_layers(include_self=False) 314 | if ( 315 | layer not in exclude 316 | and ( 317 | isinstance(layer, NORMALIZATION_LAYERS) 318 | or hasattr(layer, "bias") 319 | and layer.bias is not None 320 | ) 321 | ) 322 | ] 323 | 324 | if return_biases: 325 | return layers, biases(layers) 326 | 327 | return layers 328 | 329 | 330 | def biases( 331 | layers: List[Layer], 332 | ) -> List[tf.Tensor]: 333 | """Recursively retrieve the biases from layers. 334 | 335 | Layers containing implicit bias are unrolled before returned. For 336 | instance, the Batch Normalization layer, whose equation is defined by 337 | :math:`y(x) = \\frac{x - \\mu}{\\sigma} w + b`, will have bias equals to: 338 | 339 | .. math:: 340 | 341 | \\frac{-\\mu w}{s} + b 342 | 343 | Args: 344 | layers (List[Layer]): a list of layers from which 345 | biases should be extracted. 346 | 347 | Returns: 348 | List[tf.Tensor]: a list of all biases retrieved. 349 | """ 350 | biases = [] 351 | 352 | for layer in layers: 353 | if isinstance(layer, NORMALIZATION_LAYERS): 354 | # Batch norm := ((x - m)/s)*w + b 355 | # Hence bias factor is -m*w/s + b. 356 | biases.append( 357 | -layer.moving_mean 358 | * layer.gamma 359 | / tf.sqrt(layer.moving_variance + 1e-07) # might be variance here. 360 | + layer.beta 361 | ) 362 | 363 | elif hasattr(layer, "bias") and layer.bias is not None: 364 | biases.append(layer.bias) 365 | 366 | return biases 367 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | Welcome to ``keras-explainable`` contributor's guide. 6 | 7 | This document focuses on getting any potential contributor familiarized 8 | with the development processes, but `other kinds of contributions`_ are also 9 | appreciated. 10 | 11 | If you are new to using git_ or have never collaborated in a project previously, 12 | please have a look at `contribution-guide.org`_. Other resources are also 13 | listed in the excellent `guide created by FreeCodeCamp`_ [#contrib1]_. 14 | 15 | Please notice, all users and contributors are expected to be **open, 16 | considerate, reasonable, and respectful**. When in doubt, `Python Software 17 | Foundation's Code of Conduct`_ is a good reference in terms of behavior 18 | guidelines. 19 | 20 | Issue Reports 21 | ============= 22 | 23 | If you experience bugs or general issues with ``keras-explainable``, please have a look 24 | on the `issue tracker`_. If you don't see anything useful there, please feel 25 | free to fire an issue report. 26 | 27 | .. tip:: 28 | Please don't forget to include the closed issues in your search. 29 | Sometimes a solution was already reported, and the problem is considered 30 | **solved**. 31 | 32 | New issue reports should include information about your programming environment 33 | (e.g., operating system, Python version) and steps to reproduce the problem. 34 | Please try also to simplify the reproduction steps to a very minimal example 35 | that still illustrates the problem you are facing. By removing other factors, 36 | you help us to identify the root cause of the issue. 37 | 38 | Documentation Improvements 39 | ========================== 40 | 41 | You can help improve ``keras-explainable`` docs by making them more readable and coherent, or 42 | by adding missing information and correcting mistakes. 43 | 44 | ``keras-explainable`` documentation uses Sphinx_ as its main documentation compiler. 45 | This means that the docs are kept in the same repository as the project code, and 46 | that any documentation update is done in the same way was a code contribution. 47 | 48 | .. todo:: Don't forget to mention which markup language you are using. 49 | 50 | e.g., reStructuredText_ or CommonMark_ with MyST_ extensions. 51 | 52 | .. todo:: If your project is hosted on GitHub, you can also mention the following tip: 53 | 54 | .. tip:: 55 | Please notice that the `GitHub web interface`_ provides a quick way of 56 | propose changes in ``keras-explainable``'s files. While this mechanism can 57 | be tricky for normal code contributions, it works perfectly fine for 58 | contributing to the docs, and can be quite handy. 59 | 60 | If you are interested in trying this method out, please navigate to 61 | the ``docs`` folder in the source repository_, find which file you 62 | would like to propose changes and click in the little pencil icon at the 63 | top, to open `GitHub's code editor`_. Once you finish editing the file, 64 | please write a message in the form at the bottom of the page describing 65 | which changes have you made and what are the motivations behind them and 66 | submit your proposal. 67 | 68 | When working on documentation changes in your local machine, you can 69 | compile them using |tox|_:: 70 | 71 | tox -e docs 72 | 73 | and use Python's built-in web server for a preview in your web browser 74 | (``http://localhost:8000``):: 75 | 76 | python3 -m http.server --directory 'docs/_build/html' 77 | 78 | Code Contributions 79 | ================== 80 | 81 | .. todo:: Please include a reference or explanation about the internals of the project. 82 | 83 | An architecture description, design principles or at least a summary of the 84 | main concepts will make it easy for potential contributors to get started 85 | quickly. 86 | 87 | Submit an issue 88 | --------------- 89 | 90 | Before you work on any non-trivial code contribution it's best to first create 91 | a report in the `issue tracker`_ to start a discussion on the subject. 92 | This often provides additional considerations and avoids unnecessary work. 93 | 94 | Create an environment 95 | --------------------- 96 | 97 | Before you start coding, we recommend creating an isolated `virtual 98 | environment`_ to avoid any problems with your installed Python packages. 99 | This can easily be done via either |virtualenv|_:: 100 | 101 | virtualenv 102 | source /bin/activate 103 | 104 | or Miniconda_:: 105 | 106 | conda create -n keras-explainable python=3 six virtualenv pytest pytest-cov 107 | conda activate keras-explainable 108 | 109 | Clone the repository 110 | -------------------- 111 | 112 | #. Create an user account on |the repository service| if you do not already have one. 113 | #. Fork the project repository_: click on the *Fork* button near the top of the 114 | page. This creates a copy of the code under your account on |the repository service|. 115 | #. Clone this copy to your local disk:: 116 | 117 | git clone git@github.com:YourLogin/keras-explainable.git 118 | cd keras-explainable 119 | 120 | #. You should run:: 121 | 122 | pip install -U pip setuptools -e . 123 | 124 | to be able to import the package under development in the Python REPL. 125 | 126 | .. todo:: if you are not using pre-commit, please remove the following item: 127 | 128 | #. Install |pre-commit|_:: 129 | 130 | pip install pre-commit 131 | pre-commit install 132 | 133 | ``keras-explainable`` comes with a lot of hooks configured to automatically help the 134 | developer to check the code being written. 135 | 136 | Implement your changes 137 | ---------------------- 138 | 139 | #. Create a branch to hold your changes:: 140 | 141 | git checkout -b my-feature 142 | 143 | and start making changes. Never work on the main branch! 144 | 145 | #. Start your work on this branch. Don't forget to add docstrings_ to new 146 | functions, modules and classes, especially if they are part of public APIs. 147 | 148 | #. Add yourself to the list of contributors in ``AUTHORS.rst``. 149 | 150 | #. When you’re done editing, do:: 151 | 152 | git add 153 | git commit 154 | 155 | to record your changes in git_. 156 | 157 | .. todo:: if you are not using pre-commit, please remove the following item: 158 | 159 | Please make sure to see the validation messages from |pre-commit|_ and fix 160 | any eventual issues. 161 | This should automatically use flake8_/black_ to check/fix the code style 162 | in a way that is compatible with the project. 163 | 164 | .. important:: Don't forget to add unit tests and documentation in case your 165 | contribution adds an additional feature and is not just a bugfix. 166 | 167 | Moreover, writing a `descriptive commit message`_ is highly recommended. 168 | In case of doubt, you can check the commit history with:: 169 | 170 | git log --graph --decorate --pretty=oneline --abbrev-commit --all 171 | 172 | to look for recurring communication patterns. 173 | 174 | #. Please check that your changes don't break any unit tests with:: 175 | 176 | tox 177 | 178 | (after having installed |tox|_ with ``pip install tox`` or ``pipx``). 179 | 180 | You can also use |tox|_ to run several other pre-configured tasks in the 181 | repository. Try ``tox -av`` to see a list of the available checks. 182 | 183 | Submit your contribution 184 | ------------------------ 185 | 186 | #. If everything works fine, push your local branch to |the repository service| with:: 187 | 188 | git push -u origin my-feature 189 | 190 | #. Go to the web page of your fork and click |contribute button| 191 | to send your changes for review. 192 | 193 | .. todo:: if you are using GitHub, you can uncomment the following paragraph 194 | 195 | Find more detailed information in `creating a PR`_. You might also want to open 196 | the PR as a draft first and mark it as ready for review after the feedbacks 197 | from the continuous integration (CI) system or any required fixes. 198 | 199 | Troubleshooting 200 | --------------- 201 | 202 | The following tips can be used when facing problems to build or test the 203 | package: 204 | 205 | #. Make sure to fetch all the tags from the upstream repository_. 206 | The command ``git describe --abbrev=0 --tags`` should return the version you 207 | are expecting. If you are trying to run CI scripts in a fork repository, 208 | make sure to push all the tags. 209 | You can also try to remove all the egg files or the complete egg folder, i.e., 210 | ``.eggs``, as well as the ``*.egg-info`` folders in the ``src`` folder or 211 | potentially in the root of your project. 212 | 213 | #. Sometimes |tox|_ misses out when new dependencies are added, especially to 214 | ``setup.cfg`` and ``docs/requirements.txt``. If you find any problems with 215 | missing dependencies when running a command with |tox|_, try to recreate the 216 | ``tox`` environment using the ``-r`` flag. For example, instead of:: 217 | 218 | tox -e docs 219 | 220 | Try running:: 221 | 222 | tox -r -e docs 223 | 224 | #. Make sure to have a reliable |tox|_ installation that uses the correct 225 | Python version (e.g., 3.7+). When in doubt you can run:: 226 | 227 | tox --version 228 | # OR 229 | which tox 230 | 231 | If you have trouble and are seeing weird errors upon running |tox|_, you can 232 | also try to create a dedicated `virtual environment`_ with a |tox|_ binary 233 | freshly installed. For example:: 234 | 235 | virtualenv .venv 236 | source .venv/bin/activate 237 | .venv/bin/pip install tox 238 | .venv/bin/tox -e all 239 | 240 | #. `Pytest can drop you`_ in an interactive session in the case an error occurs. 241 | In order to do that you need to pass a ``--pdb`` option (for example by 242 | running ``tox -- -k --pdb``). 243 | You can also setup breakpoints manually instead of using the ``--pdb`` option. 244 | 245 | Maintainer tasks 246 | ================ 247 | 248 | Releases 249 | -------- 250 | 251 | .. todo:: This section assumes you are using PyPI to publicly release your package. 252 | 253 | If instead you are using a different/private package index, please update 254 | the instructions accordingly. 255 | 256 | If you are part of the group of maintainers and have correct user permissions 257 | on PyPI_, the following steps can be used to release a new version for 258 | ``keras-explainable``: 259 | 260 | #. Make sure all unit tests are successful. 261 | #. Tag the current commit on the main branch with a release tag, e.g., ``v1.2.3``. 262 | #. Push the new tag to the upstream repository_, e.g., ``git push upstream v1.2.3`` 263 | #. Clean up the ``dist`` and ``build`` folders with ``tox -e clean`` 264 | (or ``rm -rf dist build``) 265 | to avoid confusion with old builds and Sphinx docs. 266 | #. Run ``tox -e build`` and check that the files in ``dist`` have 267 | the correct version (no ``.dirty`` or git_ hash) according to the git_ tag. 268 | Also check the sizes of the distributions, if they are too big (e.g., > 269 | 500KB), unwanted clutter may have been accidentally included. 270 | #. Run ``tox -e publish -- --repository pypi`` and check that everything was 271 | uploaded to PyPI_ correctly. 272 | 273 | .. [#contrib1] Even though, these resources focus on open source projects and 274 | communities, the general ideas behind collaborating with other developers 275 | to collectively create software are general and can be applied to all sorts 276 | of environments, including private companies and proprietary code bases. 277 | 278 | .. <-- strart --> 279 | .. todo:: Please review and change the following definitions: 280 | 281 | .. |the repository service| replace:: GitHub 282 | .. |contribute button| replace:: "Create pull request" 283 | 284 | .. _repository: https://github.com/lucasdavid/keras-explainable 285 | .. _issue tracker: https://github.com/lucasdavid/keras-explainable/issues 286 | .. <-- end --> 287 | 288 | .. |virtualenv| replace:: ``virtualenv`` 289 | .. |pre-commit| replace:: ``pre-commit`` 290 | .. |tox| replace:: ``tox`` 291 | 292 | .. _black: https://pypi.org/project/black/ 293 | .. _CommonMark: https://commonmark.org/ 294 | .. _contribution-guide.org: https://www.contribution-guide.org/ 295 | .. _creating a PR: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request 296 | .. _descriptive commit message: https://chris.beams.io/posts/git-commit 297 | .. _docstrings: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 298 | .. _first-contributions tutorial: https://github.com/firstcontributions/first-contributions 299 | .. _flake8: https://flake8.pycqa.org/en/stable/ 300 | .. _git: https://git-scm.com 301 | .. _GitHub's fork and pull request workflow: https://guides.github.com/activities/forking/ 302 | .. _guide created by FreeCodeCamp: https://github.com/FreeCodeCamp/how-to-contribute-to-open-source 303 | .. _Miniconda: https://docs.conda.io/en/latest/miniconda.html 304 | .. _MyST: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html 305 | .. _other kinds of contributions: https://opensource.guide/how-to-contribute 306 | .. _pre-commit: https://pre-commit.com/ 307 | .. _PyPI: https://pypi.org/ 308 | .. _PyScaffold's contributor's guide: https://pyscaffold.org/en/stable/contributing.html 309 | .. _Pytest can drop you: https://docs.pytest.org/en/stable/how-to/failures.html#using-python-library-pdb-with-pytest 310 | .. _Python Software Foundation's Code of Conduct: https://www.python.org/psf/conduct/ 311 | .. _reStructuredText: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ 312 | .. _Sphinx: https://www.sphinx-doc.org/en/master/ 313 | .. _tox: https://tox.wiki/en/stable/ 314 | .. _virtual environment: https://realpython.com/python-virtual-environments-a-primer/ 315 | .. _virtualenv: https://virtualenv.pypa.io/en/stable/ 316 | 317 | .. _GitHub web interface: https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files 318 | .. _GitHub's code editor: https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files 319 | -------------------------------------------------------------------------------- /src/keras_explainable/methods/cams.py: -------------------------------------------------------------------------------- 1 | """Implementation of various CAM-based AI explaining methods and techniques. 2 | """ 3 | 4 | from typing import Optional 5 | from typing import Tuple 6 | from typing import Union 7 | 8 | import tensorflow as tf 9 | from keras.backend import int_shape 10 | from keras.engine.base_layer import Layer 11 | 12 | from keras_explainable.filters import normalize 13 | from keras_explainable.inspection import KERNEL_AXIS 14 | from keras_explainable.inspection import SPATIAL_AXIS 15 | from keras_explainable.inspection import gather_units 16 | from keras_explainable.inspection import get_logits_layer 17 | 18 | 19 | def cam( 20 | model: tf.keras.Model, 21 | inputs: tf.Tensor, 22 | indices: Optional[tf.Tensor] = None, 23 | indices_axis: int = KERNEL_AXIS, 24 | indices_batch_dims: int = -1, 25 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 26 | logits_layer: Optional[Union[str, Layer]] = None, 27 | ) -> Tuple[tf.Tensor, tf.Tensor]: 28 | """Computes the CAM Visualization Method. 29 | 30 | This method expects `inputs` to be a batch of positional signals of 31 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 32 | where ``(H', W', ...)`` are the sizes of the visual receptive field 33 | in the explained activation layer and ``L`` is the number of labels 34 | represented within the model's output logits. 35 | 36 | If ``indices`` is passed, the specific logits indexed by elements in 37 | this tensor are selected before the gradients are computed, 38 | effectively reducing the columns in the jacobian, and the size of 39 | the output explaining map. 40 | 41 | Usage: 42 | 43 | .. code-block:: python 44 | 45 | x = np.random.normal((1, 224, 224, 3)) 46 | y = np.asarray([[16, 32]]) 47 | 48 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 49 | model = ke.inspection.expose(model) 50 | 51 | scores, cams = ke.methods.cams.cam(model, x, y) 52 | 53 | References: 54 | - Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., & Torralba, A. (2016). 55 | Learning deep features for discriminative localization. In Proceedings 56 | of the IEEE conference on computer vision and pattern 57 | recognition (pp. 2921-2929). Available at: 58 | `arxiv/1512.04150 `_. 59 | 60 | Args: 61 | model (tf.keras.Model): the model being explained 62 | inputs (tf.Tensor): the input data 63 | indices (Optional[tf.Tensor], optional): indices that should be gathered 64 | from ``outputs``. Defaults to None. 65 | indices_axis (int, optional): the axis containing the indices to gather. 66 | Defaults to ``KERNEL_AXIS``. 67 | indices_batch_dims (int, optional): the number of dimensions to broadcast 68 | in the ``tf.gather`` operation. Defaults to ``-1``. 69 | spatial_axis (Tuple[int], optional): the dimensions containing positional 70 | information. Defaults to ``SPATIAL_AXIS``. 71 | logits_layer (Callable, optional): filter before channel combining. 72 | Defaults to tf.abs. 73 | 74 | Returns: 75 | Tuple[tf.Tensor, tf.Tensor]: the logits and Class Activation Maps (CAMs). 76 | 77 | """ 78 | logits, activations = model(inputs, training=False) 79 | logits = gather_units(logits, indices, indices_axis, indices_batch_dims) 80 | 81 | if isinstance(logits_layer, str) or logits_layer is None: 82 | logits_layer = get_logits_layer(model, name=logits_layer) 83 | 84 | weights = gather_units( 85 | tf.squeeze(logits_layer.kernel), indices, axis=-1, batch_dims=0 86 | ) 87 | 88 | dims = "kc" if indices is None else "kbc" 89 | maps = tf.einsum(f"b...k,{dims}->b...c", activations, weights) 90 | 91 | return logits, maps 92 | 93 | 94 | def gradcam( 95 | model: tf.keras.Model, 96 | inputs: tf.Tensor, 97 | indices: Optional[tf.Tensor] = None, 98 | indices_axis: int = KERNEL_AXIS, 99 | indices_batch_dims: int = -1, 100 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 101 | ): 102 | """Computes the Grad-CAM Visualization Method. 103 | 104 | This method expects `inputs` to be a batch of positional signals of 105 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 106 | where ``(H', W', ...)`` are the sizes of the visual receptive field 107 | in the explained activation layer and `L` is the number of labels 108 | represented within the model's output logits. 109 | 110 | If `indices` is passed, the specific logits indexed by elements in this 111 | tensor are selected before the gradients are computed, effectively 112 | reducing the columns in the jacobian, and the size of the output explaining map. 113 | 114 | Usage: 115 | 116 | .. code-block:: python 117 | 118 | x = np.random.normal((1, 224, 224, 3)) 119 | y = np.asarray([[16, 32]]) 120 | 121 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 122 | model = ke.inspection.expose(model) 123 | 124 | scores, cams = ke.methods.cams.gradcam(model, x, y) 125 | 126 | References: 127 | - Selvaraju, R. R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., & Batra, D. 128 | (2017). Grad-CAM: Visual explanations from deep networks via gradient-based 129 | localization. In Proceedings of the IEEE international conference on computer 130 | vision (pp. 618-626). 131 | Available at: `arxiv/1610.02391 `_. 132 | 133 | Args: 134 | model (tf.keras.Model): the model being explained 135 | inputs (tf.Tensor): the input data 136 | indices (Optional[tf.Tensor], optional): indices that should be gathered 137 | from ``outputs``. Defaults to None. 138 | indices_axis (int, optional): the axis containing the indices to gather. 139 | Defaults to ``KERNEL_AXIS``. 140 | indices_batch_dims (int, optional): the number of dimensions to broadcast 141 | in the ``tf.gather`` operation. Defaults to ``-1``. 142 | spatial_axis (Tuple[int], optional): the dimensions containing positional 143 | information. Defaults to ``SPATIAL_AXIS``. 144 | 145 | Returns: 146 | Tuple[tf.Tensor, tf.Tensor]: the logits and Class Activation Maps (CAMs). 147 | 148 | """ 149 | with tf.GradientTape(watch_accessed_variables=False) as tape: 150 | tape.watch(inputs) 151 | logits, activations = model(inputs, training=False) 152 | logits = gather_units(logits, indices, indices_axis, indices_batch_dims) 153 | 154 | dlda = tape.batch_jacobian(logits, activations) 155 | weights = tf.reduce_mean(dlda, axis=spatial_axis) 156 | maps = tf.einsum("b...k,bck->b...c", activations, weights) 157 | 158 | return logits, maps 159 | 160 | 161 | def gradcampp( 162 | model: tf.keras.Model, 163 | inputs: tf.Tensor, 164 | indices: Optional[tf.Tensor] = None, 165 | indices_axis: int = KERNEL_AXIS, 166 | indices_batch_dims: int = -1, 167 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 168 | ): 169 | """Computes the Grad-CAM++ Visualization Method. 170 | 171 | This method expects `inputs` to be a batch of positional signals of 172 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 173 | where ``(H', W', ...)`` are the sizes of the visual receptive field 174 | in the explained activation layer and `L` is the number of labels 175 | represented within the model's output logits. 176 | 177 | If `indices` is passed, the specific logits indexed by elements in this 178 | tensor are selected before the gradients are computed, effectively 179 | reducing the columns in the jacobian, and the size of the output explaining map. 180 | 181 | Usage: 182 | 183 | .. code-block:: python 184 | 185 | x = np.random.normal((1, 224, 224, 3)) 186 | y = np.asarray([[16, 32]]) 187 | 188 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 189 | model = ke.inspection.expose(model) 190 | 191 | scores, cams = ke.methods.cams.gradcampp(model, x, y) 192 | 193 | References: 194 | - Chattopadhay, A., Sarkar, A., Howlader, P., & Balasubramanian, V. N. 195 | (2018, March). Grad-cam++: Generalized gradient-based visual explanations 196 | for deep convolutional networks. In 2018 IEEE winter conference on 197 | applications of computer vision (WACV) (pp. 839-847). IEEE. 198 | - Grad-CAM++'s official implementation. Github. Available at: 199 | `adityac94/Grad-CAM++ `_ 200 | 201 | Args: 202 | model (tf.keras.Model): the model being explained 203 | inputs (tf.Tensor): the input data 204 | indices (Optional[tf.Tensor], optional): indices that should be gathered 205 | from ``outputs``. Defaults to None. 206 | indices_axis (int, optional): the axis containing the indices to gather. 207 | Defaults to ``KERNEL_AXIS``. 208 | indices_batch_dims (int, optional): the number of dimensions to broadcast 209 | in the ``tf.gather`` operation. Defaults to ``-1``. 210 | spatial_axis (Tuple[int], optional): the dimensions containing positional 211 | information. Defaults to ``SPATIAL_AXIS``. 212 | 213 | Returns: 214 | Tuple[tf.Tensor, tf.Tensor]: the logits and Class Activation Maps (CAMs). 215 | 216 | """ 217 | with tf.GradientTape(watch_accessed_variables=False) as tape: 218 | tape.watch(inputs) 219 | logits, activations = model(inputs, training=False) 220 | logits = gather_units(logits, indices, indices_axis, indices_batch_dims) 221 | 222 | dlda = tape.batch_jacobian(logits, activations) 223 | 224 | dyda = tf.einsum("bc,bc...k->bc...k", tf.exp(logits), dlda) 225 | d2 = dlda**2 226 | d3 = dlda**3 227 | aab = tf.reduce_sum(activations, axis=spatial_axis) # (BK) 228 | akc = tf.math.divide_no_nan( 229 | d2, 230 | 2.0 * d2 + tf.einsum("bk,bc...k->bc...k", aab, d3), # (2*(BUHWK) + (BK)*BUHWK) 231 | ) 232 | 233 | # Tensorflow has a glitch that doesn't allow this form: 234 | # weights = tf.einsum('bc...k,bc...k->bck', akc, tf.nn.relu(dyda)) # w: buk 235 | # So we use this one instead: 236 | weights = tf.reduce_sum(akc * tf.nn.relu(dyda), axis=spatial_axis) 237 | 238 | maps = tf.einsum("bck,b...k->b...c", weights, activations) # a: bhwk, m: buhw 239 | 240 | return logits, maps 241 | 242 | 243 | def scorecam( 244 | model: tf.keras.Model, 245 | inputs: tf.Tensor, 246 | indices: Optional[tf.Tensor] = None, 247 | indices_axis: int = KERNEL_AXIS, 248 | indices_batch_dims: int = -1, 249 | spatial_axis: Tuple[int] = SPATIAL_AXIS, 250 | ): 251 | """Computes the Score-CAM Visualization Method. 252 | 253 | This method expects `inputs` to be a batch of positional signals of 254 | shape ``BHW...C``, and will return a tensor of shape ``BH'W'...L``, 255 | where ``(H', W', ...)`` are the sizes of the visual receptive field 256 | in the explained activation layer and `L` is the number of labels 257 | represented within the model's output logits. 258 | 259 | If `indices` is passed, the specific logits indexed by elements in this 260 | tensor are selected before the gradients are computed, effectively 261 | reducing the columns in the jacobian, and the size of the output explaining map. 262 | 263 | Usage: 264 | 265 | .. code-block:: python 266 | 267 | x = np.random.normal((1, 224, 224, 3)) 268 | y = np.asarray([[16, 32]]) 269 | 270 | model = tf.keras.applications.ResNet50V2(classifier_activation=None) 271 | model = ke.inspection.expose(model) 272 | 273 | scores, cams = ke.methods.cams.scorecam(model, x, y) 274 | 275 | References: 276 | - Score-CAM: Score-Weighted Visual Explanations for Convolutional 277 | Neural Networks. Available at: 278 | `arxiv/1910.01279 `_ 279 | 280 | Args: 281 | model (tf.keras.Model): the model being explained 282 | inputs (tf.Tensor): the input data 283 | indices (Optional[tf.Tensor], optional): indices that should be gathered 284 | from ``outputs``. Defaults to None. 285 | indices_axis (int, optional): the axis containing the indices to gather. 286 | Defaults to ``KERNEL_AXIS``. 287 | indices_batch_dims (int, optional): the number of dimensions to broadcast 288 | in the ``tf.gather`` operation. Defaults to ``-1``. 289 | spatial_axis (Tuple[int], optional): the dimensions containing positional 290 | information. Defaults to ``SPATIAL_AXIS``. 291 | 292 | Returns: 293 | Tuple[tf.Tensor, tf.Tensor]: the logits and Class Activation Maps (CAMs). 294 | 295 | """ 296 | scores, activations = model(inputs, training=False) 297 | scores = gather_units(scores, indices, indices_axis, indices_batch_dims) 298 | 299 | classes = int_shape(scores)[-1] or tf.shape(scores)[-1] 300 | kernels = int_shape(activations)[-1] or tf.shape(activations)[-1] 301 | 302 | shape = tf.shape(inputs) 303 | sizes = [shape[a] for a in spatial_axis] 304 | maps = tf.zeros([shape[0]] + sizes + [classes]) 305 | 306 | for i in tf.range(kernels): 307 | mask = activations[..., i : i + 1] 308 | mask = normalize(mask, axis=spatial_axis) 309 | mask = tf.image.resize(mask, sizes) 310 | 311 | si, _ = model(inputs * mask, training=False) 312 | si = gather_units(si, indices, indices_axis, indices_batch_dims) 313 | si = tf.einsum("bc,bhw->bhwc", si, mask[..., 0]) 314 | maps += si 315 | 316 | return scores, maps 317 | 318 | 319 | METHODS = [ 320 | cam, 321 | gradcam, 322 | gradcampp, 323 | scorecam, 324 | ] 325 | """Available CAM-based AI Explaining methods. 326 | 327 | This list contains all available methods implemented in this module, 328 | and it is kept and used for introspection and validation purposes. 329 | """ 330 | 331 | __all__ = [ 332 | "cam", 333 | "gradcam", 334 | "gradcampp", 335 | "scorecam", 336 | ] 337 | --------------------------------------------------------------------------------