├── .DS_Store
├── .coveragerc
├── .gitattributes
├── .github
├── PULL_REQUEST_TEMPLATE.md
└── workflows
│ └── unit_tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── images
├── .DS_Store
├── bert_ner_explainer.png
├── bert_qa_explainer.png
├── coverage.svg
├── distilbert_example.png
├── distilbert_example_negative.png
├── logotype@1920x.png
├── logotype@1920x_transparent.png
├── multilabel_example.png
├── pairwise_cross_encoder_example.png
├── tight@1920x.png
├── tight@1920x_transparent.png
├── transformers_interpret_A01.ai
├── vision
│ ├── alpha_scaling_sbs.png
│ ├── heatmap_sbs.png
│ ├── masked_image_sbs.png
│ └── overlay_sbs.png
├── zero_shot_example.png
└── zero_shot_example2.png
├── notebooks
├── .DS_Store
├── image_classification_explainer.ipynb
├── multiclass_classification_example.ipynb
└── ner_example.ipynb
├── poetry.lock
├── pyproject.toml
├── requirements.txt
├── setup.py
├── test
├── .DS_Store
├── .gitkeep
├── text
│ ├── test_explainer.py
│ ├── test_multilabel_classification_explainer.py
│ ├── test_question_answering_explainer.py
│ ├── test_sequence_classification_explainer.py
│ ├── test_token_classification_explainer.py
│ └── test_zero_shot_explainer.py
└── vision
│ └── test_image_classification.py
└── transformers_interpret
├── .DS_Store
├── __init__.py
├── attributions.py
├── errors.py
├── explainer.py
└── explainers
├── text
├── __init__.py
├── multilabel_classification.py
├── question_answering.py
├── sequence_classification.py
├── token_classification.py
└── zero_shot_classification.py
└── vision
├── attribution_types.py
└── image_classification.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/.DS_Store
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | # .coveragerc to control coverage.py
2 | # Point at this file with PYCHARM_COVERAGERC=/PATH/TO/THIS/.coveragerc
3 | # May require reading PYCHARM_COVERAGERC in run_coverage.py in intellij helpers
4 |
5 | [run]
6 | branch = True
7 |
8 | omit =
9 | */runners/*
10 | *docs*
11 | *stubs*
12 | *examples*
13 | *tests*
14 | */config/plugins/python/helpers/*
15 |
16 | [report]
17 | omit =
18 | */runners/*
19 | *docs*
20 | *stubs*
21 | *examples*
22 | *test*
23 |
24 |
25 | exclude_lines =
26 |
27 | # Don't complain if non-runnable code isn't run:
28 | if __name__ == .__main__.:
29 |
30 | raise NotImplementedError()
31 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # PR Description
4 |
5 |
6 |
7 |
8 | ## Motivation and Context
9 |
10 |
11 |
12 | References issue: # (ISSUE)
13 |
14 | ## Tests and Coverage
15 |
16 |
17 |
18 | ## Types of changes
19 |
20 |
21 |
22 | - [ ] Bug fix (non-breaking change which fixes an issue)
23 | - [ ] New feature (non-breaking change which adds functionality)
24 | - [ ] Docs (Added to or improved Transformers Interpret's documentation)
25 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
26 |
27 | ## Final Checklist:
28 |
29 |
30 |
31 | - [ ] My code follows the [code style](https://github.com/cdpierse/transformers-interpret/blob/master/CONTRIBUTING.md) of this project.
32 | - [ ] I have updated the documentation accordingly.
33 | - [ ] I have added tests to cover my changes.
34 | - [ ] All new and existing tests passed.
35 |
--------------------------------------------------------------------------------
/.github/workflows/unit_tests.yml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Unit Tests
3 | on: push
4 | jobs:
5 | tests:
6 | runs-on: ubuntu-20.04
7 | steps:
8 | - name: Checkout
9 | uses: actions/checkout@v1
10 |
11 | - name: Set up Python 3.7.13
12 | uses: actions/setup-python@v2
13 | with:
14 | python-version: 3.7.13
15 |
16 | - name: Install poetry
17 | run: |
18 | which python
19 | which pip
20 | pip install poetry
21 |
22 | - name: Install Python dependencies
23 | if: steps.cache-poetry.outputs.cache-hit != 'true'
24 | run: |
25 | poetry install
26 |
27 | - name: Run Unit tests
28 |
29 | run: |
30 | export PATH="$HOME/.pyenv/bin:$PATH"
31 | export PYTHONPATH="."
32 |
33 | poetry run pytest -s --cov=transformers_interpret/ --cov-report term-missing \
34 | test
35 |
36 | - name: Report coverage
37 | run: |
38 | export PATH="$HOME/.pyenv/bin:$PATH"
39 | poetry run coverage report --fail-under=50
40 | poetry run coverage html -d unit_htmlcov
41 |
42 | - uses: actions/upload-artifact@v2
43 | with:
44 | name: ti-unit-coverage
45 | path: ti-unit-htmlcov/
46 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 | references.md
127 | test.html
128 | api.py
129 | .vscode
130 | notebooks/explore_captum.ipynb
131 | *.html
132 |
133 | # Pycharm
134 | .idea/
135 | release_script.py
136 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: https://github.com/psf/black
9 | rev: 22.3.0
10 | hooks:
11 | - id: black
12 | args: [--line-length=120, --target-version=py38]
13 | - repo: https://github.com/PyCQA/isort
14 | rev: 5.10.1
15 | hooks:
16 | - id: isort
17 | args: [-m, '3', --tc, --profile, black]
18 | - repo: https://github.com/myint/autoflake
19 | rev: v1.4
20 | hooks:
21 | - id: autoflake
22 | args: [--in-place, --remove-all-unused-imports]
23 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Pierse"
5 | given-names: "Charles"
6 | title: "Transformers Interpret"
7 | version: 0.5.2
8 | date-released: 2021-02-14
9 | url: "https://github.com/cdpierse/transformers-interpret"
10 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # 👐 Contributing to Transformers Interpret
2 |
3 | First off, thank you for even considering contributing to this package, every contribution big or small is greatly appreciated. Community contributions are what keep projects like this fueled and constantly improving, so a big thanks to you!
4 |
5 | Below are some sections detailing the guidelines we'd like you to follow to make your contribution as seamless as possible.
6 |
7 | - [Code of Conduct](#coc)
8 | - [Asking a Question and Discussions](#question)
9 | - [Issues, Bugs, and Feature Requests](#issue)
10 | - [Submission Guidelines](#submit)
11 | - [Code Style and Formatting](#code)
12 |
13 | ## 📜 Code of Conduct
14 |
15 | As contributors and maintainers of Transformers Interpret, we pledge to respect everyone who contributes by posting issues, updating documentation, submitting pull requests, providing feedback in comments, and any other activities.
16 |
17 | Communication within our community must be constructive and never resort to personal attacks, trolling, public or private harassment, insults, or other unprofessional conduct.
18 |
19 | We promise to extend courtesy and respect to everyone involved in this project regardless of gender, gender identity, sexual orientation, disability, age, race, ethnicity, religion, or level of experience. We expect anyone contributing to the project to do the same.
20 |
21 | If any member of the community violates this code of conduct, the maintainers of Transformers Interpret may take action, removing issues, comments, and PRs or blocking accounts as deemed appropriate.
22 |
23 | If you are subject to or witness unacceptable behavior, or have any other concerns, please drop us a line at [charlespierse@gmail.com](mailto://charlespierse@gmail.com)
24 |
25 | ## 🗣️ Got a Question ? Want to start a Discussion ?
26 |
27 | We would like to use [Github discussions](https://github.com/cdpierse/transformers-interpret/discussions) as the central hub for all community discussions, questions, and everything else in between. While Github discussions is a new service (as of 2021) we believe that it really helps keep this repo as one single source to find all relevant information. Our hope is that this page functions as a record of all the conversations that help contribute to the project's development.
28 |
29 | We also highly encourage general and theoretical discussion. Much of the Transformers Interpret package implements algorithms in the field of explainable AI (XAI) which is itself very nascent, so discussions around the topic are very welcome. Everyone has something to learn and to teach !
30 |
31 | If you are new to [Github discussions](https://github.com/cdpierse/transformers-interpret/discussions) it is a very similar experience to Stack Overflow with an added element of general discussion and discourse rather than solely being question and answer based.
32 |
33 | ## 🪲 Found an Issue or Bug?
34 |
35 | If you find a bug in the source code, you can help us by submitting an issue to our [Issues Page](https://github.com/cdpierse/transformers-interpret/issues). Even better you can submit a Pull Request if you have a fix.
36 |
37 | See [below](#submit) for some guidelines.
38 |
39 | ## ✉️ Submission Guidelines
40 |
41 | ### Submitting an Issue
42 |
43 | Before you submit your issue search the archive, maybe your question was already answered. If you feel like your issue is not specific and more of a general question about a design decision, or algorithm implementation maybe start a [discussion](https://github.com/cdpierse/transformers-interpret/discussions) instead, this helps keep the issues less cluttered and encourages more open ended conversation.
44 |
45 | If your issue appears to be a bug, and hasn't been reported, open a new issue.
46 | Help us to maximize the effort we can spend fixing issues and adding new
47 | features, by not reporting duplicate issues. Providing the following information will increase the
48 | chances of your issue being dealt with quickly:
49 |
50 | - **Describe the bug** - A clear and concise description of what the bug is.
51 | - **To Reproduce**- Steps to reproduce the behavior.
52 | - **Expected behavior** - A clear and concise description of what you expected to happen.
53 | - **Environment**
54 | - Transformers Interpret version
55 | - Python version
56 | - OS
57 | - **Suggest a Fix** - if you can't fix the bug yourself, perhaps you can point to what might be
58 | causing the problem (line of code or commit)
59 |
60 | When you submit a PR you will be presented with a PR template, please fill this in as best you can.
61 |
62 | ### Submitting a Pull Request
63 |
64 | Before you submit your pull request consider the following guidelines:
65 |
66 | - Search [GitHub](https://github.com/cdpierse/transformers-interpret/pulls) for an open or closed Pull Request
67 | that relates to your submission. You don't want to duplicate effort.
68 | - Create a fork of a repository and set up a remote that points to the original project:
69 |
70 | ```shell
71 | git remote add upstream git@github.com:cdpierse/transformers-interpret.git
72 | ```
73 |
74 | - Make your changes in a new git branch, based off master branch:
75 |
76 | ```shell
77 | git checkout -b my-fix-branch master
78 | ```
79 |
80 | - There are three typical branch name conventions we try to stick to, there are of course always exceptions to the rule but generally the branch name format is one of:
81 |
82 | - **fix/my-branch-name** - this is for fixes or patches
83 | - **docs/my-branch-name** - this is for a contribution to some form of documentation i.e. Readme etc
84 | - **feature/my-branch-name** - this is for new features or additions to the code base
85 |
86 | - Create your patch, **including appropriate test cases**.
87 | - **A note on tests**: This package is in a odd position in that is has some heavy external dependencies on the Transformers package, Pytorch, and Captum. In order to adequately test all the features of an explainer it must often be tested alongside a model, we do our best to cover most models and edge cases however this isn't always possible given that some model's have their own quirks in implementations that break convention. We don't expect 100% test coverage but generally we'd like to keep it > 90%.
88 | - Follow our [Coding Rules](#rules).
89 | - Avoid checking in files that shouldn't be tracked (e.g `dist`, `build`, `.tmp`, `.idea`). We recommend using a [global](#global-gitignore) gitignore for this.
90 | - Before you commit please run the test suite and coverage and make sure all tests are passing.
91 |
92 | ```shell
93 | coverage run -m pytest -s -v && coverage report -m
94 | ```
95 | - Format your code appropriately:
96 | * This package uses [black](https://black.readthedocs.io/en/stable/) as its formatter. In order to format your code with black run ```black . ``` from the root of the package.
97 |
98 | - Commit your changes using a descriptive commit message.
99 |
100 | ```shell
101 | git commit -a
102 | ```
103 |
104 | Note: the optional commit `-a` command line option will automatically "add" and "rm" edited files.
105 |
106 | * Push your branch to GitHub:
107 |
108 | ```shell
109 | git push origin fix/my-fix-branch
110 | ```
111 |
112 | * In GitHub, send a pull request to `transformers-interpret:dev`.
113 | * If we suggest changes then:
114 |
115 | - Make the required updates.
116 | - Rebase your branch and force push to your GitHub repository (this will update your Pull Request):
117 |
118 | ```shell
119 | git rebase master -i
120 | git push origin fix/my-fix-branch -f
121 | ```
122 |
123 | That's it! Thank you for your contribution!
124 |
125 |
126 | ### After your pull request is merged
127 |
128 | After your pull request is merged, you can safely delete your branch and pull the changes
129 | from the main (upstream) repository:
130 |
131 | - Delete the remote branch on GitHub either through the GitHub web UI or your local shell as follows:
132 |
133 | ```shell
134 | git push origin --delete my-fix-branch
135 | ```
136 |
137 | - Check out the master branch:
138 |
139 | ```shell
140 | git checkout master -f
141 | ```
142 |
143 | - Delete the local branch:
144 |
145 | ```shell
146 | git branch -D my-fix-branch
147 | ```
148 |
149 | - Update your master with the latest upstream version:
150 |
151 | ```shell
152 | git pull --ff upstream master
153 | ```
154 |
155 | ## ✅ Coding Rules
156 |
157 | We generally follow the [Google Python style guide](http://google.github.io/styleguide/pyguide.html).
158 |
159 | ----
160 |
161 | *This guide was inspired by the [Firebase Web Quickstarts contribution guidelines](https://github.com/firebase/quickstart-js/blob/master/CONTRIBUTING.md) and [Databay](https://github.com/Voyz/databay/edit/master/CONTRIBUTING.md)*
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/.DS_Store
--------------------------------------------------------------------------------
/images/bert_ner_explainer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/bert_ner_explainer.png
--------------------------------------------------------------------------------
/images/bert_qa_explainer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/bert_qa_explainer.png
--------------------------------------------------------------------------------
/images/coverage.svg:
--------------------------------------------------------------------------------
1 |
2 |
22 |
--------------------------------------------------------------------------------
/images/distilbert_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/distilbert_example.png
--------------------------------------------------------------------------------
/images/distilbert_example_negative.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/distilbert_example_negative.png
--------------------------------------------------------------------------------
/images/logotype@1920x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/logotype@1920x.png
--------------------------------------------------------------------------------
/images/logotype@1920x_transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/logotype@1920x_transparent.png
--------------------------------------------------------------------------------
/images/multilabel_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/multilabel_example.png
--------------------------------------------------------------------------------
/images/pairwise_cross_encoder_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/pairwise_cross_encoder_example.png
--------------------------------------------------------------------------------
/images/tight@1920x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/tight@1920x.png
--------------------------------------------------------------------------------
/images/tight@1920x_transparent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/tight@1920x_transparent.png
--------------------------------------------------------------------------------
/images/transformers_interpret_A01.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/transformers_interpret_A01.ai
--------------------------------------------------------------------------------
/images/vision/alpha_scaling_sbs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/alpha_scaling_sbs.png
--------------------------------------------------------------------------------
/images/vision/heatmap_sbs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/heatmap_sbs.png
--------------------------------------------------------------------------------
/images/vision/masked_image_sbs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/masked_image_sbs.png
--------------------------------------------------------------------------------
/images/vision/overlay_sbs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/overlay_sbs.png
--------------------------------------------------------------------------------
/images/zero_shot_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/zero_shot_example.png
--------------------------------------------------------------------------------
/images/zero_shot_example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/zero_shot_example2.png
--------------------------------------------------------------------------------
/notebooks/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/notebooks/.DS_Store
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "transformers-interpret"
3 | version = "0.10.0"
4 | description = "Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code."
5 | readme = "README.md"
6 | classifiers = [
7 | "Development Status :: 3 - Alpha",
8 | "Intended Audience :: Developers",
9 | "Programming Language :: Python :: 3",
10 | "Programming Language :: Python :: 3.6",
11 | "Programming Language :: Python :: 3.7",
12 | "Programming Language :: Python :: 3.8",
13 | "Programming Language :: Python :: 3.9",
14 | "Programming Language :: Python :: 3 :: Only",
15 | "Topic :: Software Development :: Libraries :: Python Modules",
16 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
17 | "Topic :: Scientific/Engineering :: Information Analysis",
18 | "Topic :: Text Processing :: Linguistic",
19 | ]
20 | authors = ["Charles Pierse "]
21 |
22 | [tool.poetry.dependencies]
23 | python = ">=3.7,<4.0"
24 | captum = ">=0.3.1"
25 | transformers = ">=3.0.0"
26 | ipython = "^7.31.1"
27 |
28 | [tool.poetry.dev-dependencies]
29 | pre-commit = "^2.19.0"
30 | twine = "^4.0.1"
31 | pytest = "^5.4.2"
32 | black = {version = "^22.6.0", allow-prereleases = true}
33 | pytest-cov = "^3.0.0"
34 | jupyterlab = "^3.4.5"
35 |
36 |
37 | [build-system]
38 | requires = ["poetry-core>=1.0.0"]
39 | build-backend = "poetry.core.masonry.api"
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest==5.4.2
2 | captum==0.4.1
3 | transformers==4.15.0
4 | ipython==7.31.1
5 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | from setuptools import find_packages
4 |
5 | with open("README.md", "r", encoding="utf-8") as fh:
6 | long_description = fh.read()
7 |
8 |
9 | setup(
10 | name="transformers-interpret",
11 | packages=find_packages(
12 | exclude=[
13 | "*.tests",
14 | "*.tests.*",
15 | "tests.*",
16 | "tests",
17 | "examples",
18 | "docs",
19 | "out",
20 | "dist",
21 | "media",
22 | "test",
23 | ]
24 | ),
25 | version="0.9.6",
26 | license="Apache-2.0",
27 | description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.",
28 | long_description=long_description,
29 | long_description_content_type="text/markdown",
30 | author="Charles Pierse",
31 | author_email="charlespierse@gmail.com",
32 | url="https://github.com/cdpierse/transformers-interpret",
33 | keywords=[
34 | "machine learning",
35 | "natural language proessing",
36 | "explainability",
37 | "transformers",
38 | "model interpretability",
39 | ],
40 | install_requires=["transformers>=3.0.0", "captum>=0.3.1"],
41 | classifiers=[
42 | "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
43 | "Intended Audience :: Developers",
44 | "Topic :: Software Development :: Libraries :: Python Modules",
45 | "Programming Language :: Python :: 3.8",
46 | ],
47 | )
48 |
--------------------------------------------------------------------------------
/test/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/test/.DS_Store
--------------------------------------------------------------------------------
/test/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/test/.gitkeep
--------------------------------------------------------------------------------
/test/text/test_explainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | from transformers import (
4 | AutoModelForCausalLM,
5 | AutoModelForMaskedLM,
6 | AutoModelForPreTraining,
7 | AutoTokenizer,
8 | PreTrainedModel,
9 | PreTrainedTokenizer,
10 | PreTrainedTokenizerFast,
11 | )
12 |
13 | from transformers_interpret import BaseExplainer
14 |
15 | DISTILBERT_MODEL = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
16 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
17 |
18 | GPT2_MODEL = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
19 | GPT2_TOKENIZER = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2")
20 |
21 | BERT_MODEL = AutoModelForPreTraining.from_pretrained("lysandre/tiny-bert-random")
22 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("lysandre/tiny-bert-random")
23 |
24 |
25 | class DummyExplainer(BaseExplainer):
26 | def __init__(self, *args, **kwargs):
27 | super().__init__(*args, **kwargs)
28 |
29 | def encode(self, text: str = None):
30 | return self.tokenizer.encode(text, add_special_tokens=False)
31 |
32 | def decode(self, input_ids):
33 | return self.tokenizer.convert_ids_to_tokens(input_ids[0])
34 |
35 | @property
36 | def word_attributions(self):
37 | pass
38 |
39 | def _run(self):
40 | pass
41 |
42 | def _calculate_attributions(self):
43 | pass
44 |
45 | def _forward(self):
46 | pass
47 |
48 |
49 | def test_explainer_init_distilbert():
50 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
51 | assert isinstance(explainer.model, PreTrainedModel)
52 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance(
53 | explainer.tokenizer, PreTrainedTokenizer
54 | )
55 | assert explainer.model_prefix == DISTILBERT_MODEL.base_model_prefix
56 | assert explainer.device == DISTILBERT_MODEL.device
57 |
58 | assert explainer.accepts_position_ids is False
59 | assert explainer.accepts_token_type_ids is False
60 |
61 | assert explainer.model.config.model_type == "distilbert"
62 | assert explainer.position_embeddings is not None
63 | assert explainer.word_embeddings is not None
64 | assert explainer.token_type_embeddings is None
65 |
66 |
67 | def test_explainer_init_bert():
68 | explainer = DummyExplainer(BERT_MODEL, BERT_TOKENIZER)
69 | assert isinstance(explainer.model, PreTrainedModel)
70 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance(
71 | explainer.tokenizer, PreTrainedTokenizer
72 | )
73 | assert explainer.model_prefix == BERT_MODEL.base_model_prefix
74 | assert explainer.device == BERT_MODEL.device
75 |
76 | assert explainer.accepts_position_ids is True
77 | assert explainer.accepts_token_type_ids is True
78 |
79 | assert explainer.model.config.model_type == "bert"
80 | assert explainer.position_embeddings is not None
81 | assert explainer.word_embeddings is not None
82 | assert explainer.token_type_embeddings is not None
83 |
84 |
85 | def test_explainer_init_gpt2():
86 | explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER)
87 | assert isinstance(explainer.model, PreTrainedModel)
88 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance(
89 | explainer.tokenizer, PreTrainedTokenizer
90 | )
91 | assert explainer.model_prefix == GPT2_MODEL.base_model_prefix
92 | assert explainer.device == GPT2_MODEL.device
93 |
94 | assert explainer.accepts_position_ids is True
95 | assert explainer.accepts_token_type_ids is True
96 |
97 | assert explainer.model.config.model_type == "gpt2"
98 | assert explainer.position_embeddings is not None
99 | assert explainer.word_embeddings is not None
100 |
101 |
102 | def test_explainer_init_cpu():
103 | old_device = DISTILBERT_MODEL.device
104 | try:
105 | DISTILBERT_MODEL.to("cpu")
106 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
107 | assert explainer.device.type == "cpu"
108 | finally:
109 | DISTILBERT_MODEL.to(old_device)
110 |
111 |
112 | def test_explainer_init_cuda():
113 | if not torch.cuda.is_available():
114 | print("Cuda device not available to test. Skipping.")
115 | else:
116 | old_device = DISTILBERT_MODEL.device
117 | try:
118 | DISTILBERT_MODEL.to("cuda")
119 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
120 | assert explainer.device.type == "cuda"
121 | finally:
122 | DISTILBERT_MODEL.to(old_device)
123 |
124 |
125 | def test_explainer_make_input_reference_pair():
126 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
127 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
128 | assert isinstance(input_ids, Tensor)
129 | assert isinstance(ref_input_ids, Tensor)
130 | assert isinstance(len_inputs, int)
131 |
132 | assert len(input_ids[0]) == len(ref_input_ids[0]) == (len_inputs + 2)
133 | assert ref_input_ids[0][0] == input_ids[0][0]
134 | assert ref_input_ids[0][-1] == input_ids[0][-1]
135 | assert ref_input_ids[0][0] == explainer.cls_token_id
136 | assert ref_input_ids[0][-1] == explainer.sep_token_id
137 |
138 |
139 | def test_explainer_make_input_reference_pair_gpt2():
140 | explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER)
141 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
142 | assert isinstance(input_ids, Tensor)
143 | assert isinstance(ref_input_ids, Tensor)
144 | assert isinstance(len_inputs, int)
145 |
146 | assert len(input_ids[0]) == len(ref_input_ids[0]) == (len_inputs)
147 |
148 |
149 | def test_explainer_make_input_token_type_pair_no_sep_idx():
150 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
151 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
152 | (
153 | token_type_ids,
154 | ref_token_type_ids,
155 | ) = explainer._make_input_reference_token_type_pair(input_ids)
156 |
157 | assert ref_token_type_ids[0][0] == torch.zeros(len(input_ids[0]))[0]
158 | for i, val in enumerate(token_type_ids[0]):
159 | if i == 0:
160 | assert val == 0
161 | else:
162 | assert val == 1
163 |
164 |
165 | def test_explainer_make_input_token_type_pair_sep_idx():
166 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
167 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
168 | (
169 | token_type_ids,
170 | ref_token_type_ids,
171 | ) = explainer._make_input_reference_token_type_pair(input_ids, 3)
172 |
173 | assert ref_token_type_ids[0][0] == torch.zeros(len(input_ids[0]))[0]
174 | for i, val in enumerate(token_type_ids[0]):
175 | if i <= 3:
176 | assert val == 0
177 | else:
178 | assert val == 1
179 |
180 |
181 | def test_explainer_make_input_reference_position_id_pair():
182 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
183 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
184 | position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(input_ids)
185 |
186 | assert ref_position_ids[0][0] == torch.zeros(len(input_ids[0]))[0]
187 | for i, val in enumerate(position_ids[0]):
188 | assert val == i
189 |
190 |
191 | def test_explainer_make_attention_mask():
192 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
193 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string")
194 | attention_mask = explainer._make_attention_mask(input_ids)
195 | assert len(attention_mask[0]) == len(input_ids[0])
196 | for i, val in enumerate(attention_mask[0]):
197 | assert val == 1
198 |
199 |
200 | def test_explainer_str():
201 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
202 | s = "DummyExplainer("
203 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__},"
204 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__}"
205 | s += ")"
206 | assert s == explainer.__str__()
207 |
--------------------------------------------------------------------------------
/test/text/test_multilabel_classification_explainer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
3 |
4 | from transformers_interpret import MultiLabelClassificationExplainer
5 |
6 | DISTILBERT_MODEL = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
7 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
8 |
9 | BERT_MODEL = AutoModelForSequenceClassification.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification")
10 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification")
11 |
12 |
13 | def test_multilabel_classification_explainer_init_distilbert():
14 | seq_explainer = MultiLabelClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
15 | assert seq_explainer.attribution_type == "lig"
16 | assert seq_explainer.label2id == DISTILBERT_MODEL.config.label2id
17 | assert seq_explainer.id2label == DISTILBERT_MODEL.config.id2label
18 | assert seq_explainer.attributions is None
19 | assert seq_explainer.labels == []
20 |
21 |
22 | def test_multilabel_classification_explainer_init_bert():
23 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
24 | assert seq_explainer.attribution_type == "lig"
25 | assert seq_explainer.label2id == BERT_MODEL.config.label2id
26 | assert seq_explainer.id2label == BERT_MODEL.config.id2label
27 | assert seq_explainer.attributions is None
28 | assert seq_explainer.labels == []
29 |
30 |
31 | def test_multilabel_classification_explainer_word_attributes_is_dict():
32 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
33 | wa = seq_explainer("this is a sample text")
34 | assert isinstance(wa, dict)
35 |
36 |
37 | def test_multilabel_classification_explainer_word_attributes_is_equals_label_length():
38 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
39 | wa = seq_explainer("this is a sample text")
40 | assert len(wa) == len(BERT_MODEL.config.id2label)
41 |
42 |
43 | def test_multilabel_classification_word_attributions_not_calculated_raises():
44 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
45 | with pytest.raises(ValueError):
46 | seq_explainer.word_attributions
47 |
48 |
49 | def test_multilabel_classification_viz():
50 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
51 | wa = seq_explainer("this is a sample text")
52 | seq_explainer.visualize()
53 |
54 |
55 | @pytest.mark.skip(reason="Slow test")
56 | def test_multilabel_classification_classification_custom_steps():
57 | explainer_string = "I love you , I like you"
58 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
59 | seq_explainer(explainer_string, n_steps=1)
60 |
61 |
62 | @pytest.mark.skip(reason="Slow test")
63 | def test_multilabel_classification_classification_internal_batch_size():
64 | explainer_string = "I love you , I like you"
65 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
66 | seq_explainer(explainer_string, internal_batch_size=1)
67 |
--------------------------------------------------------------------------------
/test/text/test_question_answering_explainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer
5 |
6 | from transformers_interpret import QuestionAnsweringExplainer
7 | from transformers_interpret.errors import (
8 | AttributionTypeNotSupportedError,
9 | InputIdsNotCalculatedError,
10 | )
11 |
12 | DISTILBERT_QA_MODEL = AutoModelForQuestionAnswering.from_pretrained("mrm8488/bert-tiny-5-finetuned-squadv2")
13 | DISTILBERT_QA_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-tiny-5-finetuned-squadv2")
14 |
15 |
16 | def test_question_answering_explainer_init_distilbert():
17 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
18 | assert qa_explainer.attribution_type == "lig"
19 | assert qa_explainer.attributions is None
20 | assert qa_explainer.position == 0
21 |
22 |
23 | def test_question_answering_explainer_init_attribution_type_error():
24 | with pytest.raises(AttributionTypeNotSupportedError):
25 | QuestionAnsweringExplainer(
26 | DISTILBERT_QA_MODEL,
27 | DISTILBERT_QA_TOKENIZER,
28 | attribution_type="UNSUPPORTED",
29 | )
30 |
31 |
32 | def test_question_answering_encode():
33 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
34 |
35 | _input = "this is a sample of text to be encoded"
36 | tokens = qa_explainer.encode(_input)
37 | assert isinstance(tokens, list)
38 | assert tokens[0] != qa_explainer.cls_token_id
39 | assert tokens[-1] != qa_explainer.sep_token_id
40 | assert len(tokens) >= len(_input.split(" "))
41 |
42 |
43 | def test_question_answering_decode():
44 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
45 | explainer_question = "what is his name ?"
46 | explainer_text = "his name is bob"
47 | input_ids, _, _ = qa_explainer._make_input_reference_pair(explainer_question, explainer_text)
48 | decoded = qa_explainer.decode(input_ids)
49 | assert decoded[0] == qa_explainer.tokenizer.cls_token
50 | assert decoded[-1] == qa_explainer.tokenizer.sep_token
51 | assert " ".join(decoded[1:-1]) == explainer_question.lower() + " [SEP] " + explainer_text.lower()
52 |
53 |
54 | def test_question_answering_word_attributions():
55 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
56 | explainer_question = "what is his name ?"
57 | explainer_text = "his name is bob"
58 | word_attributions = qa_explainer(explainer_question, explainer_text)
59 | assert isinstance(word_attributions, dict)
60 | assert "start" in word_attributions.keys()
61 | assert "end" in word_attributions.keys()
62 | assert len(word_attributions["start"]) == len(word_attributions["end"])
63 |
64 |
65 | def test_question_answering_word_attributions_input_ids_not_calculated():
66 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
67 |
68 | with pytest.raises(ValueError):
69 | qa_explainer.word_attributions
70 |
71 |
72 | def test_question_answering_start_pos():
73 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
74 | explainer_question = "what is his name ?"
75 | explainer_text = "his name is Bob"
76 | qa_explainer(explainer_question, explainer_text)
77 | start_pos = qa_explainer.start_pos
78 | assert start_pos == 10
79 |
80 |
81 | def test_question_answering_end_pos():
82 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
83 | explainer_question = "what is his name ?"
84 | explainer_text = "his name is Bob"
85 | qa_explainer(explainer_question, explainer_text)
86 | end_pos = qa_explainer.end_pos
87 | assert end_pos == 10
88 |
89 |
90 | def test_question_answering_start_pos_input_ids_not_calculated():
91 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
92 | with pytest.raises(InputIdsNotCalculatedError):
93 | qa_explainer.start_pos
94 |
95 |
96 | def test_question_answering_end_pos_input_ids_not_calculated():
97 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
98 | with pytest.raises(InputIdsNotCalculatedError):
99 | qa_explainer.end_pos
100 |
101 |
102 | def test_question_answering_predicted_answer():
103 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
104 | explainer_question = "what is his name ?"
105 | explainer_text = "his name is Bob"
106 | qa_explainer(explainer_question, explainer_text)
107 | predicted_answer = qa_explainer.predicted_answer
108 | assert predicted_answer == "bob"
109 |
110 |
111 | def test_question_answering_predicted_answer_input_ids_not_calculated():
112 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
113 | with pytest.raises(InputIdsNotCalculatedError):
114 | qa_explainer.predicted_answer
115 |
116 |
117 | def test_question_answering_visualize():
118 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
119 | explainer_question = "what is his name ?"
120 | explainer_text = "his name is Bob"
121 | qa_explainer(explainer_question, explainer_text)
122 | qa_explainer.visualize()
123 |
124 |
125 | def test_question_answering_visualize_save():
126 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
127 | explainer_question = "what is his name ?"
128 | explainer_text = "his name is Bob"
129 | qa_explainer(explainer_question, explainer_text)
130 |
131 | html_filename = "./test/qa_test.html"
132 | qa_explainer.visualize(html_filename)
133 | assert os.path.exists(html_filename)
134 | os.remove(html_filename)
135 |
136 |
137 | def test_question_answering_visualize_save_append_html_file_ending():
138 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
139 | explainer_question = "what is his name ?"
140 | explainer_text = "his name is Bob"
141 | qa_explainer(explainer_question, explainer_text)
142 |
143 | html_filename = "./test/qa_test"
144 | qa_explainer.visualize(html_filename)
145 | assert os.path.exists(html_filename + ".html")
146 | os.remove(html_filename + ".html")
147 |
148 |
149 | def xtest_question_answering_custom_steps():
150 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
151 | explainer_question = "what is his name ?"
152 | explainer_text = "his name is Bob"
153 | qa_explainer(explainer_question, explainer_text, n_steps=1)
154 |
155 |
156 | def xtest_question_answering_custom_internal_batch_size():
157 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER)
158 | explainer_question = "what is his name ?"
159 | explainer_text = "his name is Bob"
160 | qa_explainer(explainer_question, explainer_text, internal_batch_size=1)
161 |
--------------------------------------------------------------------------------
/test/text/test_sequence_classification_explainer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
3 |
4 | from transformers_interpret import (
5 | PairwiseSequenceClassificationExplainer,
6 | SequenceClassificationExplainer,
7 | )
8 | from transformers_interpret.errors import (
9 | AttributionTypeNotSupportedError,
10 | InputIdsNotCalculatedError,
11 | )
12 |
13 | DISTILBERT_MODEL = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
14 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
15 |
16 | BERT_MODEL = AutoModelForSequenceClassification.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification")
17 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification")
18 |
19 | CROSS_ENCODER_MODEL = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-TinyBERT-L-2-v2")
20 | CROSS_ENCODER_TOKENIZER = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-TinyBERT-L-2-v2")
21 |
22 |
23 | def test_sequence_classification_explainer_init_distilbert():
24 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
25 | assert seq_explainer.attribution_type == "lig"
26 | assert seq_explainer.label2id == DISTILBERT_MODEL.config.label2id
27 | assert seq_explainer.id2label == DISTILBERT_MODEL.config.id2label
28 | assert seq_explainer.attributions is None
29 |
30 |
31 | def test_sequence_classification_explainer_init_bert():
32 | seq_explainer = SequenceClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
33 | assert seq_explainer.attribution_type == "lig"
34 | assert seq_explainer.label2id == BERT_MODEL.config.label2id
35 | assert seq_explainer.id2label == BERT_MODEL.config.id2label
36 | assert seq_explainer.attributions is None
37 |
38 |
39 | def test_sequence_classification_explainer_init_attribution_type_error():
40 | with pytest.raises(AttributionTypeNotSupportedError):
41 | SequenceClassificationExplainer(
42 | DISTILBERT_MODEL,
43 | DISTILBERT_TOKENIZER,
44 | attribution_type="UNSUPPORTED",
45 | )
46 |
47 |
48 | def test_sequence_classification_explainer_init_with_custom_labels():
49 | labels = ["label_1", "label_2"]
50 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=labels)
51 | assert len(labels) == len(seq_explainer.id2label)
52 | assert len(labels) == len(seq_explainer.label2id)
53 | for (k1, v1), (k2, v2) in zip(seq_explainer.id2label.items(), seq_explainer.label2id.items()):
54 | assert v1 in labels and k2 in labels
55 |
56 |
57 | def test_sequence_classification_explainer_init_custom_labels_size_error():
58 | with pytest.raises(ValueError):
59 | SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["few_labels"])
60 |
61 |
62 | def test_sequence_classification_encode():
63 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
64 |
65 | _input = "this is a sample of text to be encoded"
66 | tokens = seq_explainer.encode(_input)
67 | assert isinstance(tokens, list)
68 | assert tokens[0] != seq_explainer.cls_token_id
69 | assert tokens[-1] != seq_explainer.sep_token_id
70 | assert len(tokens) >= len(_input.split(" "))
71 |
72 |
73 | def test_sequence_classification_decode():
74 | explainer_string = "I love you , I hate you"
75 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
76 | input_ids, _, _ = seq_explainer._make_input_reference_pair(explainer_string)
77 | decoded = seq_explainer.decode(input_ids)
78 | assert decoded[0] == seq_explainer.tokenizer.cls_token
79 | assert decoded[-1] == seq_explainer.tokenizer.sep_token
80 | assert " ".join(decoded[1:-1]) == explainer_string.lower()
81 |
82 |
83 | def test_sequence_classification_run_text_given():
84 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
85 | word_attributions = seq_explainer._run("I love you, I just love you")
86 | assert isinstance(word_attributions, list)
87 |
88 | actual_tokens = [token for token, _ in word_attributions]
89 | expected_tokens = [
90 | "[CLS]",
91 | "i",
92 | "love",
93 | "you",
94 | ",",
95 | "i",
96 | "just",
97 | "love",
98 | "you",
99 | "[SEP]",
100 | ]
101 | assert actual_tokens == expected_tokens
102 |
103 |
104 | def test_sequence_classification_explain_on_cls_index():
105 | explainer_string = "I love you , I like you"
106 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
107 | seq_explainer._run(explainer_string, index=0)
108 | assert seq_explainer.predicted_class_index == 1
109 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index
110 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index]
111 | assert seq_explainer.predicted_class_name != "NEGATIVE"
112 | assert seq_explainer.predicted_class_name == "POSITIVE"
113 |
114 |
115 | def test_sequence_classification_explain_position_embeddings():
116 | explainer_string = "I love you , I like you"
117 | seq_explainer = SequenceClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
118 | pos_attributions = seq_explainer(explainer_string, embedding_type=1)
119 | word_attributions = seq_explainer(explainer_string, embedding_type=0)
120 |
121 | assert pos_attributions != word_attributions
122 |
123 |
124 | def test_sequence_classification_explain_position_embeddings_not_available():
125 | explainer_string = "I love you , I like you"
126 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
127 | pos_attributions = seq_explainer(explainer_string, embedding_type=1)
128 | word_attributions = seq_explainer(explainer_string, embedding_type=0)
129 |
130 | assert pos_attributions == word_attributions
131 |
132 |
133 | def test_sequence_classification_explain_embedding_incorrect_value():
134 | explainer_string = "I love you , I like you"
135 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
136 |
137 | word_attributions = seq_explainer(explainer_string, embedding_type=0)
138 | incorrect_word_attributions = seq_explainer(explainer_string, embedding_type=-42)
139 |
140 | assert incorrect_word_attributions == word_attributions
141 |
142 |
143 | def test_sequence_classification_predicted_class_name_no_id2label_defaults_idx():
144 | explainer_string = "I love you , I like you"
145 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
146 | seq_explainer.id2label = {"test": "value"}
147 | seq_explainer._run(explainer_string)
148 | assert seq_explainer.predicted_class_name == 1
149 |
150 |
151 | def test_sequence_classification_explain_on_cls_name():
152 | explainer_string = "I love you , I like you"
153 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
154 | seq_explainer._run(explainer_string, class_name="NEGATIVE")
155 | assert seq_explainer.predicted_class_index == 1
156 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index
157 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index]
158 | assert seq_explainer.predicted_class_name != "NEGATIVE"
159 | assert seq_explainer.predicted_class_name == "POSITIVE"
160 |
161 |
162 | def test_sequence_classification_explain_on_cls_name_with_custom_labels():
163 | explainer_string = "I love you , I like you"
164 | seq_explainer = SequenceClassificationExplainer(
165 | DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"]
166 | )
167 | seq_explainer._run(explainer_string, class_name="sad")
168 | assert seq_explainer.predicted_class_index == 1
169 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index
170 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index]
171 | assert seq_explainer.predicted_class_name != "sad"
172 | assert seq_explainer.predicted_class_name == "happy"
173 |
174 |
175 | def test_sequence_classification_explain_on_cls_name_not_in_dict():
176 | explainer_string = "I love you , I like you"
177 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
178 | seq_explainer._run(explainer_string, class_name="UNKNOWN")
179 | assert seq_explainer.selected_index == 1
180 | assert seq_explainer.predicted_class_index == 1
181 |
182 |
183 | def test_sequence_classification_explain_raises_on_input_ids_not_calculated():
184 | with pytest.raises(InputIdsNotCalculatedError):
185 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
186 | seq_explainer.predicted_class_index
187 |
188 |
189 | def test_sequence_classification_word_attributions():
190 | explainer_string = "I love you , I like you"
191 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
192 | seq_explainer(explainer_string)
193 | assert isinstance(seq_explainer.word_attributions, list)
194 | for element in seq_explainer.word_attributions:
195 | assert isinstance(element, tuple)
196 |
197 |
198 | def test_sequence_classification_word_attributions_not_calculated_raises():
199 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
200 | with pytest.raises(ValueError):
201 | seq_explainer.word_attributions
202 |
203 |
204 | def test_sequence_classification_explainer_str():
205 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
206 | s = "SequenceClassificationExplainer("
207 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__},"
208 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__},"
209 | s += "\n\tattribution_type='lig',"
210 | s += ")"
211 | assert s == seq_explainer.__str__()
212 |
213 |
214 | def test_sequence_classification_viz():
215 | explainer_string = "I love you , I like you"
216 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
217 | seq_explainer(explainer_string)
218 | seq_explainer.visualize()
219 |
220 |
221 | def sequence_classification_custom_steps():
222 | explainer_string = "I love you , I like you"
223 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
224 | seq_explainer(explainer_string, n_steps=1)
225 |
226 |
227 | def sequence_classification_internal_batch_size():
228 | explainer_string = "I love you , I like you"
229 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
230 | seq_explainer(explainer_string, internal_batch_size=1)
231 |
232 |
233 | def test_pairwise_sequence_classification():
234 | string1 = "How many people live in berlin?"
235 | string2 = "there are 1000000 people living in berlin"
236 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
237 |
238 | attr = explainer(string1, string2)
239 | assert explainer.text1 == string1
240 | assert explainer.text2 == string2
241 | assert attr
242 |
243 |
244 | def test_pairwise_sequence_classification_flip_attribute_sign():
245 | string1 = "How many people live in berlin?"
246 | string2 = "this string is not related to the question."
247 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
248 |
249 | original_sign_attr = explainer(string1, string2)
250 | flipped_sign_attr = explainer(string1, string2, flip_sign=True)
251 |
252 | for flipped_wa, original_wa in zip(flipped_sign_attr, original_sign_attr):
253 | assert flipped_wa[1] == -original_wa[1]
254 |
255 |
256 | def test_pairwise_sequence_classification_viz():
257 | string1 = "How many people live in berlin?"
258 | string2 = "there are 1000000 people living in berlin"
259 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
260 |
261 | explainer(string1, string2)
262 | explainer.visualize()
263 |
264 |
265 | def test_pairwise_sequence_classification_custom_steps():
266 | string1 = "How many people live in berlin?"
267 | string2 = "there are 1000000 people living in berlin"
268 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
269 |
270 | explainer(string1, string2, n_steps=1)
271 |
272 |
273 | def test_pairwise_sequence_classification_internal_batch_size():
274 | string1 = "How many people live in berlin?"
275 | string2 = "there are 1000000 people living in berlin"
276 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
277 |
278 | explainer(string1, string2, internal_batch_size=1)
279 |
280 |
281 | def test_pairwise_sequence_classification_position_embeddings():
282 | string1 = "How many people live in berlin?"
283 | string2 = "there are 1000000 people living in berlin"
284 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
285 |
286 | explainer(string1, string2, embedding_type=1)
287 |
288 |
289 | def test_pairwise_sequence_classification_position_embeddings_not_accepted():
290 | string1 = "How many people live in berlin?"
291 | string2 = "there are 1000000 people living in berlin"
292 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER)
293 | explainer.accepts_position_ids = False
294 |
295 | explainer(string1, string2, embedding_type=1)
296 |
--------------------------------------------------------------------------------
/test/text/test_token_classification_explainer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from transformers import AutoModelForTokenClassification, AutoTokenizer
3 |
4 | from transformers_interpret import TokenClassificationExplainer
5 | from transformers_interpret.errors import (
6 | AttributionTypeNotSupportedError,
7 | InputIdsNotCalculatedError,
8 | )
9 |
10 | DISTILBERT_MODEL = AutoModelForTokenClassification.from_pretrained(
11 | "elastic/distilbert-base-cased-finetuned-conll03-english"
12 | )
13 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("elastic/distilbert-base-cased-finetuned-conll03-english")
14 |
15 |
16 | BERT_MODEL = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
17 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
18 |
19 |
20 | def test_token_classification_explainer_init_distilbert():
21 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
22 | assert ner_explainer.attribution_type == "lig"
23 | assert ner_explainer.label2id == DISTILBERT_MODEL.config.label2id
24 | assert ner_explainer.id2label == DISTILBERT_MODEL.config.id2label
25 | assert ner_explainer.attributions is None
26 |
27 |
28 | def test_token_classification_explainer_init_bert():
29 | ner_explainer = TokenClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
30 | assert ner_explainer.attribution_type == "lig"
31 | assert ner_explainer.label2id == BERT_MODEL.config.label2id
32 | assert ner_explainer.id2label == BERT_MODEL.config.id2label
33 | assert ner_explainer.attributions is None
34 |
35 |
36 | def test_token_classification_explainer_init_attribution_type_error():
37 | with pytest.raises(AttributionTypeNotSupportedError):
38 | TokenClassificationExplainer(
39 | DISTILBERT_MODEL,
40 | DISTILBERT_TOKENIZER,
41 | attribution_type="UNSUPPORTED",
42 | )
43 |
44 |
45 | def test_token_classification_selected_indexes_only_ignored_indexes():
46 | explainer_string = "We visited Paris during the weekend, where Emmanuel Macron lives."
47 | expected_all_indexes = list(range(15))
48 | indexes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13]
49 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
50 |
51 | word_attributions = ner_explainer(explainer_string, ignored_indexes=indexes)
52 |
53 | assert len(ner_explainer._selected_indexes) == (len(expected_all_indexes) - len(indexes))
54 |
55 | for index in ner_explainer._selected_indexes:
56 | assert index in expected_all_indexes
57 | assert index not in indexes
58 |
59 |
60 | def test_token_classification_selected_indexes_only_ignored_labels():
61 | ignored_labels = ["O", "I-LOC", "B-LOC"]
62 | indexes = [8, 9, 10]
63 | explainer_string = "We visited Paris last weekend, where Emmanuel Macron lives."
64 |
65 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
66 |
67 | word_attributions = ner_explainer(explainer_string, ignored_labels=ignored_labels)
68 |
69 | assert len(indexes) == len(ner_explainer._selected_indexes)
70 |
71 | for index in ner_explainer._selected_indexes:
72 | assert index in indexes
73 |
74 |
75 | def test_token_classification_selected_indexes_all():
76 | explainer_string = "We visited Paris during the weekend, where Emmanuel Macron lives."
77 | expected_all_indexes = list(range(15))
78 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
79 |
80 | word_attributions = ner_explainer(explainer_string)
81 |
82 | assert len(ner_explainer._selected_indexes) == ner_explainer.input_ids.shape[1]
83 |
84 | for i, index in enumerate(ner_explainer._selected_indexes):
85 | assert i == index
86 |
87 |
88 | def test_token_classification_selected_indexes_ignored_indexes_and_labels():
89 | ignored_labels = ["O", "I-PER", "B-PER"]
90 | ignored_indexes = [4, 5, 6]
91 | explainer_string = "We visited Paris last weekend"
92 | selected_indexes = [3] # this models classifies erroniously '[SEP]' as a location
93 |
94 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
95 | word_attributions = ner_explainer(explainer_string, ignored_indexes=ignored_indexes, ignored_labels=ignored_labels)
96 |
97 | assert len(selected_indexes) == len(ner_explainer._selected_indexes)
98 |
99 | for i, index in enumerate(ner_explainer._selected_indexes):
100 | assert selected_indexes[i] == index
101 |
102 |
103 | def test_token_classification_encode():
104 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
105 |
106 | _input = "this is a sample of text to be encoded"
107 | tokens = ner_explainer.encode(_input)
108 | assert isinstance(tokens, list)
109 | assert tokens[0] != ner_explainer.cls_token_id
110 | assert tokens[-1] != ner_explainer.sep_token_id
111 | assert len(tokens) >= len(_input.split(" "))
112 |
113 |
114 | def test_token_classification_decode():
115 | explainer_string = "We visited Paris during the weekend"
116 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
117 | input_ids, _, _ = ner_explainer._make_input_reference_pair(explainer_string)
118 | decoded = ner_explainer.decode(input_ids)
119 | assert decoded[0] == ner_explainer.tokenizer.cls_token
120 | assert decoded[-1] == ner_explainer.tokenizer.sep_token
121 | assert " ".join(decoded[1:-1]) == explainer_string
122 |
123 |
124 | def test_token_classification_run_text_given():
125 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
126 | word_attributions = ner_explainer._run("We visited Paris during the weekend")
127 | assert isinstance(word_attributions, dict)
128 |
129 | actual_tokens = list(word_attributions.keys())
130 | expected_tokens = [
131 | "[CLS]",
132 | "We",
133 | "visited",
134 | "Paris",
135 | "during",
136 | "the",
137 | "weekend",
138 | "[SEP]",
139 | ]
140 | assert actual_tokens == expected_tokens
141 |
142 |
143 | def test_token_classification_explain_position_embeddings():
144 | explainer_string = "We visited Paris during the weekend"
145 | ner_explainer = TokenClassificationExplainer(BERT_MODEL, BERT_TOKENIZER)
146 | pos_attributions = ner_explainer(explainer_string, embedding_type=1)
147 | word_attributions = ner_explainer(explainer_string, embedding_type=0)
148 |
149 | for token in ner_explainer.word_attributions.keys():
150 | assert pos_attributions != word_attributions
151 |
152 |
153 | def test_token_classification_explain_position_embeddings_incorrect_value():
154 | explainer_string = "We visited Paris during the weekend"
155 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
156 |
157 | word_attributions = ner_explainer(explainer_string, embedding_type=0)
158 | incorrect_word_attributions = ner_explainer(explainer_string, embedding_type=-42)
159 |
160 | assert incorrect_word_attributions == word_attributions
161 |
162 |
163 | def test_token_classification_predicted_class_names():
164 | explainer_string = "We visited Paris during the weekend"
165 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
166 | ner_explainer._run(explainer_string)
167 | ground_truths = ["O", "O", "O", "B-LOC", "O", "O", "O", "O"]
168 |
169 | assert len(ground_truths) == len(ner_explainer.predicted_class_names)
170 |
171 | for i, class_id in enumerate(ner_explainer.predicted_class_names):
172 | assert ground_truths[i] == class_id
173 |
174 |
175 | def test_token_classification_predicted_class_names_no_id2label_defaults_idx():
176 | explainer_string = "We visited Paris during the weekend"
177 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
178 | ner_explainer.id2label = {"test": "value"}
179 | ner_explainer._run(explainer_string)
180 | class_labels = list(range(9))
181 |
182 | assert len(ner_explainer.predicted_class_names) == 8
183 |
184 | for class_name in ner_explainer.predicted_class_names:
185 | assert class_name in class_labels
186 |
187 |
188 | def test_token_classification_explain_raises_on_input_ids_not_calculated():
189 | with pytest.raises(InputIdsNotCalculatedError):
190 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
191 | ner_explainer.predicted_class_indexes
192 |
193 |
194 | def test_token_classification_word_attributions():
195 | explainer_string = "We visited Paris during the weekend"
196 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
197 | ner_explainer(explainer_string)
198 |
199 | assert isinstance(ner_explainer.word_attributions, dict)
200 |
201 | for token, elements in ner_explainer.word_attributions.items():
202 | assert isinstance(elements, dict)
203 | assert list(elements.keys()) == ["label", "attribution_scores"]
204 | assert isinstance(elements["label"], str)
205 | assert isinstance(elements["attribution_scores"], list)
206 | for score in elements["attribution_scores"]:
207 | assert isinstance(score, tuple)
208 | assert isinstance(score[0], str)
209 | assert isinstance(score[1], float)
210 |
211 |
212 | def test_token_classification_word_attributions_not_calculated_raises():
213 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
214 | with pytest.raises(ValueError):
215 | ner_explainer.word_attributions
216 |
217 |
218 | def test_token_classification_explainer_str():
219 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
220 | s = "TokenClassificationExplainer("
221 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__},"
222 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__},"
223 | s += "\n\tattribution_type='lig',"
224 | s += ")"
225 | assert s == ner_explainer.__str__()
226 |
227 |
228 | def test_token_classification_viz():
229 | explainer_string = "We visited Paris during the weekend"
230 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
231 | ner_explainer(explainer_string)
232 | ner_explainer.visualize()
233 |
234 |
235 | def test_token_classification_viz_on_true_classes_value_error():
236 | explainer_string = "We visited Paris during the weekend"
237 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
238 | ner_explainer(explainer_string)
239 | true_classes = ["None", "Location", "None"]
240 | with pytest.raises(ValueError):
241 | ner_explainer.visualize(true_classes=true_classes)
242 |
243 |
244 | def token_classification_custom_steps():
245 | explainer_string = "We visited Paris during the weekend"
246 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
247 | ner_explainer(explainer_string, n_steps=1)
248 |
249 |
250 | def token_classification_internal_batch_size():
251 | explainer_string = "We visited Paris during the weekend"
252 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER)
253 | ner_explainer(explainer_string, internal_batch_size=1)
254 |
--------------------------------------------------------------------------------
/test/text/test_zero_shot_explainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from unittest.mock import patch
3 |
4 | import pytest
5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
6 |
7 | from transformers_interpret import ZeroShotClassificationExplainer
8 | from transformers_interpret.errors import AttributionTypeNotSupportedError
9 |
10 | DISTILBERT_MNLI_MODEL = AutoModelForSequenceClassification.from_pretrained("typeform/distilbert-base-uncased-mnli")
11 | DISTILBERT_MNLI_TOKENIZER = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli")
12 |
13 |
14 | def test_zero_shot_explainer_init_distilbert():
15 | zero_shot_explainer = ZeroShotClassificationExplainer(
16 | DISTILBERT_MNLI_MODEL,
17 | DISTILBERT_MNLI_TOKENIZER,
18 | )
19 |
20 | assert zero_shot_explainer.attribution_type == "lig"
21 | assert zero_shot_explainer.attributions == []
22 | assert zero_shot_explainer.label_exists is True
23 | assert zero_shot_explainer.entailment_key == "ENTAILMENT"
24 |
25 |
26 | def test_zero_shot_explainer_init_attribution_type_error():
27 | with pytest.raises(AttributionTypeNotSupportedError):
28 | ZeroShotClassificationExplainer(
29 | DISTILBERT_MNLI_MODEL,
30 | DISTILBERT_MNLI_TOKENIZER,
31 | attribution_type="UNSUPPORTED",
32 | )
33 |
34 |
35 | @patch.object(
36 | ZeroShotClassificationExplainer,
37 | "_entailment_label_exists",
38 | return_value=(False, None),
39 | )
40 | def test_zero_shot_explainer_no_entailment_label(mock_method):
41 | with pytest.raises(ValueError):
42 | ZeroShotClassificationExplainer(
43 | DISTILBERT_MNLI_MODEL,
44 | DISTILBERT_MNLI_TOKENIZER,
45 | )
46 |
47 |
48 | def test_zero_shot_explainer_word_attributions():
49 | zero_shot_explainer = ZeroShotClassificationExplainer(
50 | DISTILBERT_MNLI_MODEL,
51 | DISTILBERT_MNLI_TOKENIZER,
52 | )
53 | labels = ["urgent", "phone", "tablet", "computer"]
54 | word_attributions = zero_shot_explainer(
55 | "I have a problem with my iphone that needs to be resolved asap!!",
56 | labels=labels,
57 | )
58 | assert isinstance(word_attributions, dict)
59 | for label in labels:
60 | assert label in word_attributions.keys()
61 |
62 |
63 | def test_zero_shot_explainer_call_word_attributions_early_raises_error():
64 | with pytest.raises(ValueError):
65 | zero_shot_explainer = ZeroShotClassificationExplainer(
66 | DISTILBERT_MNLI_MODEL,
67 | DISTILBERT_MNLI_TOKENIZER,
68 | )
69 |
70 | zero_shot_explainer.word_attributions
71 |
72 |
73 | def test_zero_shot_explainer_word_attributions_include_hypothesis():
74 | zero_shot_explainer = ZeroShotClassificationExplainer(
75 | DISTILBERT_MNLI_MODEL,
76 | DISTILBERT_MNLI_TOKENIZER,
77 | )
78 | labels = ["urgent", "phone", "tablet", "computer"]
79 | word_attributions_with_hyp = zero_shot_explainer(
80 | "I have a problem with my iphone that needs to be resolved asap!!",
81 | labels=labels,
82 | include_hypothesis=True,
83 | )
84 | word_attributions_without_hyp = zero_shot_explainer(
85 | "I have a problem with my iphone that needs to be resolved asap!!",
86 | labels=labels,
87 | include_hypothesis=False,
88 | )
89 |
90 | for label in labels:
91 | assert len(word_attributions_with_hyp[label]) > len(word_attributions_without_hyp[label])
92 |
93 |
94 | def test_zero_shot_explainer_visualize():
95 | zero_shot_explainer = ZeroShotClassificationExplainer(
96 | DISTILBERT_MNLI_MODEL,
97 | DISTILBERT_MNLI_TOKENIZER,
98 | )
99 |
100 | zero_shot_explainer(
101 | "I have a problem with my iphone that needs to be resolved asap!!",
102 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
103 | )
104 | zero_shot_explainer.visualize()
105 |
106 |
107 | def test_zero_shot_explainer_visualize_save():
108 | zero_shot_explainer = ZeroShotClassificationExplainer(
109 | DISTILBERT_MNLI_MODEL,
110 | DISTILBERT_MNLI_TOKENIZER,
111 | )
112 |
113 | zero_shot_explainer(
114 | "I have a problem with my iphone that needs to be resolved asap!!",
115 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
116 | )
117 | html_filename = "./test/zero_test.html"
118 | zero_shot_explainer.visualize(html_filename)
119 | assert os.path.exists(html_filename)
120 | os.remove(html_filename)
121 |
122 |
123 | def test_zero_shot_explainer_visualize_include_hypothesis():
124 | zero_shot_explainer = ZeroShotClassificationExplainer(
125 | DISTILBERT_MNLI_MODEL,
126 | DISTILBERT_MNLI_TOKENIZER,
127 | )
128 |
129 | zero_shot_explainer(
130 | "I have a problem with my iphone that needs to be resolved asap!!",
131 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
132 | include_hypothesis=True,
133 | )
134 | zero_shot_explainer.visualize()
135 |
136 |
137 | def test_zero_explainer_visualize_save_append_html_file_ending():
138 | zero_shot_explainer = ZeroShotClassificationExplainer(
139 | DISTILBERT_MNLI_MODEL,
140 | DISTILBERT_MNLI_TOKENIZER,
141 | )
142 |
143 | zero_shot_explainer(
144 | "I have a problem with my iphone that needs to be resolved asap!!",
145 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
146 | )
147 |
148 | html_filename = "./test/zero_test"
149 | zero_shot_explainer.visualize(html_filename)
150 | assert os.path.exists(html_filename + ".html")
151 | os.remove(html_filename + ".html")
152 |
153 |
154 | def test_zero_shot_model_does_not_have_entailment_label():
155 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"l1": 0, "l2": 1, "l3": 2}):
156 | with pytest.raises(ValueError):
157 | ZeroShotClassificationExplainer(
158 | DISTILBERT_MNLI_MODEL,
159 | DISTILBERT_MNLI_TOKENIZER,
160 | )
161 |
162 |
163 | def test_zero_shot_model_uppercase_entailment():
164 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"ENTAILMENT": 0, "l2": 1, "l3": 2}):
165 | ZeroShotClassificationExplainer(
166 | DISTILBERT_MNLI_MODEL,
167 | DISTILBERT_MNLI_TOKENIZER,
168 | )
169 |
170 |
171 | def test_zero_shot_model_lowercase_entailment():
172 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"entailment": 0, "l2": 1, "l3": 2}):
173 | ZeroShotClassificationExplainer(
174 | DISTILBERT_MNLI_MODEL,
175 | DISTILBERT_MNLI_TOKENIZER,
176 | )
177 |
178 |
179 | def xtest_zero_shot_custom_steps():
180 | zero_shot_explainer = ZeroShotClassificationExplainer(
181 | DISTILBERT_MNLI_MODEL,
182 | DISTILBERT_MNLI_TOKENIZER,
183 | )
184 |
185 | zero_shot_explainer(
186 | "I have a problem with my iphone that needs to be resolved asap!!",
187 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
188 | n_steps=1,
189 | )
190 |
191 |
192 | def xtest_zero_shot_internal_batch_size():
193 | zero_shot_explainer = ZeroShotClassificationExplainer(
194 | DISTILBERT_MNLI_MODEL,
195 | DISTILBERT_MNLI_TOKENIZER,
196 | )
197 |
198 | zero_shot_explainer(
199 | "I have a problem with my iphone that needs to be resolved asap!!",
200 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"],
201 | internal_batch_size=1,
202 | )
203 |
--------------------------------------------------------------------------------
/test/vision/test_image_classification.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import requests
5 | from PIL import Image
6 | from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7 |
8 | from transformers_interpret import ImageClassificationExplainer
9 | from transformers_interpret.explainers.vision.attribution_types import AttributionType
10 |
11 | model_name = "apple/mobilevit-small"
12 | MODEL = AutoModelForImageClassification.from_pretrained(model_name)
13 | FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(model_name)
14 |
15 | IMAGE_LINK = "https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2342&q=80"
16 | TEST_IMAGE = Image.open(requests.get(IMAGE_LINK, stream=True).raw)
17 |
18 |
19 | def test_image_classification_init():
20 | img_cls_explainer = ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR)
21 |
22 | assert img_cls_explainer.model == MODEL
23 | assert img_cls_explainer.feature_extractor == FEATURE_EXTRACTOR
24 | assert img_cls_explainer.id2label == MODEL.config.id2label
25 | assert img_cls_explainer.label2id == MODEL.config.label2id
26 |
27 | assert img_cls_explainer.attributions is None
28 |
29 |
30 | def test_image_classification_init_attribution_type_not_supported():
31 | with pytest.raises(ValueError):
32 | ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR, attribution_type="not_supported")
33 |
34 |
35 | def test_image_classification_init_custom_labels():
36 | labels = [f"label_{i}" for i in range(len(MODEL.config.id2label) - 1)]
37 | img_cls_explainer = ImageClassificationExplainer(
38 | model=MODEL, feature_extractor=FEATURE_EXTRACTOR, custom_labels=labels
39 | )
40 |
41 | assert list(img_cls_explainer.label2id.keys()) == labels
42 |
43 |
44 | def test_image_classification_init_custom_labels_not_valid():
45 | with pytest.raises(ValueError):
46 | ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR, custom_labels=["label_0"])
47 |
48 |
49 | def test_image_classification_call():
50 | img_cls_explainer = ImageClassificationExplainer(
51 | model=MODEL,
52 | feature_extractor=FEATURE_EXTRACTOR,
53 | )
54 |
55 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
56 |
57 | assert img_cls_explainer.attributions is not None
58 | assert img_cls_explainer.predicted_index is not None
59 | assert img_cls_explainer.n_steps == 1
60 | assert img_cls_explainer.n_steps_noise_tunnel == 1
61 | assert img_cls_explainer.noise_tunnel_n_samples == 1
62 | assert img_cls_explainer.internal_batch_size == 1
63 |
64 |
65 | def test_image_classification_call_attribution_type_not_supported():
66 | img_cls_explainer = ImageClassificationExplainer(
67 | model=MODEL,
68 | feature_extractor=FEATURE_EXTRACTOR,
69 | )
70 |
71 | with pytest.raises(ValueError):
72 | img_cls_explainer(
73 | TEST_IMAGE,
74 | n_steps=1,
75 | n_steps_noise_tunnel=1,
76 | noise_tunnel_n_samples=1,
77 | internal_batch_size=1,
78 | noise_tunnel_type="not_supported",
79 | )
80 |
81 |
82 | def test_image_classification_visualize():
83 | img_cls_explainer = ImageClassificationExplainer(
84 | model=MODEL,
85 | feature_extractor=FEATURE_EXTRACTOR,
86 | )
87 |
88 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
89 |
90 | img_cls_explainer.visualize(method="alpha_scaling", save_path=None, sign="all", outlier_threshold=0.15)
91 |
92 |
93 | def test_image_classification_visualize_use_normed_pixel_values():
94 | img_cls_explainer = ImageClassificationExplainer(
95 | model=MODEL,
96 | feature_extractor=FEATURE_EXTRACTOR,
97 | )
98 |
99 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
100 |
101 | img_cls_explainer.visualize(
102 | method="overlay", save_path=None, sign="all", outlier_threshold=0.15, use_original_image_pixels=False
103 | )
104 |
105 |
106 | def test_image_classification_visualize_wrt_class_name():
107 | img_cls_explainer = ImageClassificationExplainer(
108 | model=MODEL,
109 | feature_extractor=FEATURE_EXTRACTOR,
110 | )
111 |
112 | img_cls_explainer(
113 | TEST_IMAGE,
114 | n_steps=1,
115 | n_steps_noise_tunnel=1,
116 | noise_tunnel_n_samples=1,
117 | internal_batch_size=1,
118 | class_name="banana",
119 | )
120 |
121 | img_cls_explainer.visualize(method="heatmap", save_path=None, sign="all", outlier_threshold=0.15)
122 |
123 |
124 | def test_image_classification_visualize_wrt_class_index():
125 | img_cls_explainer = ImageClassificationExplainer(
126 | model=MODEL,
127 | feature_extractor=FEATURE_EXTRACTOR,
128 | )
129 |
130 | img_cls_explainer(
131 | TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1, index=3
132 | )
133 |
134 | img_cls_explainer.visualize(method="overlay", save_path=None, sign="all", outlier_threshold=0.15)
135 |
136 |
137 | def test_image_classification_visualize_integrated_gradients_no_noise_tunnel():
138 | img_cls_explainer = ImageClassificationExplainer(
139 | model=MODEL, feature_extractor=FEATURE_EXTRACTOR, attribution_type=AttributionType.INTEGRATED_GRADIENTS.value
140 | )
141 |
142 | img_cls_explainer(
143 | TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1, index=3
144 | )
145 |
146 | img_cls_explainer.visualize(method="masked_image", save_path=None, sign="all", outlier_threshold=0.15)
147 |
148 |
149 | def test_image_classification_visualize_save_image():
150 | img_cls_explainer = ImageClassificationExplainer(
151 | model=MODEL,
152 | feature_extractor=FEATURE_EXTRACTOR,
153 | )
154 |
155 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
156 |
157 | img_cls_explainer.visualize(
158 | method="overlay",
159 | save_path="./test.png",
160 | sign="all",
161 | outlier_threshold=0.15,
162 | side_by_side=True,
163 | )
164 |
165 | os.remove("./test.png")
166 |
167 |
168 | def test_image_classification_visualize_positive_sign_for_unsupported_methods():
169 | img_cls_explainer = ImageClassificationExplainer(
170 | model=MODEL,
171 | feature_extractor=FEATURE_EXTRACTOR,
172 | )
173 |
174 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
175 |
176 | img_cls_explainer.visualize(
177 | method="masked_image",
178 | save_path="./test.png",
179 | sign="all",
180 | outlier_threshold=0.15,
181 | side_by_side=True,
182 | )
183 |
184 | os.remove("./test.png")
185 |
186 |
187 | def test_image_classification_visualize_unsupported_viz_method():
188 | img_cls_explainer = ImageClassificationExplainer(
189 | model=MODEL,
190 | feature_extractor=FEATURE_EXTRACTOR,
191 | )
192 |
193 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
194 |
195 | with pytest.raises(ValueError):
196 | img_cls_explainer.visualize(method="not_supported", save_path=None, sign="all", outlier_threshold=0.15)
197 |
198 |
199 | def test_image_classification_visualize_unsupported_sign():
200 | img_cls_explainer = ImageClassificationExplainer(
201 | model=MODEL,
202 | feature_extractor=FEATURE_EXTRACTOR,
203 | )
204 |
205 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
206 |
207 | with pytest.raises(ValueError):
208 | img_cls_explainer.visualize(method="overlay", save_path=None, sign="not_supported", outlier_threshold=0.15)
209 |
210 |
211 | def test_image_classification_visualize_side_by_side():
212 | img_cls_explainer = ImageClassificationExplainer(
213 | model=MODEL,
214 | feature_extractor=FEATURE_EXTRACTOR,
215 | )
216 |
217 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1)
218 |
219 | methods = ["overlay", "heatmap", "masked_image", "alpha_scaling"]
220 | for method in methods:
221 | img_cls_explainer.visualize(
222 | method=method, save_path=None, sign="all", outlier_threshold=0.15, side_by_side=True
223 | )
224 |
--------------------------------------------------------------------------------
/transformers_interpret/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/transformers_interpret/.DS_Store
--------------------------------------------------------------------------------
/transformers_interpret/__init__.py:
--------------------------------------------------------------------------------
1 | from .attributions import Attributions, LIGAttributions # noqa: F401
2 | from .explainer import BaseExplainer # noqa: F401
3 | from .explainers.text.multilabel_classification import ( # noqa: F401
4 | MultiLabelClassificationExplainer,
5 | )
6 | from .explainers.text.question_answering import QuestionAnsweringExplainer # noqa: F401
7 | from .explainers.text.sequence_classification import ( # noqa: F401
8 | PairwiseSequenceClassificationExplainer,
9 | SequenceClassificationExplainer,
10 | )
11 | from .explainers.text.token_classification import ( # noqa: F401
12 | TokenClassificationExplainer,
13 | )
14 | from .explainers.text.zero_shot_classification import ( # noqa: F401
15 | ZeroShotClassificationExplainer,
16 | )
17 | from .explainers.vision.image_classification import ( # noqa: F401
18 | ImageClassificationExplainer,
19 | )
20 |
--------------------------------------------------------------------------------
/transformers_interpret/attributions.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from captum.attr import LayerIntegratedGradients
6 | from captum.attr import visualization as viz
7 |
8 | from transformers_interpret.errors import AttributionsNotCalculatedError
9 |
10 |
11 | class Attributions:
12 | def __init__(self, custom_forward: Callable, embeddings: nn.Module, tokens: list):
13 | self.custom_forward = custom_forward
14 | self.embeddings = embeddings
15 | self.tokens = tokens
16 |
17 |
18 | class LIGAttributions(Attributions):
19 | def __init__(
20 | self,
21 | custom_forward: Callable,
22 | embeddings: nn.Module,
23 | tokens: list,
24 | input_ids: torch.Tensor,
25 | ref_input_ids: torch.Tensor,
26 | sep_id: int,
27 | attention_mask: torch.Tensor,
28 | target: Optional[Union[int, Tuple, torch.Tensor, List]] = None,
29 | token_type_ids: Optional[torch.Tensor] = None,
30 | position_ids: Optional[torch.Tensor] = None,
31 | ref_token_type_ids: Optional[torch.Tensor] = None,
32 | ref_position_ids: Optional[torch.Tensor] = None,
33 | internal_batch_size: Optional[int] = None,
34 | n_steps: int = 50,
35 | ):
36 | super().__init__(custom_forward, embeddings, tokens)
37 | self.input_ids = input_ids
38 | self.ref_input_ids = ref_input_ids
39 | self.attention_mask = attention_mask
40 | self.target = target
41 | self.token_type_ids = token_type_ids
42 | self.position_ids = position_ids
43 | self.ref_token_type_ids = ref_token_type_ids
44 | self.ref_position_ids = ref_position_ids
45 | self.internal_batch_size = internal_batch_size
46 | self.n_steps = n_steps
47 |
48 | self.lig = LayerIntegratedGradients(self.custom_forward, self.embeddings)
49 |
50 | if self.token_type_ids is not None and self.position_ids is not None:
51 | self._attributions, self.delta = self.lig.attribute(
52 | inputs=(self.input_ids, self.token_type_ids, self.position_ids),
53 | baselines=(
54 | self.ref_input_ids,
55 | self.ref_token_type_ids,
56 | self.ref_position_ids,
57 | ),
58 | target=self.target,
59 | return_convergence_delta=True,
60 | additional_forward_args=(self.attention_mask),
61 | internal_batch_size=self.internal_batch_size,
62 | n_steps=self.n_steps,
63 | )
64 | elif self.position_ids is not None:
65 | self._attributions, self.delta = self.lig.attribute(
66 | inputs=(self.input_ids, self.position_ids),
67 | baselines=(
68 | self.ref_input_ids,
69 | self.ref_position_ids,
70 | ),
71 | target=self.target,
72 | return_convergence_delta=True,
73 | additional_forward_args=(self.attention_mask),
74 | internal_batch_size=self.internal_batch_size,
75 | n_steps=self.n_steps,
76 | )
77 | elif self.token_type_ids is not None:
78 | self._attributions, self.delta = self.lig.attribute(
79 | inputs=(self.input_ids, self.token_type_ids),
80 | baselines=(
81 | self.ref_input_ids,
82 | self.ref_token_type_ids,
83 | ),
84 | target=self.target,
85 | return_convergence_delta=True,
86 | additional_forward_args=(self.attention_mask),
87 | internal_batch_size=self.internal_batch_size,
88 | n_steps=self.n_steps,
89 | )
90 |
91 | else:
92 | self._attributions, self.delta = self.lig.attribute(
93 | inputs=self.input_ids,
94 | baselines=self.ref_input_ids,
95 | target=self.target,
96 | return_convergence_delta=True,
97 | internal_batch_size=self.internal_batch_size,
98 | n_steps=self.n_steps,
99 | )
100 |
101 | @property
102 | def word_attributions(self) -> list:
103 | wa = []
104 | if len(self.attributions_sum) >= 1:
105 | for i, (word, attribution) in enumerate(zip(self.tokens, self.attributions_sum)):
106 | wa.append((word, float(attribution.cpu().data.numpy())))
107 | return wa
108 |
109 | else:
110 | raise AttributionsNotCalculatedError("Attributions are not yet calculated")
111 |
112 | def summarize(self, end_idx=None, flip_sign: bool = False):
113 | if flip_sign:
114 | multiplier = -1
115 | else:
116 | multiplier = 1
117 | self.attributions_sum = self._attributions.sum(dim=-1).squeeze(0) * multiplier
118 | self.attributions_sum = self.attributions_sum[:end_idx] / torch.norm(self.attributions_sum[:end_idx])
119 |
120 | def visualize_attributions(self, pred_prob, pred_class, true_class, attr_class, all_tokens):
121 |
122 | return viz.VisualizationDataRecord(
123 | self.attributions_sum,
124 | pred_prob,
125 | pred_class,
126 | true_class,
127 | attr_class,
128 | self.attributions_sum.sum(),
129 | all_tokens,
130 | self.delta,
131 | )
132 |
--------------------------------------------------------------------------------
/transformers_interpret/errors.py:
--------------------------------------------------------------------------------
1 | class AttributionTypeNotSupportedError(RuntimeError):
2 | "Raised when a particular attribution type is not yet supported by an explainer"
3 |
4 |
5 | class AttributionsNotCalculatedError(RuntimeError):
6 | "Raised when a user attempts to access the attributions for a model and sequence before they have be been summarized"
7 |
8 |
9 | class InputIdsNotCalculatedError(RuntimeError):
10 | "Raised when a user attempts to call a method or attribute that requires input ids"
11 |
--------------------------------------------------------------------------------
/transformers_interpret/explainer.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import re
3 | from abc import ABC, abstractmethod, abstractproperty
4 | from typing import List, Tuple, Union
5 |
6 | import torch
7 | from transformers import PreTrainedModel, PreTrainedTokenizer
8 |
9 |
10 | class BaseExplainer(ABC):
11 | def __init__(
12 | self,
13 | model: PreTrainedModel,
14 | tokenizer: PreTrainedTokenizer,
15 | ):
16 | self.model = model
17 | self.tokenizer = tokenizer
18 |
19 | if self.model.config.model_type == "gpt2":
20 | self.ref_token_id = self.tokenizer.eos_token_id
21 | else:
22 | self.ref_token_id = self.tokenizer.pad_token_id
23 |
24 | self.sep_token_id = (
25 | self.tokenizer.sep_token_id if self.tokenizer.sep_token_id is not None else self.tokenizer.eos_token_id
26 | )
27 | self.cls_token_id = (
28 | self.tokenizer.cls_token_id if self.tokenizer.cls_token_id is not None else self.tokenizer.bos_token_id
29 | )
30 |
31 | self.model_prefix = model.base_model_prefix
32 |
33 | nonstandard_model_types = ["roberta"]
34 | if (
35 | self._model_forward_signature_accepts_parameter("position_ids")
36 | and self.model.config.model_type not in nonstandard_model_types
37 | ):
38 | self.accepts_position_ids = True
39 | else:
40 | self.accepts_position_ids = False
41 |
42 | if (
43 | self._model_forward_signature_accepts_parameter("token_type_ids")
44 | and self.model.config.model_type not in nonstandard_model_types
45 | ):
46 | self.accepts_token_type_ids = True
47 | else:
48 | self.accepts_token_type_ids = False
49 |
50 | self.device = self.model.device
51 |
52 | self.word_embeddings = self.model.get_input_embeddings()
53 | self.position_embeddings = None
54 | self.token_type_embeddings = None
55 |
56 | self._set_available_embedding_types()
57 |
58 | @abstractmethod
59 | def encode(self, text: str = None):
60 | """
61 | Encode given text with a model's tokenizer.
62 | """
63 | raise NotImplementedError
64 |
65 | @abstractmethod
66 | def decode(self, input_ids: torch.Tensor) -> List[str]:
67 | """
68 | Decode received input_ids into a list of word tokens.
69 |
70 |
71 | Args:
72 | input_ids (torch.Tensor): Input ids representing
73 | word tokens for a sentence/document.
74 |
75 | """
76 | raise NotImplementedError
77 |
78 | @abstractproperty
79 | def word_attributions(self):
80 | raise NotImplementedError
81 |
82 | @abstractmethod
83 | def _run(self) -> list:
84 | raise NotImplementedError
85 |
86 | @abstractmethod
87 | def _forward(self):
88 | """
89 | Forward defines a function for passing inputs
90 | through a models's forward method.
91 |
92 | """
93 | raise NotImplementedError
94 |
95 | @abstractmethod
96 | def _calculate_attributions(self):
97 | """
98 | Internal method for calculating the attribution
99 | values for the input text.
100 |
101 | """
102 | raise NotImplementedError
103 |
104 | def _make_input_reference_pair(self, text: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor, int]:
105 | """
106 | Tokenizes `text` to numerical token id representation `input_ids`,
107 | as well as creating another reference tensor `ref_input_ids` of the same length
108 | that will be used as baseline for attributions. Additionally
109 | the length of text without special tokens appended is prepended is also
110 | returned.
111 |
112 | Args:
113 | text (str): Text for which we are creating both input ids
114 | and their corresponding reference ids
115 |
116 | Returns:
117 | Tuple[torch.Tensor, torch.Tensor, int]
118 | """
119 |
120 | if isinstance(text, list):
121 | raise NotImplementedError("Lists of text are not currently supported.")
122 |
123 | text_ids = self.encode(text)
124 | input_ids = self.tokenizer.encode(text, add_special_tokens=True)
125 |
126 | # if no special tokens were added
127 | if len(text_ids) == len(input_ids):
128 | ref_input_ids = [self.ref_token_id] * len(text_ids)
129 | else:
130 | ref_input_ids = [self.cls_token_id] + [self.ref_token_id] * len(text_ids) + [self.sep_token_id]
131 |
132 | return (
133 | torch.tensor([input_ids], device=self.device),
134 | torch.tensor([ref_input_ids], device=self.device),
135 | len(text_ids),
136 | )
137 |
138 | def _make_input_reference_token_type_pair(
139 | self, input_ids: torch.Tensor, sep_idx: int = 0
140 | ) -> Tuple[torch.Tensor, torch.Tensor]:
141 | """
142 | Returns two tensors indicating the corresponding token types for the `input_ids`
143 | and a corresponding all zero reference token type tensor.
144 | Args:
145 | input_ids (torch.Tensor): Tensor of text converted to `input_ids`
146 | sep_idx (int, optional): Defaults to 0.
147 |
148 | Returns:
149 | Tuple[torch.Tensor, torch.Tensor]
150 | """
151 | seq_len = input_ids.size(1)
152 | token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as(
153 | input_ids
154 | )
155 | ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids)
156 |
157 | return (token_type_ids, ref_token_type_ids)
158 |
159 | def _make_input_reference_position_id_pair(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
160 | """
161 | Returns tensors for positional encoding of tokens for input_ids and zeroed tensor for reference ids.
162 |
163 | Args:
164 | input_ids (torch.Tensor): inputs to create positional encoding.
165 |
166 | Returns:
167 | Tuple[torch.Tensor, torch.Tensor]
168 | """
169 | seq_len = input_ids.size(1)
170 | position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device)
171 | ref_position_ids = torch.zeros(seq_len, dtype=torch.long, device=self.device)
172 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
173 | ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
174 | return (position_ids, ref_position_ids)
175 |
176 | def _make_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
177 | return torch.ones_like(input_ids)
178 |
179 | def _get_preds(
180 | self,
181 | input_ids: torch.Tensor,
182 | token_type_ids=None,
183 | position_ids: torch.Tensor = None,
184 | attention_mask: torch.Tensor = None,
185 | ):
186 |
187 | if self.accepts_position_ids and self.accepts_token_type_ids:
188 | preds = self.model(
189 | input_ids=input_ids,
190 | token_type_ids=token_type_ids,
191 | position_ids=position_ids,
192 | attention_mask=attention_mask,
193 | )
194 | return preds
195 |
196 | elif self.accepts_position_ids:
197 | preds = self.model(
198 | input_ids=input_ids,
199 | position_ids=position_ids,
200 | attention_mask=attention_mask,
201 | )
202 |
203 | return preds
204 | elif self.accepts_token_type_ids:
205 | preds = self.model(
206 | input_ids=input_ids,
207 | token_type_ids=token_type_ids,
208 | attention_mask=attention_mask,
209 | )
210 |
211 | return preds
212 | else:
213 | preds = self.model(
214 | input_ids=input_ids,
215 | attention_mask=attention_mask,
216 | )
217 |
218 | return preds
219 |
220 | def _clean_text(self, text: str) -> str:
221 | text = re.sub("([.,!?()])", r" \1 ", text)
222 | text = re.sub("\s{2,}", " ", text)
223 | return text
224 |
225 | def _model_forward_signature_accepts_parameter(self, parameter: str) -> bool:
226 | signature = inspect.signature(self.model.forward)
227 | parameters = signature.parameters
228 | return parameter in parameters
229 |
230 | def _set_available_embedding_types(self):
231 | model_base = getattr(self.model, self.model_prefix)
232 | if self.model.config.model_type == "gpt2" and hasattr(model_base, "wpe"):
233 | self.position_embeddings = model_base.wpe.weight
234 | else:
235 | if hasattr(model_base, "embeddings"):
236 | self.model_embeddings = getattr(model_base, "embeddings")
237 | if hasattr(self.model_embeddings, "position_embeddings"):
238 | self.position_embeddings = self.model_embeddings.position_embeddings
239 | if hasattr(self.model_embeddings, "token_type_embeddings"):
240 | self.token_type_embeddings = self.model_embeddings.token_type_embeddings
241 |
242 | def __str__(self):
243 | s = f"{self.__class__.__name__}("
244 | s += f"\n\tmodel={self.model.__class__.__name__},"
245 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__}"
246 | s += ")"
247 |
248 | return s
249 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/__init__.py:
--------------------------------------------------------------------------------
1 | from .multilabel_classification import MultiLabelClassificationExplainer # noqa: F401
2 | from .question_answering import QuestionAnsweringExplainer # noqa: F401
3 | from .sequence_classification import ( # noqa: F401
4 | PairwiseSequenceClassificationExplainer,
5 | SequenceClassificationExplainer,
6 | )
7 | from .token_classification import TokenClassificationExplainer # noqa: F401
8 | from .zero_shot_classification import ZeroShotClassificationExplainer # noqa: F401
9 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/multilabel_classification.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | from captum.attr import visualization as viz
5 | from transformers import PreTrainedModel, PreTrainedTokenizer
6 |
7 | from .sequence_classification import SequenceClassificationExplainer
8 |
9 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"]
10 |
11 |
12 | class MultiLabelClassificationExplainer(SequenceClassificationExplainer):
13 | """
14 | Explainer for independently explaining label attributions in a multi-label fashion
15 | for models of type `{MODEL_NAME}ForSequenceClassification` from the Transformers package.
16 | Every label is explained independently and the word attributions are a dictionary of labels
17 | mapping to the word attributions for that label. Even if the model itself is not multi-label
18 | by the resulting word attributions treat the labels as independent.
19 |
20 | Calculates attribution for `text` using the given model
21 | and tokenizer. Since this is a multi-label explainer, the attribution calculation time scales
22 | linearly with the number of labels.
23 |
24 | This explainer also allows for attributions with respect to a particlar embedding type.
25 | This can be selected by passing a `embedding_type`. The default value is `0` which
26 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
27 | If a model does not take position ids in its forward method (distilbert) a warning will
28 | occur and the default word_embeddings will be chosen instead.
29 | """
30 |
31 | def __init__(
32 | self,
33 | model: PreTrainedModel,
34 | tokenizer: PreTrainedTokenizer,
35 | attribution_type="lig",
36 | custom_labels: Optional[List[str]] = None,
37 | ):
38 | super().__init__(model, tokenizer, attribution_type, custom_labels)
39 | self.labels = []
40 |
41 | @property
42 | def word_attributions(self) -> dict:
43 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
44 | if self.attributions != [] and self.labels != []:
45 |
46 | return dict(
47 | zip(
48 | self.labels,
49 | [attr.word_attributions for attr in self.attributions],
50 | )
51 | )
52 |
53 | else:
54 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
55 |
56 | def visualize(self, html_filepath: str = None, true_class: str = None):
57 | """
58 | Visualizes word attributions. If in a notebook table will be displayed inline.
59 |
60 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved
61 | as a html file.
62 |
63 | If the true class is known for the text that can be passed to `true_class`
64 |
65 | """
66 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
67 |
68 | score_viz = [
69 | self.attributions[i].visualize_attributions( # type: ignore
70 | self.pred_probs_list[i],
71 | "", # including a predicted class name does not make sense for this explainer
72 | "n/a" if not true_class else true_class, # no true class name for this explainer by default
73 | self.labels[i],
74 | tokens,
75 | )
76 | for i in range(len(self.attributions))
77 | ]
78 |
79 | html = viz.visualize_text(score_viz)
80 |
81 | new_html_data = html._repr_html_().replace("Predicted Label", "Prediction Score")
82 | new_html_data = new_html_data.replace("True Label", "n/a")
83 | html.data = new_html_data
84 |
85 | if html_filepath:
86 | if not html_filepath.endswith(".html"):
87 | html_filepath = html_filepath + ".html"
88 | with open(html_filepath, "w") as html_file:
89 | html_file.write(html.data)
90 | return html
91 |
92 | def _forward( # type: ignore
93 | self,
94 | input_ids: torch.Tensor,
95 | token_type_ids=None,
96 | position_ids: torch.Tensor = None,
97 | attention_mask: torch.Tensor = None,
98 | ):
99 |
100 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
101 | preds = preds[0]
102 |
103 | # if it is a single output node
104 | if len(preds[0]) == 1:
105 | self._single_node_output = True
106 | self.pred_probs = torch.sigmoid(preds)[0][0]
107 | return torch.sigmoid(preds)[:, :]
108 |
109 | self.pred_probs = torch.sigmoid(preds)[0][self.selected_index]
110 | return torch.sigmoid(preds)[:, self.selected_index]
111 |
112 | def __call__(
113 | self,
114 | text: str,
115 | embedding_type: int = 0,
116 | internal_batch_size: int = None,
117 | n_steps: int = None,
118 | ) -> dict:
119 | """
120 | Calculates attributions for `text` using the model
121 | and tokenizer given in the constructor. Attributions are calculated for
122 | every label output in the model.
123 |
124 | This explainer also allows for attributions with respect to a particlar embedding type.
125 | This can be selected by passing a `embedding_type`. The default value is `0` which
126 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
127 | If a model does not take position ids in its forward method (distilbert) a warning will
128 | occur and the default word_embeddings will be chosen instead.
129 |
130 | Args:
131 | text (str): Text to provide attributions for.
132 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0.
133 | internal_batch_size (int, optional): Divides total #steps * #examples
134 | data points into chunks of size at most internal_batch_size,
135 | which are computed (forward / backward passes)
136 | sequentially. If internal_batch_size is None, then all evaluations are
137 | processed in one batch.
138 | n_steps (int, optional): The number of steps used by the approximation
139 | method. Default: 50.
140 |
141 | Returns:
142 | dict: A dictionary of label to list of attributions.
143 | """
144 | if n_steps:
145 | self.n_steps = n_steps
146 | if internal_batch_size:
147 | self.internal_batch_size = internal_batch_size
148 |
149 | self.attributions = []
150 | self.pred_probs_list = []
151 | self.labels = [item[0] for item in sorted(self.label2id.items(), key=lambda x: x[1])]
152 | self.label_probs_dict = {}
153 | for i in range(self.model.config.num_labels):
154 | explainer = SequenceClassificationExplainer(
155 | self.model,
156 | self.tokenizer,
157 | )
158 | self.selected_index = i
159 | explainer._forward = self._forward
160 | explainer(text, i, embedding_type)
161 |
162 | self.attributions.append(explainer.attributions)
163 | self.input_ids = explainer.input_ids
164 | self.pred_probs_list.append(self.pred_probs)
165 | self.label_probs_dict[self.id2label[i]] = self.pred_probs
166 |
167 | return self.word_attributions
168 |
169 | def __str__(self):
170 | s = f"{self.__class__.__name__}("
171 | s += f"\n\tmodel={self.model.__class__.__name__},"
172 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__},"
173 | s += f"\n\tattribution_type='{self.attribution_type}',"
174 | s += f"\n\tcustom_labels={self.custom_labels},"
175 | s += ")"
176 |
177 | return s
178 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/question_answering.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Union
3 |
4 | import torch
5 | from captum.attr import visualization as viz
6 | from torch.nn.modules.sparse import Embedding
7 | from transformers import PreTrainedModel, PreTrainedTokenizer
8 |
9 | from transformers_interpret import BaseExplainer, LIGAttributions
10 | from transformers_interpret.errors import (
11 | AttributionTypeNotSupportedError,
12 | InputIdsNotCalculatedError,
13 | )
14 |
15 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"]
16 |
17 |
18 | class QuestionAnsweringExplainer(BaseExplainer):
19 | """
20 | Explainer for explaining attributions for models of type `{MODEL_NAME}ForQuestionAnswering`
21 | from the Transformers package.
22 | """
23 |
24 | def __init__(
25 | self,
26 | model: PreTrainedModel,
27 | tokenizer: PreTrainedTokenizer,
28 | attribution_type: str = "lig",
29 | ):
30 | """
31 | Args:
32 | model (PreTrainedModel): Pretrained huggingface Question Answering model.
33 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer
34 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig".
35 |
36 | Raises:
37 | AttributionTypeNotSupportedError: [description]
38 | """
39 | super().__init__(model, tokenizer)
40 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES:
41 | raise AttributionTypeNotSupportedError(
42 | f"""Attribution type '{attribution_type}' is not supported.
43 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}"""
44 | )
45 | self.attribution_type = attribution_type
46 |
47 | self.attributions: Union[None, LIGAttributions] = None
48 | self.start_attributions = None
49 | self.end_attributions = None
50 | self.input_ids: torch.Tensor = torch.Tensor()
51 |
52 | self.position = 0
53 |
54 | self.internal_batch_size = None
55 | self.n_steps = 50
56 |
57 | def encode(self, text: str) -> list: # type: ignore
58 | "Encode 'text' using tokenizer, special tokens are not added"
59 | return self.tokenizer.encode(text, add_special_tokens=False)
60 |
61 | def decode(self, input_ids: torch.Tensor) -> list:
62 | "Decode 'input_ids' to string using tokenizer"
63 | return self.tokenizer.convert_ids_to_tokens(input_ids[0])
64 |
65 | @property
66 | def word_attributions(self) -> dict:
67 | """
68 | Returns the word attributions (as `dict`) for both start and end positions of QA model.
69 |
70 | Raises error if attributions not calculated.
71 |
72 | """
73 | if self.start_attributions is not None and self.end_attributions is not None:
74 | return {
75 | "start": self.start_attributions.word_attributions,
76 | "end": self.end_attributions.word_attributions,
77 | }
78 |
79 | else:
80 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
81 |
82 | @property
83 | def start_pos(self):
84 | "Returns predicted start position for answer"
85 | if len(self.input_ids) > 0:
86 | preds = self._get_preds(
87 | self.input_ids,
88 | self.token_type_ids,
89 | self.position_ids,
90 | self.attention_mask,
91 | )
92 |
93 | preds = preds[0]
94 | return int(preds.argmax())
95 | else:
96 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`")
97 |
98 | @property
99 | def end_pos(self):
100 | "Returns predicted end position for answer"
101 | if len(self.input_ids) > 0:
102 | preds = self._get_preds(
103 | self.input_ids,
104 | self.token_type_ids,
105 | self.position_ids,
106 | self.attention_mask,
107 | )
108 |
109 | preds = preds[1]
110 | return int(preds.argmax())
111 | else:
112 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`")
113 |
114 | @property
115 | def predicted_answer(self):
116 | "Returns predicted answer span from provided `text`"
117 | if len(self.input_ids) > 0:
118 | preds = self._get_preds(
119 | self.input_ids,
120 | self.token_type_ids,
121 | self.position_ids,
122 | self.attention_mask,
123 | )
124 |
125 | start = preds[0].argmax()
126 | end = preds[1].argmax()
127 | return " ".join(self.decode(self.input_ids)[start : end + 1])
128 | else:
129 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`")
130 |
131 | def visualize(self, html_filepath: str = None):
132 | """
133 | Visualizes word attributions. If in a notebook table will be displayed inline.
134 |
135 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved
136 | as a html file.
137 | """
138 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
139 | predicted_answer = self.predicted_answer
140 |
141 | self.position = 0
142 | start_pred_probs = self._forward(self.input_ids, self.token_type_ids, self.position_ids)
143 | start_pos = self.start_pos
144 | start_pos_str = tokens[start_pos] + " (" + str(start_pos) + ")"
145 | start_score_viz = self.start_attributions.visualize_attributions(
146 | float(start_pred_probs),
147 | str(predicted_answer),
148 | start_pos_str,
149 | start_pos_str,
150 | tokens,
151 | )
152 |
153 | self.position = 1
154 |
155 | end_pred_probs = self._forward(self.input_ids, self.token_type_ids, self.position_ids)
156 | end_pos = self.end_pos
157 | end_pos_str = tokens[end_pos] + " (" + str(end_pos) + ")"
158 | end_score_viz = self.end_attributions.visualize_attributions(
159 | float(end_pred_probs),
160 | str(predicted_answer),
161 | end_pos_str,
162 | end_pos_str,
163 | tokens,
164 | )
165 |
166 | html = viz.visualize_text([start_score_viz, end_score_viz])
167 |
168 | if html_filepath:
169 | if not html_filepath.endswith(".html"):
170 | html_filepath = html_filepath + ".html"
171 | with open(html_filepath, "w") as html_file:
172 | html_file.write(html.data)
173 | return html
174 |
175 | def _make_input_reference_pair(self, question: str, text: str): # type: ignore
176 | question_ids = self.encode(question)
177 | text_ids = self.encode(text)
178 |
179 | input_ids = [self.cls_token_id] + question_ids + [self.sep_token_id] + text_ids + [self.sep_token_id]
180 |
181 | ref_input_ids = (
182 | [self.cls_token_id]
183 | + [self.ref_token_id] * len(question_ids)
184 | + [self.sep_token_id]
185 | + [self.ref_token_id] * len(text_ids)
186 | + [self.sep_token_id]
187 | )
188 |
189 | return (
190 | torch.tensor([input_ids], device=self.device),
191 | torch.tensor([ref_input_ids], device=self.device),
192 | len(question_ids),
193 | )
194 |
195 | def _get_preds(
196 | self,
197 | input_ids: torch.Tensor,
198 | token_type_ids=None,
199 | position_ids: torch.Tensor = None,
200 | attention_mask: torch.Tensor = None,
201 | ):
202 | if self.accepts_position_ids and self.accepts_token_type_ids:
203 | preds = self.model(
204 | input_ids,
205 | token_type_ids=token_type_ids,
206 | position_ids=position_ids,
207 | attention_mask=attention_mask,
208 | )
209 |
210 | return preds
211 |
212 | elif self.accepts_position_ids:
213 | preds = self.model(
214 | input_ids,
215 | position_ids=position_ids,
216 | attention_mask=attention_mask,
217 | )
218 |
219 | return preds
220 | elif self.accepts_token_type_ids:
221 | preds = self.model(
222 | input_ids,
223 | token_type_ids=token_type_ids,
224 | attention_mask=attention_mask,
225 | )
226 |
227 | return preds
228 | else:
229 | preds = self.model(
230 | input_ids,
231 | attention_mask=attention_mask,
232 | )
233 |
234 | return preds
235 |
236 | def _forward( # type: ignore
237 | self,
238 | input_ids: torch.Tensor,
239 | token_type_ids=None,
240 | position_ids: torch.Tensor = None,
241 | attention_mask: torch.Tensor = None,
242 | ):
243 |
244 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
245 |
246 | preds = preds[self.position]
247 |
248 | return preds.max(1).values
249 |
250 | def _run(self, question: str, text: str, embedding_type: int) -> dict:
251 | if embedding_type == 0:
252 | embeddings = self.word_embeddings
253 | try:
254 | if embedding_type == 1:
255 | if self.accepts_position_ids and self.position_embeddings is not None:
256 | embeddings = self.position_embeddings
257 | else:
258 | warnings.warn(
259 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings"
260 | )
261 | embeddings = self.word_embeddings
262 | elif embedding_type == 2:
263 | embeddings = self.model_embeddings
264 |
265 | else:
266 | embeddings = self.word_embeddings
267 | except Exception:
268 | warnings.warn(
269 | "This model doesn't support the embedding type you selected for attributions. Defaulting to word embeddings"
270 | )
271 | embeddings = self.word_embeddings
272 |
273 | self.question = question
274 | self.text = text
275 |
276 | self._calculate_attributions(embeddings)
277 | return self.word_attributions
278 |
279 | def _calculate_attributions(self, embeddings: Embedding): # type: ignore
280 |
281 | (
282 | self.input_ids,
283 | self.ref_input_ids,
284 | self.sep_idx,
285 | ) = self._make_input_reference_pair(self.question, self.text)
286 |
287 | (
288 | self.position_ids,
289 | self.ref_position_ids,
290 | ) = self._make_input_reference_position_id_pair(self.input_ids)
291 |
292 | (
293 | self.token_type_ids,
294 | self.ref_token_type_ids,
295 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)
296 |
297 | self.attention_mask = self._make_attention_mask(self.input_ids)
298 |
299 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
300 | self.position = 0
301 | start_lig = LIGAttributions(
302 | self._forward,
303 | embeddings,
304 | reference_tokens,
305 | self.input_ids,
306 | self.ref_input_ids,
307 | self.sep_idx,
308 | self.attention_mask,
309 | position_ids=self.position_ids,
310 | ref_position_ids=self.ref_position_ids,
311 | token_type_ids=self.token_type_ids,
312 | ref_token_type_ids=self.ref_token_type_ids,
313 | internal_batch_size=self.internal_batch_size,
314 | n_steps=self.n_steps,
315 | )
316 | start_lig.summarize()
317 | self.start_attributions = start_lig
318 |
319 | self.position = 1
320 | end_lig = LIGAttributions(
321 | self._forward,
322 | embeddings,
323 | reference_tokens,
324 | self.input_ids,
325 | self.ref_input_ids,
326 | self.sep_idx,
327 | self.attention_mask,
328 | position_ids=self.position_ids,
329 | ref_position_ids=self.ref_position_ids,
330 | token_type_ids=self.token_type_ids,
331 | ref_token_type_ids=self.ref_token_type_ids,
332 | internal_batch_size=self.internal_batch_size,
333 | n_steps=self.n_steps,
334 | )
335 | end_lig.summarize()
336 | self.end_attributions = end_lig
337 | self.attributions = [self.start_attributions, self.end_attributions]
338 |
339 | def __call__(
340 | self,
341 | question: str,
342 | text: str,
343 | embedding_type: int = 2,
344 | internal_batch_size: int = None,
345 | n_steps: int = None,
346 | ) -> dict:
347 | """
348 | Calculates start and end position word attributions for `question` and `text` using the model
349 | and tokenizer given in the constructor.
350 |
351 | This explainer also allows for attributions with respect to a particlar embedding type.
352 | This can be selected by passing a `embedding_type`. The default value is `2` which
353 | attempts to calculate for all embeddings. If `0` is passed then attributions are w.r.t word_embeddings,
354 | if `1` is passed then attributions are w.r.t position_embeddings.
355 |
356 |
357 | Args:
358 | question (str): The question text
359 | text (str): The text or context from which the model finds an answers
360 | embedding_type (int, optional): The embedding type word(0), position(1), all(2) to calculate attributions for.
361 | Defaults to 2.
362 | internal_batch_size (int, optional): Divides total #steps * #examples
363 | data points into chunks of size at most internal_batch_size,
364 | which are computed (forward / backward passes)
365 | sequentially. If internal_batch_size is None, then all evaluations are
366 | processed in one batch.
367 | n_steps (int, optional): The number of steps used by the approximation
368 | method. Default: 50.
369 |
370 | Returns:
371 | dict: Dict for start and end position word attributions.
372 | """
373 |
374 | if n_steps:
375 | self.n_steps = n_steps
376 | if internal_batch_size:
377 | self.internal_batch_size = internal_batch_size
378 | return self._run(question, text, embedding_type)
379 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/sequence_classification.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Dict, List, Optional, Tuple, Union
3 |
4 | import torch
5 | from captum.attr import visualization as viz
6 | from torch.nn.modules.sparse import Embedding
7 | from transformers import PreTrainedModel, PreTrainedTokenizer
8 |
9 | from transformers_interpret import BaseExplainer, LIGAttributions
10 | from transformers_interpret.errors import (
11 | AttributionTypeNotSupportedError,
12 | InputIdsNotCalculatedError,
13 | )
14 |
15 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"]
16 |
17 |
18 | class SequenceClassificationExplainer(BaseExplainer):
19 | """
20 | Explainer for explaining attributions for models of type
21 | `{MODEL_NAME}ForSequenceClassification` from the Transformers package.
22 |
23 | Calculates attribution for `text` using the given model
24 | and tokenizer.
25 |
26 | Attributions can be forced along the axis of a particular output index or class name.
27 | To do this provide either a valid `index` for the class label's output or if the outputs
28 | have provided labels you can pass a `class_name`.
29 |
30 | This explainer also allows for attributions with respect to a particlar embedding type.
31 | This can be selected by passing a `embedding_type`. The default value is `0` which
32 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
33 | If a model does not take position ids in its forward method (distilbert) a warning will
34 | occur and the default word_embeddings will be chosen instead.
35 |
36 | """
37 |
38 | def __init__(
39 | self,
40 | model: PreTrainedModel,
41 | tokenizer: PreTrainedTokenizer,
42 | attribution_type: str = "lig",
43 | custom_labels: Optional[List[str]] = None,
44 | ):
45 | """
46 | Args:
47 | model (PreTrainedModel): Pretrained huggingface Sequence Classification model.
48 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer
49 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig".
50 | custom_labels (List[str], optional): Applies custom labels to label2id and id2label configs.
51 | Labels must be same length as the base model configs' labels.
52 | Labels and ids are applied index-wise. Defaults to None.
53 |
54 | Raises:
55 | AttributionTypeNotSupportedError:
56 | """
57 | super().__init__(model, tokenizer)
58 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES:
59 | raise AttributionTypeNotSupportedError(
60 | f"""Attribution type '{attribution_type}' is not supported.
61 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}"""
62 | )
63 | self.attribution_type = attribution_type
64 |
65 | if custom_labels is not None:
66 | if len(custom_labels) != len(model.config.label2id):
67 | raise ValueError(
68 | f"""`custom_labels` size '{len(custom_labels)}' should match pretrained model's label2id size
69 | '{len(model.config.label2id)}'"""
70 | )
71 |
72 | self.id2label, self.label2id = self._get_id2label_and_label2id_dict(custom_labels)
73 | else:
74 | self.label2id = model.config.label2id
75 | self.id2label = model.config.id2label
76 |
77 | self.attributions: Union[None, LIGAttributions] = None
78 | self.input_ids: torch.Tensor = torch.Tensor()
79 |
80 | self._single_node_output = False
81 |
82 | self.internal_batch_size = None
83 | self.n_steps = 50
84 |
85 | @staticmethod
86 | def _get_id2label_and_label2id_dict(
87 | labels: List[str],
88 | ) -> Tuple[Dict[int, str], Dict[str, int]]:
89 | id2label: Dict[int, str] = dict()
90 | label2id: Dict[str, int] = dict()
91 | for idx, label in enumerate(labels):
92 | id2label[idx] = label
93 | label2id[label] = idx
94 |
95 | return id2label, label2id
96 |
97 | def encode(self, text: str = None) -> list:
98 | return self.tokenizer.encode(text, add_special_tokens=False)
99 |
100 | def decode(self, input_ids: torch.Tensor) -> list:
101 | "Decode 'input_ids' to string using tokenizer"
102 | return self.tokenizer.convert_ids_to_tokens(input_ids[0])
103 |
104 | @property
105 | def predicted_class_index(self) -> int:
106 | "Returns predicted class index (int) for model with last calculated `input_ids`"
107 | if len(self.input_ids) > 0:
108 | # we call this before _forward() so it has to be calculated twice
109 | preds = self.model(self.input_ids)[0]
110 | self.pred_class = torch.argmax(torch.softmax(preds, dim=0)[0])
111 | return torch.argmax(torch.softmax(preds, dim=1)[0]).cpu().detach().numpy()
112 |
113 | else:
114 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`")
115 |
116 | @property
117 | def predicted_class_name(self):
118 | "Returns predicted class name (str) for model with last calculated `input_ids`"
119 | try:
120 | index = self.predicted_class_index
121 | return self.id2label[int(index)]
122 | except Exception:
123 | return self.predicted_class_index
124 |
125 | @property
126 | def word_attributions(self) -> list:
127 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
128 | if self.attributions is not None:
129 | return self.attributions.word_attributions
130 | else:
131 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
132 |
133 | def visualize(self, html_filepath: str = None, true_class: str = None):
134 | """
135 | Visualizes word attributions. If in a notebook table will be displayed inline.
136 |
137 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved
138 | as a html file.
139 |
140 | If the true class is known for the text that can be passed to `true_class`
141 |
142 | """
143 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
144 | attr_class = self.id2label[self.selected_index]
145 |
146 | if self._single_node_output:
147 | if true_class is None:
148 | true_class = round(float(self.pred_probs))
149 | predicted_class = round(float(self.pred_probs))
150 | attr_class = round(float(self.pred_probs))
151 |
152 | else:
153 | if true_class is None:
154 | true_class = self.selected_index
155 | predicted_class = self.predicted_class_name
156 |
157 | score_viz = self.attributions.visualize_attributions( # type: ignore
158 | self.pred_probs,
159 | predicted_class,
160 | true_class,
161 | attr_class,
162 | tokens,
163 | )
164 | html = viz.visualize_text([score_viz])
165 |
166 | if html_filepath:
167 | if not html_filepath.endswith(".html"):
168 | html_filepath = html_filepath + ".html"
169 | with open(html_filepath, "w") as html_file:
170 | html_file.write(html.data)
171 | return html
172 |
173 | def _forward( # type: ignore
174 | self,
175 | input_ids: torch.Tensor,
176 | token_type_ids=None,
177 | position_ids: torch.Tensor = None,
178 | attention_mask: torch.Tensor = None,
179 | ):
180 |
181 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
182 | preds = preds[0]
183 |
184 | # if it is a single output node
185 | if len(preds[0]) == 1:
186 | self._single_node_output = True
187 | self.pred_probs = torch.sigmoid(preds)[0][0]
188 | return torch.sigmoid(preds)[:, :]
189 |
190 | self.pred_probs = torch.softmax(preds, dim=1)[0][self.selected_index]
191 | return torch.softmax(preds, dim=1)[:, self.selected_index]
192 |
193 | def _calculate_attributions(self, embeddings: Embedding, index: int = None, class_name: str = None): # type: ignore
194 | (
195 | self.input_ids,
196 | self.ref_input_ids,
197 | self.sep_idx,
198 | ) = self._make_input_reference_pair(self.text)
199 |
200 | (
201 | self.position_ids,
202 | self.ref_position_ids,
203 | ) = self._make_input_reference_position_id_pair(self.input_ids)
204 |
205 | (
206 | self.token_type_ids,
207 | self.ref_token_type_ids,
208 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)
209 |
210 | self.attention_mask = self._make_attention_mask(self.input_ids)
211 |
212 | if index is not None:
213 | self.selected_index = index
214 | elif class_name is not None:
215 | if class_name in self.label2id.keys():
216 | self.selected_index = int(self.label2id[class_name])
217 | else:
218 | s = f"'{class_name}' is not found in self.label2id keys."
219 | s += "Defaulting to predicted index instead."
220 | warnings.warn(s)
221 | self.selected_index = int(self.predicted_class_index)
222 | else:
223 | self.selected_index = int(self.predicted_class_index)
224 |
225 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
226 | lig = LIGAttributions(
227 | custom_forward=self._forward,
228 | embeddings=embeddings,
229 | tokens=reference_tokens,
230 | input_ids=self.input_ids,
231 | ref_input_ids=self.ref_input_ids,
232 | sep_id=self.sep_idx,
233 | attention_mask=self.attention_mask,
234 | position_ids=self.position_ids,
235 | ref_position_ids=self.ref_position_ids,
236 | token_type_ids=self.token_type_ids,
237 | ref_token_type_ids=self.ref_token_type_ids,
238 | internal_batch_size=self.internal_batch_size,
239 | n_steps=self.n_steps,
240 | )
241 |
242 | lig.summarize()
243 | self.attributions = lig
244 |
245 | def _run(
246 | self,
247 | text: str,
248 | index: int = None,
249 | class_name: str = None,
250 | embedding_type: int = None,
251 | ) -> list: # type: ignore
252 | if embedding_type is None:
253 | embeddings = self.word_embeddings
254 | else:
255 | if embedding_type == 0:
256 | embeddings = self.word_embeddings
257 | elif embedding_type == 1:
258 | if self.accepts_position_ids and self.position_embeddings is not None:
259 | embeddings = self.position_embeddings
260 | else:
261 | warnings.warn(
262 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings"
263 | )
264 | embeddings = self.word_embeddings
265 | else:
266 | embeddings = self.word_embeddings
267 |
268 | self.text = self._clean_text(text)
269 |
270 | self._calculate_attributions(embeddings=embeddings, index=index, class_name=class_name)
271 | return self.word_attributions # type: ignore
272 |
273 | def __call__(
274 | self,
275 | text: str,
276 | index: int = None,
277 | class_name: str = None,
278 | embedding_type: int = 0,
279 | internal_batch_size: int = None,
280 | n_steps: int = None,
281 | ) -> list:
282 | """
283 | Calculates attribution for `text` using the model
284 | and tokenizer given in the constructor.
285 |
286 | Attributions can be forced along the axis of a particular output index or class name.
287 | To do this provide either a valid `index` for the class label's output or if the outputs
288 | have provided labels you can pass a `class_name`.
289 |
290 | This explainer also allows for attributions with respect to a particular embedding type.
291 | This can be selected by passing a `embedding_type`. The default value is `0` which
292 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
293 | If a model does not take position ids in its forward method (distilbert) a warning will
294 | occur and the default word_embeddings will be chosen instead.
295 |
296 | Args:
297 | text (str): Text to provide attributions for.
298 | index (int, optional): Optional output index to provide attributions for. Defaults to None.
299 | class_name (str, optional): Optional output class name to provide attributions for. Defaults to None.
300 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0.
301 | internal_batch_size (int, optional): Divides total #steps * #examples
302 | data points into chunks of size at most internal_batch_size,
303 | which are computed (forward / backward passes)
304 | sequentially. If internal_batch_size is None, then all evaluations are
305 | processed in one batch.
306 | n_steps (int, optional): The number of steps used by the approximation
307 | method. Default: 50.
308 | Returns:
309 | list: List of tuples containing words and their associated attribution scores.
310 | """
311 |
312 | if n_steps:
313 | self.n_steps = n_steps
314 | if internal_batch_size:
315 | self.internal_batch_size = internal_batch_size
316 | return self._run(text, index, class_name, embedding_type=embedding_type)
317 |
318 | def __str__(self):
319 | s = f"{self.__class__.__name__}("
320 | s += f"\n\tmodel={self.model.__class__.__name__},"
321 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__},"
322 | s += f"\n\tattribution_type='{self.attribution_type}',"
323 | s += ")"
324 |
325 | return s
326 |
327 |
328 | class PairwiseSequenceClassificationExplainer(SequenceClassificationExplainer):
329 | def _make_input_reference_pair(
330 | self, text1: Union[List, str], text2: Union[List, str]
331 | ) -> Tuple[torch.Tensor, torch.Tensor, int]:
332 |
333 | t1_ids = self.tokenizer.encode(text1, add_special_tokens=False)
334 | t2_ids = self.tokenizer.encode(text2, add_special_tokens=False)
335 | input_ids = self.tokenizer.encode([text1, text2], add_special_tokens=True)
336 | if self.model.config.model_type == "roberta":
337 | ref_input_ids = (
338 | [self.cls_token_id]
339 | + [self.ref_token_id] * len(t1_ids)
340 | + [self.sep_token_id]
341 | + [self.sep_token_id]
342 | + [self.ref_token_id] * len(t2_ids)
343 | + [self.sep_token_id]
344 | )
345 |
346 | else:
347 |
348 | ref_input_ids = (
349 | [self.cls_token_id]
350 | + [self.ref_token_id] * len(t1_ids)
351 | + [self.sep_token_id]
352 | + [self.ref_token_id] * len(t2_ids)
353 | + [self.sep_token_id]
354 | )
355 |
356 | return (
357 | torch.tensor([input_ids], device=self.device),
358 | torch.tensor([ref_input_ids], device=self.device),
359 | len(t1_ids) + 1, # +1 for CLS token
360 | )
361 |
362 | def _calculate_attributions(
363 | self,
364 | embeddings: Embedding,
365 | index: int = None,
366 | class_name: str = None,
367 | flip_sign: bool = False,
368 | ): # type: ignore
369 | (
370 | self.input_ids,
371 | self.ref_input_ids,
372 | self.sep_idx,
373 | ) = self._make_input_reference_pair(self.text1, self.text2)
374 |
375 | (
376 | self.position_ids,
377 | self.ref_position_ids,
378 | ) = self._make_input_reference_position_id_pair(self.input_ids)
379 |
380 | (
381 | self.token_type_ids,
382 | self.ref_token_type_ids,
383 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)
384 |
385 | self.attention_mask = self._make_attention_mask(self.input_ids)
386 |
387 | if index is not None:
388 | self.selected_index = index
389 | elif class_name is not None:
390 | if class_name in self.label2id.keys():
391 | self.selected_index = int(self.label2id[class_name])
392 | else:
393 | s = f"'{class_name}' is not found in self.label2id keys."
394 | s += "Defaulting to predicted index instead."
395 | warnings.warn(s)
396 | self.selected_index = int(self.predicted_class_index)
397 | else:
398 | self.selected_index = int(self.predicted_class_index)
399 |
400 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
401 | lig = LIGAttributions(
402 | custom_forward=self._forward,
403 | embeddings=embeddings,
404 | tokens=reference_tokens,
405 | input_ids=self.input_ids,
406 | ref_input_ids=self.ref_input_ids,
407 | sep_id=self.sep_idx,
408 | attention_mask=self.attention_mask,
409 | position_ids=self.position_ids,
410 | ref_position_ids=self.ref_position_ids,
411 | token_type_ids=self.token_type_ids,
412 | ref_token_type_ids=self.ref_token_type_ids,
413 | internal_batch_size=self.internal_batch_size,
414 | n_steps=self.n_steps,
415 | )
416 | if self._single_node_output:
417 | lig.summarize(flip_sign=flip_sign)
418 | else:
419 | lig.summarize()
420 | self.attributions = lig
421 |
422 | def _run(
423 | self,
424 | text1: str,
425 | text2: str,
426 | index: int = None,
427 | class_name: str = None,
428 | embedding_type: int = None,
429 | flip_sign: bool = False,
430 | ) -> list: # type: ignore
431 | if embedding_type is None:
432 | embeddings = self.word_embeddings
433 | else:
434 | if embedding_type == 0:
435 | embeddings = self.word_embeddings
436 | elif embedding_type == 1:
437 | if self.accepts_position_ids and self.position_embeddings is not None:
438 | embeddings = self.position_embeddings
439 | else:
440 | warnings.warn(
441 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings"
442 | )
443 | embeddings = self.word_embeddings
444 | else:
445 | embeddings = self.word_embeddings
446 |
447 | self.text1 = text1
448 | self.text2 = text2
449 |
450 | self._calculate_attributions(
451 | embeddings=embeddings,
452 | index=index,
453 | class_name=class_name,
454 | flip_sign=flip_sign,
455 | )
456 | return self.word_attributions # type: ignore
457 |
458 | def __call__(
459 | self,
460 | text1: str,
461 | text2: str,
462 | index: int = None,
463 | class_name: str = None,
464 | embedding_type: int = 0,
465 | internal_batch_size: int = None,
466 | n_steps: int = None,
467 | flip_sign: bool = False,
468 | ):
469 | """
470 | Calculates pairwise attributions for two inputs `text1` and `text2` using the model
471 | and tokenizer given in the constructor. Pairwise attributions are useful for models where
472 | two distinct inputs separated by the model separator token are fed to the model, such as cross-encoder
473 | models for similarity classification.
474 |
475 | Attributions can be forced along the axis of a particular output index or class name if there is more than one.
476 | To do this provide either a valid `index` for the class label's output or if the outputs
477 | have provided labels you can pass a `class_name`.
478 |
479 | This explainer also allows for attributions with respect to a particular embedding type.
480 | This can be selected by passing a `embedding_type`. The default value is `0` which
481 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
482 | If a model does not take position ids in its forward method (distilbert) a warning will
483 | occur and the default word_embeddings will be chosen instead.
484 |
485 | Additionally, this explainer allows for attributions signs to be flipped in cases where the model
486 | only outputs a single node. By default for models that output a single node the attributions are
487 | with respect to the inputs pushing the scores closer to 1.0, however if you want to see the
488 | attributions with respect to scores closer to 0.0 you can pass `flip_sign=True`. For similarity
489 | based models this is useful, as the model might predict a score closer to 0.0 for the two inputs
490 | and in that case we would flip the attributions sign to explain why the two inputs are dissimilar.
491 |
492 | Args:
493 | text1 (str): First text input to provide pairwise attributions for.
494 | text2 (str): Second text to provide pairwise attributions for.
495 | index (int, optional): Optional output index to provide attributions for. Defaults to None.
496 | class_name (str, optional): Optional output class name to provide attributions for. Defaults to None.
497 | embedding_type (int, optional):The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0.
498 | internal_batch_size (int, optional): Divides total #steps * #examples
499 | data points into chunks of size at most internal_batch_size,
500 | which are computed (forward / backward passes)
501 | sequentially. If internal_batch_size is None, then all evaluations are
502 | processed in one batch.
503 | n_steps (int, optional): The number of steps used by the approximation
504 | method. Default: 50.
505 | flip_sign (bool, optional): Boolean flag determining whether to flip the sign of attributions. Defaults to False.
506 |
507 | Returns:
508 | _type_: _description_
509 | """
510 | if n_steps:
511 | self.n_steps = n_steps
512 | if internal_batch_size:
513 | self.internal_batch_size = internal_batch_size
514 | return self._run(
515 | text1=text1,
516 | text2=text2,
517 | embedding_type=embedding_type,
518 | index=index,
519 | class_name=class_name,
520 | flip_sign=flip_sign,
521 | )
522 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/token_classification.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Dict, List, Optional, Union
3 |
4 | import torch
5 | from captum.attr import visualization as viz
6 | from torch.nn.modules.sparse import Embedding
7 | from transformers import PreTrainedModel, PreTrainedTokenizer
8 |
9 | from transformers_interpret import BaseExplainer
10 | from transformers_interpret.attributions import LIGAttributions
11 | from transformers_interpret.errors import (
12 | AttributionTypeNotSupportedError,
13 | InputIdsNotCalculatedError,
14 | )
15 |
16 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"]
17 |
18 |
19 | class TokenClassificationExplainer(BaseExplainer):
20 | def __init__(
21 | self,
22 | model: PreTrainedModel,
23 | tokenizer: PreTrainedTokenizer,
24 | attribution_type="lig",
25 | ):
26 |
27 | """
28 | Args:
29 | model (PreTrainedModel): Pretrained huggingface Sequence Classification model.
30 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer
31 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig".
32 |
33 | Raises:
34 | AttributionTypeNotSupportedError:
35 | """
36 | super().__init__(model, tokenizer)
37 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES:
38 | raise AttributionTypeNotSupportedError(
39 | f"""Attribution type '{attribution_type}' is not supported.
40 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}"""
41 | )
42 | self.attribution_type: str = attribution_type
43 |
44 | self.label2id = model.config.label2id
45 | self.id2label = model.config.id2label
46 |
47 | self.ignored_indexes: Optional[List[int]] = None
48 | self.ignored_labels: Optional[List[str]] = None
49 |
50 | self.attributions: Union[None, Dict[int, LIGAttributions]] = None
51 | self.input_ids: torch.Tensor = torch.Tensor()
52 |
53 | self.internal_batch_size = None
54 | self.n_steps = 50
55 |
56 | def encode(self, text: str = None) -> list:
57 | "Encode the text using tokenizer"
58 | return self.tokenizer.encode(text, add_special_tokens=False)
59 |
60 | def decode(self, input_ids: torch.Tensor) -> list:
61 | "Decode 'input_ids' to string using tokenizer"
62 | return self.tokenizer.convert_ids_to_tokens(input_ids[0])
63 |
64 | @property
65 | def predicted_class_indexes(self) -> List[int]:
66 | "Returns the predicted class indexes (int) for model with last calculated `input_ids`"
67 | if len(self.input_ids) > 0:
68 |
69 | preds = self.model(self.input_ids)
70 | preds = preds[0]
71 | self.pred_class = torch.softmax(preds, dim=2)[0]
72 |
73 | return torch.argmax(torch.softmax(preds, dim=2), dim=2)[0].cpu().detach().numpy()
74 |
75 | else:
76 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`")
77 |
78 | @property
79 | def predicted_class_names(self):
80 | "Returns predicted class names (str) for model with last calculated `input_ids`"
81 | try:
82 | indexes = self.predicted_class_indexes
83 | return [self.id2label[int(index)] for index in indexes]
84 | except Exception:
85 | return self.predicted_class_indexes
86 |
87 | @property
88 | def word_attributions(self) -> Dict:
89 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
90 |
91 | if self.attributions is not None:
92 | word_attr = dict()
93 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
94 | labels = self.predicted_class_names
95 |
96 | for index, attr in self.attributions.items():
97 | try:
98 | predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
99 | except KeyError:
100 | predicted_class = torch.argmax(self.pred_probs[index]).item()
101 |
102 | word_attr[tokens[index]] = {
103 | "label": predicted_class,
104 | "attribution_scores": attr.word_attributions,
105 | }
106 |
107 | return word_attr
108 | else:
109 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
110 |
111 | @property
112 | def _selected_indexes(self) -> List[int]:
113 | """Returns the indexes for which the attributions must be calculated considering the
114 | ignored indexes and the ignored labels, in that order of priority"""
115 |
116 | selected_indexes = set(range(self.input_ids.shape[1])) # all indexes
117 |
118 | if self.ignored_indexes is not None:
119 | selected_indexes = selected_indexes.difference(set(self.ignored_indexes))
120 |
121 | if self.ignored_labels is not None:
122 | ignored_indexes_extra = []
123 | pred_labels = [self.id2label[id] for id in self.predicted_class_indexes]
124 |
125 | for index, label in enumerate(pred_labels):
126 | if label in self.ignored_labels:
127 | ignored_indexes_extra.append(index)
128 | selected_indexes = selected_indexes.difference(ignored_indexes_extra)
129 |
130 | return sorted(list(selected_indexes))
131 |
132 | def visualize(self, html_filepath: str = None, true_classes: List[str] = None):
133 | """
134 | Visualizes word attributions. If in a notebook table will be displayed inline.
135 |
136 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved
137 | as a html file.
138 |
139 | If the true class is known for the text that can be passed to `true_class`
140 |
141 | """
142 | if true_classes is not None and len(true_classes) != self.input_ids.shape[1]:
143 | raise ValueError(f"""The length of `true_classes` must be equal to the number of tokens""")
144 |
145 | score_vizs = []
146 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
147 |
148 | for index in self._selected_indexes:
149 | pred_prob = torch.max(self.pred_probs[index])
150 | predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
151 |
152 | attr_class = tokens[index]
153 | if true_classes is None:
154 | true_class = predicted_class
155 | else:
156 | true_class = true_classes[index]
157 |
158 | score_vizs.append(
159 | self.attributions[index].visualize_attributions(
160 | pred_prob,
161 | predicted_class,
162 | true_class,
163 | attr_class,
164 | tokens,
165 | )
166 | )
167 |
168 | html = viz.visualize_text(score_vizs)
169 |
170 | if html_filepath:
171 | if not html_filepath.endswith(".html"):
172 | html_filepath = html_filepath + ".html"
173 | with open(html_filepath, "w") as html_file:
174 | html_file.write(html.data)
175 | return html
176 |
177 | def _forward(
178 | self,
179 | input_ids: torch.Tensor,
180 | position_ids: torch.Tensor = None,
181 | attention_mask: torch.Tensor = None,
182 | ):
183 | if self.accepts_position_ids:
184 | preds = self.model(
185 | input_ids,
186 | position_ids=position_ids,
187 | attention_mask=attention_mask,
188 | )
189 | else:
190 | preds = self.model(input_ids, attention_mask)
191 |
192 | preds = preds.logits # preds.shape = [N_BATCH, N_TOKENS, N_CLASSES]
193 |
194 | self.pred_probs = torch.softmax(preds, dim=2)[0]
195 | return torch.softmax(preds, dim=2)[:, self.index, :]
196 |
197 | def _calculate_attributions(
198 | self,
199 | embeddings: Embedding,
200 | ) -> None:
201 | (
202 | self.input_ids,
203 | self.ref_input_ids,
204 | self.sep_idx,
205 | ) = self._make_input_reference_pair(self.text)
206 |
207 | (
208 | self.position_ids,
209 | self.ref_position_ids,
210 | ) = self._make_input_reference_position_id_pair(self.input_ids)
211 |
212 | self.attention_mask = self._make_attention_mask(self.input_ids)
213 |
214 | pred_classes = self.predicted_class_indexes
215 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
216 |
217 | ligs = {}
218 |
219 | for index in self._selected_indexes:
220 | self.index = index
221 | lig = LIGAttributions(
222 | self._forward,
223 | embeddings,
224 | reference_tokens,
225 | self.input_ids,
226 | self.ref_input_ids,
227 | self.sep_idx,
228 | self.attention_mask,
229 | target=int(pred_classes[index]),
230 | position_ids=self.position_ids,
231 | ref_position_ids=self.ref_position_ids,
232 | internal_batch_size=self.internal_batch_size,
233 | n_steps=self.n_steps,
234 | )
235 | lig.summarize()
236 | ligs[index] = lig
237 |
238 | self.attributions = ligs
239 |
240 | def _run(
241 | self,
242 | text: str,
243 | embedding_type: int = None,
244 | ) -> dict:
245 | if embedding_type is None:
246 | embeddings = self.word_embeddings
247 | else:
248 | if embedding_type == 0:
249 | embeddings = self.word_embeddings
250 | elif embedding_type == 1:
251 | if self.accepts_position_ids and self.position_embeddings is not None:
252 | embeddings = self.position_embeddings
253 | else:
254 | warnings.warn(
255 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings"
256 | )
257 | embeddings = self.word_embeddings
258 | else:
259 | embeddings = self.word_embeddings
260 |
261 | self.text = self._clean_text(text)
262 |
263 | self._calculate_attributions(embeddings=embeddings)
264 | return self.word_attributions
265 |
266 | def __call__(
267 | self,
268 | text: str,
269 | embedding_type: int = 0,
270 | internal_batch_size: Optional[int] = None,
271 | n_steps: Optional[int] = None,
272 | ignored_indexes: Optional[List[int]] = None,
273 | ignored_labels: Optional[List[str]] = None,
274 | ) -> dict:
275 | """
276 | Args:
277 | text (str): Sentence whose NER predictions are to be explained.
278 | embedding_type (int, default = 0): Custom type of embedding.
279 | internal_batch_size (int, optional): Custom internal batch size for the attributions calculation.
280 | n_steps (int): Custom number of steps in the approximation used in the attributions calculation.
281 | ignored_indexes (List[int], optional): Indexes that are to be ignored by the explainer.
282 | ignored_labels (List[str], optional)): NER labels that are to be ignored by the explainer. The
283 | explainer will ignore those indexes whose predicted label is
284 | in `ignored_labels`.
285 | """
286 |
287 | if n_steps:
288 | self.n_steps = n_steps
289 | if internal_batch_size:
290 | self.internal_batch_size = internal_batch_size
291 |
292 | self.ignored_indexes = ignored_indexes
293 | self.ignored_labels = ignored_labels
294 |
295 | return self._run(text, embedding_type=embedding_type)
296 |
297 | def __str__(self):
298 | s = f"{self.__class__.__name__}("
299 | s += f"\n\tmodel={self.model.__class__.__name__},"
300 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__},"
301 | s += f"\n\tattribution_type='{self.attribution_type}',"
302 | s += ")"
303 |
304 | return s
305 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/text/zero_shot_classification.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | from captum.attr import visualization as viz
5 | from torch.nn.modules.sparse import Embedding
6 | from transformers import PreTrainedModel, PreTrainedTokenizer
7 |
8 | from transformers_interpret import LIGAttributions
9 | from transformers_interpret.errors import AttributionTypeNotSupportedError
10 |
11 | from .question_answering import QuestionAnsweringExplainer
12 | from .sequence_classification import SequenceClassificationExplainer
13 |
14 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"]
15 |
16 |
17 | class ZeroShotClassificationExplainer(SequenceClassificationExplainer, QuestionAnsweringExplainer):
18 | """
19 | Explainer for explaining attributions for models that can perform
20 | zero-shot classification, specifically models trained on nli downstream tasks.
21 |
22 | This explainer uses the same "trick" as Huggingface to achieve attributions on
23 | arbitrary labels provided at inference time.
24 |
25 | Model's provided to this explainer must be nli sequence classification models
26 | and must have the label "entailment" or "ENTAILMENT" in
27 | `model.config.label2id.keys()` in order for it to work correctly.
28 |
29 | This explainer works by forcing the model to explain it's output with respect to
30 | the entailment class. For each label passed at inference the explainer forms a hypothesis with each
31 | and calculates attributions for each hypothesis label. The label with the highest predicted probability
32 | can be accessed via the attribute `predicted_label`.
33 | """
34 |
35 | def __init__(
36 | self,
37 | model: PreTrainedModel,
38 | tokenizer: PreTrainedTokenizer,
39 | attribution_type: str = "lig",
40 | ):
41 | """
42 |
43 | Args:
44 | model (PreTrainedModel):Pretrained huggingface Sequence Classification model. Must be a NLI model.
45 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer
46 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig".
47 |
48 | Raises:
49 | AttributionTypeNotSupportedError: [description]
50 | ValueError: [description]
51 | """
52 | super().__init__(model, tokenizer)
53 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES:
54 | raise AttributionTypeNotSupportedError(
55 | f"""Attribution type '{attribution_type}' is not supported.
56 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}"""
57 | )
58 | self.label_exists, self.entailment_key = self._entailment_label_exists()
59 | if not self.label_exists:
60 | raise ValueError('Expected label "entailment" in `model.label2id` ')
61 |
62 | self.entailment_idx = self.label2id[self.entailment_key]
63 | self.include_hypothesis = False
64 | self.attributions = []
65 |
66 | self.internal_batch_size = None
67 | self.n_steps = 50
68 |
69 | @property
70 | def word_attributions(self) -> dict:
71 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."
72 | if self.attributions != []:
73 | if self.include_hypothesis:
74 | return dict(
75 | zip(
76 | self.labels,
77 | [attr.word_attributions for attr in self.attributions],
78 | )
79 | )
80 | else:
81 | spliced_wa = [attr.word_attributions[: self.sep_idx] for attr in self.attributions]
82 | return dict(zip(self.labels, spliced_wa))
83 | else:
84 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")
85 |
86 | def visualize(self, html_filepath: str = None, true_class: str = None):
87 | """
88 | Visualizes word attributions. If in a notebook table will be displayed inline.
89 |
90 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved
91 | as a html file.
92 |
93 | If the true class is known for the text that can be passed to `true_class`
94 |
95 | """
96 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
97 |
98 | if not self.include_hypothesis:
99 | tokens = tokens[: self.sep_idx]
100 |
101 | score_viz = [
102 | self.attributions[i].visualize_attributions( # type: ignore
103 | self.pred_probs[i],
104 | self.labels[i],
105 | self.labels[i],
106 | self.labels[i],
107 | tokens,
108 | )
109 | for i in range(len(self.attributions))
110 | ]
111 | html = viz.visualize_text(score_viz)
112 |
113 | if html_filepath:
114 | if not html_filepath.endswith(".html"):
115 | html_filepath = html_filepath + ".html"
116 | with open(html_filepath, "w") as html_file:
117 | html_file.write(html.data)
118 | return html
119 |
120 | def _entailment_label_exists(self) -> bool:
121 | if "entailment" in self.label2id.keys():
122 | return True, "entailment"
123 | elif "ENTAILMENT" in self.label2id.keys():
124 | return True, "ENTAILMENT"
125 |
126 | return False, None
127 |
128 | def _get_top_predicted_label_idx(self, text, hypothesis_labels: List[str]) -> int:
129 |
130 | entailment_outputs = []
131 | for label in hypothesis_labels:
132 | input_ids, _, sep_idx = self._make_input_reference_pair(text, label)
133 | position_ids, _ = self._make_input_reference_position_id_pair(input_ids)
134 | token_type_ids, _ = self._make_input_reference_token_type_pair(input_ids, sep_idx)
135 | attention_mask = self._make_attention_mask(input_ids)
136 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
137 | entailment_outputs.append(float(torch.sigmoid(preds[0])[0][self.entailment_idx]))
138 |
139 | normed_entailment_outputs = [float(i) / sum(entailment_outputs) for i in entailment_outputs]
140 |
141 | self.pred_probs = normed_entailment_outputs
142 |
143 | return entailment_outputs.index(max(entailment_outputs))
144 |
145 | def _make_input_reference_pair(
146 | self,
147 | text: str,
148 | hypothesis_text: str,
149 | ) -> Tuple[torch.Tensor, torch.Tensor, int]:
150 | hyp_ids = self.encode(hypothesis_text)
151 | text_ids = self.encode(text)
152 |
153 | input_ids = [self.cls_token_id] + text_ids + [self.sep_token_id] + hyp_ids + [self.sep_token_id]
154 |
155 | ref_input_ids = (
156 | [self.cls_token_id]
157 | + [self.ref_token_id] * len(text_ids)
158 | + [self.sep_token_id]
159 | + [self.ref_token_id] * len(hyp_ids)
160 | + [self.sep_token_id]
161 | )
162 |
163 | return (
164 | torch.tensor([input_ids], device=self.device),
165 | torch.tensor([ref_input_ids], device=self.device),
166 | len(text_ids),
167 | )
168 |
169 | def _forward( # type: ignore
170 | self,
171 | input_ids: torch.Tensor,
172 | token_type_ids=None,
173 | position_ids: torch.Tensor = None,
174 | attention_mask: torch.Tensor = None,
175 | ):
176 |
177 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask)
178 | preds = preds[0]
179 |
180 | return torch.softmax(preds, dim=1)[:, self.selected_index]
181 |
182 | def _calculate_attributions(self, embeddings: Embedding, class_name: str, index: int = None): # type: ignore
183 | (
184 | self.input_ids,
185 | self.ref_input_ids,
186 | self.sep_idx,
187 | ) = self._make_input_reference_pair(self.text, self.hypothesis_text)
188 |
189 | (
190 | self.position_ids,
191 | self.ref_position_ids,
192 | ) = self._make_input_reference_position_id_pair(self.input_ids)
193 |
194 | (
195 | self.token_type_ids,
196 | self.ref_token_type_ids,
197 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx)
198 |
199 | self.attention_mask = self._make_attention_mask(self.input_ids)
200 |
201 | self.selected_index = int(self.label2id[class_name])
202 |
203 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
204 | lig = LIGAttributions(
205 | self._forward,
206 | embeddings,
207 | reference_tokens,
208 | self.input_ids,
209 | self.ref_input_ids,
210 | self.sep_idx,
211 | self.attention_mask,
212 | position_ids=self.position_ids,
213 | ref_position_ids=self.ref_position_ids,
214 | token_type_ids=self.token_type_ids,
215 | ref_token_type_ids=self.ref_token_type_ids,
216 | internal_batch_size=self.internal_batch_size,
217 | n_steps=self.n_steps,
218 | )
219 | if self.include_hypothesis:
220 | lig.summarize()
221 | else:
222 | lig.summarize(self.sep_idx)
223 | self.attributions.append(lig)
224 |
225 | def __call__(
226 | self,
227 | text: str,
228 | labels: List[str],
229 | embedding_type: int = 0,
230 | hypothesis_template="this text is about {} .",
231 | include_hypothesis: bool = False,
232 | internal_batch_size: int = None,
233 | n_steps: int = None,
234 | ) -> dict:
235 | """
236 | Calculates attribution for `text` using the model and
237 | tokenizer given in the constructor. Since `self.model` is
238 | a NLI type model each label in `labels` is formatted to the
239 | `hypothesis_template`. By default attributions are provided for all
240 | labels. The top predicted label can be found in the `predicted_label`
241 | attribute.
242 |
243 | Attribution is forced to be on the axis of whatever index
244 | the entailment class resolves to. e.g. {"entailment": 0, "neutral": 1, "contradiction": 2 }
245 | in the above case attributions would be for the label at index 0.
246 |
247 | This explainer also allows for attributions with respect to a particlar embedding type.
248 | This can be selected by passing a `embedding_type`. The default value is `0` which
249 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings.
250 | If a model does not take position ids in its forward method (distilbert) a warning will
251 | occur and the default word_embeddings will be chosen instead.
252 |
253 | The default `hypothesis_template` can also be overridden by providing a formattable
254 | string which accepts exactly one formattable value for the label.
255 |
256 | If `include_hypothesis` is set to `True` then the word attributions and visualization
257 | of the attributions will also included the hypothesis text which gives a complete indication
258 | of what the model sees at inference.
259 |
260 | Args:
261 | text (str): Text to provide attributions for.
262 | labels (List[str]): The labels to classify the text to. If only one label is provided in the list then
263 | attributions are guaranteed to be for that label.
264 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for.
265 | Defaults to 0.
266 | hypothesis_template (str, optional): Hypothesis presetned to NLI model given text.
267 | Defaults to "this text is about {} .".
268 | include_hypothesis (bool, optional): Alternative option to include hypothesis text in attributions
269 | and visualization. Defaults to False.
270 | internal_batch_size (int, optional): Divides total #steps * #examples
271 | data points into chunks of size at most internal_batch_size,
272 | which are computed (forward / backward passes)
273 | sequentially. If internal_batch_size is None, then all evaluations are
274 | processed in one batch.
275 | n_steps (int, optional): The number of steps used by the approximation
276 | method. Default: 50.
277 | Returns:
278 | list: List of tuples containing words and their associated attribution scores.
279 | """
280 |
281 | if n_steps:
282 | self.n_steps = n_steps
283 | if internal_batch_size:
284 | self.internal_batch_size = internal_batch_size
285 | self.attributions = []
286 | self.pred_probs = []
287 | self.include_hypothesis = include_hypothesis
288 | self.labels = labels
289 | self.hypothesis_labels = [hypothesis_template.format(label) for label in labels]
290 |
291 | predicted_text_idx = self._get_top_predicted_label_idx(text, self.hypothesis_labels)
292 |
293 | for i, _ in enumerate(self.labels):
294 | self.hypothesis_text = self.hypothesis_labels[i]
295 | self.predicted_label = labels[i] + " (" + self.entailment_key.lower() + ")"
296 | super().__call__(
297 | text,
298 | class_name=self.entailment_key,
299 | embedding_type=embedding_type,
300 | )
301 | self.predicted_label = self.labels[predicted_text_idx]
302 | return self.word_attributions
303 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/vision/attribution_types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, unique
2 |
3 |
4 | @unique
5 | class AttributionType(Enum):
6 | INTEGRATED_GRADIENTS = "IG"
7 | INTEGRATED_GRADIENTS_NOISE_TUNNEL = "IGNT"
8 |
9 |
10 | class NoiseTunnelType(Enum):
11 | SMOOTHGRAD = "smoothgrad"
12 | SMOOTHGRAD_SQUARED = "smoothgrad_sq"
13 | VARGRAD = "vargrad"
14 |
--------------------------------------------------------------------------------
/transformers_interpret/explainers/vision/image_classification.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from enum import Enum
3 | from typing import Dict, List, Optional, Tuple, Union
4 |
5 | import numpy as np
6 | from captum.attr import IntegratedGradients, NoiseTunnel
7 | from captum.attr import visualization as viz
8 | from PIL import Image
9 | from transformers.image_utils import ImageFeatureExtractionMixin
10 | from transformers.modeling_utils import PreTrainedModel
11 |
12 | from .attribution_types import AttributionType, NoiseTunnelType
13 |
14 |
15 | class ImageClassificationExplainer:
16 | """
17 | This class is used to explain the output of a model on an image.
18 | """
19 |
20 | def __init__(
21 | self,
22 | model: PreTrainedModel,
23 | feature_extractor: ImageFeatureExtractionMixin,
24 | attribution_type: str = AttributionType.INTEGRATED_GRADIENTS_NOISE_TUNNEL.value,
25 | custom_labels: Optional[List[str]] = None,
26 | ):
27 | self.model = model
28 | self.feature_extractor = feature_extractor
29 | if attribution_type not in [attribution.value for attribution in AttributionType]:
30 | raise ValueError(f"Attribution type {attribution_type} not supported.")
31 |
32 | self.attribution_type = attribution_type
33 |
34 | if custom_labels is not None:
35 | if len(custom_labels) != len(model.config.label2id):
36 | raise ValueError(
37 | f"""`custom_labels` size '{len(custom_labels)}' should match pretrained model's label2id size
38 | '{len(model.config.label2id)}'"""
39 | )
40 |
41 | self.id2label, self.label2id = self._get_id2label_and_label2id_dict(custom_labels)
42 | else:
43 | self.label2id = model.config.label2id
44 | self.id2label = model.config.id2label
45 |
46 | self.device = self.model.device
47 |
48 | self.internal_batch_size = None
49 | self.n_steps = 50
50 | self.n_steps_noise_tunnel = 5
51 | self.noise_tunnel_n_samples = 10
52 | self.noise_tunnel_type = NoiseTunnelType.SMOOTHGRAD.value
53 |
54 | self.attributions = None
55 |
56 | def visualize(
57 | self,
58 | save_path: Union[str, None] = None,
59 | method: str = "overlay",
60 | sign: str = "all",
61 | outlier_threshold: float = 0.1,
62 | use_original_image_pixels: bool = True,
63 | side_by_side: bool = False,
64 | ):
65 | outlier_threshold = min(outlier_threshold * 100, 100)
66 | attributions_t = np.transpose(self.attributions.squeeze(), (1, 2, 0))
67 | if use_original_image_pixels:
68 | np_image = np.asarray(
69 | self.feature_extractor.resize(self._image, size=(attributions_t.shape[0], attributions_t.shape[1]))
70 | )
71 | else:
72 | # uses the normalized image pixels which is what the model sees, but can be hard to interpret visually
73 | np_image = np.transpose(self.inputs.squeeze(), (1, 2, 0))
74 |
75 | if sign == "all" and method in ["alpha_scaling", "masked_image"]:
76 | warnings.warn(
77 | "sign='all' is not supported for method='alpha_scaling' or method='masked_image'. "
78 | "Please use sign='positive', sign='negative', or sign='absolute'. "
79 | "Changing sign to default 'positive'."
80 | )
81 | sign = "positive"
82 |
83 | visualizer = ImageAttributionVisualizer(
84 | attributions=attributions_t,
85 | pixel_values=np_image,
86 | outlier_threshold=outlier_threshold,
87 | pred_class=self.id2label[self.predicted_index],
88 | visualization_method=method,
89 | sign=sign,
90 | side_by_side=side_by_side,
91 | )
92 |
93 | viz_result = visualizer()
94 | if save_path:
95 | viz_result[0].savefig(save_path)
96 |
97 | return viz_result
98 |
99 | def _forward_func(self, inputs):
100 | outputs = self.model(inputs)
101 | return outputs["logits"]
102 |
103 | def _calculate_attributions(self, class_name: Union[int, None], index: Union[int, None]) -> np.ndarray:
104 |
105 | if class_name:
106 | self.selected_index = self.label2id[class_name]
107 |
108 | if index:
109 | self.selected_index = index
110 | else:
111 | self.selected_index = self.predicted_index
112 |
113 | if self.attribution_type == AttributionType.INTEGRATED_GRADIENTS.value:
114 | ig = IntegratedGradients(self._forward_func)
115 | self.attributions, self.delta = ig.attribute(
116 | self.inputs,
117 | target=self.selected_index,
118 | internal_batch_size=self.internal_batch_size,
119 | n_steps=self.n_steps,
120 | return_convergence_delta=True,
121 | )
122 | self.delta = self.delta.cpu().detach().numpy()
123 | if self.attribution_type == AttributionType.INTEGRATED_GRADIENTS_NOISE_TUNNEL.value:
124 | ig_nt = IntegratedGradients(self._forward_func)
125 | nt = NoiseTunnel(ig_nt)
126 | self.attributions = nt.attribute(
127 | self.inputs,
128 | nt_samples=self.noise_tunnel_n_samples,
129 | nt_type=self.noise_tunnel_type,
130 | target=self.selected_index,
131 | n_steps=self.n_steps_noise_tunnel,
132 | )
133 |
134 | self.inputs = self.inputs.cpu().detach().numpy()
135 | self.attributions = self.attributions.cpu().detach().numpy()
136 | return self.attributions
137 |
138 | @staticmethod
139 | def _get_id2label_and_label2id_dict(
140 | labels: List[str],
141 | ) -> Tuple[Dict[int, str], Dict[str, int]]:
142 | id2label: Dict[int, str] = dict()
143 | label2id: Dict[str, int] = dict()
144 | for idx, label in enumerate(labels):
145 | id2label[idx] = label
146 | label2id[label] = idx
147 |
148 | return id2label, label2id
149 |
150 | def __call__(
151 | self,
152 | image: Image,
153 | index: int = None,
154 | class_name: str = None,
155 | internal_batch_size: Union[int, None] = None,
156 | n_steps: Union[int, None] = None,
157 | n_steps_noise_tunnel: Union[int, None] = None,
158 | noise_tunnel_n_samples: Union[int, None] = None,
159 | noise_tunnel_type: NoiseTunnelType = NoiseTunnelType.SMOOTHGRAD.value,
160 | ):
161 | self._image: Image = image
162 | try:
163 | self.noise_tunnel_type = NoiseTunnelType(noise_tunnel_type).value
164 | except ValueError:
165 | raise ValueError(f"noise_tunnel_type must be one of {NoiseTunnelType.__members__}")
166 |
167 | self.inputs = self.feature_extractor(image, return_tensors="pt").to(self.device)["pixel_values"]
168 | self.predicted_index = self.model(self.inputs).logits.argmax().item()
169 |
170 | if n_steps:
171 | self.n_steps = n_steps
172 | if n_steps_noise_tunnel:
173 | self.n_steps_noise_tunnel = n_steps_noise_tunnel
174 | if internal_batch_size:
175 | self.internal_batch_size = internal_batch_size
176 | if noise_tunnel_n_samples:
177 | self.noise_tunnel_n_samples = noise_tunnel_n_samples
178 |
179 | return self._calculate_attributions(class_name, index)
180 |
181 |
182 | class VisualizationMethods(Enum):
183 | HEATMAP = "heatmap"
184 | OVERLAY = "overlay"
185 | ALPHA_SCALING = "alpha_scaling"
186 | MASKED_IMAGE = "masked_image"
187 |
188 |
189 | class SignType(Enum):
190 | ALL = "all"
191 | POSITIVE = "positive"
192 | NEGATIVE = "negative"
193 | ABSOLUTE = "absolute"
194 |
195 |
196 | class ImageAttributionVisualizer:
197 | def __init__(
198 | self,
199 | attributions: np.ndarray,
200 | pixel_values: np.ndarray,
201 | outlier_threshold: float,
202 | pred_class: str,
203 | sign: str,
204 | visualization_method: str,
205 | side_by_side: bool = False,
206 | ):
207 | self.attributions = attributions
208 | self.pixel_values = pixel_values
209 | self.outlier_threshold = outlier_threshold
210 | self.pred_class = pred_class
211 | self.render_pyplot = self._using_ipython()
212 | self.side_by_side = side_by_side
213 | try:
214 | self.visualization_method = VisualizationMethods(visualization_method)
215 | except ValueError:
216 | raise ValueError(
217 | f"""`visualization_method` must be one of the following: {list(VisualizationMethods.__members__.keys())}"""
218 | )
219 |
220 | try:
221 | self.sign = SignType(sign)
222 | except ValueError:
223 | raise ValueError(f"""`sign` must be one of the following: {list(SignType.__members__.keys())}""")
224 | if self.visualization_method == VisualizationMethods.HEATMAP:
225 | self.plot_function = self.heatmap
226 | elif self.visualization_method == VisualizationMethods.OVERLAY:
227 | self.plot_function = self.overlay
228 | elif self.visualization_method == VisualizationMethods.ALPHA_SCALING:
229 | self.plot_function = self.alpha_scaling
230 | elif self.visualization_method == VisualizationMethods.MASKED_IMAGE:
231 | self.plot_function = self.masked_image
232 |
233 | def overlay(self):
234 | if self.side_by_side:
235 | return viz.visualize_image_attr_multiple(
236 | attr=self.attributions,
237 | original_image=self.pixel_values,
238 | methods=["original_image", "blended_heat_map"],
239 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value],
240 | show_colorbar=True,
241 | use_pyplot=self.render_pyplot,
242 | outlier_perc=self.outlier_threshold,
243 | titles=["Original Image", f"Heatmap overlay IG. Prediction: {self.pred_class}"],
244 | )
245 | return viz.visualize_image_attr(
246 | original_image=self.pixel_values,
247 | attr=self.attributions,
248 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value,
249 | method="blended_heat_map",
250 | show_colorbar=True,
251 | outlier_perc=self.outlier_threshold,
252 | title=f"Heatmap Overlay IG. Prediction - {self.pred_class}",
253 | use_pyplot=self.render_pyplot,
254 | )
255 |
256 | def heatmap(self):
257 | if self.side_by_side:
258 | return viz.visualize_image_attr_multiple(
259 | attr=self.attributions,
260 | original_image=self.pixel_values,
261 | methods=["original_image", "heat_map"],
262 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value],
263 | show_colorbar=True,
264 | use_pyplot=self.render_pyplot,
265 | outlier_perc=self.outlier_threshold,
266 | titles=["Original Image", f"Heatmap IG. Prediction: {self.pred_class}"],
267 | )
268 |
269 | return viz.visualize_image_attr(
270 | original_image=self.pixel_values,
271 | attr=self.attributions,
272 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value,
273 | method="heat_map",
274 | show_colorbar=True,
275 | outlier_perc=self.outlier_threshold,
276 | title=f"Heatmap IG. Prediction - {self.pred_class}",
277 | use_pyplot=self.render_pyplot,
278 | )
279 |
280 | def alpha_scaling(self):
281 | if self.side_by_side:
282 | return viz.visualize_image_attr_multiple(
283 | attr=self.attributions,
284 | original_image=self.pixel_values,
285 | methods=["original_image", "alpha_scaling"],
286 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value],
287 | show_colorbar=True,
288 | use_pyplot=self.render_pyplot,
289 | outlier_perc=self.outlier_threshold,
290 | titles=["Original Image", f"Alpha Scaled IG. Prediction: {self.pred_class}"],
291 | )
292 |
293 | return viz.visualize_image_attr(
294 | original_image=self.pixel_values,
295 | attr=self.attributions,
296 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value,
297 | method="alpha_scaling",
298 | show_colorbar=True,
299 | outlier_perc=self.outlier_threshold,
300 | title=f"Alpha Scaling IG. Prediction - {self.pred_class}",
301 | use_pyplot=self.render_pyplot,
302 | )
303 |
304 | def masked_image(self):
305 | if self.side_by_side:
306 | return viz.visualize_image_attr_multiple(
307 | attr=self.attributions,
308 | original_image=self.pixel_values,
309 | methods=["original_image", "masked_image"],
310 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value],
311 | show_colorbar=True,
312 | use_pyplot=self.render_pyplot,
313 | outlier_perc=self.outlier_threshold,
314 | titles=["Original Image", f"Masked Image IG. Prediction: {self.pred_class}"],
315 | )
316 | return viz.visualize_image_attr(
317 | original_image=self.pixel_values,
318 | attr=self.attributions,
319 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value,
320 | method="masked_image",
321 | show_colorbar=True,
322 | outlier_perc=self.outlier_threshold,
323 | title=f"Masked Image IG. Prediction - {self.pred_class}",
324 | use_pyplot=self.render_pyplot,
325 | )
326 |
327 | def _using_ipython(self) -> bool:
328 | try:
329 | eval("__IPYTHON__")
330 | except NameError:
331 | return False
332 | else: # pragma: no cover
333 | return True
334 |
335 | def __call__(self):
336 | return self.plot_function()
337 |
--------------------------------------------------------------------------------