├── .DS_Store ├── .coveragerc ├── .gitattributes ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── unit_tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── images ├── .DS_Store ├── bert_ner_explainer.png ├── bert_qa_explainer.png ├── coverage.svg ├── distilbert_example.png ├── distilbert_example_negative.png ├── logotype@1920x.png ├── logotype@1920x_transparent.png ├── multilabel_example.png ├── pairwise_cross_encoder_example.png ├── tight@1920x.png ├── tight@1920x_transparent.png ├── transformers_interpret_A01.ai ├── vision │ ├── alpha_scaling_sbs.png │ ├── heatmap_sbs.png │ ├── masked_image_sbs.png │ └── overlay_sbs.png ├── zero_shot_example.png └── zero_shot_example2.png ├── notebooks ├── .DS_Store ├── image_classification_explainer.ipynb ├── multiclass_classification_example.ipynb └── ner_example.ipynb ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test ├── .DS_Store ├── .gitkeep ├── text │ ├── test_explainer.py │ ├── test_multilabel_classification_explainer.py │ ├── test_question_answering_explainer.py │ ├── test_sequence_classification_explainer.py │ ├── test_token_classification_explainer.py │ └── test_zero_shot_explainer.py └── vision │ └── test_image_classification.py └── transformers_interpret ├── .DS_Store ├── __init__.py ├── attributions.py ├── errors.py ├── explainer.py └── explainers ├── text ├── __init__.py ├── multilabel_classification.py ├── question_answering.py ├── sequence_classification.py ├── token_classification.py └── zero_shot_classification.py └── vision ├── attribution_types.py └── image_classification.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/.DS_Store -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | # Point at this file with PYCHARM_COVERAGERC=/PATH/TO/THIS/.coveragerc 3 | # May require reading PYCHARM_COVERAGERC in run_coverage.py in intellij helpers 4 | 5 | [run] 6 | branch = True 7 | 8 | omit = 9 | */runners/* 10 | *docs* 11 | *stubs* 12 | *examples* 13 | *tests* 14 | */config/plugins/python/helpers/* 15 | 16 | [report] 17 | omit = 18 | */runners/* 19 | *docs* 20 | *stubs* 21 | *examples* 22 | *test* 23 | 24 | 25 | exclude_lines = 26 | 27 | # Don't complain if non-runnable code isn't run: 28 | if __name__ == .__main__.: 29 | 30 | raise NotImplementedError() 31 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # PR Description 4 | 5 | 6 | 7 | 8 | ## Motivation and Context 9 | 10 | 11 | 12 | References issue: # (ISSUE) 13 | 14 | ## Tests and Coverage 15 | 16 | 17 | 18 | ## Types of changes 19 | 20 | 21 | 22 | - [ ] Bug fix (non-breaking change which fixes an issue) 23 | - [ ] New feature (non-breaking change which adds functionality) 24 | - [ ] Docs (Added to or improved Transformers Interpret's documentation) 25 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 26 | 27 | ## Final Checklist: 28 | 29 | 30 | 31 | - [ ] My code follows the [code style](https://github.com/cdpierse/transformers-interpret/blob/master/CONTRIBUTING.md) of this project. 32 | - [ ] I have updated the documentation accordingly. 33 | - [ ] I have added tests to cover my changes. 34 | - [ ] All new and existing tests passed. 35 | -------------------------------------------------------------------------------- /.github/workflows/unit_tests.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Unit Tests 3 | on: push 4 | jobs: 5 | tests: 6 | runs-on: ubuntu-20.04 7 | steps: 8 | - name: Checkout 9 | uses: actions/checkout@v1 10 | 11 | - name: Set up Python 3.7.13 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.7.13 15 | 16 | - name: Install poetry 17 | run: | 18 | which python 19 | which pip 20 | pip install poetry 21 | 22 | - name: Install Python dependencies 23 | if: steps.cache-poetry.outputs.cache-hit != 'true' 24 | run: | 25 | poetry install 26 | 27 | - name: Run Unit tests 28 | 29 | run: | 30 | export PATH="$HOME/.pyenv/bin:$PATH" 31 | export PYTHONPATH="." 32 | 33 | poetry run pytest -s --cov=transformers_interpret/ --cov-report term-missing \ 34 | test 35 | 36 | - name: Report coverage 37 | run: | 38 | export PATH="$HOME/.pyenv/bin:$PATH" 39 | poetry run coverage report --fail-under=50 40 | poetry run coverage html -d unit_htmlcov 41 | 42 | - uses: actions/upload-artifact@v2 43 | with: 44 | name: ti-unit-coverage 45 | path: ti-unit-htmlcov/ 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | references.md 127 | test.html 128 | api.py 129 | .vscode 130 | notebooks/explore_captum.ipynb 131 | *.html 132 | 133 | # Pycharm 134 | .idea/ 135 | release_script.py 136 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.3.0 10 | hooks: 11 | - id: black 12 | args: [--line-length=120, --target-version=py38] 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.10.1 15 | hooks: 16 | - id: isort 17 | args: [-m, '3', --tc, --profile, black] 18 | - repo: https://github.com/myint/autoflake 19 | rev: v1.4 20 | hooks: 21 | - id: autoflake 22 | args: [--in-place, --remove-all-unused-imports] 23 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Pierse" 5 | given-names: "Charles" 6 | title: "Transformers Interpret" 7 | version: 0.5.2 8 | date-released: 2021-02-14 9 | url: "https://github.com/cdpierse/transformers-interpret" 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # 👐 Contributing to Transformers Interpret 2 | 3 | First off, thank you for even considering contributing to this package, every contribution big or small is greatly appreciated. Community contributions are what keep projects like this fueled and constantly improving, so a big thanks to you! 4 | 5 | Below are some sections detailing the guidelines we'd like you to follow to make your contribution as seamless as possible. 6 | 7 | - [Code of Conduct](#coc) 8 | - [Asking a Question and Discussions](#question) 9 | - [Issues, Bugs, and Feature Requests](#issue) 10 | - [Submission Guidelines](#submit) 11 | - [Code Style and Formatting](#code) 12 | 13 | ## 📜 Code of Conduct 14 | 15 | As contributors and maintainers of Transformers Interpret, we pledge to respect everyone who contributes by posting issues, updating documentation, submitting pull requests, providing feedback in comments, and any other activities. 16 | 17 | Communication within our community must be constructive and never resort to personal attacks, trolling, public or private harassment, insults, or other unprofessional conduct. 18 | 19 | We promise to extend courtesy and respect to everyone involved in this project regardless of gender, gender identity, sexual orientation, disability, age, race, ethnicity, religion, or level of experience. We expect anyone contributing to the project to do the same. 20 | 21 | If any member of the community violates this code of conduct, the maintainers of Transformers Interpret may take action, removing issues, comments, and PRs or blocking accounts as deemed appropriate. 22 | 23 | If you are subject to or witness unacceptable behavior, or have any other concerns, please drop us a line at [charlespierse@gmail.com](mailto://charlespierse@gmail.com) 24 | 25 | ## 🗣️ Got a Question ? Want to start a Discussion ? 26 | 27 | We would like to use [Github discussions](https://github.com/cdpierse/transformers-interpret/discussions) as the central hub for all community discussions, questions, and everything else in between. While Github discussions is a new service (as of 2021) we believe that it really helps keep this repo as one single source to find all relevant information. Our hope is that this page functions as a record of all the conversations that help contribute to the project's development. 28 | 29 | We also highly encourage general and theoretical discussion. Much of the Transformers Interpret package implements algorithms in the field of explainable AI (XAI) which is itself very nascent, so discussions around the topic are very welcome. Everyone has something to learn and to teach ! 30 | 31 | If you are new to [Github discussions](https://github.com/cdpierse/transformers-interpret/discussions) it is a very similar experience to Stack Overflow with an added element of general discussion and discourse rather than solely being question and answer based. 32 | 33 | ## 🪲 Found an Issue or Bug? 34 | 35 | If you find a bug in the source code, you can help us by submitting an issue to our [Issues Page](https://github.com/cdpierse/transformers-interpret/issues). Even better you can submit a Pull Request if you have a fix. 36 | 37 | See [below](#submit) for some guidelines. 38 | 39 | ## ✉️ Submission Guidelines 40 | 41 | ### Submitting an Issue 42 | 43 | Before you submit your issue search the archive, maybe your question was already answered. If you feel like your issue is not specific and more of a general question about a design decision, or algorithm implementation maybe start a [discussion](https://github.com/cdpierse/transformers-interpret/discussions) instead, this helps keep the issues less cluttered and encourages more open ended conversation. 44 | 45 | If your issue appears to be a bug, and hasn't been reported, open a new issue. 46 | Help us to maximize the effort we can spend fixing issues and adding new 47 | features, by not reporting duplicate issues. Providing the following information will increase the 48 | chances of your issue being dealt with quickly: 49 | 50 | - **Describe the bug** - A clear and concise description of what the bug is. 51 | - **To Reproduce**- Steps to reproduce the behavior. 52 | - **Expected behavior** - A clear and concise description of what you expected to happen. 53 | - **Environment** 54 | - Transformers Interpret version 55 | - Python version 56 | - OS 57 | - **Suggest a Fix** - if you can't fix the bug yourself, perhaps you can point to what might be 58 | causing the problem (line of code or commit) 59 | 60 | When you submit a PR you will be presented with a PR template, please fill this in as best you can. 61 | 62 | ### Submitting a Pull Request 63 | 64 | Before you submit your pull request consider the following guidelines: 65 | 66 | - Search [GitHub](https://github.com/cdpierse/transformers-interpret/pulls) for an open or closed Pull Request 67 | that relates to your submission. You don't want to duplicate effort. 68 | - Create a fork of a repository and set up a remote that points to the original project: 69 | 70 | ```shell 71 | git remote add upstream git@github.com:cdpierse/transformers-interpret.git 72 | ``` 73 | 74 | - Make your changes in a new git branch, based off master branch: 75 | 76 | ```shell 77 | git checkout -b my-fix-branch master 78 | ``` 79 | 80 | - There are three typical branch name conventions we try to stick to, there are of course always exceptions to the rule but generally the branch name format is one of: 81 | 82 | - **fix/my-branch-name** - this is for fixes or patches 83 | - **docs/my-branch-name** - this is for a contribution to some form of documentation i.e. Readme etc 84 | - **feature/my-branch-name** - this is for new features or additions to the code base 85 | 86 | - Create your patch, **including appropriate test cases**. 87 | - **A note on tests**: This package is in a odd position in that is has some heavy external dependencies on the Transformers package, Pytorch, and Captum. In order to adequately test all the features of an explainer it must often be tested alongside a model, we do our best to cover most models and edge cases however this isn't always possible given that some model's have their own quirks in implementations that break convention. We don't expect 100% test coverage but generally we'd like to keep it > 90%. 88 | - Follow our [Coding Rules](#rules). 89 | - Avoid checking in files that shouldn't be tracked (e.g `dist`, `build`, `.tmp`, `.idea`). We recommend using a [global](#global-gitignore) gitignore for this. 90 | - Before you commit please run the test suite and coverage and make sure all tests are passing. 91 | 92 | ```shell 93 | coverage run -m pytest -s -v && coverage report -m 94 | ``` 95 | - Format your code appropriately: 96 | * This package uses [black](https://black.readthedocs.io/en/stable/) as its formatter. In order to format your code with black run ```black . ``` from the root of the package. 97 | 98 | - Commit your changes using a descriptive commit message. 99 | 100 | ```shell 101 | git commit -a 102 | ``` 103 | 104 | Note: the optional commit `-a` command line option will automatically "add" and "rm" edited files. 105 | 106 | * Push your branch to GitHub: 107 | 108 | ```shell 109 | git push origin fix/my-fix-branch 110 | ``` 111 | 112 | * In GitHub, send a pull request to `transformers-interpret:dev`. 113 | * If we suggest changes then: 114 | 115 | - Make the required updates. 116 | - Rebase your branch and force push to your GitHub repository (this will update your Pull Request): 117 | 118 | ```shell 119 | git rebase master -i 120 | git push origin fix/my-fix-branch -f 121 | ``` 122 | 123 | That's it! Thank you for your contribution! 124 | 125 | 126 | ### After your pull request is merged 127 | 128 | After your pull request is merged, you can safely delete your branch and pull the changes 129 | from the main (upstream) repository: 130 | 131 | - Delete the remote branch on GitHub either through the GitHub web UI or your local shell as follows: 132 | 133 | ```shell 134 | git push origin --delete my-fix-branch 135 | ``` 136 | 137 | - Check out the master branch: 138 | 139 | ```shell 140 | git checkout master -f 141 | ``` 142 | 143 | - Delete the local branch: 144 | 145 | ```shell 146 | git branch -D my-fix-branch 147 | ``` 148 | 149 | - Update your master with the latest upstream version: 150 | 151 | ```shell 152 | git pull --ff upstream master 153 | ``` 154 | 155 | ## ✅ Coding Rules 156 | 157 | We generally follow the [Google Python style guide](http://google.github.io/styleguide/pyguide.html). 158 | 159 | ---- 160 | 161 | *This guide was inspired by the [Firebase Web Quickstarts contribution guidelines](https://github.com/firebase/quickstart-js/blob/master/CONTRIBUTING.md) and [Databay](https://github.com/Voyz/databay/edit/master/CONTRIBUTING.md)* 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/.DS_Store -------------------------------------------------------------------------------- /images/bert_ner_explainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/bert_ner_explainer.png -------------------------------------------------------------------------------- /images/bert_qa_explainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/bert_qa_explainer.png -------------------------------------------------------------------------------- /images/coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 92% 19 | 92% 20 | 21 | 22 | -------------------------------------------------------------------------------- /images/distilbert_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/distilbert_example.png -------------------------------------------------------------------------------- /images/distilbert_example_negative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/distilbert_example_negative.png -------------------------------------------------------------------------------- /images/logotype@1920x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/logotype@1920x.png -------------------------------------------------------------------------------- /images/logotype@1920x_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/logotype@1920x_transparent.png -------------------------------------------------------------------------------- /images/multilabel_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/multilabel_example.png -------------------------------------------------------------------------------- /images/pairwise_cross_encoder_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/pairwise_cross_encoder_example.png -------------------------------------------------------------------------------- /images/tight@1920x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/tight@1920x.png -------------------------------------------------------------------------------- /images/tight@1920x_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/tight@1920x_transparent.png -------------------------------------------------------------------------------- /images/transformers_interpret_A01.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/transformers_interpret_A01.ai -------------------------------------------------------------------------------- /images/vision/alpha_scaling_sbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/alpha_scaling_sbs.png -------------------------------------------------------------------------------- /images/vision/heatmap_sbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/heatmap_sbs.png -------------------------------------------------------------------------------- /images/vision/masked_image_sbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/masked_image_sbs.png -------------------------------------------------------------------------------- /images/vision/overlay_sbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/vision/overlay_sbs.png -------------------------------------------------------------------------------- /images/zero_shot_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/zero_shot_example.png -------------------------------------------------------------------------------- /images/zero_shot_example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/images/zero_shot_example2.png -------------------------------------------------------------------------------- /notebooks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/notebooks/.DS_Store -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "transformers-interpret" 3 | version = "0.10.0" 4 | description = "Model explainability that works seamlessly with 🤗 transformers. Explain your transformers model in just 2 lines of code." 5 | readme = "README.md" 6 | classifiers = [ 7 | "Development Status :: 3 - Alpha", 8 | "Intended Audience :: Developers", 9 | "Programming Language :: Python :: 3", 10 | "Programming Language :: Python :: 3.6", 11 | "Programming Language :: Python :: 3.7", 12 | "Programming Language :: Python :: 3.8", 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3 :: Only", 15 | "Topic :: Software Development :: Libraries :: Python Modules", 16 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 17 | "Topic :: Scientific/Engineering :: Information Analysis", 18 | "Topic :: Text Processing :: Linguistic", 19 | ] 20 | authors = ["Charles Pierse "] 21 | 22 | [tool.poetry.dependencies] 23 | python = ">=3.7,<4.0" 24 | captum = ">=0.3.1" 25 | transformers = ">=3.0.0" 26 | ipython = "^7.31.1" 27 | 28 | [tool.poetry.dev-dependencies] 29 | pre-commit = "^2.19.0" 30 | twine = "^4.0.1" 31 | pytest = "^5.4.2" 32 | black = {version = "^22.6.0", allow-prereleases = true} 33 | pytest-cov = "^3.0.0" 34 | jupyterlab = "^3.4.5" 35 | 36 | 37 | [build-system] 38 | requires = ["poetry-core>=1.0.0"] 39 | build-backend = "poetry.core.masonry.api" 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==5.4.2 2 | captum==0.4.1 3 | transformers==4.15.0 4 | ipython==7.31.1 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | from setuptools import find_packages 4 | 5 | with open("README.md", "r", encoding="utf-8") as fh: 6 | long_description = fh.read() 7 | 8 | 9 | setup( 10 | name="transformers-interpret", 11 | packages=find_packages( 12 | exclude=[ 13 | "*.tests", 14 | "*.tests.*", 15 | "tests.*", 16 | "tests", 17 | "examples", 18 | "docs", 19 | "out", 20 | "dist", 21 | "media", 22 | "test", 23 | ] 24 | ), 25 | version="0.9.6", 26 | license="Apache-2.0", 27 | description="Transformers Interpret is a model explainability tool designed to work exclusively with 🤗 transformers.", 28 | long_description=long_description, 29 | long_description_content_type="text/markdown", 30 | author="Charles Pierse", 31 | author_email="charlespierse@gmail.com", 32 | url="https://github.com/cdpierse/transformers-interpret", 33 | keywords=[ 34 | "machine learning", 35 | "natural language proessing", 36 | "explainability", 37 | "transformers", 38 | "model interpretability", 39 | ], 40 | install_requires=["transformers>=3.0.0", "captum>=0.3.1"], 41 | classifiers=[ 42 | "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 43 | "Intended Audience :: Developers", 44 | "Topic :: Software Development :: Libraries :: Python Modules", 45 | "Programming Language :: Python :: 3.8", 46 | ], 47 | ) 48 | -------------------------------------------------------------------------------- /test/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/test/.DS_Store -------------------------------------------------------------------------------- /test/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/test/.gitkeep -------------------------------------------------------------------------------- /test/text/test_explainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from transformers import ( 4 | AutoModelForCausalLM, 5 | AutoModelForMaskedLM, 6 | AutoModelForPreTraining, 7 | AutoTokenizer, 8 | PreTrainedModel, 9 | PreTrainedTokenizer, 10 | PreTrainedTokenizerFast, 11 | ) 12 | 13 | from transformers_interpret import BaseExplainer 14 | 15 | DISTILBERT_MODEL = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased") 16 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased") 17 | 18 | GPT2_MODEL = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") 19 | GPT2_TOKENIZER = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") 20 | 21 | BERT_MODEL = AutoModelForPreTraining.from_pretrained("lysandre/tiny-bert-random") 22 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("lysandre/tiny-bert-random") 23 | 24 | 25 | class DummyExplainer(BaseExplainer): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | 29 | def encode(self, text: str = None): 30 | return self.tokenizer.encode(text, add_special_tokens=False) 31 | 32 | def decode(self, input_ids): 33 | return self.tokenizer.convert_ids_to_tokens(input_ids[0]) 34 | 35 | @property 36 | def word_attributions(self): 37 | pass 38 | 39 | def _run(self): 40 | pass 41 | 42 | def _calculate_attributions(self): 43 | pass 44 | 45 | def _forward(self): 46 | pass 47 | 48 | 49 | def test_explainer_init_distilbert(): 50 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 51 | assert isinstance(explainer.model, PreTrainedModel) 52 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance( 53 | explainer.tokenizer, PreTrainedTokenizer 54 | ) 55 | assert explainer.model_prefix == DISTILBERT_MODEL.base_model_prefix 56 | assert explainer.device == DISTILBERT_MODEL.device 57 | 58 | assert explainer.accepts_position_ids is False 59 | assert explainer.accepts_token_type_ids is False 60 | 61 | assert explainer.model.config.model_type == "distilbert" 62 | assert explainer.position_embeddings is not None 63 | assert explainer.word_embeddings is not None 64 | assert explainer.token_type_embeddings is None 65 | 66 | 67 | def test_explainer_init_bert(): 68 | explainer = DummyExplainer(BERT_MODEL, BERT_TOKENIZER) 69 | assert isinstance(explainer.model, PreTrainedModel) 70 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance( 71 | explainer.tokenizer, PreTrainedTokenizer 72 | ) 73 | assert explainer.model_prefix == BERT_MODEL.base_model_prefix 74 | assert explainer.device == BERT_MODEL.device 75 | 76 | assert explainer.accepts_position_ids is True 77 | assert explainer.accepts_token_type_ids is True 78 | 79 | assert explainer.model.config.model_type == "bert" 80 | assert explainer.position_embeddings is not None 81 | assert explainer.word_embeddings is not None 82 | assert explainer.token_type_embeddings is not None 83 | 84 | 85 | def test_explainer_init_gpt2(): 86 | explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER) 87 | assert isinstance(explainer.model, PreTrainedModel) 88 | assert isinstance(explainer.tokenizer, PreTrainedTokenizerFast) | isinstance( 89 | explainer.tokenizer, PreTrainedTokenizer 90 | ) 91 | assert explainer.model_prefix == GPT2_MODEL.base_model_prefix 92 | assert explainer.device == GPT2_MODEL.device 93 | 94 | assert explainer.accepts_position_ids is True 95 | assert explainer.accepts_token_type_ids is True 96 | 97 | assert explainer.model.config.model_type == "gpt2" 98 | assert explainer.position_embeddings is not None 99 | assert explainer.word_embeddings is not None 100 | 101 | 102 | def test_explainer_init_cpu(): 103 | old_device = DISTILBERT_MODEL.device 104 | try: 105 | DISTILBERT_MODEL.to("cpu") 106 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 107 | assert explainer.device.type == "cpu" 108 | finally: 109 | DISTILBERT_MODEL.to(old_device) 110 | 111 | 112 | def test_explainer_init_cuda(): 113 | if not torch.cuda.is_available(): 114 | print("Cuda device not available to test. Skipping.") 115 | else: 116 | old_device = DISTILBERT_MODEL.device 117 | try: 118 | DISTILBERT_MODEL.to("cuda") 119 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 120 | assert explainer.device.type == "cuda" 121 | finally: 122 | DISTILBERT_MODEL.to(old_device) 123 | 124 | 125 | def test_explainer_make_input_reference_pair(): 126 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 127 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 128 | assert isinstance(input_ids, Tensor) 129 | assert isinstance(ref_input_ids, Tensor) 130 | assert isinstance(len_inputs, int) 131 | 132 | assert len(input_ids[0]) == len(ref_input_ids[0]) == (len_inputs + 2) 133 | assert ref_input_ids[0][0] == input_ids[0][0] 134 | assert ref_input_ids[0][-1] == input_ids[0][-1] 135 | assert ref_input_ids[0][0] == explainer.cls_token_id 136 | assert ref_input_ids[0][-1] == explainer.sep_token_id 137 | 138 | 139 | def test_explainer_make_input_reference_pair_gpt2(): 140 | explainer = DummyExplainer(GPT2_MODEL, GPT2_TOKENIZER) 141 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 142 | assert isinstance(input_ids, Tensor) 143 | assert isinstance(ref_input_ids, Tensor) 144 | assert isinstance(len_inputs, int) 145 | 146 | assert len(input_ids[0]) == len(ref_input_ids[0]) == (len_inputs) 147 | 148 | 149 | def test_explainer_make_input_token_type_pair_no_sep_idx(): 150 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 151 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 152 | ( 153 | token_type_ids, 154 | ref_token_type_ids, 155 | ) = explainer._make_input_reference_token_type_pair(input_ids) 156 | 157 | assert ref_token_type_ids[0][0] == torch.zeros(len(input_ids[0]))[0] 158 | for i, val in enumerate(token_type_ids[0]): 159 | if i == 0: 160 | assert val == 0 161 | else: 162 | assert val == 1 163 | 164 | 165 | def test_explainer_make_input_token_type_pair_sep_idx(): 166 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 167 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 168 | ( 169 | token_type_ids, 170 | ref_token_type_ids, 171 | ) = explainer._make_input_reference_token_type_pair(input_ids, 3) 172 | 173 | assert ref_token_type_ids[0][0] == torch.zeros(len(input_ids[0]))[0] 174 | for i, val in enumerate(token_type_ids[0]): 175 | if i <= 3: 176 | assert val == 0 177 | else: 178 | assert val == 1 179 | 180 | 181 | def test_explainer_make_input_reference_position_id_pair(): 182 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 183 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 184 | position_ids, ref_position_ids = explainer._make_input_reference_position_id_pair(input_ids) 185 | 186 | assert ref_position_ids[0][0] == torch.zeros(len(input_ids[0]))[0] 187 | for i, val in enumerate(position_ids[0]): 188 | assert val == i 189 | 190 | 191 | def test_explainer_make_attention_mask(): 192 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 193 | input_ids, ref_input_ids, len_inputs = explainer._make_input_reference_pair("this is a test string") 194 | attention_mask = explainer._make_attention_mask(input_ids) 195 | assert len(attention_mask[0]) == len(input_ids[0]) 196 | for i, val in enumerate(attention_mask[0]): 197 | assert val == 1 198 | 199 | 200 | def test_explainer_str(): 201 | explainer = DummyExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 202 | s = "DummyExplainer(" 203 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__}," 204 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__}" 205 | s += ")" 206 | assert s == explainer.__str__() 207 | -------------------------------------------------------------------------------- /test/text/test_multilabel_classification_explainer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 3 | 4 | from transformers_interpret import MultiLabelClassificationExplainer 5 | 6 | DISTILBERT_MODEL = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") 7 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") 8 | 9 | BERT_MODEL = AutoModelForSequenceClassification.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification") 10 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification") 11 | 12 | 13 | def test_multilabel_classification_explainer_init_distilbert(): 14 | seq_explainer = MultiLabelClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 15 | assert seq_explainer.attribution_type == "lig" 16 | assert seq_explainer.label2id == DISTILBERT_MODEL.config.label2id 17 | assert seq_explainer.id2label == DISTILBERT_MODEL.config.id2label 18 | assert seq_explainer.attributions is None 19 | assert seq_explainer.labels == [] 20 | 21 | 22 | def test_multilabel_classification_explainer_init_bert(): 23 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 24 | assert seq_explainer.attribution_type == "lig" 25 | assert seq_explainer.label2id == BERT_MODEL.config.label2id 26 | assert seq_explainer.id2label == BERT_MODEL.config.id2label 27 | assert seq_explainer.attributions is None 28 | assert seq_explainer.labels == [] 29 | 30 | 31 | def test_multilabel_classification_explainer_word_attributes_is_dict(): 32 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 33 | wa = seq_explainer("this is a sample text") 34 | assert isinstance(wa, dict) 35 | 36 | 37 | def test_multilabel_classification_explainer_word_attributes_is_equals_label_length(): 38 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 39 | wa = seq_explainer("this is a sample text") 40 | assert len(wa) == len(BERT_MODEL.config.id2label) 41 | 42 | 43 | def test_multilabel_classification_word_attributions_not_calculated_raises(): 44 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 45 | with pytest.raises(ValueError): 46 | seq_explainer.word_attributions 47 | 48 | 49 | def test_multilabel_classification_viz(): 50 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 51 | wa = seq_explainer("this is a sample text") 52 | seq_explainer.visualize() 53 | 54 | 55 | @pytest.mark.skip(reason="Slow test") 56 | def test_multilabel_classification_classification_custom_steps(): 57 | explainer_string = "I love you , I like you" 58 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 59 | seq_explainer(explainer_string, n_steps=1) 60 | 61 | 62 | @pytest.mark.skip(reason="Slow test") 63 | def test_multilabel_classification_classification_internal_batch_size(): 64 | explainer_string = "I love you , I like you" 65 | seq_explainer = MultiLabelClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 66 | seq_explainer(explainer_string, internal_batch_size=1) 67 | -------------------------------------------------------------------------------- /test/text/test_question_answering_explainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer 5 | 6 | from transformers_interpret import QuestionAnsweringExplainer 7 | from transformers_interpret.errors import ( 8 | AttributionTypeNotSupportedError, 9 | InputIdsNotCalculatedError, 10 | ) 11 | 12 | DISTILBERT_QA_MODEL = AutoModelForQuestionAnswering.from_pretrained("mrm8488/bert-tiny-5-finetuned-squadv2") 13 | DISTILBERT_QA_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-tiny-5-finetuned-squadv2") 14 | 15 | 16 | def test_question_answering_explainer_init_distilbert(): 17 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 18 | assert qa_explainer.attribution_type == "lig" 19 | assert qa_explainer.attributions is None 20 | assert qa_explainer.position == 0 21 | 22 | 23 | def test_question_answering_explainer_init_attribution_type_error(): 24 | with pytest.raises(AttributionTypeNotSupportedError): 25 | QuestionAnsweringExplainer( 26 | DISTILBERT_QA_MODEL, 27 | DISTILBERT_QA_TOKENIZER, 28 | attribution_type="UNSUPPORTED", 29 | ) 30 | 31 | 32 | def test_question_answering_encode(): 33 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 34 | 35 | _input = "this is a sample of text to be encoded" 36 | tokens = qa_explainer.encode(_input) 37 | assert isinstance(tokens, list) 38 | assert tokens[0] != qa_explainer.cls_token_id 39 | assert tokens[-1] != qa_explainer.sep_token_id 40 | assert len(tokens) >= len(_input.split(" ")) 41 | 42 | 43 | def test_question_answering_decode(): 44 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 45 | explainer_question = "what is his name ?" 46 | explainer_text = "his name is bob" 47 | input_ids, _, _ = qa_explainer._make_input_reference_pair(explainer_question, explainer_text) 48 | decoded = qa_explainer.decode(input_ids) 49 | assert decoded[0] == qa_explainer.tokenizer.cls_token 50 | assert decoded[-1] == qa_explainer.tokenizer.sep_token 51 | assert " ".join(decoded[1:-1]) == explainer_question.lower() + " [SEP] " + explainer_text.lower() 52 | 53 | 54 | def test_question_answering_word_attributions(): 55 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 56 | explainer_question = "what is his name ?" 57 | explainer_text = "his name is bob" 58 | word_attributions = qa_explainer(explainer_question, explainer_text) 59 | assert isinstance(word_attributions, dict) 60 | assert "start" in word_attributions.keys() 61 | assert "end" in word_attributions.keys() 62 | assert len(word_attributions["start"]) == len(word_attributions["end"]) 63 | 64 | 65 | def test_question_answering_word_attributions_input_ids_not_calculated(): 66 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 67 | 68 | with pytest.raises(ValueError): 69 | qa_explainer.word_attributions 70 | 71 | 72 | def test_question_answering_start_pos(): 73 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 74 | explainer_question = "what is his name ?" 75 | explainer_text = "his name is Bob" 76 | qa_explainer(explainer_question, explainer_text) 77 | start_pos = qa_explainer.start_pos 78 | assert start_pos == 10 79 | 80 | 81 | def test_question_answering_end_pos(): 82 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 83 | explainer_question = "what is his name ?" 84 | explainer_text = "his name is Bob" 85 | qa_explainer(explainer_question, explainer_text) 86 | end_pos = qa_explainer.end_pos 87 | assert end_pos == 10 88 | 89 | 90 | def test_question_answering_start_pos_input_ids_not_calculated(): 91 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 92 | with pytest.raises(InputIdsNotCalculatedError): 93 | qa_explainer.start_pos 94 | 95 | 96 | def test_question_answering_end_pos_input_ids_not_calculated(): 97 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 98 | with pytest.raises(InputIdsNotCalculatedError): 99 | qa_explainer.end_pos 100 | 101 | 102 | def test_question_answering_predicted_answer(): 103 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 104 | explainer_question = "what is his name ?" 105 | explainer_text = "his name is Bob" 106 | qa_explainer(explainer_question, explainer_text) 107 | predicted_answer = qa_explainer.predicted_answer 108 | assert predicted_answer == "bob" 109 | 110 | 111 | def test_question_answering_predicted_answer_input_ids_not_calculated(): 112 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 113 | with pytest.raises(InputIdsNotCalculatedError): 114 | qa_explainer.predicted_answer 115 | 116 | 117 | def test_question_answering_visualize(): 118 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 119 | explainer_question = "what is his name ?" 120 | explainer_text = "his name is Bob" 121 | qa_explainer(explainer_question, explainer_text) 122 | qa_explainer.visualize() 123 | 124 | 125 | def test_question_answering_visualize_save(): 126 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 127 | explainer_question = "what is his name ?" 128 | explainer_text = "his name is Bob" 129 | qa_explainer(explainer_question, explainer_text) 130 | 131 | html_filename = "./test/qa_test.html" 132 | qa_explainer.visualize(html_filename) 133 | assert os.path.exists(html_filename) 134 | os.remove(html_filename) 135 | 136 | 137 | def test_question_answering_visualize_save_append_html_file_ending(): 138 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 139 | explainer_question = "what is his name ?" 140 | explainer_text = "his name is Bob" 141 | qa_explainer(explainer_question, explainer_text) 142 | 143 | html_filename = "./test/qa_test" 144 | qa_explainer.visualize(html_filename) 145 | assert os.path.exists(html_filename + ".html") 146 | os.remove(html_filename + ".html") 147 | 148 | 149 | def xtest_question_answering_custom_steps(): 150 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 151 | explainer_question = "what is his name ?" 152 | explainer_text = "his name is Bob" 153 | qa_explainer(explainer_question, explainer_text, n_steps=1) 154 | 155 | 156 | def xtest_question_answering_custom_internal_batch_size(): 157 | qa_explainer = QuestionAnsweringExplainer(DISTILBERT_QA_MODEL, DISTILBERT_QA_TOKENIZER) 158 | explainer_question = "what is his name ?" 159 | explainer_text = "his name is Bob" 160 | qa_explainer(explainer_question, explainer_text, internal_batch_size=1) 161 | -------------------------------------------------------------------------------- /test/text/test_sequence_classification_explainer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 3 | 4 | from transformers_interpret import ( 5 | PairwiseSequenceClassificationExplainer, 6 | SequenceClassificationExplainer, 7 | ) 8 | from transformers_interpret.errors import ( 9 | AttributionTypeNotSupportedError, 10 | InputIdsNotCalculatedError, 11 | ) 12 | 13 | DISTILBERT_MODEL = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") 14 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") 15 | 16 | BERT_MODEL = AutoModelForSequenceClassification.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification") 17 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("mrm8488/bert-mini-finetuned-age_news-classification") 18 | 19 | CROSS_ENCODER_MODEL = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-TinyBERT-L-2-v2") 20 | CROSS_ENCODER_TOKENIZER = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-TinyBERT-L-2-v2") 21 | 22 | 23 | def test_sequence_classification_explainer_init_distilbert(): 24 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 25 | assert seq_explainer.attribution_type == "lig" 26 | assert seq_explainer.label2id == DISTILBERT_MODEL.config.label2id 27 | assert seq_explainer.id2label == DISTILBERT_MODEL.config.id2label 28 | assert seq_explainer.attributions is None 29 | 30 | 31 | def test_sequence_classification_explainer_init_bert(): 32 | seq_explainer = SequenceClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 33 | assert seq_explainer.attribution_type == "lig" 34 | assert seq_explainer.label2id == BERT_MODEL.config.label2id 35 | assert seq_explainer.id2label == BERT_MODEL.config.id2label 36 | assert seq_explainer.attributions is None 37 | 38 | 39 | def test_sequence_classification_explainer_init_attribution_type_error(): 40 | with pytest.raises(AttributionTypeNotSupportedError): 41 | SequenceClassificationExplainer( 42 | DISTILBERT_MODEL, 43 | DISTILBERT_TOKENIZER, 44 | attribution_type="UNSUPPORTED", 45 | ) 46 | 47 | 48 | def test_sequence_classification_explainer_init_with_custom_labels(): 49 | labels = ["label_1", "label_2"] 50 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=labels) 51 | assert len(labels) == len(seq_explainer.id2label) 52 | assert len(labels) == len(seq_explainer.label2id) 53 | for (k1, v1), (k2, v2) in zip(seq_explainer.id2label.items(), seq_explainer.label2id.items()): 54 | assert v1 in labels and k2 in labels 55 | 56 | 57 | def test_sequence_classification_explainer_init_custom_labels_size_error(): 58 | with pytest.raises(ValueError): 59 | SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["few_labels"]) 60 | 61 | 62 | def test_sequence_classification_encode(): 63 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 64 | 65 | _input = "this is a sample of text to be encoded" 66 | tokens = seq_explainer.encode(_input) 67 | assert isinstance(tokens, list) 68 | assert tokens[0] != seq_explainer.cls_token_id 69 | assert tokens[-1] != seq_explainer.sep_token_id 70 | assert len(tokens) >= len(_input.split(" ")) 71 | 72 | 73 | def test_sequence_classification_decode(): 74 | explainer_string = "I love you , I hate you" 75 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 76 | input_ids, _, _ = seq_explainer._make_input_reference_pair(explainer_string) 77 | decoded = seq_explainer.decode(input_ids) 78 | assert decoded[0] == seq_explainer.tokenizer.cls_token 79 | assert decoded[-1] == seq_explainer.tokenizer.sep_token 80 | assert " ".join(decoded[1:-1]) == explainer_string.lower() 81 | 82 | 83 | def test_sequence_classification_run_text_given(): 84 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 85 | word_attributions = seq_explainer._run("I love you, I just love you") 86 | assert isinstance(word_attributions, list) 87 | 88 | actual_tokens = [token for token, _ in word_attributions] 89 | expected_tokens = [ 90 | "[CLS]", 91 | "i", 92 | "love", 93 | "you", 94 | ",", 95 | "i", 96 | "just", 97 | "love", 98 | "you", 99 | "[SEP]", 100 | ] 101 | assert actual_tokens == expected_tokens 102 | 103 | 104 | def test_sequence_classification_explain_on_cls_index(): 105 | explainer_string = "I love you , I like you" 106 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 107 | seq_explainer._run(explainer_string, index=0) 108 | assert seq_explainer.predicted_class_index == 1 109 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index 110 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index] 111 | assert seq_explainer.predicted_class_name != "NEGATIVE" 112 | assert seq_explainer.predicted_class_name == "POSITIVE" 113 | 114 | 115 | def test_sequence_classification_explain_position_embeddings(): 116 | explainer_string = "I love you , I like you" 117 | seq_explainer = SequenceClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 118 | pos_attributions = seq_explainer(explainer_string, embedding_type=1) 119 | word_attributions = seq_explainer(explainer_string, embedding_type=0) 120 | 121 | assert pos_attributions != word_attributions 122 | 123 | 124 | def test_sequence_classification_explain_position_embeddings_not_available(): 125 | explainer_string = "I love you , I like you" 126 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 127 | pos_attributions = seq_explainer(explainer_string, embedding_type=1) 128 | word_attributions = seq_explainer(explainer_string, embedding_type=0) 129 | 130 | assert pos_attributions == word_attributions 131 | 132 | 133 | def test_sequence_classification_explain_embedding_incorrect_value(): 134 | explainer_string = "I love you , I like you" 135 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 136 | 137 | word_attributions = seq_explainer(explainer_string, embedding_type=0) 138 | incorrect_word_attributions = seq_explainer(explainer_string, embedding_type=-42) 139 | 140 | assert incorrect_word_attributions == word_attributions 141 | 142 | 143 | def test_sequence_classification_predicted_class_name_no_id2label_defaults_idx(): 144 | explainer_string = "I love you , I like you" 145 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 146 | seq_explainer.id2label = {"test": "value"} 147 | seq_explainer._run(explainer_string) 148 | assert seq_explainer.predicted_class_name == 1 149 | 150 | 151 | def test_sequence_classification_explain_on_cls_name(): 152 | explainer_string = "I love you , I like you" 153 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 154 | seq_explainer._run(explainer_string, class_name="NEGATIVE") 155 | assert seq_explainer.predicted_class_index == 1 156 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index 157 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index] 158 | assert seq_explainer.predicted_class_name != "NEGATIVE" 159 | assert seq_explainer.predicted_class_name == "POSITIVE" 160 | 161 | 162 | def test_sequence_classification_explain_on_cls_name_with_custom_labels(): 163 | explainer_string = "I love you , I like you" 164 | seq_explainer = SequenceClassificationExplainer( 165 | DISTILBERT_MODEL, DISTILBERT_TOKENIZER, custom_labels=["sad", "happy"] 166 | ) 167 | seq_explainer._run(explainer_string, class_name="sad") 168 | assert seq_explainer.predicted_class_index == 1 169 | assert seq_explainer.predicted_class_index != seq_explainer.selected_index 170 | assert seq_explainer.predicted_class_name != seq_explainer.id2label[seq_explainer.selected_index] 171 | assert seq_explainer.predicted_class_name != "sad" 172 | assert seq_explainer.predicted_class_name == "happy" 173 | 174 | 175 | def test_sequence_classification_explain_on_cls_name_not_in_dict(): 176 | explainer_string = "I love you , I like you" 177 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 178 | seq_explainer._run(explainer_string, class_name="UNKNOWN") 179 | assert seq_explainer.selected_index == 1 180 | assert seq_explainer.predicted_class_index == 1 181 | 182 | 183 | def test_sequence_classification_explain_raises_on_input_ids_not_calculated(): 184 | with pytest.raises(InputIdsNotCalculatedError): 185 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 186 | seq_explainer.predicted_class_index 187 | 188 | 189 | def test_sequence_classification_word_attributions(): 190 | explainer_string = "I love you , I like you" 191 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 192 | seq_explainer(explainer_string) 193 | assert isinstance(seq_explainer.word_attributions, list) 194 | for element in seq_explainer.word_attributions: 195 | assert isinstance(element, tuple) 196 | 197 | 198 | def test_sequence_classification_word_attributions_not_calculated_raises(): 199 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 200 | with pytest.raises(ValueError): 201 | seq_explainer.word_attributions 202 | 203 | 204 | def test_sequence_classification_explainer_str(): 205 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 206 | s = "SequenceClassificationExplainer(" 207 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__}," 208 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__}," 209 | s += "\n\tattribution_type='lig'," 210 | s += ")" 211 | assert s == seq_explainer.__str__() 212 | 213 | 214 | def test_sequence_classification_viz(): 215 | explainer_string = "I love you , I like you" 216 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 217 | seq_explainer(explainer_string) 218 | seq_explainer.visualize() 219 | 220 | 221 | def sequence_classification_custom_steps(): 222 | explainer_string = "I love you , I like you" 223 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 224 | seq_explainer(explainer_string, n_steps=1) 225 | 226 | 227 | def sequence_classification_internal_batch_size(): 228 | explainer_string = "I love you , I like you" 229 | seq_explainer = SequenceClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 230 | seq_explainer(explainer_string, internal_batch_size=1) 231 | 232 | 233 | def test_pairwise_sequence_classification(): 234 | string1 = "How many people live in berlin?" 235 | string2 = "there are 1000000 people living in berlin" 236 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 237 | 238 | attr = explainer(string1, string2) 239 | assert explainer.text1 == string1 240 | assert explainer.text2 == string2 241 | assert attr 242 | 243 | 244 | def test_pairwise_sequence_classification_flip_attribute_sign(): 245 | string1 = "How many people live in berlin?" 246 | string2 = "this string is not related to the question." 247 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 248 | 249 | original_sign_attr = explainer(string1, string2) 250 | flipped_sign_attr = explainer(string1, string2, flip_sign=True) 251 | 252 | for flipped_wa, original_wa in zip(flipped_sign_attr, original_sign_attr): 253 | assert flipped_wa[1] == -original_wa[1] 254 | 255 | 256 | def test_pairwise_sequence_classification_viz(): 257 | string1 = "How many people live in berlin?" 258 | string2 = "there are 1000000 people living in berlin" 259 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 260 | 261 | explainer(string1, string2) 262 | explainer.visualize() 263 | 264 | 265 | def test_pairwise_sequence_classification_custom_steps(): 266 | string1 = "How many people live in berlin?" 267 | string2 = "there are 1000000 people living in berlin" 268 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 269 | 270 | explainer(string1, string2, n_steps=1) 271 | 272 | 273 | def test_pairwise_sequence_classification_internal_batch_size(): 274 | string1 = "How many people live in berlin?" 275 | string2 = "there are 1000000 people living in berlin" 276 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 277 | 278 | explainer(string1, string2, internal_batch_size=1) 279 | 280 | 281 | def test_pairwise_sequence_classification_position_embeddings(): 282 | string1 = "How many people live in berlin?" 283 | string2 = "there are 1000000 people living in berlin" 284 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 285 | 286 | explainer(string1, string2, embedding_type=1) 287 | 288 | 289 | def test_pairwise_sequence_classification_position_embeddings_not_accepted(): 290 | string1 = "How many people live in berlin?" 291 | string2 = "there are 1000000 people living in berlin" 292 | explainer = PairwiseSequenceClassificationExplainer(CROSS_ENCODER_MODEL, CROSS_ENCODER_TOKENIZER) 293 | explainer.accepts_position_ids = False 294 | 295 | explainer(string1, string2, embedding_type=1) 296 | -------------------------------------------------------------------------------- /test/text/test_token_classification_explainer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import AutoModelForTokenClassification, AutoTokenizer 3 | 4 | from transformers_interpret import TokenClassificationExplainer 5 | from transformers_interpret.errors import ( 6 | AttributionTypeNotSupportedError, 7 | InputIdsNotCalculatedError, 8 | ) 9 | 10 | DISTILBERT_MODEL = AutoModelForTokenClassification.from_pretrained( 11 | "elastic/distilbert-base-cased-finetuned-conll03-english" 12 | ) 13 | DISTILBERT_TOKENIZER = AutoTokenizer.from_pretrained("elastic/distilbert-base-cased-finetuned-conll03-english") 14 | 15 | 16 | BERT_MODEL = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") 17 | BERT_TOKENIZER = AutoTokenizer.from_pretrained("dslim/bert-base-NER") 18 | 19 | 20 | def test_token_classification_explainer_init_distilbert(): 21 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 22 | assert ner_explainer.attribution_type == "lig" 23 | assert ner_explainer.label2id == DISTILBERT_MODEL.config.label2id 24 | assert ner_explainer.id2label == DISTILBERT_MODEL.config.id2label 25 | assert ner_explainer.attributions is None 26 | 27 | 28 | def test_token_classification_explainer_init_bert(): 29 | ner_explainer = TokenClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 30 | assert ner_explainer.attribution_type == "lig" 31 | assert ner_explainer.label2id == BERT_MODEL.config.label2id 32 | assert ner_explainer.id2label == BERT_MODEL.config.id2label 33 | assert ner_explainer.attributions is None 34 | 35 | 36 | def test_token_classification_explainer_init_attribution_type_error(): 37 | with pytest.raises(AttributionTypeNotSupportedError): 38 | TokenClassificationExplainer( 39 | DISTILBERT_MODEL, 40 | DISTILBERT_TOKENIZER, 41 | attribution_type="UNSUPPORTED", 42 | ) 43 | 44 | 45 | def test_token_classification_selected_indexes_only_ignored_indexes(): 46 | explainer_string = "We visited Paris during the weekend, where Emmanuel Macron lives." 47 | expected_all_indexes = list(range(15)) 48 | indexes = [0, 1, 2, 3, 4, 5, 7, 8, 9, 11, 12, 13] 49 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 50 | 51 | word_attributions = ner_explainer(explainer_string, ignored_indexes=indexes) 52 | 53 | assert len(ner_explainer._selected_indexes) == (len(expected_all_indexes) - len(indexes)) 54 | 55 | for index in ner_explainer._selected_indexes: 56 | assert index in expected_all_indexes 57 | assert index not in indexes 58 | 59 | 60 | def test_token_classification_selected_indexes_only_ignored_labels(): 61 | ignored_labels = ["O", "I-LOC", "B-LOC"] 62 | indexes = [8, 9, 10] 63 | explainer_string = "We visited Paris last weekend, where Emmanuel Macron lives." 64 | 65 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 66 | 67 | word_attributions = ner_explainer(explainer_string, ignored_labels=ignored_labels) 68 | 69 | assert len(indexes) == len(ner_explainer._selected_indexes) 70 | 71 | for index in ner_explainer._selected_indexes: 72 | assert index in indexes 73 | 74 | 75 | def test_token_classification_selected_indexes_all(): 76 | explainer_string = "We visited Paris during the weekend, where Emmanuel Macron lives." 77 | expected_all_indexes = list(range(15)) 78 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 79 | 80 | word_attributions = ner_explainer(explainer_string) 81 | 82 | assert len(ner_explainer._selected_indexes) == ner_explainer.input_ids.shape[1] 83 | 84 | for i, index in enumerate(ner_explainer._selected_indexes): 85 | assert i == index 86 | 87 | 88 | def test_token_classification_selected_indexes_ignored_indexes_and_labels(): 89 | ignored_labels = ["O", "I-PER", "B-PER"] 90 | ignored_indexes = [4, 5, 6] 91 | explainer_string = "We visited Paris last weekend" 92 | selected_indexes = [3] # this models classifies erroniously '[SEP]' as a location 93 | 94 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 95 | word_attributions = ner_explainer(explainer_string, ignored_indexes=ignored_indexes, ignored_labels=ignored_labels) 96 | 97 | assert len(selected_indexes) == len(ner_explainer._selected_indexes) 98 | 99 | for i, index in enumerate(ner_explainer._selected_indexes): 100 | assert selected_indexes[i] == index 101 | 102 | 103 | def test_token_classification_encode(): 104 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 105 | 106 | _input = "this is a sample of text to be encoded" 107 | tokens = ner_explainer.encode(_input) 108 | assert isinstance(tokens, list) 109 | assert tokens[0] != ner_explainer.cls_token_id 110 | assert tokens[-1] != ner_explainer.sep_token_id 111 | assert len(tokens) >= len(_input.split(" ")) 112 | 113 | 114 | def test_token_classification_decode(): 115 | explainer_string = "We visited Paris during the weekend" 116 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 117 | input_ids, _, _ = ner_explainer._make_input_reference_pair(explainer_string) 118 | decoded = ner_explainer.decode(input_ids) 119 | assert decoded[0] == ner_explainer.tokenizer.cls_token 120 | assert decoded[-1] == ner_explainer.tokenizer.sep_token 121 | assert " ".join(decoded[1:-1]) == explainer_string 122 | 123 | 124 | def test_token_classification_run_text_given(): 125 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 126 | word_attributions = ner_explainer._run("We visited Paris during the weekend") 127 | assert isinstance(word_attributions, dict) 128 | 129 | actual_tokens = list(word_attributions.keys()) 130 | expected_tokens = [ 131 | "[CLS]", 132 | "We", 133 | "visited", 134 | "Paris", 135 | "during", 136 | "the", 137 | "weekend", 138 | "[SEP]", 139 | ] 140 | assert actual_tokens == expected_tokens 141 | 142 | 143 | def test_token_classification_explain_position_embeddings(): 144 | explainer_string = "We visited Paris during the weekend" 145 | ner_explainer = TokenClassificationExplainer(BERT_MODEL, BERT_TOKENIZER) 146 | pos_attributions = ner_explainer(explainer_string, embedding_type=1) 147 | word_attributions = ner_explainer(explainer_string, embedding_type=0) 148 | 149 | for token in ner_explainer.word_attributions.keys(): 150 | assert pos_attributions != word_attributions 151 | 152 | 153 | def test_token_classification_explain_position_embeddings_incorrect_value(): 154 | explainer_string = "We visited Paris during the weekend" 155 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 156 | 157 | word_attributions = ner_explainer(explainer_string, embedding_type=0) 158 | incorrect_word_attributions = ner_explainer(explainer_string, embedding_type=-42) 159 | 160 | assert incorrect_word_attributions == word_attributions 161 | 162 | 163 | def test_token_classification_predicted_class_names(): 164 | explainer_string = "We visited Paris during the weekend" 165 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 166 | ner_explainer._run(explainer_string) 167 | ground_truths = ["O", "O", "O", "B-LOC", "O", "O", "O", "O"] 168 | 169 | assert len(ground_truths) == len(ner_explainer.predicted_class_names) 170 | 171 | for i, class_id in enumerate(ner_explainer.predicted_class_names): 172 | assert ground_truths[i] == class_id 173 | 174 | 175 | def test_token_classification_predicted_class_names_no_id2label_defaults_idx(): 176 | explainer_string = "We visited Paris during the weekend" 177 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 178 | ner_explainer.id2label = {"test": "value"} 179 | ner_explainer._run(explainer_string) 180 | class_labels = list(range(9)) 181 | 182 | assert len(ner_explainer.predicted_class_names) == 8 183 | 184 | for class_name in ner_explainer.predicted_class_names: 185 | assert class_name in class_labels 186 | 187 | 188 | def test_token_classification_explain_raises_on_input_ids_not_calculated(): 189 | with pytest.raises(InputIdsNotCalculatedError): 190 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 191 | ner_explainer.predicted_class_indexes 192 | 193 | 194 | def test_token_classification_word_attributions(): 195 | explainer_string = "We visited Paris during the weekend" 196 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 197 | ner_explainer(explainer_string) 198 | 199 | assert isinstance(ner_explainer.word_attributions, dict) 200 | 201 | for token, elements in ner_explainer.word_attributions.items(): 202 | assert isinstance(elements, dict) 203 | assert list(elements.keys()) == ["label", "attribution_scores"] 204 | assert isinstance(elements["label"], str) 205 | assert isinstance(elements["attribution_scores"], list) 206 | for score in elements["attribution_scores"]: 207 | assert isinstance(score, tuple) 208 | assert isinstance(score[0], str) 209 | assert isinstance(score[1], float) 210 | 211 | 212 | def test_token_classification_word_attributions_not_calculated_raises(): 213 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 214 | with pytest.raises(ValueError): 215 | ner_explainer.word_attributions 216 | 217 | 218 | def test_token_classification_explainer_str(): 219 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 220 | s = "TokenClassificationExplainer(" 221 | s += f"\n\tmodel={DISTILBERT_MODEL.__class__.__name__}," 222 | s += f"\n\ttokenizer={DISTILBERT_TOKENIZER.__class__.__name__}," 223 | s += "\n\tattribution_type='lig'," 224 | s += ")" 225 | assert s == ner_explainer.__str__() 226 | 227 | 228 | def test_token_classification_viz(): 229 | explainer_string = "We visited Paris during the weekend" 230 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 231 | ner_explainer(explainer_string) 232 | ner_explainer.visualize() 233 | 234 | 235 | def test_token_classification_viz_on_true_classes_value_error(): 236 | explainer_string = "We visited Paris during the weekend" 237 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 238 | ner_explainer(explainer_string) 239 | true_classes = ["None", "Location", "None"] 240 | with pytest.raises(ValueError): 241 | ner_explainer.visualize(true_classes=true_classes) 242 | 243 | 244 | def token_classification_custom_steps(): 245 | explainer_string = "We visited Paris during the weekend" 246 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 247 | ner_explainer(explainer_string, n_steps=1) 248 | 249 | 250 | def token_classification_internal_batch_size(): 251 | explainer_string = "We visited Paris during the weekend" 252 | ner_explainer = TokenClassificationExplainer(DISTILBERT_MODEL, DISTILBERT_TOKENIZER) 253 | ner_explainer(explainer_string, internal_batch_size=1) 254 | -------------------------------------------------------------------------------- /test/text/test_zero_shot_explainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 6 | 7 | from transformers_interpret import ZeroShotClassificationExplainer 8 | from transformers_interpret.errors import AttributionTypeNotSupportedError 9 | 10 | DISTILBERT_MNLI_MODEL = AutoModelForSequenceClassification.from_pretrained("typeform/distilbert-base-uncased-mnli") 11 | DISTILBERT_MNLI_TOKENIZER = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli") 12 | 13 | 14 | def test_zero_shot_explainer_init_distilbert(): 15 | zero_shot_explainer = ZeroShotClassificationExplainer( 16 | DISTILBERT_MNLI_MODEL, 17 | DISTILBERT_MNLI_TOKENIZER, 18 | ) 19 | 20 | assert zero_shot_explainer.attribution_type == "lig" 21 | assert zero_shot_explainer.attributions == [] 22 | assert zero_shot_explainer.label_exists is True 23 | assert zero_shot_explainer.entailment_key == "ENTAILMENT" 24 | 25 | 26 | def test_zero_shot_explainer_init_attribution_type_error(): 27 | with pytest.raises(AttributionTypeNotSupportedError): 28 | ZeroShotClassificationExplainer( 29 | DISTILBERT_MNLI_MODEL, 30 | DISTILBERT_MNLI_TOKENIZER, 31 | attribution_type="UNSUPPORTED", 32 | ) 33 | 34 | 35 | @patch.object( 36 | ZeroShotClassificationExplainer, 37 | "_entailment_label_exists", 38 | return_value=(False, None), 39 | ) 40 | def test_zero_shot_explainer_no_entailment_label(mock_method): 41 | with pytest.raises(ValueError): 42 | ZeroShotClassificationExplainer( 43 | DISTILBERT_MNLI_MODEL, 44 | DISTILBERT_MNLI_TOKENIZER, 45 | ) 46 | 47 | 48 | def test_zero_shot_explainer_word_attributions(): 49 | zero_shot_explainer = ZeroShotClassificationExplainer( 50 | DISTILBERT_MNLI_MODEL, 51 | DISTILBERT_MNLI_TOKENIZER, 52 | ) 53 | labels = ["urgent", "phone", "tablet", "computer"] 54 | word_attributions = zero_shot_explainer( 55 | "I have a problem with my iphone that needs to be resolved asap!!", 56 | labels=labels, 57 | ) 58 | assert isinstance(word_attributions, dict) 59 | for label in labels: 60 | assert label in word_attributions.keys() 61 | 62 | 63 | def test_zero_shot_explainer_call_word_attributions_early_raises_error(): 64 | with pytest.raises(ValueError): 65 | zero_shot_explainer = ZeroShotClassificationExplainer( 66 | DISTILBERT_MNLI_MODEL, 67 | DISTILBERT_MNLI_TOKENIZER, 68 | ) 69 | 70 | zero_shot_explainer.word_attributions 71 | 72 | 73 | def test_zero_shot_explainer_word_attributions_include_hypothesis(): 74 | zero_shot_explainer = ZeroShotClassificationExplainer( 75 | DISTILBERT_MNLI_MODEL, 76 | DISTILBERT_MNLI_TOKENIZER, 77 | ) 78 | labels = ["urgent", "phone", "tablet", "computer"] 79 | word_attributions_with_hyp = zero_shot_explainer( 80 | "I have a problem with my iphone that needs to be resolved asap!!", 81 | labels=labels, 82 | include_hypothesis=True, 83 | ) 84 | word_attributions_without_hyp = zero_shot_explainer( 85 | "I have a problem with my iphone that needs to be resolved asap!!", 86 | labels=labels, 87 | include_hypothesis=False, 88 | ) 89 | 90 | for label in labels: 91 | assert len(word_attributions_with_hyp[label]) > len(word_attributions_without_hyp[label]) 92 | 93 | 94 | def test_zero_shot_explainer_visualize(): 95 | zero_shot_explainer = ZeroShotClassificationExplainer( 96 | DISTILBERT_MNLI_MODEL, 97 | DISTILBERT_MNLI_TOKENIZER, 98 | ) 99 | 100 | zero_shot_explainer( 101 | "I have a problem with my iphone that needs to be resolved asap!!", 102 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 103 | ) 104 | zero_shot_explainer.visualize() 105 | 106 | 107 | def test_zero_shot_explainer_visualize_save(): 108 | zero_shot_explainer = ZeroShotClassificationExplainer( 109 | DISTILBERT_MNLI_MODEL, 110 | DISTILBERT_MNLI_TOKENIZER, 111 | ) 112 | 113 | zero_shot_explainer( 114 | "I have a problem with my iphone that needs to be resolved asap!!", 115 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 116 | ) 117 | html_filename = "./test/zero_test.html" 118 | zero_shot_explainer.visualize(html_filename) 119 | assert os.path.exists(html_filename) 120 | os.remove(html_filename) 121 | 122 | 123 | def test_zero_shot_explainer_visualize_include_hypothesis(): 124 | zero_shot_explainer = ZeroShotClassificationExplainer( 125 | DISTILBERT_MNLI_MODEL, 126 | DISTILBERT_MNLI_TOKENIZER, 127 | ) 128 | 129 | zero_shot_explainer( 130 | "I have a problem with my iphone that needs to be resolved asap!!", 131 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 132 | include_hypothesis=True, 133 | ) 134 | zero_shot_explainer.visualize() 135 | 136 | 137 | def test_zero_explainer_visualize_save_append_html_file_ending(): 138 | zero_shot_explainer = ZeroShotClassificationExplainer( 139 | DISTILBERT_MNLI_MODEL, 140 | DISTILBERT_MNLI_TOKENIZER, 141 | ) 142 | 143 | zero_shot_explainer( 144 | "I have a problem with my iphone that needs to be resolved asap!!", 145 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 146 | ) 147 | 148 | html_filename = "./test/zero_test" 149 | zero_shot_explainer.visualize(html_filename) 150 | assert os.path.exists(html_filename + ".html") 151 | os.remove(html_filename + ".html") 152 | 153 | 154 | def test_zero_shot_model_does_not_have_entailment_label(): 155 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"l1": 0, "l2": 1, "l3": 2}): 156 | with pytest.raises(ValueError): 157 | ZeroShotClassificationExplainer( 158 | DISTILBERT_MNLI_MODEL, 159 | DISTILBERT_MNLI_TOKENIZER, 160 | ) 161 | 162 | 163 | def test_zero_shot_model_uppercase_entailment(): 164 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"ENTAILMENT": 0, "l2": 1, "l3": 2}): 165 | ZeroShotClassificationExplainer( 166 | DISTILBERT_MNLI_MODEL, 167 | DISTILBERT_MNLI_TOKENIZER, 168 | ) 169 | 170 | 171 | def test_zero_shot_model_lowercase_entailment(): 172 | with patch.object(DISTILBERT_MNLI_MODEL.config, "label2id", {"entailment": 0, "l2": 1, "l3": 2}): 173 | ZeroShotClassificationExplainer( 174 | DISTILBERT_MNLI_MODEL, 175 | DISTILBERT_MNLI_TOKENIZER, 176 | ) 177 | 178 | 179 | def xtest_zero_shot_custom_steps(): 180 | zero_shot_explainer = ZeroShotClassificationExplainer( 181 | DISTILBERT_MNLI_MODEL, 182 | DISTILBERT_MNLI_TOKENIZER, 183 | ) 184 | 185 | zero_shot_explainer( 186 | "I have a problem with my iphone that needs to be resolved asap!!", 187 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 188 | n_steps=1, 189 | ) 190 | 191 | 192 | def xtest_zero_shot_internal_batch_size(): 193 | zero_shot_explainer = ZeroShotClassificationExplainer( 194 | DISTILBERT_MNLI_MODEL, 195 | DISTILBERT_MNLI_TOKENIZER, 196 | ) 197 | 198 | zero_shot_explainer( 199 | "I have a problem with my iphone that needs to be resolved asap!!", 200 | labels=["urgent", " not", "urgent", "phone", "tablet", "computer"], 201 | internal_batch_size=1, 202 | ) 203 | -------------------------------------------------------------------------------- /test/vision/test_image_classification.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import requests 5 | from PIL import Image 6 | from transformers import AutoFeatureExtractor, AutoModelForImageClassification 7 | 8 | from transformers_interpret import ImageClassificationExplainer 9 | from transformers_interpret.explainers.vision.attribution_types import AttributionType 10 | 11 | model_name = "apple/mobilevit-small" 12 | MODEL = AutoModelForImageClassification.from_pretrained(model_name) 13 | FEATURE_EXTRACTOR = AutoFeatureExtractor.from_pretrained(model_name) 14 | 15 | IMAGE_LINK = "https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=2342&q=80" 16 | TEST_IMAGE = Image.open(requests.get(IMAGE_LINK, stream=True).raw) 17 | 18 | 19 | def test_image_classification_init(): 20 | img_cls_explainer = ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR) 21 | 22 | assert img_cls_explainer.model == MODEL 23 | assert img_cls_explainer.feature_extractor == FEATURE_EXTRACTOR 24 | assert img_cls_explainer.id2label == MODEL.config.id2label 25 | assert img_cls_explainer.label2id == MODEL.config.label2id 26 | 27 | assert img_cls_explainer.attributions is None 28 | 29 | 30 | def test_image_classification_init_attribution_type_not_supported(): 31 | with pytest.raises(ValueError): 32 | ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR, attribution_type="not_supported") 33 | 34 | 35 | def test_image_classification_init_custom_labels(): 36 | labels = [f"label_{i}" for i in range(len(MODEL.config.id2label) - 1)] 37 | img_cls_explainer = ImageClassificationExplainer( 38 | model=MODEL, feature_extractor=FEATURE_EXTRACTOR, custom_labels=labels 39 | ) 40 | 41 | assert list(img_cls_explainer.label2id.keys()) == labels 42 | 43 | 44 | def test_image_classification_init_custom_labels_not_valid(): 45 | with pytest.raises(ValueError): 46 | ImageClassificationExplainer(model=MODEL, feature_extractor=FEATURE_EXTRACTOR, custom_labels=["label_0"]) 47 | 48 | 49 | def test_image_classification_call(): 50 | img_cls_explainer = ImageClassificationExplainer( 51 | model=MODEL, 52 | feature_extractor=FEATURE_EXTRACTOR, 53 | ) 54 | 55 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 56 | 57 | assert img_cls_explainer.attributions is not None 58 | assert img_cls_explainer.predicted_index is not None 59 | assert img_cls_explainer.n_steps == 1 60 | assert img_cls_explainer.n_steps_noise_tunnel == 1 61 | assert img_cls_explainer.noise_tunnel_n_samples == 1 62 | assert img_cls_explainer.internal_batch_size == 1 63 | 64 | 65 | def test_image_classification_call_attribution_type_not_supported(): 66 | img_cls_explainer = ImageClassificationExplainer( 67 | model=MODEL, 68 | feature_extractor=FEATURE_EXTRACTOR, 69 | ) 70 | 71 | with pytest.raises(ValueError): 72 | img_cls_explainer( 73 | TEST_IMAGE, 74 | n_steps=1, 75 | n_steps_noise_tunnel=1, 76 | noise_tunnel_n_samples=1, 77 | internal_batch_size=1, 78 | noise_tunnel_type="not_supported", 79 | ) 80 | 81 | 82 | def test_image_classification_visualize(): 83 | img_cls_explainer = ImageClassificationExplainer( 84 | model=MODEL, 85 | feature_extractor=FEATURE_EXTRACTOR, 86 | ) 87 | 88 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 89 | 90 | img_cls_explainer.visualize(method="alpha_scaling", save_path=None, sign="all", outlier_threshold=0.15) 91 | 92 | 93 | def test_image_classification_visualize_use_normed_pixel_values(): 94 | img_cls_explainer = ImageClassificationExplainer( 95 | model=MODEL, 96 | feature_extractor=FEATURE_EXTRACTOR, 97 | ) 98 | 99 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 100 | 101 | img_cls_explainer.visualize( 102 | method="overlay", save_path=None, sign="all", outlier_threshold=0.15, use_original_image_pixels=False 103 | ) 104 | 105 | 106 | def test_image_classification_visualize_wrt_class_name(): 107 | img_cls_explainer = ImageClassificationExplainer( 108 | model=MODEL, 109 | feature_extractor=FEATURE_EXTRACTOR, 110 | ) 111 | 112 | img_cls_explainer( 113 | TEST_IMAGE, 114 | n_steps=1, 115 | n_steps_noise_tunnel=1, 116 | noise_tunnel_n_samples=1, 117 | internal_batch_size=1, 118 | class_name="banana", 119 | ) 120 | 121 | img_cls_explainer.visualize(method="heatmap", save_path=None, sign="all", outlier_threshold=0.15) 122 | 123 | 124 | def test_image_classification_visualize_wrt_class_index(): 125 | img_cls_explainer = ImageClassificationExplainer( 126 | model=MODEL, 127 | feature_extractor=FEATURE_EXTRACTOR, 128 | ) 129 | 130 | img_cls_explainer( 131 | TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1, index=3 132 | ) 133 | 134 | img_cls_explainer.visualize(method="overlay", save_path=None, sign="all", outlier_threshold=0.15) 135 | 136 | 137 | def test_image_classification_visualize_integrated_gradients_no_noise_tunnel(): 138 | img_cls_explainer = ImageClassificationExplainer( 139 | model=MODEL, feature_extractor=FEATURE_EXTRACTOR, attribution_type=AttributionType.INTEGRATED_GRADIENTS.value 140 | ) 141 | 142 | img_cls_explainer( 143 | TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1, index=3 144 | ) 145 | 146 | img_cls_explainer.visualize(method="masked_image", save_path=None, sign="all", outlier_threshold=0.15) 147 | 148 | 149 | def test_image_classification_visualize_save_image(): 150 | img_cls_explainer = ImageClassificationExplainer( 151 | model=MODEL, 152 | feature_extractor=FEATURE_EXTRACTOR, 153 | ) 154 | 155 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 156 | 157 | img_cls_explainer.visualize( 158 | method="overlay", 159 | save_path="./test.png", 160 | sign="all", 161 | outlier_threshold=0.15, 162 | side_by_side=True, 163 | ) 164 | 165 | os.remove("./test.png") 166 | 167 | 168 | def test_image_classification_visualize_positive_sign_for_unsupported_methods(): 169 | img_cls_explainer = ImageClassificationExplainer( 170 | model=MODEL, 171 | feature_extractor=FEATURE_EXTRACTOR, 172 | ) 173 | 174 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 175 | 176 | img_cls_explainer.visualize( 177 | method="masked_image", 178 | save_path="./test.png", 179 | sign="all", 180 | outlier_threshold=0.15, 181 | side_by_side=True, 182 | ) 183 | 184 | os.remove("./test.png") 185 | 186 | 187 | def test_image_classification_visualize_unsupported_viz_method(): 188 | img_cls_explainer = ImageClassificationExplainer( 189 | model=MODEL, 190 | feature_extractor=FEATURE_EXTRACTOR, 191 | ) 192 | 193 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 194 | 195 | with pytest.raises(ValueError): 196 | img_cls_explainer.visualize(method="not_supported", save_path=None, sign="all", outlier_threshold=0.15) 197 | 198 | 199 | def test_image_classification_visualize_unsupported_sign(): 200 | img_cls_explainer = ImageClassificationExplainer( 201 | model=MODEL, 202 | feature_extractor=FEATURE_EXTRACTOR, 203 | ) 204 | 205 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 206 | 207 | with pytest.raises(ValueError): 208 | img_cls_explainer.visualize(method="overlay", save_path=None, sign="not_supported", outlier_threshold=0.15) 209 | 210 | 211 | def test_image_classification_visualize_side_by_side(): 212 | img_cls_explainer = ImageClassificationExplainer( 213 | model=MODEL, 214 | feature_extractor=FEATURE_EXTRACTOR, 215 | ) 216 | 217 | img_cls_explainer(TEST_IMAGE, n_steps=1, n_steps_noise_tunnel=1, noise_tunnel_n_samples=1, internal_batch_size=1) 218 | 219 | methods = ["overlay", "heatmap", "masked_image", "alpha_scaling"] 220 | for method in methods: 221 | img_cls_explainer.visualize( 222 | method=method, save_path=None, sign="all", outlier_threshold=0.15, side_by_side=True 223 | ) 224 | -------------------------------------------------------------------------------- /transformers_interpret/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cdpierse/transformers-interpret/7c2f938451b37459533bed2f637020bb06000a5e/transformers_interpret/.DS_Store -------------------------------------------------------------------------------- /transformers_interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from .attributions import Attributions, LIGAttributions # noqa: F401 2 | from .explainer import BaseExplainer # noqa: F401 3 | from .explainers.text.multilabel_classification import ( # noqa: F401 4 | MultiLabelClassificationExplainer, 5 | ) 6 | from .explainers.text.question_answering import QuestionAnsweringExplainer # noqa: F401 7 | from .explainers.text.sequence_classification import ( # noqa: F401 8 | PairwiseSequenceClassificationExplainer, 9 | SequenceClassificationExplainer, 10 | ) 11 | from .explainers.text.token_classification import ( # noqa: F401 12 | TokenClassificationExplainer, 13 | ) 14 | from .explainers.text.zero_shot_classification import ( # noqa: F401 15 | ZeroShotClassificationExplainer, 16 | ) 17 | from .explainers.vision.image_classification import ( # noqa: F401 18 | ImageClassificationExplainer, 19 | ) 20 | -------------------------------------------------------------------------------- /transformers_interpret/attributions.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from captum.attr import LayerIntegratedGradients 6 | from captum.attr import visualization as viz 7 | 8 | from transformers_interpret.errors import AttributionsNotCalculatedError 9 | 10 | 11 | class Attributions: 12 | def __init__(self, custom_forward: Callable, embeddings: nn.Module, tokens: list): 13 | self.custom_forward = custom_forward 14 | self.embeddings = embeddings 15 | self.tokens = tokens 16 | 17 | 18 | class LIGAttributions(Attributions): 19 | def __init__( 20 | self, 21 | custom_forward: Callable, 22 | embeddings: nn.Module, 23 | tokens: list, 24 | input_ids: torch.Tensor, 25 | ref_input_ids: torch.Tensor, 26 | sep_id: int, 27 | attention_mask: torch.Tensor, 28 | target: Optional[Union[int, Tuple, torch.Tensor, List]] = None, 29 | token_type_ids: Optional[torch.Tensor] = None, 30 | position_ids: Optional[torch.Tensor] = None, 31 | ref_token_type_ids: Optional[torch.Tensor] = None, 32 | ref_position_ids: Optional[torch.Tensor] = None, 33 | internal_batch_size: Optional[int] = None, 34 | n_steps: int = 50, 35 | ): 36 | super().__init__(custom_forward, embeddings, tokens) 37 | self.input_ids = input_ids 38 | self.ref_input_ids = ref_input_ids 39 | self.attention_mask = attention_mask 40 | self.target = target 41 | self.token_type_ids = token_type_ids 42 | self.position_ids = position_ids 43 | self.ref_token_type_ids = ref_token_type_ids 44 | self.ref_position_ids = ref_position_ids 45 | self.internal_batch_size = internal_batch_size 46 | self.n_steps = n_steps 47 | 48 | self.lig = LayerIntegratedGradients(self.custom_forward, self.embeddings) 49 | 50 | if self.token_type_ids is not None and self.position_ids is not None: 51 | self._attributions, self.delta = self.lig.attribute( 52 | inputs=(self.input_ids, self.token_type_ids, self.position_ids), 53 | baselines=( 54 | self.ref_input_ids, 55 | self.ref_token_type_ids, 56 | self.ref_position_ids, 57 | ), 58 | target=self.target, 59 | return_convergence_delta=True, 60 | additional_forward_args=(self.attention_mask), 61 | internal_batch_size=self.internal_batch_size, 62 | n_steps=self.n_steps, 63 | ) 64 | elif self.position_ids is not None: 65 | self._attributions, self.delta = self.lig.attribute( 66 | inputs=(self.input_ids, self.position_ids), 67 | baselines=( 68 | self.ref_input_ids, 69 | self.ref_position_ids, 70 | ), 71 | target=self.target, 72 | return_convergence_delta=True, 73 | additional_forward_args=(self.attention_mask), 74 | internal_batch_size=self.internal_batch_size, 75 | n_steps=self.n_steps, 76 | ) 77 | elif self.token_type_ids is not None: 78 | self._attributions, self.delta = self.lig.attribute( 79 | inputs=(self.input_ids, self.token_type_ids), 80 | baselines=( 81 | self.ref_input_ids, 82 | self.ref_token_type_ids, 83 | ), 84 | target=self.target, 85 | return_convergence_delta=True, 86 | additional_forward_args=(self.attention_mask), 87 | internal_batch_size=self.internal_batch_size, 88 | n_steps=self.n_steps, 89 | ) 90 | 91 | else: 92 | self._attributions, self.delta = self.lig.attribute( 93 | inputs=self.input_ids, 94 | baselines=self.ref_input_ids, 95 | target=self.target, 96 | return_convergence_delta=True, 97 | internal_batch_size=self.internal_batch_size, 98 | n_steps=self.n_steps, 99 | ) 100 | 101 | @property 102 | def word_attributions(self) -> list: 103 | wa = [] 104 | if len(self.attributions_sum) >= 1: 105 | for i, (word, attribution) in enumerate(zip(self.tokens, self.attributions_sum)): 106 | wa.append((word, float(attribution.cpu().data.numpy()))) 107 | return wa 108 | 109 | else: 110 | raise AttributionsNotCalculatedError("Attributions are not yet calculated") 111 | 112 | def summarize(self, end_idx=None, flip_sign: bool = False): 113 | if flip_sign: 114 | multiplier = -1 115 | else: 116 | multiplier = 1 117 | self.attributions_sum = self._attributions.sum(dim=-1).squeeze(0) * multiplier 118 | self.attributions_sum = self.attributions_sum[:end_idx] / torch.norm(self.attributions_sum[:end_idx]) 119 | 120 | def visualize_attributions(self, pred_prob, pred_class, true_class, attr_class, all_tokens): 121 | 122 | return viz.VisualizationDataRecord( 123 | self.attributions_sum, 124 | pred_prob, 125 | pred_class, 126 | true_class, 127 | attr_class, 128 | self.attributions_sum.sum(), 129 | all_tokens, 130 | self.delta, 131 | ) 132 | -------------------------------------------------------------------------------- /transformers_interpret/errors.py: -------------------------------------------------------------------------------- 1 | class AttributionTypeNotSupportedError(RuntimeError): 2 | "Raised when a particular attribution type is not yet supported by an explainer" 3 | 4 | 5 | class AttributionsNotCalculatedError(RuntimeError): 6 | "Raised when a user attempts to access the attributions for a model and sequence before they have be been summarized" 7 | 8 | 9 | class InputIdsNotCalculatedError(RuntimeError): 10 | "Raised when a user attempts to call a method or attribute that requires input ids" 11 | -------------------------------------------------------------------------------- /transformers_interpret/explainer.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import re 3 | from abc import ABC, abstractmethod, abstractproperty 4 | from typing import List, Tuple, Union 5 | 6 | import torch 7 | from transformers import PreTrainedModel, PreTrainedTokenizer 8 | 9 | 10 | class BaseExplainer(ABC): 11 | def __init__( 12 | self, 13 | model: PreTrainedModel, 14 | tokenizer: PreTrainedTokenizer, 15 | ): 16 | self.model = model 17 | self.tokenizer = tokenizer 18 | 19 | if self.model.config.model_type == "gpt2": 20 | self.ref_token_id = self.tokenizer.eos_token_id 21 | else: 22 | self.ref_token_id = self.tokenizer.pad_token_id 23 | 24 | self.sep_token_id = ( 25 | self.tokenizer.sep_token_id if self.tokenizer.sep_token_id is not None else self.tokenizer.eos_token_id 26 | ) 27 | self.cls_token_id = ( 28 | self.tokenizer.cls_token_id if self.tokenizer.cls_token_id is not None else self.tokenizer.bos_token_id 29 | ) 30 | 31 | self.model_prefix = model.base_model_prefix 32 | 33 | nonstandard_model_types = ["roberta"] 34 | if ( 35 | self._model_forward_signature_accepts_parameter("position_ids") 36 | and self.model.config.model_type not in nonstandard_model_types 37 | ): 38 | self.accepts_position_ids = True 39 | else: 40 | self.accepts_position_ids = False 41 | 42 | if ( 43 | self._model_forward_signature_accepts_parameter("token_type_ids") 44 | and self.model.config.model_type not in nonstandard_model_types 45 | ): 46 | self.accepts_token_type_ids = True 47 | else: 48 | self.accepts_token_type_ids = False 49 | 50 | self.device = self.model.device 51 | 52 | self.word_embeddings = self.model.get_input_embeddings() 53 | self.position_embeddings = None 54 | self.token_type_embeddings = None 55 | 56 | self._set_available_embedding_types() 57 | 58 | @abstractmethod 59 | def encode(self, text: str = None): 60 | """ 61 | Encode given text with a model's tokenizer. 62 | """ 63 | raise NotImplementedError 64 | 65 | @abstractmethod 66 | def decode(self, input_ids: torch.Tensor) -> List[str]: 67 | """ 68 | Decode received input_ids into a list of word tokens. 69 | 70 | 71 | Args: 72 | input_ids (torch.Tensor): Input ids representing 73 | word tokens for a sentence/document. 74 | 75 | """ 76 | raise NotImplementedError 77 | 78 | @abstractproperty 79 | def word_attributions(self): 80 | raise NotImplementedError 81 | 82 | @abstractmethod 83 | def _run(self) -> list: 84 | raise NotImplementedError 85 | 86 | @abstractmethod 87 | def _forward(self): 88 | """ 89 | Forward defines a function for passing inputs 90 | through a models's forward method. 91 | 92 | """ 93 | raise NotImplementedError 94 | 95 | @abstractmethod 96 | def _calculate_attributions(self): 97 | """ 98 | Internal method for calculating the attribution 99 | values for the input text. 100 | 101 | """ 102 | raise NotImplementedError 103 | 104 | def _make_input_reference_pair(self, text: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor, int]: 105 | """ 106 | Tokenizes `text` to numerical token id representation `input_ids`, 107 | as well as creating another reference tensor `ref_input_ids` of the same length 108 | that will be used as baseline for attributions. Additionally 109 | the length of text without special tokens appended is prepended is also 110 | returned. 111 | 112 | Args: 113 | text (str): Text for which we are creating both input ids 114 | and their corresponding reference ids 115 | 116 | Returns: 117 | Tuple[torch.Tensor, torch.Tensor, int] 118 | """ 119 | 120 | if isinstance(text, list): 121 | raise NotImplementedError("Lists of text are not currently supported.") 122 | 123 | text_ids = self.encode(text) 124 | input_ids = self.tokenizer.encode(text, add_special_tokens=True) 125 | 126 | # if no special tokens were added 127 | if len(text_ids) == len(input_ids): 128 | ref_input_ids = [self.ref_token_id] * len(text_ids) 129 | else: 130 | ref_input_ids = [self.cls_token_id] + [self.ref_token_id] * len(text_ids) + [self.sep_token_id] 131 | 132 | return ( 133 | torch.tensor([input_ids], device=self.device), 134 | torch.tensor([ref_input_ids], device=self.device), 135 | len(text_ids), 136 | ) 137 | 138 | def _make_input_reference_token_type_pair( 139 | self, input_ids: torch.Tensor, sep_idx: int = 0 140 | ) -> Tuple[torch.Tensor, torch.Tensor]: 141 | """ 142 | Returns two tensors indicating the corresponding token types for the `input_ids` 143 | and a corresponding all zero reference token type tensor. 144 | Args: 145 | input_ids (torch.Tensor): Tensor of text converted to `input_ids` 146 | sep_idx (int, optional): Defaults to 0. 147 | 148 | Returns: 149 | Tuple[torch.Tensor, torch.Tensor] 150 | """ 151 | seq_len = input_ids.size(1) 152 | token_type_ids = torch.tensor([0 if i <= sep_idx else 1 for i in range(seq_len)], device=self.device).expand_as( 153 | input_ids 154 | ) 155 | ref_token_type_ids = torch.zeros_like(token_type_ids, device=self.device).expand_as(input_ids) 156 | 157 | return (token_type_ids, ref_token_type_ids) 158 | 159 | def _make_input_reference_position_id_pair(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 160 | """ 161 | Returns tensors for positional encoding of tokens for input_ids and zeroed tensor for reference ids. 162 | 163 | Args: 164 | input_ids (torch.Tensor): inputs to create positional encoding. 165 | 166 | Returns: 167 | Tuple[torch.Tensor, torch.Tensor] 168 | """ 169 | seq_len = input_ids.size(1) 170 | position_ids = torch.arange(seq_len, dtype=torch.long, device=self.device) 171 | ref_position_ids = torch.zeros(seq_len, dtype=torch.long, device=self.device) 172 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 173 | ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids) 174 | return (position_ids, ref_position_ids) 175 | 176 | def _make_attention_mask(self, input_ids: torch.Tensor) -> torch.Tensor: 177 | return torch.ones_like(input_ids) 178 | 179 | def _get_preds( 180 | self, 181 | input_ids: torch.Tensor, 182 | token_type_ids=None, 183 | position_ids: torch.Tensor = None, 184 | attention_mask: torch.Tensor = None, 185 | ): 186 | 187 | if self.accepts_position_ids and self.accepts_token_type_ids: 188 | preds = self.model( 189 | input_ids=input_ids, 190 | token_type_ids=token_type_ids, 191 | position_ids=position_ids, 192 | attention_mask=attention_mask, 193 | ) 194 | return preds 195 | 196 | elif self.accepts_position_ids: 197 | preds = self.model( 198 | input_ids=input_ids, 199 | position_ids=position_ids, 200 | attention_mask=attention_mask, 201 | ) 202 | 203 | return preds 204 | elif self.accepts_token_type_ids: 205 | preds = self.model( 206 | input_ids=input_ids, 207 | token_type_ids=token_type_ids, 208 | attention_mask=attention_mask, 209 | ) 210 | 211 | return preds 212 | else: 213 | preds = self.model( 214 | input_ids=input_ids, 215 | attention_mask=attention_mask, 216 | ) 217 | 218 | return preds 219 | 220 | def _clean_text(self, text: str) -> str: 221 | text = re.sub("([.,!?()])", r" \1 ", text) 222 | text = re.sub("\s{2,}", " ", text) 223 | return text 224 | 225 | def _model_forward_signature_accepts_parameter(self, parameter: str) -> bool: 226 | signature = inspect.signature(self.model.forward) 227 | parameters = signature.parameters 228 | return parameter in parameters 229 | 230 | def _set_available_embedding_types(self): 231 | model_base = getattr(self.model, self.model_prefix) 232 | if self.model.config.model_type == "gpt2" and hasattr(model_base, "wpe"): 233 | self.position_embeddings = model_base.wpe.weight 234 | else: 235 | if hasattr(model_base, "embeddings"): 236 | self.model_embeddings = getattr(model_base, "embeddings") 237 | if hasattr(self.model_embeddings, "position_embeddings"): 238 | self.position_embeddings = self.model_embeddings.position_embeddings 239 | if hasattr(self.model_embeddings, "token_type_embeddings"): 240 | self.token_type_embeddings = self.model_embeddings.token_type_embeddings 241 | 242 | def __str__(self): 243 | s = f"{self.__class__.__name__}(" 244 | s += f"\n\tmodel={self.model.__class__.__name__}," 245 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__}" 246 | s += ")" 247 | 248 | return s 249 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .multilabel_classification import MultiLabelClassificationExplainer # noqa: F401 2 | from .question_answering import QuestionAnsweringExplainer # noqa: F401 3 | from .sequence_classification import ( # noqa: F401 4 | PairwiseSequenceClassificationExplainer, 5 | SequenceClassificationExplainer, 6 | ) 7 | from .token_classification import TokenClassificationExplainer # noqa: F401 8 | from .zero_shot_classification import ZeroShotClassificationExplainer # noqa: F401 9 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/multilabel_classification.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from captum.attr import visualization as viz 5 | from transformers import PreTrainedModel, PreTrainedTokenizer 6 | 7 | from .sequence_classification import SequenceClassificationExplainer 8 | 9 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"] 10 | 11 | 12 | class MultiLabelClassificationExplainer(SequenceClassificationExplainer): 13 | """ 14 | Explainer for independently explaining label attributions in a multi-label fashion 15 | for models of type `{MODEL_NAME}ForSequenceClassification` from the Transformers package. 16 | Every label is explained independently and the word attributions are a dictionary of labels 17 | mapping to the word attributions for that label. Even if the model itself is not multi-label 18 | by the resulting word attributions treat the labels as independent. 19 | 20 | Calculates attribution for `text` using the given model 21 | and tokenizer. Since this is a multi-label explainer, the attribution calculation time scales 22 | linearly with the number of labels. 23 | 24 | This explainer also allows for attributions with respect to a particlar embedding type. 25 | This can be selected by passing a `embedding_type`. The default value is `0` which 26 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 27 | If a model does not take position ids in its forward method (distilbert) a warning will 28 | occur and the default word_embeddings will be chosen instead. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | model: PreTrainedModel, 34 | tokenizer: PreTrainedTokenizer, 35 | attribution_type="lig", 36 | custom_labels: Optional[List[str]] = None, 37 | ): 38 | super().__init__(model, tokenizer, attribution_type, custom_labels) 39 | self.labels = [] 40 | 41 | @property 42 | def word_attributions(self) -> dict: 43 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated." 44 | if self.attributions != [] and self.labels != []: 45 | 46 | return dict( 47 | zip( 48 | self.labels, 49 | [attr.word_attributions for attr in self.attributions], 50 | ) 51 | ) 52 | 53 | else: 54 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.") 55 | 56 | def visualize(self, html_filepath: str = None, true_class: str = None): 57 | """ 58 | Visualizes word attributions. If in a notebook table will be displayed inline. 59 | 60 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved 61 | as a html file. 62 | 63 | If the true class is known for the text that can be passed to `true_class` 64 | 65 | """ 66 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 67 | 68 | score_viz = [ 69 | self.attributions[i].visualize_attributions( # type: ignore 70 | self.pred_probs_list[i], 71 | "", # including a predicted class name does not make sense for this explainer 72 | "n/a" if not true_class else true_class, # no true class name for this explainer by default 73 | self.labels[i], 74 | tokens, 75 | ) 76 | for i in range(len(self.attributions)) 77 | ] 78 | 79 | html = viz.visualize_text(score_viz) 80 | 81 | new_html_data = html._repr_html_().replace("Predicted Label", "Prediction Score") 82 | new_html_data = new_html_data.replace("True Label", "n/a") 83 | html.data = new_html_data 84 | 85 | if html_filepath: 86 | if not html_filepath.endswith(".html"): 87 | html_filepath = html_filepath + ".html" 88 | with open(html_filepath, "w") as html_file: 89 | html_file.write(html.data) 90 | return html 91 | 92 | def _forward( # type: ignore 93 | self, 94 | input_ids: torch.Tensor, 95 | token_type_ids=None, 96 | position_ids: torch.Tensor = None, 97 | attention_mask: torch.Tensor = None, 98 | ): 99 | 100 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask) 101 | preds = preds[0] 102 | 103 | # if it is a single output node 104 | if len(preds[0]) == 1: 105 | self._single_node_output = True 106 | self.pred_probs = torch.sigmoid(preds)[0][0] 107 | return torch.sigmoid(preds)[:, :] 108 | 109 | self.pred_probs = torch.sigmoid(preds)[0][self.selected_index] 110 | return torch.sigmoid(preds)[:, self.selected_index] 111 | 112 | def __call__( 113 | self, 114 | text: str, 115 | embedding_type: int = 0, 116 | internal_batch_size: int = None, 117 | n_steps: int = None, 118 | ) -> dict: 119 | """ 120 | Calculates attributions for `text` using the model 121 | and tokenizer given in the constructor. Attributions are calculated for 122 | every label output in the model. 123 | 124 | This explainer also allows for attributions with respect to a particlar embedding type. 125 | This can be selected by passing a `embedding_type`. The default value is `0` which 126 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 127 | If a model does not take position ids in its forward method (distilbert) a warning will 128 | occur and the default word_embeddings will be chosen instead. 129 | 130 | Args: 131 | text (str): Text to provide attributions for. 132 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0. 133 | internal_batch_size (int, optional): Divides total #steps * #examples 134 | data points into chunks of size at most internal_batch_size, 135 | which are computed (forward / backward passes) 136 | sequentially. If internal_batch_size is None, then all evaluations are 137 | processed in one batch. 138 | n_steps (int, optional): The number of steps used by the approximation 139 | method. Default: 50. 140 | 141 | Returns: 142 | dict: A dictionary of label to list of attributions. 143 | """ 144 | if n_steps: 145 | self.n_steps = n_steps 146 | if internal_batch_size: 147 | self.internal_batch_size = internal_batch_size 148 | 149 | self.attributions = [] 150 | self.pred_probs_list = [] 151 | self.labels = [item[0] for item in sorted(self.label2id.items(), key=lambda x: x[1])] 152 | self.label_probs_dict = {} 153 | for i in range(self.model.config.num_labels): 154 | explainer = SequenceClassificationExplainer( 155 | self.model, 156 | self.tokenizer, 157 | ) 158 | self.selected_index = i 159 | explainer._forward = self._forward 160 | explainer(text, i, embedding_type) 161 | 162 | self.attributions.append(explainer.attributions) 163 | self.input_ids = explainer.input_ids 164 | self.pred_probs_list.append(self.pred_probs) 165 | self.label_probs_dict[self.id2label[i]] = self.pred_probs 166 | 167 | return self.word_attributions 168 | 169 | def __str__(self): 170 | s = f"{self.__class__.__name__}(" 171 | s += f"\n\tmodel={self.model.__class__.__name__}," 172 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__}," 173 | s += f"\n\tattribution_type='{self.attribution_type}'," 174 | s += f"\n\tcustom_labels={self.custom_labels}," 175 | s += ")" 176 | 177 | return s 178 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/question_answering.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Union 3 | 4 | import torch 5 | from captum.attr import visualization as viz 6 | from torch.nn.modules.sparse import Embedding 7 | from transformers import PreTrainedModel, PreTrainedTokenizer 8 | 9 | from transformers_interpret import BaseExplainer, LIGAttributions 10 | from transformers_interpret.errors import ( 11 | AttributionTypeNotSupportedError, 12 | InputIdsNotCalculatedError, 13 | ) 14 | 15 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"] 16 | 17 | 18 | class QuestionAnsweringExplainer(BaseExplainer): 19 | """ 20 | Explainer for explaining attributions for models of type `{MODEL_NAME}ForQuestionAnswering` 21 | from the Transformers package. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | model: PreTrainedModel, 27 | tokenizer: PreTrainedTokenizer, 28 | attribution_type: str = "lig", 29 | ): 30 | """ 31 | Args: 32 | model (PreTrainedModel): Pretrained huggingface Question Answering model. 33 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer 34 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig". 35 | 36 | Raises: 37 | AttributionTypeNotSupportedError: [description] 38 | """ 39 | super().__init__(model, tokenizer) 40 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES: 41 | raise AttributionTypeNotSupportedError( 42 | f"""Attribution type '{attribution_type}' is not supported. 43 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}""" 44 | ) 45 | self.attribution_type = attribution_type 46 | 47 | self.attributions: Union[None, LIGAttributions] = None 48 | self.start_attributions = None 49 | self.end_attributions = None 50 | self.input_ids: torch.Tensor = torch.Tensor() 51 | 52 | self.position = 0 53 | 54 | self.internal_batch_size = None 55 | self.n_steps = 50 56 | 57 | def encode(self, text: str) -> list: # type: ignore 58 | "Encode 'text' using tokenizer, special tokens are not added" 59 | return self.tokenizer.encode(text, add_special_tokens=False) 60 | 61 | def decode(self, input_ids: torch.Tensor) -> list: 62 | "Decode 'input_ids' to string using tokenizer" 63 | return self.tokenizer.convert_ids_to_tokens(input_ids[0]) 64 | 65 | @property 66 | def word_attributions(self) -> dict: 67 | """ 68 | Returns the word attributions (as `dict`) for both start and end positions of QA model. 69 | 70 | Raises error if attributions not calculated. 71 | 72 | """ 73 | if self.start_attributions is not None and self.end_attributions is not None: 74 | return { 75 | "start": self.start_attributions.word_attributions, 76 | "end": self.end_attributions.word_attributions, 77 | } 78 | 79 | else: 80 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.") 81 | 82 | @property 83 | def start_pos(self): 84 | "Returns predicted start position for answer" 85 | if len(self.input_ids) > 0: 86 | preds = self._get_preds( 87 | self.input_ids, 88 | self.token_type_ids, 89 | self.position_ids, 90 | self.attention_mask, 91 | ) 92 | 93 | preds = preds[0] 94 | return int(preds.argmax()) 95 | else: 96 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`") 97 | 98 | @property 99 | def end_pos(self): 100 | "Returns predicted end position for answer" 101 | if len(self.input_ids) > 0: 102 | preds = self._get_preds( 103 | self.input_ids, 104 | self.token_type_ids, 105 | self.position_ids, 106 | self.attention_mask, 107 | ) 108 | 109 | preds = preds[1] 110 | return int(preds.argmax()) 111 | else: 112 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`") 113 | 114 | @property 115 | def predicted_answer(self): 116 | "Returns predicted answer span from provided `text`" 117 | if len(self.input_ids) > 0: 118 | preds = self._get_preds( 119 | self.input_ids, 120 | self.token_type_ids, 121 | self.position_ids, 122 | self.attention_mask, 123 | ) 124 | 125 | start = preds[0].argmax() 126 | end = preds[1].argmax() 127 | return " ".join(self.decode(self.input_ids)[start : end + 1]) 128 | else: 129 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`") 130 | 131 | def visualize(self, html_filepath: str = None): 132 | """ 133 | Visualizes word attributions. If in a notebook table will be displayed inline. 134 | 135 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved 136 | as a html file. 137 | """ 138 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 139 | predicted_answer = self.predicted_answer 140 | 141 | self.position = 0 142 | start_pred_probs = self._forward(self.input_ids, self.token_type_ids, self.position_ids) 143 | start_pos = self.start_pos 144 | start_pos_str = tokens[start_pos] + " (" + str(start_pos) + ")" 145 | start_score_viz = self.start_attributions.visualize_attributions( 146 | float(start_pred_probs), 147 | str(predicted_answer), 148 | start_pos_str, 149 | start_pos_str, 150 | tokens, 151 | ) 152 | 153 | self.position = 1 154 | 155 | end_pred_probs = self._forward(self.input_ids, self.token_type_ids, self.position_ids) 156 | end_pos = self.end_pos 157 | end_pos_str = tokens[end_pos] + " (" + str(end_pos) + ")" 158 | end_score_viz = self.end_attributions.visualize_attributions( 159 | float(end_pred_probs), 160 | str(predicted_answer), 161 | end_pos_str, 162 | end_pos_str, 163 | tokens, 164 | ) 165 | 166 | html = viz.visualize_text([start_score_viz, end_score_viz]) 167 | 168 | if html_filepath: 169 | if not html_filepath.endswith(".html"): 170 | html_filepath = html_filepath + ".html" 171 | with open(html_filepath, "w") as html_file: 172 | html_file.write(html.data) 173 | return html 174 | 175 | def _make_input_reference_pair(self, question: str, text: str): # type: ignore 176 | question_ids = self.encode(question) 177 | text_ids = self.encode(text) 178 | 179 | input_ids = [self.cls_token_id] + question_ids + [self.sep_token_id] + text_ids + [self.sep_token_id] 180 | 181 | ref_input_ids = ( 182 | [self.cls_token_id] 183 | + [self.ref_token_id] * len(question_ids) 184 | + [self.sep_token_id] 185 | + [self.ref_token_id] * len(text_ids) 186 | + [self.sep_token_id] 187 | ) 188 | 189 | return ( 190 | torch.tensor([input_ids], device=self.device), 191 | torch.tensor([ref_input_ids], device=self.device), 192 | len(question_ids), 193 | ) 194 | 195 | def _get_preds( 196 | self, 197 | input_ids: torch.Tensor, 198 | token_type_ids=None, 199 | position_ids: torch.Tensor = None, 200 | attention_mask: torch.Tensor = None, 201 | ): 202 | if self.accepts_position_ids and self.accepts_token_type_ids: 203 | preds = self.model( 204 | input_ids, 205 | token_type_ids=token_type_ids, 206 | position_ids=position_ids, 207 | attention_mask=attention_mask, 208 | ) 209 | 210 | return preds 211 | 212 | elif self.accepts_position_ids: 213 | preds = self.model( 214 | input_ids, 215 | position_ids=position_ids, 216 | attention_mask=attention_mask, 217 | ) 218 | 219 | return preds 220 | elif self.accepts_token_type_ids: 221 | preds = self.model( 222 | input_ids, 223 | token_type_ids=token_type_ids, 224 | attention_mask=attention_mask, 225 | ) 226 | 227 | return preds 228 | else: 229 | preds = self.model( 230 | input_ids, 231 | attention_mask=attention_mask, 232 | ) 233 | 234 | return preds 235 | 236 | def _forward( # type: ignore 237 | self, 238 | input_ids: torch.Tensor, 239 | token_type_ids=None, 240 | position_ids: torch.Tensor = None, 241 | attention_mask: torch.Tensor = None, 242 | ): 243 | 244 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask) 245 | 246 | preds = preds[self.position] 247 | 248 | return preds.max(1).values 249 | 250 | def _run(self, question: str, text: str, embedding_type: int) -> dict: 251 | if embedding_type == 0: 252 | embeddings = self.word_embeddings 253 | try: 254 | if embedding_type == 1: 255 | if self.accepts_position_ids and self.position_embeddings is not None: 256 | embeddings = self.position_embeddings 257 | else: 258 | warnings.warn( 259 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings" 260 | ) 261 | embeddings = self.word_embeddings 262 | elif embedding_type == 2: 263 | embeddings = self.model_embeddings 264 | 265 | else: 266 | embeddings = self.word_embeddings 267 | except Exception: 268 | warnings.warn( 269 | "This model doesn't support the embedding type you selected for attributions. Defaulting to word embeddings" 270 | ) 271 | embeddings = self.word_embeddings 272 | 273 | self.question = question 274 | self.text = text 275 | 276 | self._calculate_attributions(embeddings) 277 | return self.word_attributions 278 | 279 | def _calculate_attributions(self, embeddings: Embedding): # type: ignore 280 | 281 | ( 282 | self.input_ids, 283 | self.ref_input_ids, 284 | self.sep_idx, 285 | ) = self._make_input_reference_pair(self.question, self.text) 286 | 287 | ( 288 | self.position_ids, 289 | self.ref_position_ids, 290 | ) = self._make_input_reference_position_id_pair(self.input_ids) 291 | 292 | ( 293 | self.token_type_ids, 294 | self.ref_token_type_ids, 295 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) 296 | 297 | self.attention_mask = self._make_attention_mask(self.input_ids) 298 | 299 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 300 | self.position = 0 301 | start_lig = LIGAttributions( 302 | self._forward, 303 | embeddings, 304 | reference_tokens, 305 | self.input_ids, 306 | self.ref_input_ids, 307 | self.sep_idx, 308 | self.attention_mask, 309 | position_ids=self.position_ids, 310 | ref_position_ids=self.ref_position_ids, 311 | token_type_ids=self.token_type_ids, 312 | ref_token_type_ids=self.ref_token_type_ids, 313 | internal_batch_size=self.internal_batch_size, 314 | n_steps=self.n_steps, 315 | ) 316 | start_lig.summarize() 317 | self.start_attributions = start_lig 318 | 319 | self.position = 1 320 | end_lig = LIGAttributions( 321 | self._forward, 322 | embeddings, 323 | reference_tokens, 324 | self.input_ids, 325 | self.ref_input_ids, 326 | self.sep_idx, 327 | self.attention_mask, 328 | position_ids=self.position_ids, 329 | ref_position_ids=self.ref_position_ids, 330 | token_type_ids=self.token_type_ids, 331 | ref_token_type_ids=self.ref_token_type_ids, 332 | internal_batch_size=self.internal_batch_size, 333 | n_steps=self.n_steps, 334 | ) 335 | end_lig.summarize() 336 | self.end_attributions = end_lig 337 | self.attributions = [self.start_attributions, self.end_attributions] 338 | 339 | def __call__( 340 | self, 341 | question: str, 342 | text: str, 343 | embedding_type: int = 2, 344 | internal_batch_size: int = None, 345 | n_steps: int = None, 346 | ) -> dict: 347 | """ 348 | Calculates start and end position word attributions for `question` and `text` using the model 349 | and tokenizer given in the constructor. 350 | 351 | This explainer also allows for attributions with respect to a particlar embedding type. 352 | This can be selected by passing a `embedding_type`. The default value is `2` which 353 | attempts to calculate for all embeddings. If `0` is passed then attributions are w.r.t word_embeddings, 354 | if `1` is passed then attributions are w.r.t position_embeddings. 355 | 356 | 357 | Args: 358 | question (str): The question text 359 | text (str): The text or context from which the model finds an answers 360 | embedding_type (int, optional): The embedding type word(0), position(1), all(2) to calculate attributions for. 361 | Defaults to 2. 362 | internal_batch_size (int, optional): Divides total #steps * #examples 363 | data points into chunks of size at most internal_batch_size, 364 | which are computed (forward / backward passes) 365 | sequentially. If internal_batch_size is None, then all evaluations are 366 | processed in one batch. 367 | n_steps (int, optional): The number of steps used by the approximation 368 | method. Default: 50. 369 | 370 | Returns: 371 | dict: Dict for start and end position word attributions. 372 | """ 373 | 374 | if n_steps: 375 | self.n_steps = n_steps 376 | if internal_batch_size: 377 | self.internal_batch_size = internal_batch_size 378 | return self._run(question, text, embedding_type) 379 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/sequence_classification.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, List, Optional, Tuple, Union 3 | 4 | import torch 5 | from captum.attr import visualization as viz 6 | from torch.nn.modules.sparse import Embedding 7 | from transformers import PreTrainedModel, PreTrainedTokenizer 8 | 9 | from transformers_interpret import BaseExplainer, LIGAttributions 10 | from transformers_interpret.errors import ( 11 | AttributionTypeNotSupportedError, 12 | InputIdsNotCalculatedError, 13 | ) 14 | 15 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"] 16 | 17 | 18 | class SequenceClassificationExplainer(BaseExplainer): 19 | """ 20 | Explainer for explaining attributions for models of type 21 | `{MODEL_NAME}ForSequenceClassification` from the Transformers package. 22 | 23 | Calculates attribution for `text` using the given model 24 | and tokenizer. 25 | 26 | Attributions can be forced along the axis of a particular output index or class name. 27 | To do this provide either a valid `index` for the class label's output or if the outputs 28 | have provided labels you can pass a `class_name`. 29 | 30 | This explainer also allows for attributions with respect to a particlar embedding type. 31 | This can be selected by passing a `embedding_type`. The default value is `0` which 32 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 33 | If a model does not take position ids in its forward method (distilbert) a warning will 34 | occur and the default word_embeddings will be chosen instead. 35 | 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model: PreTrainedModel, 41 | tokenizer: PreTrainedTokenizer, 42 | attribution_type: str = "lig", 43 | custom_labels: Optional[List[str]] = None, 44 | ): 45 | """ 46 | Args: 47 | model (PreTrainedModel): Pretrained huggingface Sequence Classification model. 48 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer 49 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig". 50 | custom_labels (List[str], optional): Applies custom labels to label2id and id2label configs. 51 | Labels must be same length as the base model configs' labels. 52 | Labels and ids are applied index-wise. Defaults to None. 53 | 54 | Raises: 55 | AttributionTypeNotSupportedError: 56 | """ 57 | super().__init__(model, tokenizer) 58 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES: 59 | raise AttributionTypeNotSupportedError( 60 | f"""Attribution type '{attribution_type}' is not supported. 61 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}""" 62 | ) 63 | self.attribution_type = attribution_type 64 | 65 | if custom_labels is not None: 66 | if len(custom_labels) != len(model.config.label2id): 67 | raise ValueError( 68 | f"""`custom_labels` size '{len(custom_labels)}' should match pretrained model's label2id size 69 | '{len(model.config.label2id)}'""" 70 | ) 71 | 72 | self.id2label, self.label2id = self._get_id2label_and_label2id_dict(custom_labels) 73 | else: 74 | self.label2id = model.config.label2id 75 | self.id2label = model.config.id2label 76 | 77 | self.attributions: Union[None, LIGAttributions] = None 78 | self.input_ids: torch.Tensor = torch.Tensor() 79 | 80 | self._single_node_output = False 81 | 82 | self.internal_batch_size = None 83 | self.n_steps = 50 84 | 85 | @staticmethod 86 | def _get_id2label_and_label2id_dict( 87 | labels: List[str], 88 | ) -> Tuple[Dict[int, str], Dict[str, int]]: 89 | id2label: Dict[int, str] = dict() 90 | label2id: Dict[str, int] = dict() 91 | for idx, label in enumerate(labels): 92 | id2label[idx] = label 93 | label2id[label] = idx 94 | 95 | return id2label, label2id 96 | 97 | def encode(self, text: str = None) -> list: 98 | return self.tokenizer.encode(text, add_special_tokens=False) 99 | 100 | def decode(self, input_ids: torch.Tensor) -> list: 101 | "Decode 'input_ids' to string using tokenizer" 102 | return self.tokenizer.convert_ids_to_tokens(input_ids[0]) 103 | 104 | @property 105 | def predicted_class_index(self) -> int: 106 | "Returns predicted class index (int) for model with last calculated `input_ids`" 107 | if len(self.input_ids) > 0: 108 | # we call this before _forward() so it has to be calculated twice 109 | preds = self.model(self.input_ids)[0] 110 | self.pred_class = torch.argmax(torch.softmax(preds, dim=0)[0]) 111 | return torch.argmax(torch.softmax(preds, dim=1)[0]).cpu().detach().numpy() 112 | 113 | else: 114 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`") 115 | 116 | @property 117 | def predicted_class_name(self): 118 | "Returns predicted class name (str) for model with last calculated `input_ids`" 119 | try: 120 | index = self.predicted_class_index 121 | return self.id2label[int(index)] 122 | except Exception: 123 | return self.predicted_class_index 124 | 125 | @property 126 | def word_attributions(self) -> list: 127 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated." 128 | if self.attributions is not None: 129 | return self.attributions.word_attributions 130 | else: 131 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.") 132 | 133 | def visualize(self, html_filepath: str = None, true_class: str = None): 134 | """ 135 | Visualizes word attributions. If in a notebook table will be displayed inline. 136 | 137 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved 138 | as a html file. 139 | 140 | If the true class is known for the text that can be passed to `true_class` 141 | 142 | """ 143 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 144 | attr_class = self.id2label[self.selected_index] 145 | 146 | if self._single_node_output: 147 | if true_class is None: 148 | true_class = round(float(self.pred_probs)) 149 | predicted_class = round(float(self.pred_probs)) 150 | attr_class = round(float(self.pred_probs)) 151 | 152 | else: 153 | if true_class is None: 154 | true_class = self.selected_index 155 | predicted_class = self.predicted_class_name 156 | 157 | score_viz = self.attributions.visualize_attributions( # type: ignore 158 | self.pred_probs, 159 | predicted_class, 160 | true_class, 161 | attr_class, 162 | tokens, 163 | ) 164 | html = viz.visualize_text([score_viz]) 165 | 166 | if html_filepath: 167 | if not html_filepath.endswith(".html"): 168 | html_filepath = html_filepath + ".html" 169 | with open(html_filepath, "w") as html_file: 170 | html_file.write(html.data) 171 | return html 172 | 173 | def _forward( # type: ignore 174 | self, 175 | input_ids: torch.Tensor, 176 | token_type_ids=None, 177 | position_ids: torch.Tensor = None, 178 | attention_mask: torch.Tensor = None, 179 | ): 180 | 181 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask) 182 | preds = preds[0] 183 | 184 | # if it is a single output node 185 | if len(preds[0]) == 1: 186 | self._single_node_output = True 187 | self.pred_probs = torch.sigmoid(preds)[0][0] 188 | return torch.sigmoid(preds)[:, :] 189 | 190 | self.pred_probs = torch.softmax(preds, dim=1)[0][self.selected_index] 191 | return torch.softmax(preds, dim=1)[:, self.selected_index] 192 | 193 | def _calculate_attributions(self, embeddings: Embedding, index: int = None, class_name: str = None): # type: ignore 194 | ( 195 | self.input_ids, 196 | self.ref_input_ids, 197 | self.sep_idx, 198 | ) = self._make_input_reference_pair(self.text) 199 | 200 | ( 201 | self.position_ids, 202 | self.ref_position_ids, 203 | ) = self._make_input_reference_position_id_pair(self.input_ids) 204 | 205 | ( 206 | self.token_type_ids, 207 | self.ref_token_type_ids, 208 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) 209 | 210 | self.attention_mask = self._make_attention_mask(self.input_ids) 211 | 212 | if index is not None: 213 | self.selected_index = index 214 | elif class_name is not None: 215 | if class_name in self.label2id.keys(): 216 | self.selected_index = int(self.label2id[class_name]) 217 | else: 218 | s = f"'{class_name}' is not found in self.label2id keys." 219 | s += "Defaulting to predicted index instead." 220 | warnings.warn(s) 221 | self.selected_index = int(self.predicted_class_index) 222 | else: 223 | self.selected_index = int(self.predicted_class_index) 224 | 225 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 226 | lig = LIGAttributions( 227 | custom_forward=self._forward, 228 | embeddings=embeddings, 229 | tokens=reference_tokens, 230 | input_ids=self.input_ids, 231 | ref_input_ids=self.ref_input_ids, 232 | sep_id=self.sep_idx, 233 | attention_mask=self.attention_mask, 234 | position_ids=self.position_ids, 235 | ref_position_ids=self.ref_position_ids, 236 | token_type_ids=self.token_type_ids, 237 | ref_token_type_ids=self.ref_token_type_ids, 238 | internal_batch_size=self.internal_batch_size, 239 | n_steps=self.n_steps, 240 | ) 241 | 242 | lig.summarize() 243 | self.attributions = lig 244 | 245 | def _run( 246 | self, 247 | text: str, 248 | index: int = None, 249 | class_name: str = None, 250 | embedding_type: int = None, 251 | ) -> list: # type: ignore 252 | if embedding_type is None: 253 | embeddings = self.word_embeddings 254 | else: 255 | if embedding_type == 0: 256 | embeddings = self.word_embeddings 257 | elif embedding_type == 1: 258 | if self.accepts_position_ids and self.position_embeddings is not None: 259 | embeddings = self.position_embeddings 260 | else: 261 | warnings.warn( 262 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings" 263 | ) 264 | embeddings = self.word_embeddings 265 | else: 266 | embeddings = self.word_embeddings 267 | 268 | self.text = self._clean_text(text) 269 | 270 | self._calculate_attributions(embeddings=embeddings, index=index, class_name=class_name) 271 | return self.word_attributions # type: ignore 272 | 273 | def __call__( 274 | self, 275 | text: str, 276 | index: int = None, 277 | class_name: str = None, 278 | embedding_type: int = 0, 279 | internal_batch_size: int = None, 280 | n_steps: int = None, 281 | ) -> list: 282 | """ 283 | Calculates attribution for `text` using the model 284 | and tokenizer given in the constructor. 285 | 286 | Attributions can be forced along the axis of a particular output index or class name. 287 | To do this provide either a valid `index` for the class label's output or if the outputs 288 | have provided labels you can pass a `class_name`. 289 | 290 | This explainer also allows for attributions with respect to a particular embedding type. 291 | This can be selected by passing a `embedding_type`. The default value is `0` which 292 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 293 | If a model does not take position ids in its forward method (distilbert) a warning will 294 | occur and the default word_embeddings will be chosen instead. 295 | 296 | Args: 297 | text (str): Text to provide attributions for. 298 | index (int, optional): Optional output index to provide attributions for. Defaults to None. 299 | class_name (str, optional): Optional output class name to provide attributions for. Defaults to None. 300 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0. 301 | internal_batch_size (int, optional): Divides total #steps * #examples 302 | data points into chunks of size at most internal_batch_size, 303 | which are computed (forward / backward passes) 304 | sequentially. If internal_batch_size is None, then all evaluations are 305 | processed in one batch. 306 | n_steps (int, optional): The number of steps used by the approximation 307 | method. Default: 50. 308 | Returns: 309 | list: List of tuples containing words and their associated attribution scores. 310 | """ 311 | 312 | if n_steps: 313 | self.n_steps = n_steps 314 | if internal_batch_size: 315 | self.internal_batch_size = internal_batch_size 316 | return self._run(text, index, class_name, embedding_type=embedding_type) 317 | 318 | def __str__(self): 319 | s = f"{self.__class__.__name__}(" 320 | s += f"\n\tmodel={self.model.__class__.__name__}," 321 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__}," 322 | s += f"\n\tattribution_type='{self.attribution_type}'," 323 | s += ")" 324 | 325 | return s 326 | 327 | 328 | class PairwiseSequenceClassificationExplainer(SequenceClassificationExplainer): 329 | def _make_input_reference_pair( 330 | self, text1: Union[List, str], text2: Union[List, str] 331 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 332 | 333 | t1_ids = self.tokenizer.encode(text1, add_special_tokens=False) 334 | t2_ids = self.tokenizer.encode(text2, add_special_tokens=False) 335 | input_ids = self.tokenizer.encode([text1, text2], add_special_tokens=True) 336 | if self.model.config.model_type == "roberta": 337 | ref_input_ids = ( 338 | [self.cls_token_id] 339 | + [self.ref_token_id] * len(t1_ids) 340 | + [self.sep_token_id] 341 | + [self.sep_token_id] 342 | + [self.ref_token_id] * len(t2_ids) 343 | + [self.sep_token_id] 344 | ) 345 | 346 | else: 347 | 348 | ref_input_ids = ( 349 | [self.cls_token_id] 350 | + [self.ref_token_id] * len(t1_ids) 351 | + [self.sep_token_id] 352 | + [self.ref_token_id] * len(t2_ids) 353 | + [self.sep_token_id] 354 | ) 355 | 356 | return ( 357 | torch.tensor([input_ids], device=self.device), 358 | torch.tensor([ref_input_ids], device=self.device), 359 | len(t1_ids) + 1, # +1 for CLS token 360 | ) 361 | 362 | def _calculate_attributions( 363 | self, 364 | embeddings: Embedding, 365 | index: int = None, 366 | class_name: str = None, 367 | flip_sign: bool = False, 368 | ): # type: ignore 369 | ( 370 | self.input_ids, 371 | self.ref_input_ids, 372 | self.sep_idx, 373 | ) = self._make_input_reference_pair(self.text1, self.text2) 374 | 375 | ( 376 | self.position_ids, 377 | self.ref_position_ids, 378 | ) = self._make_input_reference_position_id_pair(self.input_ids) 379 | 380 | ( 381 | self.token_type_ids, 382 | self.ref_token_type_ids, 383 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) 384 | 385 | self.attention_mask = self._make_attention_mask(self.input_ids) 386 | 387 | if index is not None: 388 | self.selected_index = index 389 | elif class_name is not None: 390 | if class_name in self.label2id.keys(): 391 | self.selected_index = int(self.label2id[class_name]) 392 | else: 393 | s = f"'{class_name}' is not found in self.label2id keys." 394 | s += "Defaulting to predicted index instead." 395 | warnings.warn(s) 396 | self.selected_index = int(self.predicted_class_index) 397 | else: 398 | self.selected_index = int(self.predicted_class_index) 399 | 400 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 401 | lig = LIGAttributions( 402 | custom_forward=self._forward, 403 | embeddings=embeddings, 404 | tokens=reference_tokens, 405 | input_ids=self.input_ids, 406 | ref_input_ids=self.ref_input_ids, 407 | sep_id=self.sep_idx, 408 | attention_mask=self.attention_mask, 409 | position_ids=self.position_ids, 410 | ref_position_ids=self.ref_position_ids, 411 | token_type_ids=self.token_type_ids, 412 | ref_token_type_ids=self.ref_token_type_ids, 413 | internal_batch_size=self.internal_batch_size, 414 | n_steps=self.n_steps, 415 | ) 416 | if self._single_node_output: 417 | lig.summarize(flip_sign=flip_sign) 418 | else: 419 | lig.summarize() 420 | self.attributions = lig 421 | 422 | def _run( 423 | self, 424 | text1: str, 425 | text2: str, 426 | index: int = None, 427 | class_name: str = None, 428 | embedding_type: int = None, 429 | flip_sign: bool = False, 430 | ) -> list: # type: ignore 431 | if embedding_type is None: 432 | embeddings = self.word_embeddings 433 | else: 434 | if embedding_type == 0: 435 | embeddings = self.word_embeddings 436 | elif embedding_type == 1: 437 | if self.accepts_position_ids and self.position_embeddings is not None: 438 | embeddings = self.position_embeddings 439 | else: 440 | warnings.warn( 441 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings" 442 | ) 443 | embeddings = self.word_embeddings 444 | else: 445 | embeddings = self.word_embeddings 446 | 447 | self.text1 = text1 448 | self.text2 = text2 449 | 450 | self._calculate_attributions( 451 | embeddings=embeddings, 452 | index=index, 453 | class_name=class_name, 454 | flip_sign=flip_sign, 455 | ) 456 | return self.word_attributions # type: ignore 457 | 458 | def __call__( 459 | self, 460 | text1: str, 461 | text2: str, 462 | index: int = None, 463 | class_name: str = None, 464 | embedding_type: int = 0, 465 | internal_batch_size: int = None, 466 | n_steps: int = None, 467 | flip_sign: bool = False, 468 | ): 469 | """ 470 | Calculates pairwise attributions for two inputs `text1` and `text2` using the model 471 | and tokenizer given in the constructor. Pairwise attributions are useful for models where 472 | two distinct inputs separated by the model separator token are fed to the model, such as cross-encoder 473 | models for similarity classification. 474 | 475 | Attributions can be forced along the axis of a particular output index or class name if there is more than one. 476 | To do this provide either a valid `index` for the class label's output or if the outputs 477 | have provided labels you can pass a `class_name`. 478 | 479 | This explainer also allows for attributions with respect to a particular embedding type. 480 | This can be selected by passing a `embedding_type`. The default value is `0` which 481 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 482 | If a model does not take position ids in its forward method (distilbert) a warning will 483 | occur and the default word_embeddings will be chosen instead. 484 | 485 | Additionally, this explainer allows for attributions signs to be flipped in cases where the model 486 | only outputs a single node. By default for models that output a single node the attributions are 487 | with respect to the inputs pushing the scores closer to 1.0, however if you want to see the 488 | attributions with respect to scores closer to 0.0 you can pass `flip_sign=True`. For similarity 489 | based models this is useful, as the model might predict a score closer to 0.0 for the two inputs 490 | and in that case we would flip the attributions sign to explain why the two inputs are dissimilar. 491 | 492 | Args: 493 | text1 (str): First text input to provide pairwise attributions for. 494 | text2 (str): Second text to provide pairwise attributions for. 495 | index (int, optional): Optional output index to provide attributions for. Defaults to None. 496 | class_name (str, optional): Optional output class name to provide attributions for. Defaults to None. 497 | embedding_type (int, optional):The embedding type word(0) or position(1) to calculate attributions for. Defaults to 0. 498 | internal_batch_size (int, optional): Divides total #steps * #examples 499 | data points into chunks of size at most internal_batch_size, 500 | which are computed (forward / backward passes) 501 | sequentially. If internal_batch_size is None, then all evaluations are 502 | processed in one batch. 503 | n_steps (int, optional): The number of steps used by the approximation 504 | method. Default: 50. 505 | flip_sign (bool, optional): Boolean flag determining whether to flip the sign of attributions. Defaults to False. 506 | 507 | Returns: 508 | _type_: _description_ 509 | """ 510 | if n_steps: 511 | self.n_steps = n_steps 512 | if internal_batch_size: 513 | self.internal_batch_size = internal_batch_size 514 | return self._run( 515 | text1=text1, 516 | text2=text2, 517 | embedding_type=embedding_type, 518 | index=index, 519 | class_name=class_name, 520 | flip_sign=flip_sign, 521 | ) 522 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/token_classification.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, List, Optional, Union 3 | 4 | import torch 5 | from captum.attr import visualization as viz 6 | from torch.nn.modules.sparse import Embedding 7 | from transformers import PreTrainedModel, PreTrainedTokenizer 8 | 9 | from transformers_interpret import BaseExplainer 10 | from transformers_interpret.attributions import LIGAttributions 11 | from transformers_interpret.errors import ( 12 | AttributionTypeNotSupportedError, 13 | InputIdsNotCalculatedError, 14 | ) 15 | 16 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"] 17 | 18 | 19 | class TokenClassificationExplainer(BaseExplainer): 20 | def __init__( 21 | self, 22 | model: PreTrainedModel, 23 | tokenizer: PreTrainedTokenizer, 24 | attribution_type="lig", 25 | ): 26 | 27 | """ 28 | Args: 29 | model (PreTrainedModel): Pretrained huggingface Sequence Classification model. 30 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer 31 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig". 32 | 33 | Raises: 34 | AttributionTypeNotSupportedError: 35 | """ 36 | super().__init__(model, tokenizer) 37 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES: 38 | raise AttributionTypeNotSupportedError( 39 | f"""Attribution type '{attribution_type}' is not supported. 40 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}""" 41 | ) 42 | self.attribution_type: str = attribution_type 43 | 44 | self.label2id = model.config.label2id 45 | self.id2label = model.config.id2label 46 | 47 | self.ignored_indexes: Optional[List[int]] = None 48 | self.ignored_labels: Optional[List[str]] = None 49 | 50 | self.attributions: Union[None, Dict[int, LIGAttributions]] = None 51 | self.input_ids: torch.Tensor = torch.Tensor() 52 | 53 | self.internal_batch_size = None 54 | self.n_steps = 50 55 | 56 | def encode(self, text: str = None) -> list: 57 | "Encode the text using tokenizer" 58 | return self.tokenizer.encode(text, add_special_tokens=False) 59 | 60 | def decode(self, input_ids: torch.Tensor) -> list: 61 | "Decode 'input_ids' to string using tokenizer" 62 | return self.tokenizer.convert_ids_to_tokens(input_ids[0]) 63 | 64 | @property 65 | def predicted_class_indexes(self) -> List[int]: 66 | "Returns the predicted class indexes (int) for model with last calculated `input_ids`" 67 | if len(self.input_ids) > 0: 68 | 69 | preds = self.model(self.input_ids) 70 | preds = preds[0] 71 | self.pred_class = torch.softmax(preds, dim=2)[0] 72 | 73 | return torch.argmax(torch.softmax(preds, dim=2), dim=2)[0].cpu().detach().numpy() 74 | 75 | else: 76 | raise InputIdsNotCalculatedError("input_ids have not been created yet.`") 77 | 78 | @property 79 | def predicted_class_names(self): 80 | "Returns predicted class names (str) for model with last calculated `input_ids`" 81 | try: 82 | indexes = self.predicted_class_indexes 83 | return [self.id2label[int(index)] for index in indexes] 84 | except Exception: 85 | return self.predicted_class_indexes 86 | 87 | @property 88 | def word_attributions(self) -> Dict: 89 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated." 90 | 91 | if self.attributions is not None: 92 | word_attr = dict() 93 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 94 | labels = self.predicted_class_names 95 | 96 | for index, attr in self.attributions.items(): 97 | try: 98 | predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()] 99 | except KeyError: 100 | predicted_class = torch.argmax(self.pred_probs[index]).item() 101 | 102 | word_attr[tokens[index]] = { 103 | "label": predicted_class, 104 | "attribution_scores": attr.word_attributions, 105 | } 106 | 107 | return word_attr 108 | else: 109 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.") 110 | 111 | @property 112 | def _selected_indexes(self) -> List[int]: 113 | """Returns the indexes for which the attributions must be calculated considering the 114 | ignored indexes and the ignored labels, in that order of priority""" 115 | 116 | selected_indexes = set(range(self.input_ids.shape[1])) # all indexes 117 | 118 | if self.ignored_indexes is not None: 119 | selected_indexes = selected_indexes.difference(set(self.ignored_indexes)) 120 | 121 | if self.ignored_labels is not None: 122 | ignored_indexes_extra = [] 123 | pred_labels = [self.id2label[id] for id in self.predicted_class_indexes] 124 | 125 | for index, label in enumerate(pred_labels): 126 | if label in self.ignored_labels: 127 | ignored_indexes_extra.append(index) 128 | selected_indexes = selected_indexes.difference(ignored_indexes_extra) 129 | 130 | return sorted(list(selected_indexes)) 131 | 132 | def visualize(self, html_filepath: str = None, true_classes: List[str] = None): 133 | """ 134 | Visualizes word attributions. If in a notebook table will be displayed inline. 135 | 136 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved 137 | as a html file. 138 | 139 | If the true class is known for the text that can be passed to `true_class` 140 | 141 | """ 142 | if true_classes is not None and len(true_classes) != self.input_ids.shape[1]: 143 | raise ValueError(f"""The length of `true_classes` must be equal to the number of tokens""") 144 | 145 | score_vizs = [] 146 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 147 | 148 | for index in self._selected_indexes: 149 | pred_prob = torch.max(self.pred_probs[index]) 150 | predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()] 151 | 152 | attr_class = tokens[index] 153 | if true_classes is None: 154 | true_class = predicted_class 155 | else: 156 | true_class = true_classes[index] 157 | 158 | score_vizs.append( 159 | self.attributions[index].visualize_attributions( 160 | pred_prob, 161 | predicted_class, 162 | true_class, 163 | attr_class, 164 | tokens, 165 | ) 166 | ) 167 | 168 | html = viz.visualize_text(score_vizs) 169 | 170 | if html_filepath: 171 | if not html_filepath.endswith(".html"): 172 | html_filepath = html_filepath + ".html" 173 | with open(html_filepath, "w") as html_file: 174 | html_file.write(html.data) 175 | return html 176 | 177 | def _forward( 178 | self, 179 | input_ids: torch.Tensor, 180 | position_ids: torch.Tensor = None, 181 | attention_mask: torch.Tensor = None, 182 | ): 183 | if self.accepts_position_ids: 184 | preds = self.model( 185 | input_ids, 186 | position_ids=position_ids, 187 | attention_mask=attention_mask, 188 | ) 189 | else: 190 | preds = self.model(input_ids, attention_mask) 191 | 192 | preds = preds.logits # preds.shape = [N_BATCH, N_TOKENS, N_CLASSES] 193 | 194 | self.pred_probs = torch.softmax(preds, dim=2)[0] 195 | return torch.softmax(preds, dim=2)[:, self.index, :] 196 | 197 | def _calculate_attributions( 198 | self, 199 | embeddings: Embedding, 200 | ) -> None: 201 | ( 202 | self.input_ids, 203 | self.ref_input_ids, 204 | self.sep_idx, 205 | ) = self._make_input_reference_pair(self.text) 206 | 207 | ( 208 | self.position_ids, 209 | self.ref_position_ids, 210 | ) = self._make_input_reference_position_id_pair(self.input_ids) 211 | 212 | self.attention_mask = self._make_attention_mask(self.input_ids) 213 | 214 | pred_classes = self.predicted_class_indexes 215 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 216 | 217 | ligs = {} 218 | 219 | for index in self._selected_indexes: 220 | self.index = index 221 | lig = LIGAttributions( 222 | self._forward, 223 | embeddings, 224 | reference_tokens, 225 | self.input_ids, 226 | self.ref_input_ids, 227 | self.sep_idx, 228 | self.attention_mask, 229 | target=int(pred_classes[index]), 230 | position_ids=self.position_ids, 231 | ref_position_ids=self.ref_position_ids, 232 | internal_batch_size=self.internal_batch_size, 233 | n_steps=self.n_steps, 234 | ) 235 | lig.summarize() 236 | ligs[index] = lig 237 | 238 | self.attributions = ligs 239 | 240 | def _run( 241 | self, 242 | text: str, 243 | embedding_type: int = None, 244 | ) -> dict: 245 | if embedding_type is None: 246 | embeddings = self.word_embeddings 247 | else: 248 | if embedding_type == 0: 249 | embeddings = self.word_embeddings 250 | elif embedding_type == 1: 251 | if self.accepts_position_ids and self.position_embeddings is not None: 252 | embeddings = self.position_embeddings 253 | else: 254 | warnings.warn( 255 | "This model doesn't support position embeddings for attributions. Defaulting to word embeddings" 256 | ) 257 | embeddings = self.word_embeddings 258 | else: 259 | embeddings = self.word_embeddings 260 | 261 | self.text = self._clean_text(text) 262 | 263 | self._calculate_attributions(embeddings=embeddings) 264 | return self.word_attributions 265 | 266 | def __call__( 267 | self, 268 | text: str, 269 | embedding_type: int = 0, 270 | internal_batch_size: Optional[int] = None, 271 | n_steps: Optional[int] = None, 272 | ignored_indexes: Optional[List[int]] = None, 273 | ignored_labels: Optional[List[str]] = None, 274 | ) -> dict: 275 | """ 276 | Args: 277 | text (str): Sentence whose NER predictions are to be explained. 278 | embedding_type (int, default = 0): Custom type of embedding. 279 | internal_batch_size (int, optional): Custom internal batch size for the attributions calculation. 280 | n_steps (int): Custom number of steps in the approximation used in the attributions calculation. 281 | ignored_indexes (List[int], optional): Indexes that are to be ignored by the explainer. 282 | ignored_labels (List[str], optional)): NER labels that are to be ignored by the explainer. The 283 | explainer will ignore those indexes whose predicted label is 284 | in `ignored_labels`. 285 | """ 286 | 287 | if n_steps: 288 | self.n_steps = n_steps 289 | if internal_batch_size: 290 | self.internal_batch_size = internal_batch_size 291 | 292 | self.ignored_indexes = ignored_indexes 293 | self.ignored_labels = ignored_labels 294 | 295 | return self._run(text, embedding_type=embedding_type) 296 | 297 | def __str__(self): 298 | s = f"{self.__class__.__name__}(" 299 | s += f"\n\tmodel={self.model.__class__.__name__}," 300 | s += f"\n\ttokenizer={self.tokenizer.__class__.__name__}," 301 | s += f"\n\tattribution_type='{self.attribution_type}'," 302 | s += ")" 303 | 304 | return s 305 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/text/zero_shot_classification.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from captum.attr import visualization as viz 5 | from torch.nn.modules.sparse import Embedding 6 | from transformers import PreTrainedModel, PreTrainedTokenizer 7 | 8 | from transformers_interpret import LIGAttributions 9 | from transformers_interpret.errors import AttributionTypeNotSupportedError 10 | 11 | from .question_answering import QuestionAnsweringExplainer 12 | from .sequence_classification import SequenceClassificationExplainer 13 | 14 | SUPPORTED_ATTRIBUTION_TYPES = ["lig"] 15 | 16 | 17 | class ZeroShotClassificationExplainer(SequenceClassificationExplainer, QuestionAnsweringExplainer): 18 | """ 19 | Explainer for explaining attributions for models that can perform 20 | zero-shot classification, specifically models trained on nli downstream tasks. 21 | 22 | This explainer uses the same "trick" as Huggingface to achieve attributions on 23 | arbitrary labels provided at inference time. 24 | 25 | Model's provided to this explainer must be nli sequence classification models 26 | and must have the label "entailment" or "ENTAILMENT" in 27 | `model.config.label2id.keys()` in order for it to work correctly. 28 | 29 | This explainer works by forcing the model to explain it's output with respect to 30 | the entailment class. For each label passed at inference the explainer forms a hypothesis with each 31 | and calculates attributions for each hypothesis label. The label with the highest predicted probability 32 | can be accessed via the attribute `predicted_label`. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model: PreTrainedModel, 38 | tokenizer: PreTrainedTokenizer, 39 | attribution_type: str = "lig", 40 | ): 41 | """ 42 | 43 | Args: 44 | model (PreTrainedModel):Pretrained huggingface Sequence Classification model. Must be a NLI model. 45 | tokenizer (PreTrainedTokenizer): Pretrained huggingface tokenizer 46 | attribution_type (str, optional): The attribution method to calculate on. Defaults to "lig". 47 | 48 | Raises: 49 | AttributionTypeNotSupportedError: [description] 50 | ValueError: [description] 51 | """ 52 | super().__init__(model, tokenizer) 53 | if attribution_type not in SUPPORTED_ATTRIBUTION_TYPES: 54 | raise AttributionTypeNotSupportedError( 55 | f"""Attribution type '{attribution_type}' is not supported. 56 | Supported types are {SUPPORTED_ATTRIBUTION_TYPES}""" 57 | ) 58 | self.label_exists, self.entailment_key = self._entailment_label_exists() 59 | if not self.label_exists: 60 | raise ValueError('Expected label "entailment" in `model.label2id` ') 61 | 62 | self.entailment_idx = self.label2id[self.entailment_key] 63 | self.include_hypothesis = False 64 | self.attributions = [] 65 | 66 | self.internal_batch_size = None 67 | self.n_steps = 50 68 | 69 | @property 70 | def word_attributions(self) -> dict: 71 | "Returns the word attributions for model and the text provided. Raises error if attributions not calculated." 72 | if self.attributions != []: 73 | if self.include_hypothesis: 74 | return dict( 75 | zip( 76 | self.labels, 77 | [attr.word_attributions for attr in self.attributions], 78 | ) 79 | ) 80 | else: 81 | spliced_wa = [attr.word_attributions[: self.sep_idx] for attr in self.attributions] 82 | return dict(zip(self.labels, spliced_wa)) 83 | else: 84 | raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.") 85 | 86 | def visualize(self, html_filepath: str = None, true_class: str = None): 87 | """ 88 | Visualizes word attributions. If in a notebook table will be displayed inline. 89 | 90 | Otherwise pass a valid path to `html_filepath` and the visualization will be saved 91 | as a html file. 92 | 93 | If the true class is known for the text that can be passed to `true_class` 94 | 95 | """ 96 | tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 97 | 98 | if not self.include_hypothesis: 99 | tokens = tokens[: self.sep_idx] 100 | 101 | score_viz = [ 102 | self.attributions[i].visualize_attributions( # type: ignore 103 | self.pred_probs[i], 104 | self.labels[i], 105 | self.labels[i], 106 | self.labels[i], 107 | tokens, 108 | ) 109 | for i in range(len(self.attributions)) 110 | ] 111 | html = viz.visualize_text(score_viz) 112 | 113 | if html_filepath: 114 | if not html_filepath.endswith(".html"): 115 | html_filepath = html_filepath + ".html" 116 | with open(html_filepath, "w") as html_file: 117 | html_file.write(html.data) 118 | return html 119 | 120 | def _entailment_label_exists(self) -> bool: 121 | if "entailment" in self.label2id.keys(): 122 | return True, "entailment" 123 | elif "ENTAILMENT" in self.label2id.keys(): 124 | return True, "ENTAILMENT" 125 | 126 | return False, None 127 | 128 | def _get_top_predicted_label_idx(self, text, hypothesis_labels: List[str]) -> int: 129 | 130 | entailment_outputs = [] 131 | for label in hypothesis_labels: 132 | input_ids, _, sep_idx = self._make_input_reference_pair(text, label) 133 | position_ids, _ = self._make_input_reference_position_id_pair(input_ids) 134 | token_type_ids, _ = self._make_input_reference_token_type_pair(input_ids, sep_idx) 135 | attention_mask = self._make_attention_mask(input_ids) 136 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask) 137 | entailment_outputs.append(float(torch.sigmoid(preds[0])[0][self.entailment_idx])) 138 | 139 | normed_entailment_outputs = [float(i) / sum(entailment_outputs) for i in entailment_outputs] 140 | 141 | self.pred_probs = normed_entailment_outputs 142 | 143 | return entailment_outputs.index(max(entailment_outputs)) 144 | 145 | def _make_input_reference_pair( 146 | self, 147 | text: str, 148 | hypothesis_text: str, 149 | ) -> Tuple[torch.Tensor, torch.Tensor, int]: 150 | hyp_ids = self.encode(hypothesis_text) 151 | text_ids = self.encode(text) 152 | 153 | input_ids = [self.cls_token_id] + text_ids + [self.sep_token_id] + hyp_ids + [self.sep_token_id] 154 | 155 | ref_input_ids = ( 156 | [self.cls_token_id] 157 | + [self.ref_token_id] * len(text_ids) 158 | + [self.sep_token_id] 159 | + [self.ref_token_id] * len(hyp_ids) 160 | + [self.sep_token_id] 161 | ) 162 | 163 | return ( 164 | torch.tensor([input_ids], device=self.device), 165 | torch.tensor([ref_input_ids], device=self.device), 166 | len(text_ids), 167 | ) 168 | 169 | def _forward( # type: ignore 170 | self, 171 | input_ids: torch.Tensor, 172 | token_type_ids=None, 173 | position_ids: torch.Tensor = None, 174 | attention_mask: torch.Tensor = None, 175 | ): 176 | 177 | preds = self._get_preds(input_ids, token_type_ids, position_ids, attention_mask) 178 | preds = preds[0] 179 | 180 | return torch.softmax(preds, dim=1)[:, self.selected_index] 181 | 182 | def _calculate_attributions(self, embeddings: Embedding, class_name: str, index: int = None): # type: ignore 183 | ( 184 | self.input_ids, 185 | self.ref_input_ids, 186 | self.sep_idx, 187 | ) = self._make_input_reference_pair(self.text, self.hypothesis_text) 188 | 189 | ( 190 | self.position_ids, 191 | self.ref_position_ids, 192 | ) = self._make_input_reference_position_id_pair(self.input_ids) 193 | 194 | ( 195 | self.token_type_ids, 196 | self.ref_token_type_ids, 197 | ) = self._make_input_reference_token_type_pair(self.input_ids, self.sep_idx) 198 | 199 | self.attention_mask = self._make_attention_mask(self.input_ids) 200 | 201 | self.selected_index = int(self.label2id[class_name]) 202 | 203 | reference_tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)] 204 | lig = LIGAttributions( 205 | self._forward, 206 | embeddings, 207 | reference_tokens, 208 | self.input_ids, 209 | self.ref_input_ids, 210 | self.sep_idx, 211 | self.attention_mask, 212 | position_ids=self.position_ids, 213 | ref_position_ids=self.ref_position_ids, 214 | token_type_ids=self.token_type_ids, 215 | ref_token_type_ids=self.ref_token_type_ids, 216 | internal_batch_size=self.internal_batch_size, 217 | n_steps=self.n_steps, 218 | ) 219 | if self.include_hypothesis: 220 | lig.summarize() 221 | else: 222 | lig.summarize(self.sep_idx) 223 | self.attributions.append(lig) 224 | 225 | def __call__( 226 | self, 227 | text: str, 228 | labels: List[str], 229 | embedding_type: int = 0, 230 | hypothesis_template="this text is about {} .", 231 | include_hypothesis: bool = False, 232 | internal_batch_size: int = None, 233 | n_steps: int = None, 234 | ) -> dict: 235 | """ 236 | Calculates attribution for `text` using the model and 237 | tokenizer given in the constructor. Since `self.model` is 238 | a NLI type model each label in `labels` is formatted to the 239 | `hypothesis_template`. By default attributions are provided for all 240 | labels. The top predicted label can be found in the `predicted_label` 241 | attribute. 242 | 243 | Attribution is forced to be on the axis of whatever index 244 | the entailment class resolves to. e.g. {"entailment": 0, "neutral": 1, "contradiction": 2 } 245 | in the above case attributions would be for the label at index 0. 246 | 247 | This explainer also allows for attributions with respect to a particlar embedding type. 248 | This can be selected by passing a `embedding_type`. The default value is `0` which 249 | is for word_embeddings, if `1` is passed then attributions are w.r.t to position_embeddings. 250 | If a model does not take position ids in its forward method (distilbert) a warning will 251 | occur and the default word_embeddings will be chosen instead. 252 | 253 | The default `hypothesis_template` can also be overridden by providing a formattable 254 | string which accepts exactly one formattable value for the label. 255 | 256 | If `include_hypothesis` is set to `True` then the word attributions and visualization 257 | of the attributions will also included the hypothesis text which gives a complete indication 258 | of what the model sees at inference. 259 | 260 | Args: 261 | text (str): Text to provide attributions for. 262 | labels (List[str]): The labels to classify the text to. If only one label is provided in the list then 263 | attributions are guaranteed to be for that label. 264 | embedding_type (int, optional): The embedding type word(0) or position(1) to calculate attributions for. 265 | Defaults to 0. 266 | hypothesis_template (str, optional): Hypothesis presetned to NLI model given text. 267 | Defaults to "this text is about {} .". 268 | include_hypothesis (bool, optional): Alternative option to include hypothesis text in attributions 269 | and visualization. Defaults to False. 270 | internal_batch_size (int, optional): Divides total #steps * #examples 271 | data points into chunks of size at most internal_batch_size, 272 | which are computed (forward / backward passes) 273 | sequentially. If internal_batch_size is None, then all evaluations are 274 | processed in one batch. 275 | n_steps (int, optional): The number of steps used by the approximation 276 | method. Default: 50. 277 | Returns: 278 | list: List of tuples containing words and their associated attribution scores. 279 | """ 280 | 281 | if n_steps: 282 | self.n_steps = n_steps 283 | if internal_batch_size: 284 | self.internal_batch_size = internal_batch_size 285 | self.attributions = [] 286 | self.pred_probs = [] 287 | self.include_hypothesis = include_hypothesis 288 | self.labels = labels 289 | self.hypothesis_labels = [hypothesis_template.format(label) for label in labels] 290 | 291 | predicted_text_idx = self._get_top_predicted_label_idx(text, self.hypothesis_labels) 292 | 293 | for i, _ in enumerate(self.labels): 294 | self.hypothesis_text = self.hypothesis_labels[i] 295 | self.predicted_label = labels[i] + " (" + self.entailment_key.lower() + ")" 296 | super().__call__( 297 | text, 298 | class_name=self.entailment_key, 299 | embedding_type=embedding_type, 300 | ) 301 | self.predicted_label = self.labels[predicted_text_idx] 302 | return self.word_attributions 303 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/vision/attribution_types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | 3 | 4 | @unique 5 | class AttributionType(Enum): 6 | INTEGRATED_GRADIENTS = "IG" 7 | INTEGRATED_GRADIENTS_NOISE_TUNNEL = "IGNT" 8 | 9 | 10 | class NoiseTunnelType(Enum): 11 | SMOOTHGRAD = "smoothgrad" 12 | SMOOTHGRAD_SQUARED = "smoothgrad_sq" 13 | VARGRAD = "vargrad" 14 | -------------------------------------------------------------------------------- /transformers_interpret/explainers/vision/image_classification.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from enum import Enum 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | from captum.attr import IntegratedGradients, NoiseTunnel 7 | from captum.attr import visualization as viz 8 | from PIL import Image 9 | from transformers.image_utils import ImageFeatureExtractionMixin 10 | from transformers.modeling_utils import PreTrainedModel 11 | 12 | from .attribution_types import AttributionType, NoiseTunnelType 13 | 14 | 15 | class ImageClassificationExplainer: 16 | """ 17 | This class is used to explain the output of a model on an image. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | model: PreTrainedModel, 23 | feature_extractor: ImageFeatureExtractionMixin, 24 | attribution_type: str = AttributionType.INTEGRATED_GRADIENTS_NOISE_TUNNEL.value, 25 | custom_labels: Optional[List[str]] = None, 26 | ): 27 | self.model = model 28 | self.feature_extractor = feature_extractor 29 | if attribution_type not in [attribution.value for attribution in AttributionType]: 30 | raise ValueError(f"Attribution type {attribution_type} not supported.") 31 | 32 | self.attribution_type = attribution_type 33 | 34 | if custom_labels is not None: 35 | if len(custom_labels) != len(model.config.label2id): 36 | raise ValueError( 37 | f"""`custom_labels` size '{len(custom_labels)}' should match pretrained model's label2id size 38 | '{len(model.config.label2id)}'""" 39 | ) 40 | 41 | self.id2label, self.label2id = self._get_id2label_and_label2id_dict(custom_labels) 42 | else: 43 | self.label2id = model.config.label2id 44 | self.id2label = model.config.id2label 45 | 46 | self.device = self.model.device 47 | 48 | self.internal_batch_size = None 49 | self.n_steps = 50 50 | self.n_steps_noise_tunnel = 5 51 | self.noise_tunnel_n_samples = 10 52 | self.noise_tunnel_type = NoiseTunnelType.SMOOTHGRAD.value 53 | 54 | self.attributions = None 55 | 56 | def visualize( 57 | self, 58 | save_path: Union[str, None] = None, 59 | method: str = "overlay", 60 | sign: str = "all", 61 | outlier_threshold: float = 0.1, 62 | use_original_image_pixels: bool = True, 63 | side_by_side: bool = False, 64 | ): 65 | outlier_threshold = min(outlier_threshold * 100, 100) 66 | attributions_t = np.transpose(self.attributions.squeeze(), (1, 2, 0)) 67 | if use_original_image_pixels: 68 | np_image = np.asarray( 69 | self.feature_extractor.resize(self._image, size=(attributions_t.shape[0], attributions_t.shape[1])) 70 | ) 71 | else: 72 | # uses the normalized image pixels which is what the model sees, but can be hard to interpret visually 73 | np_image = np.transpose(self.inputs.squeeze(), (1, 2, 0)) 74 | 75 | if sign == "all" and method in ["alpha_scaling", "masked_image"]: 76 | warnings.warn( 77 | "sign='all' is not supported for method='alpha_scaling' or method='masked_image'. " 78 | "Please use sign='positive', sign='negative', or sign='absolute'. " 79 | "Changing sign to default 'positive'." 80 | ) 81 | sign = "positive" 82 | 83 | visualizer = ImageAttributionVisualizer( 84 | attributions=attributions_t, 85 | pixel_values=np_image, 86 | outlier_threshold=outlier_threshold, 87 | pred_class=self.id2label[self.predicted_index], 88 | visualization_method=method, 89 | sign=sign, 90 | side_by_side=side_by_side, 91 | ) 92 | 93 | viz_result = visualizer() 94 | if save_path: 95 | viz_result[0].savefig(save_path) 96 | 97 | return viz_result 98 | 99 | def _forward_func(self, inputs): 100 | outputs = self.model(inputs) 101 | return outputs["logits"] 102 | 103 | def _calculate_attributions(self, class_name: Union[int, None], index: Union[int, None]) -> np.ndarray: 104 | 105 | if class_name: 106 | self.selected_index = self.label2id[class_name] 107 | 108 | if index: 109 | self.selected_index = index 110 | else: 111 | self.selected_index = self.predicted_index 112 | 113 | if self.attribution_type == AttributionType.INTEGRATED_GRADIENTS.value: 114 | ig = IntegratedGradients(self._forward_func) 115 | self.attributions, self.delta = ig.attribute( 116 | self.inputs, 117 | target=self.selected_index, 118 | internal_batch_size=self.internal_batch_size, 119 | n_steps=self.n_steps, 120 | return_convergence_delta=True, 121 | ) 122 | self.delta = self.delta.cpu().detach().numpy() 123 | if self.attribution_type == AttributionType.INTEGRATED_GRADIENTS_NOISE_TUNNEL.value: 124 | ig_nt = IntegratedGradients(self._forward_func) 125 | nt = NoiseTunnel(ig_nt) 126 | self.attributions = nt.attribute( 127 | self.inputs, 128 | nt_samples=self.noise_tunnel_n_samples, 129 | nt_type=self.noise_tunnel_type, 130 | target=self.selected_index, 131 | n_steps=self.n_steps_noise_tunnel, 132 | ) 133 | 134 | self.inputs = self.inputs.cpu().detach().numpy() 135 | self.attributions = self.attributions.cpu().detach().numpy() 136 | return self.attributions 137 | 138 | @staticmethod 139 | def _get_id2label_and_label2id_dict( 140 | labels: List[str], 141 | ) -> Tuple[Dict[int, str], Dict[str, int]]: 142 | id2label: Dict[int, str] = dict() 143 | label2id: Dict[str, int] = dict() 144 | for idx, label in enumerate(labels): 145 | id2label[idx] = label 146 | label2id[label] = idx 147 | 148 | return id2label, label2id 149 | 150 | def __call__( 151 | self, 152 | image: Image, 153 | index: int = None, 154 | class_name: str = None, 155 | internal_batch_size: Union[int, None] = None, 156 | n_steps: Union[int, None] = None, 157 | n_steps_noise_tunnel: Union[int, None] = None, 158 | noise_tunnel_n_samples: Union[int, None] = None, 159 | noise_tunnel_type: NoiseTunnelType = NoiseTunnelType.SMOOTHGRAD.value, 160 | ): 161 | self._image: Image = image 162 | try: 163 | self.noise_tunnel_type = NoiseTunnelType(noise_tunnel_type).value 164 | except ValueError: 165 | raise ValueError(f"noise_tunnel_type must be one of {NoiseTunnelType.__members__}") 166 | 167 | self.inputs = self.feature_extractor(image, return_tensors="pt").to(self.device)["pixel_values"] 168 | self.predicted_index = self.model(self.inputs).logits.argmax().item() 169 | 170 | if n_steps: 171 | self.n_steps = n_steps 172 | if n_steps_noise_tunnel: 173 | self.n_steps_noise_tunnel = n_steps_noise_tunnel 174 | if internal_batch_size: 175 | self.internal_batch_size = internal_batch_size 176 | if noise_tunnel_n_samples: 177 | self.noise_tunnel_n_samples = noise_tunnel_n_samples 178 | 179 | return self._calculate_attributions(class_name, index) 180 | 181 | 182 | class VisualizationMethods(Enum): 183 | HEATMAP = "heatmap" 184 | OVERLAY = "overlay" 185 | ALPHA_SCALING = "alpha_scaling" 186 | MASKED_IMAGE = "masked_image" 187 | 188 | 189 | class SignType(Enum): 190 | ALL = "all" 191 | POSITIVE = "positive" 192 | NEGATIVE = "negative" 193 | ABSOLUTE = "absolute" 194 | 195 | 196 | class ImageAttributionVisualizer: 197 | def __init__( 198 | self, 199 | attributions: np.ndarray, 200 | pixel_values: np.ndarray, 201 | outlier_threshold: float, 202 | pred_class: str, 203 | sign: str, 204 | visualization_method: str, 205 | side_by_side: bool = False, 206 | ): 207 | self.attributions = attributions 208 | self.pixel_values = pixel_values 209 | self.outlier_threshold = outlier_threshold 210 | self.pred_class = pred_class 211 | self.render_pyplot = self._using_ipython() 212 | self.side_by_side = side_by_side 213 | try: 214 | self.visualization_method = VisualizationMethods(visualization_method) 215 | except ValueError: 216 | raise ValueError( 217 | f"""`visualization_method` must be one of the following: {list(VisualizationMethods.__members__.keys())}""" 218 | ) 219 | 220 | try: 221 | self.sign = SignType(sign) 222 | except ValueError: 223 | raise ValueError(f"""`sign` must be one of the following: {list(SignType.__members__.keys())}""") 224 | if self.visualization_method == VisualizationMethods.HEATMAP: 225 | self.plot_function = self.heatmap 226 | elif self.visualization_method == VisualizationMethods.OVERLAY: 227 | self.plot_function = self.overlay 228 | elif self.visualization_method == VisualizationMethods.ALPHA_SCALING: 229 | self.plot_function = self.alpha_scaling 230 | elif self.visualization_method == VisualizationMethods.MASKED_IMAGE: 231 | self.plot_function = self.masked_image 232 | 233 | def overlay(self): 234 | if self.side_by_side: 235 | return viz.visualize_image_attr_multiple( 236 | attr=self.attributions, 237 | original_image=self.pixel_values, 238 | methods=["original_image", "blended_heat_map"], 239 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value], 240 | show_colorbar=True, 241 | use_pyplot=self.render_pyplot, 242 | outlier_perc=self.outlier_threshold, 243 | titles=["Original Image", f"Heatmap overlay IG. Prediction: {self.pred_class}"], 244 | ) 245 | return viz.visualize_image_attr( 246 | original_image=self.pixel_values, 247 | attr=self.attributions, 248 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value, 249 | method="blended_heat_map", 250 | show_colorbar=True, 251 | outlier_perc=self.outlier_threshold, 252 | title=f"Heatmap Overlay IG. Prediction - {self.pred_class}", 253 | use_pyplot=self.render_pyplot, 254 | ) 255 | 256 | def heatmap(self): 257 | if self.side_by_side: 258 | return viz.visualize_image_attr_multiple( 259 | attr=self.attributions, 260 | original_image=self.pixel_values, 261 | methods=["original_image", "heat_map"], 262 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value], 263 | show_colorbar=True, 264 | use_pyplot=self.render_pyplot, 265 | outlier_perc=self.outlier_threshold, 266 | titles=["Original Image", f"Heatmap IG. Prediction: {self.pred_class}"], 267 | ) 268 | 269 | return viz.visualize_image_attr( 270 | original_image=self.pixel_values, 271 | attr=self.attributions, 272 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value, 273 | method="heat_map", 274 | show_colorbar=True, 275 | outlier_perc=self.outlier_threshold, 276 | title=f"Heatmap IG. Prediction - {self.pred_class}", 277 | use_pyplot=self.render_pyplot, 278 | ) 279 | 280 | def alpha_scaling(self): 281 | if self.side_by_side: 282 | return viz.visualize_image_attr_multiple( 283 | attr=self.attributions, 284 | original_image=self.pixel_values, 285 | methods=["original_image", "alpha_scaling"], 286 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value], 287 | show_colorbar=True, 288 | use_pyplot=self.render_pyplot, 289 | outlier_perc=self.outlier_threshold, 290 | titles=["Original Image", f"Alpha Scaled IG. Prediction: {self.pred_class}"], 291 | ) 292 | 293 | return viz.visualize_image_attr( 294 | original_image=self.pixel_values, 295 | attr=self.attributions, 296 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value, 297 | method="alpha_scaling", 298 | show_colorbar=True, 299 | outlier_perc=self.outlier_threshold, 300 | title=f"Alpha Scaling IG. Prediction - {self.pred_class}", 301 | use_pyplot=self.render_pyplot, 302 | ) 303 | 304 | def masked_image(self): 305 | if self.side_by_side: 306 | return viz.visualize_image_attr_multiple( 307 | attr=self.attributions, 308 | original_image=self.pixel_values, 309 | methods=["original_image", "masked_image"], 310 | signs=["all", "absolute_value" if self.sign.value == "absolute" else self.sign.value], 311 | show_colorbar=True, 312 | use_pyplot=self.render_pyplot, 313 | outlier_perc=self.outlier_threshold, 314 | titles=["Original Image", f"Masked Image IG. Prediction: {self.pred_class}"], 315 | ) 316 | return viz.visualize_image_attr( 317 | original_image=self.pixel_values, 318 | attr=self.attributions, 319 | sign="absolute_value" if self.sign.value == "absolute" else self.sign.value, 320 | method="masked_image", 321 | show_colorbar=True, 322 | outlier_perc=self.outlier_threshold, 323 | title=f"Masked Image IG. Prediction - {self.pred_class}", 324 | use_pyplot=self.render_pyplot, 325 | ) 326 | 327 | def _using_ipython(self) -> bool: 328 | try: 329 | eval("__IPYTHON__") 330 | except NameError: 331 | return False 332 | else: # pragma: no cover 333 | return True 334 | 335 | def __call__(self): 336 | return self.plot_function() 337 | --------------------------------------------------------------------------------