├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .gitlab-ci.yml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── COPYING ├── COPYING.LESSER ├── LICENSE ├── README.md ├── docs └── source │ ├── _static │ └── favicon.svg │ ├── _templates │ └── modules.rst │ ├── bibliography.bib │ ├── bibliography.rst │ ├── conf.py │ ├── getting-started.rst │ ├── how-to │ ├── compute-second-order-gradients.rst │ ├── get-intermediate-relevance.rst │ ├── index.rst │ ├── use-attributors.rst │ ├── use-rules-composites-and-canonizers.rst │ ├── visualize-results.rst │ ├── write-custom-attributors.rst │ ├── write-custom-canonizers.rst │ ├── write-custom-composites.rst │ └── write-custom-rules.rst │ ├── index.rst │ ├── reference │ └── index.rst │ └── tutorial │ ├── image-classification-vgg-resnet.ipynb │ └── index.rst ├── pylintrc ├── setup.py ├── share ├── example │ └── feed_forward.py ├── img │ ├── beacon_resnet50_various.webp │ ├── beacon_vgg16_epsilon_gamma_box.png │ ├── beacon_vgg16_various.webp │ ├── zennit.png │ └── zennit.svg ├── merge_maps │ └── vgg16_bn.json └── scripts │ ├── download-lighthouses.sh │ ├── palette_fit.py │ ├── palette_swap.py │ └── show_cmaps.py ├── src └── zennit │ ├── __init__.py │ ├── attribution.py │ ├── canonizers.py │ ├── cmap.py │ ├── composites.py │ ├── core.py │ ├── image.py │ ├── layer.py │ ├── rules.py │ ├── torchvision.py │ └── types.py ├── tests ├── conftest.py ├── helpers.py ├── test_attribution.py ├── test_canonizers.py ├── test_cmap.py ├── test_composites.py ├── test_core.py ├── test_image.py ├── test_rules.py └── test_torchvision.py └── tox.ini /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: 3 | push: 4 | branches: [master] 5 | pull_request: 6 | branches: [master] 7 | 8 | jobs: 9 | test: 10 | name: test ${{matrix.tox_env}} 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | include: 16 | - tox_env: py37 17 | python: "3.7" 18 | - tox_env: py38 19 | python: "3.8" 20 | - tox_env: py39 21 | python: "3.9" 22 | steps: 23 | - uses: actions/checkout@v2 24 | with: 25 | fetch-depth: 0 26 | - name: Install base python for tox 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: "3.9" 30 | - name: Install tox 31 | run: python -m pip install tox 32 | - name: Install python for test 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python }} 36 | - name: Setup test environment 37 | run: tox -vv --notest -e ${{ matrix.tox_env }} 38 | - name: Run test 39 | run: tox --skip-pkg-install -e ${{ matrix.tox_env }} 40 | 41 | 42 | check: 43 | name: check ${{ matrix.tox_env }} 44 | runs-on: ubuntu-latest 45 | strategy: 46 | fail-fast: false 47 | matrix: 48 | tox_env: 49 | - flake8 50 | - pylint 51 | steps: 52 | - uses: actions/checkout@v2 53 | with: 54 | fetch-depth: 0 55 | - name: Install base python for tox 56 | uses: actions/setup-python@v2 57 | with: 58 | python-version: "3.9" 59 | - name: Install tox 60 | run: python -m pip install tox 61 | - name: Setup test environment 62 | run: tox -vv --notest -e ${{ matrix.tox_env }} 63 | - name: Run test 64 | run: tox --skip-pkg-install -e ${{ matrix.tox_env }} 65 | 66 | docs: 67 | name: docs 68 | runs-on: ubuntu-latest 69 | strategy: 70 | fail-fast: false 71 | steps: 72 | - uses: actions/checkout@v2 73 | with: 74 | fetch-depth: 0 75 | - name: Install base python for tox 76 | uses: actions/setup-python@v2 77 | with: 78 | python-version: "3.9" 79 | - name: Install pandoc 80 | run: sudo apt-get update -y && sudo apt-get install -y pandoc 81 | - name: Install tox 82 | run: python -m pip install tox 83 | - name: Setup test environment 84 | run: tox -vv --notest -e docs 85 | - name: Run test 86 | run: tox --skip-pkg-install -e docs 87 | 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDEs and Code editors 2 | .vscode 3 | .idea 4 | .venv 5 | .exrc 6 | 7 | # Python output 8 | .python 9 | *.egg-info 10 | __pycache__ 11 | 12 | # Setup output 13 | build/ 14 | dist/ 15 | 16 | # Results output 17 | result 18 | 19 | # Testing 20 | .tox 21 | .coverage 22 | .coverage.* 23 | 24 | # System Files 25 | .DS_Store 26 | Thumbs.db 27 | 28 | # NPM dependencies 29 | node_modules 30 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | 2 | stages: 3 | - linting 4 | - unit-tests 5 | 6 | pylint: 7 | stage: linting 8 | script: 9 | - python3.8 -m tox -e pylint 10 | 11 | flake8: 12 | stage: linting 13 | script: 14 | - python3.8 -m tox -e flake8 15 | 16 | pytest: 17 | stage: unit-tests 18 | when: always 19 | script: 20 | - python3.8 -m tox -e py38 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.9" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: ["docs"] 16 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide for Zennit 2 | 3 | Thank you for your interest in contributing to Zennit! 4 | 5 | If you would like to fix a bug or add a feature, please write an issue before submitting a pull request. 6 | 7 | 8 | ## Git 9 | We use a linear git-history, where each commit contains a full feature/bug fix, 10 | such that each commit represents an executable version. 11 | The commit message contains a subject followed by an empty line, followed by a detailed description, similar to the following: 12 | 13 | ``` 14 | Category: Short subject describing changes (50 characters or less) 15 | 16 | - detailed description, wrapped at 72 characters 17 | - bullet points or sentences are okay 18 | - all changes should be documented and explained 19 | - valid categories are, for example: 20 | - `Docs` for documentation 21 | - `Tests` for tests 22 | - `Composites` for changes in composites 23 | - `Core` for core changes 24 | - `Package` for package-related changes, e.g. in setup.py 25 | ``` 26 | 27 | We recommend to not use `-m` for committing, as this often results in very short commit messages. 28 | 29 | ## Code Style 30 | We use [PEP8](https://www.python.org/dev/peps/pep-0008) with a line-width of 120 characters. For 31 | docstrings we use [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html). 32 | 33 | We use [`flake8`](https://pypi.org/project/flake8/) for quick style checks and 34 | [`pylint`](https://pypi.org/project/pylint/) for thorough style checks. 35 | 36 | ## Testing 37 | Tests are written using [Pytest](https://docs.pytest.org) and executed 38 | in a separate environment using [Tox](https://tox.readthedocs.io/en/latest/). 39 | 40 | A full style check and all tests can be run by simply calling `tox` in the repository root. 41 | 42 | If you add a new feature, please also include appropriate tests to verify its intended functionality. 43 | We try to keep the code coverage close to 100%. 44 | 45 | ## Documentation 46 | The documentation uses [Sphinx](https://www.sphinx-doc.org). It can be built at 47 | `docs/build` using the respective Tox environment with `tox -e docs`. To rebuild the full 48 | documentation, `tox -e docs -- -aE` can be used. 49 | 50 | The API-documentation is generated from the numpydoc-style docstring of respective modules/classes/functions. 51 | 52 | ### Tutorials 53 | Tutorials are written as Jupyter notebooks in order to execute them using 54 | [Binder](https://mybinder.org/) or [Google 55 | Colab](https://colab.research.google.com/). 56 | They are found at [`docs/source/tutorial`](docs/source/tutorial). 57 | Their output should be empty when committing, as they will be executed when 58 | building the documentation. 59 | To reduce the building time of the documentation, their execution time should 60 | be kept short, i.e. large files like model parameters should not be downloaded 61 | automatically. 62 | To include parameter files for users, include a comment which describes how to 63 | use the full model/data, and provide the necessary code in a comment or an if-condition 64 | which always evaluates to `False`. 65 | 66 | ## Continuous Integration 67 | Linting, tests and the documentation are all checked using a Github Actions 68 | workflow which executes the appropriate tox environments. 69 | -------------------------------------------------------------------------------- /COPYING.LESSER: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | * Zennit is licensed under the GNU LESSER GENERAL PUBLIC LICENSE VERSION 3 OR 2 | LATER -- see the 'COPYING' and 'COPYING.LESSER' files in the root directory for 3 | details. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zennit 2 | ![Zennit-Logo](share/img/zennit.png) 3 | 4 | [![Documentation Status](https://readthedocs.org/projects/zennit/badge/?version=latest)](https://zennit.readthedocs.io/en/latest/?badge=latest) 5 | [![tests](https://github.com/chr5tphr/zennit/actions/workflows/tests.yml/badge.svg)](https://github.com/chr5tphr/zennit/actions/workflows/tests.yml) 6 | [![PyPI Version](https://img.shields.io/pypi/v/zennit)](https://pypi.org/project/zennit/) 7 | [![License](https://img.shields.io/pypi/l/zennit)](https://github.com/chr5tphr/zennit/blob/master/COPYING.LESSER) 8 | 9 | Zennit (**Z**ennit **e**xplains **n**eural **n**etworks **i**n **t**orch) is a 10 | high-level framework in Python using Pytorch for explaining/exploring neural 11 | networks. Its design philosophy is intended to provide high customizability and 12 | integration as a standardized solution for applying rule-based attribution 13 | methods in research, with a strong focus on Layerwise Relevance Propagation 14 | (LRP). Zennit strictly requires models to use Pytorch's `torch.nn.Module` 15 | structure (including activation functions). 16 | 17 | Zennit is currently under active development, but should be mostly stable. 18 | 19 | If you find Zennit useful for your research, please consider citing our related 20 | [paper](https://arxiv.org/abs/2106.13200): 21 | ``` 22 | @article{anders2021software, 23 | author = {Anders, Christopher J. and 24 | Neumann, David and 25 | Samek, Wojciech and 26 | Müller, Klaus-Robert and 27 | Lapuschkin, Sebastian}, 28 | title = {Software for Dataset-wide XAI: From Local Explanations to Global Insights with {Zennit}, {CoRelAy}, and {ViRelAy}}, 29 | journal = {CoRR}, 30 | volume = {abs/2106.13200}, 31 | year = {2021}, 32 | } 33 | ``` 34 | 35 | ## Documentation 36 | The latest documentation is hosted at 37 | [zennit.readthedocs.io](https://zennit.readthedocs.io/en/latest/). 38 | 39 | ## Install 40 | 41 | To install directly from PyPI using pip, use: 42 | ```shell 43 | $ pip install zennit 44 | ``` 45 | 46 | Alternatively, install from a manually cloned repository to try out the examples: 47 | ```shell 48 | $ git clone https://github.com/chr5tphr/zennit.git 49 | $ pip install ./zennit 50 | ``` 51 | 52 | ## Usage 53 | At its heart, Zennit registers hooks at Pytorch's Module level, to modify the 54 | backward pass to produce rule-based attributions like LRP (instead of the usual 55 | gradient). All rules are implemented as hooks 56 | ([`zennit/rules.py`](src/zennit/rules.py)) and most use the LRP basis 57 | `BasicHook` ([`zennit/core.py`](src/zennit/core.py)). 58 | 59 | **Composites** ([`zennit/composites.py`](src/zennit/composites.py)) are a way 60 | of choosing the right hook for the right layer. In addition to the abstract 61 | **NameMapComposite**, which assigns hooks to layers by name, and 62 | **LayerMapComposite**, which assigns hooks to layers based on their Type, there 63 | exist explicit **Composites**, some of which are `EpsilonGammaBox` (`ZBox` in 64 | input, `Epsilon` in dense, `Gamma` in convolutions) or `EpsilonPlus` (`Epsilon` 65 | in dense, `ZPlus` in convolutions). All composites may be used by directly 66 | importing from `zennit.composites`, or by using their snake-case name as key 67 | for `zennit.composites.COMPOSITES`. 68 | 69 | **Canonizers** ([`zennit/canonizers.py`](src/zennit/canonizers.py)) temporarily 70 | transform models into a canonical form, if required, like 71 | `SequentialMergeBatchNorm`, which automatically detects and merges BatchNorm 72 | layers followed by linear layers in sequential networks, or 73 | `AttributeCanonizer`, which temporarily overwrites attributes of applicable 74 | modules, e.g. to handle the residual connection in ResNet-Bottleneck modules. 75 | 76 | **Attributors** ([`zennit/attribution.py`](src/zennit/attribution.py)) directly 77 | execute the necessary steps to apply certain attribution methods, like the 78 | simple `Gradient`, `SmoothGrad` or `Occlusion`. An optional **Composite** may 79 | be passed, which will be applied during the **Attributor**'s execution to 80 | compute the modified gradient, or hybrid methods. 81 | 82 | Using all of these components, an LRP-type attribution for VGG16 with 83 | batch-norm layers with respect to label 0 may be computed using: 84 | 85 | ```python 86 | import torch 87 | from torchvision.models import vgg16_bn 88 | 89 | from zennit.composites import EpsilonGammaBox 90 | from zennit.canonizers import SequentialMergeBatchNorm 91 | from zennit.attribution import Gradient 92 | 93 | 94 | data = torch.randn(1, 3, 224, 224) 95 | model = vgg16_bn() 96 | 97 | canonizers = [SequentialMergeBatchNorm()] 98 | composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) 99 | 100 | with Gradient(model=model, composite=composite) as attributor: 101 | out, relevance = attributor(data, torch.eye(1000)[[0]]) 102 | ``` 103 | 104 | A similar setup using [the example script](share/example/feed_forward.py) 105 | produces the following attribution heatmaps: 106 | ![beacon heatmaps](share/img/beacon_vgg16_epsilon_gamma_box.png) 107 | 108 | For more details and examples, have a look at our 109 | [**documentation**](https://zennit.readthedocs.io/en/latest/). 110 | 111 | ### More Example Heatmaps 112 | More heatmaps of various attribution methods for VGG16 and ResNet50, all 113 | generated using 114 | [`share/example/feed_forward.py`](share/example/feed_forward.py), can be found 115 | below. 116 | 117 |
118 | Heatmaps for VGG16 119 | 120 | ![vgg16 heatmaps](share/img/beacon_vgg16_various.webp) 121 |
122 | 123 |
124 | Heatmaps for ResNet50 125 | 126 | ![resnet50 heatmaps](share/img/beacon_resnet50_various.webp) 127 |
128 | 129 | ## Contributing 130 | See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed instructions on how to contribute. 131 | 132 | ## License 133 | Zennit is licensed under the GNU LESSER GENERAL PUBLIC LICENSE VERSION 3 OR 134 | LATER -- see the [LICENSE](LICENSE), [COPYING](COPYING) and 135 | [COPYING.LESSER](COPYING.LESSER) files for details. 136 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 21 | 23 | 25 | 28 | 32 | 36 | 37 | 48 | 49 | 72 | 77 | 78 | 80 | 81 | 83 | image/svg+xml 84 | 86 | 87 | 88 | 89 | 94 | 101 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /docs/source/_templates/modules.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | :members: 5 | :show-inheritance: 6 | 7 | {% block attributes %} 8 | {% if attributes %} 9 | .. rubric:: {{ _('Module Attributes') }} 10 | 11 | .. autosummary:: 12 | :nosignatures: 13 | 14 | {% for item in attributes %} 15 | {{ item }} 16 | {%- endfor %} 17 | {% endif %} 18 | {% endblock %} 19 | 20 | {% block functions %} 21 | {% if functions %} 22 | .. rubric:: {{ _('Functions') }} 23 | 24 | .. autosummary:: 25 | :nosignatures: 26 | 27 | {% for item in functions %} 28 | {{ item }} 29 | {%- endfor %} 30 | {% endif %} 31 | {% endblock %} 32 | 33 | {% block classes %} 34 | {% if classes %} 35 | .. rubric:: {{ _('Classes') }} 36 | 37 | .. autosummary:: 38 | :nosignatures: 39 | 40 | {% for item in classes %} 41 | {{ item }} 42 | {%- endfor %} 43 | {% endif %} 44 | {% endblock %} 45 | 46 | {% block exceptions %} 47 | {% if exceptions %} 48 | .. rubric:: {{ _('Exceptions') }} 49 | 50 | .. autosummary:: 51 | :nosignatures: 52 | 53 | {% for item in exceptions %} 54 | {{ item }} 55 | {%- endfor %} 56 | {% endif %} 57 | {% endblock %} 58 | 59 | {% block modules %} 60 | {% if modules %} 61 | .. rubric:: Modules 62 | 63 | .. autosummary:: 64 | :toctree: 65 | :recursive: 66 | {% for item in modules %} 67 | {{ item }} 68 | {%- endfor %} 69 | {% endif %} 70 | {% endblock %} 71 | -------------------------------------------------------------------------------- /docs/source/bibliography.bib: -------------------------------------------------------------------------------- 1 | 2 | @inproceedings{zeiler2014visualizing, 3 | author = {Matthew D. Zeiler and 4 | Rob Fergus}, 5 | title = {Visualizing and Understanding Convolutional Networks}, 6 | booktitle = {Computer Vision - {ECCV} 2014 - 13th European Conference, Zurich, 7 | Switzerland, September 6-12, 2014, Proceedings, Part {I}}, 8 | series = {Lecture Notes in Computer Science}, 9 | volume = {8689}, 10 | pages = {818--833}, 11 | publisher = {Springer}, 12 | year = {2014}, 13 | url = {https://doi.org/10.1007/978-3-319-10590-1_53}, 14 | } 15 | 16 | 17 | @article{bach2015pixel, 18 | author = {Sebastian Bach and 19 | Alexander Binder and 20 | Gr{\'e}goire Montavon and 21 | Frederick Klauschen and 22 | Klaus-Robert M{\"u}ller and 23 | Wojciech Samek}, 24 | title = {On pixel-wise explanations for non-linear classifier decisions by 25 | layer-wise relevance propagation}, 26 | journal = {PloS one}, 27 | volume = {10}, 28 | number = {7}, 29 | pages = {e0130140}, 30 | year = {2015}, 31 | publisher = {Public Library of Science San Francisco, CA USA}, 32 | url = {https://doi.org/10.1371/journal.pone.0130140} 33 | } 34 | 35 | @inproceedings{springenberg2015striving, 36 | author = {Jost Tobias Springenberg and 37 | Alexey Dosovitskiy and 38 | Thomas Brox and 39 | Martin A. Riedmiller}, 40 | title = {Striving for Simplicity: The All Convolutional Net}, 41 | booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015, 42 | San Diego, CA, USA, May 7-9, 2015, Workshop Track Proceedings}, 43 | year = {2015}, 44 | url = {http://arxiv.org/abs/1412.6806}, 45 | } 46 | 47 | @inproceedings{zhang2016top, 48 | author = {Jianming Zhang and 49 | Zhe L. Lin and 50 | Jonathan Brandt and 51 | Xiaohui Shen and 52 | Stan Sclaroff}, 53 | title = {Top-Down Neural Attention by Excitation Backprop}, 54 | booktitle = {Computer Vision - {ECCV} 2016 - 14th European Conference, Amsterdam, 55 | The Netherlands, October 11-14, 2016, Proceedings, Part {IV}}, 56 | series = {Lecture Notes in Computer Science}, 57 | volume = {9908}, 58 | pages = {543--559}, 59 | publisher = {Springer}, 60 | year = {2016}, 61 | url = {https://doi.org/10.1007/978-3-319-46493-0_33}, 62 | } 63 | 64 | @article{montavon2017explaining, 65 | author = {Gr{\'{e}}goire Montavon and 66 | Sebastian Lapuschkin and 67 | Alexander Binder and 68 | Wojciech Samek and 69 | Klaus{-}Robert M{\"{u}}ller}, 70 | title = {Explaining nonlinear classification decisions with deep Taylor decomposition}, 71 | journal = {Pattern Recognit.}, 72 | volume = {65}, 73 | pages = {211--222}, 74 | year = {2017}, 75 | url = {https://doi.org/10.1016/j.patcog.2016.11.008}, 76 | } 77 | 78 | @inproceedings{sundararajan2017axiomatic, 79 | author = {Mukund Sundararajan and 80 | Ankur Taly and 81 | Qiqi Yan}, 82 | title = {Axiomatic Attribution for Deep Networks}, 83 | booktitle = {Proceedings of the 34th International Conference on Machine Learning, 84 | {ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017}, 85 | series = {Proceedings of Machine Learning Research}, 86 | volume = {70}, 87 | pages = {3319--3328}, 88 | publisher = {{PMLR}}, 89 | year = {2017}, 90 | url = {http://proceedings.mlr.press/v70/sundararajan17a.html}, 91 | } 92 | 93 | @article{smilkov2017smoothgrad, 94 | author = {Daniel Smilkov and 95 | Nikhil Thorat and 96 | Been Kim and 97 | Fernanda B. Vi{\'{e}}gas and 98 | Martin Wattenberg}, 99 | title = {SmoothGrad: removing noise by adding noise}, 100 | journal = {CoRR}, 101 | volume = {abs/1706.03825}, 102 | year = {2017}, 103 | url = {https://arxiv.org/abs/1706.03825}, 104 | } 105 | 106 | @article{DBLP:journals/corr/abs-1902-10178, 107 | author = {Sebastian Lapuschkin and 108 | Stephan W{\"{a}}ldchen and 109 | Alexander Binder and 110 | Gr{\'{e}}goire Montavon and 111 | Wojciech Samek and 112 | Klaus{-}Robert M{\"{u}}ller}, 113 | title = {Unmasking Clever Hans Predictors and Assessing What Machines Really 114 | Learn}, 115 | journal = {CoRR}, 116 | volume = {abs/1902.10178}, 117 | year = {2019}, 118 | url = {http://arxiv.org/abs/1902.10178}, 119 | } 120 | 121 | @article{lapuschkin2019unmasking, 122 | title = {Unmasking Clever Hans predictors and assessing what machines really learn}, 123 | author = {Sebastian Lapuschkin and 124 | Stephan W{\"a}ldchen and 125 | Alexander Binder and 126 | Gr{\'e}goire Montavon and 127 | Wojciech Samek and 128 | Klaus-Robert M{\"u}ller}, 129 | journal = {Nature communications}, 130 | volume = {10}, 131 | number = {1}, 132 | pages = {1--8}, 133 | year = {2019}, 134 | publisher = {Nature Publishing Group}, 135 | url = {https://doi.org/10.1038/s41467-019-08987-4} 136 | } 137 | 138 | 139 | @incollection{montavon2019layer, 140 | author = {Gr{\'{e}}goire Montavon and 141 | Alexander Binder and 142 | Sebastian Lapuschkin and 143 | Wojciech Samek and 144 | Klaus{-}Robert M{\"{u}}ller}, 145 | title = {Layer-Wise Relevance Propagation: An Overview}, 146 | booktitle = {Explainable {AI:} Interpreting, Explaining and Visualizing Deep Learning}, 147 | series = {Lecture Notes in Computer Science}, 148 | volume = {11700}, 149 | pages = {193--209}, 150 | publisher = {Springer}, 151 | year = {2019}, 152 | url = {https://doi.org/10.1007/978-3-030-28954-6_10}, 153 | } 154 | 155 | @inproceedings{dombrowski2019explanations, 156 | author = {Ann{-}Kathrin Dombrowski and 157 | Maximilian Alber and 158 | Christopher J. Anders and 159 | Marcel Ackermann and 160 | Klaus{-}Robert M{\"{u}}ller and 161 | Pan Kessel}, 162 | title = {Explanations can be manipulated and geometry is to blame}, 163 | booktitle = {Advances in Neural Information Processing Systems 32: Annual Conference 164 | on Neural Information Processing Systems 2019, NeurIPS 2019, December 165 | 8-14, 2019, Vancouver, BC, Canada}, 166 | pages = {13567--13578}, 167 | year = {2019}, 168 | url = {https://proceedings.neurips.cc/paper/2019/hash/bb836c01cdc9120a9c984c525e4b1a4a-Abstract.html}, 169 | } 170 | 171 | @inproceedings{anders2020fairwashing, 172 | author = {Christopher J. Anders and 173 | Plamen Pasliev and 174 | Ann{-}Kathrin Dombrowski and 175 | Klaus{-}Robert M{\"{u}}ller and 176 | Pan Kessel}, 177 | title = {Fairwashing explanations with off-manifold detergent}, 178 | booktitle = {Proceedings of the 37th International Conference on Machine Learning, 179 | {ICML} 2020, 13-18 July 2020, Virtual Event}, 180 | series = {Proceedings of Machine Learning Research}, 181 | volume = {119}, 182 | pages = {314--323}, 183 | publisher = {{PMLR}}, 184 | year = {2020}, 185 | url = {http://proceedings.mlr.press/v119/anders20a.html}, 186 | } 187 | 188 | @article{anders2021software, 189 | author = {Christopher J. Anders and 190 | David Neumann and 191 | Wojciech Samek and 192 | Klaus{-}Robert M{\"{u}}ller and 193 | Sebastian Lapuschkin}, 194 | title = {Software for Dataset-wide {XAI:} From Local Explanations to Global 195 | Insights with Zennit, CoRelAy, and ViRelAy}, 196 | journal = {CoRR}, 197 | volume = {abs/2106.13200}, 198 | year = {2021}, 199 | url = {https://arxiv.org/abs/2106.13200}, 200 | } 201 | 202 | @article{andeol2021learning, 203 | author = {L{\'{e}}o And{\'{e}}ol and 204 | Yusei Kawakami and 205 | Yuichiro Wada and 206 | Takafumi Kanamori and 207 | Klaus{-}Robert M{\"{u}}ller and 208 | Gr{\'{e}}goire Montavon}, 209 | title = {Learning Domain Invariant Representations by Joint Wasserstein Distance 210 | Minimization}, 211 | journal = {CoRR}, 212 | volume = {abs/2106.04923}, 213 | year = {2021}, 214 | url = {https://arxiv.org/abs/2106.04923}, 215 | } 216 | -------------------------------------------------------------------------------- /docs/source/bibliography.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Bibliography 3 | ============ 4 | 5 | .. bibliography:: 6 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | import sys 17 | import os 18 | from subprocess import run, CalledProcessError 19 | import inspect 20 | import pkg_resources 21 | 22 | from pybtex.style.formatting.plain import Style as PlainStyle 23 | from pybtex.style.labels import BaseLabelStyle 24 | from pybtex.plugin import register_plugin 25 | 26 | 27 | # -- Project information ----------------------------------------------------- 28 | project = 'zennit' 29 | copyright = '2021, chr5tphr' 30 | author = 'chr5tphr' 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be 36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 37 | # ones. 38 | extensions = [ 39 | 'sphinx.ext.autodoc', 40 | 'sphinx.ext.autosummary', 41 | 'sphinx.ext.napoleon', 42 | 'sphinx.ext.linkcode', 43 | 'sphinx.ext.mathjax', 44 | 'sphinx.ext.extlinks', 45 | 'sphinx_rtd_theme', 46 | 'sphinx_copybutton', 47 | 'sphinxcontrib.datatemplates', 48 | 'sphinxcontrib.bibtex', 49 | 'nbsphinx', 50 | ] 51 | 52 | 53 | def config_inited_handler(app, config): 54 | os.makedirs(os.path.join(app.srcdir, app.config.generated_path), exist_ok=True) 55 | 56 | 57 | def setup(app): 58 | app.add_config_value('REVISION', 'master', 'env') 59 | app.add_config_value('generated_path', '_generated', 'env') 60 | app.connect('config-inited', config_inited_handler) 61 | 62 | 63 | # Add any paths that contain templates here, relative to this directory. 64 | templates_path = ['_templates'] 65 | 66 | # List of patterns, relative to source directory, that match files and 67 | # directories to ignore when looking for source files. 68 | # This pattern also affects html_static_path and html_extra_path. 69 | exclude_patterns = [] 70 | 71 | # interactive badges for binder and colab 72 | nbsphinx_prolog = r""" 73 | {% set docname = 'docs/source/' + env.doc2path(env.docname, base=False) %} 74 | 75 | .. raw:: html 76 | 77 |
78 | This page was generated from 79 | {{ docname|e }} 80 |
81 | Interactive online version: 82 | 83 | 84 | launch binder 85 | 86 | 87 | 88 | 89 | Open in Colab 90 | 91 | 92 |
93 | """ 94 | 95 | # autosummary_generate = True 96 | 97 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 98 | copybutton_prompt_is_regexp = True 99 | copybutton_line_continuation_character = "\\" 100 | copybutton_here_doc_delimiter = "EOT" 101 | 102 | # -- Options for HTML output ------------------------------------------------- 103 | 104 | # The theme to use for HTML and HTML Help pages. See the documentation for 105 | # a list of builtin themes. 106 | # 107 | html_theme = 'sphinx_rtd_theme' 108 | # html_theme = 'alabaster' 109 | 110 | html_favicon = '_static/favicon.svg' 111 | 112 | # Add any paths that contain custom static files (such as style sheets) here, 113 | # relative to this directory. They are copied after the builtin static files, 114 | # so a file named "default.css" will overwrite the builtin "default.css". 115 | html_static_path = ['_static'] 116 | 117 | bibtex_bibfiles = ['bibliography.bib'] 118 | bibtex_default_style = 'author_year_style' 119 | bibtex_reference_style = 'author_year' 120 | 121 | 122 | class AuthorYearLabelStyle(BaseLabelStyle): 123 | def format_labels(self, sorted_entries): 124 | for entry in sorted_entries: 125 | yield f'[{entry.persons["author"][0].last_names[0]} et al., {entry.fields["year"]}]' 126 | 127 | 128 | class AuthorYearStyle(PlainStyle): 129 | default_label_style = AuthorYearLabelStyle 130 | 131 | 132 | register_plugin('pybtex.style.formatting', 'author_year_style', AuthorYearStyle) 133 | 134 | 135 | def getrev(): 136 | try: 137 | revision = run( 138 | ['git', 'describe', '--tags', 'HEAD'], 139 | capture_output=True, 140 | check=True, 141 | text=True 142 | ).stdout[:-1] 143 | except CalledProcessError: 144 | revision = 'master' 145 | 146 | return revision 147 | 148 | 149 | REVISION = getrev() 150 | 151 | extlinks = { 152 | 'repo': ( 153 | f'https://github.com/chr5tphr/zennit/blob/{REVISION}/%s', 154 | '%s' 155 | ) 156 | } 157 | 158 | LINKCODE_URL = ( 159 | f'https://github.com/chr5tphr/zennit/blob/{REVISION}' 160 | '/src/{filepath}#L{linestart}-L{linestop}' 161 | ) 162 | 163 | 164 | # revised from https://gist.github.com/nlgranger/55ff2e7ff10c280731348a16d569cb73 165 | def linkcode_resolve(domain, info): 166 | if domain != 'py' or not info['module']: 167 | return None 168 | 169 | modname = info['module'] 170 | topmodulename = modname.split('.')[0] 171 | fullname = info['fullname'] 172 | 173 | submod = sys.modules.get(modname) 174 | if submod is None: 175 | return None 176 | 177 | obj = submod 178 | for part in fullname.split('.'): 179 | try: 180 | obj = getattr(obj, part) 181 | except Exception: 182 | return None 183 | 184 | try: 185 | modpath = pkg_resources.require(topmodulename)[0].location 186 | filepath = os.path.relpath(inspect.getsourcefile(obj), modpath) 187 | if filepath is None: 188 | return 189 | except Exception: 190 | return None 191 | 192 | try: 193 | source, lineno = inspect.getsourcelines(obj) 194 | except OSError: 195 | return None 196 | else: 197 | linestart, linestop = lineno, lineno + len(source) - 1 198 | 199 | return LINKCODE_URL.format(filepath=filepath, linestart=linestart, linestop=linestop) 200 | -------------------------------------------------------------------------------- /docs/source/getting-started.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | Getting started 3 | ================ 4 | 5 | 6 | Install 7 | ------- 8 | 9 | Zennit can be installed directly from PyPI: 10 | 11 | .. code-block:: console 12 | 13 | $ pip install zennit 14 | 15 | For the current development version, or to try out examples, Zennit may be 16 | alternatively cloned and installed with 17 | 18 | .. code-block:: console 19 | 20 | $ git clone https://github.com/chr5tphr/zennit.git 21 | $ pip install ./zennit 22 | 23 | Basic Usage 24 | ----------- 25 | 26 | Zennit implements propagation-based attribution methods by overwriting the 27 | gradient of PyTorch modules in PyTorch's auto-differentiation engine. This means 28 | that Zennit will only work on models which are strictly implemented using 29 | PyTorch modules, including activation functions. The following demonstrates a 30 | setup to compute Layer-wise Relevance Propagation (LRP) relevance for a simple 31 | model and random data. 32 | 33 | .. code-block:: python 34 | 35 | import torch 36 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten 37 | 38 | 39 | # setup the model and data 40 | model = Sequential( 41 | Conv2d(3, 10, 3, padding=1), 42 | ReLU(), 43 | Flatten(), 44 | Linear(10 * 32 * 32, 10), 45 | ) 46 | input = torch.randn(1, 3, 32, 32) 47 | 48 | The most important high-level structures in Zennit are ``Composites``, 49 | ``Attributors`` and ``Canonizers``. 50 | 51 | 52 | Composites 53 | ^^^^^^^^^^ 54 | 55 | Composites map ``Rules`` to modules based on their properties and context to 56 | modify their gradient. The most common composites for LRP are implemented in 57 | :py:mod:`zennit.composites`. 58 | 59 | The following computes LRP relevance using the ``EpsilonPlusFlat`` composite: 60 | 61 | .. code-block:: python 62 | 63 | from zennit.composites import EpsilonPlusFlat 64 | 65 | 66 | # create a composite instance 67 | composite = EpsilonPlusFlat() 68 | 69 | # use the following instead to ignore bias for the relevance 70 | # composite = EpsilonPlusFlat(zero_params='bias') 71 | 72 | # make sure the input requires a gradient 73 | input.requires_grad = True 74 | 75 | # compute the output and gradient within the composite's context 76 | with composite.context(model) as modified_model: 77 | output = modified_model(input) 78 | # gradient/ relevance wrt. class/output 0 79 | output.backward(gradient=torch.eye(10)[[0]]) 80 | # relevance is not accumulated in .grad if using torch.autograd.grad 81 | # relevance, = torch.autograd.grad(output, input, torch.eye(10)[[0]) 82 | 83 | # gradient is accumulated in input.grad 84 | print('Backward:', input.grad) 85 | 86 | 87 | The context created by :py:func:`zennit.core.Composite.context` registers the 88 | composite, which means that all rules are applied according to the composite's 89 | mapping. See :doc:`/how-to/use-rules-composites-and-canonizers` for information on 90 | using composites, :py:mod:`zennit.composites` for an API reference and 91 | :doc:`/how-to/write-custom-composites` for writing new compositors. Available 92 | ``Rules`` can be found in :py:mod:`zennit.rules`, their use is described in 93 | :doc:`/how-to/use-rules-composites-and-canonizers` and how to add new ones is described in 94 | :doc:`/how-to/write-custom-rules`. 95 | 96 | Attributors 97 | ^^^^^^^^^^^ 98 | 99 | Alternatively, *attributors* may be used instead of ``composite.context``. 100 | 101 | .. code-block:: python 102 | 103 | from zennit.attribution import Gradient 104 | 105 | 106 | attributor = Gradient(model, composite) 107 | 108 | with attributor: 109 | # gradient/ relevance wrt. output/class 1 110 | output, relevance = attributor(input, torch.eye(10)[[1]]) 111 | 112 | print('EpsilonPlusFlat:', relevance) 113 | 114 | Attribution methods which are not propagation-based, like 115 | :py:class:`zennit.attribution.SmoothGrad` are implemented as attributors, and 116 | may be combined with propagation-based (composite) approaches. 117 | 118 | .. code-block:: python 119 | 120 | from zennit.attribution import SmoothGrad 121 | 122 | 123 | # we do not need a composite to compute vanilla SmoothGrad 124 | with SmoothGrad(model, noise_level=0.1, n_iter=10) as attributor: 125 | # gradient/ relevance wrt. output/class 7 126 | output, relevance = attributor(input, torch.eye(10)[[7]]) 127 | 128 | print('SmoothGrad:', relevance) 129 | 130 | More information on attributors can be found in :doc:`/how-to/use-attributors` 131 | and :doc:`/how-to/write-custom-attributors`. 132 | 133 | Canonizers 134 | ^^^^^^^^^^ 135 | 136 | For some modules and operations, Layer-wise Relevance Propagation (LRP) is not 137 | implementation-invariant, eg. ``BatchNorm -> Dense -> ReLU`` will be attributed 138 | differently than ``Dense -> BatchNorm -> ReLU``. Therefore, LRP needs a 139 | canonical form of the model, which is implemented in ``Canonizers``. These may 140 | be simply supplied when instantiating a composite: 141 | 142 | .. code-block:: python 143 | 144 | from torchvision.models import vgg16 145 | from zennit.composites import EpsilonGammaBox 146 | from zennit.torchvision import VGGCanonizer 147 | 148 | 149 | # instantiate the model 150 | model = vgg16() 151 | # create the canonizers 152 | canonizers = [VGGCanonizer()] 153 | # EpsilonGammaBox needs keyword arguments 'low' and 'high' 154 | composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers) 155 | 156 | with Gradient(model, composite) as attributor: 157 | # gradient/ relevance wrt. output/class 0 158 | # torchvision.vgg16 has 1000 output classes by default 159 | output, relevance = attributor(input, torch.eye(1000)[[0]]) 160 | 161 | print('EpsilonGammaBox:', relevance) 162 | 163 | Some pre-defined canonizers for models from ``torchvision`` can be found in 164 | :py:mod:`zennit.torchvision`. The :py:class:`zennit.torchvision.VGGCanonizer` 165 | specifically is simply :py:class:`zennit.canonizers.SequentialMergeBatchNorm`, 166 | which may be used when ``BatchNorm`` is used in sequential models. Note that for 167 | ``SequentialMergeBatchNorm`` to work, all functions (linear layers, activations, 168 | ...) must be modules and assigned to their parent module in the order they are 169 | visited (see :py:class:`zennit.canonizers.SequentialMergeBatchNorm`). For more 170 | information on canonizers see :doc:`/how-to/use-rules-composites-and-canonizers` and 171 | :doc:`/how-to/write-custom-canonizers`. 172 | 173 | 174 | Visualizing Results 175 | ^^^^^^^^^^^^^^^^^^^ 176 | 177 | While attribution approaches are not limited to the domain of images, they are 178 | predominantly used on image models and produce heat maps of relevance. For 179 | this reason, Zennit implements methods to visualize relevance heat maps. 180 | 181 | .. code-block:: python 182 | 183 | from zennit.image import imsave 184 | 185 | 186 | # sum over the color channels 187 | heatmap = relevance.sum(1) 188 | # get the absolute maximum, to center the heat map around 0 189 | amax = heatmap.abs().numpy().max((1, 2)) 190 | 191 | # save heat map with color map 'coldnhot' 192 | imsave( 193 | 'heatmap.png', 194 | heatmap[0], 195 | vmin=-amax, 196 | vmax=amax, 197 | cmap='coldnhot', 198 | level=1.0, 199 | grid=False 200 | ) 201 | 202 | Information on ``imsave`` can be found at :py:func:`zennit.image.imsave`. 203 | Saving an image with 3 color channels will result in the image being saved 204 | without a color map but with the channels assumed as RGB. The keyword argument 205 | ``grid`` will create a grid of multiple images over the batch dimension if 206 | ``True``. Custom color maps may be created with 207 | :py:class:`zennit.cmap.ColorMap`, eg. to save the previous image with a color 208 | map ranging from blue to yellow to red: 209 | 210 | .. code-block:: python 211 | 212 | from zennit.cmap import ColorMap 213 | 214 | 215 | # 00f is blue, ff0 is yellow, f00 is red, 0x80 is the center of the range 216 | cmap = ColorMap('00f,80:ff0,f00') 217 | 218 | imsave( 219 | 'heatmap.png', 220 | heatmap, 221 | vmin=-amax, 222 | vmax=amax, 223 | cmap=cmap, 224 | level=1.0, 225 | grid=True 226 | ) 227 | 228 | More details to visualize heat maps and color maps can be found in 229 | :doc:`/how-to/visualize-results`. The ColorMap specification language is 230 | described in :py:class:`zennit.cmap.ColorMap` and built-in color maps are 231 | implemented in :py:obj:`zennit.image.CMAPS`. 232 | 233 | Example Script 234 | -------------- 235 | 236 | A ready-to use example to analyze a few ImageNet models provided by torchvision 237 | can be found at :repo:`share/example/feed_forward.py`. 238 | 239 | The following setup requires bash, cURL and (magic-)file. 240 | 241 | Create a virtual environment, install Zennit and download the example scripts: 242 | 243 | .. code-block:: console 244 | 245 | $ mkdir zennit-example 246 | $ cd zennit-example 247 | $ python -m venv .venv 248 | $ .venv/bin/pip install zennit 249 | $ curl -o feed_forward.py \ 250 | 'https://raw.githubusercontent.com/chr5tphr/zennit/master/share/example/feed_forward.py' 251 | $ curl -o download-lighthouses.sh \ 252 | 'https://raw.githubusercontent.com/chr5tphr/zennit/master/share/scripts/download-lighthouses.sh' 253 | 254 | Prepare the data required for the example: 255 | 256 | .. code-block:: console 257 | 258 | $ mkdir params data results 259 | $ bash download-lighthouses.sh --output data/lighthouses 260 | $ curl -o params/vgg16-397923af.pth 'https://download.pytorch.org/models/vgg16-397923af.pth' 261 | 262 | This creates the needed directories and downloads the pre-trained vgg16 263 | parameters and 8 images of light houses from wikimedia commons into the 264 | required label-directory structure for the imagenet dataset in PyTorch. 265 | 266 | The ``feed_forward.py`` example can then be run using: 267 | 268 | .. code-block:: console 269 | 270 | $ .venv/bin/python feed_forward.py \ 271 | data/lighthouses \ 272 | 'results/vgg16_epsilon_gamma_box_{sample:02d}.png' \ 273 | --inputs 'results/vgg16_input_{sample:02d}.png' \ 274 | --parameters params/vgg16-397923af.pth \ 275 | --model vgg16 \ 276 | --composite epsilon_gamma_box \ 277 | --no-bias \ 278 | --relevance-norm symmetric \ 279 | --cmap coldnhot 280 | 281 | which computes the lrp heatmaps according to the ``epsilon_gamma_box`` rule and 282 | stores them in results, along with the respective input images. Other possible 283 | composites that can be passed to ``--composites`` are, e.g., ``epsilon_plus``, 284 | ``epsilon_alpha2_beta1_flat``, ``guided_backprop``, ``excitation_backprop``. 285 | The bias can be ignored in the LRP-computation by passing ``--no-bias``. 286 | 287 | 288 | .. 289 | The resulting heatmaps may look like the following: 290 | 291 | .. image:: /img/beacon_vgg16_epsilon_gamma_box.png 292 | :alt: Lighthouses with Attributions 293 | 294 | Alternatively, heatmaps for SmoothGrad with absolute relevances may be computed 295 | by omitting ``--composite`` and supplying ``--attributor``: 296 | 297 | .. code-block:: console 298 | 299 | $ .venv/bin/python feed_forward.py \ 300 | data/lighthouses \ 301 | 'results/vgg16_smoothgrad_{sample:02d}.png' \ 302 | --inputs 'results/vgg16_input_{sample:02d}.png' \ 303 | --parameters params/vgg16-397923af.pth \ 304 | --model vgg16 \ 305 | --attributor smoothgrad \ 306 | --relevance-norm absolute \ 307 | --cmap hot 308 | 309 | For Integrated Gradients, ``--attributor integrads`` may be provided. 310 | 311 | Heatmaps for Occlusion Analysis with unaligned relevances may be computed by 312 | executing: 313 | 314 | .. code-block:: console 315 | 316 | $ .venv/bin/python feed_forward.py \ 317 | data/lighthouses \ 318 | 'results/vgg16_occlusion_{sample:02d}.png' \ 319 | --inputs 'results/vgg16_input_{sample:02d}.png' \ 320 | --parameters params/vgg16-397923af.pth \ 321 | --model vgg16 \ 322 | --attributor occlusion \ 323 | --relevance-norm unaligned \ 324 | --cmap hot 325 | 326 | -------------------------------------------------------------------------------- /docs/source/how-to/compute-second-order-gradients.rst: -------------------------------------------------------------------------------- 1 | ================================ 2 | Computing Second Order Gradients 3 | ================================ 4 | 5 | Sometimes, it may be necessary to compute the gradient of the attribution. One 6 | example is to compute the gradient with respect to the input in order to 7 | find adversarial explanations :cite:p:`dombrowski2019explanations`, 8 | or to regularize or transform the attributions of a network 9 | :cite:p:`anders2020fairwashing`. 10 | 11 | In Zennit, the attribution is computed using the modified gradient, which means 12 | that in order to compute the gradient of the attribution, the second order 13 | gradient needs to be computed. Pytorch natively supports the computation of 14 | higher order gradients, simply by supplying ``create_graph=True`` with 15 | :py:func:`torch.autograd.grad` to declare that the backward-function needs to 16 | be backward-able itself. 17 | 18 | 19 | Vanilla Gradient and ReLU 20 | ------------------------- 21 | 22 | If we simply need the second order gradient of a model, without using Zennit, we can do the following: 23 | 24 | .. code-block:: python 25 | 26 | import torch 27 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten 28 | 29 | 30 | # setup the model and data 31 | model = Sequential( 32 | Conv2d(3, 10, 3, padding=1), 33 | ReLU(), 34 | Flatten(), 35 | Linear(10 * 32 * 32, 10), 36 | ) 37 | input = torch.randn(1, 3, 32, 32) 38 | 39 | # make sure the input requires a gradient 40 | input.requires_grad = True 41 | 42 | output = model(input) 43 | # a vector for the vector-jacobian-product, i.e. the grad_output 44 | target = torch.ones_like(output) 45 | 46 | grad, = torch.autograd.grad(output, input, target, create_graph=True) 47 | 48 | # the grad_output for grad 49 | gradtarget = torch.ones_like(grad) 50 | # compute the second order gradient 51 | gradgrad, = torch.autograd.grad(grad, input, gradtarget) 52 | 53 | Here, you might notice that ``gradgrad`` is all zeros, regardless of the input 54 | and model parameters. The culprit is ``ReLU``, which has a gradient of zero 55 | everywhere except at zero, where it is undefined. In order to get a meaningful 56 | gradient, we could instead use a *smooth* activation function in our model. 57 | However, ReLU models are quite common, and we may not like to retrain every 58 | model using only smooth activation functions. 59 | 60 | :cite:t:`dombrowski2019explanations` proposed to replace the ReLU activations 61 | with its smooth variation, the *Softplus* function: 62 | 63 | .. math:: 64 | 65 | \text{Softplus}(x;\beta) = \frac{1}{\beta} \log (1 + \exp (\beta x)) 66 | \,\text{.} 67 | 68 | With :math:`\beta\rightarrow\infty`, Softplus will be equivalent to ReLU, but in 69 | practice choosing :math:`\beta = 10` is most often sufficient to keep the model 70 | output unchanged but still obtain a meaningful second order gradient. 71 | 72 | To temporarily replace the ReLU gradients in-place, we can use the 73 | :py:class:`~zennit.rules.ReLUBetaSmooth` rule: 74 | 75 | 76 | .. code-block:: python 77 | 78 | from zennit.composites import BetaSmooth 79 | 80 | # LayerMapComposite which assigns the ReLUBetaSmooth hook to ReLUs 81 | composite = BetaSmooth(beta_smooth=10.) 82 | 83 | with composite.context(model): 84 | output = model(input) 85 | target = torch.ones_like(output) 86 | grad, = torch.autograd.grad(output, input, target, create_graph=True) 87 | 88 | gradtarget = torch.ones_like(grad) 89 | gradgrad, = torch.autograd.grad(grad, input, gradtarget) 90 | 91 | Notice here that we computed the second order gradient **outside** of the 92 | composite context. A property of the Pytorch gradients hooks is that they are 93 | also called when the *second* order gradient with respect to a tensor is 94 | computed. 95 | Due to this, computing the second order gradient *while rules are still 96 | registered* will lead to incorrect results. 97 | 98 | Temporarily Disabling Hooks 99 | --------------------------- 100 | 101 | In order compute the second order gradient *without* removing the hooks (i.e. to 102 | compute multiple values in a loop), we can temporarily deactivate them using 103 | :py:meth:`zennit.core.Composite.inactive`: 104 | 105 | .. code-block:: python 106 | 107 | with composite.context(model): 108 | output = model(input) 109 | target = torch.ones_like(output) 110 | grad, = torch.autograd.grad(output, input, target, create_graph=True) 111 | 112 | # temporarily disable all hooks registered by composite 113 | with composite.inactive(): 114 | gradtarget = torch.ones_like(grad) 115 | gradgrad, = torch.autograd.grad(grad, input, gradtarget) 116 | 117 | All Attributors support the computation of gradients. For gradient-based 118 | attributors like :py:class:`~zennit.attribution.Gradient` or 119 | :py:class:`~zennit.attribution.SmoothGrad`, the ``create_graph=True`` parameter 120 | can be supplied to the class constructor: 121 | 122 | .. code-block:: python 123 | 124 | from zennit.attribution import Gradient 125 | from zennit.composites import EpsilonGammaBox 126 | 127 | # any composites support second order gradients 128 | composite = EpsilonGammaBox(low=-3., high=3.) 129 | 130 | with Gradient(model, composite, create_graph=True) as attributor: 131 | output, grad = attributor(input, torch.ones_like) 132 | 133 | # temporarily disable all hooks registered by the attributor's composite 134 | with attributor.inactive(): 135 | gradtarget = torch.ones_like(grad) 136 | gradgrad, = torch.autograd.grad(grad, input, gradtarget) 137 | 138 | Here, we also used a different composite, which results in the gradient 139 | computation of the modified gradient. Since the ReLU gradient is ignored (using 140 | the :py:class:`~zennit.rules.Pass` rule) for Layer-wise Relevance 141 | Propagation-specific composites, we do not need to use the 142 | :py:class:`~zennit.rules.ReLUBetaSmooth` rule. However, if this behaviour 143 | should be overwritten, :ref:`cooperative-layermapcomposites` can be used. 144 | 145 | Using Hooks Only 146 | ---------------- 147 | 148 | Under the hood, :py:class:`~zennit.core.Hook` has an attribute ``active``, 149 | which, when set to ``False``, will not execute the associated backward function. 150 | A minimal example without using composites would look like the following: 151 | 152 | .. code-block:: python 153 | 154 | from zennit.rules import Epsilon 155 | 156 | conv = Conv2d(3, 10, 3, padding=1) 157 | 158 | # create and register the hook 159 | epsilon = Epsilon() 160 | handles = epsilon.register(conv) 161 | 162 | output = conv(input) 163 | target = torch.ones_like(output) 164 | grad, = torch.autograd.grad(output, input, target, create_graph=True) 165 | 166 | # during this block, epsilon will be inactive 167 | epsilon.active = False 168 | grad_target = torch.ones_like(grad) 169 | gradgrad, = torch.autograd.grad(grad, input, grad_target) 170 | epsilon.active = True 171 | 172 | # after calling handles.remove, epsilon will also be inactive 173 | handles.remove() 174 | 175 | The same can here also be achieved by simply removing the handles before calling 176 | ``torch.autograd.grad`` on ``grad``, although the hooks would then need to be 177 | re-registered in order to compute the epsilon-modified gradient again. 178 | -------------------------------------------------------------------------------- /docs/source/how-to/get-intermediate-relevance.rst: -------------------------------------------------------------------------------- 1 | ============================== 2 | Getting Intermediate Relevance 3 | ============================== 4 | 5 | In some cases, intermediate gradients or relevances of a model may be needed. 6 | Since Zennit uses Pytorch's autograd engine, intermediate relevances can be 7 | retained simply as the intermediate gradients of accessible non-leaf tensors 8 | in the tensor's ``.grad`` attribute by calling ``tensor.retain_grad()`` before 9 | the gradient computation. 10 | 11 | In most cases when using ``torch.nn.Module``-based models, the intermediate 12 | outputs are not easily accessible, which we can solve by using forward-hooks. 13 | 14 | We create following setting with some random input data and a simple, randomly 15 | initialized model, for which we want to compute the LRP EpsilonPlus relevance: 16 | 17 | .. code-block:: python 18 | 19 | import torch 20 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten 21 | 22 | from zennit.attribution import Gradient 23 | from zennit.composites import EpsilonPlusFlat 24 | 25 | # setup the model and data 26 | model = Sequential( 27 | Conv2d(3, 10, 3, padding=1), 28 | ReLU(), 29 | Flatten(), 30 | Linear(10 * 32 * 32, 10), 31 | ) 32 | input = torch.randn(1, 3, 32, 32) 33 | 34 | # make sure the input requires a gradient 35 | input.requires_grad = True 36 | 37 | # create a composite instance 38 | composite = EpsilonPlusFlat() 39 | 40 | # create a gradient attributor 41 | attributor = Gradient(model, composite) 42 | 43 | Now we create a function ``store_hook`` which we register as a forward hook to 44 | all modules. The function sets the module's attribute ``.output`` to its output 45 | tensor, and ensures the gradient is stored in the tensor's ``.grad`` attribute 46 | even if it is not a leaf-tensor by using ``.retain_grad()``. 47 | 48 | .. code-block:: python 49 | 50 | # create a hook to keep track of intermediate outputs 51 | def store_hook(module, input, output): 52 | # set the current module's attribute 'output' to the its tensor 53 | module.output = output 54 | # keep the output tensor gradient, even if it is not a leaf-tensor 55 | output.retain_grad() 56 | 57 | # enter the attributor's context to register the rule-hooks 58 | with attributor: 59 | # register the store_hook AFTER the rule-hooks have been registered (by 60 | # entering the context) so we get the last output before the next module 61 | handles = [ 62 | module.register_forward_hook(store_hook) for module in model.modules() 63 | ] 64 | # compute the relevance wrt. output/class 1 65 | output, relevance = attributor(input, torch.eye(10)[[1]]) 66 | 67 | # remove the hooks using store_hook 68 | for handle in handles: 69 | handle.remove() 70 | 71 | # print the gradient tensors for demonstration 72 | for name, module in model.named_modules(): 73 | print(f'{name}: {module.output.grad}') 74 | 75 | The hooks are registered within the attributor's with-context, such that they 76 | are applied after the rule hooks. Once we are finished, we can remove the 77 | store-hooks by calling ``.remove()`` on all handles returned when registering the 78 | hooks. 79 | 80 | Be aware that storing the intermediate outputs and their gradients may require 81 | significantly more memory, depending on the model. In practice, it may be better 82 | to register the store-hook only to modules for which the relevance is needed. 83 | -------------------------------------------------------------------------------- /docs/source/how-to/index.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | How-Tos 3 | ================ 4 | 5 | 6 | These How-Tos give more detailed information on how to use Zennit. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | use-rules-composites-and-canonizers 12 | use-attributors 13 | visualize-results 14 | get-intermediate-relevance 15 | compute-second-order-gradients 16 | write-custom-composites 17 | write-custom-canonizers 18 | write-custom-rules 19 | write-custom-attributors 20 | -------------------------------------------------------------------------------- /docs/source/how-to/use-attributors.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Using Attributors 3 | ================= 4 | 5 | **Attributors** are used to both shorten Zennit's common ``composite.context -> 6 | gradient`` approach, as well as provide model-agnostic attribution approaches. 7 | Available **Attributors** can be found in :py:mod:`zennit.attribution`, some of 8 | which are: 9 | 10 | * :py:class:`~zennit.attribution.Gradient`, which computes the gradient 11 | * :py:class:`~zennit.attribution.IntegratedGradients`, which computes the 12 | Integrated Gradients 13 | * :py:class:`~zennit.attribution.SmoothGrad`, which computes SmoothGrad 14 | * :py:class:`~zennit.attribution.Occlusion`, which computes the attribution 15 | based on the model output activation values when occluding parts of the input 16 | with a sliding window 17 | 18 | Using the basic :py:class:`~zennit.attribution.Gradient`, the unmodified 19 | gradient may be computed with: 20 | 21 | .. code-block:: python 22 | 23 | import torch 24 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten 25 | from zennit.attribution import Gradient 26 | 27 | # setup the model 28 | model = Sequential( 29 | Conv2d(3, 8, 3, padding=1), 30 | ReLU(), 31 | Conv2d(8, 16, 3, padding=1), 32 | ReLU(), 33 | Flatten(), 34 | Linear(16 * 32 * 32, 1024), 35 | ReLU(), 36 | Linear(1024, 10), 37 | ) 38 | # some random input data 39 | input = torch.randn(1, 3, 32, 32, requires_grad=True) 40 | 41 | # compute the gradient and output using the Gradient attributor 42 | with Gradient(model) as attributor: 43 | output, relevance = attributor(input) 44 | 45 | Computing attributions using a composite can be done with: 46 | 47 | .. code-block:: python 48 | 49 | from zennit.composites import EpsilonPlusFlat 50 | 51 | # prepare the composite 52 | composite = EpsilonPlusFlat() 53 | 54 | # compute the gradient within the composite's context, i.e. the 55 | # EpsilonPlusFlat LRP relevance 56 | with Gradient(model, composite) as attributor: 57 | # torch.eye is used here to get a one-hot encoding of the 58 | # first (index 0) label 59 | output, relevance = attributor(input, torch.eye(10)[[0]]) 60 | 61 | which uses the second argument ``attr_output_fn`` of the call to 62 | :py:class:`~zennit.attribution.Attributor` to specify a constant tensor used for 63 | the *output relevance* (i.e. ``grad_output``), but alternatively, a function 64 | of the output may also be used: 65 | 66 | .. code-block:: python 67 | 68 | def one_hot_max(output): 69 | '''Get the one-hot encoded max at the original indices in dim=1''' 70 | values, indices = output.max(1) 71 | return values[:, None] * torch.eye(output.shape[1])[indices] 72 | 73 | with Gradient(model) as attributor: 74 | output, relevance = attributor(input, one_hot_max) 75 | 76 | The constructor of :py:class:`~zennit.attribution.Attributor` also has a third 77 | argument ``attr_output``, which also can either be a constant 78 | :py:class:`~torch.Tensor`, or a function of the model's output and specifies 79 | which *output relevance* (i.e. ``grad_output``) should be used by default. When 80 | not supplying anything, the default will be the *identity*. If the default 81 | should be for example ones for all outputs, one could write: 82 | 83 | .. code-block:: python 84 | 85 | # compute the gradient and output using the Gradient attributor, and with 86 | # a vector of ones as grad_output 87 | with Gradient(model, attr_output=torch.ones_like) as attributor: 88 | output, relevance = attributor(input) 89 | 90 | Gradient-based **Attributors** like 91 | :py:class:`~zennit.attribution.IntegratedGradients` and 92 | :py:class:`~zennit.attribution.SmoothGrad` may also be used together with 93 | composites to produce *hybrid attributions*: 94 | 95 | .. code-block:: python 96 | 97 | from zennit.attribution import SmoothGrad 98 | 99 | # prepare the composite 100 | composite = EpsilonPlusFlat() 101 | 102 | # do a *smooth* version of EpsilonPlusFlat LRP by using the SmoothGrad 103 | # attributor in combination with the composite 104 | with SmoothGrad(model, composite, noise_level=0.1, n_iter=20) as attributor: 105 | output, relevance = attributor(input, torch.eye(10)[[0]]) 106 | 107 | which in this case will sample 20 samples in an epsilon-ball (size controlled 108 | with `noise_level`) around the input. Note that for Zennit's implementation of 109 | :py:class:`~zennit.attribution.SmoothGrad`, the first sample will always be the 110 | original input, i.e. ``SmoothGrad(model, n_iter=1)`` will produce the plain 111 | gradient as ``Gradient(model)`` would. 112 | 113 | :py:class:`~zennit.attribution.Occlusion` will move a sliding window with 114 | arbitrary size and strides over an input with any dimensionality. In addition to 115 | specifying window-size and strides, a function may be specified, which will be 116 | supplied with the input and a mask. When using the default, everything within 117 | the sliding window will be set to zero. A function 118 | :py:func:`zennit.attribution.occlude_independent` is available to simplify the 119 | process of specifying how to fill the window, and to invert the window if 120 | desired. The following adds some gaussian noise to the area within the sliding 121 | window: 122 | 123 | .. code-block:: python 124 | 125 | from functools import partial 126 | from zennit.attribution import Occlusion, occlude_independent 127 | 128 | input = torch.randn((16, 3, 32, 32)) 129 | 130 | attributor = Occlusion( 131 | model, 132 | window=8, # 8x8 overlapping windows 133 | stride=4, # with strides 4x4 134 | occlusion_fn=partial( # occlusion_fn gets the full input and a mask 135 | occlude_independent, # applies fill_fn at provided mask 136 | fill_fn=lambda x: x * torch.randn_like(x) * 0.2, # add some noise 137 | invert=False # do not invert, i.e. occlude *within* mask 138 | ) 139 | ) 140 | with attributor: 141 | # for occlusion, the score for each window-pass is the sum of the 142 | # provided *grad_output*, which we choose as the model output at index 0 143 | output, relevance = attributor(input, lambda out: torch.eye(10)[[0]] * out) 144 | 145 | 146 | Note that while the interface allows to pass a composite for any 147 | :py:class:`~zennit.attribution.Attributor`, using a composite with 148 | :py:class:`~zennit.attribution.Occlusion` does not change the outcome, as it 149 | does not utilize the gradient. 150 | 151 | An introduction on how to write custom **Attributors** can be found at 152 | :doc:`/how-to/write-custom-attributors`. 153 | -------------------------------------------------------------------------------- /docs/source/how-to/write-custom-attributors.rst: -------------------------------------------------------------------------------- 1 | ========================== 2 | Writing Custom Attributors 3 | ========================== 4 | 5 | **Attributors** provide an additional layer of abstraction over the context of 6 | **Composites**, and are used to directly produce *attributions*, which may or 7 | may not be computed with modified gradients, if they are used, from 8 | **Composites**. 9 | More information on **Attributors**, examples and their use can be found in 10 | :doc:`/how-to/use-attributors`. 11 | 12 | **Attributors** can be used to implement non-layer-wise or only partly 13 | layer-wise attribution methods. 14 | For this, it is enough to define a subclass of 15 | :py:class:`zennit.attribution.Attributor` and implement its 16 | :py:meth:`~zennit.attribution.Attributor.forward` and optionally its 17 | :py:meth:`~zennit.attribution.Attributor.__init__` methods. 18 | 19 | :py:meth:`~zennit.attribution.Attributor.forward` takes 2 arguments, the tensor 20 | with respect to which the attribution shall be computed ``input``, and 21 | ``attr_output_fn``, which is a function that, given the output of the 22 | attributed model, computes the *gradient output* for the gradient computation, 23 | which is, for example, a one-hot encoding of the target label of the attributed 24 | input. 25 | When calling an :py:class:`~zennit.attribution.Attributor`, the ``__call__`` 26 | function will ensure ``forward`` receives a valid function to transform the 27 | output of the analyzed model to a tensor which can be used for the 28 | ``grad_output`` argument of :py:func:`torch.autograd.grad`. 29 | A constant tensor or function is provided by the user either to ``__init__`` or 30 | to ``__call__``. 31 | It is expected that :py:meth:`~zennit.attribution.Attributor.forward` will 32 | return a tuple containing, in order, the model output and the attribution. 33 | 34 | As an example, we can implement *gradient times input* in the following way: 35 | 36 | .. code-block:: python 37 | 38 | import torch 39 | from torchvision.models import vgg11 40 | 41 | from zennit.attribution import Attributor 42 | 43 | 44 | class GradientTimesInput(Attributor): 45 | '''Model-agnostic gradient times input.''' 46 | def forward(self, input, attr_output_fn): 47 | '''Compute gradient times input.''' 48 | input_detached = input.detach().requires_grad_(True) 49 | output = self.model(input_detached) 50 | gradient, = torch.autograd.grad( 51 | (output,), (input_detached,), (attr_output_fn(output.detach()),) 52 | ) 53 | relevance = gradient * input 54 | return output, relevance 55 | 56 | model = vgg11() 57 | data = torch.randn((1, 3, 224, 224)) 58 | 59 | with GradientTimesInput(model) as attributor: 60 | output, relevance = attributor(data) 61 | 62 | :py:class:`~zennit.attribution.Attributor` accepts an optional 63 | :py:class:`~zennit.core.Composite`, which, if supplied, will always be used to 64 | create a context in ``__call__`` around ``forward``. 65 | For the ``GradientTimesInput`` class above, using a **Composite** will probably 66 | not produce anything useful, although more involved combinations of custom 67 | **Rules** and a custom **Attributor** can be used to implement complex 68 | attribution methods with both model-agnostic and layer-wise parts. 69 | 70 | The following shows an example of *sensitivity analysis*, which is the absolute 71 | value, with a custom ``__init__()`` where we can pass the argument 72 | ``sum_channels`` to specify whether the **Attributor** should sum over the 73 | channel dimension: 74 | 75 | .. code-block:: python 76 | 77 | import torch 78 | from torchvision.models import vgg11 79 | 80 | from zennit.attribution import Attributor 81 | 82 | 83 | class SensitivityAnalysis(Attributor): 84 | '''Model-agnostic sensitivity analysis which optionally sums over color 85 | channels. 86 | ''' 87 | def __init__( 88 | self, model, sum_channels=False, composite=None, attr_output=None 89 | ): 90 | super().__init__( 91 | model, composite=composite, attr_output=attr_output 92 | ) 93 | 94 | self.sum_channels = sum_channels 95 | 96 | 97 | def forward(self, input, attr_output_fn): 98 | '''Compute the absolute gradient (or the sensitivity) and 99 | optionally sum over the color channels. 100 | ''' 101 | input_detached = input.detach().requires_grad_(True) 102 | output = self.model(input_detached) 103 | gradient, = torch.autograd.grad( 104 | (output,), (input_detached,), (attr_output_fn(output.detach()),) 105 | ) 106 | relevance = gradient.abs() 107 | if self.sum_channels: 108 | relevance = relevance.sum(1) 109 | return output, relevance 110 | 111 | model = vgg11() 112 | data = torch.randn((1, 3, 224, 224)) 113 | 114 | with SensitivityAnalysis(model, sum_channels=True) as attributor: 115 | output, relevance = attributor(data) 116 | -------------------------------------------------------------------------------- /docs/source/how-to/write-custom-canonizers.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | Writing Custom Canonizers 3 | ========================= 4 | 5 | **Canonizers** are used to temporarily transform models into a canonical form to 6 | mitigate the lack of implementation invariance of methods Layer-wise Relevance 7 | Propagation (LRP). A general introduction to **Canonizers** can be found here: 8 | :ref:`use-canonizers`. 9 | 10 | As both **Canonizers** and **Composites** (via **Rules**) change the outcome of 11 | the attribution, it can be a little bit confusing in the beginning when 12 | challenged with the question whether a novel network architectures needs a new 13 | set of **Rules** and **Composites**, or if it should be adapted to the existing 14 | framework using **Canonizers**. While ultimately it depends on the design 15 | preference of the developer, our suggestion is to go through the following steps 16 | in order: 17 | 18 | 1. Check whether a custom **Composite** is enough to correctly attribute the 19 | model, i.e. the new layer-type is only a composition of existing layer types 20 | without any unaccounted intermediate steps or incapabilities with existing 21 | rules. 22 | 2. If some of the rules which should be used are incompatible without changes 23 | (e.g. subsequent linear layers), or some parts of a module has intermediate 24 | computations that are not implemented with sub-modules, it should be checked 25 | whether a **Canonizer** can be implemented to fix these issues. If you are in 26 | control of the module in question, check whether rewriting the module with 27 | sub-modules is easier than implementing a **Canonizer**. 28 | 3. If the module consists of computations which cannot be separated into 29 | existing modules with compatible rules, or would result in an overly complex 30 | architecture, a custom **Rule** may be the choice to go with. 31 | 32 | **Rules** and **Composites** are not designed to change the forward computation 33 | of a model. While **Canonizers** can change the outcome of the forward pass, 34 | this should be used with care, since a modified function output means that the 35 | function itself has been modified, which will therefore result in an attribution 36 | of the modified function instead. 37 | 38 | To implement a custom **Canonizer**, a class inheriting from 39 | :py:class:`zennit.canonizers.Canonizer` needs to implement the following four 40 | methods: 41 | 42 | * :py:meth:`~zennit.canonizers.Canonizer.apply`, which finds the sub-modules 43 | that should be modified by the **Canonizer** and passes their information to ... 44 | * :py:meth:`~zennit.canonizers.Canonizer.register`, which copies the current 45 | instance using :py:meth:`~zennit.canonizers.Canonizer.copy`, applies the 46 | changes that should be introduced by the **Canonizer**, and makes sure they 47 | can be reverted later, using ... 48 | * :py:meth:`~zennit.canonizers.Canonizer.remove`, which reverts the changes 49 | introduced by the **Canonizer**, by i.e. loading the original parameters which 50 | were temporarily stored, and 51 | * :py:meth:`~zennit.canonizers.Canonizer.copy`, which copies the current 52 | instance, to create an individual instance for each applicable module with the 53 | same parameters. 54 | 55 | Suppose we have a ReLU model (e.g. VGG11) for which we want to compute the 56 | second-order derivative, e.g. to find an adversarial explanation (see 57 | :cite:p:`dombrowski2019explanations`). The ReLU is not differentiable at 0, and 58 | its second order derivative is zero everywhere except at 0, where it is 59 | undefined. :cite:t:`dombrowski2019explanations` replace the ReLU activations in 60 | a model with *Softplus* activations, which when running *beta* towards infinity 61 | will be identical to the ReLU activation. For the numerical estimate, it is 62 | enough to set *beta* to a relatively large value, e.g. to 10. The following is 63 | an implementation of the **SoftplusCanonizer**, which will temporarily replace 64 | the ReLU activations in a model with Softplus activations: 65 | 66 | .. code-block:: python 67 | 68 | import torch 69 | 70 | from zennit.canonizers import Canonizer 71 | 72 | 73 | class SoftplusCanonizer(Canonizer): 74 | '''Replaces ReLUs with Softplus units.''' 75 | def __init__(self, beta=10.): 76 | self.beta = beta 77 | self.module = None 78 | self.relu_children = None 79 | 80 | def apply(self, root_module): 81 | '''Iterate all modules under root_module and register the Canonizer 82 | if they have immediate ReLU sub-modules. 83 | ''' 84 | # track the SoftplusCanonizer instances to remove them later 85 | instances = [] 86 | # iterate recursively over all modules 87 | for module in root_module.modules(): 88 | # get all the direct sub-module instances of torch.nn.ReLU 89 | relu_children = [ 90 | (name, child) 91 | for name, child in module.named_children() 92 | if isinstance(child, torch.nn.ReLU) 93 | ] 94 | # if there is at least on direct ReLU sub-module 95 | if relu_children: 96 | # create a copy (with the same beta parameter) 97 | instance = self.copy() 98 | # register the module 99 | instance.register(module, relu_children) 100 | # add the copy to the instance list 101 | instances.append(instance) 102 | return instances 103 | 104 | def register(self, module, relu_children): 105 | '''Store the module and the immediate ReLU-sub-modules, and then 106 | overwrite the attributes that corresponds to each ReLU-sub-modules 107 | with a new instance of ``torch.nn.Softplus``. 108 | ''' 109 | self.module = module 110 | self.relu_children = relu_children 111 | for name, _ in relu_children: 112 | # set each of the attributes corresponding to the ReLU to a new 113 | # instance of torch.nn.Softplus 114 | setattr(module, name, torch.nn.Softplus(beta=self.beta)) 115 | 116 | def remove(self): 117 | '''Undo the changes introduces by this Canonizer, by setting the 118 | appropriate attributes of the stored module back to the original 119 | ReLU sub-module instance. 120 | ''' 121 | for name, child in self.relu_children: 122 | setattr(self.module, name, child) 123 | 124 | def copy(self): 125 | '''Create a copy of this instance. Each module requires its own 126 | instance to call ``.register``. 127 | ''' 128 | return SoftplusCanonizer(beta=self.beta) 129 | 130 | 131 | Note that we can only replace modules by changing their immediate parent. This 132 | means that if ``root_module`` was a ``torch.nn.ReLU`` itself, it would be 133 | impossible to replace it with a ``torch.nn.Softplus`` without replacing the 134 | ``root_module`` itself. 135 | 136 | For demonstration purposes, we can compute the gradient w.r.t. a loss which uses 137 | the gradient of the modified model in the following way: 138 | 139 | .. code-block:: python 140 | 141 | import torch 142 | from torchvision.models import vgg11 143 | 144 | from zennit.core import Composite 145 | from zennit.image import imgify 146 | 147 | 148 | # create a VGG11 model with random parameters 149 | model = vgg11() 150 | # use the Canonizer with an "empty" Composite (without specifying 151 | # module_map), which will not attach rules to any sub-module, thus resulting 152 | # in a plain gradient computation, but with a Canonizer applied 153 | composite = Composite( 154 | canonizers=[SoftplusCanonizer()] 155 | ) 156 | 157 | input = torch.randn(1, 3, 224, 224, requires_grad=True) 158 | target = torch.eye(1000)[[0]] 159 | with composite.context(model) as modified_model: 160 | out = modified_model(input) 161 | relevance, = torch.autograd.grad(out, input, target, create_graph=True) 162 | # find adversarial example such that input and its respective 163 | # attribution are close 164 | loss = ((relevance - input.detach()) ** 2).mean() 165 | # compute the gradient of input w.r.t. loss, using the second order 166 | # derivative w.r.t. input; note that this currently does not work when 167 | # using BasicHook, which detaches the gradient to avoid wrong values 168 | adv_grad, = torch.autograd.grad(loss, input) 169 | 170 | # visualize adv_grad 171 | imgify(adv_grad[0].abs().sum(0), cmap='hot').show() 172 | 173 | -------------------------------------------------------------------------------- /docs/source/how-to/write-custom-composites.rst: -------------------------------------------------------------------------------- 1 | ========================= 2 | Writing Custom Composites 3 | ========================= 4 | 5 | Zennit provides a number of commonly used **Composites**. 6 | While these are often enough for feed-forward-type neural networks, one primary goal of Zennit is to provide the tools to easily customize the computation of rule-based attribution methods. 7 | This is especially useful to analyze novel architectures, for which no attribution-approach has been designed before. 8 | 9 | For most use-cases, using the abstract **Composites** :py:class:`~zennit.composites.LayerMapComposite`, :py:class:`~zennit.composites.SpecialFirstLayerMapComposite`, and :py:class:`~zennit.composites.NameMapComposite` already provides enough freedom to customize which Layer should receive which rule. See :ref:`use-composites` for an introduction. 10 | Depending on the setup, it may however be more convenient to either directly use or implement a new **Composite** by creating a Subclass from :py:class:`zennit.core.Composite`. 11 | In either case, the :py:class:`~zennit.core.Composite` requires an argument ``module_map``, which is a function with the signature ``(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None``, which, given a context dict, the name of a single module and the module itself, either returns an instance of :py:class:`~zennit.core.Hook` which should be copied and registered to the module, or ``None`` if no ``Hook`` should be applied. 12 | The context dict ``ctx`` can be used to track subsequent calls to the ``module_map`` function, e.g. to count the number of processed modules, or to verify if some condition has been met before, e.g. a linear layer has been seen before. 13 | The ``module_map`` is used in :py:meth:`zennit.core.Composite.register`, where the context dict is initialized to an empty dict ``{}`` before iterating over all the sub-modules of the root-module to which the composite will be registered. 14 | The iteration is done using :py:meth:`torch.nn.Module.named_modules`, which will therefore dictate the order modules are visited, which is depth-first in the order sub-modules were assigned. 15 | 16 | A simple **Composite**, which only provides rules for linear layers that are leaves and bases the rule on how many leaf modules were visited before could be implemented like the following: 17 | 18 | 19 | .. code-block:: python 20 | 21 | import torch 22 | from torchvision.models import vgg16 23 | from zennit.rules import Epsilon, AlphaBeta 24 | from zennit.types import Linear 25 | from zennit.core import Composite 26 | from zennit.attribution import Gradient 27 | 28 | 29 | def module_map(ctx, name, module): 30 | # check whether there is at least one child, i.e. the module is not a leaf 31 | try: 32 | next(module.children()) 33 | except StopIteration: 34 | # StopIteration is raised if the iterator has no more elements, 35 | # which means in this case there are no children and module is a leaf 36 | pass 37 | else: 38 | # if StopIteration is not raised on the first element, module is not a leaf 39 | return None 40 | 41 | # if the module is not Linear, we do not want to assign a hook 42 | if not isinstance(module, Linear): 43 | return None 44 | 45 | # count the number of the leaves processed yet in 'leafnum' 46 | if 'leafnum' not in ctx: 47 | ctx['leafnum'] = 0 48 | else: 49 | ctx['leafnum'] += 1 50 | 51 | # the first 10 leaf-modules which are of type Linear should be assigned 52 | # the Alpha2Beta1 rule 53 | if ctx['leafnum'] < 10: 54 | return AlphaBeta(alpha=2, beta=1) 55 | # all other rules should be assigned Epsilon 56 | return Epsilon(epsilon=1e-3) 57 | 58 | 59 | # we can then create a composite by passing the module_map function 60 | # canonizers may also be passed as with all composites 61 | composite = Composite(module_map=module_map) 62 | 63 | # try out the composite 64 | model = vgg16() 65 | with Gradient(model, composite) as attributor: 66 | out, grad = attributor(torch.randn(1, 3, 224, 224)) 67 | 68 | 69 | A more general **Composite**, where we can specify which layer number and which type should be assigned which rule, can be implemented by creating a class: 70 | 71 | .. code-block:: python 72 | 73 | from itertools import islice 74 | 75 | import torch 76 | from torchvision.models import vgg16 77 | from zennit.rules import Epsilon, ZBox, Gamma, Pass, Norm 78 | from zennit.types import Linear, Convolution, Activation, AvgPool 79 | from zennit.core import Composite 80 | from zennit.attribution import Gradient 81 | 82 | 83 | class LeafNumberTypeComposite(Composite): 84 | def __init__(self, leafnum_map): 85 | # pass the class method self.mapping as the module_map 86 | super().__init__(module_map=self.mapping) 87 | # set the instance attribute so we can use it in self.mapping 88 | self.leafnum_map = leafnum_map 89 | 90 | def mapping(self, ctx, name, module): 91 | # check whether there is at least one child, i.e. the module is not a leaf 92 | # but this time shorter using itertools.islice to get at most one child 93 | if list(islice(module.children(), 1)): 94 | return None 95 | 96 | # count the number of the leaves processed yet in 'leafnum' 97 | # this time in a single line with get and all layers count, e.g. ReLU 98 | ctx['leafnum'] = ctx.get('leafnum', -1) + 1 99 | 100 | # loop over the leafnum_map and use the first template for which 101 | # the module type matches and the current ctx['leafnum'] falls into 102 | # the bounds 103 | for (low, high), dtype, template in self.leafnum_map: 104 | if isinstance(module, dtype) and low <= ctx['leafnum'] < high: 105 | return template 106 | # if none of the leafnum_map apply this means there is no rule 107 | # matching the current layer 108 | return None 109 | 110 | 111 | # this can be compared with int and will always be larger 112 | inf = float('inf') 113 | 114 | # we create an example leafnum-map, note that Linear is here 115 | # zennit.types.Linear and not torch.nn.Linear 116 | # the first two entries are for demonstration only and would 117 | # in practice most likely be a single "Linear" with appropriate low/high 118 | leafnum_map = [ 119 | [(0, 1), Convolution, ZBox(low=-3.0, high=3.0)], 120 | [(0, 1), torch.nn.Linear, ZBox(low=0.0, high=1.0)], 121 | [(1, 17), Linear, Gamma(gamma=0.25)], 122 | [(17, 31), Linear, Epsilon(epsilon=0.5)], 123 | [(31, inf), Linear, Epsilon(epsilon=1e-9)], 124 | # catch all activations 125 | [(0, inf), Activation, Pass()], 126 | # explicit None is possible e.g. to (ab-)use precedence 127 | [(0, 17), torch.nn.MaxPool2d, None], 128 | # catch all AvgPool/MaxPool2d, isinstance also accepts tuples of types 129 | [(0, inf), (AvgPool, torch.nn.MaxPool2d), Norm()], 130 | ] 131 | 132 | # finally, create the composite using the leafnum_map 133 | composite = LeafNumberTypeComposite(leafnum_map) 134 | 135 | # try out the composite 136 | model = vgg16() 137 | with Gradient(model, composite) as attributor: 138 | out, grad = attributor(torch.randn(1, 3, 224, 224)) 139 | 140 | In practice, however, we do not recommend to use the index of the layer when designing **Composites**, because most of the time, when such a configuration is chosen, it is done to shape the **Composite** for an explicit model. 141 | For these kinds of **Composites**, a :py:class:`~zennit.composites.NameMapComposite` will directly map the name of a sub-module to a Hook, which is a more explicit and transparent way to create a special **Composite** for a single neural network. 142 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ==================== 2 | Zennit Documentation 3 | ==================== 4 | 5 | Zennit (Zennit Explains Neural Networks in Torch) is a python framework using PyTorch to compute local attributions in the sense of eXplainable AI (XAI) with a focus on Layer-wise Relevance Propagation. 6 | It works by defining *rules* which are used to overwrite the gradient of PyTorch modules in PyTorch's auto-differentiation engine. 7 | Rules are mapped to layers with *composites*, which contain directions to compute the attributions of a full model, which maps rules to modules based on their properties and context. 8 | 9 | Zennit is available on PyPI and may be installed using: 10 | 11 | .. code-block:: console 12 | 13 | $ pip install zennit 14 | 15 | Contents 16 | -------- 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | 21 | getting-started 22 | how-to/index 23 | tutorial/index 24 | reference/index 25 | bibliography 26 | 27 | Indices and tables 28 | ------------------ 29 | 30 | * :ref:`genindex` 31 | * :ref:`modindex` 32 | * :ref:`search` 33 | 34 | 35 | Citing 36 | ------ 37 | 38 | If you find Zennit useful, why not cite our related paper :cite:p:`anders2021software`: 39 | 40 | .. code-block:: bibtex 41 | 42 | @article{anders2021software, 43 | author = {Anders, Christopher J. and 44 | Neumann, David and 45 | Samek, Wojciech and 46 | Müller, Klaus-Robert and 47 | Lapuschkin, Sebastian}, 48 | title = {Software for Dataset-wide XAI: From Local Explanations to Global Insights with {Zennit}, {CoRelAy}, and {ViRelAy}}, 49 | journal = {CoRR}, 50 | volume = {abs/2106.13200}, 51 | year = {2021}, 52 | } 53 | 54 | -------------------------------------------------------------------------------- /docs/source/reference/index.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | API Reference 3 | ================ 4 | 5 | .. autosummary:: 6 | :toctree: 7 | :nosignatures: 8 | :recursive: 9 | :template: modules.rst 10 | 11 | zennit.attribution 12 | zennit.canonizers 13 | zennit.cmap 14 | zennit.composites 15 | zennit.core 16 | zennit.image 17 | zennit.layer 18 | zennit.rules 19 | zennit.torchvision 20 | zennit.types 21 | -------------------------------------------------------------------------------- /docs/source/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | Tutorials 3 | ================ 4 | 5 | .. toctree:: 6 | :maxdepth: 1 7 | 8 | image-classification-vgg-resnet 9 | .. 10 | image-segmentation-with-unet 11 | text-classification-with-tbd 12 | audio-classification-with-tbd 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import re 3 | from setuptools import setup, find_packages 4 | from subprocess import run, CalledProcessError 5 | 6 | 7 | def get_long_description(project_path): 8 | '''Fetch the README contents and replace relative links with absolute ones 9 | pointing to github for correct behaviour on PyPI. 10 | ''' 11 | try: 12 | revision = run( 13 | ['git', 'describe', '--tags'], 14 | capture_output=True, 15 | check=True, 16 | text=True 17 | ).stdout[:-1] 18 | except CalledProcessError: 19 | try: 20 | with open('PKG-INFO', 'r') as fd: 21 | body = fd.read().partition('\n\n')[2] 22 | if body: 23 | return body 24 | except FileNotFoundError: 25 | revision = 'master' 26 | 27 | with open('README.md', 'r', encoding='utf-8') as fd: 28 | long_description = fd.read() 29 | 30 | link_root = { 31 | '': f'https://github.com/{project_path}/blob', 32 | '!': f'https://raw.githubusercontent.com/{project_path}', 33 | } 34 | 35 | def replace(mobj): 36 | return f'{mobj[1]}[{mobj[2]}]({link_root[mobj[1]]}/{revision}/{mobj[3]})' 37 | 38 | link_rexp = re.compile(r'(!?)\[([^\]]*)\]\((?!https?://|/)([^\)]+)\)') 39 | return link_rexp.sub(replace, long_description) 40 | 41 | 42 | setup( 43 | name='zennit', 44 | use_scm_version=True, 45 | author='chrstphr', 46 | author_email='zennit@j0d.de', 47 | description='Attribution of Neural Networks using PyTorch', 48 | long_description=get_long_description('chr5tphr/zennit'), 49 | long_description_content_type='text/markdown', 50 | url='https://github.com/chr5tphr/zennit', 51 | packages=find_packages(where='src', include=['zennit*']), 52 | package_dir={'': 'src'}, 53 | install_requires=[ 54 | 'click', 55 | 'numpy', 56 | 'Pillow', 57 | 'torch>=1.7.0', 58 | 'torchvision', 59 | ], 60 | setup_requires=[ 61 | 'setuptools_scm', 62 | ], 63 | extras_require={ 64 | 'docs': [ 65 | 'sphinx-copybutton>=0.4.0', 66 | 'sphinx-rtd-theme>=1.0.0', 67 | 'sphinxcontrib.datatemplates>=0.9.0', 68 | 'sphinxcontrib.bibtex>=2.4.1', 69 | 'nbsphinx>=0.8.8', 70 | 'nbconvert<7.14', # see https://github.com/jupyter/nbconvert/issues/2092 71 | 'ipykernel>=6.13.0', 72 | ], 73 | 'tests': [ 74 | 'pytest', 75 | 'pytest-cov', 76 | ] 77 | }, 78 | python_requires='>=3.7', 79 | classifiers=[ 80 | 'Development Status :: 3 - Alpha', 81 | 'License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)', 82 | 'Programming Language :: Python :: 3.7', 83 | 'Programming Language :: Python :: 3.8', 84 | 'Programming Language :: Python :: 3.9', 85 | ] 86 | ) 87 | -------------------------------------------------------------------------------- /share/example/feed_forward.py: -------------------------------------------------------------------------------- 1 | '''A quick example to generate heatmaps for vgg16.''' 2 | import os 3 | from functools import partial 4 | 5 | import click 6 | import torch 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Subset 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor 10 | from torchvision.datasets import ImageFolder 11 | from torchvision.models import vgg11, vgg11_bn, vgg16, vgg16_bn, resnet18, resnet50 12 | 13 | from zennit.attribution import Gradient, SmoothGrad, IntegratedGradients, Occlusion 14 | from zennit.composites import COMPOSITES 15 | from zennit.core import Hook 16 | from zennit.image import imsave, CMAPS 17 | from zennit.layer import Sum 18 | from zennit.torchvision import VGGCanonizer, ResNetCanonizer 19 | 20 | 21 | MODELS = { 22 | 'vgg16': (vgg16, VGGCanonizer), 23 | 'vgg16_bn': (vgg16_bn, VGGCanonizer), 24 | 'vgg11': (vgg11, VGGCanonizer), 25 | 'vgg11_bn': (vgg11_bn, VGGCanonizer), 26 | 'resnet18': (resnet18, ResNetCanonizer), 27 | 'resnet50': (resnet50, ResNetCanonizer), 28 | } 29 | 30 | ATTRIBUTORS = { 31 | 'gradient': Gradient, 32 | 'smoothgrad': SmoothGrad, 33 | 'integrads': IntegratedGradients, 34 | 'occlusion': Occlusion, 35 | 'inputxgrad': IntegratedGradients, 36 | } 37 | 38 | 39 | class SumSingle(Hook): 40 | def __init__(self, dim=1): 41 | super().__init__() 42 | self.dim = dim 43 | 44 | def backward(self, module, grad_input, grad_output): 45 | elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1]) 46 | elems[self.dim] = grad_output[0] 47 | return (torch.stack(elems, dim=-1),) 48 | 49 | 50 | class BatchNormalize: 51 | def __init__(self, mean, std, device=None): 52 | self.mean = torch.tensor(mean, device=device)[None, :, None, None] 53 | self.std = torch.tensor(std, device=device)[None, :, None, None] 54 | 55 | def __call__(self, tensor): 56 | return (tensor - self.mean) / self.std 57 | 58 | 59 | class AllowEmptyClassImageFolder(ImageFolder): 60 | '''Subclass of ImageFolder, which only finds non-empty classes, but with their correct indices given other empty 61 | classes. This counter-acts the changes in torchvision 0.10.0, in which DatasetFolder does not allow empty classes 62 | anymore by default. Versions before 0.10.0 do not expose `find_classes`, and thus this change does not change the 63 | functionality of `ImageFolder` in earlier versions. 64 | ''' 65 | def find_classes(self, directory): 66 | with os.scandir(directory) as scanit: 67 | class_info = sorted((entry.name, len(list(os.scandir(entry.path)))) for entry in scanit if entry.is_dir()) 68 | class_to_idx = {class_name: index for index, (class_name, n_members) in enumerate(class_info) if n_members} 69 | if not class_to_idx: 70 | raise FileNotFoundError(f'No non-empty classes found in \'{directory}\'.') 71 | return list(class_to_idx), class_to_idx 72 | 73 | 74 | @click.command() 75 | @click.argument('dataset-root', type=click.Path(file_okay=False)) 76 | @click.argument('relevance_format', type=click.Path(dir_okay=False, writable=True)) 77 | @click.option('--attributor', 'attributor_name', type=click.Choice(list(ATTRIBUTORS)), default='gradient') 78 | @click.option('--composite', 'composite_name', type=click.Choice(list(COMPOSITES))) 79 | @click.option('--model', 'model_name', type=click.Choice(list(MODELS)), default='vgg16_bn') 80 | @click.option('--parameters', type=click.Path(dir_okay=False)) 81 | @click.option( 82 | '--inputs', 83 | 'input_format', 84 | type=click.Path(dir_okay=False, writable=True), 85 | help='Input image format string. {sample} is replaced with the sample index.' 86 | ) 87 | @click.option('--batch-size', type=int, default=16) 88 | @click.option('--max-samples', type=int) 89 | @click.option('--n-outputs', type=int, default=1000) 90 | @click.option('--cpu/--gpu', default=True) 91 | @click.option('--shuffle/--no-shuffle', default=False) 92 | @click.option('--with-bias/--no-bias', default=True) 93 | @click.option('--with-residual/--no-residual', default=True) 94 | @click.option('--relevance-norm', type=click.Choice(['symmetric', 'absolute', 'unaligned']), default='symmetric') 95 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot') 96 | @click.option('--level', type=float, default=1.0) 97 | @click.option('--seed', type=int, default=0xDEADBEEF) 98 | def main( 99 | dataset_root, 100 | relevance_format, 101 | attributor_name, 102 | composite_name, 103 | model_name, 104 | parameters, 105 | input_format, 106 | batch_size, 107 | max_samples, 108 | n_outputs, 109 | cpu, 110 | shuffle, 111 | with_bias, 112 | with_residual, 113 | cmap, 114 | level, 115 | relevance_norm, 116 | seed 117 | ): 118 | '''Generate heatmaps of an image folder at DATASET_ROOT to files RELEVANCE_FORMAT. 119 | RELEVANCE_FORMAT is a format string, for which {sample} is replaced with the sample index. 120 | ''' 121 | # set a manual seed for the RNG 122 | torch.manual_seed(seed) 123 | 124 | # use the gpu if requested and available, else use the cpu 125 | device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu') 126 | 127 | # mean and std of ILSVRC2012 as computed for the torchvision models 128 | norm_fn = BatchNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), device=device) 129 | 130 | # transforms as used for torchvision model evaluation 131 | transform = Compose([ 132 | Resize(256), 133 | CenterCrop(224), 134 | ToTensor(), 135 | ]) 136 | 137 | # the dataset is a folder containing folders with samples, where each folder corresponds to one label 138 | dataset = AllowEmptyClassImageFolder(dataset_root, transform=transform) 139 | 140 | # limit the number of output samples, if requested, by creating a subset 141 | if max_samples is not None: 142 | if shuffle: 143 | indices = sorted(np.random.choice(len(dataset), min(len(dataset), max_samples), replace=False)) 144 | else: 145 | indices = range(min(len(dataset), max_samples)) 146 | dataset = Subset(dataset, indices) 147 | 148 | loader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size) 149 | 150 | model = MODELS[model_name][0]() 151 | 152 | # load model parameters if requested; the parameter file may need to be downloaded separately 153 | if parameters is not None: 154 | state_dict = torch.load(parameters) 155 | model.load_state_dict(state_dict) 156 | model.to(device) 157 | model.eval() 158 | 159 | # disable requires_grad for all parameters, we do not need their modified gradients 160 | for param in model.parameters(): 161 | param.requires_grad = False 162 | 163 | # convenience identity matrix to produce one-hot encodings 164 | eye = torch.eye(n_outputs, device=device) 165 | 166 | # function to compute output relevance given the function output and a target 167 | def attr_output_fn(output, target): 168 | # output times one-hot encoding of the target labels of size (len(target), 1000) 169 | return output * eye[target] 170 | 171 | # create a composite if composite_name was set, otherwise we do not use a composite 172 | composite = None 173 | if composite_name is not None: 174 | composite_kwargs = {} 175 | if composite_name == 'epsilon_gamma_box': 176 | # the maximal input shape, needed for the ZBox rule 177 | shape = (batch_size, 3, 224, 224) 178 | 179 | # the highest and lowest pixel values for the ZBox rule 180 | composite_kwargs['low'] = norm_fn(torch.zeros(*shape, device=device)) 181 | composite_kwargs['high'] = norm_fn(torch.ones(*shape, device=device)) 182 | if not with_residual and 'resnet' in model_name: 183 | # skip the residual connection through the Sum added by the ResNetCanonizer 184 | composite_kwargs['layer_map'] = [(Sum, SumSingle(1))] 185 | 186 | # provide the name 'bias' in zero_params if no bias should be used to compute the relevance 187 | if not with_bias and composite_name in [ 188 | 'epsilon_gamma_box', 189 | 'epsilon_plus', 190 | 'epsilon_alpha2_beta1', 191 | 'epsilon_plus_flat', 192 | 'epsilon_alpha2_beta1_flat', 193 | 'excitation_backprop', 194 | ]: 195 | composite_kwargs['zero_params'] = ['bias'] 196 | 197 | # use torchvision specific canonizers, as supplied in the MODELS dict 198 | composite_kwargs['canonizers'] = [MODELS[model_name][1]()] 199 | 200 | # create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit. 201 | composite = COMPOSITES[composite_name](**composite_kwargs) 202 | 203 | # specify some attributor-specific arguments 204 | attributor_kwargs = { 205 | 'smoothgrad': {'noise_level': 0.1, 'n_iter': 20}, 206 | 'integrads': {'n_iter': 20}, 207 | 'inputxgrad': {'n_iter': 1}, 208 | 'occlusion': {'window': (56, 56), 'stride': (28, 28)}, 209 | }.get(attributor_name, {}) 210 | 211 | # create an attributor, given the ATTRIBUTORS dict given above. If composite is None, the gradient will not be 212 | # modified for the attribution 213 | attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs) 214 | 215 | # the current sample index for creating file names 216 | sample_index = 0 217 | 218 | # the accuracy 219 | accuracy = 0. 220 | 221 | # enter the attributor context outside the data loader loop, such that its canonizers and hooks do not need to be 222 | # registered and removed for each step. This registers the composite (and applies the canonizer) to the model 223 | # within the with-statement 224 | with attributor: 225 | for data, target in loader: 226 | # we use data without the normalization applied for visualization, and with the normalization applied as 227 | # the model input 228 | data_norm = norm_fn(data.to(device)) 229 | 230 | # create output relevance function of output with fixed target 231 | output_relevance = partial(attr_output_fn, target=target) 232 | 233 | # this will compute the modified gradient of model, where the output relevance is chosen by the as the 234 | # model's output for the ground-truth label index 235 | output, relevance = attributor(data_norm, output_relevance) 236 | 237 | # sum over the color channel for visualization 238 | relevance = np.array(relevance.sum(1).detach().cpu()) 239 | 240 | # normalize between 0. and 1. given the specified strategy 241 | if relevance_norm == 'symmetric': 242 | # 0-aligned symmetric relevance, negative and positive can be compared, the original 0. becomes 0.5 243 | amax = np.abs(relevance).max((1, 2), keepdims=True) 244 | relevance = (relevance + amax) / 2 / amax 245 | elif relevance_norm == 'absolute': 246 | # 0-aligned absolute relevance, only the amplitude of relevance matters, the original 0. becomes 0. 247 | relevance = np.abs(relevance) 248 | relevance /= relevance.max((1, 2), keepdims=True) 249 | elif relevance_norm == 'unaligned': 250 | # do not align, the original minimum value becomes 0., the original maximum becomes 1. 251 | rmin = relevance.min((1, 2), keepdims=True) 252 | rmax = relevance.max((1, 2), keepdims=True) 253 | relevance = (relevance - rmin) / (rmax - rmin) 254 | 255 | for n in range(len(data)): 256 | fname = relevance_format.format(sample=sample_index + n) 257 | # zennit.image.imsave will create an appropriate heatmap given a cmap specification 258 | imsave(fname, relevance[n], vmin=0., vmax=1., level=level, cmap=cmap) 259 | if input_format is not None: 260 | fname = input_format.format(sample=sample_index + n) 261 | # if there are 3 color channels, imsave will not create a heatmap, but instead save the image with 262 | # its appropriate colors 263 | imsave(fname, data[n]) 264 | sample_index += len(data) 265 | 266 | # update the accuracy 267 | accuracy += (output.argmax(1) == target).sum().detach().cpu().item() 268 | 269 | accuracy /= len(dataset) 270 | print(f'Accuracy: {accuracy:.2f}') 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /share/img/beacon_resnet50_various.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_resnet50_various.webp -------------------------------------------------------------------------------- /share/img/beacon_vgg16_epsilon_gamma_box.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_vgg16_epsilon_gamma_box.png -------------------------------------------------------------------------------- /share/img/beacon_vgg16_various.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_vgg16_various.webp -------------------------------------------------------------------------------- /share/img/zennit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/zennit.png -------------------------------------------------------------------------------- /share/img/zennit.svg: -------------------------------------------------------------------------------- 1 | 2 | 21 | 23 | 25 | 28 | 32 | 36 | 37 | 40 | 44 | 48 | 49 | 60 | 71 | 72 | 92 | 97 | 98 | 100 | 101 | 103 | image/svg+xml 104 | 106 | 107 | 108 | 109 | 110 | 115 | 124 | 131 | 138 | 152 | 159 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /share/merge_maps/vgg16_bn.json: -------------------------------------------------------------------------------- 1 | [ 2 | [["features.0"], "features.1"], 3 | [["features.3"], "features.4"], 4 | [["features.7"], "features.8"], 5 | [["features.10"], "features.11"], 6 | [["features.14"], "features.15"], 7 | [["features.17"], "features.18"], 8 | [["features.20"], "features.21"], 9 | [["features.24"], "features.25"], 10 | [["features.27"], "features.28"], 11 | [["features.30"], "features.31"], 12 | [["features.34"], "features.35"], 13 | [["features.37"], "features.38"], 14 | [["features.40"], "features.41"] 15 | ] 16 | -------------------------------------------------------------------------------- /share/scripts/palette_fit.py: -------------------------------------------------------------------------------- 1 | '''Script to fit RGB heatmap images to a source color palette.''' 2 | import click 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from zennit.image import CMAPS, palette 7 | 8 | 9 | def gale_shapley(dist): 10 | '''Find a stable matching given a distance matrix.''' 11 | preference = np.argsort(dist, axis=1) 12 | proposed = np.zeros(dist.shape[0], dtype=int) 13 | loners = set(range(dist.shape[0])) 14 | guys = [-1] * dist.shape[0] 15 | gals = [-1] * dist.shape[1] 16 | while loners: 17 | loner = loners.pop() 18 | target = preference[loner, proposed[loner]] 19 | if gals[target] == -1: 20 | gals[target] = loner 21 | guys[loner] = target 22 | elif dist[gals[target], target] > dist[loner, target]: 23 | gals[target] = loner 24 | guys[loner] = target 25 | guys[gals[target]] = -1 26 | loners.add(gals[target]) 27 | else: 28 | loners.add(loner) 29 | proposed[loner] += 1 30 | return guys 31 | 32 | 33 | @click.command() 34 | @click.argument('source-file', type=click.Path(exists=True, dir_okay=False)) 35 | @click.argument('output-file', type=click.Path(writable=True, dir_okay=False)) 36 | @click.option('--strategy', type=click.Choice(['intensity', 'nearest', 'histogram']), default='intensity') 37 | @click.option('--source-cmap', type=click.Choice(list(CMAPS)), default='bwr') 38 | @click.option('--source-level', type=float, default=1.0) 39 | @click.option('--invert/--no-invert', default=False) 40 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot') 41 | @click.option('--level', type=float, default=1.0) 42 | def main(source_file, output_file, strategy, source_cmap, source_level, invert, cmap, level): 43 | '''Fit an existing RGB heatmap image to a color palette.''' 44 | source = np.array(Image.open(source_file).convert('RGB')) 45 | matchpal = palette(source_cmap, source_level) 46 | 47 | if strategy == 'intensity': 48 | # matching based on the source image intensity/ brightness 49 | values = source.astype(float).mean(2) 50 | elif strategy == 'nearest': 51 | # match by finding the neareast centroids in a source colormap 52 | dists = (np.abs(source[None].astype(float) - matchpal[:, None, None].astype(float))).sum(3) 53 | values = np.argmin(dists, axis=0) 54 | elif strategy == 'histogram': 55 | # match by finding a stable match between the color histogram of the source image and a source colormap 56 | source = np.concatenate([source, np.zeros_like(source[:, :, [0]])], axis=2).view(np.uint32)[..., 0] 57 | uniques, counts = np.unique(source, return_counts=True) 58 | uniques = uniques[np.argsort(counts)[-256:]] 59 | dist = (np.abs(uniques.view(np.uint8).reshape(-1, 1, 4)[..., :3] - matchpal[None])).sum(2) 60 | matches = np.array(gale_shapley(dist)) 61 | 62 | ind_bin, ind_h, ind_w = np.nonzero(source[None] == uniques[:, None, None]) 63 | values = np.zeros(source.shape[:2], dtype=np.uint8) 64 | values[ind_h, ind_w] = matches[ind_bin] 65 | 66 | values = values.clip(0, 255).astype(np.uint8) 67 | if invert: 68 | values = 255 - values 69 | 70 | img = Image.fromarray(values, mode='P') 71 | pal = palette(cmap, level) 72 | img.putpalette(pal) 73 | img.save(output_file) 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /share/scripts/palette_swap.py: -------------------------------------------------------------------------------- 1 | '''Script to swap the palette of heatmap images.''' 2 | import click 3 | from PIL import Image 4 | 5 | from zennit.image import CMAPS, palette 6 | 7 | 8 | @click.command() 9 | @click.argument('image-files', type=click.Path(exists=True, dir_okay=False), nargs=-1) 10 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot') 11 | @click.option('--level', type=float, default=1.0) 12 | def main(image_files, cmap, level): 13 | '''Swap the palette of heatmap image files inline.''' 14 | for fname in image_files: 15 | img = Image.open(fname) 16 | img = img.convert('P') 17 | pal = palette(cmap, level) 18 | img.putpalette(pal) 19 | img.save(fname) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /share/scripts/show_cmaps.py: -------------------------------------------------------------------------------- 1 | '''Script to visually inspect color maps.''' 2 | import click 3 | import numpy as np 4 | from PIL import Image 5 | 6 | from zennit.image import CMAPS, palette 7 | 8 | 9 | def semsstr(string): 10 | if isinstance(string, list): 11 | return string 12 | return [obj for obj in string.split(';') if obj] 13 | 14 | 15 | @click.command() 16 | @click.argument('output') 17 | @click.option('--cmap', 'colormap_src', type=semsstr, default=list(CMAPS)) 18 | @click.option('--level', type=float, default=1.0) 19 | def main(output, colormap_src, level): 20 | print('\n'.join(colormap_src)) 21 | palettes = np.stack([palette(obj, level) for obj in colormap_src]) 22 | arr = np.repeat(palettes, 32, 0) 23 | img = Image.fromarray(arr) 24 | img.save(output) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /src/zennit/__init__.py: -------------------------------------------------------------------------------- 1 | '''Zennit top-level __init__.''' 2 | from . import attribution 3 | from . import canonizers 4 | from . import cmap 5 | from . import composites 6 | from . import core 7 | from . import image 8 | from . import layer 9 | from . import rules 10 | from . import torchvision 11 | from . import types 12 | 13 | 14 | __all__ = [ 15 | 'attribution', 16 | 'canonizers', 17 | 'cmap', 18 | 'composites', 19 | 'core', 20 | 'image', 21 | 'layer', 22 | 'rules', 23 | 'torchvision', 24 | 'types', 25 | ] 26 | -------------------------------------------------------------------------------- /src/zennit/cmap.py: -------------------------------------------------------------------------------- 1 | # This file is part of Zennit 2 | # Copyright (C) 2019-2021 Christopher J. Anders 3 | # 4 | # zennit/cmap.py 5 | # 6 | # Zennit is free software: you can redistribute it and/or modify it under 7 | # the terms of the GNU Lesser General Public License as published by the Free 8 | # Software Foundation; either version 3 of the License, or (at your option) any 9 | # later version. 10 | # 11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT 12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for 14 | # more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this library. If not, see . 18 | '''Create color maps from a color-map specification language''' 19 | import re 20 | from typing import NamedTuple 21 | 22 | import numpy as np 23 | 24 | 25 | class CMapToken(NamedTuple): 26 | '''Tokens used by the lexer of ColorMap.''' 27 | type: str 28 | value: str 29 | pos: int 30 | 31 | 32 | class ColorNode(NamedTuple): 33 | '''Nodes produced by the parser of ColorMap.''' 34 | index: int 35 | value: np.ndarray 36 | 37 | 38 | class ColorMap: 39 | '''Compile a color map from color-map specification language (cmsl) source code. 40 | 41 | The color-map specification language (cmsl) is used to specify linear color maps with comma-separated colors 42 | supplied as hexadecimal values for each color channel in RGB, with either 1 or 2 values per channel. Optionally, a 43 | hexadecimal index with either one or two digits may be supplied in front of each color, followed by a colon, to 44 | indicate the index which should be the color. Values for the ColorMap in-between colors will be interpolated 45 | linearly. If no index is supplied, colors without indices will be spaced evenly between indices. If the first and 46 | last indices are supplied but not 0 (or 00) and f (or ff) respectively, they will be added as an additional node in 47 | the color map, with the same color as the colors with the lowest and highest index respectively. If indices are 48 | provided, they must be in ascending order from left to right, with an arbitrary number of non-indexed colors. If 49 | the first and/or last color are not indexed, they are assumed to be 0 (or 00) and f (or ff) respectively. 50 | 51 | Parameters 52 | ---------- 53 | source : str 54 | Source code to generate the color map. 55 | 56 | ''' 57 | _rexp = re.compile( 58 | r'(?P[0-9a-fA-F]{6})|' 59 | r'(?P[0-9a-fA-F]{3})|' 60 | r'(?P[0-9a-fA-F]{1,2})|' 61 | r'(?P:)|' 62 | r'(?P,)|' 63 | r'(?P\s+)|' 64 | r'(?P.+)' 65 | ) 66 | 67 | def __init__(self, source): 68 | self._source = None 69 | self.source = source 70 | 71 | @property 72 | def source(self) -> str: 73 | '''Source code property used to generate the color map. May be overwritten with a new string, which will be 74 | compiled to change the color map. 75 | ''' 76 | return self._source 77 | 78 | @source.setter 79 | def source(self, value: str): 80 | '''Set source code property and re-compile the color map.''' 81 | try: 82 | tokens = self._lex(value) 83 | nodes = self._parse(tokens) 84 | self._indices, self._colors = self._make_palette(nodes) 85 | except RuntimeError as err: 86 | raise RuntimeError('Compilation of ColorMap failed!') from err 87 | 88 | self._source = value 89 | 90 | @staticmethod 91 | def _lex(string): 92 | '''Lexical scanning of cmsl using regular expressions.''' 93 | return [CMapToken(match.lastgroup, match.group(), match.start()) for match in ColorMap._rexp.finditer(string)] 94 | 95 | @staticmethod 96 | def _parse(tokens): 97 | '''Parse cmsl tokens into a list of color nodes.''' 98 | nodes = [] 99 | log = [] 100 | for token in tokens: 101 | if token.type == 'index' and not log: 102 | log.append(token) 103 | elif token.type == 'adsep' and len(log) == 1 and log[-1].type == 'index': 104 | log.append(token) 105 | elif token.type in ('shortcolor', 'longcolor'): 106 | if len(log) == 2 and log[-2].type == 'index': 107 | indval = log[-2].value 108 | if len(indval) == 1: 109 | indval = indval * 2 110 | index = int(indval, base=16) 111 | elif not log: 112 | index = None 113 | else: 114 | raise RuntimeError(f'Unexpected {token}') 115 | 116 | value_it = iter(token.value) if token.type == 'longcolor' else token.value 117 | value = [int(''.join(chars), base=16) for chars in zip(*[value_it] * 2)] 118 | nodes.append(ColorNode(index, np.array(value))) 119 | log.append(token) 120 | elif token.type == 'sep' and log and log[-1].type in ('longcolor', 'shortcolor'): 121 | log.clear() 122 | elif token.type != 'whitespace': 123 | raise RuntimeError(f'Unexpected {token}') 124 | 125 | if log and log[-1].type not in ('shortcolor', 'longcolor'): 126 | raise RuntimeError(f'Unexpected {log[-1]}') 127 | 128 | return nodes 129 | 130 | @staticmethod 131 | def _make_palette(nodes): 132 | '''Generate color map indices and colors from a list of color nodes.''' 133 | if len(nodes) < 2: 134 | raise RuntimeError("ColorMap needs at least 2 colors!") 135 | result = [] 136 | log = [] 137 | 138 | start = nodes.pop(0) 139 | result.append(ColorNode(0, start.value)) 140 | if start.index is not None and start.index > 0: 141 | result.append(start) 142 | 143 | for n, node in enumerate(nodes): 144 | if node.index is None: 145 | if n < len(nodes) - 1: 146 | log.append(node) 147 | continue 148 | node = ColorNode(255, node.value) 149 | elif node.index < result[-1].index: 150 | raise RuntimeError('ColorMap indices not ordered! Provided indices are required in ascending order.') 151 | if log: 152 | result += [ 153 | ColorNode( 154 | int(result[-1].index * (1. - alpha) + node.index * alpha), 155 | lognode.value 156 | ) for alpha, lognode in zip(np.linspace(0., 1., len(log) + 2)[1:-1], log) 157 | ] 158 | log.clear() 159 | result.append(node) 160 | 161 | result.append(ColorNode(256, result[-1].value)) 162 | 163 | indices = np.array([node.index for node in result]) 164 | colors = np.stack([node.value for node in result], axis=0) 165 | 166 | return indices, colors 167 | 168 | def __call__(self, x): 169 | '''Map scalar values in the range [0, 1] to RGB. This appends an axis with size 3 to `x`. Values are clipped to 170 | the range [0, 1], and the output will also lie in this range. 171 | 172 | Parameters 173 | ---------- 174 | x : obj:`numpy.ndarray` 175 | Input array of arbitrary shape, which will be clipped to range [0, 1], and mapped to RGB using this 176 | ColorMap. 177 | 178 | Returns 179 | ------- 180 | obj:`numpy.ndarray` 181 | The input array `x`, clipped to [0, 1] and mapped to RGB given this colormap, where the 3 color channels 182 | are appended as a new axis to the end. 183 | ''' 184 | x = (x * 255).clip(0, 255) 185 | index = np.searchsorted(self._indices[:-1], x, side='right') 186 | alpha = ((x - self._indices[index - 1]) / (self._indices[index] - self._indices[index - 1]))[..., None] 187 | return (self._colors[index - 1] * (1 - alpha) + self._colors[index] * alpha) / 255. 188 | 189 | def palette(self, level=1.0): 190 | '''Create an 8-bit palette. 191 | 192 | Parameters 193 | ---------- 194 | level: float 195 | The level of the color map. 1.0 is default. Values below zero reduce the color range, with only a single 196 | color used at value 0.0. Values above 1.0 clip the value earlier towards the maximum, with an increasingly 197 | steep transition at the center of the image. 198 | 199 | Returns 200 | ------- 201 | obj:`numpy.ndarray` 202 | The palette described by an unsigned 8-bit numpy array with 256 entries. 203 | ''' 204 | x = np.linspace(-1., 1., 256, dtype=np.float64) * level 205 | x = ((x + 1.) / 2.).clip(0., 1.) 206 | x = self(x) 207 | x = (x * 255.).round(12).clip(0., 255.).astype(np.uint8) 208 | return x 209 | 210 | 211 | class LazyColorMapCache: 212 | '''Dict-like object to store sources for colormaps, and compile and cache them lazily. 213 | 214 | Parameters 215 | ---------- 216 | sources : dict 217 | Dict containing a mapping from names to color map specification language source. 218 | ''' 219 | def __init__(self, sources): 220 | self._sources = sources 221 | self._compiled = {} 222 | 223 | def __getitem__(self, name): 224 | if name not in self._sources: 225 | raise KeyError(f'No source for key {name}.') 226 | if name not in self._compiled: 227 | self._compiled[name] = ColorMap(self._sources[name]) 228 | return self._compiled[name] 229 | 230 | def __setitem__(self, name, value): 231 | self._sources[name] = value 232 | if name in self._compiled: 233 | self._compiled[name].source = value 234 | 235 | def __delitem__(self, name): 236 | del self._sources[name] 237 | if name in self._compiled: 238 | del self._compiled[name] 239 | 240 | def __iter__(self): 241 | return iter(self._sources) 242 | 243 | def __len__(self): 244 | return len(self._sources) 245 | -------------------------------------------------------------------------------- /src/zennit/layer.py: -------------------------------------------------------------------------------- 1 | # This file is part of Zennit 2 | # Copyright (C) 2019-2021 Christopher J. Anders 3 | # 4 | # zennit/layer.py 5 | # 6 | # Zennit is free software: you can redistribute it and/or modify it under 7 | # the terms of the GNU Lesser General Public License as published by the Free 8 | # Software Foundation; either version 3 of the License, or (at your option) any 9 | # later version. 10 | # 11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT 12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for 14 | # more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this library. If not, see . 18 | '''Additional Utility Layers''' 19 | import torch 20 | 21 | 22 | class Sum(torch.nn.Module): 23 | '''Compute the sum along an axis. 24 | 25 | Parameters 26 | ---------- 27 | dim : int 28 | Dimension over which to sum. 29 | ''' 30 | def __init__(self, dim=-1): 31 | super().__init__() 32 | self.dim = dim 33 | 34 | def forward(self, input): 35 | '''Computes the sum along a dimension.''' 36 | return torch.sum(input, dim=self.dim) 37 | -------------------------------------------------------------------------------- /src/zennit/torchvision.py: -------------------------------------------------------------------------------- 1 | # This file is part of Zennit 2 | # Copyright (C) 2019-2021 Christopher J. Anders 3 | # 4 | # zennit/torchvision.py 5 | # 6 | # Zennit is free software: you can redistribute it and/or modify it under 7 | # the terms of the GNU Lesser General Public License as published by the Free 8 | # Software Foundation; either version 3 of the License, or (at your option) any 9 | # later version. 10 | # 11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT 12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for 14 | # more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this library. If not, see . 18 | '''Specialized Canonizers for models from torchvision.''' 19 | import torch 20 | from torchvision.models.resnet import Bottleneck as ResNetBottleneck, BasicBlock as ResNetBasicBlock 21 | 22 | from .canonizers import SequentialMergeBatchNorm, AttributeCanonizer, CompositeCanonizer 23 | from .layer import Sum 24 | 25 | 26 | class VGGCanonizer(SequentialMergeBatchNorm): 27 | '''Canonizer for torchvision.models.vgg* type models. This is so far identical to a SequentialMergeBatchNorm''' 28 | 29 | 30 | class ResNetBottleneckCanonizer(AttributeCanonizer): 31 | '''Canonizer specifically for Bottlenecks of torchvision.models.resnet* type models.''' 32 | def __init__(self): 33 | super().__init__(self._attribute_map) 34 | 35 | @classmethod 36 | def _attribute_map(cls, name, module): 37 | '''Create a forward function and a Sum module to overload as new attributes for module. 38 | 39 | Parameters 40 | ---------- 41 | name : string 42 | Name by which the module is identified. 43 | module : obj:`torch.nn.Module` 44 | Instance of a module. If this is a Bottleneck layer, the appropriate attributes to overload are returned. 45 | 46 | Returns 47 | ------- 48 | None or dict 49 | None if `module` is not an instance of Bottleneck, otherwise the appropriate attributes to overload onto 50 | the module instance. 51 | ''' 52 | if isinstance(module, ResNetBottleneck): 53 | attributes = { 54 | 'forward': cls.forward.__get__(module), 55 | 'canonizer_sum': Sum(), 56 | } 57 | return attributes 58 | return None 59 | 60 | @staticmethod 61 | def forward(self, x): 62 | '''Modified Bottleneck forward for ResNet.''' 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out = torch.stack([identity, out], dim=-1) 80 | out = self.canonizer_sum(out) 81 | 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNetBasicBlockCanonizer(AttributeCanonizer): 88 | '''Canonizer specifically for BasicBlocks of torchvision.models.resnet* type models.''' 89 | def __init__(self): 90 | super().__init__(self._attribute_map) 91 | 92 | @classmethod 93 | def _attribute_map(cls, name, module): 94 | '''Create a forward function and a Sum module to overload as new attributes for module. 95 | 96 | Parameters 97 | ---------- 98 | name : string 99 | Name by which the module is identified. 100 | module : obj:`torch.nn.Module` 101 | Instance of a module. If this is a BasicBlock layer, the appropriate attributes to overload are returned. 102 | 103 | Returns 104 | ------- 105 | None or dict 106 | None if `module` is not an instance of BasicBlock, otherwise the appropriate attributes to overload onto 107 | the module instance. 108 | ''' 109 | if isinstance(module, ResNetBasicBlock): 110 | attributes = { 111 | 'forward': cls.forward.__get__(module), 112 | 'canonizer_sum': Sum(), 113 | } 114 | return attributes 115 | return None 116 | 117 | @staticmethod 118 | def forward(self, x): 119 | '''Modified BasicBlock forward for ResNet.''' 120 | identity = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | 129 | if self.downsample is not None: 130 | identity = self.downsample(x) 131 | 132 | out = torch.stack([identity, out], dim=-1) 133 | out = self.canonizer_sum(out) 134 | 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class ResNetCanonizer(CompositeCanonizer): 141 | '''Canonizer for torchvision.models.resnet* type models. This applies SequentialMergeBatchNorm, as well as 142 | add a Sum module to the Bottleneck modules and overload their forward method to use the Sum module instead of 143 | simply adding two tensors, such that forward and backward hooks may be applied.''' 144 | def __init__(self): 145 | super().__init__(( 146 | SequentialMergeBatchNorm(), 147 | ResNetBottleneckCanonizer(), 148 | ResNetBasicBlockCanonizer(), 149 | )) 150 | -------------------------------------------------------------------------------- /src/zennit/types.py: -------------------------------------------------------------------------------- 1 | # This file is part of Zennit 2 | # Copyright (C) 2019-2021 Christopher J. Anders 3 | # 4 | # zennit/types.py 5 | # 6 | # Zennit is free software: you can redistribute it and/or modify it under 7 | # the terms of the GNU Lesser General Public License as published by the Free 8 | # Software Foundation; either version 3 of the License, or (at your option) any 9 | # later version. 10 | # 11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT 12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for 14 | # more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this library. If not, see . 18 | '''Type definitions for convenience.''' 19 | import torch 20 | 21 | 22 | class SubclassMeta(type): 23 | '''Meta class to bundle multiple subclasses.''' 24 | def __instancecheck__(cls, inst): 25 | """Implement isinstance(inst, cls) as subclasscheck.""" 26 | return cls.__subclasscheck__(type(inst)) 27 | 28 | def __subclasscheck__(cls, sub): 29 | """Implement issubclass(sub, cls) with by considering additional __subclass__ members.""" 30 | candidates = cls.__dict__.get("__subclass__", tuple()) 31 | return type.__subclasscheck__(cls, sub) or issubclass(sub, candidates) 32 | 33 | 34 | class ConvolutionTranspose(metaclass=SubclassMeta): 35 | '''Abstract base class that describes transposed convolutional modules.''' 36 | __subclass__ = ( 37 | torch.nn.modules.conv.ConvTranspose1d, 38 | torch.nn.modules.conv.ConvTranspose2d, 39 | torch.nn.modules.conv.ConvTranspose3d, 40 | ) 41 | 42 | 43 | class ConvolutionStandard(metaclass=SubclassMeta): 44 | '''Abstract base class that describes standard (forward) convolutional modules.''' 45 | __subclass__ = ( 46 | torch.nn.modules.conv.Conv1d, 47 | torch.nn.modules.conv.Conv2d, 48 | torch.nn.modules.conv.Conv3d, 49 | ) 50 | 51 | 52 | class Convolution(metaclass=SubclassMeta): 53 | '''Abstract base class that describes all convolutional modules.''' 54 | __subclass__ = ( 55 | ConvolutionStandard, 56 | ConvolutionTranspose, 57 | ) 58 | 59 | 60 | class Linear(metaclass=SubclassMeta): 61 | '''Abstract base class that describes linear modules.''' 62 | __subclass__ = ( 63 | Convolution, 64 | torch.nn.modules.linear.Linear, 65 | ) 66 | 67 | 68 | class BatchNorm(metaclass=SubclassMeta): 69 | '''Abstract base class that describes batch normalization modules.''' 70 | __subclass__ = ( 71 | torch.nn.modules.batchnorm.BatchNorm1d, 72 | torch.nn.modules.batchnorm.BatchNorm2d, 73 | torch.nn.modules.batchnorm.BatchNorm3d, 74 | ) 75 | 76 | 77 | class AvgPool(metaclass=SubclassMeta): 78 | '''Abstract base class that describes sum-pooling modules.''' 79 | __subclass__ = ( 80 | torch.nn.modules.pooling.AvgPool1d, 81 | torch.nn.modules.pooling.AvgPool2d, 82 | torch.nn.modules.pooling.AvgPool3d, 83 | torch.nn.modules.pooling.AdaptiveAvgPool1d, 84 | torch.nn.modules.pooling.AdaptiveAvgPool2d, 85 | torch.nn.modules.pooling.AdaptiveAvgPool3d, 86 | ) 87 | 88 | 89 | class MaxPool(metaclass=SubclassMeta): 90 | '''Abstract base class that describes max-pooling modules.''' 91 | __subclass__ = ( 92 | torch.nn.modules.pooling.MaxPool1d, 93 | torch.nn.modules.pooling.MaxPool2d, 94 | torch.nn.modules.pooling.MaxPool3d, 95 | torch.nn.modules.pooling.AdaptiveMaxPool1d, 96 | torch.nn.modules.pooling.AdaptiveMaxPool2d, 97 | torch.nn.modules.pooling.AdaptiveMaxPool3d, 98 | ) 99 | 100 | 101 | class Activation(metaclass=SubclassMeta): 102 | '''Abstract base class that describes activation modules.''' 103 | __subclass__ = ( 104 | torch.nn.modules.activation.ELU, 105 | torch.nn.modules.activation.Hardshrink, 106 | torch.nn.modules.activation.Hardsigmoid, 107 | torch.nn.modules.activation.Hardtanh, 108 | torch.nn.modules.activation.Hardswish, 109 | torch.nn.modules.activation.LeakyReLU, 110 | torch.nn.modules.activation.LogSigmoid, 111 | torch.nn.modules.activation.PReLU, 112 | torch.nn.modules.activation.ReLU, 113 | torch.nn.modules.activation.ReLU6, 114 | torch.nn.modules.activation.RReLU, 115 | torch.nn.modules.activation.SELU, 116 | torch.nn.modules.activation.CELU, 117 | torch.nn.modules.activation.GELU, 118 | torch.nn.modules.activation.Sigmoid, 119 | torch.nn.modules.activation.SiLU, 120 | torch.nn.modules.activation.Softplus, 121 | torch.nn.modules.activation.Softshrink, 122 | torch.nn.modules.activation.Softsign, 123 | torch.nn.modules.activation.Tanh, 124 | torch.nn.modules.activation.Tanhshrink, 125 | torch.nn.modules.activation.Threshold, 126 | ) 127 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | '''Configuration and fixtures for testing''' 2 | import random 3 | from itertools import product, groupby 4 | from collections import OrderedDict 5 | 6 | import pytest 7 | import torch 8 | from torch.nn import Conv1d, ConvTranspose1d, Linear 9 | from torch.nn import Conv2d, ConvTranspose2d 10 | from torch.nn import Conv3d, ConvTranspose3d 11 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d 12 | from torchvision.models import vgg11, resnet18, alexnet 13 | from helpers import prodict, one_hot_max 14 | 15 | from zennit.attribution import identity 16 | from zennit.core import Composite, Hook 17 | from zennit.composites import COMPOSITES 18 | from zennit.composites import EpsilonGammaBox 19 | from zennit.composites import LayerMapComposite 20 | from zennit.composites import MixedComposite 21 | from zennit.composites import NameLayerMapComposite 22 | from zennit.composites import NameMapComposite 23 | from zennit.composites import SpecialFirstLayerMapComposite 24 | from zennit.types import Linear as AnyLinear, Activation 25 | 26 | 27 | def pytest_addoption(parser): 28 | '''Add options to pytest.''' 29 | parser.addoption( 30 | '--batchsize', 31 | default=4, 32 | help='Batch-size for generated samples.' 33 | ) 34 | 35 | 36 | def pytest_generate_tests(metafunc): 37 | '''Generate test fixture values based on CLI options.''' 38 | if 'batchsize' in metafunc.fixturenames: 39 | metafunc.parametrize('batchsize', [metafunc.config.getoption('batchsize')], scope='session') 40 | 41 | 42 | @pytest.fixture( 43 | scope='session', 44 | params=[ 45 | 0xdeadbeef, 46 | 0xd0c0ffee, 47 | *[pytest.param(seed, marks=pytest.mark.extended) for seed in [ 48 | 0xc001bee5, 0xc01dfee7, 0xbe577001, 0xca7b0075, 0x1057b0a7, 0x900ddeed 49 | ]], 50 | ], 51 | ids=hex 52 | ) 53 | def rng(request): 54 | '''Fixture for the NumPy random number generator.''' 55 | return torch.manual_seed(request.param) 56 | 57 | 58 | @pytest.fixture(scope='session') 59 | def pyrng(rng): 60 | '''Fixture for the Python random number generator.''' 61 | return random.Random(rng.initial_seed()) 62 | 63 | 64 | @pytest.fixture( 65 | scope='session', 66 | params=[ 67 | (torch.nn.ReLU, {}), 68 | (torch.nn.Softmax, {'dim': 1}), 69 | (torch.nn.Tanh, {}), 70 | (torch.nn.Sigmoid, {}), 71 | (torch.nn.Softplus, {'beta': 1}), 72 | ], 73 | ids=lambda param: param[0].__name__ 74 | ) 75 | def module_simple(rng, request): 76 | '''Fixture for simple modules.''' 77 | module_type, kwargs = request.param 78 | return module_type(**kwargs).to(torch.float64).eval() 79 | 80 | 81 | @pytest.fixture( 82 | scope='session', 83 | params=[ 84 | *product( 85 | [Linear], 86 | prodict(in_features=[16], out_features=[15], bias=[True, False]), 87 | ), 88 | *product( 89 | [Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d], 90 | prodict(in_channels=[1, 3], out_channels=[1, 3], kernel_size=[2, 3], bias=[True, False]), 91 | ), 92 | ], 93 | ids=lambda param: param[0].__name__ 94 | ) 95 | def module_linear(rng, request): 96 | '''Fixture for linear modules.''' 97 | module_type, kwargs = request.param 98 | return module_type(**kwargs).to(torch.float64).eval() 99 | 100 | 101 | @pytest.fixture(scope='session') 102 | def module_batchnorm(module_linear, rng): 103 | '''Fixture for BatchNorm-type modules, based on adjacent linear module.''' 104 | module_map = [ 105 | ((Linear, Conv1d, ConvTranspose1d), BatchNorm1d), 106 | ((Conv2d, ConvTranspose2d), BatchNorm2d), 107 | ((Conv3d, ConvTranspose3d), BatchNorm3d), 108 | ] 109 | feature_index_map = [ 110 | ((ConvTranspose1d, ConvTranspose2d, ConvTranspose3d), 1), 111 | ((Linear, Conv1d, Conv2d, Conv3d), 0), 112 | ] 113 | 114 | batchnorm_type = None 115 | for types, target_type in module_map: 116 | if isinstance(module_linear, types): 117 | batchnorm_type = target_type 118 | break 119 | if batchnorm_type is None: 120 | raise RuntimeError('No batchnorm type for linear layer found.') 121 | 122 | feature_index = None 123 | for types, index in feature_index_map: 124 | if isinstance(module_linear, types): 125 | feature_index = index 126 | break 127 | if feature_index is None: 128 | raise RuntimeError('No feature index for linear layer found.') 129 | 130 | batchnorm = batchnorm_type(num_features=module_linear.weight.shape[feature_index]).to(torch.float64).eval() 131 | batchnorm.weight.data.uniform_(**{'from': 0.1, 'to': 2.0, 'generator': rng}) 132 | batchnorm.bias.data.normal_(generator=rng) 133 | batchnorm.eps = 1e-30 134 | return batchnorm 135 | 136 | 137 | @pytest.fixture(scope='session') 138 | def data_linear(rng, batchsize, module_linear): 139 | '''Fixture to create data for a linear module, given an RNG.''' 140 | shape = (batchsize,) 141 | setups = [ 142 | (Conv1d, 1, 1), 143 | (ConvTranspose1d, 0, 1), 144 | (Conv2d, 1, 2), 145 | (ConvTranspose2d, 0, 2), 146 | (Conv3d, 1, 3), 147 | (ConvTranspose3d, 0, 3) 148 | ] 149 | if isinstance(module_linear, Linear): 150 | shape += (module_linear.weight.shape[1],) 151 | else: 152 | for module_type, dim, ndims in setups: 153 | if isinstance(module_linear, module_type): 154 | shape += (module_linear.weight.shape[dim],) + (4,) * ndims 155 | 156 | return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng) 157 | 158 | 159 | @pytest.fixture(scope='session', params=[ 160 | (16,), 161 | (4,), 162 | (4, 4), 163 | (4, 4, 4), 164 | ]) 165 | def data_simple(request, rng, batchsize): 166 | '''Fixture to create data for a linear module, given an RNG.''' 167 | shape = (batchsize,) + request.param 168 | return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng) 169 | 170 | 171 | COMPOSITE_KWARGS = { 172 | EpsilonGammaBox: {'low': -3., 'high': 3.}, 173 | } 174 | 175 | 176 | class PassClone(Hook): 177 | '''Clone of the Pass rule.''' 178 | def backward(self, module, grad_input, grad_output): 179 | '''Directly return grad_output.''' 180 | return grad_output 181 | 182 | 183 | class GradClone(Hook): 184 | '''Explicit rule to return the cloned gradient.''' 185 | def backward(self, module, grad_input, grad_output): 186 | '''Directly return grad_output.''' 187 | return grad_input.clone() 188 | 189 | 190 | @pytest.fixture(scope='session', params=[ 191 | None, 192 | [(Linear, GradClone()), (Activation, PassClone())], 193 | ]) 194 | def cooperative_layer_map(request): 195 | '''Fixture for a cooperative layer map in LayerMapComposite subtypes.''' 196 | return request.param 197 | 198 | 199 | @pytest.fixture(scope='session', params=[ 200 | None, 201 | [(AnyLinear, GradClone())], 202 | ]) 203 | def cooperative_first_map(request): 204 | '''Fixture for a cooperative layer map for the first layer in SpecialFirstLayerMapComposite subtypes.''' 205 | return request.param 206 | 207 | 208 | @pytest.fixture(scope='session', params=[ 209 | elem for elem in COMPOSITES.values() 210 | if issubclass(elem, LayerMapComposite) and not issubclass(elem, SpecialFirstLayerMapComposite) 211 | ]) 212 | def layer_map_composite(request, cooperative_layer_map): 213 | '''Fixture for explicit LayerMapComposites.''' 214 | return request.param(layer_map=cooperative_layer_map, **COMPOSITE_KWARGS.get(request.param, {})) 215 | 216 | 217 | @pytest.fixture(scope='session', params=[ 218 | elem for elem in COMPOSITES.values() if issubclass(elem, SpecialFirstLayerMapComposite) 219 | ]) 220 | def special_first_layer_map_composite(request, cooperative_layer_map, cooperative_first_map): 221 | '''Fixturer for explicit SpecialFirstLayerMapComposites.''' 222 | return request.param( 223 | layer_map=cooperative_layer_map, 224 | first_map=cooperative_first_map, 225 | **COMPOSITE_KWARGS.get(request.param, {}) 226 | ) 227 | 228 | 229 | @pytest.fixture(scope='session', params=[Composite, *COMPOSITES.values()]) 230 | def any_composite(request): 231 | '''Fixture for all explicitly registered Composites, as well as the empty Composite.''' 232 | return request.param(**COMPOSITE_KWARGS.get(request.param, {})) 233 | 234 | 235 | @pytest.fixture(scope='session') 236 | def name_map_composite(model_vision, layer_map_composite): 237 | '''Fixture to create NameMapComposites based on explicit LayerMapComposites.''' 238 | rule_map = {} 239 | for name, child in model_vision.named_modules(): 240 | for dtype, hook_template in layer_map_composite.layer_map: 241 | if isinstance(child, dtype): 242 | rule_map.setdefault(hook_template, []).append(name) 243 | break 244 | name_map = [(tuple(value), key) for key, value in rule_map.items()] 245 | return NameMapComposite(name_map=name_map) 246 | 247 | 248 | @pytest.fixture(scope='session') 249 | def partial_name_map_composite(name_map_composite, pyrng): 250 | '''Fixture to create a randomly sampled partial NameMapComposites.''' 251 | name_map = name_map_composite.name_map 252 | assocs = [(i, j) for i, (keys, _) in enumerate(name_map) for j in range(len(keys))] 253 | accepted_assocs = sorted(pyrng.sample(assocs, len(assocs) // 2)) 254 | partial_name_map = [ 255 | (tuple(name_map[k][0][n] for _, n in g), name_map[k][1].copy()) 256 | for k, g in groupby(accepted_assocs, lambda o: o[0]) 257 | ] 258 | 259 | return NameMapComposite(name_map=partial_name_map) 260 | 261 | 262 | @pytest.fixture(scope='session') 263 | def mixed_composite(partial_name_map_composite, special_first_layer_map_composite): 264 | '''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and 265 | SpecialFirstLayerMapComposites. 266 | ''' 267 | composites = [partial_name_map_composite, special_first_layer_map_composite] 268 | return MixedComposite(composites) 269 | 270 | 271 | @pytest.fixture(scope='session') 272 | def name_layer_map_composite(partial_name_map_composite, layer_map_composite): 273 | '''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and LayerMapComposite.''' 274 | return NameLayerMapComposite( 275 | name_map=partial_name_map_composite.name_map, 276 | layer_map=layer_map_composite.layer_map, 277 | ) 278 | 279 | 280 | @pytest.fixture(scope='session', params=[alexnet, vgg11, resnet18]) 281 | def model_vision(request): 282 | '''Models to test composites on.''' 283 | return request.param() 284 | 285 | 286 | @pytest.fixture(scope='session') 287 | def model_simple(rng, module_linear, data_linear): 288 | '''Fixture for a simple model, using a linear module followed by a ReLU and a dense layer.''' 289 | with torch.no_grad(): 290 | intermediate = module_linear(data_linear) 291 | return torch.nn.Sequential(OrderedDict([ 292 | ('linr0', module_linear), 293 | ('actv0', torch.nn.ReLU()), 294 | ('flat0', torch.nn.Flatten()), 295 | ('linr1', torch.nn.Linear(intermediate.shape[1:].numel(), 4, dtype=intermediate.dtype)), 296 | ])) 297 | 298 | 299 | @pytest.fixture(scope='session') 300 | def model_simple_grad(data_linear, model_simple): 301 | '''Fixture for gradient wrt. data_linear for model_simple.''' 302 | data = data_linear.detach().requires_grad_() 303 | output = model_simple(data) 304 | grad, = torch.autograd.grad(output, data, output) 305 | return grad 306 | 307 | 308 | @pytest.fixture(scope='session') 309 | def model_simple_output(data_linear, model_simple): 310 | '''Fixture for output given data_linear for model_simple.''' 311 | data = data_linear.detach() 312 | output = model_simple(data) 313 | return output 314 | 315 | 316 | @pytest.fixture(scope='session', params=[ 317 | identity, 318 | one_hot_max, 319 | torch.ones_like, 320 | ]) 321 | def grad_outputs_func(request): 322 | '''Fixture for common attr_output_fn functions.''' 323 | return request.param 324 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | '''Helper functions for various tests.''' 2 | from itertools import product 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from zennit.types import BatchNorm 8 | 9 | 10 | def prodict(**kwargs): 11 | '''Create a dictionary with values which are the cartesian product of the input keyword arguments.''' 12 | return [dict(zip(kwargs, val)) for val in product(*kwargs.values())] 13 | 14 | 15 | def one_hot_max(output): 16 | '''Get the one-hot encoded max.''' 17 | return torch.sparse_coo_tensor( 18 | [*zip(np.unravel_index(output.argmax(), output.shape))], [1.], output.shape, dtype=output.dtype 19 | ).to_dense() 20 | 21 | 22 | def assert_identity_hook(equal=True, message=''): 23 | '''Create an assertion hook which checks whether the module does or does not modify its input.''' 24 | def assert_identity(module, input, output): 25 | '''Assert whether the module does or does not modify its input.''' 26 | assert equal == torch.allclose(input[0], output, rtol=1e-5), message 27 | return assert_identity 28 | 29 | 30 | def randomize_bnorm(model): 31 | '''Randomize all BatchNorm module parameters of a model.''' 32 | for module in model.modules(): 33 | if isinstance(module, BatchNorm): 34 | module.weight.data.uniform_(0.1, 2.0) 35 | module.running_var.data.uniform_(0.1, 2.0) 36 | module.bias.data.normal_() 37 | module.running_mean.data.normal_() 38 | # smaller eps to reduce error 39 | module.eps = 1e-30 40 | return model 41 | 42 | 43 | def nograd(model): 44 | '''Unset grad requirement for all model parameters.''' 45 | for param in model.parameters(): 46 | param.requires_grad = False 47 | return model 48 | -------------------------------------------------------------------------------- /tests/test_attribution.py: -------------------------------------------------------------------------------- 1 | '''Tests for Attributors.''' 2 | from functools import partial 3 | from itertools import product 4 | 5 | import pytest 6 | import torch 7 | 8 | from zennit.attribution import Gradient, IntegratedGradients, SmoothGrad, Occlusion, occlude_independent 9 | 10 | 11 | class IdentityLogger(torch.nn.Module): 12 | '''Helper-Module to log input tensors.''' 13 | def __init__(self): 14 | super().__init__() 15 | self.tensors = [] 16 | 17 | def forward(self, input): 18 | '''Clone input, append to self.tensors and return the cloned tensor.''' 19 | self.tensors.append(input.clone()) 20 | return self.tensors[-1] 21 | 22 | 23 | def test_gradient_attributor_inactive( 24 | data_linear, model_simple, model_simple_output, any_composite, grad_outputs_func 25 | ): 26 | '''Test whether composite context and attributor match for Gradient.''' 27 | 28 | with Gradient(model=model_simple, composite=any_composite, attr_output=grad_outputs_func) as attributor: 29 | # verify that all hooks are active 30 | assert all(hook.active for hook in attributor.composite.hook_refs) 31 | with attributor.inactive(): 32 | # verify that all hooks are inactive 33 | assert all(not hook.active for hook in attributor.composite.hook_refs) 34 | 35 | 36 | def test_gradient_attributor_composite( 37 | data_linear, model_simple, model_simple_output, any_composite, grad_outputs_func 38 | ): 39 | '''Test whether composite context and attributor match for Gradient.''' 40 | with any_composite.context(model_simple) as module: 41 | data = data_linear.detach().requires_grad_() 42 | output_context = module(data) 43 | grad_outputs = grad_outputs_func(output_context) 44 | grad_context, = torch.autograd.grad(output_context, data, grad_outputs) 45 | 46 | with Gradient(model=model_simple, composite=any_composite, attr_output=grad_outputs_func) as attributor: 47 | output_attributor, grad_attributor = attributor(data_linear) 48 | 49 | assert torch.allclose(output_context, output_attributor) 50 | assert torch.allclose(grad_context, grad_attributor) 51 | assert torch.allclose(model_simple_output, output_attributor) 52 | 53 | 54 | @pytest.mark.parametrize('use_const,use_call,use_init', product(*[[True, False]] * 3)) 55 | def test_gradient_attributor_output_fn(data_simple, grad_outputs_func, use_const, use_call, use_init): 56 | '''Test whether attributors' attr_output supports functions, constants and None in any of supplied or not supplied 57 | for each the attributor initialization and the call. 58 | ''' 59 | model = IdentityLogger() 60 | 61 | attr_output = grad_outputs_func(data_simple) if use_const else grad_outputs_func 62 | init_attr_output = attr_output if use_init else None 63 | call_attr_output = attr_output if use_call else None 64 | 65 | with Gradient(model=model, attr_output=init_attr_output) as attributor: 66 | _, grad = attributor(data_simple, attr_output=call_attr_output) 67 | 68 | if (use_call or use_init): 69 | expected_grad = grad_outputs_func(data_simple) 70 | else: 71 | # the identity is the default attr_output 72 | expected_grad = data_simple 73 | 74 | assert torch.allclose(expected_grad, grad), 'Attributor output function gradient mismatch!' 75 | 76 | 77 | def test_gradient_attributor_grad(data_simple): 78 | '''Test whether the gradient of Gradient matches.''' 79 | model = torch.nn.Softplus(beta=1.) 80 | data = data_simple.view_as(data_simple).requires_grad_() 81 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data)) 82 | 83 | with Gradient(model=model, create_graph=True) as attributor: 84 | _, grad = attributor(data, torch.ones_like) 85 | gradgrad, = torch.autograd.grad(grad.sum(), data) 86 | 87 | assert torch.allclose(gradgrad, target), 'Gradient Attributor second order gradient mismatch!' 88 | 89 | 90 | def test_gradient_attributor_output_fn_precedence(data_simple): 91 | '''Test whether the gradient attributor attr_output at call is preferred when it is supplied at both initialization 92 | and call. 93 | ''' 94 | model = IdentityLogger() 95 | 96 | init_attr_output = torch.ones_like 97 | call_attr_output = torch.zeros_like 98 | 99 | with Gradient(model=model, attr_output=init_attr_output) as attributor: 100 | _, grad = attributor(data_simple, attr_output=call_attr_output) 101 | 102 | expected_grad = call_attr_output(data_simple) 103 | assert torch.allclose(expected_grad, grad), 'Attributor output function precedence mismatch!' 104 | 105 | 106 | def test_smooth_grad_single(data_linear, model_simple, model_simple_output, model_simple_grad): 107 | '''Test whether SmoothGrad with a single iteration is equal to the gradient.''' 108 | with SmoothGrad(model=model_simple, noise_level=0.1, n_iter=1) as attributor: 109 | output, grad = attributor(data_linear) 110 | 111 | assert torch.allclose(model_simple_grad, grad) 112 | assert torch.allclose(model_simple_output, output) 113 | 114 | 115 | def test_smooth_grad_single_grad(data_simple): 116 | '''Test whether the gradient of SmoothGrad matches.''' 117 | model = torch.nn.Softplus(beta=1.) 118 | data = data_simple.view_as(data_simple).requires_grad_() 119 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data)) 120 | 121 | with SmoothGrad(model=model, noise_level=0.1, n_iter=1, create_graph=True) as attributor: 122 | _, grad = attributor(data, torch.ones_like) 123 | gradgrad, = torch.autograd.grad(grad.sum(), data) 124 | 125 | assert torch.allclose(gradgrad, target), 'SmoothGrad Attributor second order gradient mismatch!' 126 | 127 | 128 | @pytest.mark.parametrize('noise_level', [0.0, 0.1, 0.3, 0.5]) 129 | def test_smooth_grad_distribution(data_simple, noise_level): 130 | '''Test whether the SmoothGrad sampled distribution matches.''' 131 | model = IdentityLogger() 132 | 133 | dims = tuple(range(1, data_simple.ndim)) 134 | noise_var = (noise_level * (data_simple.amax(dims) - data_simple.amin(dims))) ** 2 135 | n_iter = 100 136 | 137 | with SmoothGrad(model=model, noise_level=noise_level, n_iter=n_iter, attr_output=torch.ones_like) as attributor: 138 | _, grad = attributor(data_simple) 139 | 140 | assert len(model.tensors) == n_iter, 'SmoothGrad iterations did not match n_iter!' 141 | 142 | sample_mean = sum(model.tensors) / len(model.tensors) 143 | sample_var = ((sum((tensor - sample_mean) ** 2 for tensor in model.tensors) / len(model.tensors))).mean(dims) 144 | 145 | if noise_level > 0.: 146 | std_ratio = (sample_var / noise_var) ** .5 147 | assert (std_ratio < 1.5).all().item(), 'SmoothGrad sample variance is too high!' 148 | assert (std_ratio > 0.667).all().item(), 'SmoothGrad sample variance is too low!' 149 | else: 150 | assert (sample_var < 1e-9).all().item(), 'SmoothGrad sample variance is too high!' 151 | assert torch.allclose(grad, torch.ones_like(data_simple)), 'SmoothGrad of identity is wrong!' 152 | 153 | 154 | @pytest.mark.parametrize('baseline_fn', [None, torch.zeros_like, torch.ones_like]) 155 | def test_integrated_gradients_single(data_linear, model_simple, model_simple_output, model_simple_grad, baseline_fn): 156 | '''Test whether IntegratedGradients with a single iteration is equal to the expected output given multiple 157 | baselines. 158 | ''' 159 | with IntegratedGradients(model=model_simple, n_iter=1, baseline_fn=baseline_fn) as attributor: 160 | output, grad = attributor(data_linear) 161 | 162 | if baseline_fn is None: 163 | baseline_fn = torch.zeros_like 164 | expected_grad = model_simple_grad * (data_linear - baseline_fn(data_linear)) 165 | 166 | assert torch.allclose(expected_grad, grad), 'Gradient mismatch for IntegratedGradients!' 167 | assert torch.allclose(model_simple_output, output), 'Output mismatch for IntegratedGradients!' 168 | 169 | 170 | def test_integrated_gradients_single_grad(data_simple): 171 | '''Test whether the gradient of IntegratedGradients matches.''' 172 | model = torch.nn.Softplus(beta=1.) 173 | data = data_simple.view_as(data_simple).requires_grad_() 174 | # this is d/dx (x * d/dx softplus(x)), i.e. the gradient of input times gradient of softplus 175 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data)) * data + torch.sigmoid(data) 176 | 177 | with IntegratedGradients(model=model, n_iter=1, baseline_fn=torch.zeros_like, create_graph=True) as attributor: 178 | _, grad = attributor(data, torch.ones_like) 179 | gradgrad, = torch.autograd.grad(grad.sum(), data) 180 | 181 | assert torch.allclose(gradgrad, target), 'IntegratedGradients Attributor second order gradient mismatch!' 182 | 183 | 184 | def test_integrated_gradients_path(data_simple): 185 | '''Test whether IntegratedGradients with a single iteration and a zero-baseline is equal to the input times the 186 | gradient. 187 | ''' 188 | model = IdentityLogger() 189 | 190 | dims = tuple(range(1, data_simple.ndim)) 191 | n_iter = 100 192 | with IntegratedGradients(model=model, n_iter=n_iter, attr_output=torch.ones_like) as attributor: 193 | _, grad = attributor(data_simple) 194 | 195 | assert len(model.tensors) == n_iter, 'IntegratedGradients iterations did not match n_iter!' 196 | 197 | data_simple_norm = data_simple / (data_simple ** 2).sum(dim=dims, keepdim=True) ** .5 198 | assert all( 199 | torch.allclose(step / (step ** 2).sum(dim=dims, keepdim=True) ** .5, data_simple_norm) 200 | for step in model.tensors 201 | ), 'IntegratedGradients segments do not lie on path!' 202 | assert torch.allclose(data_simple, grad), 'IntegratedGradients of identity is wrong!' 203 | 204 | 205 | @pytest.mark.parametrize('window,stride', zip([1, 2, 4, (1,), (2,), (4,)], [1, 2, 4, (1,), (2,), (4,)])) 206 | def test_occlusion_disjunct(data_simple, window, stride): 207 | '''Function to test whether the inputs used for disjunct occlusion windows are correct.''' 208 | model = IdentityLogger() 209 | 210 | # delete everything except the window 211 | occlusion_fn = partial(occlude_independent, fill_fn=torch.zeros_like, invert=False) 212 | 213 | with Occlusion(model=model, window=window, stride=stride, occlusion_fn=occlusion_fn) as attributor: 214 | attributor(data_simple) 215 | 216 | # omit final pass for full output 217 | reconstruct = sum(model.tensors[:-1]) 218 | assert torch.allclose(data_simple, reconstruct), 'Disjunct occlusion does not sum to original input!' 219 | 220 | 221 | @pytest.mark.parametrize( 222 | 'fill_fn,invert', [ 223 | (None, False), 224 | (torch.zeros_like, False), 225 | (torch.zeros_like, True), 226 | (torch.ones_like, True), 227 | ] 228 | ) 229 | def test_occlusion_single(data_linear, model_simple, model_simple_output, grad_outputs_func, fill_fn, invert): 230 | '''Function to test whether the inputs used for a full occlusion window are correct.''' 231 | window, stride = [data_linear.shape] * 2 232 | if fill_fn is None: 233 | # setting when no occlusion_fn is supplied 234 | occlusion_fn = None 235 | fill_fn = torch.zeros_like 236 | else: 237 | occlusion_fn = partial(occlude_independent, fill_fn=fill_fn, invert=invert) 238 | 239 | identity_logger = IdentityLogger() 240 | model = torch.nn.Sequential(identity_logger, model_simple) 241 | 242 | with Occlusion( 243 | model=model, 244 | window=window, 245 | stride=stride, 246 | attr_output=grad_outputs_func, 247 | occlusion_fn=occlusion_fn, 248 | ) as attributor: 249 | output, score = attributor(data_linear) 250 | 251 | expected_occluded = fill_fn(data_linear) if invert else data_linear 252 | expected_output = model_simple(expected_occluded) 253 | expected_score = grad_outputs_func(expected_output).sum( 254 | tuple(range(1, expected_output.ndim)) 255 | )[(slice(None),) + (None,) * (data_linear.ndim - 1)].expand_as(data_linear) 256 | 257 | assert len(identity_logger.tensors) == 2, 'Incorrect number of forward passes for Occlusion!' 258 | assert torch.allclose(identity_logger.tensors[0], expected_occluded), 'Occluded input mismatch!' 259 | assert torch.allclose(model_simple_output, output), 'Output mismatch for Occlusion!' 260 | assert torch.allclose(expected_score, score), 'Scores are incorrect for Occlusion!' 261 | 262 | 263 | @pytest.mark.parametrize('argument,container', product( 264 | ['window', 'stride'], 265 | ['monkey', {3}, ('you', 'are', 'breathtaking'), range(3), [3]] 266 | )) 267 | def test_occlusion_stride_window_typecheck(argument, container): 268 | '''Test whether Occlusion raises a TypeError on incorrect types for window and stride.''' 269 | with pytest.raises(TypeError): 270 | Occlusion(model=None, **{argument: container}) 271 | -------------------------------------------------------------------------------- /tests/test_canonizers.py: -------------------------------------------------------------------------------- 1 | '''Tests for canonizers''' 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import pytest 6 | import torch 7 | from torch.nn import Sequential 8 | from helpers import assert_identity_hook 9 | 10 | from zennit.canonizers import Canonizer, CompositeCanonizer 11 | from zennit.canonizers import SequentialMergeBatchNorm, NamedMergeBatchNorm, AttributeCanonizer 12 | from zennit.core import RemovableHandleList 13 | from zennit.types import BatchNorm 14 | 15 | 16 | def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_linear): 17 | '''Test whether the output of the merged batchnorm is consistent with its original output.''' 18 | output_linear_before = module_linear(data_linear) 19 | output_batchnorm_before = module_batchnorm(output_linear_before) 20 | canonizer = SequentialMergeBatchNorm() 21 | 22 | try: 23 | canonizer.register((module_linear,), module_batchnorm) 24 | output_linear_canonizer = module_linear(data_linear) 25 | output_batchnorm_canonizer = module_batchnorm(output_linear_canonizer) 26 | finally: 27 | canonizer.remove() 28 | 29 | output_linear_after = module_linear(data_linear) 30 | output_batchnorm_after = module_batchnorm(output_linear_after) 31 | 32 | assert all(torch.allclose(left, right, atol=1e-5) for left, right in [ 33 | (output_linear_before, output_linear_after), 34 | (output_batchnorm_before, output_batchnorm_after), 35 | (output_batchnorm_before, output_linear_canonizer), 36 | (output_linear_canonizer, output_batchnorm_canonizer), 37 | ]) 38 | 39 | 40 | @pytest.mark.parametrize('canonizer_fn', [ 41 | SequentialMergeBatchNorm, 42 | partial(NamedMergeBatchNorm, [(['dense0'], 'bnorm0')]), 43 | ]) 44 | def test_merge_batchnorm_apply(canonizer_fn, module_linear, module_batchnorm, data_linear): 45 | '''Test whether SequentialMergeBatchNorm merges BatchNorm modules correctly and keeps the output unchanged.''' 46 | model = Sequential(OrderedDict([ 47 | ('dense0', module_linear), 48 | ('bnorm0', module_batchnorm) 49 | ])) 50 | output_before = model(data_linear) 51 | 52 | handles = RemovableHandleList( 53 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!')) 54 | for module in model.modules() if isinstance(module, BatchNorm) 55 | ) 56 | 57 | canonizer = canonizer_fn() 58 | 59 | canonizer_handles = RemovableHandleList(canonizer.apply(model)) 60 | try: 61 | output_canonizer = model(data_linear) 62 | finally: 63 | handles.remove() 64 | canonizer_handles.remove() 65 | 66 | handles = RemovableHandleList( 67 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!')) 68 | for module in model.modules() if isinstance(module, BatchNorm) 69 | ) 70 | 71 | try: 72 | output_after = model(data_linear) 73 | finally: 74 | handles.remove() 75 | 76 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!' 77 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!' 78 | 79 | 80 | def test_attribute_canonizer(module_linear, data_linear): 81 | '''Test whether AttributeCanonizer overwrites and restores a linear module's forward correctly. ''' 82 | model = Sequential(OrderedDict([ 83 | ('dense0', module_linear), 84 | ])) 85 | output_before = model(data_linear) 86 | 87 | modules = [module_linear] 88 | module_type = type(module_linear) 89 | 90 | assert all( 91 | module.forward == module_type.forward.__get__(module) for module in modules 92 | ), 'Model has its forward already overwritten!' 93 | 94 | def attribute_map(name, module): 95 | if module is module_linear: 96 | return {'forward': lambda x: module_type.forward.__get__(module)(x) * 2} 97 | return None 98 | 99 | canonizer = AttributeCanonizer(attribute_map) 100 | 101 | handles = RemovableHandleList(canonizer.apply(model)) 102 | try: 103 | assert not any( 104 | module.forward == module_type.forward.__get__(module) for module in modules 105 | ), 'Model forward was not overwritten!' 106 | output_canonizer = model(data_linear) 107 | finally: 108 | handles.remove() 109 | 110 | output_after = model(data_linear) 111 | 112 | assert all( 113 | module.forward == module_type.forward.__get__(module) for module in modules 114 | ), 'Model forward was not restored!' 115 | assert torch.allclose(output_canonizer, output_before * 2, rtol=1e-5), 'Canonizer output mismatch after register!' 116 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!' 117 | 118 | 119 | def test_composite_canonizer(): 120 | '''Test whether CompositeCanonizer correctly combines two AttributeCanonizer canonizers.''' 121 | module_vanilla = torch.nn.Module() 122 | model = torch.nn.Sequential(module_vanilla) 123 | 124 | canonizer = CompositeCanonizer([ 125 | AttributeCanonizer(lambda name, module: {'_test_x': 13}), 126 | AttributeCanonizer(lambda name, module: {'_test_y': 13}), 127 | ]) 128 | 129 | handles = RemovableHandleList(canonizer.apply(model)) 130 | try: 131 | assert hasattr(module_vanilla, '_test_x'), 'Model attribute _test_x was not overwritten!' 132 | assert hasattr(module_vanilla, '_test_y'), 'Model attribute _test_y was not overwritten!' 133 | finally: 134 | handles.remove() 135 | 136 | assert not hasattr(module_vanilla, '_test_x'), 'Model attribute _test_x was not removed!' 137 | assert not hasattr(module_vanilla, '_test_y'), 'Model attribute _test_y was not removed!' 138 | 139 | 140 | def test_base_canonizer_cooperative_apply(): 141 | '''Test whether Canonizer's apply method is cooperative.''' 142 | 143 | class DummyCanonizer(Canonizer): 144 | '''Class to test Canonizer's cooperative apply.''' 145 | def apply(self, root_module): 146 | '''Cooperative apply which appends a string 'dummy' to the result of the parent class.''' 147 | instances = super().apply(root_module) 148 | instances += ['dummy'] 149 | return instances 150 | 151 | def register(self): 152 | '''No-op register for abstract method.''' 153 | 154 | def remove(self): 155 | '''No-op remove for abstract method.''' 156 | 157 | canonizer = DummyCanonizer() 158 | model = Sequential() 159 | instances = canonizer.apply(model) 160 | assert 'dummy' in instances, 'Unexpected canonizer instance list!' 161 | -------------------------------------------------------------------------------- /tests/test_cmap.py: -------------------------------------------------------------------------------- 1 | '''Tests for ColorMap and CMSL.''' 2 | from typing import NamedTuple 3 | import pytest 4 | import numpy as np 5 | 6 | from zennit.cmap import ColorMap, LazyColorMapCache 7 | 8 | 9 | class CMapExample(NamedTuple): 10 | '''Named tuple for example color maps used in tests.''' 11 | source: str 12 | nodes: list 13 | 14 | 15 | CMAPS = [ 16 | ('000,fff', [ 17 | (0x00, (0x00, 0x00, 0x00)), 18 | (0xff, (0xff, 0xff, 0xff)), 19 | ]), 20 | ('fff,f00', [ 21 | (0x00, (0xff, 0xff, 0xff)), 22 | (0xff, (0xff, 0x00, 0x00)), 23 | ]), 24 | ('fff,00f', [ 25 | (0x00, (0xff, 0xff, 0xff)), 26 | (0xff, (0x00, 0x00, 0xff)), 27 | ]), 28 | ('000,f00,ff0,fff', [ 29 | (0x00, (0x00, 0x00, 0x00)), 30 | (0x55, (0xff, 0x00, 0x00)), 31 | (0xaa, (0xff, 0xff, 0x00)), 32 | (0xff, (0xff, 0xff, 0xff)), 33 | ]), 34 | ('000,00f,0ff', [ 35 | (0x00, (0x00, 0x00, 0x00)), 36 | (0x7f, (0x00, 0x00, 0xff)), 37 | (0xff, (0x00, 0xff, 0xff)), 38 | ]), 39 | ('0ff,00f,80:000,f00,ff0,fff', [ 40 | (0x00, (0x00, 0xff, 0xff)), 41 | (0x40, (0x00, 0x00, 0xff)), 42 | (0x80, (0x00, 0x00, 0x00)), 43 | (0xaa, (0xff, 0x00, 0x00)), 44 | (0xd4, (0xff, 0xff, 0x00)), 45 | (0xff, (0xff, 0xff, 0xff)), 46 | ]), 47 | ('00f,80:fff,f00', [ 48 | (0x00, (0x00, 0x00, 0xff)), 49 | (0x80, (0xff, 0xff, 0xff)), 50 | (0xff, (0xff, 0x00, 0x00)), 51 | ]), 52 | ('0055a4,80:fff,ef4135', [ 53 | (0x00, (0x00, 0x55, 0xa4)), 54 | (0x80, (0xff, 0xff, 0xff)), 55 | (0xff, (0xef, 0x41, 0x35)), 56 | ]), 57 | ('0000d0,80:d0d0d0,d00000', [ 58 | (0x00, (0x00, 0x00, 0xd0)), 59 | (0x80, (0xd0, 0xd0, 0xd0)), 60 | (0xff, (0xd0, 0x00, 0x00)), 61 | ]), 62 | ('00d0d0,80:d0d0d0,d000d0', [ 63 | (0x00, (0x00, 0xd0, 0xd0)), 64 | (0x80, (0xd0, 0xd0, 0xd0)), 65 | (0xff, (0xd0, 0x00, 0xd0)), 66 | ]), 67 | ('00d000,80:d0d0d0,d000d0', [ 68 | (0x00, (0x00, 0xd0, 0x00)), 69 | (0x80, (0xd0, 0xd0, 0xd0)), 70 | (0xff, (0xd0, 0x00, 0xd0)), 71 | ]), 72 | ('7:000, 9:ffffff', [ 73 | (0x00, (0x00, 0x00, 0x00)), 74 | (0x77, (0x00, 0x00, 0x00)), 75 | (0x99, (0xff, 0xff, 0xff)), 76 | (0xff, (0xff, 0xff, 0xff)), 77 | ]), 78 | ] 79 | 80 | 81 | def interpolate(x, nodes): 82 | '''Interpolate from example color map nodes.''' 83 | xp_addr = np.array([node[0] for node in nodes], dtype=np.float64) 84 | fp_rgb = np.array([node[1] for node in nodes], dtype=np.float64).T 85 | return np.stack([np.interp(x, xp_addr, fp) for fp in fp_rgb], axis=-1).round(12).clip(0., 255.).astype(np.uint8) 86 | 87 | 88 | @pytest.fixture(scope='session', params=CMAPS) 89 | def cmap_example(request): 90 | '''Example color map fixture.''' 91 | return CMapExample(*request.param) 92 | 93 | 94 | @pytest.mark.parametrize('source_code', [ 95 | 'this', 'fff', ',,,', '111:111:111', 'fffff,fffff', 'f,f', 'fffffffff', 'ff:', 'ff:fff,00:000' 96 | ]) 97 | def test_color_map_wrong_syntax(source_code): 98 | '''Test whether different kinds of syntax errors cause a RuntimeError.''' 99 | with pytest.raises(RuntimeError): 100 | ColorMap(source_code) 101 | 102 | 103 | def test_color_map_nodes_call(cmap_example): 104 | '''Test if the color map nodes have the specified color when calling a ColorMap instance.''' 105 | cmap = ColorMap(cmap_example.source) 106 | input_addr = np.array([node[0] for node in cmap_example.nodes], dtype=np.float64)[None] 107 | expected_rgb = np.array([node[1] for node in cmap_example.nodes], dtype=np.uint8)[None] 108 | cmap_rgb = (cmap(input_addr / 255.) * 255.).round(12).clip(0., 255.).astype(np.uint8) 109 | assert np.allclose(expected_rgb, cmap_rgb) 110 | 111 | 112 | def test_color_map_nodes_palette(cmap_example): 113 | '''Test if the color map nodes have the specified color when using ColorMap.palette.''' 114 | cmap = ColorMap(cmap_example.source) 115 | input_addr = [node[0] for node in cmap_example.nodes] 116 | expected_rgb = np.array([node[1] for node in cmap_example.nodes], dtype=np.uint8)[None] 117 | palette = cmap.palette(level=1.) 118 | cmap_rgb = palette[input_addr] 119 | assert np.allclose(expected_rgb, cmap_rgb) 120 | 121 | 122 | def test_color_map_full_call(cmap_example): 123 | '''Test if the color map nodes have correctly interpolated colors when calling a ColorMap instance.''' 124 | cmap = ColorMap(cmap_example.source) 125 | input_addr = np.arange(256, dtype=np.uint8) 126 | expected_rgb = interpolate(input_addr, cmap_example.nodes) 127 | cmap_rgb = (cmap(input_addr / 255.) * 255.).round(12).clip(0., 255.).astype(np.uint8) 128 | assert np.allclose(expected_rgb, cmap_rgb) 129 | 130 | 131 | def test_color_map_full_palette(cmap_example): 132 | '''Test if the color map nodes have correctly interpolated colors when using ColorMap.palette.''' 133 | input_addr = np.arange(256, dtype=np.uint8) 134 | expected_palette = interpolate(input_addr, cmap_example.nodes) 135 | cmap = ColorMap(cmap_example.source) 136 | cmap_palette = cmap.palette(level=1.0) 137 | assert np.allclose(expected_palette, cmap_palette) 138 | 139 | 140 | def test_color_map_reassign_source_palette(cmap_example): 141 | '''Test if calling a ColorMap instance for which the source was changed produces correctly interpolated colors.''' 142 | cmap = ColorMap('fff,fff') 143 | cmap.source = cmap_example.source 144 | 145 | input_addr = np.arange(256, dtype=np.uint8) 146 | expected_palette = interpolate(input_addr, cmap_example.nodes) 147 | cmap_palette = cmap.palette(level=1.0) 148 | assert np.allclose(expected_palette, cmap_palette) 149 | 150 | 151 | def test_color_map_source_property(cmap_example): 152 | '''Test if the source property of a color map is equal to the specified source code.''' 153 | cmap = ColorMap(cmap_example.source) 154 | assert cmap.source == cmap_example.source, 'Mismatching source!' 155 | 156 | 157 | @pytest.fixture(scope='function') 158 | def lazy_cmap_cache(): 159 | '''Single fixture for a LazyColorMapCache''' 160 | return LazyColorMapCache({ 161 | 'gray': '000,fff', 162 | 'red': '100,f00', 163 | }) 164 | 165 | 166 | class TestLazyColorMapCache: 167 | '''Tests for LazyColorMapCache.''' 168 | @staticmethod 169 | def test_missing(lazy_cmap_cache): 170 | '''Test whether accessing an unknown key causes a KeyError.''' 171 | with pytest.raises(KeyError): 172 | _ = lazy_cmap_cache['no such cmap'] 173 | 174 | @staticmethod 175 | def test_get_item_uncompiled(lazy_cmap_cache): 176 | '''Test whether accessing an uncompiled entry compiles and returns the correct ColorMap.''' 177 | cmap = lazy_cmap_cache['red'] 178 | assert isinstance(cmap, ColorMap) 179 | assert cmap.source == '100,f00' 180 | 181 | @staticmethod 182 | def test_get_item_cached(lazy_cmap_cache): 183 | '''Test whether accessing a previously compiled and cached entry returns the same ColorMap.''' 184 | cmaps = [ 185 | lazy_cmap_cache['red'], 186 | lazy_cmap_cache['red'], 187 | ] 188 | assert cmaps[0] is cmaps[1] 189 | 190 | @staticmethod 191 | def test_set_item_existing(lazy_cmap_cache): 192 | '''Test whether setting an already existing, uncompiled entry and accessing it returns the correct ColorMap.''' 193 | lazy_cmap_cache['red'] = 'fff,f00' 194 | assert lazy_cmap_cache['red'].source == 'fff,f00' 195 | 196 | @staticmethod 197 | def test_set_item_new(lazy_cmap_cache): 198 | '''Test whether setting a new entry and accessing it returns the correct ColorMap.''' 199 | lazy_cmap_cache['blue'] = 'fff,00f' 200 | assert lazy_cmap_cache['blue'].source == 'fff,00f' 201 | 202 | @staticmethod 203 | def test_set_item_compiled(lazy_cmap_cache): 204 | '''Test whether setting an already existing, compiled entry and accessing it returns the same, modified 205 | ColorMap instance. 206 | ''' 207 | original_cmap = lazy_cmap_cache['red'] 208 | lazy_cmap_cache['red'] = 'fff,f00' 209 | assert lazy_cmap_cache['red'].source == 'fff,f00' 210 | assert original_cmap is lazy_cmap_cache['red'] 211 | 212 | @staticmethod 213 | def test_del_item_uncompiled(lazy_cmap_cache): 214 | '''Test whether deleting an uncompiled entry correctly removes the entry.''' 215 | del lazy_cmap_cache['red'] 216 | assert 'red' not in lazy_cmap_cache 217 | 218 | @staticmethod 219 | def test_del_item_compiled(lazy_cmap_cache): 220 | '''Test whether deleting a compiled entry correctly removes the entry.''' 221 | _ = lazy_cmap_cache['red'] 222 | del lazy_cmap_cache['red'] 223 | assert 'red' not in lazy_cmap_cache 224 | 225 | @staticmethod 226 | def test_iter(lazy_cmap_cache): 227 | '''Test whether iterating a LazyColorMapCache returns its keys.''' 228 | assert (list(lazy_cmap_cache) == ['gray', 'red']) 229 | 230 | @staticmethod 231 | def test_len(lazy_cmap_cache): 232 | '''Test whether calling len on a LazyColorMapCache returns the correct length.''' 233 | assert len(lazy_cmap_cache) == 2 234 | -------------------------------------------------------------------------------- /tests/test_composites.py: -------------------------------------------------------------------------------- 1 | '''Tests for composites using torchvision models.''' 2 | from types import MethodType 3 | from itertools import product 4 | 5 | from zennit.core import BasicHook, collect_leaves 6 | 7 | 8 | def ishookcopy(hook, hook_template): 9 | '''Check if ``hook`` is a copy of ``hook_template`` (due to copying-mechanics of BasicHook).''' 10 | if isinstance(hook_template, BasicHook): 11 | return all( 12 | getattr(hook, key) == getattr(hook_template, key) 13 | for key in ( 14 | 'input_modifiers', 15 | 'param_modifiers', 16 | 'output_modifiers', 17 | 'gradient_mapper', 18 | ) 19 | ) 20 | return isinstance(hook, type(hook_template)) 21 | 22 | 23 | def check_hook_registered(module, hook_template): 24 | '''Check whether a ``hook_template`` has been registered to ``module``. ''' 25 | return any( 26 | ishookcopy(hook_func.__self__, hook_template) 27 | for hook_func in getattr(module, '_forward_pre_hooks').values() 28 | if isinstance(hook_func, MethodType) 29 | ) 30 | 31 | 32 | def verify_no_hooks(model): 33 | '''Verify that ``model`` has no registered forward (-pre) hooks.''' 34 | return not any( 35 | any(getattr(module, key) for key in ('_forward_hooks', '_forward_pre_hooks')) 36 | for module in model.modules() 37 | ) 38 | 39 | 40 | def test_composite_layer_map_registered(layer_map_composite, model_vision): 41 | '''Tests whether the explicit LayerMapComposites register and unregister their rules correctly.''' 42 | errors = [] 43 | with layer_map_composite.context(model_vision): 44 | for child in model_vision.modules(): 45 | for dtype, hook_template in layer_map_composite.layer_map: 46 | if isinstance(child, dtype): 47 | if not check_hook_registered(child, hook_template): 48 | errors.append(( 49 | '{} is first of {} but {} is not registered!', 50 | (child, dtype, hook_template), 51 | )) 52 | break 53 | 54 | if not verify_no_hooks(model_vision): 55 | errors.append(('Model has hooks registered after composite was removed!', ())) 56 | 57 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors)) 58 | 59 | 60 | def test_composite_special_first_layer_map_registered(special_first_layer_map_composite, model_vision): 61 | '''Tests whether the explicit LayerMapComposites register and unregister their rules correctly.''' 62 | errors = [] 63 | try: 64 | special_first_layer, special_first_template, special_first_dtype = next( 65 | (child, hook_template, dtype) 66 | for child, (dtype, hook_template) in product( 67 | collect_leaves(model_vision), special_first_layer_map_composite.first_map 68 | ) if isinstance(child, dtype) 69 | ) 70 | except StopIteration: 71 | special_first_layer = None 72 | special_first_template = None 73 | 74 | with special_first_layer_map_composite.context(model_vision): 75 | if special_first_layer is not None and not check_hook_registered(special_first_layer, special_first_template): 76 | errors.append(( 77 | 'Special first layer {} is first of {} but {} is not registered!', 78 | (special_first_layer, special_first_dtype, special_first_template) 79 | )) 80 | 81 | children = (child for child in model_vision.modules() if child is not special_first_layer) 82 | for child in children: 83 | for dtype, hook_template in special_first_layer_map_composite.layer_map: 84 | if isinstance(child, dtype): 85 | if not check_hook_registered(child, hook_template): 86 | errors.append(( 87 | '{} is first of {} but {} is not registered!', 88 | (child, dtype, hook_template), 89 | )) 90 | break 91 | 92 | if not verify_no_hooks(model_vision): 93 | errors.append(('Model has hooks registered after composite was removed!', ())) 94 | 95 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors)) 96 | 97 | 98 | def test_composite_name_map_registered(name_map_composite, model_vision): 99 | '''Tests whether the constructed NameMapComposites register and unregister their rules correctly.''' 100 | errors = [] 101 | with name_map_composite.context(model_vision): 102 | for name, child in model_vision.named_modules(): 103 | for names, hook_template in name_map_composite.name_map: 104 | if name in names: 105 | if not check_hook_registered(child, hook_template): 106 | errors.append(( 107 | '{} is first in name map for {}, but is not registered!', 108 | (name, hook_template), 109 | )) 110 | break 111 | 112 | if not verify_no_hooks(model_vision): 113 | errors.append(('Model has hooks registered after composite was removed!', ())) 114 | 115 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors)) 116 | 117 | 118 | def test_composite_mixed_registered(mixed_composite, model_vision): 119 | '''Tests whether the constructed MixedComposites register and unregister their rules correctly.''' 120 | errors = [] 121 | 122 | name_map_composite, special_first_layer_map_composite = mixed_composite.composites 123 | 124 | try: 125 | special_first_layer, special_first_template, special_first_dtype = next( 126 | (child, hook_template, dtype) 127 | for child, (dtype, hook_template) in product( 128 | collect_leaves(model_vision), special_first_layer_map_composite.first_map 129 | ) if isinstance(child, dtype) 130 | ) 131 | except StopIteration: 132 | special_first_layer = None 133 | special_first_template = None 134 | 135 | with mixed_composite.context(model_vision): 136 | has_matched_first_layer = False 137 | for name, child in model_vision.named_modules(): 138 | has_matched_name_map = False 139 | for names, hook_template in name_map_composite.name_map: 140 | if name in names: 141 | has_matched_name_map = True 142 | if not check_hook_registered(child, hook_template): 143 | errors.append(( 144 | '{} is first in name map for {}, but is not registered!', 145 | (name, hook_template), 146 | )) 147 | break 148 | 149 | if has_matched_name_map: 150 | continue 151 | 152 | if not has_matched_first_layer and child == special_first_layer: 153 | has_matched_first_layer = True 154 | if not check_hook_registered(child, special_first_template): 155 | errors.append(( 156 | 'Special first layer {} is first of {} but {} is not registered!', 157 | (special_first_layer, special_first_dtype, special_first_template) 158 | )) 159 | continue 160 | 161 | for dtype, hook_template in special_first_layer_map_composite.layer_map: 162 | if isinstance(child, dtype): 163 | if not check_hook_registered(child, hook_template): 164 | errors.append(( 165 | '{} is first of {} but {} is not registered!', 166 | (child, dtype, hook_template), 167 | )) 168 | break 169 | 170 | if not verify_no_hooks(model_vision): 171 | errors.append(('Model has hooks registered after composite was removed!', ())) 172 | 173 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors)) 174 | 175 | 176 | def test_composite_name_layer_map_registered(name_layer_map_composite, model_vision): 177 | '''Tests whether the constructed NameLayerMapComposites register and unregister their rules correctly.''' 178 | errors = [] 179 | 180 | name_map_composite, layer_map_composite = name_layer_map_composite.composites 181 | 182 | with name_layer_map_composite.context(model_vision): 183 | for name, child in model_vision.named_modules(): 184 | for names, hook_template in name_map_composite.name_map: 185 | has_matched_name_map = False 186 | if name in names: 187 | has_matched_name_map = True 188 | if not check_hook_registered(child, hook_template): 189 | errors.append(( 190 | '{} is first in name map for {}, but is not registered!', 191 | (name, hook_template), 192 | )) 193 | break 194 | 195 | if has_matched_name_map: 196 | continue 197 | 198 | for dtype, hook_template in layer_map_composite.layer_map: 199 | if isinstance(child, dtype): 200 | if not check_hook_registered(child, hook_template): 201 | errors.append(( 202 | '{} is first of {} but {} is not registered!', 203 | (child, dtype, hook_template), 204 | )) 205 | break 206 | 207 | if not verify_no_hooks(model_vision): 208 | errors.append(('Model has hooks registered after composite was removed!', ())) 209 | 210 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors)) 211 | -------------------------------------------------------------------------------- /tests/test_image.py: -------------------------------------------------------------------------------- 1 | '''Tests for image operations.''' 2 | from typing import NamedTuple 3 | from itertools import product 4 | from io import BytesIO 5 | 6 | import pytest 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from zennit.cmap import ColorMap 11 | from zennit.image import get_cmap, palette, imgify, gridify, imsave, interval_norm_bounds 12 | 13 | 14 | @pytest.fixture(scope='session', params=[ 15 | 'gray', '000,fff', ColorMap('000,fff') 16 | ]) 17 | def cmap_source(request): 18 | '''Fixture for multiple ways to specify the "gray" color map.''' 19 | return request.param 20 | 21 | 22 | class ImageTuple(NamedTuple): 23 | '''NamedTuple for image-array setups.''' 24 | grid: bool 25 | nchannels: list 26 | channel_front: bool 27 | width: int 28 | height: int 29 | array: np.ndarray 30 | 31 | 32 | @pytest.fixture(scope='session', params=product( 33 | [False, True], 34 | [1, 3], 35 | [False, True], 36 | [5, 10], 37 | [5, 10], 38 | [np.float64, np.uint8] 39 | )) 40 | def image_tuple(request): 41 | '''Image-array setups with varying size, type, number of channels, channel position and grid dimension.''' 42 | grid, nchannels, channel_front, width, height, dtype = request.param 43 | 44 | shape = (height, width) 45 | if channel_front: 46 | shape = (nchannels,) + shape 47 | else: 48 | shape = shape + (nchannels,) 49 | shape = (1,) * grid + shape 50 | 51 | return ImageTuple( 52 | grid, 53 | nchannels, 54 | channel_front, 55 | width, 56 | height, 57 | np.ones(shape, dtype=dtype), 58 | ) 59 | 60 | 61 | def test_get_cmap(cmap_source): 62 | '''Test whether get_cmap handles its supported cmap types correctly.''' 63 | cmap = get_cmap(cmap_source) 64 | assert isinstance(cmap, ColorMap), 'Returned object is not a ColorMap!' 65 | assert cmap.source == '000,fff', 'Mismatch in source code of returned ColorMap instance.' 66 | 67 | 68 | def test_palette(cmap_source): 69 | '''Test whether palette returns the correct palette for all of its supported types.''' 70 | pal = palette(cmap_source) 71 | expected_pal = np.repeat(np.arange(256, dtype=np.uint8)[:, None], 3, axis=1) 72 | assert np.allclose(expected_pal, pal) 73 | 74 | 75 | @pytest.mark.parametrize('ndim', [1, 4, 5, 6]) 76 | def test_imgify_wrong_dim(ndim): 77 | '''Test whether imgify fails for an unsupported number of dimensions.''' 78 | with pytest.raises(TypeError): 79 | imgify(np.zeros((1,) * ndim)) 80 | 81 | 82 | @pytest.mark.parametrize('ndim', [1, 2, 5, 6]) 83 | def test_imgify_grid_wrong_dim(ndim): 84 | '''Test whether imgify fails for an unsupported number of dimensions with grid=True.''' 85 | with pytest.raises(TypeError): 86 | imgify(np.zeros((1,) * ndim), grid=True) 87 | 88 | 89 | @pytest.mark.parametrize('grid', [[1], (1,), 1, [1, 1, 1], (1, 1, 1)]) 90 | def test_imgify_grid_bad_grid(grid): 91 | '''Test whether imgify fails for unsupported grid values.''' 92 | with pytest.raises(TypeError): 93 | imgify(np.zeros((1,) * 4), grid=grid) 94 | 95 | 96 | @pytest.mark.parametrize('grid,nchannels', product([False, True], [2, 4])) 97 | def test_imgify_wrong_channels(grid, nchannels): 98 | '''Test whether imgify fails for an unsupported number of dimensions with grid=True.''' 99 | with pytest.raises(TypeError): 100 | imgify(np.zeros((1,) * grid + (2, 2, nchannels)), grid=grid) 101 | 102 | 103 | def test_imgify_container(image_tuple): 104 | '''Test whether imgify produces the correct PIL Image container''' 105 | image = imgify(image_tuple.array, grid=image_tuple.grid) 106 | assert image.mode == ('P' if image_tuple.nchannels == 1 else 'RGB'), 'Mode mismatch!' 107 | assert image.width == image_tuple.width, 'Width mismatch!' 108 | assert image.height == image_tuple.height, 'Height mismatch!' 109 | 110 | 111 | @pytest.mark.parametrize('vmin,vmax,symmetric', product([None, 1.], [None, 2.], [False, True])) 112 | def test_imgify_normalization(vmin, vmax, symmetric): 113 | '''Test whether imgify normalizes as expected.''' 114 | array = np.array([[-1., 0., 3.]]) 115 | 116 | image = imgify(array, cmap='gray', vmin=vmin, vmax=vmax, symmetric=symmetric) 117 | 118 | if vmin is None: 119 | if symmetric: 120 | vmin = -np.abs(array).max() 121 | else: 122 | vmin = array.min() 123 | if vmax is None: 124 | if symmetric: 125 | vmax = np.abs(array).max() 126 | else: 127 | vmax = array.max() 128 | 129 | expected = (((array - vmin) / (vmax - vmin)) * 255.).clip(0, 255).astype(np.uint8) 130 | 131 | assert np.allclose(np.array(image), expected) 132 | 133 | 134 | @pytest.mark.parametrize('ndim', [1, 2, 5, 6]) 135 | def test_gridify_wrong_dim(ndim): 136 | '''Test whether imgify fails for an unsupported number of dimensions.''' 137 | with pytest.raises(TypeError): 138 | gridify(np.zeros((1,) * ndim)) 139 | 140 | 141 | @pytest.mark.parametrize('channel_front,nchannels', product([False, True], [2, 4])) 142 | def test_gridify_wrong_channels(channel_front, nchannels): 143 | '''Test whether gridify fails for an unsupported number of channels in both channel positions.''' 144 | shape = (2, 2) 145 | if channel_front: 146 | shape = (nchannels,) + shape 147 | else: 148 | shape = shape + (nchannels,) 149 | shape = (1,) + shape 150 | 151 | with pytest.raises(TypeError): 152 | gridify(np.zeros(shape)) 153 | 154 | 155 | @pytest.mark.parametrize('shape,expected_shape', [ 156 | [(4, 2, 2, 3), (4, 4, 3)], 157 | [(4, 2, 2, 1), (4, 4, 1)], 158 | [(4, 2, 2), (4, 4, 1)], 159 | [(4, 3, 2, 2), (4, 4, 3)], 160 | [(4, 1, 2, 2), (4, 4, 1)], 161 | ]) 162 | def test_gridify_shape(shape, expected_shape): 163 | '''Test whether gridify produces the correct shape.''' 164 | output = gridify(np.zeros(shape)) 165 | assert expected_shape == output.shape 166 | 167 | 168 | @pytest.mark.parametrize('fill_value', [None, 0.]) 169 | def test_gridify_fill(fill_value): 170 | '''Test whether gridify fills empty pixels with the correct value.''' 171 | array = np.array([[[[1.]]]]) 172 | output = gridify(array, fill_value=fill_value, shape=(1, 2)) 173 | expected_value = array.min() if fill_value is None else fill_value 174 | assert output[0, 1, 0] == expected_value 175 | 176 | 177 | @pytest.mark.parametrize('writer_params', [None, {}]) 178 | def test_imsave_container(image_tuple, writer_params): 179 | '''Test whether imsave produces a file, which loads as the correct PIL Image container.''' 180 | fp = BytesIO() 181 | imsave(fp, image_tuple.array, grid=image_tuple.grid, format='png', writer_params=writer_params) 182 | fp.seek(0) 183 | image = Image.open(fp) 184 | assert image.mode == ('P' if image_tuple.nchannels == 1 else 'RGB'), 'Mode mismatch!' 185 | assert image.width == image_tuple.width, 'Width mismatch!' 186 | assert image.height == image_tuple.height, 'Height mismatch!' 187 | 188 | 189 | @pytest.mark.parametrize('symmetric,dim,expected_bounds', [ 190 | (False, None, (np.array([[[[-1.]]], [[[0.]]]]), np.array([[[[-0.2]]], [[[0.8]]]]))), 191 | (False, (1, 2, 3), (np.array([[[[-1.]]], [[[0.]]]]), np.array([[[[-0.2]]], [[[0.8]]]]))), 192 | (False, (0, 1, 2, 3), (np.array([[[[-1.]]]]), np.array([[[[0.8]]]]))), 193 | (True, None, (np.array([[[[-1.]]], [[[-0.8]]]]), np.array([[[[1.]]], [[[0.8]]]]))), 194 | (True, (1, 2, 3), (np.array([[[[-1.]]], [[[-0.8]]]]), np.array([[[[1.]]], [[[0.8]]]]))), 195 | (True, (0, 1, 2, 3), (np.array([[[[-1.]]]]), np.array([[[[1.]]]]))), 196 | ]) 197 | def test_interval_norm_bounds(symmetric, dim, expected_bounds): 198 | '''Test whether interval_norm_bounds computes the correct minimum and maximum values.''' 199 | array = np.linspace(-1., 0.8, 10).reshape((2, 1, 5, 1)) 200 | bounds = interval_norm_bounds(array, symmetric=symmetric, dim=dim) 201 | assert np.allclose(expected_bounds, bounds) 202 | -------------------------------------------------------------------------------- /tests/test_rules.py: -------------------------------------------------------------------------------- 1 | '''Tests for various rules. Rules are re-implemented in a slower, less complicated way, which closely follows the 2 | definition in the original works, which makes them easier to compare and thus less likely to be wrong. 3 | ''' 4 | from functools import wraps, partial 5 | from copy import deepcopy 6 | 7 | import pytest 8 | import torch 9 | from zennit.rules import Epsilon, ZPlus, AlphaBeta, Gamma, ZBox, Norm, WSquare, Flat 10 | from zennit.rules import Pass, ReLUDeconvNet, ReLUGuidedBackprop, ReLUBetaSmooth 11 | from zennit.rules import zero_bias as name_zero_bias 12 | 13 | 14 | def stabilize(input, epsilon=1e-6): 15 | '''Replicates zennit.core.stabilize for testing.''' 16 | return input + ((input == 0.).to(input) + input.sign()) * epsilon 17 | 18 | 19 | def as_matrix(module_linear, input, output): 20 | '''Get flat weight and bias using the jacobian.''' 21 | jac = torch.autograd.functional.jacobian(module_linear, input[None]) 22 | weight = jac.reshape((output.numel(), input.numel())) 23 | bias = output.flatten() - weight @ input.flatten() 24 | return weight, bias 25 | 26 | 27 | RULES_LINEAR = [] 28 | RULES_SIMPLE = [] 29 | 30 | 31 | def replicates(target_list, replicated_func, **kwargs): 32 | '''Decorator to indicate a replication of a function for testing.''' 33 | def wrapper(func): 34 | '''Append to ``RULES_LINEAR`` as partial, given ``kwargs``.''' 35 | target_list.append( 36 | pytest.param( 37 | (partial(replicated_func, **kwargs), partial(func, **kwargs)), 38 | id=replicated_func.__name__ 39 | ) 40 | ) 41 | return func 42 | return wrapper 43 | 44 | 45 | def flat_module_params(func): 46 | '''Decorator to to copy module and overwrite module params completely with ones (for rule_flat).''' 47 | @wraps(func) 48 | def wrapped(module_linear, *args, **kwargs): 49 | '''Make a deep copy of module_linear, fill all parameters inline with ones, and call func with the copy.''' 50 | module_copy = deepcopy(module_linear) 51 | for param in module_copy.parameters(): 52 | param.requires_grad_(False).fill_(1.0) 53 | return func(module_copy, *args, **kwargs) 54 | return wrapped 55 | 56 | 57 | def matrix_form(func): 58 | '''Decorator to wrap function such that weights and bias supplied in matrix-form and input and output are flattened 59 | appropriately.''' 60 | @wraps(func) 61 | def wrapped(module_linear, input, output, **kwargs): 62 | '''Get flat weight matrix and bias using the jacobian, flatten input and output, and pass arguments to func.''' 63 | weight, bias = as_matrix(module_linear, input[0], output[0]) 64 | return func( 65 | weight, 66 | bias, 67 | input.flatten(start_dim=1), 68 | output.flatten(start_dim=1), 69 | **kwargs 70 | ).reshape(input.shape) 71 | return wrapped 72 | 73 | 74 | def with_grad(func): 75 | '''Decorator to wrap function such that the gradient is computed and passed to the function instead of module.''' 76 | @wraps(func) 77 | def wrapped(module, input, output, **kwargs): 78 | '''Get gradient and pass along input, output and keyword arguments to func.''' 79 | gradient, = torch.autograd.grad(module(input), input, output) 80 | return func( 81 | gradient, 82 | input, 83 | output, 84 | **kwargs 85 | ) 86 | return wrapped 87 | 88 | 89 | def zero_bias(zero_params, bias): 90 | '''Return a tensor with zeros like ``bias`` if zero_params is equal to or contains the string ``'bias'``, otherwise 91 | return the unmodified tensor ``bias``.''' 92 | if zero_params is None: 93 | zero_params = [] 94 | if bias is not None and (zero_params == 'bias' or 'bias' in zero_params): 95 | return torch.zeros_like(bias) 96 | return bias 97 | 98 | 99 | @replicates(RULES_LINEAR, Epsilon, epsilon=1e-6) 100 | @replicates(RULES_LINEAR, Epsilon, epsilon=1e-6, zero_params='bias') 101 | @replicates(RULES_LINEAR, Epsilon, epsilon=1.0) 102 | @replicates(RULES_LINEAR, Epsilon, epsilon=1.0, zero_params='bias') 103 | @replicates(RULES_LINEAR, Norm) 104 | @matrix_form 105 | def rule_epsilon(weight, bias, input, relevance, epsilon=1e-6, zero_params=None): 106 | '''Replicates the Epsilon rule.''' 107 | bias = zero_bias(zero_params, bias) 108 | return input * ((relevance / stabilize(input @ weight.t() + bias, epsilon)) @ weight) 109 | 110 | 111 | @replicates(RULES_LINEAR, ZPlus) 112 | @replicates(RULES_LINEAR, ZPlus, zero_params='bias') 113 | @matrix_form 114 | def rule_zplus(weight, bias, input, relevance, zero_params=None): 115 | '''Replicates the ZPlus rule.''' 116 | bias = zero_bias(zero_params, bias) 117 | wplus = weight.clamp(min=0) 118 | wminus = weight.clamp(max=0) 119 | xplus = input.clamp(min=0) 120 | xminus = input.clamp(max=0) 121 | zval = xplus @ wplus.t() + xminus @ wminus.t() + bias.clamp(min=0) 122 | rfac = relevance / stabilize(zval) 123 | return xplus * (rfac @ wplus) + xminus * (rfac @ wminus) 124 | 125 | 126 | @replicates(RULES_LINEAR, Gamma, gamma=0.25) 127 | @replicates(RULES_LINEAR, Gamma, gamma=0.25, zero_params='bias') 128 | @replicates(RULES_LINEAR, Gamma, gamma=0.5) 129 | @replicates(RULES_LINEAR, Gamma, gamma=0.5, zero_params='bias') 130 | @matrix_form 131 | def rule_gamma(weight, bias, input, relevance, gamma, zero_params=None): 132 | '''Replicates the Gamma rule.''' 133 | output = input @ weight.t() + bias 134 | bias = zero_bias(zero_params, bias) 135 | pinput = input.clamp(min=0) 136 | ninput = input.clamp(max=0) 137 | pwgamma = weight + weight.clamp(min=0) * gamma 138 | nwgamma = weight + weight.clamp(max=0) * gamma 139 | pbgamma = bias + bias.clamp(min=0) * gamma 140 | nbgamma = bias + bias.clamp(max=0) * gamma 141 | 142 | pgrad_out = (relevance / stabilize(pinput @ pwgamma.t() + ninput @ nwgamma.t() + pbgamma)) * (output > 0.) 143 | positive = pinput * (pgrad_out @ pwgamma) + ninput * (pgrad_out @ nwgamma) 144 | 145 | ngrad_out = (relevance / stabilize(pinput @ nwgamma.t() + ninput @ pwgamma.t() + nbgamma)) * (output < 0.) 146 | negative = pinput * (ngrad_out @ nwgamma) + ninput * (ngrad_out @ pwgamma) 147 | 148 | return positive + negative 149 | 150 | 151 | @replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0) 152 | @replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0, zero_params='bias') 153 | @replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0) 154 | @replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0, zero_params='bias') 155 | @matrix_form 156 | def rule_alpha_beta(weight, bias, input, relevance, alpha, beta, zero_params=None): 157 | '''Replicates the AlphaBeta rule.''' 158 | bias = zero_bias(zero_params, bias) 159 | wplus = weight.clamp(min=0) 160 | wminus = weight.clamp(max=0) 161 | xplus = input.clamp(min=0) 162 | xminus = input.clamp(max=0) 163 | zalpha = xplus @ wplus.t() + xminus @ wminus.t() + bias.clamp(min=0) 164 | zbeta = xplus @ wminus.t() + xminus @ wplus.t() + bias.clamp(max=0) 165 | ralpha = relevance / stabilize(zalpha) 166 | rbeta = relevance / stabilize(zbeta) 167 | result_alpha = xplus * (ralpha @ wplus) + xminus * (ralpha @ wminus) 168 | result_beta = xplus * (rbeta @ wminus) + xminus * (rbeta @ wplus) 169 | return alpha * result_alpha - beta * result_beta 170 | 171 | 172 | @replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0) 173 | @replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0, zero_params='bias') 174 | @matrix_form 175 | def rule_zbox(weight, bias, input, relevance, low, high, zero_params=None): 176 | '''Replicates the ZBox rule.''' 177 | wplus = weight.clamp(min=0) 178 | wminus = weight.clamp(max=0) 179 | low = torch.tensor(low).expand_as(input).to(input) 180 | high = torch.tensor(high).expand_as(input).to(input) 181 | zval = input @ weight.t() - low @ wplus.t() - high @ wminus.t() 182 | rfac = relevance / stabilize(zval) 183 | return input * (rfac @ weight) - low * (rfac @ wplus) - high * (rfac @ wminus) 184 | 185 | 186 | @replicates(RULES_LINEAR, WSquare) 187 | @replicates(RULES_LINEAR, WSquare, zero_params='bias') 188 | @matrix_form 189 | def rule_wsquare(weight, bias, input, relevance, zero_params=None): 190 | '''Replicates the WSquare rule.''' 191 | bias = zero_bias(zero_params, bias) 192 | wsquare = weight ** 2 193 | zval = torch.ones_like(input) @ wsquare.t() + bias ** 2 194 | rfac = relevance / stabilize(zval) 195 | return rfac @ wsquare 196 | 197 | 198 | @replicates(RULES_LINEAR, Flat) 199 | @flat_module_params 200 | @matrix_form 201 | def rule_flat(wflat, bias, input, relevance): 202 | '''Replicates the Flat rule.''' 203 | zval = torch.ones_like(input) @ wflat.t() 204 | rfac = relevance / stabilize(zval) 205 | return rfac @ wflat 206 | 207 | 208 | @replicates(RULES_SIMPLE, Pass) 209 | def rule_pass(module, input, relevance): 210 | '''Replicates the Pass rule.''' 211 | return relevance 212 | 213 | 214 | @replicates(RULES_SIMPLE, ReLUDeconvNet) 215 | def rule_relu_deconvnet(module, input, relevance): 216 | '''Replicates the ReLUDeconvNet rule.''' 217 | return relevance.clamp(min=0) 218 | 219 | 220 | @replicates(RULES_SIMPLE, ReLUGuidedBackprop) 221 | @with_grad 222 | def rule_relu_guidedbackprop(gradient, input, relevance): 223 | '''Replicates the ReLUGuidedBackprop rule.''' 224 | return gradient * (relevance > 0.) 225 | 226 | 227 | @replicates(RULES_SIMPLE, ReLUBetaSmooth, beta_smooth=10.) 228 | @replicates(RULES_SIMPLE, ReLUBetaSmooth, beta_smooth=1.) 229 | def rule_relu_beta_smooth(module, input, relevance, beta_smooth): 230 | '''Replicates the ReLUBetaSmooth rule.''' 231 | return relevance * torch.sigmoid(beta_smooth * input) 232 | 233 | 234 | @pytest.fixture(scope='session', params=RULES_LINEAR) 235 | def rule_pair_linear(request): 236 | '''Fixture to supply ``RULES_LINEAR``.''' 237 | return request.param 238 | 239 | 240 | @pytest.fixture(scope='session', params=RULES_SIMPLE) 241 | def rule_pair_simple(request): 242 | '''Fixture to supply ``RULES_SIMPLE``.''' 243 | return request.param 244 | 245 | 246 | def compare_rule_pair(module, data, rule_pair): 247 | '''Compare rules with their replicated versions.''' 248 | rule_hook, rule_replicated = rule_pair 249 | 250 | input = data.clone().requires_grad_() 251 | handle = rule_hook().register(module) 252 | try: 253 | output = module(input) 254 | relevance_hook, = torch.autograd.grad(output, input, grad_outputs=output) 255 | finally: 256 | handle.remove() 257 | 258 | relevance_replicated = rule_replicated(module, input, output) 259 | 260 | assert torch.allclose(relevance_hook, relevance_replicated, atol=1e-5) 261 | 262 | 263 | def test_linear_rule(module_linear, data_linear, rule_pair_linear): 264 | '''Test whether replicated and original implementations of rules for linear layers agree.''' 265 | compare_rule_pair(module_linear, data_linear, rule_pair_linear) 266 | 267 | 268 | def test_simple_rule(module_simple, data_simple, rule_pair_simple): 269 | '''Test whether replicated and original implementations of rules for simple layers agree.''' 270 | compare_rule_pair(module_simple, data_simple, rule_pair_simple) 271 | 272 | 273 | def test_alpha_beta_invalid_values(): 274 | '''Test whether AlphaBeta raises ValueErrors for negative alpha/beta or when alpha - beta is not equal to 1.''' 275 | with pytest.raises(ValueError): 276 | AlphaBeta(alpha=-1.) 277 | with pytest.raises(ValueError): 278 | AlphaBeta(beta=-1.) 279 | with pytest.raises(ValueError): 280 | AlphaBeta(alpha=1., beta=1.) 281 | 282 | 283 | @pytest.mark.parametrize('params', [None, 'weight', ['weight'], 'bias', ['bias'], ['weight', 'bias']]) 284 | def test_zero_bias(params): 285 | '''Test whether zero_bias correctly appends 'bias' to the zero_params list/str used for ParamMod.''' 286 | result = name_zero_bias(params) 287 | assert isinstance(result, list) 288 | assert 'bias' in result 289 | -------------------------------------------------------------------------------- /tests/test_torchvision.py: -------------------------------------------------------------------------------- 1 | '''Tests for torchvision-model-specific canonizers.''' 2 | import pytest 3 | import torch 4 | from torchvision.models import vgg11_bn, resnet18, resnet50 5 | from torchvision.models.resnet import BasicBlock as ResNetBasicBlock, Bottleneck as ResNetBottleneck 6 | from helpers import assert_identity_hook, randomize_bnorm, nograd 7 | 8 | from zennit.core import Composite, RemovableHandleList 9 | from zennit.torchvision import VGGCanonizer, ResNetCanonizer 10 | from zennit.types import BatchNorm 11 | 12 | 13 | def test_vgg_canonizer(batchsize): 14 | '''Test whether VGGCanonizer merges BatchNorm modules correctly and keeps the output unchanged.''' 15 | model = randomize_bnorm(nograd(vgg11_bn().eval().to(torch.float64))) 16 | data = torch.randn((batchsize, 3, 224, 224), dtype=torch.float64) 17 | output_before = model(data) 18 | 19 | handles = RemovableHandleList( 20 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!')) 21 | for module in model.modules() if isinstance(module, BatchNorm) 22 | ) 23 | 24 | canonizer = VGGCanonizer() 25 | composite = Composite(canonizers=[canonizer]) 26 | 27 | try: 28 | composite.register(model) 29 | output_canonizer = model(data) 30 | finally: 31 | composite.remove() 32 | handles.remove() 33 | 34 | # this assumes the batch-norm is not initialized as the identity 35 | handles = RemovableHandleList( 36 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!')) 37 | for module in model.modules() if isinstance(module, BatchNorm) 38 | ) 39 | try: 40 | output_after = model(data) 41 | finally: 42 | handles.remove() 43 | 44 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!' 45 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!' 46 | 47 | 48 | @pytest.mark.parametrize('model_fn,block_type', [ 49 | (resnet18, ResNetBasicBlock), 50 | (resnet50, ResNetBottleneck), 51 | ]) 52 | def test_resnet_canonizer(batchsize, model_fn, block_type): 53 | '''Test whether ResNetCanonizer overwrites and restores the Bottleneck/BasicBlock forward, merges BatchNorm modules 54 | correctly and keeps the output unchanged. 55 | ''' 56 | model = randomize_bnorm(nograd(model_fn().eval().to(torch.float64))) 57 | data = torch.randn((batchsize, 3, 224, 224), dtype=torch.float64) 58 | blocks = [module for module in model.modules() if isinstance(module, block_type)] 59 | 60 | assert blocks, 'Model has no blocks!' 61 | assert all( 62 | block.forward == block_type.forward.__get__(block) for block in blocks 63 | ), 'Model has its forward already overwritten!' 64 | 65 | output_before = model(data) 66 | 67 | handles = RemovableHandleList( 68 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!')) 69 | for module in model.modules() if isinstance(module, BatchNorm) 70 | ) 71 | 72 | canonizer = ResNetCanonizer() 73 | composite = Composite(canonizers=[canonizer]) 74 | 75 | try: 76 | composite.register(model) 77 | assert not any( 78 | block.forward == block_type.forward.__get__(block) for block in blocks 79 | ), 'Model forward was not overwritten!' 80 | output_canonizer = model(data) 81 | finally: 82 | composite.remove() 83 | handles.remove() 84 | 85 | # this assumes the batch-norm is not initialized as the identity 86 | handles = RemovableHandleList( 87 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!')) 88 | for module in model.modules() if isinstance(module, BatchNorm) 89 | ) 90 | try: 91 | output_after = model(data) 92 | finally: 93 | handles.remove() 94 | 95 | assert all( 96 | block.forward == block_type.forward.__get__(block) for block in blocks 97 | ), 'Model forward was not restored!' 98 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!' 99 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!' 100 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | skip_missing_interpreters = true 3 | envlist = py37,py38,py39,pylint,flake8,docs 4 | 5 | [testenv] 6 | extras = tests 7 | setenv = 8 | COVERAGE_FILE = {toxworkdir}/.coverage.{envname} 9 | commands = 10 | pytest \ 11 | --cov "{envsitepackagesdir}/zennit" \ 12 | --cov-config "{toxinidir}/tox.ini" \ 13 | {posargs:.} 14 | 15 | [testenv:coverage] 16 | deps = 17 | coverage 18 | setenv = 19 | COVERAGE_FILE = {toxworkdir}/.coverage 20 | skip_install = true 21 | commands = 22 | coverage combine 23 | coverage report -m 24 | depends = py37,py38,py39 25 | 26 | [testenv:docs] 27 | basepython = python3.9 28 | extras = docs 29 | commands = 30 | sphinx-build \ 31 | --color \ 32 | -W \ 33 | --keep-going \ 34 | -d "{toxinidir}/docs/doctree" \ 35 | -b html \ 36 | "{toxinidir}/docs/source" \ 37 | "{toxinidir}/docs/build" \ 38 | {posargs} 39 | 40 | [testenv:flake8] 41 | basepython = python3.9 42 | changedir = {toxinidir} 43 | deps = 44 | flake8 45 | commands = 46 | flake8 "{toxinidir}/src/zennit" "{toxinidir}/tests" {posargs} 47 | 48 | 49 | [testenv:pylint] 50 | basepython = python3.9 51 | deps = 52 | pylint 53 | pytest 54 | changedir = {toxinidir} 55 | commands = 56 | pylint --rcfile=pylintrc --output-format=parseable {toxinidir}/src/zennit {toxinidir}/tests 57 | 58 | [flake8] 59 | # R902 Too many instance attributes 60 | # R913 Too many arguments 61 | # R914 Too many local variables 62 | # W503 Line-break before binary operator 63 | ignore = R902,R913,R914,W503 64 | 65 | exclude=.venv,.git,.tox,build,dist,docs,*egg,*.ini 66 | 67 | max-line-length = 120 68 | 69 | [pytest] 70 | testpaths = tests 71 | addopts = -ra -l 72 | 73 | [coverage:run] 74 | parallel = true 75 | branch = true 76 | 77 | [coverage:report] 78 | skip_covered = true 79 | show_missing = true 80 | 81 | [coverage:paths] 82 | source = src/zennit 83 | */.tox/*/lib/python*/site-packages/zennit 84 | */src/zennit 85 | --------------------------------------------------------------------------------