├── .github
└── workflows
│ └── tests.yml
├── .gitignore
├── .gitlab-ci.yml
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── COPYING
├── COPYING.LESSER
├── LICENSE
├── README.md
├── docs
└── source
│ ├── _static
│ └── favicon.svg
│ ├── _templates
│ └── modules.rst
│ ├── bibliography.bib
│ ├── bibliography.rst
│ ├── conf.py
│ ├── getting-started.rst
│ ├── how-to
│ ├── compute-second-order-gradients.rst
│ ├── get-intermediate-relevance.rst
│ ├── index.rst
│ ├── use-attributors.rst
│ ├── use-rules-composites-and-canonizers.rst
│ ├── visualize-results.rst
│ ├── write-custom-attributors.rst
│ ├── write-custom-canonizers.rst
│ ├── write-custom-composites.rst
│ └── write-custom-rules.rst
│ ├── index.rst
│ ├── reference
│ └── index.rst
│ └── tutorial
│ ├── image-classification-vgg-resnet.ipynb
│ └── index.rst
├── pylintrc
├── setup.py
├── share
├── example
│ └── feed_forward.py
├── img
│ ├── beacon_resnet50_various.webp
│ ├── beacon_vgg16_epsilon_gamma_box.png
│ ├── beacon_vgg16_various.webp
│ ├── zennit.png
│ └── zennit.svg
├── merge_maps
│ └── vgg16_bn.json
└── scripts
│ ├── download-lighthouses.sh
│ ├── palette_fit.py
│ ├── palette_swap.py
│ └── show_cmaps.py
├── src
└── zennit
│ ├── __init__.py
│ ├── attribution.py
│ ├── canonizers.py
│ ├── cmap.py
│ ├── composites.py
│ ├── core.py
│ ├── image.py
│ ├── layer.py
│ ├── rules.py
│ ├── torchvision.py
│ └── types.py
├── tests
├── conftest.py
├── helpers.py
├── test_attribution.py
├── test_canonizers.py
├── test_cmap.py
├── test_composites.py
├── test_core.py
├── test_image.py
├── test_rules.py
└── test_torchvision.py
└── tox.ini
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 | on:
3 | push:
4 | branches: [master]
5 | pull_request:
6 | branches: [master]
7 |
8 | jobs:
9 | test:
10 | name: test ${{matrix.tox_env}}
11 | runs-on: ubuntu-latest
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | include:
16 | - tox_env: py37
17 | python: "3.7"
18 | - tox_env: py38
19 | python: "3.8"
20 | - tox_env: py39
21 | python: "3.9"
22 | steps:
23 | - uses: actions/checkout@v2
24 | with:
25 | fetch-depth: 0
26 | - name: Install base python for tox
27 | uses: actions/setup-python@v2
28 | with:
29 | python-version: "3.9"
30 | - name: Install tox
31 | run: python -m pip install tox
32 | - name: Install python for test
33 | uses: actions/setup-python@v2
34 | with:
35 | python-version: ${{ matrix.python }}
36 | - name: Setup test environment
37 | run: tox -vv --notest -e ${{ matrix.tox_env }}
38 | - name: Run test
39 | run: tox --skip-pkg-install -e ${{ matrix.tox_env }}
40 |
41 |
42 | check:
43 | name: check ${{ matrix.tox_env }}
44 | runs-on: ubuntu-latest
45 | strategy:
46 | fail-fast: false
47 | matrix:
48 | tox_env:
49 | - flake8
50 | - pylint
51 | steps:
52 | - uses: actions/checkout@v2
53 | with:
54 | fetch-depth: 0
55 | - name: Install base python for tox
56 | uses: actions/setup-python@v2
57 | with:
58 | python-version: "3.9"
59 | - name: Install tox
60 | run: python -m pip install tox
61 | - name: Setup test environment
62 | run: tox -vv --notest -e ${{ matrix.tox_env }}
63 | - name: Run test
64 | run: tox --skip-pkg-install -e ${{ matrix.tox_env }}
65 |
66 | docs:
67 | name: docs
68 | runs-on: ubuntu-latest
69 | strategy:
70 | fail-fast: false
71 | steps:
72 | - uses: actions/checkout@v2
73 | with:
74 | fetch-depth: 0
75 | - name: Install base python for tox
76 | uses: actions/setup-python@v2
77 | with:
78 | python-version: "3.9"
79 | - name: Install pandoc
80 | run: sudo apt-get update -y && sudo apt-get install -y pandoc
81 | - name: Install tox
82 | run: python -m pip install tox
83 | - name: Setup test environment
84 | run: tox -vv --notest -e docs
85 | - name: Run test
86 | run: tox --skip-pkg-install -e docs
87 |
88 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # IDEs and Code editors
2 | .vscode
3 | .idea
4 | .venv
5 | .exrc
6 |
7 | # Python output
8 | .python
9 | *.egg-info
10 | __pycache__
11 |
12 | # Setup output
13 | build/
14 | dist/
15 |
16 | # Results output
17 | result
18 |
19 | # Testing
20 | .tox
21 | .coverage
22 | .coverage.*
23 |
24 | # System Files
25 | .DS_Store
26 | Thumbs.db
27 |
28 | # NPM dependencies
29 | node_modules
30 |
--------------------------------------------------------------------------------
/.gitlab-ci.yml:
--------------------------------------------------------------------------------
1 |
2 | stages:
3 | - linting
4 | - unit-tests
5 |
6 | pylint:
7 | stage: linting
8 | script:
9 | - python3.8 -m tox -e pylint
10 |
11 | flake8:
12 | stage: linting
13 | script:
14 | - python3.8 -m tox -e flake8
15 |
16 | pytest:
17 | stage: unit-tests
18 | when: always
19 | script:
20 | - python3.8 -m tox -e py38
21 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.9"
7 |
8 | sphinx:
9 | configuration: docs/source/conf.py
10 |
11 | python:
12 | install:
13 | - method: pip
14 | path: .
15 | extra_requirements: ["docs"]
16 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guide for Zennit
2 |
3 | Thank you for your interest in contributing to Zennit!
4 |
5 | If you would like to fix a bug or add a feature, please write an issue before submitting a pull request.
6 |
7 |
8 | ## Git
9 | We use a linear git-history, where each commit contains a full feature/bug fix,
10 | such that each commit represents an executable version.
11 | The commit message contains a subject followed by an empty line, followed by a detailed description, similar to the following:
12 |
13 | ```
14 | Category: Short subject describing changes (50 characters or less)
15 |
16 | - detailed description, wrapped at 72 characters
17 | - bullet points or sentences are okay
18 | - all changes should be documented and explained
19 | - valid categories are, for example:
20 | - `Docs` for documentation
21 | - `Tests` for tests
22 | - `Composites` for changes in composites
23 | - `Core` for core changes
24 | - `Package` for package-related changes, e.g. in setup.py
25 | ```
26 |
27 | We recommend to not use `-m` for committing, as this often results in very short commit messages.
28 |
29 | ## Code Style
30 | We use [PEP8](https://www.python.org/dev/peps/pep-0008) with a line-width of 120 characters. For
31 | docstrings we use [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html).
32 |
33 | We use [`flake8`](https://pypi.org/project/flake8/) for quick style checks and
34 | [`pylint`](https://pypi.org/project/pylint/) for thorough style checks.
35 |
36 | ## Testing
37 | Tests are written using [Pytest](https://docs.pytest.org) and executed
38 | in a separate environment using [Tox](https://tox.readthedocs.io/en/latest/).
39 |
40 | A full style check and all tests can be run by simply calling `tox` in the repository root.
41 |
42 | If you add a new feature, please also include appropriate tests to verify its intended functionality.
43 | We try to keep the code coverage close to 100%.
44 |
45 | ## Documentation
46 | The documentation uses [Sphinx](https://www.sphinx-doc.org). It can be built at
47 | `docs/build` using the respective Tox environment with `tox -e docs`. To rebuild the full
48 | documentation, `tox -e docs -- -aE` can be used.
49 |
50 | The API-documentation is generated from the numpydoc-style docstring of respective modules/classes/functions.
51 |
52 | ### Tutorials
53 | Tutorials are written as Jupyter notebooks in order to execute them using
54 | [Binder](https://mybinder.org/) or [Google
55 | Colab](https://colab.research.google.com/).
56 | They are found at [`docs/source/tutorial`](docs/source/tutorial).
57 | Their output should be empty when committing, as they will be executed when
58 | building the documentation.
59 | To reduce the building time of the documentation, their execution time should
60 | be kept short, i.e. large files like model parameters should not be downloaded
61 | automatically.
62 | To include parameter files for users, include a comment which describes how to
63 | use the full model/data, and provide the necessary code in a comment or an if-condition
64 | which always evaluates to `False`.
65 |
66 | ## Continuous Integration
67 | Linting, tests and the documentation are all checked using a Github Actions
68 | workflow which executes the appropriate tox environments.
69 |
--------------------------------------------------------------------------------
/COPYING.LESSER:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | * Zennit is licensed under the GNU LESSER GENERAL PUBLIC LICENSE VERSION 3 OR
2 | LATER -- see the 'COPYING' and 'COPYING.LESSER' files in the root directory for
3 | details.
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Zennit
2 | 
3 |
4 | [](https://zennit.readthedocs.io/en/latest/?badge=latest)
5 | [](https://github.com/chr5tphr/zennit/actions/workflows/tests.yml)
6 | [](https://pypi.org/project/zennit/)
7 | [](https://github.com/chr5tphr/zennit/blob/master/COPYING.LESSER)
8 |
9 | Zennit (**Z**ennit **e**xplains **n**eural **n**etworks **i**n **t**orch) is a
10 | high-level framework in Python using Pytorch for explaining/exploring neural
11 | networks. Its design philosophy is intended to provide high customizability and
12 | integration as a standardized solution for applying rule-based attribution
13 | methods in research, with a strong focus on Layerwise Relevance Propagation
14 | (LRP). Zennit strictly requires models to use Pytorch's `torch.nn.Module`
15 | structure (including activation functions).
16 |
17 | Zennit is currently under active development, but should be mostly stable.
18 |
19 | If you find Zennit useful for your research, please consider citing our related
20 | [paper](https://arxiv.org/abs/2106.13200):
21 | ```
22 | @article{anders2021software,
23 | author = {Anders, Christopher J. and
24 | Neumann, David and
25 | Samek, Wojciech and
26 | Müller, Klaus-Robert and
27 | Lapuschkin, Sebastian},
28 | title = {Software for Dataset-wide XAI: From Local Explanations to Global Insights with {Zennit}, {CoRelAy}, and {ViRelAy}},
29 | journal = {CoRR},
30 | volume = {abs/2106.13200},
31 | year = {2021},
32 | }
33 | ```
34 |
35 | ## Documentation
36 | The latest documentation is hosted at
37 | [zennit.readthedocs.io](https://zennit.readthedocs.io/en/latest/).
38 |
39 | ## Install
40 |
41 | To install directly from PyPI using pip, use:
42 | ```shell
43 | $ pip install zennit
44 | ```
45 |
46 | Alternatively, install from a manually cloned repository to try out the examples:
47 | ```shell
48 | $ git clone https://github.com/chr5tphr/zennit.git
49 | $ pip install ./zennit
50 | ```
51 |
52 | ## Usage
53 | At its heart, Zennit registers hooks at Pytorch's Module level, to modify the
54 | backward pass to produce rule-based attributions like LRP (instead of the usual
55 | gradient). All rules are implemented as hooks
56 | ([`zennit/rules.py`](src/zennit/rules.py)) and most use the LRP basis
57 | `BasicHook` ([`zennit/core.py`](src/zennit/core.py)).
58 |
59 | **Composites** ([`zennit/composites.py`](src/zennit/composites.py)) are a way
60 | of choosing the right hook for the right layer. In addition to the abstract
61 | **NameMapComposite**, which assigns hooks to layers by name, and
62 | **LayerMapComposite**, which assigns hooks to layers based on their Type, there
63 | exist explicit **Composites**, some of which are `EpsilonGammaBox` (`ZBox` in
64 | input, `Epsilon` in dense, `Gamma` in convolutions) or `EpsilonPlus` (`Epsilon`
65 | in dense, `ZPlus` in convolutions). All composites may be used by directly
66 | importing from `zennit.composites`, or by using their snake-case name as key
67 | for `zennit.composites.COMPOSITES`.
68 |
69 | **Canonizers** ([`zennit/canonizers.py`](src/zennit/canonizers.py)) temporarily
70 | transform models into a canonical form, if required, like
71 | `SequentialMergeBatchNorm`, which automatically detects and merges BatchNorm
72 | layers followed by linear layers in sequential networks, or
73 | `AttributeCanonizer`, which temporarily overwrites attributes of applicable
74 | modules, e.g. to handle the residual connection in ResNet-Bottleneck modules.
75 |
76 | **Attributors** ([`zennit/attribution.py`](src/zennit/attribution.py)) directly
77 | execute the necessary steps to apply certain attribution methods, like the
78 | simple `Gradient`, `SmoothGrad` or `Occlusion`. An optional **Composite** may
79 | be passed, which will be applied during the **Attributor**'s execution to
80 | compute the modified gradient, or hybrid methods.
81 |
82 | Using all of these components, an LRP-type attribution for VGG16 with
83 | batch-norm layers with respect to label 0 may be computed using:
84 |
85 | ```python
86 | import torch
87 | from torchvision.models import vgg16_bn
88 |
89 | from zennit.composites import EpsilonGammaBox
90 | from zennit.canonizers import SequentialMergeBatchNorm
91 | from zennit.attribution import Gradient
92 |
93 |
94 | data = torch.randn(1, 3, 224, 224)
95 | model = vgg16_bn()
96 |
97 | canonizers = [SequentialMergeBatchNorm()]
98 | composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
99 |
100 | with Gradient(model=model, composite=composite) as attributor:
101 | out, relevance = attributor(data, torch.eye(1000)[[0]])
102 | ```
103 |
104 | A similar setup using [the example script](share/example/feed_forward.py)
105 | produces the following attribution heatmaps:
106 | 
107 |
108 | For more details and examples, have a look at our
109 | [**documentation**](https://zennit.readthedocs.io/en/latest/).
110 |
111 | ### More Example Heatmaps
112 | More heatmaps of various attribution methods for VGG16 and ResNet50, all
113 | generated using
114 | [`share/example/feed_forward.py`](share/example/feed_forward.py), can be found
115 | below.
116 |
117 |
118 | Heatmaps for VGG16
119 |
120 | 
121 |
122 |
123 |
124 | Heatmaps for ResNet50
125 |
126 | 
127 |
128 |
129 | ## Contributing
130 | See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed instructions on how to contribute.
131 |
132 | ## License
133 | Zennit is licensed under the GNU LESSER GENERAL PUBLIC LICENSE VERSION 3 OR
134 | LATER -- see the [LICENSE](LICENSE), [COPYING](COPYING) and
135 | [COPYING.LESSER](COPYING.LESSER) files for details.
136 |
--------------------------------------------------------------------------------
/docs/source/_static/favicon.svg:
--------------------------------------------------------------------------------
1 |
2 |
21 |
23 |
25 |
28 |
32 |
36 |
37 |
48 |
49 |
72 |
77 |
78 |
80 |
81 |
83 | image/svg+xml
84 |
86 |
87 |
88 |
89 |
94 |
101 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/docs/source/_templates/modules.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 | :members:
5 | :show-inheritance:
6 |
7 | {% block attributes %}
8 | {% if attributes %}
9 | .. rubric:: {{ _('Module Attributes') }}
10 |
11 | .. autosummary::
12 | :nosignatures:
13 |
14 | {% for item in attributes %}
15 | {{ item }}
16 | {%- endfor %}
17 | {% endif %}
18 | {% endblock %}
19 |
20 | {% block functions %}
21 | {% if functions %}
22 | .. rubric:: {{ _('Functions') }}
23 |
24 | .. autosummary::
25 | :nosignatures:
26 |
27 | {% for item in functions %}
28 | {{ item }}
29 | {%- endfor %}
30 | {% endif %}
31 | {% endblock %}
32 |
33 | {% block classes %}
34 | {% if classes %}
35 | .. rubric:: {{ _('Classes') }}
36 |
37 | .. autosummary::
38 | :nosignatures:
39 |
40 | {% for item in classes %}
41 | {{ item }}
42 | {%- endfor %}
43 | {% endif %}
44 | {% endblock %}
45 |
46 | {% block exceptions %}
47 | {% if exceptions %}
48 | .. rubric:: {{ _('Exceptions') }}
49 |
50 | .. autosummary::
51 | :nosignatures:
52 |
53 | {% for item in exceptions %}
54 | {{ item }}
55 | {%- endfor %}
56 | {% endif %}
57 | {% endblock %}
58 |
59 | {% block modules %}
60 | {% if modules %}
61 | .. rubric:: Modules
62 |
63 | .. autosummary::
64 | :toctree:
65 | :recursive:
66 | {% for item in modules %}
67 | {{ item }}
68 | {%- endfor %}
69 | {% endif %}
70 | {% endblock %}
71 |
--------------------------------------------------------------------------------
/docs/source/bibliography.bib:
--------------------------------------------------------------------------------
1 |
2 | @inproceedings{zeiler2014visualizing,
3 | author = {Matthew D. Zeiler and
4 | Rob Fergus},
5 | title = {Visualizing and Understanding Convolutional Networks},
6 | booktitle = {Computer Vision - {ECCV} 2014 - 13th European Conference, Zurich,
7 | Switzerland, September 6-12, 2014, Proceedings, Part {I}},
8 | series = {Lecture Notes in Computer Science},
9 | volume = {8689},
10 | pages = {818--833},
11 | publisher = {Springer},
12 | year = {2014},
13 | url = {https://doi.org/10.1007/978-3-319-10590-1_53},
14 | }
15 |
16 |
17 | @article{bach2015pixel,
18 | author = {Sebastian Bach and
19 | Alexander Binder and
20 | Gr{\'e}goire Montavon and
21 | Frederick Klauschen and
22 | Klaus-Robert M{\"u}ller and
23 | Wojciech Samek},
24 | title = {On pixel-wise explanations for non-linear classifier decisions by
25 | layer-wise relevance propagation},
26 | journal = {PloS one},
27 | volume = {10},
28 | number = {7},
29 | pages = {e0130140},
30 | year = {2015},
31 | publisher = {Public Library of Science San Francisco, CA USA},
32 | url = {https://doi.org/10.1371/journal.pone.0130140}
33 | }
34 |
35 | @inproceedings{springenberg2015striving,
36 | author = {Jost Tobias Springenberg and
37 | Alexey Dosovitskiy and
38 | Thomas Brox and
39 | Martin A. Riedmiller},
40 | title = {Striving for Simplicity: The All Convolutional Net},
41 | booktitle = {3rd International Conference on Learning Representations, {ICLR} 2015,
42 | San Diego, CA, USA, May 7-9, 2015, Workshop Track Proceedings},
43 | year = {2015},
44 | url = {http://arxiv.org/abs/1412.6806},
45 | }
46 |
47 | @inproceedings{zhang2016top,
48 | author = {Jianming Zhang and
49 | Zhe L. Lin and
50 | Jonathan Brandt and
51 | Xiaohui Shen and
52 | Stan Sclaroff},
53 | title = {Top-Down Neural Attention by Excitation Backprop},
54 | booktitle = {Computer Vision - {ECCV} 2016 - 14th European Conference, Amsterdam,
55 | The Netherlands, October 11-14, 2016, Proceedings, Part {IV}},
56 | series = {Lecture Notes in Computer Science},
57 | volume = {9908},
58 | pages = {543--559},
59 | publisher = {Springer},
60 | year = {2016},
61 | url = {https://doi.org/10.1007/978-3-319-46493-0_33},
62 | }
63 |
64 | @article{montavon2017explaining,
65 | author = {Gr{\'{e}}goire Montavon and
66 | Sebastian Lapuschkin and
67 | Alexander Binder and
68 | Wojciech Samek and
69 | Klaus{-}Robert M{\"{u}}ller},
70 | title = {Explaining nonlinear classification decisions with deep Taylor decomposition},
71 | journal = {Pattern Recognit.},
72 | volume = {65},
73 | pages = {211--222},
74 | year = {2017},
75 | url = {https://doi.org/10.1016/j.patcog.2016.11.008},
76 | }
77 |
78 | @inproceedings{sundararajan2017axiomatic,
79 | author = {Mukund Sundararajan and
80 | Ankur Taly and
81 | Qiqi Yan},
82 | title = {Axiomatic Attribution for Deep Networks},
83 | booktitle = {Proceedings of the 34th International Conference on Machine Learning,
84 | {ICML} 2017, Sydney, NSW, Australia, 6-11 August 2017},
85 | series = {Proceedings of Machine Learning Research},
86 | volume = {70},
87 | pages = {3319--3328},
88 | publisher = {{PMLR}},
89 | year = {2017},
90 | url = {http://proceedings.mlr.press/v70/sundararajan17a.html},
91 | }
92 |
93 | @article{smilkov2017smoothgrad,
94 | author = {Daniel Smilkov and
95 | Nikhil Thorat and
96 | Been Kim and
97 | Fernanda B. Vi{\'{e}}gas and
98 | Martin Wattenberg},
99 | title = {SmoothGrad: removing noise by adding noise},
100 | journal = {CoRR},
101 | volume = {abs/1706.03825},
102 | year = {2017},
103 | url = {https://arxiv.org/abs/1706.03825},
104 | }
105 |
106 | @article{DBLP:journals/corr/abs-1902-10178,
107 | author = {Sebastian Lapuschkin and
108 | Stephan W{\"{a}}ldchen and
109 | Alexander Binder and
110 | Gr{\'{e}}goire Montavon and
111 | Wojciech Samek and
112 | Klaus{-}Robert M{\"{u}}ller},
113 | title = {Unmasking Clever Hans Predictors and Assessing What Machines Really
114 | Learn},
115 | journal = {CoRR},
116 | volume = {abs/1902.10178},
117 | year = {2019},
118 | url = {http://arxiv.org/abs/1902.10178},
119 | }
120 |
121 | @article{lapuschkin2019unmasking,
122 | title = {Unmasking Clever Hans predictors and assessing what machines really learn},
123 | author = {Sebastian Lapuschkin and
124 | Stephan W{\"a}ldchen and
125 | Alexander Binder and
126 | Gr{\'e}goire Montavon and
127 | Wojciech Samek and
128 | Klaus-Robert M{\"u}ller},
129 | journal = {Nature communications},
130 | volume = {10},
131 | number = {1},
132 | pages = {1--8},
133 | year = {2019},
134 | publisher = {Nature Publishing Group},
135 | url = {https://doi.org/10.1038/s41467-019-08987-4}
136 | }
137 |
138 |
139 | @incollection{montavon2019layer,
140 | author = {Gr{\'{e}}goire Montavon and
141 | Alexander Binder and
142 | Sebastian Lapuschkin and
143 | Wojciech Samek and
144 | Klaus{-}Robert M{\"{u}}ller},
145 | title = {Layer-Wise Relevance Propagation: An Overview},
146 | booktitle = {Explainable {AI:} Interpreting, Explaining and Visualizing Deep Learning},
147 | series = {Lecture Notes in Computer Science},
148 | volume = {11700},
149 | pages = {193--209},
150 | publisher = {Springer},
151 | year = {2019},
152 | url = {https://doi.org/10.1007/978-3-030-28954-6_10},
153 | }
154 |
155 | @inproceedings{dombrowski2019explanations,
156 | author = {Ann{-}Kathrin Dombrowski and
157 | Maximilian Alber and
158 | Christopher J. Anders and
159 | Marcel Ackermann and
160 | Klaus{-}Robert M{\"{u}}ller and
161 | Pan Kessel},
162 | title = {Explanations can be manipulated and geometry is to blame},
163 | booktitle = {Advances in Neural Information Processing Systems 32: Annual Conference
164 | on Neural Information Processing Systems 2019, NeurIPS 2019, December
165 | 8-14, 2019, Vancouver, BC, Canada},
166 | pages = {13567--13578},
167 | year = {2019},
168 | url = {https://proceedings.neurips.cc/paper/2019/hash/bb836c01cdc9120a9c984c525e4b1a4a-Abstract.html},
169 | }
170 |
171 | @inproceedings{anders2020fairwashing,
172 | author = {Christopher J. Anders and
173 | Plamen Pasliev and
174 | Ann{-}Kathrin Dombrowski and
175 | Klaus{-}Robert M{\"{u}}ller and
176 | Pan Kessel},
177 | title = {Fairwashing explanations with off-manifold detergent},
178 | booktitle = {Proceedings of the 37th International Conference on Machine Learning,
179 | {ICML} 2020, 13-18 July 2020, Virtual Event},
180 | series = {Proceedings of Machine Learning Research},
181 | volume = {119},
182 | pages = {314--323},
183 | publisher = {{PMLR}},
184 | year = {2020},
185 | url = {http://proceedings.mlr.press/v119/anders20a.html},
186 | }
187 |
188 | @article{anders2021software,
189 | author = {Christopher J. Anders and
190 | David Neumann and
191 | Wojciech Samek and
192 | Klaus{-}Robert M{\"{u}}ller and
193 | Sebastian Lapuschkin},
194 | title = {Software for Dataset-wide {XAI:} From Local Explanations to Global
195 | Insights with Zennit, CoRelAy, and ViRelAy},
196 | journal = {CoRR},
197 | volume = {abs/2106.13200},
198 | year = {2021},
199 | url = {https://arxiv.org/abs/2106.13200},
200 | }
201 |
202 | @article{andeol2021learning,
203 | author = {L{\'{e}}o And{\'{e}}ol and
204 | Yusei Kawakami and
205 | Yuichiro Wada and
206 | Takafumi Kanamori and
207 | Klaus{-}Robert M{\"{u}}ller and
208 | Gr{\'{e}}goire Montavon},
209 | title = {Learning Domain Invariant Representations by Joint Wasserstein Distance
210 | Minimization},
211 | journal = {CoRR},
212 | volume = {abs/2106.04923},
213 | year = {2021},
214 | url = {https://arxiv.org/abs/2106.04923},
215 | }
216 |
--------------------------------------------------------------------------------
/docs/source/bibliography.rst:
--------------------------------------------------------------------------------
1 | ============
2 | Bibliography
3 | ============
4 |
5 | .. bibliography::
6 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 | import sys
17 | import os
18 | from subprocess import run, CalledProcessError
19 | import inspect
20 | import pkg_resources
21 |
22 | from pybtex.style.formatting.plain import Style as PlainStyle
23 | from pybtex.style.labels import BaseLabelStyle
24 | from pybtex.plugin import register_plugin
25 |
26 |
27 | # -- Project information -----------------------------------------------------
28 | project = 'zennit'
29 | copyright = '2021, chr5tphr'
30 | author = 'chr5tphr'
31 |
32 |
33 | # -- General configuration ---------------------------------------------------
34 |
35 | # Add any Sphinx extension module names here, as strings. They can be
36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
37 | # ones.
38 | extensions = [
39 | 'sphinx.ext.autodoc',
40 | 'sphinx.ext.autosummary',
41 | 'sphinx.ext.napoleon',
42 | 'sphinx.ext.linkcode',
43 | 'sphinx.ext.mathjax',
44 | 'sphinx.ext.extlinks',
45 | 'sphinx_rtd_theme',
46 | 'sphinx_copybutton',
47 | 'sphinxcontrib.datatemplates',
48 | 'sphinxcontrib.bibtex',
49 | 'nbsphinx',
50 | ]
51 |
52 |
53 | def config_inited_handler(app, config):
54 | os.makedirs(os.path.join(app.srcdir, app.config.generated_path), exist_ok=True)
55 |
56 |
57 | def setup(app):
58 | app.add_config_value('REVISION', 'master', 'env')
59 | app.add_config_value('generated_path', '_generated', 'env')
60 | app.connect('config-inited', config_inited_handler)
61 |
62 |
63 | # Add any paths that contain templates here, relative to this directory.
64 | templates_path = ['_templates']
65 |
66 | # List of patterns, relative to source directory, that match files and
67 | # directories to ignore when looking for source files.
68 | # This pattern also affects html_static_path and html_extra_path.
69 | exclude_patterns = []
70 |
71 | # interactive badges for binder and colab
72 | nbsphinx_prolog = r"""
73 | {% set docname = 'docs/source/' + env.doc2path(env.docname, base=False) %}
74 |
75 | .. raw:: html
76 |
77 |
93 | """
94 |
95 | # autosummary_generate = True
96 |
97 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: "
98 | copybutton_prompt_is_regexp = True
99 | copybutton_line_continuation_character = "\\"
100 | copybutton_here_doc_delimiter = "EOT"
101 |
102 | # -- Options for HTML output -------------------------------------------------
103 |
104 | # The theme to use for HTML and HTML Help pages. See the documentation for
105 | # a list of builtin themes.
106 | #
107 | html_theme = 'sphinx_rtd_theme'
108 | # html_theme = 'alabaster'
109 |
110 | html_favicon = '_static/favicon.svg'
111 |
112 | # Add any paths that contain custom static files (such as style sheets) here,
113 | # relative to this directory. They are copied after the builtin static files,
114 | # so a file named "default.css" will overwrite the builtin "default.css".
115 | html_static_path = ['_static']
116 |
117 | bibtex_bibfiles = ['bibliography.bib']
118 | bibtex_default_style = 'author_year_style'
119 | bibtex_reference_style = 'author_year'
120 |
121 |
122 | class AuthorYearLabelStyle(BaseLabelStyle):
123 | def format_labels(self, sorted_entries):
124 | for entry in sorted_entries:
125 | yield f'[{entry.persons["author"][0].last_names[0]} et al., {entry.fields["year"]}]'
126 |
127 |
128 | class AuthorYearStyle(PlainStyle):
129 | default_label_style = AuthorYearLabelStyle
130 |
131 |
132 | register_plugin('pybtex.style.formatting', 'author_year_style', AuthorYearStyle)
133 |
134 |
135 | def getrev():
136 | try:
137 | revision = run(
138 | ['git', 'describe', '--tags', 'HEAD'],
139 | capture_output=True,
140 | check=True,
141 | text=True
142 | ).stdout[:-1]
143 | except CalledProcessError:
144 | revision = 'master'
145 |
146 | return revision
147 |
148 |
149 | REVISION = getrev()
150 |
151 | extlinks = {
152 | 'repo': (
153 | f'https://github.com/chr5tphr/zennit/blob/{REVISION}/%s',
154 | '%s'
155 | )
156 | }
157 |
158 | LINKCODE_URL = (
159 | f'https://github.com/chr5tphr/zennit/blob/{REVISION}'
160 | '/src/{filepath}#L{linestart}-L{linestop}'
161 | )
162 |
163 |
164 | # revised from https://gist.github.com/nlgranger/55ff2e7ff10c280731348a16d569cb73
165 | def linkcode_resolve(domain, info):
166 | if domain != 'py' or not info['module']:
167 | return None
168 |
169 | modname = info['module']
170 | topmodulename = modname.split('.')[0]
171 | fullname = info['fullname']
172 |
173 | submod = sys.modules.get(modname)
174 | if submod is None:
175 | return None
176 |
177 | obj = submod
178 | for part in fullname.split('.'):
179 | try:
180 | obj = getattr(obj, part)
181 | except Exception:
182 | return None
183 |
184 | try:
185 | modpath = pkg_resources.require(topmodulename)[0].location
186 | filepath = os.path.relpath(inspect.getsourcefile(obj), modpath)
187 | if filepath is None:
188 | return
189 | except Exception:
190 | return None
191 |
192 | try:
193 | source, lineno = inspect.getsourcelines(obj)
194 | except OSError:
195 | return None
196 | else:
197 | linestart, linestop = lineno, lineno + len(source) - 1
198 |
199 | return LINKCODE_URL.format(filepath=filepath, linestart=linestart, linestop=linestop)
200 |
--------------------------------------------------------------------------------
/docs/source/getting-started.rst:
--------------------------------------------------------------------------------
1 | ================
2 | Getting started
3 | ================
4 |
5 |
6 | Install
7 | -------
8 |
9 | Zennit can be installed directly from PyPI:
10 |
11 | .. code-block:: console
12 |
13 | $ pip install zennit
14 |
15 | For the current development version, or to try out examples, Zennit may be
16 | alternatively cloned and installed with
17 |
18 | .. code-block:: console
19 |
20 | $ git clone https://github.com/chr5tphr/zennit.git
21 | $ pip install ./zennit
22 |
23 | Basic Usage
24 | -----------
25 |
26 | Zennit implements propagation-based attribution methods by overwriting the
27 | gradient of PyTorch modules in PyTorch's auto-differentiation engine. This means
28 | that Zennit will only work on models which are strictly implemented using
29 | PyTorch modules, including activation functions. The following demonstrates a
30 | setup to compute Layer-wise Relevance Propagation (LRP) relevance for a simple
31 | model and random data.
32 |
33 | .. code-block:: python
34 |
35 | import torch
36 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
37 |
38 |
39 | # setup the model and data
40 | model = Sequential(
41 | Conv2d(3, 10, 3, padding=1),
42 | ReLU(),
43 | Flatten(),
44 | Linear(10 * 32 * 32, 10),
45 | )
46 | input = torch.randn(1, 3, 32, 32)
47 |
48 | The most important high-level structures in Zennit are ``Composites``,
49 | ``Attributors`` and ``Canonizers``.
50 |
51 |
52 | Composites
53 | ^^^^^^^^^^
54 |
55 | Composites map ``Rules`` to modules based on their properties and context to
56 | modify their gradient. The most common composites for LRP are implemented in
57 | :py:mod:`zennit.composites`.
58 |
59 | The following computes LRP relevance using the ``EpsilonPlusFlat`` composite:
60 |
61 | .. code-block:: python
62 |
63 | from zennit.composites import EpsilonPlusFlat
64 |
65 |
66 | # create a composite instance
67 | composite = EpsilonPlusFlat()
68 |
69 | # use the following instead to ignore bias for the relevance
70 | # composite = EpsilonPlusFlat(zero_params='bias')
71 |
72 | # make sure the input requires a gradient
73 | input.requires_grad = True
74 |
75 | # compute the output and gradient within the composite's context
76 | with composite.context(model) as modified_model:
77 | output = modified_model(input)
78 | # gradient/ relevance wrt. class/output 0
79 | output.backward(gradient=torch.eye(10)[[0]])
80 | # relevance is not accumulated in .grad if using torch.autograd.grad
81 | # relevance, = torch.autograd.grad(output, input, torch.eye(10)[[0])
82 |
83 | # gradient is accumulated in input.grad
84 | print('Backward:', input.grad)
85 |
86 |
87 | The context created by :py:func:`zennit.core.Composite.context` registers the
88 | composite, which means that all rules are applied according to the composite's
89 | mapping. See :doc:`/how-to/use-rules-composites-and-canonizers` for information on
90 | using composites, :py:mod:`zennit.composites` for an API reference and
91 | :doc:`/how-to/write-custom-composites` for writing new compositors. Available
92 | ``Rules`` can be found in :py:mod:`zennit.rules`, their use is described in
93 | :doc:`/how-to/use-rules-composites-and-canonizers` and how to add new ones is described in
94 | :doc:`/how-to/write-custom-rules`.
95 |
96 | Attributors
97 | ^^^^^^^^^^^
98 |
99 | Alternatively, *attributors* may be used instead of ``composite.context``.
100 |
101 | .. code-block:: python
102 |
103 | from zennit.attribution import Gradient
104 |
105 |
106 | attributor = Gradient(model, composite)
107 |
108 | with attributor:
109 | # gradient/ relevance wrt. output/class 1
110 | output, relevance = attributor(input, torch.eye(10)[[1]])
111 |
112 | print('EpsilonPlusFlat:', relevance)
113 |
114 | Attribution methods which are not propagation-based, like
115 | :py:class:`zennit.attribution.SmoothGrad` are implemented as attributors, and
116 | may be combined with propagation-based (composite) approaches.
117 |
118 | .. code-block:: python
119 |
120 | from zennit.attribution import SmoothGrad
121 |
122 |
123 | # we do not need a composite to compute vanilla SmoothGrad
124 | with SmoothGrad(model, noise_level=0.1, n_iter=10) as attributor:
125 | # gradient/ relevance wrt. output/class 7
126 | output, relevance = attributor(input, torch.eye(10)[[7]])
127 |
128 | print('SmoothGrad:', relevance)
129 |
130 | More information on attributors can be found in :doc:`/how-to/use-attributors`
131 | and :doc:`/how-to/write-custom-attributors`.
132 |
133 | Canonizers
134 | ^^^^^^^^^^
135 |
136 | For some modules and operations, Layer-wise Relevance Propagation (LRP) is not
137 | implementation-invariant, eg. ``BatchNorm -> Dense -> ReLU`` will be attributed
138 | differently than ``Dense -> BatchNorm -> ReLU``. Therefore, LRP needs a
139 | canonical form of the model, which is implemented in ``Canonizers``. These may
140 | be simply supplied when instantiating a composite:
141 |
142 | .. code-block:: python
143 |
144 | from torchvision.models import vgg16
145 | from zennit.composites import EpsilonGammaBox
146 | from zennit.torchvision import VGGCanonizer
147 |
148 |
149 | # instantiate the model
150 | model = vgg16()
151 | # create the canonizers
152 | canonizers = [VGGCanonizer()]
153 | # EpsilonGammaBox needs keyword arguments 'low' and 'high'
154 | composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)
155 |
156 | with Gradient(model, composite) as attributor:
157 | # gradient/ relevance wrt. output/class 0
158 | # torchvision.vgg16 has 1000 output classes by default
159 | output, relevance = attributor(input, torch.eye(1000)[[0]])
160 |
161 | print('EpsilonGammaBox:', relevance)
162 |
163 | Some pre-defined canonizers for models from ``torchvision`` can be found in
164 | :py:mod:`zennit.torchvision`. The :py:class:`zennit.torchvision.VGGCanonizer`
165 | specifically is simply :py:class:`zennit.canonizers.SequentialMergeBatchNorm`,
166 | which may be used when ``BatchNorm`` is used in sequential models. Note that for
167 | ``SequentialMergeBatchNorm`` to work, all functions (linear layers, activations,
168 | ...) must be modules and assigned to their parent module in the order they are
169 | visited (see :py:class:`zennit.canonizers.SequentialMergeBatchNorm`). For more
170 | information on canonizers see :doc:`/how-to/use-rules-composites-and-canonizers` and
171 | :doc:`/how-to/write-custom-canonizers`.
172 |
173 |
174 | Visualizing Results
175 | ^^^^^^^^^^^^^^^^^^^
176 |
177 | While attribution approaches are not limited to the domain of images, they are
178 | predominantly used on image models and produce heat maps of relevance. For
179 | this reason, Zennit implements methods to visualize relevance heat maps.
180 |
181 | .. code-block:: python
182 |
183 | from zennit.image import imsave
184 |
185 |
186 | # sum over the color channels
187 | heatmap = relevance.sum(1)
188 | # get the absolute maximum, to center the heat map around 0
189 | amax = heatmap.abs().numpy().max((1, 2))
190 |
191 | # save heat map with color map 'coldnhot'
192 | imsave(
193 | 'heatmap.png',
194 | heatmap[0],
195 | vmin=-amax,
196 | vmax=amax,
197 | cmap='coldnhot',
198 | level=1.0,
199 | grid=False
200 | )
201 |
202 | Information on ``imsave`` can be found at :py:func:`zennit.image.imsave`.
203 | Saving an image with 3 color channels will result in the image being saved
204 | without a color map but with the channels assumed as RGB. The keyword argument
205 | ``grid`` will create a grid of multiple images over the batch dimension if
206 | ``True``. Custom color maps may be created with
207 | :py:class:`zennit.cmap.ColorMap`, eg. to save the previous image with a color
208 | map ranging from blue to yellow to red:
209 |
210 | .. code-block:: python
211 |
212 | from zennit.cmap import ColorMap
213 |
214 |
215 | # 00f is blue, ff0 is yellow, f00 is red, 0x80 is the center of the range
216 | cmap = ColorMap('00f,80:ff0,f00')
217 |
218 | imsave(
219 | 'heatmap.png',
220 | heatmap,
221 | vmin=-amax,
222 | vmax=amax,
223 | cmap=cmap,
224 | level=1.0,
225 | grid=True
226 | )
227 |
228 | More details to visualize heat maps and color maps can be found in
229 | :doc:`/how-to/visualize-results`. The ColorMap specification language is
230 | described in :py:class:`zennit.cmap.ColorMap` and built-in color maps are
231 | implemented in :py:obj:`zennit.image.CMAPS`.
232 |
233 | Example Script
234 | --------------
235 |
236 | A ready-to use example to analyze a few ImageNet models provided by torchvision
237 | can be found at :repo:`share/example/feed_forward.py`.
238 |
239 | The following setup requires bash, cURL and (magic-)file.
240 |
241 | Create a virtual environment, install Zennit and download the example scripts:
242 |
243 | .. code-block:: console
244 |
245 | $ mkdir zennit-example
246 | $ cd zennit-example
247 | $ python -m venv .venv
248 | $ .venv/bin/pip install zennit
249 | $ curl -o feed_forward.py \
250 | 'https://raw.githubusercontent.com/chr5tphr/zennit/master/share/example/feed_forward.py'
251 | $ curl -o download-lighthouses.sh \
252 | 'https://raw.githubusercontent.com/chr5tphr/zennit/master/share/scripts/download-lighthouses.sh'
253 |
254 | Prepare the data required for the example:
255 |
256 | .. code-block:: console
257 |
258 | $ mkdir params data results
259 | $ bash download-lighthouses.sh --output data/lighthouses
260 | $ curl -o params/vgg16-397923af.pth 'https://download.pytorch.org/models/vgg16-397923af.pth'
261 |
262 | This creates the needed directories and downloads the pre-trained vgg16
263 | parameters and 8 images of light houses from wikimedia commons into the
264 | required label-directory structure for the imagenet dataset in PyTorch.
265 |
266 | The ``feed_forward.py`` example can then be run using:
267 |
268 | .. code-block:: console
269 |
270 | $ .venv/bin/python feed_forward.py \
271 | data/lighthouses \
272 | 'results/vgg16_epsilon_gamma_box_{sample:02d}.png' \
273 | --inputs 'results/vgg16_input_{sample:02d}.png' \
274 | --parameters params/vgg16-397923af.pth \
275 | --model vgg16 \
276 | --composite epsilon_gamma_box \
277 | --no-bias \
278 | --relevance-norm symmetric \
279 | --cmap coldnhot
280 |
281 | which computes the lrp heatmaps according to the ``epsilon_gamma_box`` rule and
282 | stores them in results, along with the respective input images. Other possible
283 | composites that can be passed to ``--composites`` are, e.g., ``epsilon_plus``,
284 | ``epsilon_alpha2_beta1_flat``, ``guided_backprop``, ``excitation_backprop``.
285 | The bias can be ignored in the LRP-computation by passing ``--no-bias``.
286 |
287 |
288 | ..
289 | The resulting heatmaps may look like the following:
290 |
291 | .. image:: /img/beacon_vgg16_epsilon_gamma_box.png
292 | :alt: Lighthouses with Attributions
293 |
294 | Alternatively, heatmaps for SmoothGrad with absolute relevances may be computed
295 | by omitting ``--composite`` and supplying ``--attributor``:
296 |
297 | .. code-block:: console
298 |
299 | $ .venv/bin/python feed_forward.py \
300 | data/lighthouses \
301 | 'results/vgg16_smoothgrad_{sample:02d}.png' \
302 | --inputs 'results/vgg16_input_{sample:02d}.png' \
303 | --parameters params/vgg16-397923af.pth \
304 | --model vgg16 \
305 | --attributor smoothgrad \
306 | --relevance-norm absolute \
307 | --cmap hot
308 |
309 | For Integrated Gradients, ``--attributor integrads`` may be provided.
310 |
311 | Heatmaps for Occlusion Analysis with unaligned relevances may be computed by
312 | executing:
313 |
314 | .. code-block:: console
315 |
316 | $ .venv/bin/python feed_forward.py \
317 | data/lighthouses \
318 | 'results/vgg16_occlusion_{sample:02d}.png' \
319 | --inputs 'results/vgg16_input_{sample:02d}.png' \
320 | --parameters params/vgg16-397923af.pth \
321 | --model vgg16 \
322 | --attributor occlusion \
323 | --relevance-norm unaligned \
324 | --cmap hot
325 |
326 |
--------------------------------------------------------------------------------
/docs/source/how-to/compute-second-order-gradients.rst:
--------------------------------------------------------------------------------
1 | ================================
2 | Computing Second Order Gradients
3 | ================================
4 |
5 | Sometimes, it may be necessary to compute the gradient of the attribution. One
6 | example is to compute the gradient with respect to the input in order to
7 | find adversarial explanations :cite:p:`dombrowski2019explanations`,
8 | or to regularize or transform the attributions of a network
9 | :cite:p:`anders2020fairwashing`.
10 |
11 | In Zennit, the attribution is computed using the modified gradient, which means
12 | that in order to compute the gradient of the attribution, the second order
13 | gradient needs to be computed. Pytorch natively supports the computation of
14 | higher order gradients, simply by supplying ``create_graph=True`` with
15 | :py:func:`torch.autograd.grad` to declare that the backward-function needs to
16 | be backward-able itself.
17 |
18 |
19 | Vanilla Gradient and ReLU
20 | -------------------------
21 |
22 | If we simply need the second order gradient of a model, without using Zennit, we can do the following:
23 |
24 | .. code-block:: python
25 |
26 | import torch
27 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
28 |
29 |
30 | # setup the model and data
31 | model = Sequential(
32 | Conv2d(3, 10, 3, padding=1),
33 | ReLU(),
34 | Flatten(),
35 | Linear(10 * 32 * 32, 10),
36 | )
37 | input = torch.randn(1, 3, 32, 32)
38 |
39 | # make sure the input requires a gradient
40 | input.requires_grad = True
41 |
42 | output = model(input)
43 | # a vector for the vector-jacobian-product, i.e. the grad_output
44 | target = torch.ones_like(output)
45 |
46 | grad, = torch.autograd.grad(output, input, target, create_graph=True)
47 |
48 | # the grad_output for grad
49 | gradtarget = torch.ones_like(grad)
50 | # compute the second order gradient
51 | gradgrad, = torch.autograd.grad(grad, input, gradtarget)
52 |
53 | Here, you might notice that ``gradgrad`` is all zeros, regardless of the input
54 | and model parameters. The culprit is ``ReLU``, which has a gradient of zero
55 | everywhere except at zero, where it is undefined. In order to get a meaningful
56 | gradient, we could instead use a *smooth* activation function in our model.
57 | However, ReLU models are quite common, and we may not like to retrain every
58 | model using only smooth activation functions.
59 |
60 | :cite:t:`dombrowski2019explanations` proposed to replace the ReLU activations
61 | with its smooth variation, the *Softplus* function:
62 |
63 | .. math::
64 |
65 | \text{Softplus}(x;\beta) = \frac{1}{\beta} \log (1 + \exp (\beta x))
66 | \,\text{.}
67 |
68 | With :math:`\beta\rightarrow\infty`, Softplus will be equivalent to ReLU, but in
69 | practice choosing :math:`\beta = 10` is most often sufficient to keep the model
70 | output unchanged but still obtain a meaningful second order gradient.
71 |
72 | To temporarily replace the ReLU gradients in-place, we can use the
73 | :py:class:`~zennit.rules.ReLUBetaSmooth` rule:
74 |
75 |
76 | .. code-block:: python
77 |
78 | from zennit.composites import BetaSmooth
79 |
80 | # LayerMapComposite which assigns the ReLUBetaSmooth hook to ReLUs
81 | composite = BetaSmooth(beta_smooth=10.)
82 |
83 | with composite.context(model):
84 | output = model(input)
85 | target = torch.ones_like(output)
86 | grad, = torch.autograd.grad(output, input, target, create_graph=True)
87 |
88 | gradtarget = torch.ones_like(grad)
89 | gradgrad, = torch.autograd.grad(grad, input, gradtarget)
90 |
91 | Notice here that we computed the second order gradient **outside** of the
92 | composite context. A property of the Pytorch gradients hooks is that they are
93 | also called when the *second* order gradient with respect to a tensor is
94 | computed.
95 | Due to this, computing the second order gradient *while rules are still
96 | registered* will lead to incorrect results.
97 |
98 | Temporarily Disabling Hooks
99 | ---------------------------
100 |
101 | In order compute the second order gradient *without* removing the hooks (i.e. to
102 | compute multiple values in a loop), we can temporarily deactivate them using
103 | :py:meth:`zennit.core.Composite.inactive`:
104 |
105 | .. code-block:: python
106 |
107 | with composite.context(model):
108 | output = model(input)
109 | target = torch.ones_like(output)
110 | grad, = torch.autograd.grad(output, input, target, create_graph=True)
111 |
112 | # temporarily disable all hooks registered by composite
113 | with composite.inactive():
114 | gradtarget = torch.ones_like(grad)
115 | gradgrad, = torch.autograd.grad(grad, input, gradtarget)
116 |
117 | All Attributors support the computation of gradients. For gradient-based
118 | attributors like :py:class:`~zennit.attribution.Gradient` or
119 | :py:class:`~zennit.attribution.SmoothGrad`, the ``create_graph=True`` parameter
120 | can be supplied to the class constructor:
121 |
122 | .. code-block:: python
123 |
124 | from zennit.attribution import Gradient
125 | from zennit.composites import EpsilonGammaBox
126 |
127 | # any composites support second order gradients
128 | composite = EpsilonGammaBox(low=-3., high=3.)
129 |
130 | with Gradient(model, composite, create_graph=True) as attributor:
131 | output, grad = attributor(input, torch.ones_like)
132 |
133 | # temporarily disable all hooks registered by the attributor's composite
134 | with attributor.inactive():
135 | gradtarget = torch.ones_like(grad)
136 | gradgrad, = torch.autograd.grad(grad, input, gradtarget)
137 |
138 | Here, we also used a different composite, which results in the gradient
139 | computation of the modified gradient. Since the ReLU gradient is ignored (using
140 | the :py:class:`~zennit.rules.Pass` rule) for Layer-wise Relevance
141 | Propagation-specific composites, we do not need to use the
142 | :py:class:`~zennit.rules.ReLUBetaSmooth` rule. However, if this behaviour
143 | should be overwritten, :ref:`cooperative-layermapcomposites` can be used.
144 |
145 | Using Hooks Only
146 | ----------------
147 |
148 | Under the hood, :py:class:`~zennit.core.Hook` has an attribute ``active``,
149 | which, when set to ``False``, will not execute the associated backward function.
150 | A minimal example without using composites would look like the following:
151 |
152 | .. code-block:: python
153 |
154 | from zennit.rules import Epsilon
155 |
156 | conv = Conv2d(3, 10, 3, padding=1)
157 |
158 | # create and register the hook
159 | epsilon = Epsilon()
160 | handles = epsilon.register(conv)
161 |
162 | output = conv(input)
163 | target = torch.ones_like(output)
164 | grad, = torch.autograd.grad(output, input, target, create_graph=True)
165 |
166 | # during this block, epsilon will be inactive
167 | epsilon.active = False
168 | grad_target = torch.ones_like(grad)
169 | gradgrad, = torch.autograd.grad(grad, input, grad_target)
170 | epsilon.active = True
171 |
172 | # after calling handles.remove, epsilon will also be inactive
173 | handles.remove()
174 |
175 | The same can here also be achieved by simply removing the handles before calling
176 | ``torch.autograd.grad`` on ``grad``, although the hooks would then need to be
177 | re-registered in order to compute the epsilon-modified gradient again.
178 |
--------------------------------------------------------------------------------
/docs/source/how-to/get-intermediate-relevance.rst:
--------------------------------------------------------------------------------
1 | ==============================
2 | Getting Intermediate Relevance
3 | ==============================
4 |
5 | In some cases, intermediate gradients or relevances of a model may be needed.
6 | Since Zennit uses Pytorch's autograd engine, intermediate relevances can be
7 | retained simply as the intermediate gradients of accessible non-leaf tensors
8 | in the tensor's ``.grad`` attribute by calling ``tensor.retain_grad()`` before
9 | the gradient computation.
10 |
11 | In most cases when using ``torch.nn.Module``-based models, the intermediate
12 | outputs are not easily accessible, which we can solve by using forward-hooks.
13 |
14 | We create following setting with some random input data and a simple, randomly
15 | initialized model, for which we want to compute the LRP EpsilonPlus relevance:
16 |
17 | .. code-block:: python
18 |
19 | import torch
20 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
21 |
22 | from zennit.attribution import Gradient
23 | from zennit.composites import EpsilonPlusFlat
24 |
25 | # setup the model and data
26 | model = Sequential(
27 | Conv2d(3, 10, 3, padding=1),
28 | ReLU(),
29 | Flatten(),
30 | Linear(10 * 32 * 32, 10),
31 | )
32 | input = torch.randn(1, 3, 32, 32)
33 |
34 | # make sure the input requires a gradient
35 | input.requires_grad = True
36 |
37 | # create a composite instance
38 | composite = EpsilonPlusFlat()
39 |
40 | # create a gradient attributor
41 | attributor = Gradient(model, composite)
42 |
43 | Now we create a function ``store_hook`` which we register as a forward hook to
44 | all modules. The function sets the module's attribute ``.output`` to its output
45 | tensor, and ensures the gradient is stored in the tensor's ``.grad`` attribute
46 | even if it is not a leaf-tensor by using ``.retain_grad()``.
47 |
48 | .. code-block:: python
49 |
50 | # create a hook to keep track of intermediate outputs
51 | def store_hook(module, input, output):
52 | # set the current module's attribute 'output' to the its tensor
53 | module.output = output
54 | # keep the output tensor gradient, even if it is not a leaf-tensor
55 | output.retain_grad()
56 |
57 | # enter the attributor's context to register the rule-hooks
58 | with attributor:
59 | # register the store_hook AFTER the rule-hooks have been registered (by
60 | # entering the context) so we get the last output before the next module
61 | handles = [
62 | module.register_forward_hook(store_hook) for module in model.modules()
63 | ]
64 | # compute the relevance wrt. output/class 1
65 | output, relevance = attributor(input, torch.eye(10)[[1]])
66 |
67 | # remove the hooks using store_hook
68 | for handle in handles:
69 | handle.remove()
70 |
71 | # print the gradient tensors for demonstration
72 | for name, module in model.named_modules():
73 | print(f'{name}: {module.output.grad}')
74 |
75 | The hooks are registered within the attributor's with-context, such that they
76 | are applied after the rule hooks. Once we are finished, we can remove the
77 | store-hooks by calling ``.remove()`` on all handles returned when registering the
78 | hooks.
79 |
80 | Be aware that storing the intermediate outputs and their gradients may require
81 | significantly more memory, depending on the model. In practice, it may be better
82 | to register the store-hook only to modules for which the relevance is needed.
83 |
--------------------------------------------------------------------------------
/docs/source/how-to/index.rst:
--------------------------------------------------------------------------------
1 | ================
2 | How-Tos
3 | ================
4 |
5 |
6 | These How-Tos give more detailed information on how to use Zennit.
7 |
8 | .. toctree::
9 | :maxdepth: 1
10 |
11 | use-rules-composites-and-canonizers
12 | use-attributors
13 | visualize-results
14 | get-intermediate-relevance
15 | compute-second-order-gradients
16 | write-custom-composites
17 | write-custom-canonizers
18 | write-custom-rules
19 | write-custom-attributors
20 |
--------------------------------------------------------------------------------
/docs/source/how-to/use-attributors.rst:
--------------------------------------------------------------------------------
1 | =================
2 | Using Attributors
3 | =================
4 |
5 | **Attributors** are used to both shorten Zennit's common ``composite.context ->
6 | gradient`` approach, as well as provide model-agnostic attribution approaches.
7 | Available **Attributors** can be found in :py:mod:`zennit.attribution`, some of
8 | which are:
9 |
10 | * :py:class:`~zennit.attribution.Gradient`, which computes the gradient
11 | * :py:class:`~zennit.attribution.IntegratedGradients`, which computes the
12 | Integrated Gradients
13 | * :py:class:`~zennit.attribution.SmoothGrad`, which computes SmoothGrad
14 | * :py:class:`~zennit.attribution.Occlusion`, which computes the attribution
15 | based on the model output activation values when occluding parts of the input
16 | with a sliding window
17 |
18 | Using the basic :py:class:`~zennit.attribution.Gradient`, the unmodified
19 | gradient may be computed with:
20 |
21 | .. code-block:: python
22 |
23 | import torch
24 | from torch.nn import Sequential, Conv2d, ReLU, Linear, Flatten
25 | from zennit.attribution import Gradient
26 |
27 | # setup the model
28 | model = Sequential(
29 | Conv2d(3, 8, 3, padding=1),
30 | ReLU(),
31 | Conv2d(8, 16, 3, padding=1),
32 | ReLU(),
33 | Flatten(),
34 | Linear(16 * 32 * 32, 1024),
35 | ReLU(),
36 | Linear(1024, 10),
37 | )
38 | # some random input data
39 | input = torch.randn(1, 3, 32, 32, requires_grad=True)
40 |
41 | # compute the gradient and output using the Gradient attributor
42 | with Gradient(model) as attributor:
43 | output, relevance = attributor(input)
44 |
45 | Computing attributions using a composite can be done with:
46 |
47 | .. code-block:: python
48 |
49 | from zennit.composites import EpsilonPlusFlat
50 |
51 | # prepare the composite
52 | composite = EpsilonPlusFlat()
53 |
54 | # compute the gradient within the composite's context, i.e. the
55 | # EpsilonPlusFlat LRP relevance
56 | with Gradient(model, composite) as attributor:
57 | # torch.eye is used here to get a one-hot encoding of the
58 | # first (index 0) label
59 | output, relevance = attributor(input, torch.eye(10)[[0]])
60 |
61 | which uses the second argument ``attr_output_fn`` of the call to
62 | :py:class:`~zennit.attribution.Attributor` to specify a constant tensor used for
63 | the *output relevance* (i.e. ``grad_output``), but alternatively, a function
64 | of the output may also be used:
65 |
66 | .. code-block:: python
67 |
68 | def one_hot_max(output):
69 | '''Get the one-hot encoded max at the original indices in dim=1'''
70 | values, indices = output.max(1)
71 | return values[:, None] * torch.eye(output.shape[1])[indices]
72 |
73 | with Gradient(model) as attributor:
74 | output, relevance = attributor(input, one_hot_max)
75 |
76 | The constructor of :py:class:`~zennit.attribution.Attributor` also has a third
77 | argument ``attr_output``, which also can either be a constant
78 | :py:class:`~torch.Tensor`, or a function of the model's output and specifies
79 | which *output relevance* (i.e. ``grad_output``) should be used by default. When
80 | not supplying anything, the default will be the *identity*. If the default
81 | should be for example ones for all outputs, one could write:
82 |
83 | .. code-block:: python
84 |
85 | # compute the gradient and output using the Gradient attributor, and with
86 | # a vector of ones as grad_output
87 | with Gradient(model, attr_output=torch.ones_like) as attributor:
88 | output, relevance = attributor(input)
89 |
90 | Gradient-based **Attributors** like
91 | :py:class:`~zennit.attribution.IntegratedGradients` and
92 | :py:class:`~zennit.attribution.SmoothGrad` may also be used together with
93 | composites to produce *hybrid attributions*:
94 |
95 | .. code-block:: python
96 |
97 | from zennit.attribution import SmoothGrad
98 |
99 | # prepare the composite
100 | composite = EpsilonPlusFlat()
101 |
102 | # do a *smooth* version of EpsilonPlusFlat LRP by using the SmoothGrad
103 | # attributor in combination with the composite
104 | with SmoothGrad(model, composite, noise_level=0.1, n_iter=20) as attributor:
105 | output, relevance = attributor(input, torch.eye(10)[[0]])
106 |
107 | which in this case will sample 20 samples in an epsilon-ball (size controlled
108 | with `noise_level`) around the input. Note that for Zennit's implementation of
109 | :py:class:`~zennit.attribution.SmoothGrad`, the first sample will always be the
110 | original input, i.e. ``SmoothGrad(model, n_iter=1)`` will produce the plain
111 | gradient as ``Gradient(model)`` would.
112 |
113 | :py:class:`~zennit.attribution.Occlusion` will move a sliding window with
114 | arbitrary size and strides over an input with any dimensionality. In addition to
115 | specifying window-size and strides, a function may be specified, which will be
116 | supplied with the input and a mask. When using the default, everything within
117 | the sliding window will be set to zero. A function
118 | :py:func:`zennit.attribution.occlude_independent` is available to simplify the
119 | process of specifying how to fill the window, and to invert the window if
120 | desired. The following adds some gaussian noise to the area within the sliding
121 | window:
122 |
123 | .. code-block:: python
124 |
125 | from functools import partial
126 | from zennit.attribution import Occlusion, occlude_independent
127 |
128 | input = torch.randn((16, 3, 32, 32))
129 |
130 | attributor = Occlusion(
131 | model,
132 | window=8, # 8x8 overlapping windows
133 | stride=4, # with strides 4x4
134 | occlusion_fn=partial( # occlusion_fn gets the full input and a mask
135 | occlude_independent, # applies fill_fn at provided mask
136 | fill_fn=lambda x: x * torch.randn_like(x) * 0.2, # add some noise
137 | invert=False # do not invert, i.e. occlude *within* mask
138 | )
139 | )
140 | with attributor:
141 | # for occlusion, the score for each window-pass is the sum of the
142 | # provided *grad_output*, which we choose as the model output at index 0
143 | output, relevance = attributor(input, lambda out: torch.eye(10)[[0]] * out)
144 |
145 |
146 | Note that while the interface allows to pass a composite for any
147 | :py:class:`~zennit.attribution.Attributor`, using a composite with
148 | :py:class:`~zennit.attribution.Occlusion` does not change the outcome, as it
149 | does not utilize the gradient.
150 |
151 | An introduction on how to write custom **Attributors** can be found at
152 | :doc:`/how-to/write-custom-attributors`.
153 |
--------------------------------------------------------------------------------
/docs/source/how-to/write-custom-attributors.rst:
--------------------------------------------------------------------------------
1 | ==========================
2 | Writing Custom Attributors
3 | ==========================
4 |
5 | **Attributors** provide an additional layer of abstraction over the context of
6 | **Composites**, and are used to directly produce *attributions*, which may or
7 | may not be computed with modified gradients, if they are used, from
8 | **Composites**.
9 | More information on **Attributors**, examples and their use can be found in
10 | :doc:`/how-to/use-attributors`.
11 |
12 | **Attributors** can be used to implement non-layer-wise or only partly
13 | layer-wise attribution methods.
14 | For this, it is enough to define a subclass of
15 | :py:class:`zennit.attribution.Attributor` and implement its
16 | :py:meth:`~zennit.attribution.Attributor.forward` and optionally its
17 | :py:meth:`~zennit.attribution.Attributor.__init__` methods.
18 |
19 | :py:meth:`~zennit.attribution.Attributor.forward` takes 2 arguments, the tensor
20 | with respect to which the attribution shall be computed ``input``, and
21 | ``attr_output_fn``, which is a function that, given the output of the
22 | attributed model, computes the *gradient output* for the gradient computation,
23 | which is, for example, a one-hot encoding of the target label of the attributed
24 | input.
25 | When calling an :py:class:`~zennit.attribution.Attributor`, the ``__call__``
26 | function will ensure ``forward`` receives a valid function to transform the
27 | output of the analyzed model to a tensor which can be used for the
28 | ``grad_output`` argument of :py:func:`torch.autograd.grad`.
29 | A constant tensor or function is provided by the user either to ``__init__`` or
30 | to ``__call__``.
31 | It is expected that :py:meth:`~zennit.attribution.Attributor.forward` will
32 | return a tuple containing, in order, the model output and the attribution.
33 |
34 | As an example, we can implement *gradient times input* in the following way:
35 |
36 | .. code-block:: python
37 |
38 | import torch
39 | from torchvision.models import vgg11
40 |
41 | from zennit.attribution import Attributor
42 |
43 |
44 | class GradientTimesInput(Attributor):
45 | '''Model-agnostic gradient times input.'''
46 | def forward(self, input, attr_output_fn):
47 | '''Compute gradient times input.'''
48 | input_detached = input.detach().requires_grad_(True)
49 | output = self.model(input_detached)
50 | gradient, = torch.autograd.grad(
51 | (output,), (input_detached,), (attr_output_fn(output.detach()),)
52 | )
53 | relevance = gradient * input
54 | return output, relevance
55 |
56 | model = vgg11()
57 | data = torch.randn((1, 3, 224, 224))
58 |
59 | with GradientTimesInput(model) as attributor:
60 | output, relevance = attributor(data)
61 |
62 | :py:class:`~zennit.attribution.Attributor` accepts an optional
63 | :py:class:`~zennit.core.Composite`, which, if supplied, will always be used to
64 | create a context in ``__call__`` around ``forward``.
65 | For the ``GradientTimesInput`` class above, using a **Composite** will probably
66 | not produce anything useful, although more involved combinations of custom
67 | **Rules** and a custom **Attributor** can be used to implement complex
68 | attribution methods with both model-agnostic and layer-wise parts.
69 |
70 | The following shows an example of *sensitivity analysis*, which is the absolute
71 | value, with a custom ``__init__()`` where we can pass the argument
72 | ``sum_channels`` to specify whether the **Attributor** should sum over the
73 | channel dimension:
74 |
75 | .. code-block:: python
76 |
77 | import torch
78 | from torchvision.models import vgg11
79 |
80 | from zennit.attribution import Attributor
81 |
82 |
83 | class SensitivityAnalysis(Attributor):
84 | '''Model-agnostic sensitivity analysis which optionally sums over color
85 | channels.
86 | '''
87 | def __init__(
88 | self, model, sum_channels=False, composite=None, attr_output=None
89 | ):
90 | super().__init__(
91 | model, composite=composite, attr_output=attr_output
92 | )
93 |
94 | self.sum_channels = sum_channels
95 |
96 |
97 | def forward(self, input, attr_output_fn):
98 | '''Compute the absolute gradient (or the sensitivity) and
99 | optionally sum over the color channels.
100 | '''
101 | input_detached = input.detach().requires_grad_(True)
102 | output = self.model(input_detached)
103 | gradient, = torch.autograd.grad(
104 | (output,), (input_detached,), (attr_output_fn(output.detach()),)
105 | )
106 | relevance = gradient.abs()
107 | if self.sum_channels:
108 | relevance = relevance.sum(1)
109 | return output, relevance
110 |
111 | model = vgg11()
112 | data = torch.randn((1, 3, 224, 224))
113 |
114 | with SensitivityAnalysis(model, sum_channels=True) as attributor:
115 | output, relevance = attributor(data)
116 |
--------------------------------------------------------------------------------
/docs/source/how-to/write-custom-canonizers.rst:
--------------------------------------------------------------------------------
1 | =========================
2 | Writing Custom Canonizers
3 | =========================
4 |
5 | **Canonizers** are used to temporarily transform models into a canonical form to
6 | mitigate the lack of implementation invariance of methods Layer-wise Relevance
7 | Propagation (LRP). A general introduction to **Canonizers** can be found here:
8 | :ref:`use-canonizers`.
9 |
10 | As both **Canonizers** and **Composites** (via **Rules**) change the outcome of
11 | the attribution, it can be a little bit confusing in the beginning when
12 | challenged with the question whether a novel network architectures needs a new
13 | set of **Rules** and **Composites**, or if it should be adapted to the existing
14 | framework using **Canonizers**. While ultimately it depends on the design
15 | preference of the developer, our suggestion is to go through the following steps
16 | in order:
17 |
18 | 1. Check whether a custom **Composite** is enough to correctly attribute the
19 | model, i.e. the new layer-type is only a composition of existing layer types
20 | without any unaccounted intermediate steps or incapabilities with existing
21 | rules.
22 | 2. If some of the rules which should be used are incompatible without changes
23 | (e.g. subsequent linear layers), or some parts of a module has intermediate
24 | computations that are not implemented with sub-modules, it should be checked
25 | whether a **Canonizer** can be implemented to fix these issues. If you are in
26 | control of the module in question, check whether rewriting the module with
27 | sub-modules is easier than implementing a **Canonizer**.
28 | 3. If the module consists of computations which cannot be separated into
29 | existing modules with compatible rules, or would result in an overly complex
30 | architecture, a custom **Rule** may be the choice to go with.
31 |
32 | **Rules** and **Composites** are not designed to change the forward computation
33 | of a model. While **Canonizers** can change the outcome of the forward pass,
34 | this should be used with care, since a modified function output means that the
35 | function itself has been modified, which will therefore result in an attribution
36 | of the modified function instead.
37 |
38 | To implement a custom **Canonizer**, a class inheriting from
39 | :py:class:`zennit.canonizers.Canonizer` needs to implement the following four
40 | methods:
41 |
42 | * :py:meth:`~zennit.canonizers.Canonizer.apply`, which finds the sub-modules
43 | that should be modified by the **Canonizer** and passes their information to ...
44 | * :py:meth:`~zennit.canonizers.Canonizer.register`, which copies the current
45 | instance using :py:meth:`~zennit.canonizers.Canonizer.copy`, applies the
46 | changes that should be introduced by the **Canonizer**, and makes sure they
47 | can be reverted later, using ...
48 | * :py:meth:`~zennit.canonizers.Canonizer.remove`, which reverts the changes
49 | introduced by the **Canonizer**, by i.e. loading the original parameters which
50 | were temporarily stored, and
51 | * :py:meth:`~zennit.canonizers.Canonizer.copy`, which copies the current
52 | instance, to create an individual instance for each applicable module with the
53 | same parameters.
54 |
55 | Suppose we have a ReLU model (e.g. VGG11) for which we want to compute the
56 | second-order derivative, e.g. to find an adversarial explanation (see
57 | :cite:p:`dombrowski2019explanations`). The ReLU is not differentiable at 0, and
58 | its second order derivative is zero everywhere except at 0, where it is
59 | undefined. :cite:t:`dombrowski2019explanations` replace the ReLU activations in
60 | a model with *Softplus* activations, which when running *beta* towards infinity
61 | will be identical to the ReLU activation. For the numerical estimate, it is
62 | enough to set *beta* to a relatively large value, e.g. to 10. The following is
63 | an implementation of the **SoftplusCanonizer**, which will temporarily replace
64 | the ReLU activations in a model with Softplus activations:
65 |
66 | .. code-block:: python
67 |
68 | import torch
69 |
70 | from zennit.canonizers import Canonizer
71 |
72 |
73 | class SoftplusCanonizer(Canonizer):
74 | '''Replaces ReLUs with Softplus units.'''
75 | def __init__(self, beta=10.):
76 | self.beta = beta
77 | self.module = None
78 | self.relu_children = None
79 |
80 | def apply(self, root_module):
81 | '''Iterate all modules under root_module and register the Canonizer
82 | if they have immediate ReLU sub-modules.
83 | '''
84 | # track the SoftplusCanonizer instances to remove them later
85 | instances = []
86 | # iterate recursively over all modules
87 | for module in root_module.modules():
88 | # get all the direct sub-module instances of torch.nn.ReLU
89 | relu_children = [
90 | (name, child)
91 | for name, child in module.named_children()
92 | if isinstance(child, torch.nn.ReLU)
93 | ]
94 | # if there is at least on direct ReLU sub-module
95 | if relu_children:
96 | # create a copy (with the same beta parameter)
97 | instance = self.copy()
98 | # register the module
99 | instance.register(module, relu_children)
100 | # add the copy to the instance list
101 | instances.append(instance)
102 | return instances
103 |
104 | def register(self, module, relu_children):
105 | '''Store the module and the immediate ReLU-sub-modules, and then
106 | overwrite the attributes that corresponds to each ReLU-sub-modules
107 | with a new instance of ``torch.nn.Softplus``.
108 | '''
109 | self.module = module
110 | self.relu_children = relu_children
111 | for name, _ in relu_children:
112 | # set each of the attributes corresponding to the ReLU to a new
113 | # instance of torch.nn.Softplus
114 | setattr(module, name, torch.nn.Softplus(beta=self.beta))
115 |
116 | def remove(self):
117 | '''Undo the changes introduces by this Canonizer, by setting the
118 | appropriate attributes of the stored module back to the original
119 | ReLU sub-module instance.
120 | '''
121 | for name, child in self.relu_children:
122 | setattr(self.module, name, child)
123 |
124 | def copy(self):
125 | '''Create a copy of this instance. Each module requires its own
126 | instance to call ``.register``.
127 | '''
128 | return SoftplusCanonizer(beta=self.beta)
129 |
130 |
131 | Note that we can only replace modules by changing their immediate parent. This
132 | means that if ``root_module`` was a ``torch.nn.ReLU`` itself, it would be
133 | impossible to replace it with a ``torch.nn.Softplus`` without replacing the
134 | ``root_module`` itself.
135 |
136 | For demonstration purposes, we can compute the gradient w.r.t. a loss which uses
137 | the gradient of the modified model in the following way:
138 |
139 | .. code-block:: python
140 |
141 | import torch
142 | from torchvision.models import vgg11
143 |
144 | from zennit.core import Composite
145 | from zennit.image import imgify
146 |
147 |
148 | # create a VGG11 model with random parameters
149 | model = vgg11()
150 | # use the Canonizer with an "empty" Composite (without specifying
151 | # module_map), which will not attach rules to any sub-module, thus resulting
152 | # in a plain gradient computation, but with a Canonizer applied
153 | composite = Composite(
154 | canonizers=[SoftplusCanonizer()]
155 | )
156 |
157 | input = torch.randn(1, 3, 224, 224, requires_grad=True)
158 | target = torch.eye(1000)[[0]]
159 | with composite.context(model) as modified_model:
160 | out = modified_model(input)
161 | relevance, = torch.autograd.grad(out, input, target, create_graph=True)
162 | # find adversarial example such that input and its respective
163 | # attribution are close
164 | loss = ((relevance - input.detach()) ** 2).mean()
165 | # compute the gradient of input w.r.t. loss, using the second order
166 | # derivative w.r.t. input; note that this currently does not work when
167 | # using BasicHook, which detaches the gradient to avoid wrong values
168 | adv_grad, = torch.autograd.grad(loss, input)
169 |
170 | # visualize adv_grad
171 | imgify(adv_grad[0].abs().sum(0), cmap='hot').show()
172 |
173 |
--------------------------------------------------------------------------------
/docs/source/how-to/write-custom-composites.rst:
--------------------------------------------------------------------------------
1 | =========================
2 | Writing Custom Composites
3 | =========================
4 |
5 | Zennit provides a number of commonly used **Composites**.
6 | While these are often enough for feed-forward-type neural networks, one primary goal of Zennit is to provide the tools to easily customize the computation of rule-based attribution methods.
7 | This is especially useful to analyze novel architectures, for which no attribution-approach has been designed before.
8 |
9 | For most use-cases, using the abstract **Composites** :py:class:`~zennit.composites.LayerMapComposite`, :py:class:`~zennit.composites.SpecialFirstLayerMapComposite`, and :py:class:`~zennit.composites.NameMapComposite` already provides enough freedom to customize which Layer should receive which rule. See :ref:`use-composites` for an introduction.
10 | Depending on the setup, it may however be more convenient to either directly use or implement a new **Composite** by creating a Subclass from :py:class:`zennit.core.Composite`.
11 | In either case, the :py:class:`~zennit.core.Composite` requires an argument ``module_map``, which is a function with the signature ``(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None``, which, given a context dict, the name of a single module and the module itself, either returns an instance of :py:class:`~zennit.core.Hook` which should be copied and registered to the module, or ``None`` if no ``Hook`` should be applied.
12 | The context dict ``ctx`` can be used to track subsequent calls to the ``module_map`` function, e.g. to count the number of processed modules, or to verify if some condition has been met before, e.g. a linear layer has been seen before.
13 | The ``module_map`` is used in :py:meth:`zennit.core.Composite.register`, where the context dict is initialized to an empty dict ``{}`` before iterating over all the sub-modules of the root-module to which the composite will be registered.
14 | The iteration is done using :py:meth:`torch.nn.Module.named_modules`, which will therefore dictate the order modules are visited, which is depth-first in the order sub-modules were assigned.
15 |
16 | A simple **Composite**, which only provides rules for linear layers that are leaves and bases the rule on how many leaf modules were visited before could be implemented like the following:
17 |
18 |
19 | .. code-block:: python
20 |
21 | import torch
22 | from torchvision.models import vgg16
23 | from zennit.rules import Epsilon, AlphaBeta
24 | from zennit.types import Linear
25 | from zennit.core import Composite
26 | from zennit.attribution import Gradient
27 |
28 |
29 | def module_map(ctx, name, module):
30 | # check whether there is at least one child, i.e. the module is not a leaf
31 | try:
32 | next(module.children())
33 | except StopIteration:
34 | # StopIteration is raised if the iterator has no more elements,
35 | # which means in this case there are no children and module is a leaf
36 | pass
37 | else:
38 | # if StopIteration is not raised on the first element, module is not a leaf
39 | return None
40 |
41 | # if the module is not Linear, we do not want to assign a hook
42 | if not isinstance(module, Linear):
43 | return None
44 |
45 | # count the number of the leaves processed yet in 'leafnum'
46 | if 'leafnum' not in ctx:
47 | ctx['leafnum'] = 0
48 | else:
49 | ctx['leafnum'] += 1
50 |
51 | # the first 10 leaf-modules which are of type Linear should be assigned
52 | # the Alpha2Beta1 rule
53 | if ctx['leafnum'] < 10:
54 | return AlphaBeta(alpha=2, beta=1)
55 | # all other rules should be assigned Epsilon
56 | return Epsilon(epsilon=1e-3)
57 |
58 |
59 | # we can then create a composite by passing the module_map function
60 | # canonizers may also be passed as with all composites
61 | composite = Composite(module_map=module_map)
62 |
63 | # try out the composite
64 | model = vgg16()
65 | with Gradient(model, composite) as attributor:
66 | out, grad = attributor(torch.randn(1, 3, 224, 224))
67 |
68 |
69 | A more general **Composite**, where we can specify which layer number and which type should be assigned which rule, can be implemented by creating a class:
70 |
71 | .. code-block:: python
72 |
73 | from itertools import islice
74 |
75 | import torch
76 | from torchvision.models import vgg16
77 | from zennit.rules import Epsilon, ZBox, Gamma, Pass, Norm
78 | from zennit.types import Linear, Convolution, Activation, AvgPool
79 | from zennit.core import Composite
80 | from zennit.attribution import Gradient
81 |
82 |
83 | class LeafNumberTypeComposite(Composite):
84 | def __init__(self, leafnum_map):
85 | # pass the class method self.mapping as the module_map
86 | super().__init__(module_map=self.mapping)
87 | # set the instance attribute so we can use it in self.mapping
88 | self.leafnum_map = leafnum_map
89 |
90 | def mapping(self, ctx, name, module):
91 | # check whether there is at least one child, i.e. the module is not a leaf
92 | # but this time shorter using itertools.islice to get at most one child
93 | if list(islice(module.children(), 1)):
94 | return None
95 |
96 | # count the number of the leaves processed yet in 'leafnum'
97 | # this time in a single line with get and all layers count, e.g. ReLU
98 | ctx['leafnum'] = ctx.get('leafnum', -1) + 1
99 |
100 | # loop over the leafnum_map and use the first template for which
101 | # the module type matches and the current ctx['leafnum'] falls into
102 | # the bounds
103 | for (low, high), dtype, template in self.leafnum_map:
104 | if isinstance(module, dtype) and low <= ctx['leafnum'] < high:
105 | return template
106 | # if none of the leafnum_map apply this means there is no rule
107 | # matching the current layer
108 | return None
109 |
110 |
111 | # this can be compared with int and will always be larger
112 | inf = float('inf')
113 |
114 | # we create an example leafnum-map, note that Linear is here
115 | # zennit.types.Linear and not torch.nn.Linear
116 | # the first two entries are for demonstration only and would
117 | # in practice most likely be a single "Linear" with appropriate low/high
118 | leafnum_map = [
119 | [(0, 1), Convolution, ZBox(low=-3.0, high=3.0)],
120 | [(0, 1), torch.nn.Linear, ZBox(low=0.0, high=1.0)],
121 | [(1, 17), Linear, Gamma(gamma=0.25)],
122 | [(17, 31), Linear, Epsilon(epsilon=0.5)],
123 | [(31, inf), Linear, Epsilon(epsilon=1e-9)],
124 | # catch all activations
125 | [(0, inf), Activation, Pass()],
126 | # explicit None is possible e.g. to (ab-)use precedence
127 | [(0, 17), torch.nn.MaxPool2d, None],
128 | # catch all AvgPool/MaxPool2d, isinstance also accepts tuples of types
129 | [(0, inf), (AvgPool, torch.nn.MaxPool2d), Norm()],
130 | ]
131 |
132 | # finally, create the composite using the leafnum_map
133 | composite = LeafNumberTypeComposite(leafnum_map)
134 |
135 | # try out the composite
136 | model = vgg16()
137 | with Gradient(model, composite) as attributor:
138 | out, grad = attributor(torch.randn(1, 3, 224, 224))
139 |
140 | In practice, however, we do not recommend to use the index of the layer when designing **Composites**, because most of the time, when such a configuration is chosen, it is done to shape the **Composite** for an explicit model.
141 | For these kinds of **Composites**, a :py:class:`~zennit.composites.NameMapComposite` will directly map the name of a sub-module to a Hook, which is a more explicit and transparent way to create a special **Composite** for a single neural network.
142 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | ====================
2 | Zennit Documentation
3 | ====================
4 |
5 | Zennit (Zennit Explains Neural Networks in Torch) is a python framework using PyTorch to compute local attributions in the sense of eXplainable AI (XAI) with a focus on Layer-wise Relevance Propagation.
6 | It works by defining *rules* which are used to overwrite the gradient of PyTorch modules in PyTorch's auto-differentiation engine.
7 | Rules are mapped to layers with *composites*, which contain directions to compute the attributions of a full model, which maps rules to modules based on their properties and context.
8 |
9 | Zennit is available on PyPI and may be installed using:
10 |
11 | .. code-block:: console
12 |
13 | $ pip install zennit
14 |
15 | Contents
16 | --------
17 |
18 | .. toctree::
19 | :maxdepth: 2
20 |
21 | getting-started
22 | how-to/index
23 | tutorial/index
24 | reference/index
25 | bibliography
26 |
27 | Indices and tables
28 | ------------------
29 |
30 | * :ref:`genindex`
31 | * :ref:`modindex`
32 | * :ref:`search`
33 |
34 |
35 | Citing
36 | ------
37 |
38 | If you find Zennit useful, why not cite our related paper :cite:p:`anders2021software`:
39 |
40 | .. code-block:: bibtex
41 |
42 | @article{anders2021software,
43 | author = {Anders, Christopher J. and
44 | Neumann, David and
45 | Samek, Wojciech and
46 | Müller, Klaus-Robert and
47 | Lapuschkin, Sebastian},
48 | title = {Software for Dataset-wide XAI: From Local Explanations to Global Insights with {Zennit}, {CoRelAy}, and {ViRelAy}},
49 | journal = {CoRR},
50 | volume = {abs/2106.13200},
51 | year = {2021},
52 | }
53 |
54 |
--------------------------------------------------------------------------------
/docs/source/reference/index.rst:
--------------------------------------------------------------------------------
1 | ================
2 | API Reference
3 | ================
4 |
5 | .. autosummary::
6 | :toctree:
7 | :nosignatures:
8 | :recursive:
9 | :template: modules.rst
10 |
11 | zennit.attribution
12 | zennit.canonizers
13 | zennit.cmap
14 | zennit.composites
15 | zennit.core
16 | zennit.image
17 | zennit.layer
18 | zennit.rules
19 | zennit.torchvision
20 | zennit.types
21 |
--------------------------------------------------------------------------------
/docs/source/tutorial/index.rst:
--------------------------------------------------------------------------------
1 | ================
2 | Tutorials
3 | ================
4 |
5 | .. toctree::
6 | :maxdepth: 1
7 |
8 | image-classification-vgg-resnet
9 | ..
10 | image-segmentation-with-unet
11 | text-classification-with-tbd
12 | audio-classification-with-tbd
13 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import re
3 | from setuptools import setup, find_packages
4 | from subprocess import run, CalledProcessError
5 |
6 |
7 | def get_long_description(project_path):
8 | '''Fetch the README contents and replace relative links with absolute ones
9 | pointing to github for correct behaviour on PyPI.
10 | '''
11 | try:
12 | revision = run(
13 | ['git', 'describe', '--tags'],
14 | capture_output=True,
15 | check=True,
16 | text=True
17 | ).stdout[:-1]
18 | except CalledProcessError:
19 | try:
20 | with open('PKG-INFO', 'r') as fd:
21 | body = fd.read().partition('\n\n')[2]
22 | if body:
23 | return body
24 | except FileNotFoundError:
25 | revision = 'master'
26 |
27 | with open('README.md', 'r', encoding='utf-8') as fd:
28 | long_description = fd.read()
29 |
30 | link_root = {
31 | '': f'https://github.com/{project_path}/blob',
32 | '!': f'https://raw.githubusercontent.com/{project_path}',
33 | }
34 |
35 | def replace(mobj):
36 | return f'{mobj[1]}[{mobj[2]}]({link_root[mobj[1]]}/{revision}/{mobj[3]})'
37 |
38 | link_rexp = re.compile(r'(!?)\[([^\]]*)\]\((?!https?://|/)([^\)]+)\)')
39 | return link_rexp.sub(replace, long_description)
40 |
41 |
42 | setup(
43 | name='zennit',
44 | use_scm_version=True,
45 | author='chrstphr',
46 | author_email='zennit@j0d.de',
47 | description='Attribution of Neural Networks using PyTorch',
48 | long_description=get_long_description('chr5tphr/zennit'),
49 | long_description_content_type='text/markdown',
50 | url='https://github.com/chr5tphr/zennit',
51 | packages=find_packages(where='src', include=['zennit*']),
52 | package_dir={'': 'src'},
53 | install_requires=[
54 | 'click',
55 | 'numpy',
56 | 'Pillow',
57 | 'torch>=1.7.0',
58 | 'torchvision',
59 | ],
60 | setup_requires=[
61 | 'setuptools_scm',
62 | ],
63 | extras_require={
64 | 'docs': [
65 | 'sphinx-copybutton>=0.4.0',
66 | 'sphinx-rtd-theme>=1.0.0',
67 | 'sphinxcontrib.datatemplates>=0.9.0',
68 | 'sphinxcontrib.bibtex>=2.4.1',
69 | 'nbsphinx>=0.8.8',
70 | 'nbconvert<7.14', # see https://github.com/jupyter/nbconvert/issues/2092
71 | 'ipykernel>=6.13.0',
72 | ],
73 | 'tests': [
74 | 'pytest',
75 | 'pytest-cov',
76 | ]
77 | },
78 | python_requires='>=3.7',
79 | classifiers=[
80 | 'Development Status :: 3 - Alpha',
81 | 'License :: OSI Approved :: GNU Lesser General Public License v3 or later (LGPLv3+)',
82 | 'Programming Language :: Python :: 3.7',
83 | 'Programming Language :: Python :: 3.8',
84 | 'Programming Language :: Python :: 3.9',
85 | ]
86 | )
87 |
--------------------------------------------------------------------------------
/share/example/feed_forward.py:
--------------------------------------------------------------------------------
1 | '''A quick example to generate heatmaps for vgg16.'''
2 | import os
3 | from functools import partial
4 |
5 | import click
6 | import torch
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Subset
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
10 | from torchvision.datasets import ImageFolder
11 | from torchvision.models import vgg11, vgg11_bn, vgg16, vgg16_bn, resnet18, resnet50
12 |
13 | from zennit.attribution import Gradient, SmoothGrad, IntegratedGradients, Occlusion
14 | from zennit.composites import COMPOSITES
15 | from zennit.core import Hook
16 | from zennit.image import imsave, CMAPS
17 | from zennit.layer import Sum
18 | from zennit.torchvision import VGGCanonizer, ResNetCanonizer
19 |
20 |
21 | MODELS = {
22 | 'vgg16': (vgg16, VGGCanonizer),
23 | 'vgg16_bn': (vgg16_bn, VGGCanonizer),
24 | 'vgg11': (vgg11, VGGCanonizer),
25 | 'vgg11_bn': (vgg11_bn, VGGCanonizer),
26 | 'resnet18': (resnet18, ResNetCanonizer),
27 | 'resnet50': (resnet50, ResNetCanonizer),
28 | }
29 |
30 | ATTRIBUTORS = {
31 | 'gradient': Gradient,
32 | 'smoothgrad': SmoothGrad,
33 | 'integrads': IntegratedGradients,
34 | 'occlusion': Occlusion,
35 | 'inputxgrad': IntegratedGradients,
36 | }
37 |
38 |
39 | class SumSingle(Hook):
40 | def __init__(self, dim=1):
41 | super().__init__()
42 | self.dim = dim
43 |
44 | def backward(self, module, grad_input, grad_output):
45 | elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1])
46 | elems[self.dim] = grad_output[0]
47 | return (torch.stack(elems, dim=-1),)
48 |
49 |
50 | class BatchNormalize:
51 | def __init__(self, mean, std, device=None):
52 | self.mean = torch.tensor(mean, device=device)[None, :, None, None]
53 | self.std = torch.tensor(std, device=device)[None, :, None, None]
54 |
55 | def __call__(self, tensor):
56 | return (tensor - self.mean) / self.std
57 |
58 |
59 | class AllowEmptyClassImageFolder(ImageFolder):
60 | '''Subclass of ImageFolder, which only finds non-empty classes, but with their correct indices given other empty
61 | classes. This counter-acts the changes in torchvision 0.10.0, in which DatasetFolder does not allow empty classes
62 | anymore by default. Versions before 0.10.0 do not expose `find_classes`, and thus this change does not change the
63 | functionality of `ImageFolder` in earlier versions.
64 | '''
65 | def find_classes(self, directory):
66 | with os.scandir(directory) as scanit:
67 | class_info = sorted((entry.name, len(list(os.scandir(entry.path)))) for entry in scanit if entry.is_dir())
68 | class_to_idx = {class_name: index for index, (class_name, n_members) in enumerate(class_info) if n_members}
69 | if not class_to_idx:
70 | raise FileNotFoundError(f'No non-empty classes found in \'{directory}\'.')
71 | return list(class_to_idx), class_to_idx
72 |
73 |
74 | @click.command()
75 | @click.argument('dataset-root', type=click.Path(file_okay=False))
76 | @click.argument('relevance_format', type=click.Path(dir_okay=False, writable=True))
77 | @click.option('--attributor', 'attributor_name', type=click.Choice(list(ATTRIBUTORS)), default='gradient')
78 | @click.option('--composite', 'composite_name', type=click.Choice(list(COMPOSITES)))
79 | @click.option('--model', 'model_name', type=click.Choice(list(MODELS)), default='vgg16_bn')
80 | @click.option('--parameters', type=click.Path(dir_okay=False))
81 | @click.option(
82 | '--inputs',
83 | 'input_format',
84 | type=click.Path(dir_okay=False, writable=True),
85 | help='Input image format string. {sample} is replaced with the sample index.'
86 | )
87 | @click.option('--batch-size', type=int, default=16)
88 | @click.option('--max-samples', type=int)
89 | @click.option('--n-outputs', type=int, default=1000)
90 | @click.option('--cpu/--gpu', default=True)
91 | @click.option('--shuffle/--no-shuffle', default=False)
92 | @click.option('--with-bias/--no-bias', default=True)
93 | @click.option('--with-residual/--no-residual', default=True)
94 | @click.option('--relevance-norm', type=click.Choice(['symmetric', 'absolute', 'unaligned']), default='symmetric')
95 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot')
96 | @click.option('--level', type=float, default=1.0)
97 | @click.option('--seed', type=int, default=0xDEADBEEF)
98 | def main(
99 | dataset_root,
100 | relevance_format,
101 | attributor_name,
102 | composite_name,
103 | model_name,
104 | parameters,
105 | input_format,
106 | batch_size,
107 | max_samples,
108 | n_outputs,
109 | cpu,
110 | shuffle,
111 | with_bias,
112 | with_residual,
113 | cmap,
114 | level,
115 | relevance_norm,
116 | seed
117 | ):
118 | '''Generate heatmaps of an image folder at DATASET_ROOT to files RELEVANCE_FORMAT.
119 | RELEVANCE_FORMAT is a format string, for which {sample} is replaced with the sample index.
120 | '''
121 | # set a manual seed for the RNG
122 | torch.manual_seed(seed)
123 |
124 | # use the gpu if requested and available, else use the cpu
125 | device = torch.device('cuda:0' if torch.cuda.is_available() and not cpu else 'cpu')
126 |
127 | # mean and std of ILSVRC2012 as computed for the torchvision models
128 | norm_fn = BatchNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), device=device)
129 |
130 | # transforms as used for torchvision model evaluation
131 | transform = Compose([
132 | Resize(256),
133 | CenterCrop(224),
134 | ToTensor(),
135 | ])
136 |
137 | # the dataset is a folder containing folders with samples, where each folder corresponds to one label
138 | dataset = AllowEmptyClassImageFolder(dataset_root, transform=transform)
139 |
140 | # limit the number of output samples, if requested, by creating a subset
141 | if max_samples is not None:
142 | if shuffle:
143 | indices = sorted(np.random.choice(len(dataset), min(len(dataset), max_samples), replace=False))
144 | else:
145 | indices = range(min(len(dataset), max_samples))
146 | dataset = Subset(dataset, indices)
147 |
148 | loader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
149 |
150 | model = MODELS[model_name][0]()
151 |
152 | # load model parameters if requested; the parameter file may need to be downloaded separately
153 | if parameters is not None:
154 | state_dict = torch.load(parameters)
155 | model.load_state_dict(state_dict)
156 | model.to(device)
157 | model.eval()
158 |
159 | # disable requires_grad for all parameters, we do not need their modified gradients
160 | for param in model.parameters():
161 | param.requires_grad = False
162 |
163 | # convenience identity matrix to produce one-hot encodings
164 | eye = torch.eye(n_outputs, device=device)
165 |
166 | # function to compute output relevance given the function output and a target
167 | def attr_output_fn(output, target):
168 | # output times one-hot encoding of the target labels of size (len(target), 1000)
169 | return output * eye[target]
170 |
171 | # create a composite if composite_name was set, otherwise we do not use a composite
172 | composite = None
173 | if composite_name is not None:
174 | composite_kwargs = {}
175 | if composite_name == 'epsilon_gamma_box':
176 | # the maximal input shape, needed for the ZBox rule
177 | shape = (batch_size, 3, 224, 224)
178 |
179 | # the highest and lowest pixel values for the ZBox rule
180 | composite_kwargs['low'] = norm_fn(torch.zeros(*shape, device=device))
181 | composite_kwargs['high'] = norm_fn(torch.ones(*shape, device=device))
182 | if not with_residual and 'resnet' in model_name:
183 | # skip the residual connection through the Sum added by the ResNetCanonizer
184 | composite_kwargs['layer_map'] = [(Sum, SumSingle(1))]
185 |
186 | # provide the name 'bias' in zero_params if no bias should be used to compute the relevance
187 | if not with_bias and composite_name in [
188 | 'epsilon_gamma_box',
189 | 'epsilon_plus',
190 | 'epsilon_alpha2_beta1',
191 | 'epsilon_plus_flat',
192 | 'epsilon_alpha2_beta1_flat',
193 | 'excitation_backprop',
194 | ]:
195 | composite_kwargs['zero_params'] = ['bias']
196 |
197 | # use torchvision specific canonizers, as supplied in the MODELS dict
198 | composite_kwargs['canonizers'] = [MODELS[model_name][1]()]
199 |
200 | # create a composite specified by a name; the COMPOSITES dict includes all preset composites provided by zennit.
201 | composite = COMPOSITES[composite_name](**composite_kwargs)
202 |
203 | # specify some attributor-specific arguments
204 | attributor_kwargs = {
205 | 'smoothgrad': {'noise_level': 0.1, 'n_iter': 20},
206 | 'integrads': {'n_iter': 20},
207 | 'inputxgrad': {'n_iter': 1},
208 | 'occlusion': {'window': (56, 56), 'stride': (28, 28)},
209 | }.get(attributor_name, {})
210 |
211 | # create an attributor, given the ATTRIBUTORS dict given above. If composite is None, the gradient will not be
212 | # modified for the attribution
213 | attributor = ATTRIBUTORS[attributor_name](model, composite, **attributor_kwargs)
214 |
215 | # the current sample index for creating file names
216 | sample_index = 0
217 |
218 | # the accuracy
219 | accuracy = 0.
220 |
221 | # enter the attributor context outside the data loader loop, such that its canonizers and hooks do not need to be
222 | # registered and removed for each step. This registers the composite (and applies the canonizer) to the model
223 | # within the with-statement
224 | with attributor:
225 | for data, target in loader:
226 | # we use data without the normalization applied for visualization, and with the normalization applied as
227 | # the model input
228 | data_norm = norm_fn(data.to(device))
229 |
230 | # create output relevance function of output with fixed target
231 | output_relevance = partial(attr_output_fn, target=target)
232 |
233 | # this will compute the modified gradient of model, where the output relevance is chosen by the as the
234 | # model's output for the ground-truth label index
235 | output, relevance = attributor(data_norm, output_relevance)
236 |
237 | # sum over the color channel for visualization
238 | relevance = np.array(relevance.sum(1).detach().cpu())
239 |
240 | # normalize between 0. and 1. given the specified strategy
241 | if relevance_norm == 'symmetric':
242 | # 0-aligned symmetric relevance, negative and positive can be compared, the original 0. becomes 0.5
243 | amax = np.abs(relevance).max((1, 2), keepdims=True)
244 | relevance = (relevance + amax) / 2 / amax
245 | elif relevance_norm == 'absolute':
246 | # 0-aligned absolute relevance, only the amplitude of relevance matters, the original 0. becomes 0.
247 | relevance = np.abs(relevance)
248 | relevance /= relevance.max((1, 2), keepdims=True)
249 | elif relevance_norm == 'unaligned':
250 | # do not align, the original minimum value becomes 0., the original maximum becomes 1.
251 | rmin = relevance.min((1, 2), keepdims=True)
252 | rmax = relevance.max((1, 2), keepdims=True)
253 | relevance = (relevance - rmin) / (rmax - rmin)
254 |
255 | for n in range(len(data)):
256 | fname = relevance_format.format(sample=sample_index + n)
257 | # zennit.image.imsave will create an appropriate heatmap given a cmap specification
258 | imsave(fname, relevance[n], vmin=0., vmax=1., level=level, cmap=cmap)
259 | if input_format is not None:
260 | fname = input_format.format(sample=sample_index + n)
261 | # if there are 3 color channels, imsave will not create a heatmap, but instead save the image with
262 | # its appropriate colors
263 | imsave(fname, data[n])
264 | sample_index += len(data)
265 |
266 | # update the accuracy
267 | accuracy += (output.argmax(1) == target).sum().detach().cpu().item()
268 |
269 | accuracy /= len(dataset)
270 | print(f'Accuracy: {accuracy:.2f}')
271 |
272 |
273 | if __name__ == '__main__':
274 | main()
275 |
--------------------------------------------------------------------------------
/share/img/beacon_resnet50_various.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_resnet50_various.webp
--------------------------------------------------------------------------------
/share/img/beacon_vgg16_epsilon_gamma_box.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_vgg16_epsilon_gamma_box.png
--------------------------------------------------------------------------------
/share/img/beacon_vgg16_various.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/beacon_vgg16_various.webp
--------------------------------------------------------------------------------
/share/img/zennit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chr5tphr/zennit/e5699aa7e6fb98bec67505af917d0a17cd81d3b5/share/img/zennit.png
--------------------------------------------------------------------------------
/share/img/zennit.svg:
--------------------------------------------------------------------------------
1 |
2 |
21 |
23 |
25 |
28 |
32 |
36 |
37 |
40 |
44 |
48 |
49 |
60 |
71 |
72 |
92 |
97 |
98 |
100 |
101 |
103 | image/svg+xml
104 |
106 |
107 |
108 |
109 |
110 |
115 |
124 |
131 |
138 |
152 |
159 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/share/merge_maps/vgg16_bn.json:
--------------------------------------------------------------------------------
1 | [
2 | [["features.0"], "features.1"],
3 | [["features.3"], "features.4"],
4 | [["features.7"], "features.8"],
5 | [["features.10"], "features.11"],
6 | [["features.14"], "features.15"],
7 | [["features.17"], "features.18"],
8 | [["features.20"], "features.21"],
9 | [["features.24"], "features.25"],
10 | [["features.27"], "features.28"],
11 | [["features.30"], "features.31"],
12 | [["features.34"], "features.35"],
13 | [["features.37"], "features.38"],
14 | [["features.40"], "features.41"]
15 | ]
16 |
--------------------------------------------------------------------------------
/share/scripts/palette_fit.py:
--------------------------------------------------------------------------------
1 | '''Script to fit RGB heatmap images to a source color palette.'''
2 | import click
3 | import numpy as np
4 | from PIL import Image
5 |
6 | from zennit.image import CMAPS, palette
7 |
8 |
9 | def gale_shapley(dist):
10 | '''Find a stable matching given a distance matrix.'''
11 | preference = np.argsort(dist, axis=1)
12 | proposed = np.zeros(dist.shape[0], dtype=int)
13 | loners = set(range(dist.shape[0]))
14 | guys = [-1] * dist.shape[0]
15 | gals = [-1] * dist.shape[1]
16 | while loners:
17 | loner = loners.pop()
18 | target = preference[loner, proposed[loner]]
19 | if gals[target] == -1:
20 | gals[target] = loner
21 | guys[loner] = target
22 | elif dist[gals[target], target] > dist[loner, target]:
23 | gals[target] = loner
24 | guys[loner] = target
25 | guys[gals[target]] = -1
26 | loners.add(gals[target])
27 | else:
28 | loners.add(loner)
29 | proposed[loner] += 1
30 | return guys
31 |
32 |
33 | @click.command()
34 | @click.argument('source-file', type=click.Path(exists=True, dir_okay=False))
35 | @click.argument('output-file', type=click.Path(writable=True, dir_okay=False))
36 | @click.option('--strategy', type=click.Choice(['intensity', 'nearest', 'histogram']), default='intensity')
37 | @click.option('--source-cmap', type=click.Choice(list(CMAPS)), default='bwr')
38 | @click.option('--source-level', type=float, default=1.0)
39 | @click.option('--invert/--no-invert', default=False)
40 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot')
41 | @click.option('--level', type=float, default=1.0)
42 | def main(source_file, output_file, strategy, source_cmap, source_level, invert, cmap, level):
43 | '''Fit an existing RGB heatmap image to a color palette.'''
44 | source = np.array(Image.open(source_file).convert('RGB'))
45 | matchpal = palette(source_cmap, source_level)
46 |
47 | if strategy == 'intensity':
48 | # matching based on the source image intensity/ brightness
49 | values = source.astype(float).mean(2)
50 | elif strategy == 'nearest':
51 | # match by finding the neareast centroids in a source colormap
52 | dists = (np.abs(source[None].astype(float) - matchpal[:, None, None].astype(float))).sum(3)
53 | values = np.argmin(dists, axis=0)
54 | elif strategy == 'histogram':
55 | # match by finding a stable match between the color histogram of the source image and a source colormap
56 | source = np.concatenate([source, np.zeros_like(source[:, :, [0]])], axis=2).view(np.uint32)[..., 0]
57 | uniques, counts = np.unique(source, return_counts=True)
58 | uniques = uniques[np.argsort(counts)[-256:]]
59 | dist = (np.abs(uniques.view(np.uint8).reshape(-1, 1, 4)[..., :3] - matchpal[None])).sum(2)
60 | matches = np.array(gale_shapley(dist))
61 |
62 | ind_bin, ind_h, ind_w = np.nonzero(source[None] == uniques[:, None, None])
63 | values = np.zeros(source.shape[:2], dtype=np.uint8)
64 | values[ind_h, ind_w] = matches[ind_bin]
65 |
66 | values = values.clip(0, 255).astype(np.uint8)
67 | if invert:
68 | values = 255 - values
69 |
70 | img = Image.fromarray(values, mode='P')
71 | pal = palette(cmap, level)
72 | img.putpalette(pal)
73 | img.save(output_file)
74 |
75 |
76 | if __name__ == '__main__':
77 | main()
78 |
--------------------------------------------------------------------------------
/share/scripts/palette_swap.py:
--------------------------------------------------------------------------------
1 | '''Script to swap the palette of heatmap images.'''
2 | import click
3 | from PIL import Image
4 |
5 | from zennit.image import CMAPS, palette
6 |
7 |
8 | @click.command()
9 | @click.argument('image-files', type=click.Path(exists=True, dir_okay=False), nargs=-1)
10 | @click.option('--cmap', type=click.Choice(list(CMAPS)), default='coldnhot')
11 | @click.option('--level', type=float, default=1.0)
12 | def main(image_files, cmap, level):
13 | '''Swap the palette of heatmap image files inline.'''
14 | for fname in image_files:
15 | img = Image.open(fname)
16 | img = img.convert('P')
17 | pal = palette(cmap, level)
18 | img.putpalette(pal)
19 | img.save(fname)
20 |
21 |
22 | if __name__ == '__main__':
23 | main()
24 |
--------------------------------------------------------------------------------
/share/scripts/show_cmaps.py:
--------------------------------------------------------------------------------
1 | '''Script to visually inspect color maps.'''
2 | import click
3 | import numpy as np
4 | from PIL import Image
5 |
6 | from zennit.image import CMAPS, palette
7 |
8 |
9 | def semsstr(string):
10 | if isinstance(string, list):
11 | return string
12 | return [obj for obj in string.split(';') if obj]
13 |
14 |
15 | @click.command()
16 | @click.argument('output')
17 | @click.option('--cmap', 'colormap_src', type=semsstr, default=list(CMAPS))
18 | @click.option('--level', type=float, default=1.0)
19 | def main(output, colormap_src, level):
20 | print('\n'.join(colormap_src))
21 | palettes = np.stack([palette(obj, level) for obj in colormap_src])
22 | arr = np.repeat(palettes, 32, 0)
23 | img = Image.fromarray(arr)
24 | img.save(output)
25 |
26 |
27 | if __name__ == '__main__':
28 | main()
29 |
--------------------------------------------------------------------------------
/src/zennit/__init__.py:
--------------------------------------------------------------------------------
1 | '''Zennit top-level __init__.'''
2 | from . import attribution
3 | from . import canonizers
4 | from . import cmap
5 | from . import composites
6 | from . import core
7 | from . import image
8 | from . import layer
9 | from . import rules
10 | from . import torchvision
11 | from . import types
12 |
13 |
14 | __all__ = [
15 | 'attribution',
16 | 'canonizers',
17 | 'cmap',
18 | 'composites',
19 | 'core',
20 | 'image',
21 | 'layer',
22 | 'rules',
23 | 'torchvision',
24 | 'types',
25 | ]
26 |
--------------------------------------------------------------------------------
/src/zennit/cmap.py:
--------------------------------------------------------------------------------
1 | # This file is part of Zennit
2 | # Copyright (C) 2019-2021 Christopher J. Anders
3 | #
4 | # zennit/cmap.py
5 | #
6 | # Zennit is free software: you can redistribute it and/or modify it under
7 | # the terms of the GNU Lesser General Public License as published by the Free
8 | # Software Foundation; either version 3 of the License, or (at your option) any
9 | # later version.
10 | #
11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT
12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
14 | # more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this library. If not, see .
18 | '''Create color maps from a color-map specification language'''
19 | import re
20 | from typing import NamedTuple
21 |
22 | import numpy as np
23 |
24 |
25 | class CMapToken(NamedTuple):
26 | '''Tokens used by the lexer of ColorMap.'''
27 | type: str
28 | value: str
29 | pos: int
30 |
31 |
32 | class ColorNode(NamedTuple):
33 | '''Nodes produced by the parser of ColorMap.'''
34 | index: int
35 | value: np.ndarray
36 |
37 |
38 | class ColorMap:
39 | '''Compile a color map from color-map specification language (cmsl) source code.
40 |
41 | The color-map specification language (cmsl) is used to specify linear color maps with comma-separated colors
42 | supplied as hexadecimal values for each color channel in RGB, with either 1 or 2 values per channel. Optionally, a
43 | hexadecimal index with either one or two digits may be supplied in front of each color, followed by a colon, to
44 | indicate the index which should be the color. Values for the ColorMap in-between colors will be interpolated
45 | linearly. If no index is supplied, colors without indices will be spaced evenly between indices. If the first and
46 | last indices are supplied but not 0 (or 00) and f (or ff) respectively, they will be added as an additional node in
47 | the color map, with the same color as the colors with the lowest and highest index respectively. If indices are
48 | provided, they must be in ascending order from left to right, with an arbitrary number of non-indexed colors. If
49 | the first and/or last color are not indexed, they are assumed to be 0 (or 00) and f (or ff) respectively.
50 |
51 | Parameters
52 | ----------
53 | source : str
54 | Source code to generate the color map.
55 |
56 | '''
57 | _rexp = re.compile(
58 | r'(?P[0-9a-fA-F]{6})|'
59 | r'(?P[0-9a-fA-F]{3})|'
60 | r'(?P[0-9a-fA-F]{1,2})|'
61 | r'(?P:)|'
62 | r'(?P,)|'
63 | r'(?P\s+)|'
64 | r'(?P.+)'
65 | )
66 |
67 | def __init__(self, source):
68 | self._source = None
69 | self.source = source
70 |
71 | @property
72 | def source(self) -> str:
73 | '''Source code property used to generate the color map. May be overwritten with a new string, which will be
74 | compiled to change the color map.
75 | '''
76 | return self._source
77 |
78 | @source.setter
79 | def source(self, value: str):
80 | '''Set source code property and re-compile the color map.'''
81 | try:
82 | tokens = self._lex(value)
83 | nodes = self._parse(tokens)
84 | self._indices, self._colors = self._make_palette(nodes)
85 | except RuntimeError as err:
86 | raise RuntimeError('Compilation of ColorMap failed!') from err
87 |
88 | self._source = value
89 |
90 | @staticmethod
91 | def _lex(string):
92 | '''Lexical scanning of cmsl using regular expressions.'''
93 | return [CMapToken(match.lastgroup, match.group(), match.start()) for match in ColorMap._rexp.finditer(string)]
94 |
95 | @staticmethod
96 | def _parse(tokens):
97 | '''Parse cmsl tokens into a list of color nodes.'''
98 | nodes = []
99 | log = []
100 | for token in tokens:
101 | if token.type == 'index' and not log:
102 | log.append(token)
103 | elif token.type == 'adsep' and len(log) == 1 and log[-1].type == 'index':
104 | log.append(token)
105 | elif token.type in ('shortcolor', 'longcolor'):
106 | if len(log) == 2 and log[-2].type == 'index':
107 | indval = log[-2].value
108 | if len(indval) == 1:
109 | indval = indval * 2
110 | index = int(indval, base=16)
111 | elif not log:
112 | index = None
113 | else:
114 | raise RuntimeError(f'Unexpected {token}')
115 |
116 | value_it = iter(token.value) if token.type == 'longcolor' else token.value
117 | value = [int(''.join(chars), base=16) for chars in zip(*[value_it] * 2)]
118 | nodes.append(ColorNode(index, np.array(value)))
119 | log.append(token)
120 | elif token.type == 'sep' and log and log[-1].type in ('longcolor', 'shortcolor'):
121 | log.clear()
122 | elif token.type != 'whitespace':
123 | raise RuntimeError(f'Unexpected {token}')
124 |
125 | if log and log[-1].type not in ('shortcolor', 'longcolor'):
126 | raise RuntimeError(f'Unexpected {log[-1]}')
127 |
128 | return nodes
129 |
130 | @staticmethod
131 | def _make_palette(nodes):
132 | '''Generate color map indices and colors from a list of color nodes.'''
133 | if len(nodes) < 2:
134 | raise RuntimeError("ColorMap needs at least 2 colors!")
135 | result = []
136 | log = []
137 |
138 | start = nodes.pop(0)
139 | result.append(ColorNode(0, start.value))
140 | if start.index is not None and start.index > 0:
141 | result.append(start)
142 |
143 | for n, node in enumerate(nodes):
144 | if node.index is None:
145 | if n < len(nodes) - 1:
146 | log.append(node)
147 | continue
148 | node = ColorNode(255, node.value)
149 | elif node.index < result[-1].index:
150 | raise RuntimeError('ColorMap indices not ordered! Provided indices are required in ascending order.')
151 | if log:
152 | result += [
153 | ColorNode(
154 | int(result[-1].index * (1. - alpha) + node.index * alpha),
155 | lognode.value
156 | ) for alpha, lognode in zip(np.linspace(0., 1., len(log) + 2)[1:-1], log)
157 | ]
158 | log.clear()
159 | result.append(node)
160 |
161 | result.append(ColorNode(256, result[-1].value))
162 |
163 | indices = np.array([node.index for node in result])
164 | colors = np.stack([node.value for node in result], axis=0)
165 |
166 | return indices, colors
167 |
168 | def __call__(self, x):
169 | '''Map scalar values in the range [0, 1] to RGB. This appends an axis with size 3 to `x`. Values are clipped to
170 | the range [0, 1], and the output will also lie in this range.
171 |
172 | Parameters
173 | ----------
174 | x : obj:`numpy.ndarray`
175 | Input array of arbitrary shape, which will be clipped to range [0, 1], and mapped to RGB using this
176 | ColorMap.
177 |
178 | Returns
179 | -------
180 | obj:`numpy.ndarray`
181 | The input array `x`, clipped to [0, 1] and mapped to RGB given this colormap, where the 3 color channels
182 | are appended as a new axis to the end.
183 | '''
184 | x = (x * 255).clip(0, 255)
185 | index = np.searchsorted(self._indices[:-1], x, side='right')
186 | alpha = ((x - self._indices[index - 1]) / (self._indices[index] - self._indices[index - 1]))[..., None]
187 | return (self._colors[index - 1] * (1 - alpha) + self._colors[index] * alpha) / 255.
188 |
189 | def palette(self, level=1.0):
190 | '''Create an 8-bit palette.
191 |
192 | Parameters
193 | ----------
194 | level: float
195 | The level of the color map. 1.0 is default. Values below zero reduce the color range, with only a single
196 | color used at value 0.0. Values above 1.0 clip the value earlier towards the maximum, with an increasingly
197 | steep transition at the center of the image.
198 |
199 | Returns
200 | -------
201 | obj:`numpy.ndarray`
202 | The palette described by an unsigned 8-bit numpy array with 256 entries.
203 | '''
204 | x = np.linspace(-1., 1., 256, dtype=np.float64) * level
205 | x = ((x + 1.) / 2.).clip(0., 1.)
206 | x = self(x)
207 | x = (x * 255.).round(12).clip(0., 255.).astype(np.uint8)
208 | return x
209 |
210 |
211 | class LazyColorMapCache:
212 | '''Dict-like object to store sources for colormaps, and compile and cache them lazily.
213 |
214 | Parameters
215 | ----------
216 | sources : dict
217 | Dict containing a mapping from names to color map specification language source.
218 | '''
219 | def __init__(self, sources):
220 | self._sources = sources
221 | self._compiled = {}
222 |
223 | def __getitem__(self, name):
224 | if name not in self._sources:
225 | raise KeyError(f'No source for key {name}.')
226 | if name not in self._compiled:
227 | self._compiled[name] = ColorMap(self._sources[name])
228 | return self._compiled[name]
229 |
230 | def __setitem__(self, name, value):
231 | self._sources[name] = value
232 | if name in self._compiled:
233 | self._compiled[name].source = value
234 |
235 | def __delitem__(self, name):
236 | del self._sources[name]
237 | if name in self._compiled:
238 | del self._compiled[name]
239 |
240 | def __iter__(self):
241 | return iter(self._sources)
242 |
243 | def __len__(self):
244 | return len(self._sources)
245 |
--------------------------------------------------------------------------------
/src/zennit/layer.py:
--------------------------------------------------------------------------------
1 | # This file is part of Zennit
2 | # Copyright (C) 2019-2021 Christopher J. Anders
3 | #
4 | # zennit/layer.py
5 | #
6 | # Zennit is free software: you can redistribute it and/or modify it under
7 | # the terms of the GNU Lesser General Public License as published by the Free
8 | # Software Foundation; either version 3 of the License, or (at your option) any
9 | # later version.
10 | #
11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT
12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
14 | # more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this library. If not, see .
18 | '''Additional Utility Layers'''
19 | import torch
20 |
21 |
22 | class Sum(torch.nn.Module):
23 | '''Compute the sum along an axis.
24 |
25 | Parameters
26 | ----------
27 | dim : int
28 | Dimension over which to sum.
29 | '''
30 | def __init__(self, dim=-1):
31 | super().__init__()
32 | self.dim = dim
33 |
34 | def forward(self, input):
35 | '''Computes the sum along a dimension.'''
36 | return torch.sum(input, dim=self.dim)
37 |
--------------------------------------------------------------------------------
/src/zennit/torchvision.py:
--------------------------------------------------------------------------------
1 | # This file is part of Zennit
2 | # Copyright (C) 2019-2021 Christopher J. Anders
3 | #
4 | # zennit/torchvision.py
5 | #
6 | # Zennit is free software: you can redistribute it and/or modify it under
7 | # the terms of the GNU Lesser General Public License as published by the Free
8 | # Software Foundation; either version 3 of the License, or (at your option) any
9 | # later version.
10 | #
11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT
12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
14 | # more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this library. If not, see .
18 | '''Specialized Canonizers for models from torchvision.'''
19 | import torch
20 | from torchvision.models.resnet import Bottleneck as ResNetBottleneck, BasicBlock as ResNetBasicBlock
21 |
22 | from .canonizers import SequentialMergeBatchNorm, AttributeCanonizer, CompositeCanonizer
23 | from .layer import Sum
24 |
25 |
26 | class VGGCanonizer(SequentialMergeBatchNorm):
27 | '''Canonizer for torchvision.models.vgg* type models. This is so far identical to a SequentialMergeBatchNorm'''
28 |
29 |
30 | class ResNetBottleneckCanonizer(AttributeCanonizer):
31 | '''Canonizer specifically for Bottlenecks of torchvision.models.resnet* type models.'''
32 | def __init__(self):
33 | super().__init__(self._attribute_map)
34 |
35 | @classmethod
36 | def _attribute_map(cls, name, module):
37 | '''Create a forward function and a Sum module to overload as new attributes for module.
38 |
39 | Parameters
40 | ----------
41 | name : string
42 | Name by which the module is identified.
43 | module : obj:`torch.nn.Module`
44 | Instance of a module. If this is a Bottleneck layer, the appropriate attributes to overload are returned.
45 |
46 | Returns
47 | -------
48 | None or dict
49 | None if `module` is not an instance of Bottleneck, otherwise the appropriate attributes to overload onto
50 | the module instance.
51 | '''
52 | if isinstance(module, ResNetBottleneck):
53 | attributes = {
54 | 'forward': cls.forward.__get__(module),
55 | 'canonizer_sum': Sum(),
56 | }
57 | return attributes
58 | return None
59 |
60 | @staticmethod
61 | def forward(self, x):
62 | '''Modified Bottleneck forward for ResNet.'''
63 | identity = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | identity = self.downsample(x)
78 |
79 | out = torch.stack([identity, out], dim=-1)
80 | out = self.canonizer_sum(out)
81 |
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class ResNetBasicBlockCanonizer(AttributeCanonizer):
88 | '''Canonizer specifically for BasicBlocks of torchvision.models.resnet* type models.'''
89 | def __init__(self):
90 | super().__init__(self._attribute_map)
91 |
92 | @classmethod
93 | def _attribute_map(cls, name, module):
94 | '''Create a forward function and a Sum module to overload as new attributes for module.
95 |
96 | Parameters
97 | ----------
98 | name : string
99 | Name by which the module is identified.
100 | module : obj:`torch.nn.Module`
101 | Instance of a module. If this is a BasicBlock layer, the appropriate attributes to overload are returned.
102 |
103 | Returns
104 | -------
105 | None or dict
106 | None if `module` is not an instance of BasicBlock, otherwise the appropriate attributes to overload onto
107 | the module instance.
108 | '''
109 | if isinstance(module, ResNetBasicBlock):
110 | attributes = {
111 | 'forward': cls.forward.__get__(module),
112 | 'canonizer_sum': Sum(),
113 | }
114 | return attributes
115 | return None
116 |
117 | @staticmethod
118 | def forward(self, x):
119 | '''Modified BasicBlock forward for ResNet.'''
120 | identity = x
121 |
122 | out = self.conv1(x)
123 | out = self.bn1(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv2(out)
127 | out = self.bn2(out)
128 |
129 | if self.downsample is not None:
130 | identity = self.downsample(x)
131 |
132 | out = torch.stack([identity, out], dim=-1)
133 | out = self.canonizer_sum(out)
134 |
135 | out = self.relu(out)
136 |
137 | return out
138 |
139 |
140 | class ResNetCanonizer(CompositeCanonizer):
141 | '''Canonizer for torchvision.models.resnet* type models. This applies SequentialMergeBatchNorm, as well as
142 | add a Sum module to the Bottleneck modules and overload their forward method to use the Sum module instead of
143 | simply adding two tensors, such that forward and backward hooks may be applied.'''
144 | def __init__(self):
145 | super().__init__((
146 | SequentialMergeBatchNorm(),
147 | ResNetBottleneckCanonizer(),
148 | ResNetBasicBlockCanonizer(),
149 | ))
150 |
--------------------------------------------------------------------------------
/src/zennit/types.py:
--------------------------------------------------------------------------------
1 | # This file is part of Zennit
2 | # Copyright (C) 2019-2021 Christopher J. Anders
3 | #
4 | # zennit/types.py
5 | #
6 | # Zennit is free software: you can redistribute it and/or modify it under
7 | # the terms of the GNU Lesser General Public License as published by the Free
8 | # Software Foundation; either version 3 of the License, or (at your option) any
9 | # later version.
10 | #
11 | # Zennit is distributed in the hope that it will be useful, but WITHOUT
12 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
13 | # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
14 | # more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this library. If not, see .
18 | '''Type definitions for convenience.'''
19 | import torch
20 |
21 |
22 | class SubclassMeta(type):
23 | '''Meta class to bundle multiple subclasses.'''
24 | def __instancecheck__(cls, inst):
25 | """Implement isinstance(inst, cls) as subclasscheck."""
26 | return cls.__subclasscheck__(type(inst))
27 |
28 | def __subclasscheck__(cls, sub):
29 | """Implement issubclass(sub, cls) with by considering additional __subclass__ members."""
30 | candidates = cls.__dict__.get("__subclass__", tuple())
31 | return type.__subclasscheck__(cls, sub) or issubclass(sub, candidates)
32 |
33 |
34 | class ConvolutionTranspose(metaclass=SubclassMeta):
35 | '''Abstract base class that describes transposed convolutional modules.'''
36 | __subclass__ = (
37 | torch.nn.modules.conv.ConvTranspose1d,
38 | torch.nn.modules.conv.ConvTranspose2d,
39 | torch.nn.modules.conv.ConvTranspose3d,
40 | )
41 |
42 |
43 | class ConvolutionStandard(metaclass=SubclassMeta):
44 | '''Abstract base class that describes standard (forward) convolutional modules.'''
45 | __subclass__ = (
46 | torch.nn.modules.conv.Conv1d,
47 | torch.nn.modules.conv.Conv2d,
48 | torch.nn.modules.conv.Conv3d,
49 | )
50 |
51 |
52 | class Convolution(metaclass=SubclassMeta):
53 | '''Abstract base class that describes all convolutional modules.'''
54 | __subclass__ = (
55 | ConvolutionStandard,
56 | ConvolutionTranspose,
57 | )
58 |
59 |
60 | class Linear(metaclass=SubclassMeta):
61 | '''Abstract base class that describes linear modules.'''
62 | __subclass__ = (
63 | Convolution,
64 | torch.nn.modules.linear.Linear,
65 | )
66 |
67 |
68 | class BatchNorm(metaclass=SubclassMeta):
69 | '''Abstract base class that describes batch normalization modules.'''
70 | __subclass__ = (
71 | torch.nn.modules.batchnorm.BatchNorm1d,
72 | torch.nn.modules.batchnorm.BatchNorm2d,
73 | torch.nn.modules.batchnorm.BatchNorm3d,
74 | )
75 |
76 |
77 | class AvgPool(metaclass=SubclassMeta):
78 | '''Abstract base class that describes sum-pooling modules.'''
79 | __subclass__ = (
80 | torch.nn.modules.pooling.AvgPool1d,
81 | torch.nn.modules.pooling.AvgPool2d,
82 | torch.nn.modules.pooling.AvgPool3d,
83 | torch.nn.modules.pooling.AdaptiveAvgPool1d,
84 | torch.nn.modules.pooling.AdaptiveAvgPool2d,
85 | torch.nn.modules.pooling.AdaptiveAvgPool3d,
86 | )
87 |
88 |
89 | class MaxPool(metaclass=SubclassMeta):
90 | '''Abstract base class that describes max-pooling modules.'''
91 | __subclass__ = (
92 | torch.nn.modules.pooling.MaxPool1d,
93 | torch.nn.modules.pooling.MaxPool2d,
94 | torch.nn.modules.pooling.MaxPool3d,
95 | torch.nn.modules.pooling.AdaptiveMaxPool1d,
96 | torch.nn.modules.pooling.AdaptiveMaxPool2d,
97 | torch.nn.modules.pooling.AdaptiveMaxPool3d,
98 | )
99 |
100 |
101 | class Activation(metaclass=SubclassMeta):
102 | '''Abstract base class that describes activation modules.'''
103 | __subclass__ = (
104 | torch.nn.modules.activation.ELU,
105 | torch.nn.modules.activation.Hardshrink,
106 | torch.nn.modules.activation.Hardsigmoid,
107 | torch.nn.modules.activation.Hardtanh,
108 | torch.nn.modules.activation.Hardswish,
109 | torch.nn.modules.activation.LeakyReLU,
110 | torch.nn.modules.activation.LogSigmoid,
111 | torch.nn.modules.activation.PReLU,
112 | torch.nn.modules.activation.ReLU,
113 | torch.nn.modules.activation.ReLU6,
114 | torch.nn.modules.activation.RReLU,
115 | torch.nn.modules.activation.SELU,
116 | torch.nn.modules.activation.CELU,
117 | torch.nn.modules.activation.GELU,
118 | torch.nn.modules.activation.Sigmoid,
119 | torch.nn.modules.activation.SiLU,
120 | torch.nn.modules.activation.Softplus,
121 | torch.nn.modules.activation.Softshrink,
122 | torch.nn.modules.activation.Softsign,
123 | torch.nn.modules.activation.Tanh,
124 | torch.nn.modules.activation.Tanhshrink,
125 | torch.nn.modules.activation.Threshold,
126 | )
127 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | '''Configuration and fixtures for testing'''
2 | import random
3 | from itertools import product, groupby
4 | from collections import OrderedDict
5 |
6 | import pytest
7 | import torch
8 | from torch.nn import Conv1d, ConvTranspose1d, Linear
9 | from torch.nn import Conv2d, ConvTranspose2d
10 | from torch.nn import Conv3d, ConvTranspose3d
11 | from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
12 | from torchvision.models import vgg11, resnet18, alexnet
13 | from helpers import prodict, one_hot_max
14 |
15 | from zennit.attribution import identity
16 | from zennit.core import Composite, Hook
17 | from zennit.composites import COMPOSITES
18 | from zennit.composites import EpsilonGammaBox
19 | from zennit.composites import LayerMapComposite
20 | from zennit.composites import MixedComposite
21 | from zennit.composites import NameLayerMapComposite
22 | from zennit.composites import NameMapComposite
23 | from zennit.composites import SpecialFirstLayerMapComposite
24 | from zennit.types import Linear as AnyLinear, Activation
25 |
26 |
27 | def pytest_addoption(parser):
28 | '''Add options to pytest.'''
29 | parser.addoption(
30 | '--batchsize',
31 | default=4,
32 | help='Batch-size for generated samples.'
33 | )
34 |
35 |
36 | def pytest_generate_tests(metafunc):
37 | '''Generate test fixture values based on CLI options.'''
38 | if 'batchsize' in metafunc.fixturenames:
39 | metafunc.parametrize('batchsize', [metafunc.config.getoption('batchsize')], scope='session')
40 |
41 |
42 | @pytest.fixture(
43 | scope='session',
44 | params=[
45 | 0xdeadbeef,
46 | 0xd0c0ffee,
47 | *[pytest.param(seed, marks=pytest.mark.extended) for seed in [
48 | 0xc001bee5, 0xc01dfee7, 0xbe577001, 0xca7b0075, 0x1057b0a7, 0x900ddeed
49 | ]],
50 | ],
51 | ids=hex
52 | )
53 | def rng(request):
54 | '''Fixture for the NumPy random number generator.'''
55 | return torch.manual_seed(request.param)
56 |
57 |
58 | @pytest.fixture(scope='session')
59 | def pyrng(rng):
60 | '''Fixture for the Python random number generator.'''
61 | return random.Random(rng.initial_seed())
62 |
63 |
64 | @pytest.fixture(
65 | scope='session',
66 | params=[
67 | (torch.nn.ReLU, {}),
68 | (torch.nn.Softmax, {'dim': 1}),
69 | (torch.nn.Tanh, {}),
70 | (torch.nn.Sigmoid, {}),
71 | (torch.nn.Softplus, {'beta': 1}),
72 | ],
73 | ids=lambda param: param[0].__name__
74 | )
75 | def module_simple(rng, request):
76 | '''Fixture for simple modules.'''
77 | module_type, kwargs = request.param
78 | return module_type(**kwargs).to(torch.float64).eval()
79 |
80 |
81 | @pytest.fixture(
82 | scope='session',
83 | params=[
84 | *product(
85 | [Linear],
86 | prodict(in_features=[16], out_features=[15], bias=[True, False]),
87 | ),
88 | *product(
89 | [Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d],
90 | prodict(in_channels=[1, 3], out_channels=[1, 3], kernel_size=[2, 3], bias=[True, False]),
91 | ),
92 | ],
93 | ids=lambda param: param[0].__name__
94 | )
95 | def module_linear(rng, request):
96 | '''Fixture for linear modules.'''
97 | module_type, kwargs = request.param
98 | return module_type(**kwargs).to(torch.float64).eval()
99 |
100 |
101 | @pytest.fixture(scope='session')
102 | def module_batchnorm(module_linear, rng):
103 | '''Fixture for BatchNorm-type modules, based on adjacent linear module.'''
104 | module_map = [
105 | ((Linear, Conv1d, ConvTranspose1d), BatchNorm1d),
106 | ((Conv2d, ConvTranspose2d), BatchNorm2d),
107 | ((Conv3d, ConvTranspose3d), BatchNorm3d),
108 | ]
109 | feature_index_map = [
110 | ((ConvTranspose1d, ConvTranspose2d, ConvTranspose3d), 1),
111 | ((Linear, Conv1d, Conv2d, Conv3d), 0),
112 | ]
113 |
114 | batchnorm_type = None
115 | for types, target_type in module_map:
116 | if isinstance(module_linear, types):
117 | batchnorm_type = target_type
118 | break
119 | if batchnorm_type is None:
120 | raise RuntimeError('No batchnorm type for linear layer found.')
121 |
122 | feature_index = None
123 | for types, index in feature_index_map:
124 | if isinstance(module_linear, types):
125 | feature_index = index
126 | break
127 | if feature_index is None:
128 | raise RuntimeError('No feature index for linear layer found.')
129 |
130 | batchnorm = batchnorm_type(num_features=module_linear.weight.shape[feature_index]).to(torch.float64).eval()
131 | batchnorm.weight.data.uniform_(**{'from': 0.1, 'to': 2.0, 'generator': rng})
132 | batchnorm.bias.data.normal_(generator=rng)
133 | batchnorm.eps = 1e-30
134 | return batchnorm
135 |
136 |
137 | @pytest.fixture(scope='session')
138 | def data_linear(rng, batchsize, module_linear):
139 | '''Fixture to create data for a linear module, given an RNG.'''
140 | shape = (batchsize,)
141 | setups = [
142 | (Conv1d, 1, 1),
143 | (ConvTranspose1d, 0, 1),
144 | (Conv2d, 1, 2),
145 | (ConvTranspose2d, 0, 2),
146 | (Conv3d, 1, 3),
147 | (ConvTranspose3d, 0, 3)
148 | ]
149 | if isinstance(module_linear, Linear):
150 | shape += (module_linear.weight.shape[1],)
151 | else:
152 | for module_type, dim, ndims in setups:
153 | if isinstance(module_linear, module_type):
154 | shape += (module_linear.weight.shape[dim],) + (4,) * ndims
155 |
156 | return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng)
157 |
158 |
159 | @pytest.fixture(scope='session', params=[
160 | (16,),
161 | (4,),
162 | (4, 4),
163 | (4, 4, 4),
164 | ])
165 | def data_simple(request, rng, batchsize):
166 | '''Fixture to create data for a linear module, given an RNG.'''
167 | shape = (batchsize,) + request.param
168 | return torch.empty(*shape, dtype=torch.float64).normal_(generator=rng)
169 |
170 |
171 | COMPOSITE_KWARGS = {
172 | EpsilonGammaBox: {'low': -3., 'high': 3.},
173 | }
174 |
175 |
176 | class PassClone(Hook):
177 | '''Clone of the Pass rule.'''
178 | def backward(self, module, grad_input, grad_output):
179 | '''Directly return grad_output.'''
180 | return grad_output
181 |
182 |
183 | class GradClone(Hook):
184 | '''Explicit rule to return the cloned gradient.'''
185 | def backward(self, module, grad_input, grad_output):
186 | '''Directly return grad_output.'''
187 | return grad_input.clone()
188 |
189 |
190 | @pytest.fixture(scope='session', params=[
191 | None,
192 | [(Linear, GradClone()), (Activation, PassClone())],
193 | ])
194 | def cooperative_layer_map(request):
195 | '''Fixture for a cooperative layer map in LayerMapComposite subtypes.'''
196 | return request.param
197 |
198 |
199 | @pytest.fixture(scope='session', params=[
200 | None,
201 | [(AnyLinear, GradClone())],
202 | ])
203 | def cooperative_first_map(request):
204 | '''Fixture for a cooperative layer map for the first layer in SpecialFirstLayerMapComposite subtypes.'''
205 | return request.param
206 |
207 |
208 | @pytest.fixture(scope='session', params=[
209 | elem for elem in COMPOSITES.values()
210 | if issubclass(elem, LayerMapComposite) and not issubclass(elem, SpecialFirstLayerMapComposite)
211 | ])
212 | def layer_map_composite(request, cooperative_layer_map):
213 | '''Fixture for explicit LayerMapComposites.'''
214 | return request.param(layer_map=cooperative_layer_map, **COMPOSITE_KWARGS.get(request.param, {}))
215 |
216 |
217 | @pytest.fixture(scope='session', params=[
218 | elem for elem in COMPOSITES.values() if issubclass(elem, SpecialFirstLayerMapComposite)
219 | ])
220 | def special_first_layer_map_composite(request, cooperative_layer_map, cooperative_first_map):
221 | '''Fixturer for explicit SpecialFirstLayerMapComposites.'''
222 | return request.param(
223 | layer_map=cooperative_layer_map,
224 | first_map=cooperative_first_map,
225 | **COMPOSITE_KWARGS.get(request.param, {})
226 | )
227 |
228 |
229 | @pytest.fixture(scope='session', params=[Composite, *COMPOSITES.values()])
230 | def any_composite(request):
231 | '''Fixture for all explicitly registered Composites, as well as the empty Composite.'''
232 | return request.param(**COMPOSITE_KWARGS.get(request.param, {}))
233 |
234 |
235 | @pytest.fixture(scope='session')
236 | def name_map_composite(model_vision, layer_map_composite):
237 | '''Fixture to create NameMapComposites based on explicit LayerMapComposites.'''
238 | rule_map = {}
239 | for name, child in model_vision.named_modules():
240 | for dtype, hook_template in layer_map_composite.layer_map:
241 | if isinstance(child, dtype):
242 | rule_map.setdefault(hook_template, []).append(name)
243 | break
244 | name_map = [(tuple(value), key) for key, value in rule_map.items()]
245 | return NameMapComposite(name_map=name_map)
246 |
247 |
248 | @pytest.fixture(scope='session')
249 | def partial_name_map_composite(name_map_composite, pyrng):
250 | '''Fixture to create a randomly sampled partial NameMapComposites.'''
251 | name_map = name_map_composite.name_map
252 | assocs = [(i, j) for i, (keys, _) in enumerate(name_map) for j in range(len(keys))]
253 | accepted_assocs = sorted(pyrng.sample(assocs, len(assocs) // 2))
254 | partial_name_map = [
255 | (tuple(name_map[k][0][n] for _, n in g), name_map[k][1].copy())
256 | for k, g in groupby(accepted_assocs, lambda o: o[0])
257 | ]
258 |
259 | return NameMapComposite(name_map=partial_name_map)
260 |
261 |
262 | @pytest.fixture(scope='session')
263 | def mixed_composite(partial_name_map_composite, special_first_layer_map_composite):
264 | '''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and
265 | SpecialFirstLayerMapComposites.
266 | '''
267 | composites = [partial_name_map_composite, special_first_layer_map_composite]
268 | return MixedComposite(composites)
269 |
270 |
271 | @pytest.fixture(scope='session')
272 | def name_layer_map_composite(partial_name_map_composite, layer_map_composite):
273 | '''Fixture to create NameLayerMapComposites based on an explicit NameMapComposite and LayerMapComposite.'''
274 | return NameLayerMapComposite(
275 | name_map=partial_name_map_composite.name_map,
276 | layer_map=layer_map_composite.layer_map,
277 | )
278 |
279 |
280 | @pytest.fixture(scope='session', params=[alexnet, vgg11, resnet18])
281 | def model_vision(request):
282 | '''Models to test composites on.'''
283 | return request.param()
284 |
285 |
286 | @pytest.fixture(scope='session')
287 | def model_simple(rng, module_linear, data_linear):
288 | '''Fixture for a simple model, using a linear module followed by a ReLU and a dense layer.'''
289 | with torch.no_grad():
290 | intermediate = module_linear(data_linear)
291 | return torch.nn.Sequential(OrderedDict([
292 | ('linr0', module_linear),
293 | ('actv0', torch.nn.ReLU()),
294 | ('flat0', torch.nn.Flatten()),
295 | ('linr1', torch.nn.Linear(intermediate.shape[1:].numel(), 4, dtype=intermediate.dtype)),
296 | ]))
297 |
298 |
299 | @pytest.fixture(scope='session')
300 | def model_simple_grad(data_linear, model_simple):
301 | '''Fixture for gradient wrt. data_linear for model_simple.'''
302 | data = data_linear.detach().requires_grad_()
303 | output = model_simple(data)
304 | grad, = torch.autograd.grad(output, data, output)
305 | return grad
306 |
307 |
308 | @pytest.fixture(scope='session')
309 | def model_simple_output(data_linear, model_simple):
310 | '''Fixture for output given data_linear for model_simple.'''
311 | data = data_linear.detach()
312 | output = model_simple(data)
313 | return output
314 |
315 |
316 | @pytest.fixture(scope='session', params=[
317 | identity,
318 | one_hot_max,
319 | torch.ones_like,
320 | ])
321 | def grad_outputs_func(request):
322 | '''Fixture for common attr_output_fn functions.'''
323 | return request.param
324 |
--------------------------------------------------------------------------------
/tests/helpers.py:
--------------------------------------------------------------------------------
1 | '''Helper functions for various tests.'''
2 | from itertools import product
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from zennit.types import BatchNorm
8 |
9 |
10 | def prodict(**kwargs):
11 | '''Create a dictionary with values which are the cartesian product of the input keyword arguments.'''
12 | return [dict(zip(kwargs, val)) for val in product(*kwargs.values())]
13 |
14 |
15 | def one_hot_max(output):
16 | '''Get the one-hot encoded max.'''
17 | return torch.sparse_coo_tensor(
18 | [*zip(np.unravel_index(output.argmax(), output.shape))], [1.], output.shape, dtype=output.dtype
19 | ).to_dense()
20 |
21 |
22 | def assert_identity_hook(equal=True, message=''):
23 | '''Create an assertion hook which checks whether the module does or does not modify its input.'''
24 | def assert_identity(module, input, output):
25 | '''Assert whether the module does or does not modify its input.'''
26 | assert equal == torch.allclose(input[0], output, rtol=1e-5), message
27 | return assert_identity
28 |
29 |
30 | def randomize_bnorm(model):
31 | '''Randomize all BatchNorm module parameters of a model.'''
32 | for module in model.modules():
33 | if isinstance(module, BatchNorm):
34 | module.weight.data.uniform_(0.1, 2.0)
35 | module.running_var.data.uniform_(0.1, 2.0)
36 | module.bias.data.normal_()
37 | module.running_mean.data.normal_()
38 | # smaller eps to reduce error
39 | module.eps = 1e-30
40 | return model
41 |
42 |
43 | def nograd(model):
44 | '''Unset grad requirement for all model parameters.'''
45 | for param in model.parameters():
46 | param.requires_grad = False
47 | return model
48 |
--------------------------------------------------------------------------------
/tests/test_attribution.py:
--------------------------------------------------------------------------------
1 | '''Tests for Attributors.'''
2 | from functools import partial
3 | from itertools import product
4 |
5 | import pytest
6 | import torch
7 |
8 | from zennit.attribution import Gradient, IntegratedGradients, SmoothGrad, Occlusion, occlude_independent
9 |
10 |
11 | class IdentityLogger(torch.nn.Module):
12 | '''Helper-Module to log input tensors.'''
13 | def __init__(self):
14 | super().__init__()
15 | self.tensors = []
16 |
17 | def forward(self, input):
18 | '''Clone input, append to self.tensors and return the cloned tensor.'''
19 | self.tensors.append(input.clone())
20 | return self.tensors[-1]
21 |
22 |
23 | def test_gradient_attributor_inactive(
24 | data_linear, model_simple, model_simple_output, any_composite, grad_outputs_func
25 | ):
26 | '''Test whether composite context and attributor match for Gradient.'''
27 |
28 | with Gradient(model=model_simple, composite=any_composite, attr_output=grad_outputs_func) as attributor:
29 | # verify that all hooks are active
30 | assert all(hook.active for hook in attributor.composite.hook_refs)
31 | with attributor.inactive():
32 | # verify that all hooks are inactive
33 | assert all(not hook.active for hook in attributor.composite.hook_refs)
34 |
35 |
36 | def test_gradient_attributor_composite(
37 | data_linear, model_simple, model_simple_output, any_composite, grad_outputs_func
38 | ):
39 | '''Test whether composite context and attributor match for Gradient.'''
40 | with any_composite.context(model_simple) as module:
41 | data = data_linear.detach().requires_grad_()
42 | output_context = module(data)
43 | grad_outputs = grad_outputs_func(output_context)
44 | grad_context, = torch.autograd.grad(output_context, data, grad_outputs)
45 |
46 | with Gradient(model=model_simple, composite=any_composite, attr_output=grad_outputs_func) as attributor:
47 | output_attributor, grad_attributor = attributor(data_linear)
48 |
49 | assert torch.allclose(output_context, output_attributor)
50 | assert torch.allclose(grad_context, grad_attributor)
51 | assert torch.allclose(model_simple_output, output_attributor)
52 |
53 |
54 | @pytest.mark.parametrize('use_const,use_call,use_init', product(*[[True, False]] * 3))
55 | def test_gradient_attributor_output_fn(data_simple, grad_outputs_func, use_const, use_call, use_init):
56 | '''Test whether attributors' attr_output supports functions, constants and None in any of supplied or not supplied
57 | for each the attributor initialization and the call.
58 | '''
59 | model = IdentityLogger()
60 |
61 | attr_output = grad_outputs_func(data_simple) if use_const else grad_outputs_func
62 | init_attr_output = attr_output if use_init else None
63 | call_attr_output = attr_output if use_call else None
64 |
65 | with Gradient(model=model, attr_output=init_attr_output) as attributor:
66 | _, grad = attributor(data_simple, attr_output=call_attr_output)
67 |
68 | if (use_call or use_init):
69 | expected_grad = grad_outputs_func(data_simple)
70 | else:
71 | # the identity is the default attr_output
72 | expected_grad = data_simple
73 |
74 | assert torch.allclose(expected_grad, grad), 'Attributor output function gradient mismatch!'
75 |
76 |
77 | def test_gradient_attributor_grad(data_simple):
78 | '''Test whether the gradient of Gradient matches.'''
79 | model = torch.nn.Softplus(beta=1.)
80 | data = data_simple.view_as(data_simple).requires_grad_()
81 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data))
82 |
83 | with Gradient(model=model, create_graph=True) as attributor:
84 | _, grad = attributor(data, torch.ones_like)
85 | gradgrad, = torch.autograd.grad(grad.sum(), data)
86 |
87 | assert torch.allclose(gradgrad, target), 'Gradient Attributor second order gradient mismatch!'
88 |
89 |
90 | def test_gradient_attributor_output_fn_precedence(data_simple):
91 | '''Test whether the gradient attributor attr_output at call is preferred when it is supplied at both initialization
92 | and call.
93 | '''
94 | model = IdentityLogger()
95 |
96 | init_attr_output = torch.ones_like
97 | call_attr_output = torch.zeros_like
98 |
99 | with Gradient(model=model, attr_output=init_attr_output) as attributor:
100 | _, grad = attributor(data_simple, attr_output=call_attr_output)
101 |
102 | expected_grad = call_attr_output(data_simple)
103 | assert torch.allclose(expected_grad, grad), 'Attributor output function precedence mismatch!'
104 |
105 |
106 | def test_smooth_grad_single(data_linear, model_simple, model_simple_output, model_simple_grad):
107 | '''Test whether SmoothGrad with a single iteration is equal to the gradient.'''
108 | with SmoothGrad(model=model_simple, noise_level=0.1, n_iter=1) as attributor:
109 | output, grad = attributor(data_linear)
110 |
111 | assert torch.allclose(model_simple_grad, grad)
112 | assert torch.allclose(model_simple_output, output)
113 |
114 |
115 | def test_smooth_grad_single_grad(data_simple):
116 | '''Test whether the gradient of SmoothGrad matches.'''
117 | model = torch.nn.Softplus(beta=1.)
118 | data = data_simple.view_as(data_simple).requires_grad_()
119 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data))
120 |
121 | with SmoothGrad(model=model, noise_level=0.1, n_iter=1, create_graph=True) as attributor:
122 | _, grad = attributor(data, torch.ones_like)
123 | gradgrad, = torch.autograd.grad(grad.sum(), data)
124 |
125 | assert torch.allclose(gradgrad, target), 'SmoothGrad Attributor second order gradient mismatch!'
126 |
127 |
128 | @pytest.mark.parametrize('noise_level', [0.0, 0.1, 0.3, 0.5])
129 | def test_smooth_grad_distribution(data_simple, noise_level):
130 | '''Test whether the SmoothGrad sampled distribution matches.'''
131 | model = IdentityLogger()
132 |
133 | dims = tuple(range(1, data_simple.ndim))
134 | noise_var = (noise_level * (data_simple.amax(dims) - data_simple.amin(dims))) ** 2
135 | n_iter = 100
136 |
137 | with SmoothGrad(model=model, noise_level=noise_level, n_iter=n_iter, attr_output=torch.ones_like) as attributor:
138 | _, grad = attributor(data_simple)
139 |
140 | assert len(model.tensors) == n_iter, 'SmoothGrad iterations did not match n_iter!'
141 |
142 | sample_mean = sum(model.tensors) / len(model.tensors)
143 | sample_var = ((sum((tensor - sample_mean) ** 2 for tensor in model.tensors) / len(model.tensors))).mean(dims)
144 |
145 | if noise_level > 0.:
146 | std_ratio = (sample_var / noise_var) ** .5
147 | assert (std_ratio < 1.5).all().item(), 'SmoothGrad sample variance is too high!'
148 | assert (std_ratio > 0.667).all().item(), 'SmoothGrad sample variance is too low!'
149 | else:
150 | assert (sample_var < 1e-9).all().item(), 'SmoothGrad sample variance is too high!'
151 | assert torch.allclose(grad, torch.ones_like(data_simple)), 'SmoothGrad of identity is wrong!'
152 |
153 |
154 | @pytest.mark.parametrize('baseline_fn', [None, torch.zeros_like, torch.ones_like])
155 | def test_integrated_gradients_single(data_linear, model_simple, model_simple_output, model_simple_grad, baseline_fn):
156 | '''Test whether IntegratedGradients with a single iteration is equal to the expected output given multiple
157 | baselines.
158 | '''
159 | with IntegratedGradients(model=model_simple, n_iter=1, baseline_fn=baseline_fn) as attributor:
160 | output, grad = attributor(data_linear)
161 |
162 | if baseline_fn is None:
163 | baseline_fn = torch.zeros_like
164 | expected_grad = model_simple_grad * (data_linear - baseline_fn(data_linear))
165 |
166 | assert torch.allclose(expected_grad, grad), 'Gradient mismatch for IntegratedGradients!'
167 | assert torch.allclose(model_simple_output, output), 'Output mismatch for IntegratedGradients!'
168 |
169 |
170 | def test_integrated_gradients_single_grad(data_simple):
171 | '''Test whether the gradient of IntegratedGradients matches.'''
172 | model = torch.nn.Softplus(beta=1.)
173 | data = data_simple.view_as(data_simple).requires_grad_()
174 | # this is d/dx (x * d/dx softplus(x)), i.e. the gradient of input times gradient of softplus
175 | target = torch.sigmoid(data) * (1 - torch.sigmoid(data)) * data + torch.sigmoid(data)
176 |
177 | with IntegratedGradients(model=model, n_iter=1, baseline_fn=torch.zeros_like, create_graph=True) as attributor:
178 | _, grad = attributor(data, torch.ones_like)
179 | gradgrad, = torch.autograd.grad(grad.sum(), data)
180 |
181 | assert torch.allclose(gradgrad, target), 'IntegratedGradients Attributor second order gradient mismatch!'
182 |
183 |
184 | def test_integrated_gradients_path(data_simple):
185 | '''Test whether IntegratedGradients with a single iteration and a zero-baseline is equal to the input times the
186 | gradient.
187 | '''
188 | model = IdentityLogger()
189 |
190 | dims = tuple(range(1, data_simple.ndim))
191 | n_iter = 100
192 | with IntegratedGradients(model=model, n_iter=n_iter, attr_output=torch.ones_like) as attributor:
193 | _, grad = attributor(data_simple)
194 |
195 | assert len(model.tensors) == n_iter, 'IntegratedGradients iterations did not match n_iter!'
196 |
197 | data_simple_norm = data_simple / (data_simple ** 2).sum(dim=dims, keepdim=True) ** .5
198 | assert all(
199 | torch.allclose(step / (step ** 2).sum(dim=dims, keepdim=True) ** .5, data_simple_norm)
200 | for step in model.tensors
201 | ), 'IntegratedGradients segments do not lie on path!'
202 | assert torch.allclose(data_simple, grad), 'IntegratedGradients of identity is wrong!'
203 |
204 |
205 | @pytest.mark.parametrize('window,stride', zip([1, 2, 4, (1,), (2,), (4,)], [1, 2, 4, (1,), (2,), (4,)]))
206 | def test_occlusion_disjunct(data_simple, window, stride):
207 | '''Function to test whether the inputs used for disjunct occlusion windows are correct.'''
208 | model = IdentityLogger()
209 |
210 | # delete everything except the window
211 | occlusion_fn = partial(occlude_independent, fill_fn=torch.zeros_like, invert=False)
212 |
213 | with Occlusion(model=model, window=window, stride=stride, occlusion_fn=occlusion_fn) as attributor:
214 | attributor(data_simple)
215 |
216 | # omit final pass for full output
217 | reconstruct = sum(model.tensors[:-1])
218 | assert torch.allclose(data_simple, reconstruct), 'Disjunct occlusion does not sum to original input!'
219 |
220 |
221 | @pytest.mark.parametrize(
222 | 'fill_fn,invert', [
223 | (None, False),
224 | (torch.zeros_like, False),
225 | (torch.zeros_like, True),
226 | (torch.ones_like, True),
227 | ]
228 | )
229 | def test_occlusion_single(data_linear, model_simple, model_simple_output, grad_outputs_func, fill_fn, invert):
230 | '''Function to test whether the inputs used for a full occlusion window are correct.'''
231 | window, stride = [data_linear.shape] * 2
232 | if fill_fn is None:
233 | # setting when no occlusion_fn is supplied
234 | occlusion_fn = None
235 | fill_fn = torch.zeros_like
236 | else:
237 | occlusion_fn = partial(occlude_independent, fill_fn=fill_fn, invert=invert)
238 |
239 | identity_logger = IdentityLogger()
240 | model = torch.nn.Sequential(identity_logger, model_simple)
241 |
242 | with Occlusion(
243 | model=model,
244 | window=window,
245 | stride=stride,
246 | attr_output=grad_outputs_func,
247 | occlusion_fn=occlusion_fn,
248 | ) as attributor:
249 | output, score = attributor(data_linear)
250 |
251 | expected_occluded = fill_fn(data_linear) if invert else data_linear
252 | expected_output = model_simple(expected_occluded)
253 | expected_score = grad_outputs_func(expected_output).sum(
254 | tuple(range(1, expected_output.ndim))
255 | )[(slice(None),) + (None,) * (data_linear.ndim - 1)].expand_as(data_linear)
256 |
257 | assert len(identity_logger.tensors) == 2, 'Incorrect number of forward passes for Occlusion!'
258 | assert torch.allclose(identity_logger.tensors[0], expected_occluded), 'Occluded input mismatch!'
259 | assert torch.allclose(model_simple_output, output), 'Output mismatch for Occlusion!'
260 | assert torch.allclose(expected_score, score), 'Scores are incorrect for Occlusion!'
261 |
262 |
263 | @pytest.mark.parametrize('argument,container', product(
264 | ['window', 'stride'],
265 | ['monkey', {3}, ('you', 'are', 'breathtaking'), range(3), [3]]
266 | ))
267 | def test_occlusion_stride_window_typecheck(argument, container):
268 | '''Test whether Occlusion raises a TypeError on incorrect types for window and stride.'''
269 | with pytest.raises(TypeError):
270 | Occlusion(model=None, **{argument: container})
271 |
--------------------------------------------------------------------------------
/tests/test_canonizers.py:
--------------------------------------------------------------------------------
1 | '''Tests for canonizers'''
2 | from collections import OrderedDict
3 | from functools import partial
4 |
5 | import pytest
6 | import torch
7 | from torch.nn import Sequential
8 | from helpers import assert_identity_hook
9 |
10 | from zennit.canonizers import Canonizer, CompositeCanonizer
11 | from zennit.canonizers import SequentialMergeBatchNorm, NamedMergeBatchNorm, AttributeCanonizer
12 | from zennit.core import RemovableHandleList
13 | from zennit.types import BatchNorm
14 |
15 |
16 | def test_merge_batchnorm_consistency(module_linear, module_batchnorm, data_linear):
17 | '''Test whether the output of the merged batchnorm is consistent with its original output.'''
18 | output_linear_before = module_linear(data_linear)
19 | output_batchnorm_before = module_batchnorm(output_linear_before)
20 | canonizer = SequentialMergeBatchNorm()
21 |
22 | try:
23 | canonizer.register((module_linear,), module_batchnorm)
24 | output_linear_canonizer = module_linear(data_linear)
25 | output_batchnorm_canonizer = module_batchnorm(output_linear_canonizer)
26 | finally:
27 | canonizer.remove()
28 |
29 | output_linear_after = module_linear(data_linear)
30 | output_batchnorm_after = module_batchnorm(output_linear_after)
31 |
32 | assert all(torch.allclose(left, right, atol=1e-5) for left, right in [
33 | (output_linear_before, output_linear_after),
34 | (output_batchnorm_before, output_batchnorm_after),
35 | (output_batchnorm_before, output_linear_canonizer),
36 | (output_linear_canonizer, output_batchnorm_canonizer),
37 | ])
38 |
39 |
40 | @pytest.mark.parametrize('canonizer_fn', [
41 | SequentialMergeBatchNorm,
42 | partial(NamedMergeBatchNorm, [(['dense0'], 'bnorm0')]),
43 | ])
44 | def test_merge_batchnorm_apply(canonizer_fn, module_linear, module_batchnorm, data_linear):
45 | '''Test whether SequentialMergeBatchNorm merges BatchNorm modules correctly and keeps the output unchanged.'''
46 | model = Sequential(OrderedDict([
47 | ('dense0', module_linear),
48 | ('bnorm0', module_batchnorm)
49 | ]))
50 | output_before = model(data_linear)
51 |
52 | handles = RemovableHandleList(
53 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!'))
54 | for module in model.modules() if isinstance(module, BatchNorm)
55 | )
56 |
57 | canonizer = canonizer_fn()
58 |
59 | canonizer_handles = RemovableHandleList(canonizer.apply(model))
60 | try:
61 | output_canonizer = model(data_linear)
62 | finally:
63 | handles.remove()
64 | canonizer_handles.remove()
65 |
66 | handles = RemovableHandleList(
67 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!'))
68 | for module in model.modules() if isinstance(module, BatchNorm)
69 | )
70 |
71 | try:
72 | output_after = model(data_linear)
73 | finally:
74 | handles.remove()
75 |
76 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!'
77 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!'
78 |
79 |
80 | def test_attribute_canonizer(module_linear, data_linear):
81 | '''Test whether AttributeCanonizer overwrites and restores a linear module's forward correctly. '''
82 | model = Sequential(OrderedDict([
83 | ('dense0', module_linear),
84 | ]))
85 | output_before = model(data_linear)
86 |
87 | modules = [module_linear]
88 | module_type = type(module_linear)
89 |
90 | assert all(
91 | module.forward == module_type.forward.__get__(module) for module in modules
92 | ), 'Model has its forward already overwritten!'
93 |
94 | def attribute_map(name, module):
95 | if module is module_linear:
96 | return {'forward': lambda x: module_type.forward.__get__(module)(x) * 2}
97 | return None
98 |
99 | canonizer = AttributeCanonizer(attribute_map)
100 |
101 | handles = RemovableHandleList(canonizer.apply(model))
102 | try:
103 | assert not any(
104 | module.forward == module_type.forward.__get__(module) for module in modules
105 | ), 'Model forward was not overwritten!'
106 | output_canonizer = model(data_linear)
107 | finally:
108 | handles.remove()
109 |
110 | output_after = model(data_linear)
111 |
112 | assert all(
113 | module.forward == module_type.forward.__get__(module) for module in modules
114 | ), 'Model forward was not restored!'
115 | assert torch.allclose(output_canonizer, output_before * 2, rtol=1e-5), 'Canonizer output mismatch after register!'
116 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!'
117 |
118 |
119 | def test_composite_canonizer():
120 | '''Test whether CompositeCanonizer correctly combines two AttributeCanonizer canonizers.'''
121 | module_vanilla = torch.nn.Module()
122 | model = torch.nn.Sequential(module_vanilla)
123 |
124 | canonizer = CompositeCanonizer([
125 | AttributeCanonizer(lambda name, module: {'_test_x': 13}),
126 | AttributeCanonizer(lambda name, module: {'_test_y': 13}),
127 | ])
128 |
129 | handles = RemovableHandleList(canonizer.apply(model))
130 | try:
131 | assert hasattr(module_vanilla, '_test_x'), 'Model attribute _test_x was not overwritten!'
132 | assert hasattr(module_vanilla, '_test_y'), 'Model attribute _test_y was not overwritten!'
133 | finally:
134 | handles.remove()
135 |
136 | assert not hasattr(module_vanilla, '_test_x'), 'Model attribute _test_x was not removed!'
137 | assert not hasattr(module_vanilla, '_test_y'), 'Model attribute _test_y was not removed!'
138 |
139 |
140 | def test_base_canonizer_cooperative_apply():
141 | '''Test whether Canonizer's apply method is cooperative.'''
142 |
143 | class DummyCanonizer(Canonizer):
144 | '''Class to test Canonizer's cooperative apply.'''
145 | def apply(self, root_module):
146 | '''Cooperative apply which appends a string 'dummy' to the result of the parent class.'''
147 | instances = super().apply(root_module)
148 | instances += ['dummy']
149 | return instances
150 |
151 | def register(self):
152 | '''No-op register for abstract method.'''
153 |
154 | def remove(self):
155 | '''No-op remove for abstract method.'''
156 |
157 | canonizer = DummyCanonizer()
158 | model = Sequential()
159 | instances = canonizer.apply(model)
160 | assert 'dummy' in instances, 'Unexpected canonizer instance list!'
161 |
--------------------------------------------------------------------------------
/tests/test_cmap.py:
--------------------------------------------------------------------------------
1 | '''Tests for ColorMap and CMSL.'''
2 | from typing import NamedTuple
3 | import pytest
4 | import numpy as np
5 |
6 | from zennit.cmap import ColorMap, LazyColorMapCache
7 |
8 |
9 | class CMapExample(NamedTuple):
10 | '''Named tuple for example color maps used in tests.'''
11 | source: str
12 | nodes: list
13 |
14 |
15 | CMAPS = [
16 | ('000,fff', [
17 | (0x00, (0x00, 0x00, 0x00)),
18 | (0xff, (0xff, 0xff, 0xff)),
19 | ]),
20 | ('fff,f00', [
21 | (0x00, (0xff, 0xff, 0xff)),
22 | (0xff, (0xff, 0x00, 0x00)),
23 | ]),
24 | ('fff,00f', [
25 | (0x00, (0xff, 0xff, 0xff)),
26 | (0xff, (0x00, 0x00, 0xff)),
27 | ]),
28 | ('000,f00,ff0,fff', [
29 | (0x00, (0x00, 0x00, 0x00)),
30 | (0x55, (0xff, 0x00, 0x00)),
31 | (0xaa, (0xff, 0xff, 0x00)),
32 | (0xff, (0xff, 0xff, 0xff)),
33 | ]),
34 | ('000,00f,0ff', [
35 | (0x00, (0x00, 0x00, 0x00)),
36 | (0x7f, (0x00, 0x00, 0xff)),
37 | (0xff, (0x00, 0xff, 0xff)),
38 | ]),
39 | ('0ff,00f,80:000,f00,ff0,fff', [
40 | (0x00, (0x00, 0xff, 0xff)),
41 | (0x40, (0x00, 0x00, 0xff)),
42 | (0x80, (0x00, 0x00, 0x00)),
43 | (0xaa, (0xff, 0x00, 0x00)),
44 | (0xd4, (0xff, 0xff, 0x00)),
45 | (0xff, (0xff, 0xff, 0xff)),
46 | ]),
47 | ('00f,80:fff,f00', [
48 | (0x00, (0x00, 0x00, 0xff)),
49 | (0x80, (0xff, 0xff, 0xff)),
50 | (0xff, (0xff, 0x00, 0x00)),
51 | ]),
52 | ('0055a4,80:fff,ef4135', [
53 | (0x00, (0x00, 0x55, 0xa4)),
54 | (0x80, (0xff, 0xff, 0xff)),
55 | (0xff, (0xef, 0x41, 0x35)),
56 | ]),
57 | ('0000d0,80:d0d0d0,d00000', [
58 | (0x00, (0x00, 0x00, 0xd0)),
59 | (0x80, (0xd0, 0xd0, 0xd0)),
60 | (0xff, (0xd0, 0x00, 0x00)),
61 | ]),
62 | ('00d0d0,80:d0d0d0,d000d0', [
63 | (0x00, (0x00, 0xd0, 0xd0)),
64 | (0x80, (0xd0, 0xd0, 0xd0)),
65 | (0xff, (0xd0, 0x00, 0xd0)),
66 | ]),
67 | ('00d000,80:d0d0d0,d000d0', [
68 | (0x00, (0x00, 0xd0, 0x00)),
69 | (0x80, (0xd0, 0xd0, 0xd0)),
70 | (0xff, (0xd0, 0x00, 0xd0)),
71 | ]),
72 | ('7:000, 9:ffffff', [
73 | (0x00, (0x00, 0x00, 0x00)),
74 | (0x77, (0x00, 0x00, 0x00)),
75 | (0x99, (0xff, 0xff, 0xff)),
76 | (0xff, (0xff, 0xff, 0xff)),
77 | ]),
78 | ]
79 |
80 |
81 | def interpolate(x, nodes):
82 | '''Interpolate from example color map nodes.'''
83 | xp_addr = np.array([node[0] for node in nodes], dtype=np.float64)
84 | fp_rgb = np.array([node[1] for node in nodes], dtype=np.float64).T
85 | return np.stack([np.interp(x, xp_addr, fp) for fp in fp_rgb], axis=-1).round(12).clip(0., 255.).astype(np.uint8)
86 |
87 |
88 | @pytest.fixture(scope='session', params=CMAPS)
89 | def cmap_example(request):
90 | '''Example color map fixture.'''
91 | return CMapExample(*request.param)
92 |
93 |
94 | @pytest.mark.parametrize('source_code', [
95 | 'this', 'fff', ',,,', '111:111:111', 'fffff,fffff', 'f,f', 'fffffffff', 'ff:', 'ff:fff,00:000'
96 | ])
97 | def test_color_map_wrong_syntax(source_code):
98 | '''Test whether different kinds of syntax errors cause a RuntimeError.'''
99 | with pytest.raises(RuntimeError):
100 | ColorMap(source_code)
101 |
102 |
103 | def test_color_map_nodes_call(cmap_example):
104 | '''Test if the color map nodes have the specified color when calling a ColorMap instance.'''
105 | cmap = ColorMap(cmap_example.source)
106 | input_addr = np.array([node[0] for node in cmap_example.nodes], dtype=np.float64)[None]
107 | expected_rgb = np.array([node[1] for node in cmap_example.nodes], dtype=np.uint8)[None]
108 | cmap_rgb = (cmap(input_addr / 255.) * 255.).round(12).clip(0., 255.).astype(np.uint8)
109 | assert np.allclose(expected_rgb, cmap_rgb)
110 |
111 |
112 | def test_color_map_nodes_palette(cmap_example):
113 | '''Test if the color map nodes have the specified color when using ColorMap.palette.'''
114 | cmap = ColorMap(cmap_example.source)
115 | input_addr = [node[0] for node in cmap_example.nodes]
116 | expected_rgb = np.array([node[1] for node in cmap_example.nodes], dtype=np.uint8)[None]
117 | palette = cmap.palette(level=1.)
118 | cmap_rgb = palette[input_addr]
119 | assert np.allclose(expected_rgb, cmap_rgb)
120 |
121 |
122 | def test_color_map_full_call(cmap_example):
123 | '''Test if the color map nodes have correctly interpolated colors when calling a ColorMap instance.'''
124 | cmap = ColorMap(cmap_example.source)
125 | input_addr = np.arange(256, dtype=np.uint8)
126 | expected_rgb = interpolate(input_addr, cmap_example.nodes)
127 | cmap_rgb = (cmap(input_addr / 255.) * 255.).round(12).clip(0., 255.).astype(np.uint8)
128 | assert np.allclose(expected_rgb, cmap_rgb)
129 |
130 |
131 | def test_color_map_full_palette(cmap_example):
132 | '''Test if the color map nodes have correctly interpolated colors when using ColorMap.palette.'''
133 | input_addr = np.arange(256, dtype=np.uint8)
134 | expected_palette = interpolate(input_addr, cmap_example.nodes)
135 | cmap = ColorMap(cmap_example.source)
136 | cmap_palette = cmap.palette(level=1.0)
137 | assert np.allclose(expected_palette, cmap_palette)
138 |
139 |
140 | def test_color_map_reassign_source_palette(cmap_example):
141 | '''Test if calling a ColorMap instance for which the source was changed produces correctly interpolated colors.'''
142 | cmap = ColorMap('fff,fff')
143 | cmap.source = cmap_example.source
144 |
145 | input_addr = np.arange(256, dtype=np.uint8)
146 | expected_palette = interpolate(input_addr, cmap_example.nodes)
147 | cmap_palette = cmap.palette(level=1.0)
148 | assert np.allclose(expected_palette, cmap_palette)
149 |
150 |
151 | def test_color_map_source_property(cmap_example):
152 | '''Test if the source property of a color map is equal to the specified source code.'''
153 | cmap = ColorMap(cmap_example.source)
154 | assert cmap.source == cmap_example.source, 'Mismatching source!'
155 |
156 |
157 | @pytest.fixture(scope='function')
158 | def lazy_cmap_cache():
159 | '''Single fixture for a LazyColorMapCache'''
160 | return LazyColorMapCache({
161 | 'gray': '000,fff',
162 | 'red': '100,f00',
163 | })
164 |
165 |
166 | class TestLazyColorMapCache:
167 | '''Tests for LazyColorMapCache.'''
168 | @staticmethod
169 | def test_missing(lazy_cmap_cache):
170 | '''Test whether accessing an unknown key causes a KeyError.'''
171 | with pytest.raises(KeyError):
172 | _ = lazy_cmap_cache['no such cmap']
173 |
174 | @staticmethod
175 | def test_get_item_uncompiled(lazy_cmap_cache):
176 | '''Test whether accessing an uncompiled entry compiles and returns the correct ColorMap.'''
177 | cmap = lazy_cmap_cache['red']
178 | assert isinstance(cmap, ColorMap)
179 | assert cmap.source == '100,f00'
180 |
181 | @staticmethod
182 | def test_get_item_cached(lazy_cmap_cache):
183 | '''Test whether accessing a previously compiled and cached entry returns the same ColorMap.'''
184 | cmaps = [
185 | lazy_cmap_cache['red'],
186 | lazy_cmap_cache['red'],
187 | ]
188 | assert cmaps[0] is cmaps[1]
189 |
190 | @staticmethod
191 | def test_set_item_existing(lazy_cmap_cache):
192 | '''Test whether setting an already existing, uncompiled entry and accessing it returns the correct ColorMap.'''
193 | lazy_cmap_cache['red'] = 'fff,f00'
194 | assert lazy_cmap_cache['red'].source == 'fff,f00'
195 |
196 | @staticmethod
197 | def test_set_item_new(lazy_cmap_cache):
198 | '''Test whether setting a new entry and accessing it returns the correct ColorMap.'''
199 | lazy_cmap_cache['blue'] = 'fff,00f'
200 | assert lazy_cmap_cache['blue'].source == 'fff,00f'
201 |
202 | @staticmethod
203 | def test_set_item_compiled(lazy_cmap_cache):
204 | '''Test whether setting an already existing, compiled entry and accessing it returns the same, modified
205 | ColorMap instance.
206 | '''
207 | original_cmap = lazy_cmap_cache['red']
208 | lazy_cmap_cache['red'] = 'fff,f00'
209 | assert lazy_cmap_cache['red'].source == 'fff,f00'
210 | assert original_cmap is lazy_cmap_cache['red']
211 |
212 | @staticmethod
213 | def test_del_item_uncompiled(lazy_cmap_cache):
214 | '''Test whether deleting an uncompiled entry correctly removes the entry.'''
215 | del lazy_cmap_cache['red']
216 | assert 'red' not in lazy_cmap_cache
217 |
218 | @staticmethod
219 | def test_del_item_compiled(lazy_cmap_cache):
220 | '''Test whether deleting a compiled entry correctly removes the entry.'''
221 | _ = lazy_cmap_cache['red']
222 | del lazy_cmap_cache['red']
223 | assert 'red' not in lazy_cmap_cache
224 |
225 | @staticmethod
226 | def test_iter(lazy_cmap_cache):
227 | '''Test whether iterating a LazyColorMapCache returns its keys.'''
228 | assert (list(lazy_cmap_cache) == ['gray', 'red'])
229 |
230 | @staticmethod
231 | def test_len(lazy_cmap_cache):
232 | '''Test whether calling len on a LazyColorMapCache returns the correct length.'''
233 | assert len(lazy_cmap_cache) == 2
234 |
--------------------------------------------------------------------------------
/tests/test_composites.py:
--------------------------------------------------------------------------------
1 | '''Tests for composites using torchvision models.'''
2 | from types import MethodType
3 | from itertools import product
4 |
5 | from zennit.core import BasicHook, collect_leaves
6 |
7 |
8 | def ishookcopy(hook, hook_template):
9 | '''Check if ``hook`` is a copy of ``hook_template`` (due to copying-mechanics of BasicHook).'''
10 | if isinstance(hook_template, BasicHook):
11 | return all(
12 | getattr(hook, key) == getattr(hook_template, key)
13 | for key in (
14 | 'input_modifiers',
15 | 'param_modifiers',
16 | 'output_modifiers',
17 | 'gradient_mapper',
18 | )
19 | )
20 | return isinstance(hook, type(hook_template))
21 |
22 |
23 | def check_hook_registered(module, hook_template):
24 | '''Check whether a ``hook_template`` has been registered to ``module``. '''
25 | return any(
26 | ishookcopy(hook_func.__self__, hook_template)
27 | for hook_func in getattr(module, '_forward_pre_hooks').values()
28 | if isinstance(hook_func, MethodType)
29 | )
30 |
31 |
32 | def verify_no_hooks(model):
33 | '''Verify that ``model`` has no registered forward (-pre) hooks.'''
34 | return not any(
35 | any(getattr(module, key) for key in ('_forward_hooks', '_forward_pre_hooks'))
36 | for module in model.modules()
37 | )
38 |
39 |
40 | def test_composite_layer_map_registered(layer_map_composite, model_vision):
41 | '''Tests whether the explicit LayerMapComposites register and unregister their rules correctly.'''
42 | errors = []
43 | with layer_map_composite.context(model_vision):
44 | for child in model_vision.modules():
45 | for dtype, hook_template in layer_map_composite.layer_map:
46 | if isinstance(child, dtype):
47 | if not check_hook_registered(child, hook_template):
48 | errors.append((
49 | '{} is first of {} but {} is not registered!',
50 | (child, dtype, hook_template),
51 | ))
52 | break
53 |
54 | if not verify_no_hooks(model_vision):
55 | errors.append(('Model has hooks registered after composite was removed!', ()))
56 |
57 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors))
58 |
59 |
60 | def test_composite_special_first_layer_map_registered(special_first_layer_map_composite, model_vision):
61 | '''Tests whether the explicit LayerMapComposites register and unregister their rules correctly.'''
62 | errors = []
63 | try:
64 | special_first_layer, special_first_template, special_first_dtype = next(
65 | (child, hook_template, dtype)
66 | for child, (dtype, hook_template) in product(
67 | collect_leaves(model_vision), special_first_layer_map_composite.first_map
68 | ) if isinstance(child, dtype)
69 | )
70 | except StopIteration:
71 | special_first_layer = None
72 | special_first_template = None
73 |
74 | with special_first_layer_map_composite.context(model_vision):
75 | if special_first_layer is not None and not check_hook_registered(special_first_layer, special_first_template):
76 | errors.append((
77 | 'Special first layer {} is first of {} but {} is not registered!',
78 | (special_first_layer, special_first_dtype, special_first_template)
79 | ))
80 |
81 | children = (child for child in model_vision.modules() if child is not special_first_layer)
82 | for child in children:
83 | for dtype, hook_template in special_first_layer_map_composite.layer_map:
84 | if isinstance(child, dtype):
85 | if not check_hook_registered(child, hook_template):
86 | errors.append((
87 | '{} is first of {} but {} is not registered!',
88 | (child, dtype, hook_template),
89 | ))
90 | break
91 |
92 | if not verify_no_hooks(model_vision):
93 | errors.append(('Model has hooks registered after composite was removed!', ()))
94 |
95 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors))
96 |
97 |
98 | def test_composite_name_map_registered(name_map_composite, model_vision):
99 | '''Tests whether the constructed NameMapComposites register and unregister their rules correctly.'''
100 | errors = []
101 | with name_map_composite.context(model_vision):
102 | for name, child in model_vision.named_modules():
103 | for names, hook_template in name_map_composite.name_map:
104 | if name in names:
105 | if not check_hook_registered(child, hook_template):
106 | errors.append((
107 | '{} is first in name map for {}, but is not registered!',
108 | (name, hook_template),
109 | ))
110 | break
111 |
112 | if not verify_no_hooks(model_vision):
113 | errors.append(('Model has hooks registered after composite was removed!', ()))
114 |
115 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors))
116 |
117 |
118 | def test_composite_mixed_registered(mixed_composite, model_vision):
119 | '''Tests whether the constructed MixedComposites register and unregister their rules correctly.'''
120 | errors = []
121 |
122 | name_map_composite, special_first_layer_map_composite = mixed_composite.composites
123 |
124 | try:
125 | special_first_layer, special_first_template, special_first_dtype = next(
126 | (child, hook_template, dtype)
127 | for child, (dtype, hook_template) in product(
128 | collect_leaves(model_vision), special_first_layer_map_composite.first_map
129 | ) if isinstance(child, dtype)
130 | )
131 | except StopIteration:
132 | special_first_layer = None
133 | special_first_template = None
134 |
135 | with mixed_composite.context(model_vision):
136 | has_matched_first_layer = False
137 | for name, child in model_vision.named_modules():
138 | has_matched_name_map = False
139 | for names, hook_template in name_map_composite.name_map:
140 | if name in names:
141 | has_matched_name_map = True
142 | if not check_hook_registered(child, hook_template):
143 | errors.append((
144 | '{} is first in name map for {}, but is not registered!',
145 | (name, hook_template),
146 | ))
147 | break
148 |
149 | if has_matched_name_map:
150 | continue
151 |
152 | if not has_matched_first_layer and child == special_first_layer:
153 | has_matched_first_layer = True
154 | if not check_hook_registered(child, special_first_template):
155 | errors.append((
156 | 'Special first layer {} is first of {} but {} is not registered!',
157 | (special_first_layer, special_first_dtype, special_first_template)
158 | ))
159 | continue
160 |
161 | for dtype, hook_template in special_first_layer_map_composite.layer_map:
162 | if isinstance(child, dtype):
163 | if not check_hook_registered(child, hook_template):
164 | errors.append((
165 | '{} is first of {} but {} is not registered!',
166 | (child, dtype, hook_template),
167 | ))
168 | break
169 |
170 | if not verify_no_hooks(model_vision):
171 | errors.append(('Model has hooks registered after composite was removed!', ()))
172 |
173 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors))
174 |
175 |
176 | def test_composite_name_layer_map_registered(name_layer_map_composite, model_vision):
177 | '''Tests whether the constructed NameLayerMapComposites register and unregister their rules correctly.'''
178 | errors = []
179 |
180 | name_map_composite, layer_map_composite = name_layer_map_composite.composites
181 |
182 | with name_layer_map_composite.context(model_vision):
183 | for name, child in model_vision.named_modules():
184 | for names, hook_template in name_map_composite.name_map:
185 | has_matched_name_map = False
186 | if name in names:
187 | has_matched_name_map = True
188 | if not check_hook_registered(child, hook_template):
189 | errors.append((
190 | '{} is first in name map for {}, but is not registered!',
191 | (name, hook_template),
192 | ))
193 | break
194 |
195 | if has_matched_name_map:
196 | continue
197 |
198 | for dtype, hook_template in layer_map_composite.layer_map:
199 | if isinstance(child, dtype):
200 | if not check_hook_registered(child, hook_template):
201 | errors.append((
202 | '{} is first of {} but {} is not registered!',
203 | (child, dtype, hook_template),
204 | ))
205 | break
206 |
207 | if not verify_no_hooks(model_vision):
208 | errors.append(('Model has hooks registered after composite was removed!', ()))
209 |
210 | assert not errors, 'Errors:\n ' + '\n '.join(f'[{n}] ' + msg.format(*arg) for n, (msg, arg) in enumerate(errors))
211 |
--------------------------------------------------------------------------------
/tests/test_image.py:
--------------------------------------------------------------------------------
1 | '''Tests for image operations.'''
2 | from typing import NamedTuple
3 | from itertools import product
4 | from io import BytesIO
5 |
6 | import pytest
7 | import numpy as np
8 | from PIL import Image
9 |
10 | from zennit.cmap import ColorMap
11 | from zennit.image import get_cmap, palette, imgify, gridify, imsave, interval_norm_bounds
12 |
13 |
14 | @pytest.fixture(scope='session', params=[
15 | 'gray', '000,fff', ColorMap('000,fff')
16 | ])
17 | def cmap_source(request):
18 | '''Fixture for multiple ways to specify the "gray" color map.'''
19 | return request.param
20 |
21 |
22 | class ImageTuple(NamedTuple):
23 | '''NamedTuple for image-array setups.'''
24 | grid: bool
25 | nchannels: list
26 | channel_front: bool
27 | width: int
28 | height: int
29 | array: np.ndarray
30 |
31 |
32 | @pytest.fixture(scope='session', params=product(
33 | [False, True],
34 | [1, 3],
35 | [False, True],
36 | [5, 10],
37 | [5, 10],
38 | [np.float64, np.uint8]
39 | ))
40 | def image_tuple(request):
41 | '''Image-array setups with varying size, type, number of channels, channel position and grid dimension.'''
42 | grid, nchannels, channel_front, width, height, dtype = request.param
43 |
44 | shape = (height, width)
45 | if channel_front:
46 | shape = (nchannels,) + shape
47 | else:
48 | shape = shape + (nchannels,)
49 | shape = (1,) * grid + shape
50 |
51 | return ImageTuple(
52 | grid,
53 | nchannels,
54 | channel_front,
55 | width,
56 | height,
57 | np.ones(shape, dtype=dtype),
58 | )
59 |
60 |
61 | def test_get_cmap(cmap_source):
62 | '''Test whether get_cmap handles its supported cmap types correctly.'''
63 | cmap = get_cmap(cmap_source)
64 | assert isinstance(cmap, ColorMap), 'Returned object is not a ColorMap!'
65 | assert cmap.source == '000,fff', 'Mismatch in source code of returned ColorMap instance.'
66 |
67 |
68 | def test_palette(cmap_source):
69 | '''Test whether palette returns the correct palette for all of its supported types.'''
70 | pal = palette(cmap_source)
71 | expected_pal = np.repeat(np.arange(256, dtype=np.uint8)[:, None], 3, axis=1)
72 | assert np.allclose(expected_pal, pal)
73 |
74 |
75 | @pytest.mark.parametrize('ndim', [1, 4, 5, 6])
76 | def test_imgify_wrong_dim(ndim):
77 | '''Test whether imgify fails for an unsupported number of dimensions.'''
78 | with pytest.raises(TypeError):
79 | imgify(np.zeros((1,) * ndim))
80 |
81 |
82 | @pytest.mark.parametrize('ndim', [1, 2, 5, 6])
83 | def test_imgify_grid_wrong_dim(ndim):
84 | '''Test whether imgify fails for an unsupported number of dimensions with grid=True.'''
85 | with pytest.raises(TypeError):
86 | imgify(np.zeros((1,) * ndim), grid=True)
87 |
88 |
89 | @pytest.mark.parametrize('grid', [[1], (1,), 1, [1, 1, 1], (1, 1, 1)])
90 | def test_imgify_grid_bad_grid(grid):
91 | '''Test whether imgify fails for unsupported grid values.'''
92 | with pytest.raises(TypeError):
93 | imgify(np.zeros((1,) * 4), grid=grid)
94 |
95 |
96 | @pytest.mark.parametrize('grid,nchannels', product([False, True], [2, 4]))
97 | def test_imgify_wrong_channels(grid, nchannels):
98 | '''Test whether imgify fails for an unsupported number of dimensions with grid=True.'''
99 | with pytest.raises(TypeError):
100 | imgify(np.zeros((1,) * grid + (2, 2, nchannels)), grid=grid)
101 |
102 |
103 | def test_imgify_container(image_tuple):
104 | '''Test whether imgify produces the correct PIL Image container'''
105 | image = imgify(image_tuple.array, grid=image_tuple.grid)
106 | assert image.mode == ('P' if image_tuple.nchannels == 1 else 'RGB'), 'Mode mismatch!'
107 | assert image.width == image_tuple.width, 'Width mismatch!'
108 | assert image.height == image_tuple.height, 'Height mismatch!'
109 |
110 |
111 | @pytest.mark.parametrize('vmin,vmax,symmetric', product([None, 1.], [None, 2.], [False, True]))
112 | def test_imgify_normalization(vmin, vmax, symmetric):
113 | '''Test whether imgify normalizes as expected.'''
114 | array = np.array([[-1., 0., 3.]])
115 |
116 | image = imgify(array, cmap='gray', vmin=vmin, vmax=vmax, symmetric=symmetric)
117 |
118 | if vmin is None:
119 | if symmetric:
120 | vmin = -np.abs(array).max()
121 | else:
122 | vmin = array.min()
123 | if vmax is None:
124 | if symmetric:
125 | vmax = np.abs(array).max()
126 | else:
127 | vmax = array.max()
128 |
129 | expected = (((array - vmin) / (vmax - vmin)) * 255.).clip(0, 255).astype(np.uint8)
130 |
131 | assert np.allclose(np.array(image), expected)
132 |
133 |
134 | @pytest.mark.parametrize('ndim', [1, 2, 5, 6])
135 | def test_gridify_wrong_dim(ndim):
136 | '''Test whether imgify fails for an unsupported number of dimensions.'''
137 | with pytest.raises(TypeError):
138 | gridify(np.zeros((1,) * ndim))
139 |
140 |
141 | @pytest.mark.parametrize('channel_front,nchannels', product([False, True], [2, 4]))
142 | def test_gridify_wrong_channels(channel_front, nchannels):
143 | '''Test whether gridify fails for an unsupported number of channels in both channel positions.'''
144 | shape = (2, 2)
145 | if channel_front:
146 | shape = (nchannels,) + shape
147 | else:
148 | shape = shape + (nchannels,)
149 | shape = (1,) + shape
150 |
151 | with pytest.raises(TypeError):
152 | gridify(np.zeros(shape))
153 |
154 |
155 | @pytest.mark.parametrize('shape,expected_shape', [
156 | [(4, 2, 2, 3), (4, 4, 3)],
157 | [(4, 2, 2, 1), (4, 4, 1)],
158 | [(4, 2, 2), (4, 4, 1)],
159 | [(4, 3, 2, 2), (4, 4, 3)],
160 | [(4, 1, 2, 2), (4, 4, 1)],
161 | ])
162 | def test_gridify_shape(shape, expected_shape):
163 | '''Test whether gridify produces the correct shape.'''
164 | output = gridify(np.zeros(shape))
165 | assert expected_shape == output.shape
166 |
167 |
168 | @pytest.mark.parametrize('fill_value', [None, 0.])
169 | def test_gridify_fill(fill_value):
170 | '''Test whether gridify fills empty pixels with the correct value.'''
171 | array = np.array([[[[1.]]]])
172 | output = gridify(array, fill_value=fill_value, shape=(1, 2))
173 | expected_value = array.min() if fill_value is None else fill_value
174 | assert output[0, 1, 0] == expected_value
175 |
176 |
177 | @pytest.mark.parametrize('writer_params', [None, {}])
178 | def test_imsave_container(image_tuple, writer_params):
179 | '''Test whether imsave produces a file, which loads as the correct PIL Image container.'''
180 | fp = BytesIO()
181 | imsave(fp, image_tuple.array, grid=image_tuple.grid, format='png', writer_params=writer_params)
182 | fp.seek(0)
183 | image = Image.open(fp)
184 | assert image.mode == ('P' if image_tuple.nchannels == 1 else 'RGB'), 'Mode mismatch!'
185 | assert image.width == image_tuple.width, 'Width mismatch!'
186 | assert image.height == image_tuple.height, 'Height mismatch!'
187 |
188 |
189 | @pytest.mark.parametrize('symmetric,dim,expected_bounds', [
190 | (False, None, (np.array([[[[-1.]]], [[[0.]]]]), np.array([[[[-0.2]]], [[[0.8]]]]))),
191 | (False, (1, 2, 3), (np.array([[[[-1.]]], [[[0.]]]]), np.array([[[[-0.2]]], [[[0.8]]]]))),
192 | (False, (0, 1, 2, 3), (np.array([[[[-1.]]]]), np.array([[[[0.8]]]]))),
193 | (True, None, (np.array([[[[-1.]]], [[[-0.8]]]]), np.array([[[[1.]]], [[[0.8]]]]))),
194 | (True, (1, 2, 3), (np.array([[[[-1.]]], [[[-0.8]]]]), np.array([[[[1.]]], [[[0.8]]]]))),
195 | (True, (0, 1, 2, 3), (np.array([[[[-1.]]]]), np.array([[[[1.]]]]))),
196 | ])
197 | def test_interval_norm_bounds(symmetric, dim, expected_bounds):
198 | '''Test whether interval_norm_bounds computes the correct minimum and maximum values.'''
199 | array = np.linspace(-1., 0.8, 10).reshape((2, 1, 5, 1))
200 | bounds = interval_norm_bounds(array, symmetric=symmetric, dim=dim)
201 | assert np.allclose(expected_bounds, bounds)
202 |
--------------------------------------------------------------------------------
/tests/test_rules.py:
--------------------------------------------------------------------------------
1 | '''Tests for various rules. Rules are re-implemented in a slower, less complicated way, which closely follows the
2 | definition in the original works, which makes them easier to compare and thus less likely to be wrong.
3 | '''
4 | from functools import wraps, partial
5 | from copy import deepcopy
6 |
7 | import pytest
8 | import torch
9 | from zennit.rules import Epsilon, ZPlus, AlphaBeta, Gamma, ZBox, Norm, WSquare, Flat
10 | from zennit.rules import Pass, ReLUDeconvNet, ReLUGuidedBackprop, ReLUBetaSmooth
11 | from zennit.rules import zero_bias as name_zero_bias
12 |
13 |
14 | def stabilize(input, epsilon=1e-6):
15 | '''Replicates zennit.core.stabilize for testing.'''
16 | return input + ((input == 0.).to(input) + input.sign()) * epsilon
17 |
18 |
19 | def as_matrix(module_linear, input, output):
20 | '''Get flat weight and bias using the jacobian.'''
21 | jac = torch.autograd.functional.jacobian(module_linear, input[None])
22 | weight = jac.reshape((output.numel(), input.numel()))
23 | bias = output.flatten() - weight @ input.flatten()
24 | return weight, bias
25 |
26 |
27 | RULES_LINEAR = []
28 | RULES_SIMPLE = []
29 |
30 |
31 | def replicates(target_list, replicated_func, **kwargs):
32 | '''Decorator to indicate a replication of a function for testing.'''
33 | def wrapper(func):
34 | '''Append to ``RULES_LINEAR`` as partial, given ``kwargs``.'''
35 | target_list.append(
36 | pytest.param(
37 | (partial(replicated_func, **kwargs), partial(func, **kwargs)),
38 | id=replicated_func.__name__
39 | )
40 | )
41 | return func
42 | return wrapper
43 |
44 |
45 | def flat_module_params(func):
46 | '''Decorator to to copy module and overwrite module params completely with ones (for rule_flat).'''
47 | @wraps(func)
48 | def wrapped(module_linear, *args, **kwargs):
49 | '''Make a deep copy of module_linear, fill all parameters inline with ones, and call func with the copy.'''
50 | module_copy = deepcopy(module_linear)
51 | for param in module_copy.parameters():
52 | param.requires_grad_(False).fill_(1.0)
53 | return func(module_copy, *args, **kwargs)
54 | return wrapped
55 |
56 |
57 | def matrix_form(func):
58 | '''Decorator to wrap function such that weights and bias supplied in matrix-form and input and output are flattened
59 | appropriately.'''
60 | @wraps(func)
61 | def wrapped(module_linear, input, output, **kwargs):
62 | '''Get flat weight matrix and bias using the jacobian, flatten input and output, and pass arguments to func.'''
63 | weight, bias = as_matrix(module_linear, input[0], output[0])
64 | return func(
65 | weight,
66 | bias,
67 | input.flatten(start_dim=1),
68 | output.flatten(start_dim=1),
69 | **kwargs
70 | ).reshape(input.shape)
71 | return wrapped
72 |
73 |
74 | def with_grad(func):
75 | '''Decorator to wrap function such that the gradient is computed and passed to the function instead of module.'''
76 | @wraps(func)
77 | def wrapped(module, input, output, **kwargs):
78 | '''Get gradient and pass along input, output and keyword arguments to func.'''
79 | gradient, = torch.autograd.grad(module(input), input, output)
80 | return func(
81 | gradient,
82 | input,
83 | output,
84 | **kwargs
85 | )
86 | return wrapped
87 |
88 |
89 | def zero_bias(zero_params, bias):
90 | '''Return a tensor with zeros like ``bias`` if zero_params is equal to or contains the string ``'bias'``, otherwise
91 | return the unmodified tensor ``bias``.'''
92 | if zero_params is None:
93 | zero_params = []
94 | if bias is not None and (zero_params == 'bias' or 'bias' in zero_params):
95 | return torch.zeros_like(bias)
96 | return bias
97 |
98 |
99 | @replicates(RULES_LINEAR, Epsilon, epsilon=1e-6)
100 | @replicates(RULES_LINEAR, Epsilon, epsilon=1e-6, zero_params='bias')
101 | @replicates(RULES_LINEAR, Epsilon, epsilon=1.0)
102 | @replicates(RULES_LINEAR, Epsilon, epsilon=1.0, zero_params='bias')
103 | @replicates(RULES_LINEAR, Norm)
104 | @matrix_form
105 | def rule_epsilon(weight, bias, input, relevance, epsilon=1e-6, zero_params=None):
106 | '''Replicates the Epsilon rule.'''
107 | bias = zero_bias(zero_params, bias)
108 | return input * ((relevance / stabilize(input @ weight.t() + bias, epsilon)) @ weight)
109 |
110 |
111 | @replicates(RULES_LINEAR, ZPlus)
112 | @replicates(RULES_LINEAR, ZPlus, zero_params='bias')
113 | @matrix_form
114 | def rule_zplus(weight, bias, input, relevance, zero_params=None):
115 | '''Replicates the ZPlus rule.'''
116 | bias = zero_bias(zero_params, bias)
117 | wplus = weight.clamp(min=0)
118 | wminus = weight.clamp(max=0)
119 | xplus = input.clamp(min=0)
120 | xminus = input.clamp(max=0)
121 | zval = xplus @ wplus.t() + xminus @ wminus.t() + bias.clamp(min=0)
122 | rfac = relevance / stabilize(zval)
123 | return xplus * (rfac @ wplus) + xminus * (rfac @ wminus)
124 |
125 |
126 | @replicates(RULES_LINEAR, Gamma, gamma=0.25)
127 | @replicates(RULES_LINEAR, Gamma, gamma=0.25, zero_params='bias')
128 | @replicates(RULES_LINEAR, Gamma, gamma=0.5)
129 | @replicates(RULES_LINEAR, Gamma, gamma=0.5, zero_params='bias')
130 | @matrix_form
131 | def rule_gamma(weight, bias, input, relevance, gamma, zero_params=None):
132 | '''Replicates the Gamma rule.'''
133 | output = input @ weight.t() + bias
134 | bias = zero_bias(zero_params, bias)
135 | pinput = input.clamp(min=0)
136 | ninput = input.clamp(max=0)
137 | pwgamma = weight + weight.clamp(min=0) * gamma
138 | nwgamma = weight + weight.clamp(max=0) * gamma
139 | pbgamma = bias + bias.clamp(min=0) * gamma
140 | nbgamma = bias + bias.clamp(max=0) * gamma
141 |
142 | pgrad_out = (relevance / stabilize(pinput @ pwgamma.t() + ninput @ nwgamma.t() + pbgamma)) * (output > 0.)
143 | positive = pinput * (pgrad_out @ pwgamma) + ninput * (pgrad_out @ nwgamma)
144 |
145 | ngrad_out = (relevance / stabilize(pinput @ nwgamma.t() + ninput @ pwgamma.t() + nbgamma)) * (output < 0.)
146 | negative = pinput * (ngrad_out @ nwgamma) + ninput * (ngrad_out @ pwgamma)
147 |
148 | return positive + negative
149 |
150 |
151 | @replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0)
152 | @replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0, zero_params='bias')
153 | @replicates(RULES_LINEAR, AlphaBeta, alpha=2.0, beta=1.0)
154 | @replicates(RULES_LINEAR, AlphaBeta, alpha=1.0, beta=0.0, zero_params='bias')
155 | @matrix_form
156 | def rule_alpha_beta(weight, bias, input, relevance, alpha, beta, zero_params=None):
157 | '''Replicates the AlphaBeta rule.'''
158 | bias = zero_bias(zero_params, bias)
159 | wplus = weight.clamp(min=0)
160 | wminus = weight.clamp(max=0)
161 | xplus = input.clamp(min=0)
162 | xminus = input.clamp(max=0)
163 | zalpha = xplus @ wplus.t() + xminus @ wminus.t() + bias.clamp(min=0)
164 | zbeta = xplus @ wminus.t() + xminus @ wplus.t() + bias.clamp(max=0)
165 | ralpha = relevance / stabilize(zalpha)
166 | rbeta = relevance / stabilize(zbeta)
167 | result_alpha = xplus * (ralpha @ wplus) + xminus * (ralpha @ wminus)
168 | result_beta = xplus * (rbeta @ wminus) + xminus * (rbeta @ wplus)
169 | return alpha * result_alpha - beta * result_beta
170 |
171 |
172 | @replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0)
173 | @replicates(RULES_LINEAR, ZBox, low=-3.0, high=3.0, zero_params='bias')
174 | @matrix_form
175 | def rule_zbox(weight, bias, input, relevance, low, high, zero_params=None):
176 | '''Replicates the ZBox rule.'''
177 | wplus = weight.clamp(min=0)
178 | wminus = weight.clamp(max=0)
179 | low = torch.tensor(low).expand_as(input).to(input)
180 | high = torch.tensor(high).expand_as(input).to(input)
181 | zval = input @ weight.t() - low @ wplus.t() - high @ wminus.t()
182 | rfac = relevance / stabilize(zval)
183 | return input * (rfac @ weight) - low * (rfac @ wplus) - high * (rfac @ wminus)
184 |
185 |
186 | @replicates(RULES_LINEAR, WSquare)
187 | @replicates(RULES_LINEAR, WSquare, zero_params='bias')
188 | @matrix_form
189 | def rule_wsquare(weight, bias, input, relevance, zero_params=None):
190 | '''Replicates the WSquare rule.'''
191 | bias = zero_bias(zero_params, bias)
192 | wsquare = weight ** 2
193 | zval = torch.ones_like(input) @ wsquare.t() + bias ** 2
194 | rfac = relevance / stabilize(zval)
195 | return rfac @ wsquare
196 |
197 |
198 | @replicates(RULES_LINEAR, Flat)
199 | @flat_module_params
200 | @matrix_form
201 | def rule_flat(wflat, bias, input, relevance):
202 | '''Replicates the Flat rule.'''
203 | zval = torch.ones_like(input) @ wflat.t()
204 | rfac = relevance / stabilize(zval)
205 | return rfac @ wflat
206 |
207 |
208 | @replicates(RULES_SIMPLE, Pass)
209 | def rule_pass(module, input, relevance):
210 | '''Replicates the Pass rule.'''
211 | return relevance
212 |
213 |
214 | @replicates(RULES_SIMPLE, ReLUDeconvNet)
215 | def rule_relu_deconvnet(module, input, relevance):
216 | '''Replicates the ReLUDeconvNet rule.'''
217 | return relevance.clamp(min=0)
218 |
219 |
220 | @replicates(RULES_SIMPLE, ReLUGuidedBackprop)
221 | @with_grad
222 | def rule_relu_guidedbackprop(gradient, input, relevance):
223 | '''Replicates the ReLUGuidedBackprop rule.'''
224 | return gradient * (relevance > 0.)
225 |
226 |
227 | @replicates(RULES_SIMPLE, ReLUBetaSmooth, beta_smooth=10.)
228 | @replicates(RULES_SIMPLE, ReLUBetaSmooth, beta_smooth=1.)
229 | def rule_relu_beta_smooth(module, input, relevance, beta_smooth):
230 | '''Replicates the ReLUBetaSmooth rule.'''
231 | return relevance * torch.sigmoid(beta_smooth * input)
232 |
233 |
234 | @pytest.fixture(scope='session', params=RULES_LINEAR)
235 | def rule_pair_linear(request):
236 | '''Fixture to supply ``RULES_LINEAR``.'''
237 | return request.param
238 |
239 |
240 | @pytest.fixture(scope='session', params=RULES_SIMPLE)
241 | def rule_pair_simple(request):
242 | '''Fixture to supply ``RULES_SIMPLE``.'''
243 | return request.param
244 |
245 |
246 | def compare_rule_pair(module, data, rule_pair):
247 | '''Compare rules with their replicated versions.'''
248 | rule_hook, rule_replicated = rule_pair
249 |
250 | input = data.clone().requires_grad_()
251 | handle = rule_hook().register(module)
252 | try:
253 | output = module(input)
254 | relevance_hook, = torch.autograd.grad(output, input, grad_outputs=output)
255 | finally:
256 | handle.remove()
257 |
258 | relevance_replicated = rule_replicated(module, input, output)
259 |
260 | assert torch.allclose(relevance_hook, relevance_replicated, atol=1e-5)
261 |
262 |
263 | def test_linear_rule(module_linear, data_linear, rule_pair_linear):
264 | '''Test whether replicated and original implementations of rules for linear layers agree.'''
265 | compare_rule_pair(module_linear, data_linear, rule_pair_linear)
266 |
267 |
268 | def test_simple_rule(module_simple, data_simple, rule_pair_simple):
269 | '''Test whether replicated and original implementations of rules for simple layers agree.'''
270 | compare_rule_pair(module_simple, data_simple, rule_pair_simple)
271 |
272 |
273 | def test_alpha_beta_invalid_values():
274 | '''Test whether AlphaBeta raises ValueErrors for negative alpha/beta or when alpha - beta is not equal to 1.'''
275 | with pytest.raises(ValueError):
276 | AlphaBeta(alpha=-1.)
277 | with pytest.raises(ValueError):
278 | AlphaBeta(beta=-1.)
279 | with pytest.raises(ValueError):
280 | AlphaBeta(alpha=1., beta=1.)
281 |
282 |
283 | @pytest.mark.parametrize('params', [None, 'weight', ['weight'], 'bias', ['bias'], ['weight', 'bias']])
284 | def test_zero_bias(params):
285 | '''Test whether zero_bias correctly appends 'bias' to the zero_params list/str used for ParamMod.'''
286 | result = name_zero_bias(params)
287 | assert isinstance(result, list)
288 | assert 'bias' in result
289 |
--------------------------------------------------------------------------------
/tests/test_torchvision.py:
--------------------------------------------------------------------------------
1 | '''Tests for torchvision-model-specific canonizers.'''
2 | import pytest
3 | import torch
4 | from torchvision.models import vgg11_bn, resnet18, resnet50
5 | from torchvision.models.resnet import BasicBlock as ResNetBasicBlock, Bottleneck as ResNetBottleneck
6 | from helpers import assert_identity_hook, randomize_bnorm, nograd
7 |
8 | from zennit.core import Composite, RemovableHandleList
9 | from zennit.torchvision import VGGCanonizer, ResNetCanonizer
10 | from zennit.types import BatchNorm
11 |
12 |
13 | def test_vgg_canonizer(batchsize):
14 | '''Test whether VGGCanonizer merges BatchNorm modules correctly and keeps the output unchanged.'''
15 | model = randomize_bnorm(nograd(vgg11_bn().eval().to(torch.float64)))
16 | data = torch.randn((batchsize, 3, 224, 224), dtype=torch.float64)
17 | output_before = model(data)
18 |
19 | handles = RemovableHandleList(
20 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!'))
21 | for module in model.modules() if isinstance(module, BatchNorm)
22 | )
23 |
24 | canonizer = VGGCanonizer()
25 | composite = Composite(canonizers=[canonizer])
26 |
27 | try:
28 | composite.register(model)
29 | output_canonizer = model(data)
30 | finally:
31 | composite.remove()
32 | handles.remove()
33 |
34 | # this assumes the batch-norm is not initialized as the identity
35 | handles = RemovableHandleList(
36 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!'))
37 | for module in model.modules() if isinstance(module, BatchNorm)
38 | )
39 | try:
40 | output_after = model(data)
41 | finally:
42 | handles.remove()
43 |
44 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!'
45 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!'
46 |
47 |
48 | @pytest.mark.parametrize('model_fn,block_type', [
49 | (resnet18, ResNetBasicBlock),
50 | (resnet50, ResNetBottleneck),
51 | ])
52 | def test_resnet_canonizer(batchsize, model_fn, block_type):
53 | '''Test whether ResNetCanonizer overwrites and restores the Bottleneck/BasicBlock forward, merges BatchNorm modules
54 | correctly and keeps the output unchanged.
55 | '''
56 | model = randomize_bnorm(nograd(model_fn().eval().to(torch.float64)))
57 | data = torch.randn((batchsize, 3, 224, 224), dtype=torch.float64)
58 | blocks = [module for module in model.modules() if isinstance(module, block_type)]
59 |
60 | assert blocks, 'Model has no blocks!'
61 | assert all(
62 | block.forward == block_type.forward.__get__(block) for block in blocks
63 | ), 'Model has its forward already overwritten!'
64 |
65 | output_before = model(data)
66 |
67 | handles = RemovableHandleList(
68 | module.register_forward_hook(assert_identity_hook(True, 'BatchNorm was not merged!'))
69 | for module in model.modules() if isinstance(module, BatchNorm)
70 | )
71 |
72 | canonizer = ResNetCanonizer()
73 | composite = Composite(canonizers=[canonizer])
74 |
75 | try:
76 | composite.register(model)
77 | assert not any(
78 | block.forward == block_type.forward.__get__(block) for block in blocks
79 | ), 'Model forward was not overwritten!'
80 | output_canonizer = model(data)
81 | finally:
82 | composite.remove()
83 | handles.remove()
84 |
85 | # this assumes the batch-norm is not initialized as the identity
86 | handles = RemovableHandleList(
87 | module.register_forward_hook(assert_identity_hook(False, 'BatchNorm was not restored!'))
88 | for module in model.modules() if isinstance(module, BatchNorm)
89 | )
90 | try:
91 | output_after = model(data)
92 | finally:
93 | handles.remove()
94 |
95 | assert all(
96 | block.forward == block_type.forward.__get__(block) for block in blocks
97 | ), 'Model forward was not restored!'
98 | assert torch.allclose(output_canonizer, output_before, rtol=1e-5), 'Canonizer changed output after register!'
99 | assert torch.allclose(output_before, output_after, rtol=1e-5), 'Canonizer changed output after remove!'
100 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | skip_missing_interpreters = true
3 | envlist = py37,py38,py39,pylint,flake8,docs
4 |
5 | [testenv]
6 | extras = tests
7 | setenv =
8 | COVERAGE_FILE = {toxworkdir}/.coverage.{envname}
9 | commands =
10 | pytest \
11 | --cov "{envsitepackagesdir}/zennit" \
12 | --cov-config "{toxinidir}/tox.ini" \
13 | {posargs:.}
14 |
15 | [testenv:coverage]
16 | deps =
17 | coverage
18 | setenv =
19 | COVERAGE_FILE = {toxworkdir}/.coverage
20 | skip_install = true
21 | commands =
22 | coverage combine
23 | coverage report -m
24 | depends = py37,py38,py39
25 |
26 | [testenv:docs]
27 | basepython = python3.9
28 | extras = docs
29 | commands =
30 | sphinx-build \
31 | --color \
32 | -W \
33 | --keep-going \
34 | -d "{toxinidir}/docs/doctree" \
35 | -b html \
36 | "{toxinidir}/docs/source" \
37 | "{toxinidir}/docs/build" \
38 | {posargs}
39 |
40 | [testenv:flake8]
41 | basepython = python3.9
42 | changedir = {toxinidir}
43 | deps =
44 | flake8
45 | commands =
46 | flake8 "{toxinidir}/src/zennit" "{toxinidir}/tests" {posargs}
47 |
48 |
49 | [testenv:pylint]
50 | basepython = python3.9
51 | deps =
52 | pylint
53 | pytest
54 | changedir = {toxinidir}
55 | commands =
56 | pylint --rcfile=pylintrc --output-format=parseable {toxinidir}/src/zennit {toxinidir}/tests
57 |
58 | [flake8]
59 | # R902 Too many instance attributes
60 | # R913 Too many arguments
61 | # R914 Too many local variables
62 | # W503 Line-break before binary operator
63 | ignore = R902,R913,R914,W503
64 |
65 | exclude=.venv,.git,.tox,build,dist,docs,*egg,*.ini
66 |
67 | max-line-length = 120
68 |
69 | [pytest]
70 | testpaths = tests
71 | addopts = -ra -l
72 |
73 | [coverage:run]
74 | parallel = true
75 | branch = true
76 |
77 | [coverage:report]
78 | skip_covered = true
79 | show_missing = true
80 |
81 | [coverage:paths]
82 | source = src/zennit
83 | */.tox/*/lib/python*/site-packages/zennit
84 | */src/zennit
85 |
--------------------------------------------------------------------------------