├── docs ├── source │ ├── api │ │ ├── .gitignore │ │ ├── api │ │ │ ├── ferret.Benchmark.score.rst │ │ │ ├── ferret.Benchmark.explain.rst │ │ │ ├── ferret.Benchmark.show_table.rst │ │ │ ├── ferret.Benchmark.load_dataset.rst │ │ │ ├── ferret.Benchmark.get_dataframe.rst │ │ │ ├── ferret.Benchmark.evaluate_samples.rst │ │ │ ├── ferret.Benchmark.evaluate_explanation.rst │ │ │ ├── ferret.Benchmark.evaluate_explanations.rst │ │ │ ├── ferret.Benchmark.show_evaluation_table.rst │ │ │ ├── ferret.Benchmark.show_samples_evaluation_table.rst │ │ │ ├── ferret.BaseDataset.rst │ │ │ ├── ferret.SSTDataset.rst │ │ │ ├── ferret.MovieReviews.rst │ │ │ ├── ferret.BaseExplainer.rst │ │ │ ├── ferret.SHAPExplainer.rst │ │ │ ├── ferret.HateXplainDataset.rst │ │ │ ├── ferret.LIMEExplainer.rst │ │ │ ├── ferret.GradientExplainer.rst │ │ │ ├── ferret.explainers.BaseExplainer.rst │ │ │ ├── ferret.Benchmark.rst │ │ │ ├── ferret.BaseEvaluator.rst │ │ │ ├── ferret.ThermostatDataset.rst │ │ │ ├── ferret.evaluators.BaseEvaluator.rst │ │ │ ├── ferret.IntegratedGradientExplainer.rst │ │ │ ├── ferret.TauLOO_Evaluation.rst │ │ │ ├── ferret.AOPC_Sufficiency_Evaluation.rst │ │ │ ├── ferret.AUPRC_PlausibilityEvaluation.rst │ │ │ ├── ferret.Tokenf1_PlausibilityEvaluation.rst │ │ │ ├── ferret.TokenIOU_PlausibilityEvaluation.rst │ │ │ └── ferret.AOPC_Comprehensiveness_Evaluation.rst │ │ ├── datasets.rst │ │ ├── explainers.rst │ │ ├── evaluators.rst │ │ ├── index.rst │ │ └── benchmark.rst │ ├── history.rst │ ├── _static │ │ ├── logo.png │ │ ├── banner.png │ │ ├── banner_v2.png │ │ ├── favicon.ico │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ ├── apple-touch-icon.png │ │ ├── android-chrome-192x192.png │ │ └── versions.json │ ├── user_guide │ │ ├── _images │ │ │ ├── example_explanations_viz.png │ │ │ ├── example_evaluation_faithfulness_viz.png │ │ │ └── example_evaluation_plausibility_viz.png │ │ ├── _speechxai_images │ │ │ ├── example_paralinguistic_expl.png │ │ │ ├── example_paralinguistic_variations.png │ │ │ └── example_word-level-audio-segments-loo.png │ │ ├── index.rst │ │ ├── whatisferret.rst │ │ ├── advanced.rst │ │ ├── quickstart.rst │ │ ├── speechxai.rst │ │ ├── explaining.rst │ │ ├── benchmarking.rst │ │ ├── tasks.rst │ │ ├── notions.explainers.rst │ │ └── notions.benchmarking.rst │ ├── index.rst │ └── conf.py ├── .DS_Store ├── requirements.txt ├── Makefile └── make.bat ├── ferret ├── explainers │ ├── explanation_speech │ │ ├── __init__.py │ │ ├── pink_noise.mp3 │ │ ├── white_noise.mp3 │ │ ├── explanation_speech.py │ │ ├── loo_speech_explainer.py │ │ ├── equal_width │ │ │ ├── loo_equal_width_explainer.py │ │ │ ├── lime_equal_width_explainer.py │ │ │ └── gradient_equal_width_explainer.py │ │ ├── lime_speech_explainer.py │ │ ├── utils_removal.py │ │ └── gradient_speech_explainer.py │ ├── utils.py │ ├── explanation.py │ ├── dummy.py │ ├── __init__.py │ ├── shap.py │ ├── lime.py │ └── gradient.py ├── modeling │ ├── speech_model_helpers │ │ ├── __init__.py │ │ ├── model_helper_er.py │ │ ├── model_helper_italic.py │ │ └── model_helper_fsc.py │ ├── __init__.py │ └── base_helpers.py ├── evaluators │ ├── evaluation.py │ ├── perturbation.py │ ├── class_measures.py │ ├── utils_from_soft_to_discrete.py │ └── __init__.py ├── datasets │ ├── __init__.py │ └── utils_sst_rationale_generation.py ├── __init__.py └── visualization.py ├── tests └── __init__.py ├── .readthedocs.yaml ├── MANIFEST.in ├── tox.ini ├── authors.md ├── .editorconfig ├── AUTHORS.rst ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ ├── main.yml │ ├── publish-to-pypi.yml │ └── flake8-pytest.yml ├── LICENSE ├── HISTORY.rst ├── setup.py ├── examples └── README.md ├── .gitignore ├── pyproject.toml ├── Makefile ├── CONTRIBUTING.rst └── README.md /docs/source/api/.gitignore: -------------------------------------------------------------------------------- 1 | ./api/ -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ferret/modeling/speech_model_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../HISTORY.rst 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for ferret.""" 2 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/.DS_Store -------------------------------------------------------------------------------- /docs/source/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/logo.png -------------------------------------------------------------------------------- /docs/source/_static/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/banner.png -------------------------------------------------------------------------------- /docs/source/_static/banner_v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/banner_v2.png -------------------------------------------------------------------------------- /docs/source/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/favicon-16x16.png -------------------------------------------------------------------------------- /docs/source/_static/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/favicon-32x32.png -------------------------------------------------------------------------------- /docs/source/_static/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/apple-touch-icon.png -------------------------------------------------------------------------------- /docs/source/_static/android-chrome-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/_static/android-chrome-192x192.png -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/pink_noise.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/ferret/explainers/explanation_speech/pink_noise.mp3 -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/white_noise.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/ferret/explainers/explanation_speech/white_noise.mp3 -------------------------------------------------------------------------------- /docs/source/user_guide/_images/example_explanations_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_images/example_explanations_viz.png -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.score.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.score 2 | ====================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.score -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.explain.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.explain 2 | ======================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.explain -------------------------------------------------------------------------------- /docs/source/user_guide/_images/example_evaluation_faithfulness_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_images/example_evaluation_faithfulness_viz.png -------------------------------------------------------------------------------- /docs/source/user_guide/_images/example_evaluation_plausibility_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_images/example_evaluation_plausibility_viz.png -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.show_table.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.show\_table 2 | ============================ 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.show_table -------------------------------------------------------------------------------- /docs/source/user_guide/_speechxai_images/example_paralinguistic_expl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_speechxai_images/example_paralinguistic_expl.png -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.load_dataset.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.load\_dataset 2 | ============================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.load_dataset -------------------------------------------------------------------------------- /docs/source/user_guide/_speechxai_images/example_paralinguistic_variations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_speechxai_images/example_paralinguistic_variations.png -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.get_dataframe.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.get\_dataframe 2 | =============================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.get_dataframe -------------------------------------------------------------------------------- /docs/source/user_guide/_speechxai_images/example_word-level-audio-segments-loo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g8a9/ferret/HEAD/docs/source/user_guide/_speechxai_images/example_word-level-audio-segments-loo.png -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.evaluate_samples.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.evaluate\_samples 2 | ================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.evaluate_samples -------------------------------------------------------------------------------- /ferret/explainers/utils.py: -------------------------------------------------------------------------------- 1 | def parse_explainer_args(explainer_args): 2 | init_args = explainer_args.get("init_args", {}) 3 | call_args = explainer_args.get("call_args", {}) 4 | return init_args, call_args -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.evaluate_explanation.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.evaluate\_explanation 2 | ====================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.evaluate_explanation -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.evaluate_explanations.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.evaluate\_explanations 2 | ======================================= 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.evaluate_explanations -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.show_evaluation_table.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.show\_evaluation\_table 2 | ======================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.show_evaluation_table -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.show_samples_evaluation_table.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark.show\_samples\_evaluation\_table 2 | ================================================= 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. automethod:: Benchmark.show_samples_evaluation_table -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.10" 7 | 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | python: 12 | install: 13 | - requirements: docs/requirements.txt 14 | - method: pip 15 | path: . -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | datasets 3 | captum 4 | einops 5 | shap 6 | seaborn 7 | matplotlib 8 | scikit-image 9 | lime 10 | opencv-python 11 | pytreebank 12 | tqdm 13 | sphinx-toggleprompt>=0.3.1 14 | sphinx-copybutton>=0.5.1 15 | sphinx-favicon>=1.0.1 16 | pydata_sphinx_theme>=0.12.0 17 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | 7 | recursive-include tests * 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[co] 10 | 11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 12 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36, py37, py38, flake8 3 | 4 | [travis] 5 | python = 6 | 3.8: py38 7 | 3.7: py37 8 | 3.6: py36 9 | 10 | [testenv:flake8] 11 | basepython = python 12 | deps = flake8 13 | commands = flake8 ferret tests 14 | 15 | [testenv] 16 | setenv = 17 | PYTHONPATH = {toxinidir} 18 | 19 | commands = python setup.py test 20 | -------------------------------------------------------------------------------- /authors.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Credits 3 | --- 4 | 5 | # Development Lead 6 | 7 | - Giuseppe Attanasio \<\> 8 | - Eliana Pastor \<\> 9 | - Debora Nozza \<\> 10 | - Chiara Di Bonaventura \<\> 11 | 12 | # Contributors 13 | 14 | None yet. Why not be the first? 15 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Giuseppe Attanasio 9 | * Eliana Pastor 10 | * Debora Nozza 11 | * Chiara Di Bonaventura 12 | 13 | Contributors 14 | ------------ 15 | 16 | None yet. Why not be the first? 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * ferret version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout code 12 | uses: actions/checkout@v4 13 | with: 14 | fetch-depth: 0 15 | - name: Secret Scanning 16 | uses: trufflesecurity/trufflehog@main 17 | with: 18 | extra_args: --only-verified 19 | -------------------------------------------------------------------------------- /docs/source/api/datasets.rst: -------------------------------------------------------------------------------- 1 | .. _api.datasets: 2 | 3 | ======== 4 | Datasets 5 | ======== 6 | 7 | .. currentmodule:: ferret 8 | 9 | Abstract Classes 10 | ---------------- 11 | 12 | .. autosummary:: 13 | :toctree: api/ 14 | 15 | BaseDataset 16 | 17 | 18 | Integrated XAI Datasets 19 | ----------------------- 20 | 21 | .. autosummary:: 22 | :toctree: api/ 23 | 24 | HateXplainDataset 25 | MovieReviews 26 | SSTDataset 27 | ThermostatDataset -------------------------------------------------------------------------------- /docs/source/api/explainers.rst: -------------------------------------------------------------------------------- 1 | .. _api.explainers: 2 | 3 | ========== 4 | Explainers 5 | ========== 6 | 7 | .. currentmodule:: ferret 8 | 9 | Abstract Classes 10 | ---------------- 11 | 12 | .. autosummary:: 13 | :toctree: api/ 14 | 15 | BaseExplainer 16 | 17 | 18 | Post-Hoc Feature Explainers 19 | --------------------------- 20 | 21 | .. autosummary:: 22 | :toctree: api/ 23 | 24 | GradientExplainer 25 | IntegratedGradientExplainer 26 | SHAPExplainer 27 | LIMEExplainer 28 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.BaseDataset.rst: -------------------------------------------------------------------------------- 1 | ferret.BaseDataset 2 | ================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: BaseDataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~BaseDataset.__init__ 17 | ~BaseDataset.get_instance 18 | ~BaseDataset.get_true_rationale_from_words_to_tokens 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~BaseDataset.NAME 29 | ~BaseDataset.avg_rationale_size 30 | 31 | -------------------------------------------------------------------------------- /ferret/evaluators/evaluation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import List, Tuple 4 | 5 | from ferret.explainers.explanation import Explanation 6 | 7 | from . import BaseEvaluator 8 | 9 | 10 | @dataclass 11 | class EvaluationMetricOutput: 12 | """Output to store any metric result.""" 13 | 14 | metric: BaseEvaluator 15 | value: float 16 | 17 | 18 | @dataclass 19 | class ExplanationEvaluation: 20 | """Generic class to represent an Evaluation""" 21 | 22 | explanation: Explanation 23 | evaluation_outputs: List[EvaluationMetricOutput] 24 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.SSTDataset.rst: -------------------------------------------------------------------------------- 1 | ferret.SSTDataset 2 | ================= 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: SSTDataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SSTDataset.__init__ 17 | ~SSTDataset.get_instance 18 | ~SSTDataset.get_true_rationale_from_words_to_tokens 19 | ~SSTDataset.len 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~SSTDataset.NAME 30 | ~SSTDataset.avg_rationale_size 31 | 32 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.MovieReviews.rst: -------------------------------------------------------------------------------- 1 | ferret.MovieReviews 2 | =================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: MovieReviews 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~MovieReviews.__init__ 17 | ~MovieReviews.get_instance 18 | ~MovieReviews.get_true_rationale_from_words_to_tokens 19 | ~MovieReviews.len 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~MovieReviews.NAME 30 | ~MovieReviews.avg_rationale_size 31 | 32 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.BaseExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.BaseExplainer 2 | ==================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: BaseExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~BaseExplainer.__init__ 17 | ~BaseExplainer.compute_feature_importance 18 | ~BaseExplainer.get_input_embeds 19 | ~BaseExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~BaseExplainer.NAME 30 | ~BaseExplainer.device 31 | ~BaseExplainer.tokenizer 32 | 33 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.SHAPExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.SHAPExplainer 2 | ==================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: SHAPExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~SHAPExplainer.__init__ 17 | ~SHAPExplainer.compute_feature_importance 18 | ~SHAPExplainer.get_input_embeds 19 | ~SHAPExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~SHAPExplainer.NAME 30 | ~SHAPExplainer.device 31 | ~SHAPExplainer.tokenizer 32 | 33 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.HateXplainDataset.rst: -------------------------------------------------------------------------------- 1 | ferret.HateXplainDataset 2 | ======================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: HateXplainDataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~HateXplainDataset.__init__ 17 | ~HateXplainDataset.get_instance 18 | ~HateXplainDataset.get_true_rationale_from_words_to_tokens 19 | ~HateXplainDataset.len 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~HateXplainDataset.NAME 30 | ~HateXplainDataset.avg_rationale_size 31 | 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = ferret 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /ferret/explainers/explanation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | 7 | @dataclass 8 | class Explanation: 9 | """Generic class to represent an Explanation""" 10 | 11 | text: str 12 | tokens: str 13 | scores: np.array 14 | explainer: str 15 | target_pos_idx: int 16 | helper_type: str 17 | target_token_pos_idx: Optional[int] = None 18 | target: Optional[str] = None 19 | target_token: Optional[str] = None 20 | 21 | 22 | @dataclass 23 | class ExplanationWithRationale(Explanation): 24 | """Specific explanation to contain the gold rationale""" 25 | 26 | rationale: Optional[np.array] = None 27 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.LIMEExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.LIMEExplainer 2 | ==================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: LIMEExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~LIMEExplainer.__init__ 17 | ~LIMEExplainer.compute_feature_importance 18 | ~LIMEExplainer.get_input_embeds 19 | ~LIMEExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~LIMEExplainer.MAX_SAMPLES 30 | ~LIMEExplainer.NAME 31 | ~LIMEExplainer.device 32 | ~LIMEExplainer.tokenizer 33 | 34 | -------------------------------------------------------------------------------- /docs/source/api/evaluators.rst: -------------------------------------------------------------------------------- 1 | .. _api.evaluators: 2 | 3 | ========== 4 | Evaluators 5 | ========== 6 | 7 | .. currentmodule:: ferret 8 | 9 | Abstract Classes 10 | ---------------- 11 | 12 | .. autosummary:: 13 | :toctree: api/ 14 | 15 | BaseEvaluator 16 | 17 | 18 | Evaluation Methods 19 | ------------------ 20 | 21 | .. autosummary:: 22 | :toctree: api/ 23 | :caption: Faithfulness 24 | 25 | AOPC_Comprehensiveness_Evaluation 26 | AOPC_Sufficiency_Evaluation 27 | TauLOO_Evaluation 28 | 29 | .. autosummary:: 30 | :toctree: api/ 31 | :caption: Plausibility 32 | 33 | AUPRC_PlausibilityEvaluation 34 | Tokenf1_PlausibilityEvaluation 35 | TokenIOU_PlausibilityEvaluation -------------------------------------------------------------------------------- /docs/source/api/api/ferret.GradientExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.GradientExplainer 2 | ======================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: GradientExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~GradientExplainer.__init__ 17 | ~GradientExplainer.compute_feature_importance 18 | ~GradientExplainer.get_input_embeds 19 | ~GradientExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~GradientExplainer.NAME 30 | ~GradientExplainer.device 31 | ~GradientExplainer.tokenizer 32 | 33 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.explainers.BaseExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.explainers.BaseExplainer 2 | =============================== 3 | 4 | .. currentmodule:: ferret.explainers 5 | 6 | .. autoclass:: BaseExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~BaseExplainer.__init__ 17 | ~BaseExplainer.compute_feature_importance 18 | ~BaseExplainer.get_input_embeds 19 | ~BaseExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~BaseExplainer.NAME 30 | ~BaseExplainer.device 31 | ~BaseExplainer.tokenizer 32 | 33 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Benchmark.rst: -------------------------------------------------------------------------------- 1 | ferret.Benchmark 2 | ================ 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: Benchmark 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Benchmark.__init__ 17 | ~Benchmark.evaluate_explanation 18 | ~Benchmark.evaluate_explanations 19 | ~Benchmark.evaluate_samples 20 | ~Benchmark.explain 21 | ~Benchmark.get_dataframe 22 | ~Benchmark.load_dataset 23 | ~Benchmark.score 24 | ~Benchmark.show_evaluation_table 25 | ~Benchmark.show_samples_evaluation_table 26 | ~Benchmark.show_table 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /ferret/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_helpers import ( 2 | SequenceClassificationHelper, 3 | TokenClassificationHelper, 4 | ZeroShotTextClassificationHelper, 5 | ) 6 | 7 | SUPPORTED_TASKS_TO_HELPERS = { 8 | "text-classification": SequenceClassificationHelper, 9 | "nli": SequenceClassificationHelper, 10 | "zero-shot-text-classification": ZeroShotTextClassificationHelper, 11 | "ner": TokenClassificationHelper, 12 | } 13 | 14 | 15 | def create_helper(model, tokenizer, task_name): 16 | helper = SUPPORTED_TASKS_TO_HELPERS.get(task_name, None) 17 | if helper is None: 18 | raise ValueError(f"Task {task_name} is not supported.") 19 | else: 20 | return helper(model, tokenizer) 21 | -------------------------------------------------------------------------------- /docs/source/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | User Guide 3 | ========== 4 | 5 | This guide is an overview and explain the main features in ferret and how to use them. 6 | Head to :ref:`quickstart` for a basic example. 7 | 8 | .. toctree:: 9 | :caption: Getting started 10 | :maxdepth: 2 11 | 12 | whatisferret 13 | quickstart 14 | 15 | .. toctree:: 16 | :caption: Using ferret 17 | :maxdepth: 2 18 | 19 | explaining 20 | benchmarking 21 | advanced 22 | speechxai 23 | 24 | .. toctree:: 25 | :caption: Tasks 26 | :maxdepth: 1 27 | 28 | tasks 29 | 30 | .. toctree:: 31 | :caption: Notions 32 | :maxdepth: 2 33 | 34 | notions.explainers 35 | notions.benchmarking 36 | -------------------------------------------------------------------------------- /ferret/explainers/dummy.py: -------------------------------------------------------------------------------- 1 | """Dummy Explainer module""" 2 | import numpy as np 3 | 4 | from . import BaseExplainer 5 | from .explanation import Explanation 6 | from .utils import parse_explainer_args 7 | 8 | 9 | class DummyExplainer(BaseExplainer): 10 | """Dummy Explainer that assigns random scores to tokens.""" 11 | 12 | NAME = "dummy" 13 | 14 | def __init__(self, model, tokenizer): 15 | super().__init__(model, tokenizer) 16 | 17 | def compute_feature_importance(self, text, target=1, **explainer_args): 18 | tokens = self._tokenize(text) 19 | output = Explanation( 20 | text, self.get_tokens(text), np.random.randn(len(tokens)), self.NAME, target 21 | ) 22 | return output 23 | -------------------------------------------------------------------------------- /docs/source/user_guide/whatisferret.rst: -------------------------------------------------------------------------------- 1 | .. _whatisferret: 2 | 3 | *************** 4 | What is ferret? 5 | *************** 6 | 7 | 8 | **ferret** is Python library for benchmarking interpretability techniques. The library 9 | is deeply integrated with the `transformers`_ library, models, and code. 10 | 11 | Concretely, ferret lets you: 12 | 13 | - visualize and compare Transformers-based models output explanations using state-of-the-art XAI methods 14 | - benchmark explanations with ad-hoc XAI metrics for faithfulness and plausibility 15 | - streamline the evaluation on existing XAI corpora 16 | - work both on text-based data and audio-based data (speech classification tasks) 17 | 18 | .. _transformers: https://huggingface.co/docs/transformers/index -------------------------------------------------------------------------------- /docs/source/api/api/ferret.BaseEvaluator.rst: -------------------------------------------------------------------------------- 1 | ferret.BaseEvaluator 2 | ==================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: BaseEvaluator 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~BaseEvaluator.__init__ 17 | ~BaseEvaluator.aggregate_score 18 | ~BaseEvaluator.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~BaseEvaluator.BEST_SORTING_ASCENDING 29 | ~BaseEvaluator.INIT_VALUE 30 | ~BaseEvaluator.NAME 31 | ~BaseEvaluator.SHORT_NAME 32 | ~BaseEvaluator.TYPE_METRIC 33 | ~BaseEvaluator.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/_static/versions.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "latest", 4 | "version": "0.4.1", 5 | "url": "https://ferret.readthedocs.io/en/latest/" 6 | }, 7 | { 8 | "name": "v0.4.1 (Furo)", 9 | "version": "0.4.1", 10 | "url": "https://ferret.readthedocs.io/en/v0.4.1/" 11 | }, 12 | { 13 | "name": "dev", 14 | "version": "dev", 15 | "url": "https://ferret.readthedocs.io/en/dev/" 16 | }, 17 | { 18 | "name": "v0.3.5", 19 | "version": "0.3.5", 20 | "url": "https://ferret.readthedocs.io/en/0.3.5/" 21 | }, 22 | { 23 | "name": "v0.2.4", 24 | "version": "0.2.4", 25 | "url": "https://ferret.readthedocs.io/en/0.2.4/" 26 | } 27 | ] -------------------------------------------------------------------------------- /docs/source/api/api/ferret.ThermostatDataset.rst: -------------------------------------------------------------------------------- 1 | ferret.ThermostatDataset 2 | ======================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: ThermostatDataset 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~ThermostatDataset.__init__ 17 | ~ThermostatDataset.get_explanations 18 | ~ThermostatDataset.get_instance 19 | ~ThermostatDataset.get_target_explanations 20 | ~ThermostatDataset.get_true_rationale_from_words_to_tokens 21 | ~ThermostatDataset.len 22 | 23 | 24 | 25 | 26 | 27 | .. rubric:: Attributes 28 | 29 | .. autosummary:: 30 | 31 | ~ThermostatDataset.NAME 32 | ~ThermostatDataset.avg_rationale_size 33 | 34 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.evaluators.BaseEvaluator.rst: -------------------------------------------------------------------------------- 1 | ferret.evaluators.BaseEvaluator 2 | =============================== 3 | 4 | .. currentmodule:: ferret.evaluators 5 | 6 | .. autoclass:: BaseEvaluator 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~BaseEvaluator.__init__ 17 | ~BaseEvaluator.aggregate_score 18 | ~BaseEvaluator.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~BaseEvaluator.BEST_SORTING_ASCENDING 29 | ~BaseEvaluator.INIT_VALUE 30 | ~BaseEvaluator.NAME 31 | ~BaseEvaluator.SHORT_NAME 32 | ~BaseEvaluator.TYPE_METRIC 33 | ~BaseEvaluator.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/explanation_speech.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | from typing import Optional 4 | 5 | 6 | @dataclass 7 | class ExplanationSpeech: 8 | features: list 9 | scores: np.array 10 | explainer: str 11 | target: list 12 | audio_path: Optional[str] = None 13 | 14 | 15 | @dataclass 16 | class EvaluationSpeech: 17 | """ 18 | Generic class to represent a speech Evaluation. 19 | 20 | Note: this has a subset of the `Explanation` dataclass' attributes, so it 21 | should be possible to write smaller common parent class for both 22 | very similar to this (the `Explanation` class - for text - is more 23 | specific). 24 | """ 25 | 26 | name: str 27 | score: list 28 | target: list -------------------------------------------------------------------------------- /docs/source/api/api/ferret.IntegratedGradientExplainer.rst: -------------------------------------------------------------------------------- 1 | ferret.IntegratedGradientExplainer 2 | ================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: IntegratedGradientExplainer 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~IntegratedGradientExplainer.__init__ 17 | ~IntegratedGradientExplainer.compute_feature_importance 18 | ~IntegratedGradientExplainer.get_input_embeds 19 | ~IntegratedGradientExplainer.get_tokens 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~IntegratedGradientExplainer.NAME 30 | ~IntegratedGradientExplainer.device 31 | ~IntegratedGradientExplainer.tokenizer 32 | 33 | -------------------------------------------------------------------------------- /ferret/evaluators/perturbation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | class PertubationHelper: 5 | def __init__(self, tokenizer): 6 | self.tokenizer = tokenizer 7 | 8 | def edit_one_token(self, input_ids, strategy, mask_token_id=None): 9 | samples = list() 10 | 11 | for occ_idx in range(len(input_ids)): 12 | sample = copy.copy(input_ids) 13 | 14 | if strategy == "remove": 15 | sample.pop(occ_idx) 16 | elif strategy == "mask": 17 | if mask_token_id is None: 18 | mask_token_id = self.tokenizer.mask_token_id 19 | sample[occ_idx] = mask_token_id 20 | else: 21 | raise ValueError(f"Unknown strategy: {strategy}") 22 | 23 | samples.append(sample) 24 | return samples 25 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.TauLOO_Evaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.TauLOO\_Evaluation 2 | ========================= 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: TauLOO_Evaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~TauLOO_Evaluation.__init__ 17 | ~TauLOO_Evaluation.aggregate_score 18 | ~TauLOO_Evaluation.compute_evaluation 19 | ~TauLOO_Evaluation.compute_leave_one_out_occlusion 20 | 21 | 22 | 23 | 24 | 25 | .. rubric:: Attributes 26 | 27 | .. autosummary:: 28 | 29 | ~TauLOO_Evaluation.BEST_SORTING_ASCENDING 30 | ~TauLOO_Evaluation.INIT_VALUE 31 | ~TauLOO_Evaluation.NAME 32 | ~TauLOO_Evaluation.SHORT_NAME 33 | ~TauLOO_Evaluation.TYPE_METRIC 34 | ~TauLOO_Evaluation.tokenizer 35 | 36 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Reference 3 | ========= 4 | 5 | This list contains all main classes and methods. 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | benchmark 11 | explainers 12 | evaluators 13 | datasets 14 | 15 | .. Benchmark 16 | .. --------- 17 | 18 | .. .. autoclass:: ferret.benchmark.Benchmark 19 | .. :members: 20 | .. :undoc-members: 21 | 22 | .. Explainers 23 | .. ---------- 24 | 25 | .. .. autoclass:: ferret.explainers.gradient.GradientExplainer 26 | .. :members: 27 | .. :undoc-members: 28 | 29 | .. .. autoclass:: ferret.explainers.gradient.IntegratedGradientExplainer 30 | .. :members: 31 | .. :undoc-members: 32 | 33 | .. .. autoclass:: ferret.explainers.shap.SHAPExplainer 34 | .. :members: 35 | .. :undoc-members: 36 | 37 | .. .. autoclass:: ferret.explainers.lime.LIMEExplainer 38 | .. :members: 39 | .. :undoc-members: -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=ferret 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.AOPC_Sufficiency_Evaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.AOPC\_Sufficiency\_Evaluation 2 | ==================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: AOPC_Sufficiency_Evaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~AOPC_Sufficiency_Evaluation.__init__ 17 | ~AOPC_Sufficiency_Evaluation.aggregate_score 18 | ~AOPC_Sufficiency_Evaluation.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~AOPC_Sufficiency_Evaluation.BEST_SORTING_ASCENDING 29 | ~AOPC_Sufficiency_Evaluation.INIT_VALUE 30 | ~AOPC_Sufficiency_Evaluation.NAME 31 | ~AOPC_Sufficiency_Evaluation.SHORT_NAME 32 | ~AOPC_Sufficiency_Evaluation.TYPE_METRIC 33 | ~AOPC_Sufficiency_Evaluation.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.AUPRC_PlausibilityEvaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.AUPRC\_PlausibilityEvaluation 2 | ==================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: AUPRC_PlausibilityEvaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~AUPRC_PlausibilityEvaluation.__init__ 17 | ~AUPRC_PlausibilityEvaluation.aggregate_score 18 | ~AUPRC_PlausibilityEvaluation.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~AUPRC_PlausibilityEvaluation.BEST_SORTING_ASCENDING 29 | ~AUPRC_PlausibilityEvaluation.INIT_VALUE 30 | ~AUPRC_PlausibilityEvaluation.NAME 31 | ~AUPRC_PlausibilityEvaluation.SHORT_NAME 32 | ~AUPRC_PlausibilityEvaluation.TYPE_METRIC 33 | ~AUPRC_PlausibilityEvaluation.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.Tokenf1_PlausibilityEvaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.Tokenf1\_PlausibilityEvaluation 2 | ====================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: Tokenf1_PlausibilityEvaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~Tokenf1_PlausibilityEvaluation.__init__ 17 | ~Tokenf1_PlausibilityEvaluation.aggregate_score 18 | ~Tokenf1_PlausibilityEvaluation.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~Tokenf1_PlausibilityEvaluation.BEST_SORTING_ASCENDING 29 | ~Tokenf1_PlausibilityEvaluation.INIT_VALUE 30 | ~Tokenf1_PlausibilityEvaluation.NAME 31 | ~Tokenf1_PlausibilityEvaluation.SHORT_NAME 32 | ~Tokenf1_PlausibilityEvaluation.TYPE_METRIC 33 | ~Tokenf1_PlausibilityEvaluation.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.TokenIOU_PlausibilityEvaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.TokenIOU\_PlausibilityEvaluation 2 | ======================================= 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: TokenIOU_PlausibilityEvaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~TokenIOU_PlausibilityEvaluation.__init__ 17 | ~TokenIOU_PlausibilityEvaluation.aggregate_score 18 | ~TokenIOU_PlausibilityEvaluation.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~TokenIOU_PlausibilityEvaluation.BEST_SORTING_ASCENDING 29 | ~TokenIOU_PlausibilityEvaluation.INIT_VALUE 30 | ~TokenIOU_PlausibilityEvaluation.NAME 31 | ~TokenIOU_PlausibilityEvaluation.SHORT_NAME 32 | ~TokenIOU_PlausibilityEvaluation.TYPE_METRIC 33 | ~TokenIOU_PlausibilityEvaluation.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/api/ferret.AOPC_Comprehensiveness_Evaluation.rst: -------------------------------------------------------------------------------- 1 | ferret.AOPC\_Comprehensiveness\_Evaluation 2 | ========================================== 3 | 4 | .. currentmodule:: ferret 5 | 6 | .. autoclass:: AOPC_Comprehensiveness_Evaluation 7 | 8 | 9 | .. automethod:: __init__ 10 | 11 | 12 | .. rubric:: Methods 13 | 14 | .. autosummary:: 15 | 16 | ~AOPC_Comprehensiveness_Evaluation.__init__ 17 | ~AOPC_Comprehensiveness_Evaluation.aggregate_score 18 | ~AOPC_Comprehensiveness_Evaluation.compute_evaluation 19 | 20 | 21 | 22 | 23 | 24 | .. rubric:: Attributes 25 | 26 | .. autosummary:: 27 | 28 | ~AOPC_Comprehensiveness_Evaluation.BEST_SORTING_ASCENDING 29 | ~AOPC_Comprehensiveness_Evaluation.INIT_VALUE 30 | ~AOPC_Comprehensiveness_Evaluation.NAME 31 | ~AOPC_Comprehensiveness_Evaluation.SHORT_NAME 32 | ~AOPC_Comprehensiveness_Evaluation.TYPE_METRIC 33 | ~AOPC_Comprehensiveness_Evaluation.tokenizer 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/benchmark.rst: -------------------------------------------------------------------------------- 1 | .. _api.benchmark: 2 | 3 | ========= 4 | Benchmark 5 | ========= 6 | 7 | .. currentmodule:: ferret 8 | 9 | Constructor 10 | ----------- 11 | 12 | .. autosummary:: 13 | :toctree: api/ 14 | 15 | Benchmark 16 | 17 | Explaining 18 | ---------- 19 | 20 | .. autosummary:: 21 | :toctree: api/ 22 | 23 | Benchmark.explain 24 | 25 | 26 | Benchmarking Explanations 27 | ------------------------- 28 | 29 | .. autosummary:: 30 | :toctree: api/ 31 | 32 | Benchmark.evaluate_explanation 33 | Benchmark.evaluate_explanations 34 | 35 | 36 | Visualization 37 | ------------- 38 | 39 | .. autosummary:: 40 | :toctree: api/ 41 | 42 | Benchmark.show_table 43 | Benchmark.show_evaluation_table 44 | Benchmark.get_dataframe 45 | Benchmark.show_samples_evaluation_table 46 | 47 | 48 | Datasets Interface 49 | ------------------ 50 | 51 | .. autosummary:: 52 | :toctree: api/ 53 | 54 | Benchmark.load_dataset 55 | Benchmark.evaluate_samples 56 | 57 | 58 | Inference 59 | --------- 60 | 61 | .. autosummary:: 62 | :toctree: api/ 63 | 64 | Benchmark.score 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022, Giuseppe Attanasio 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /ferret/modeling/base_helpers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseTaskHelper(ABC): 5 | """ 6 | Base helper class to handle basic steps of the pipeline (e.g., tokenization, inference). 7 | """ 8 | 9 | def __init__(self, model, tokenizer=None): 10 | self.model = model 11 | self.tokenizer = tokenizer 12 | 13 | @abstractmethod 14 | def _check_target(self, target, **kwargs): 15 | """Validate the specific target requested for the explanation""" 16 | pass 17 | 18 | @abstractmethod 19 | def _check_sample(self, input, **kwargs): 20 | """Validate the specific input requested for the explanation""" 21 | pass 22 | 23 | def _prepare_sample(self, sample, **kwargs): 24 | """Format the input before the explanation""" 25 | return sample 26 | 27 | def format_target(self, target, **kwargs): 28 | """Format the target variable 29 | 30 | In all our current explainers, 'target' must be a positional integer for the 31 | logits matrix. Default: leave target unchanged. 32 | """ 33 | return target 34 | 35 | def _postprocess_logits(self, logits, **kwargs): 36 | """Process the logits before computing the explanation""" 37 | return logits 38 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '[0-9]+.[0-9]+.[0-9]+' 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@master 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v3 17 | with: 18 | python-version: "3.10" 19 | 20 | - name: Install pypa/build 21 | run: >- 22 | python -m 23 | pip install 24 | build 25 | --user 26 | - name: Build a binary wheel and a source tarball 27 | run: >- 28 | python -m 29 | build 30 | --sdist 31 | --wheel 32 | --outdir dist/ 33 | . 34 | 35 | - name: Publish distribution 📦 to Test PyPI 36 | uses: pypa/gh-action-pypi-publish@release/v1 37 | with: 38 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 39 | repository_url: https://test.pypi.org/legacy/ 40 | 41 | - name: Publish distribution 📦 to PyPI 42 | # if: startsWith(github.ref, 'refs/tags') 43 | uses: pypa/gh-action-pypi-publish@release/v1 44 | with: 45 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | Release notes 3 | ============= 4 | 5 | 0.5.0 (2024-02-27) 6 | ------------------ 7 | 8 | * [added] Task-API interface. (#35) The Benchmark class allows now to specify one of the supported NLP tasks and handles explanation and evaluation according to the semantic of the task. 9 | * [added] Support to speech models for classification (#36). The library exposes a new SpeechBenchmark class that implements the methodology presentend in `this paper `_. 10 | * [deprecated] We deprecated the methods *evaluate_samples* and *show_samples_evaluation_table* since they run basic aggregation / averaging which we decided to leave to the user. 11 | 12 | 13 | 0.4.1 (2022-12-27) 14 | ------------------ 15 | 16 | * [added] Integrated interface to Thermostat datasets and pre-coumpute feature attributions 17 | 18 | 0.4.0 (2022-09-01) 19 | ------------------ 20 | 21 | * [added] GPU inference for all the supported explainers 22 | * [added] Batched inference on both CPU and GPU (see our `usage guides `_) 23 | * New cool-looking `docs `_ using `Furo `_. 24 | 25 | 0.1.0 (2022-05-30) 26 | ------------------ 27 | 28 | * First release on PyPI. And a lot of dreams ahead of us. 29 | 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | from setuptools import find_packages, setup 6 | 7 | with open("README.md") as readme_file: 8 | readme = readme_file.read() 9 | 10 | with open("HISTORY.rst") as history_file: 11 | history = history_file.read() 12 | 13 | requirements = list() 14 | 15 | test_requirements = list() 16 | 17 | setup( 18 | author="Giuseppe Attanasio", 19 | author_email="giuseppeattanasio6@gmail.com", 20 | python_requires=">=3.8", 21 | classifiers=[ 22 | "Development Status :: 2 - Pre-Alpha", 23 | "Intended Audience :: Developers", 24 | "License :: OSI Approved :: MIT License", 25 | "Natural Language :: English", 26 | "Programming Language :: Python :: 3", 27 | "Programming Language :: Python :: 3.8", 28 | ], 29 | description="A python package for benchmarking interpretability approaches.", 30 | install_requires=requirements, 31 | license="MIT license", 32 | long_description=readme, # + "\n\n" + history, 33 | long_description_content_type="text/x-rst", 34 | include_package_data=True, 35 | keywords="ferret", 36 | name="ferret-xai", 37 | packages=find_packages(include=["ferret", "ferret.*"]), 38 | test_suite="tests", 39 | tests_require=test_requirements, 40 | url="https://github.com/g8a9/ferret", 41 | version="0.4.1", 42 | zip_safe=False, 43 | ) 44 | -------------------------------------------------------------------------------- /.github/workflows/flake8-pytest.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.9", "3.10", "3.11"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest poetry 31 | # poetry install --all-extras 32 | pip install -e .[all] 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | -------------------------------------------------------------------------------- /docs/source/user_guide/advanced.rst: -------------------------------------------------------------------------------- 1 | .. _advanced: 2 | 3 | ************** 4 | Advanced Usage 5 | ************** 6 | 7 | 8 | Inference on GPU 9 | ^^^^^^^^^^^^^^^^ 10 | 11 | Starting from version 0.4.0, ferret supports inference on GPU. 12 | In practice, ferret will use the device of the `model` with no changes of explicit calls. 13 | 14 | Assuming that your model is currently on CPU, to run explanations on GPU you just need to move it before: 15 | 16 | .. code-block:: python 17 | 18 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 19 | from ferret import LIMEExplainer 20 | 21 | name = "cardiffnlp/twitter-xlm-roberta-base-sentiment" 22 | m = AutoModelForSequenceClassification.from_pretrained(name).to("cuda:0") 23 | t = AutoTokenizer.from_pretrained(name) 24 | 25 | exp = LIMEExplainer(model, tokenizer) 26 | explanation = expl("You look stunning!", target=1) 27 | 28 | Batched Inference 29 | ^^^^^^^^^^^^^^^^^ 30 | 31 | Some explainers (e.g., LIME or IntegratedGradients) require to run inference on a large number 32 | of data points, which might be computationally unfeasible for transformer models. 33 | 34 | Since verson 0.4.0, ferret supports automatically batched inference (on both CPU and GPU). 35 | When the batched inference is available, you can specify both `batch_size` and 36 | `show_progress`. 37 | 38 | .. code-block:: python 39 | 40 | exp = LIMEExplainer(m, t) 41 | call_args={"num_samples": 5000, "show_progress": True, "batch_size": 16} 42 | explanation = expl("You look stunning!", call_args=call_args) -------------------------------------------------------------------------------- /docs/source/user_guide/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. _quickstart: 2 | 3 | ********** 4 | Quickstart 5 | ********** 6 | 7 | Here is a code snipped to show **ferret** integrated with your existing **transformers** models for a text-based task. 8 | 9 | .. code-block:: python 10 | 11 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 12 | from ferret import Benchmark 13 | 14 | name = "cardiffnlp/twitter-xlm-roberta-base-sentiment" 15 | model = AutoModelForSequenceClassification.from_pretrained(name) 16 | tokenizer = AutoTokenizer.from_pretrained(name) 17 | 18 | bench = Benchmark(model, tokenizer) 19 | explanations = bench.explain("You look stunning!", target=1) 20 | evaluations = bench.evaluate_explanations(explanations, target=1) 21 | 22 | bench.show_evaluation_table(evaluations) 23 | 24 | The ferret library also streamlines working with audio (speech) data. 25 | 26 | .. code-block:: python 27 | 28 | from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor 29 | from ferret import SpeechBenchmark, AOPC_Comprehensiveness_Evaluation_Speech 30 | 31 | model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic") 32 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superbwav2vec2-base-superb-ic") 33 | 34 | speech_benchmark = SpeechBenchmark(model, feature_extractor) 35 | explanation = speech_benchmark.explain(audio_path=audio_path, methodology='LOO') 36 | aopc_compr = AOPC_Comprehensiveness_Evaluation_Speech(benchmark.model_helper) 37 | evaluation_output_c = aopc_compr.compute_evaluation(explanation) 38 | 39 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Task Support Matrix 2 | 3 | *ferret* integrates seamlessly with a wide range of tasks. Please refer to the matrix below 4 | to see which task we currently support off-the-shelf (note: **ferret-xai >= 0.5.0 is required**). 5 | 6 | 7 | | Task (`HF Class`) | G | IG | SHAP | LIME | Tutorial | 8 | |-------------------------------|:-:|:--:|:----:|:----:|----------| 9 | | Sequence Classification (`AutoModelForSequenceClassification`) | ✅ | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/g8a9/ferret/blob/task-API/examples/sentiment_classification.ipynb) | 10 | | Natural Language Inference (`AutoModelForSequenceClassification`) | ✅ | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/g8a9/ferret/blob/task-API/examples/nli.ipynb) | 11 | | Zero-Shot Text Classification (`AutoModelForSequenceClassification`) | ✅ | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/g8a9/ferret/blob/task-API/examples/zeroshot_text_classification.ipynb) | 12 | | Named Entity Recognition (`AutoModelForTokenClassification`) | ✅ | ✅ | ✅️ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/g8a9/ferret/blob/task-API/examples/ner.ipynb) | 13 | | _Multiple Choice_ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | 14 | | _Masked Language Modeling_ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | 15 | | _Casual Language Modeling_ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | ⚙️ | 16 | 17 | Where: 18 | - ✅: we got you covered! 19 | - ⚙️: working on it... 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .DS_Store 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # dotenv 86 | .env 87 | 88 | # virtualenv 89 | .venv 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE settings 107 | .vscode/ 108 | .idea/ 109 | 110 | tmp/ 111 | .DS_Store 112 | -------------------------------------------------------------------------------- /ferret/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Datasets API""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | TRAIN_SET = "TRAIN_SET" 7 | VALIDATION_SET = "VALIDATION_SET" 8 | TEST_SET = "TEST_SET" 9 | 10 | 11 | class BaseDataset(ABC): 12 | @property 13 | @abstractmethod 14 | def NAME(self): 15 | pass 16 | 17 | @property 18 | @abstractmethod 19 | def avg_rationale_size(self): 20 | # Default value 21 | return 5 22 | 23 | def __init__(self, tokenizer): 24 | self.tokenizer = tokenizer 25 | 26 | @abstractmethod 27 | def get_instance(self, idx: int, split_type: str = TEST_SET): 28 | pass 29 | 30 | @abstractmethod 31 | def _get_item(self, idx: int, split_type: str = TEST_SET): 32 | pass 33 | 34 | @abstractmethod 35 | def _get_text(self, idx, split_type: str = TEST_SET): 36 | pass 37 | 38 | @abstractmethod 39 | def _get_rationale(self, idx, split_type: str = TEST_SET): 40 | pass 41 | 42 | @abstractmethod 43 | def _get_ground_truth(self, idx, split_type: str = TEST_SET): 44 | pass 45 | 46 | def get_true_rationale_from_words_to_tokens( 47 | self, word_based_tokens: List[str], words_based_rationales: List[int] 48 | ) -> List[int]: 49 | # original_tokens --> list of words. 50 | # rationale_original_tokens --> 0 or 1, if the token belongs to the rationale or not 51 | # Typically, the importance is associated with each word rather than each token. 52 | # We convert each word in token using the tokenizer. If a word is in the rationale, 53 | # we consider as important all the tokens of the word. 54 | token_rationale = [] 55 | for t, rationale_t in zip(word_based_tokens, words_based_rationales): 56 | converted_token = self.tokenizer.encode(t)[1:-1] 57 | 58 | for token_i in converted_token: 59 | token_rationale.append(rationale_t) 60 | return token_rationale 61 | -------------------------------------------------------------------------------- /ferret/evaluators/class_measures.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy as np 4 | 5 | from ferret.explainers.explanation import Explanation, ExplanationWithRationale 6 | 7 | from ..modeling import create_helper 8 | from .evaluation import EvaluationMetricOutput 9 | from .faithfulness_measures import AOPC_Comprehensiveness_Evaluation 10 | 11 | 12 | class AOPC_Comprehensiveness_Evaluation_by_class: 13 | NAME = "aopc_class_comprehensiveness" 14 | SHORT_NAME = "aopc_class_compr" 15 | # Higher is better 16 | BEST_SORTING_ASCENDING = False 17 | TYPE_METRIC = "class_faithfulness" 18 | 19 | def __init__( 20 | self, 21 | model, 22 | tokenizer, 23 | task_name, 24 | aopc_compr_eval: AOPC_Comprehensiveness_Evaluation = None, 25 | ): 26 | if aopc_compr_eval is None: 27 | if model is None or tokenizer is None: 28 | raise ValueError("Please specify a model and a tokenizer.") 29 | 30 | self.helper = create_helper(model, tokenizer, task_name) 31 | self.aopc_compr_eval = AOPC_Comprehensiveness_Evaluation( 32 | model, tokenizer, task_name 33 | ) 34 | else: 35 | self.aopc_compr_eval = aopc_compr_eval 36 | 37 | def compute_evaluation( 38 | self, 39 | class_explanation: List[Union[Explanation, ExplanationWithRationale]], 40 | **evaluation_args 41 | ): 42 | 43 | """ 44 | Each element of the list is the explanation for a target class 45 | """ 46 | 47 | evaluation_args["only_pos"] = True 48 | 49 | aopc_values = [] 50 | for target, explanation in enumerate(class_explanation): 51 | aopc_values.append( 52 | self.aopc_compr_eval.compute_evaluation( 53 | explanation, target, **evaluation_args 54 | ).score 55 | ) 56 | aopc_class_score = np.mean(aopc_values) 57 | evaluation_output = EvaluationMetricOutput(self.SHORT_NAME, aopc_class_score) 58 | return evaluation_output 59 | 60 | def aggregate_score(self, score, total, **aggregation_args): 61 | return score / total 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ferret-xai" 3 | version = "0.5.0" 4 | description = "A python package for benchmarking interpretability approaches." 5 | authors = ["Giuseppe Attanasio "] 6 | license = "MIT" 7 | readme = "README.md" 8 | homepage = "https://github.com/g8a9/ferret" 9 | repository = "https://github.com/g8a9/ferret" 10 | documentation = "https://ferret.readthedocs.io/en/latest/" 11 | keywords = ["interpretability", "benchmarking", "xai", "nlp", "ml", "ai"] 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Intended Audience :: Science/Research", 15 | "Intended Audience :: Education", 16 | "Intended Audience :: Developers", 17 | "License :: OSI Approved :: MIT License", 18 | "Natural Language :: English", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | ] 23 | packages = [ 24 | { include = "ferret" }, 25 | ] 26 | 27 | 28 | [tool.poetry.dependencies] 29 | python = "^3.9.0" 30 | transformers = "^4.36.2" 31 | datasets = "^2.16.1" 32 | sentencepiece = "^0.1.99" 33 | captum = "^0.7.0" 34 | shap = "^0.44.0" 35 | seaborn = "^0.13.1" 36 | matplotlib = "^3.7.4" 37 | numpy = "^1.24.4" 38 | pandas = "^2.0.3" 39 | tqdm = "^4.66.1" 40 | scikit-image = "^0.21.0" 41 | opencv-python = "^4.9.0.80" 42 | lime = "^0.2.0.1" 43 | joblib = "^1.3.2" 44 | pytreebank = "^0.2.7" 45 | thermostat-datasets = "^1.1.0" 46 | # Speech-XAI additional requirements to allow for `pip install ferret[speech]`. 47 | pydub = { version = "0.25.1", optional = true } 48 | audiomentations = { version = "0.34.1", optional = true } 49 | audiostretchy = { version = "1.3.5", optional = true } 50 | pyroomacoustics = { version = "0.7.3", optional = true } 51 | whisperx = { version = "3.1.2", optional = true } 52 | 53 | [tool.poetry.extras] 54 | speech = [ 55 | "pydub", 56 | "audiomentations", 57 | "audiostretchy", 58 | "pyroomacoustics", 59 | "whisperx" 60 | ] 61 | all = [ 62 | "pydub", 63 | "audiomentations", 64 | "audiostretchy", 65 | "pyroomacoustics", 66 | "whisperx" 67 | ] 68 | 69 | 70 | [build-system] 71 | requires = ["poetry-core"] 72 | build-backend = "poetry.core.masonry.api" 73 | 74 | [tool.black] 75 | line-length = 89 76 | 77 | [tool.isort] 78 | profile = "black" 79 | -------------------------------------------------------------------------------- /ferret/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for ferret.""" 2 | 3 | __author__ = """Giuseppe Attanasio""" 4 | __email__ = "giuseppeattanasio6@gmail.com" 5 | __version__ = "0.5.0" 6 | 7 | from logging import getLogger 8 | 9 | logger = getLogger(__name__) 10 | 11 | from .benchmark import Benchmark 12 | 13 | # Dataset Interface 14 | from .datasets import BaseDataset 15 | from .datasets.datamanagers import HateXplainDataset, MovieReviews, SSTDataset 16 | from .datasets.datamanagers_thermostat import ThermostatDataset 17 | 18 | # Benchmarking methods 19 | from .evaluators import BaseEvaluator 20 | from .evaluators.faithfulness_measures import ( 21 | AOPC_Comprehensiveness_Evaluation, 22 | AOPC_Sufficiency_Evaluation, 23 | TauLOO_Evaluation, 24 | ) 25 | from .evaluators.plausibility_measures import ( 26 | AUPRC_PlausibilityEvaluation, 27 | Tokenf1_PlausibilityEvaluation, 28 | TokenIOU_PlausibilityEvaluation, 29 | ) 30 | 31 | # Explainers 32 | from .explainers import BaseExplainer 33 | from .explainers.dummy import DummyExplainer 34 | from .explainers.gradient import GradientExplainer, IntegratedGradientExplainer 35 | from .explainers.lime import LIMEExplainer 36 | from .explainers.shap import SHAPExplainer 37 | from .modeling.text_helpers import TokenClassificationHelper 38 | 39 | 40 | # Conditional imports for speech-related tasks 41 | try: 42 | # Explainers 43 | from .explainers.explanation_speech.paraling_speech_explainer import ( 44 | ParalinguisticSpeechExplainer, 45 | ) 46 | from .explainers.explanation_speech.loo_speech_explainer import LOOSpeechExplainer 47 | from .explainers.explanation_speech.explanation_speech import ExplanationSpeech 48 | 49 | # Model Helpers 50 | from .modeling.speech_model_helpers.model_helper_er import ModelHelperER 51 | from .modeling.speech_model_helpers.model_helper_fsc import ModelHelperFSC 52 | from .modeling.speech_model_helpers.model_helper_italic import ModelHelperITALIC 53 | from .benchmark_speech import SpeechBenchmark 54 | 55 | # Benchmarking methods 56 | from .evaluators.faithfulness_measures_speech import ( 57 | AOPC_Comprehensiveness_Evaluation_Speech, 58 | AOPC_Sufficiency_Evaluation_Speech, 59 | ) 60 | except ImportError: 61 | logger.info( 62 | "Speech-related modules could not be imported. It is very likely that ferret was installed in the standard, text-only mode. Run `pip install ferret-xai[speech]` or `pip install ferret-xai[all] to include them." 63 | ) 64 | -------------------------------------------------------------------------------- /ferret/explainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Explainers API""" 2 | 3 | import warnings 4 | from abc import ABC, abstractmethod 5 | from typing import Optional, Union 6 | 7 | from ..modeling import create_helper 8 | from ..modeling.base_helpers import BaseTaskHelper 9 | 10 | 11 | class BaseExplainer(ABC): 12 | @property 13 | @abstractmethod 14 | def NAME(self): 15 | pass 16 | 17 | def __init__( 18 | self, model, tokenizer, model_helper: Optional[BaseTaskHelper] = None, **kwargs 19 | ): 20 | # We use the task_name parameter to specify the correct helper via the create_helper() function 21 | task_name = kwargs.pop('task_name', None) 22 | 23 | if model is None or tokenizer is None: 24 | raise ValueError("Please specify a model and a tokenizer.") 25 | 26 | self.init_args = kwargs 27 | 28 | # The user can now specify the task name even for explainers, and that will set the correct helper 29 | # even if no model_helper is specified. If the user does not specify anything, we show the Warning. 30 | if model_helper is None: 31 | if task_name is None: 32 | task_name = "text-classification" 33 | warnings.warn( 34 | "No helper provided. Using default 'text-classification' helper." 35 | ) 36 | self.helper = create_helper(model, tokenizer, task_name) 37 | else: 38 | self.helper=model_helper 39 | 40 | @property 41 | def device(self): 42 | return self.helper.model.device 43 | 44 | @property 45 | def model(self): 46 | return self.helper.model 47 | 48 | @property 49 | def tokenizer(self): 50 | return self.helper.tokenizer 51 | 52 | def _tokenize(self, text, **tok_kwargs): 53 | return self.helper._tokenize(text, **tok_kwargs) 54 | 55 | def get_tokens(self, text): 56 | return self.helper.get_tokens(text) 57 | 58 | def get_input_embeds(self, text): 59 | return self.helper.get_input_embeds(text) 60 | 61 | @abstractmethod 62 | def compute_feature_importance( 63 | self, text: str, target: int, target_token: Optional[str], **explainer_args 64 | ): 65 | pass 66 | 67 | def __call__( 68 | self, 69 | text: str, 70 | target: Union[str,int], 71 | target_token: Optional[str] = None, 72 | **explainer_args 73 | ): 74 | return self.compute_feature_importance( 75 | text, target, target_token, **explainer_args 76 | ) 77 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-build clean-pyc clean-test coverage dist docs help install lint lint/flake8 lint/black 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | lint/flake8: ## check style with flake8 51 | flake8 ferret tests 52 | lint/black: ## check style with black 53 | black --check ferret tests 54 | 55 | lint: lint/flake8 lint/black ## check style 56 | 57 | test: ## run tests quickly with the default Python 58 | python setup.py test 59 | 60 | test-all: ## run tests on every Python version with tox 61 | tox 62 | 63 | coverage: ## check code coverage quickly with the default Python 64 | coverage run --source ferret setup.py test 65 | coverage report -m 66 | coverage html 67 | $(BROWSER) htmlcov/index.html 68 | 69 | docs: ## generate Sphinx HTML documentation, including API docs 70 | rm -f docs/ferret.rst 71 | rm -f docs/modules.rst 72 | sphinx-apidoc -o docs/ ferret 73 | $(MAKE) -C docs clean 74 | $(MAKE) -C docs html 75 | $(BROWSER) docs/_build/html/index.html 76 | 77 | servedocs: docs ## compile the docs watching for changes 78 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 79 | 80 | release: dist ## package and upload a release 81 | twine upload dist/* 82 | 83 | dist: clean ## builds source and wheel package 84 | python setup.py sdist 85 | python setup.py bdist_wheel 86 | ls -l dist 87 | 88 | install: clean ## install the package to the active Python's site-packages 89 | python setup.py install 90 | -------------------------------------------------------------------------------- /ferret/evaluators/utils_from_soft_to_discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _get_id_tokens_greater_th(soft_score_explanation, th, only_pos=None): 5 | id_top = np.where(soft_score_explanation > th)[0] 6 | return id_top 7 | 8 | 9 | def _get_id_tokens_top_k(soft_score_explanation, k, only_pos=True): 10 | if only_pos: 11 | id_top_k = [ 12 | i 13 | for i in np.array(soft_score_explanation).argsort()[-k:][::-1] 14 | if soft_score_explanation[i] > 0 15 | ] 16 | else: 17 | id_top_k = np.array(soft_score_explanation).argsort()[-k:][::-1] 18 | # None if we take no token 19 | if id_top_k == []: 20 | return None 21 | return id_top_k 22 | 23 | 24 | def _get_id_tokens_percentage(soft_score_explanation, percentage, only_pos=True): 25 | v = int(percentage * len(soft_score_explanation)) 26 | # Only if we remove at least instance. TBD 27 | if v > 0 and v <= len(soft_score_explanation): 28 | return _get_id_tokens_top_k(soft_score_explanation, v, only_pos=only_pos) 29 | else: 30 | return None 31 | 32 | 33 | def get_discrete_explanation_topK(score_explanation, topK, only_pos=False): 34 | 35 | # Indexes in the top k. If only pos is true, we only consider scores>0 36 | topk_indices = _get_id_tokens_top_k(score_explanation, topK, only_pos=only_pos) 37 | 38 | # Return default score 39 | if topk_indices is None: 40 | return None 41 | 42 | # topk_score_explanations: one hot encoding: 1 if the token is in the rationale, 0 otherwise 43 | # i hate you [0, 1, 1] 44 | 45 | topk_score_explanations = [ 46 | 1 if i in topk_indices else 0 for i in range(len(score_explanation)) 47 | ] 48 | return topk_score_explanations 49 | 50 | 51 | def _check_and_define_get_id_discrete_rationale_function(based_on): 52 | if based_on == "th": 53 | get_discrete_rationale_function = _get_id_tokens_greater_th 54 | elif based_on == "k": 55 | get_discrete_rationale_function = _get_id_tokens_top_k 56 | elif based_on == "perc": 57 | get_discrete_rationale_function = _get_id_tokens_percentage 58 | else: 59 | raise ValueError(f"{based_on} type not supported. Specify th, k or perc.") 60 | return get_discrete_rationale_function 61 | 62 | 63 | def parse_evaluator_args(evaluator_args): 64 | # Default parameters 65 | 66 | # We omit the scores [CLS] and [SEP] 67 | remove_first_last = evaluator_args.get("remove_first_last", True) 68 | 69 | # As a default, we consider in the rationale only the terms influencing positively the prediction 70 | only_pos = evaluator_args.get("only_pos", True) 71 | 72 | removal_args_input = evaluator_args.get("removal_args", None) 73 | 74 | # As a default, we remove from 10% to 100% of the tokens. 75 | removal_args = { 76 | "remove_tokens": True, 77 | "based_on": "perc", 78 | "thresholds": np.arange(0.1, 1.1, 0.1), 79 | } 80 | 81 | if removal_args_input: 82 | removal_args.update(removal_args_input) 83 | 84 | # Top k tokens to be considered for the hard evaluation of plausibility 85 | # This is typically set as the average size of human rationales 86 | top_k_hard_rationale = evaluator_args.get("top_k_rationale", 5) 87 | 88 | return remove_first_last, only_pos, removal_args, top_k_hard_rationale 89 | -------------------------------------------------------------------------------- /docs/source/user_guide/speechxai.rst: -------------------------------------------------------------------------------- 1 | .. _speechxai: 2 | 3 | ***************************** 4 | Speech XAI 5 | ***************************** 6 | 7 | 8 | ferret offers Speech XAI functionalities through the `SpeechBenchmark` class (analogous to the `Benchmark` one for text data). We provide two types of insights. 🚀 9 | 10 | - Word-level. We measure the impact of each audio segment aligned with a word on the outcome. 11 | 12 | - Paralinguistic. We evaluate how non-linguistic features (e.g., prosody and background noise) affect the outcome if perturbed. 13 | 14 | Explanation 15 | =========== 16 | The code below provides a minimal example on how to generate word-level audio segment and paralinguistic attributions. 17 | 18 | We start by loading the model to explain 19 | 20 | .. code-block:: python 21 | 22 | from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor 23 | 24 | model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic") 25 | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superbwav2vec2-base-superb-ic") 26 | 27 | We generate explanations by simply specifying the path of the audio to explain. 28 | 29 | Here we derive word-level audio segment explanation via the leave-one-out technique: 30 | 31 | .. code-block:: python 32 | 33 | from ferret import SpeechBenchmark 34 | 35 | speech_benchmark = SpeechBenchmark(model, feature_extractor) 36 | 37 | explanation = speech_benchmark.explain( 38 | audio_path=audio_path, 39 | methodology='LOO') 40 | 41 | display(speech_benchmark.show_table(explanation, decimals=3)), 42 | 43 | .. image:: _speechxai_images/example_word-level-audio-segments-loo.png 44 | :width: 400 45 | :alt: Example of word-level audio segment explanation 46 | 47 | Here we derive paralinguistic attributions 48 | 49 | .. code-block:: python 50 | 51 | paraling_expl = speech_benchmark.explain( 52 | audio_path=audio_path, 53 | methodology='perturb_paraling', 54 | ) 55 | 56 | display(speech_benchmark.show_table(paraling_expl, decimals=2)) 57 | 58 | .. image:: _speechxai_images/example_paralinguistic_expl.png 59 | :width: 400 60 | :alt: Example of paralinguistic attribution 61 | 62 | We can also plot the impact on the prediction probability when varying the degree of perturbations of the paralinguistic features: 63 | 64 | .. code-block:: python 65 | 66 | variations_table = speech_benchmark.explain_variations( 67 | audio_path=audio_path, 68 | perturbation_types=['time stretching', 'pitch shifting', 'reverberation']) 69 | 70 | speech_benchmark.plot_variations(variations_table, show_diff = True); 71 | 72 | .. image:: _speechxai_images/example_paralinguistic_variations.png 73 | :width: 400 74 | :alt: Example of paralinguistic explanation 75 | 76 | Evaluation 77 | ========== 78 | We can evaluate the faithfulness of our word-level segment explanation in terms of comprehensiveness and sufficiency: 79 | 80 | .. code-block:: python 81 | 82 | from ferret import AOPC_Comprehensiveness_Evaluation_Speech, AOPC_Sufficiency_Evaluation_Speech 83 | 84 | aopc_compr = AOPC_Comprehensiveness_Evaluation_Speech(speech_benchmark.model_helper) 85 | evaluation_output_c = aopc_compr.compute_evaluation(explanation) 86 | 87 | aopc_suff = AOPC_Sufficiency_Evaluation_Speech(speech_benchmark.model_helper) 88 | evaluation_output_s = aopc_suff.compute_evaluation(explanation) 89 | -------------------------------------------------------------------------------- /docs/source/user_guide/explaining.rst: -------------------------------------------------------------------------------- 1 | .. _explaining: 2 | 3 | ********** 4 | Explaining 5 | ********** 6 | 7 | In this page, we show how to use ferret's built-in explainers to generate post-hoc feature attribution scores on a simple text. 8 | 9 | :ref:`Post-hoc feature attribution methods ` explain why a model made a specific prediction for a given text. 10 | These methods assign an importance score to each input. In the context of text data, we typically assign a score to each token, and so in ferret. 11 | Given a model, a target class, and a prediction, ferret lets you measure how much each token contributed to that prediction. 12 | 13 | ferret integrates multiple post-hoc feature attribution methods: Gradient, GradientXInput, Integrated Gradient, SHAP, LIME. 14 | We can explain a prediction with the multiple supported approaches and visualize explanations. 15 | 16 | 17 | 18 | 19 | .. _explain-predictions: 20 | 21 | Explain predictions 22 | ====================== 23 | 24 | ferret offers direct integration with Hugging Face models and naming conventions. Hence, we can easily explain Hugging face models for text classification. 25 | 26 | Consider a common text classification pipeline 27 | 28 | .. code-block:: python 29 | 30 | 31 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 32 | from ferret import Benchmark 33 | 34 | name = "cardiffnlp/twitter-xlm-roberta-base-sentiment" 35 | tokenizer = AutoTokenizer.from_pretrained(name) 36 | model = AutoModelForSequenceClassification.from_pretrained(name) 37 | 38 | 39 | .. _generate-explanations: 40 | 41 | Generate explanations 42 | ---------------------------- 43 | 44 | We first specify the model and tokenizer in use through the ferret's main API access point, the `Benchmark` class. 45 | If we do not initialize with any additional parameters, we will use by default all supported post-hoc explainers with default parameters. 46 | Each explainer will provide a list of feature importance scores that quantify of *large* was the contribution of the token to a target class. 47 | A positive attribution score indicates that the token positively contributed to the final prediction. 48 | 49 | We can explain the the prediction for a given input text with respect to a target class directly using the **explain** method. 50 | 51 | 52 | .. code-block:: python 53 | 54 | from ferret import Benchmark 55 | bench = Benchmark(model, tokenizer) 56 | explanations = bench.explain('I love your style!', target=2) 57 | 58 | The **explain** method returns a list of Explanations, one for each explainer. An **Explanation** has the following form. 59 | 60 | .. code-block:: python 61 | 62 | Explanation\(text='I love your style!', tokens=['', '▁I', '▁love', '▁your', '▁style', '!', ''], scores=array([-6.40356006e-08, 1.44730296e-02, 4.23283947e-01, 2.80506348e-01, 2.20774370e-01, 6.09622411e-02, 0.00000000e+00]), explainer='Partition SHAP', target=2) 63 | 64 | It stores the input text, the tokens, the importance **score** for each token, the explainer name and the target class. 65 | 66 | 67 | .. _visualize-explanations: 68 | 69 | Visualize explanations 70 | ---------------------------- 71 | 72 | We can visualize the explanations using the **show_evaluation_table** method. 73 | 74 | .. code-block:: python 75 | 76 | bench.show_table(explanations) 77 | 78 | 79 | Here there is the output for our example. 80 | 81 | .. image:: _images/example_explanations_viz.png 82 | :width: 400 83 | :alt: Example of explanation visualization 84 | -------------------------------------------------------------------------------- /ferret/explainers/shap.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from typing import Dict, Optional, Text, Union 3 | import logging 4 | 5 | import numpy as np 6 | import shap 7 | from shap.maskers import Text as TextMasker 8 | 9 | from . import BaseExplainer 10 | from .explanation import Explanation 11 | from .utils import parse_explainer_args 12 | 13 | 14 | class SHAPExplainer(BaseExplainer): 15 | NAME = "Partition SHAP" 16 | 17 | def __init__( 18 | self, 19 | model, 20 | tokenizer, 21 | model_helper: Optional[str] = None, 22 | silent: bool = True, 23 | algorithm: str = "partition", 24 | seed: int = 42, 25 | **kwargs, 26 | ): 27 | super().__init__(model, tokenizer, model_helper, **kwargs) 28 | # Initializing SHAP-specific arguments 29 | self.init_args["silent"] = silent 30 | self.init_args["algorithm"] = algorithm 31 | self.init_args["seed"] = seed 32 | 33 | def compute_feature_importance( 34 | self, 35 | text, 36 | target: Union[int, Text] = 1, 37 | target_token: Optional[Union[int, Text]] = None, 38 | **kwargs, 39 | ): 40 | # sanity checks 41 | target_pos_idx = self.helper._check_target(target) 42 | target_token_pos_idx = self.helper._check_target_token(text, target_token) 43 | text = self.helper._check_sample(text) 44 | 45 | # Removing 'target_option' if passed as it's not relevant here 46 | if 'target_option' in kwargs: 47 | logging.warning("The 'target_option' argument is not used in SHAPExplainer and will be removed.") 48 | kwargs.pop('target_option') 49 | 50 | # Function to compute logits for SHAP explainer 51 | def func(texts: np.array): 52 | _, logits = self.helper._forward(texts.tolist()) 53 | # Adjust logits based on the target token position 54 | logits = self.helper._postprocess_logits( 55 | logits, target_token_pos_idx=target_token_pos_idx 56 | ) 57 | return logits.softmax(-1).cpu().numpy() 58 | 59 | masker = TextMasker(self.tokenizer) 60 | explainer_partition = shap.Explainer(model=func, masker=masker, **self.init_args) 61 | shap_values = explainer_partition(text, **kwargs) 62 | attr = shap_values.values[0][:, target_pos_idx] 63 | # Tokenize the text for token-level explanation 64 | item = self._tokenize(text, return_special_tokens_mask=True) 65 | token_ids = item['input_ids'][0].tolist() 66 | token_scores = np.zeros_like(token_ids, dtype=float) 67 | # Assigning SHAP values to tokens, ignoring special tokens 68 | for i, (shap_value, is_special_token) in enumerate(zip(attr, item['special_tokens_mask'][0])): 69 | if not is_special_token: 70 | token_scores[i] = shap_value 71 | 72 | output = Explanation( 73 | text=text, 74 | tokens=self.get_tokens(text), 75 | scores=token_scores, 76 | explainer=self.NAME, 77 | helper_type=self.helper.HELPER_TYPE, 78 | target_pos_idx=target_pos_idx, 79 | target_token_pos_idx=target_token_pos_idx, 80 | target=self.helper.model.config.id2label[target_pos_idx], 81 | target_token=self.helper.tokenizer.decode( 82 | item["input_ids"][0, target_token_pos_idx].item() 83 | ) 84 | if self.helper.HELPER_TYPE == "token-classification" 85 | else None, 86 | ) 87 | return output 88 | -------------------------------------------------------------------------------- /ferret/modeling/speech_model_helpers/model_helper_er.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List, Union, Tuple 3 | from pydub import AudioSegment 4 | import torch 5 | from ...speechxai_utils import pydub_to_np 6 | 7 | 8 | 9 | class ModelHelperER: 10 | """ 11 | Wrapper class to interface with HuggingFace models 12 | """ 13 | 14 | def __init__(self, model, feature_extractor, device, language="en"): 15 | self.model = model 16 | self.feature_extractor = feature_extractor 17 | self.device = device 18 | self.n_labels = 1 # Single label problem 19 | self.language = language 20 | self.label_name = "class" 21 | 22 | # PREDICT SINGLE 23 | def predict( 24 | self, 25 | audios: List[np.ndarray], 26 | ) -> np.ndarray: 27 | """ 28 | Predicts action, object and location from audio one sample at a time. 29 | Returns probs for each class. 30 | # We do separately for consistency with FSC/IC model and the bug of padding 31 | """ 32 | 33 | probs = np.empty((len(audios), self.model.config.num_labels)) 34 | for e, audio in enumerate(audios): 35 | probs[e] = self._predict([audio]) 36 | return probs 37 | 38 | def _predict( 39 | self, 40 | audios: List[np.ndarray], 41 | ) -> np.ndarray: 42 | """ 43 | Predicts emotion from audio. 44 | Returns probs for each class. 45 | """ 46 | 47 | ## Extract features 48 | inputs = self.feature_extractor( 49 | [audio.squeeze() for audio in audios], 50 | sampling_rate=self.feature_extractor.sampling_rate, 51 | padding=True, 52 | return_tensors="pt", 53 | ) 54 | 55 | ## Predict logits 56 | with torch.no_grad(): 57 | logits = ( 58 | self.model(inputs.input_values.to(self.device)) 59 | .logits.detach() 60 | .cpu() 61 | # .numpy() 62 | ) 63 | logits = logits 64 | 65 | return logits.softmax(-1).numpy() 66 | 67 | def get_text_labels(self, targets) -> str: 68 | if type(targets) is list: 69 | class_index = targets[0] 70 | else: 71 | class_index = targets 72 | return self.model.config.id2label[class_index] 73 | 74 | def get_text_labels_with_class(self, targets) -> str: 75 | """ 76 | Return the text labels with the class name as strings (e.g., ['action = increase', 'object = lights', 'location = kitchen']]) 77 | """ 78 | text_target = self.get_text_labels(targets) 79 | return f"{self.label_name}={text_target}" 80 | 81 | def get_predicted_classes(self, audio_path=None, audio=None): 82 | if audio is None and audio_path is None: 83 | raise ValueError("Specify the audio path or the audio as a numpy array") 84 | 85 | if audio is None: 86 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 87 | 88 | logits = self.predict([audio]) 89 | predicted_ids = np.argmax(logits, axis=1)[0] 90 | return predicted_ids 91 | 92 | def get_predicted_probs(self, audio_path=None, audio=None): 93 | if audio is None and audio_path is None: 94 | raise ValueError("Specify the audio path or the audio as a numpy array") 95 | 96 | if audio is None: 97 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 98 | 99 | logits = self.predict([audio]) 100 | predicted_id = np.argmax(logits, axis=1)[0] 101 | 102 | # TODO - these are not the logits, but the probs.. rename! 103 | 104 | predicted_probs = logits[:, predicted_id][0] 105 | return predicted_probs -------------------------------------------------------------------------------- /ferret/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluators API""" 2 | 3 | from abc import ABC, abstractmethod 4 | from enum import Enum 5 | from typing import Any, List, Union 6 | 7 | from ..explainers.explanation import Explanation, ExplanationWithRationale 8 | from ..modeling import create_helper 9 | from ..explainers.explanation_speech.explanation_speech import ( 10 | ExplanationSpeech, EvaluationSpeech) 11 | 12 | 13 | class EvaluationMetricFamily(Enum): 14 | """Enum to represent the family of an EvaluationMetric""" 15 | 16 | FAITHFULNESS = "faithfulness" 17 | PLAUSIBILITY = "plausibility" 18 | 19 | 20 | class BaseEvaluator(ABC): 21 | @property 22 | @abstractmethod 23 | def NAME(self): 24 | pass 25 | 26 | @property 27 | @abstractmethod 28 | def MIN_VALUE(self): 29 | pass 30 | 31 | @property 32 | @abstractmethod 33 | def MAX_VALUE(self): 34 | pass 35 | 36 | @property 37 | @abstractmethod 38 | def SHORT_NAME(self): 39 | pass 40 | 41 | @property 42 | @abstractmethod 43 | def LOWER_IS_BETTER(self): 44 | pass 45 | 46 | @property 47 | @abstractmethod 48 | def METRIC_FAMILY(self) -> EvaluationMetricFamily: 49 | pass 50 | 51 | def __repr__(self) -> str: 52 | return str( 53 | dict( 54 | NAME=self.NAME, 55 | SHORT_NAME=self.SHORT_NAME, 56 | MIN_VALUE=self.MIN_VALUE, 57 | MAX_VALUE=self.MAX_VALUE, 58 | LOWER_IS_BETTER=self.LOWER_IS_BETTER, 59 | METRIC_FAMILY=self.METRIC_FAMILY, 60 | ) 61 | ) 62 | 63 | @property 64 | def tokenizer(self): 65 | return self.helper.tokenizer 66 | 67 | def __init__(self, model, tokenizer, task_name): 68 | if model is None or tokenizer is None: 69 | raise ValueError("Please specify a model and a tokenizer.") 70 | 71 | self.helper = create_helper(model, tokenizer, task_name) 72 | 73 | def __call__(self, explanation: Explanation): 74 | return self.compute_evaluation(explanation) 75 | 76 | @abstractmethod 77 | def compute_evaluation( 78 | self, explanation: Union[Explanation, ExplanationWithRationale] 79 | ): 80 | pass 81 | 82 | # def aggregate_score(self, score, total, **aggregation_args): 83 | # return score / total 84 | 85 | 86 | class SpeechBaseEvaluator(ABC): 87 | """ 88 | Abstract base class for evaluator objects (metrics) for speech 89 | explainability. 90 | 91 | Notes: 92 | * Should we include `MIN_VALUE` and `MAX_VALUE` properties, as for 93 | the text-based evaluators? 94 | """ 95 | @property 96 | @abstractmethod 97 | def NAME(self): 98 | pass 99 | 100 | @property 101 | @abstractmethod 102 | def SHORT_NAME(self): 103 | pass 104 | 105 | @property 106 | @abstractmethod 107 | def LOWER_IS_BETTER(self): 108 | pass 109 | 110 | @property 111 | @abstractmethod 112 | def METRIC_FAMILY(self) -> EvaluationMetricFamily: 113 | pass 114 | 115 | def __repr__(self) -> str: 116 | return str( 117 | dict( 118 | NAME=self.NAME, 119 | SHORT_NAME=self.SHORT_NAME, 120 | LOWER_IS_BETTER=self.LOWER_IS_BETTER, 121 | METRIC_FAMILY=self.METRIC_FAMILY, 122 | ) 123 | ) 124 | 125 | def __init__(self, model_helper, **kwargs): 126 | self.model_helper = model_helper 127 | 128 | @abstractmethod 129 | def compute_evaluation( 130 | self, 131 | explanation: ExplanationSpeech, 132 | target: List = None, 133 | words_trascript: List = None, 134 | **evaluation_args, 135 | ) -> EvaluationSpeech: 136 | pass -------------------------------------------------------------------------------- /docs/source/user_guide/benchmarking.rst: -------------------------------------------------------------------------------- 1 | .. _benchmarking: 2 | 3 | ***************************** 4 | Benchmarking 5 | ***************************** 6 | 7 | 8 | In this page, we show how to use evaluate and compare a set of explanations with our built-in benchmarking methods. 9 | 10 | Given a set of explanations from multiple explainers as described in the :ref:`Explaining ` section, we are interested in quantitatively evaluate and comparing them. 11 | ferret offers multiple measures which evaluate both the :ref:`faithfulness ` and plausibility of explanations. 12 | 13 | .. _benchmarking-faithfulness: 14 | 15 | Evaluate faithfulness 16 | ======================= 17 | Faithfulness evaluates how accurately the explanation reflects the inner working of the model (Jacovi and Goldberg, 2020). 18 | 19 | ferret offers the following measures of faithfulness: 20 | 21 | - :ref:`AOPC Comprehensiveness ` - (aopc_compr, ↑) - goes from 0 to 1 (best) 22 | - :ref:`AOPC Sufficiency ` - (aopc_suff, ↓)) - goes from 0 (best) to 1; 23 | - :ref:`Kendall's Tau correlation with Leave-One-Out token removal ` - (taucorr_loo, ↑) - goes from -1 to 1 (best). 24 | 25 | 26 | The Benchmark class exposes the **evaluate_explanations** method to evaluate the explanations produced. 27 | If no parameters are specified, we compute all supported faithfulness measures for all explanations. 28 | 29 | .. code-block:: python 30 | 31 | explanation_evaluations = bench.evaluate_explanations(explanations, target=2) 32 | 33 | 34 | We can visualize the faithfulness result in a tabular form using the **show_evaluation_table** method. 35 | 36 | .. code-block:: python 37 | 38 | bench.show_evaluation_table(explanation_evaluations) 39 | 40 | Here there is the output for our example. 41 | 42 | 43 | .. image:: _images/example_evaluation_faithfulness_viz.png 44 | :width: 400 45 | :alt: Example of faithfulness result 46 | 47 | 48 | 49 | 50 | .. _benchmarking-plausibility: 51 | 52 | Evaluate plausibility 53 | ======================= 54 | Plausibility reflects how explanations are aligned with human reasoning. 55 | 56 | ferret offers the following measures of plausibility: 57 | 58 | - :ref:`Area-Under-Precision-Recall-Curve ` - (auprc_plau, ↑) - goes from 0 to 1 (best) 59 | - :ref:`Token F1 ` - (token_f1_plau, ↑) - goes from 0 to 1 (best); 60 | - :ref:`Token Intersection Over Union ` - (taucorr_loo, ↑) - goes from 0 to 1 (best). 61 | 62 | The evaluation is performed by comparing explanations with human rationales. 63 | 64 | **Human rationales** are annotations highlighting the most relevant words a human annotator attributed to a given class label. 65 | 66 | 67 | To evaluate the plausibility of explanations for our input text, we need to specify the human rationales, i.e., which tokens we expect to be salient. 68 | In ferret, we represent human rationale as a list with values 0 or 1. A value of 1 means that the corresponding token is part of the human rationale, 0 otherwise. 69 | 70 | In our example with the text 'I love your style!', 'love' and '!' could be the tokens we associate to a positive sentiment (class 2 in our example) as humans. 71 | 72 | 73 | We specify the human rationale as input to the evaluate_explanations method. ferret will evaluate plausibility measures (and faithfulness) for our explanations. 74 | 75 | .. code-block:: python 76 | 77 | human_rationale = {"▁I": 0, "_love": 1, "_your": 0, "_style": 0, "!": 0} 78 | 79 | evaluations = bench.evaluate_explanations(explanations, target=2, human_rationale=list(human_rationale.values())) 80 | 81 | 82 | We can visualize both the faithfulness result and the plausibility ones using the **show_evaluation_table** method. 83 | 84 | .. code-block:: python 85 | 86 | bench.show_evaluation_table(explanation_evaluations) 87 | 88 | Here there is the output for our example. 89 | 90 | 91 | .. image:: _images/example_evaluation_plausibility_viz.png 92 | :width: 600 93 | :alt: Example of faithfulness and plausibility result 94 | 95 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/g8a9/ferret/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | ferret could always use more documentation, whether as part of the 42 | official ferret docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/g8a9/ferret/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `ferret` for local development. 61 | 62 | 1. Fork the `ferret` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/ferret.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv ferret 70 | $ cd ferret/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the 80 | tests, including testing other Python versions with tox:: 81 | 82 | $ flake8 ferret tests 83 | $ python setup.py test or pytest 84 | $ tox 85 | 86 | To get flake8 and tox, just pip install them into your virtualenv. 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | 7. Submit a pull request through the GitHub website. 95 | 96 | Pull Request Guidelines 97 | ----------------------- 98 | 99 | Before you submit a pull request, check that it meets these guidelines: 100 | 101 | 1. The pull request should include tests. 102 | 2. If the pull request adds functionality, the docs should be updated. Put 103 | your new functionality into a function with a docstring, and add the 104 | feature to the list in README.md. 105 | 3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check 106 | https://travis-ci.com/g8a9/ferret/pull_requests 107 | and make sure that the tests pass for all supported Python versions. 108 | 109 | Tips 110 | ---- 111 | 112 | To run a subset of tests:: 113 | 114 | 115 | $ python -m unittest tests.test_ferret 116 | 117 | Deploying 118 | --------- 119 | 120 | A reminder for the maintainers on how to deploy. 121 | Make sure all your changes are committed (including an entry in HISTORY.rst). 122 | Then run:: 123 | 124 | $ bump2version patch # possible: major / minor / patch 125 | $ git push 126 | $ git push --tags 127 | 128 | Travis will then deploy to PyPI if tests pass. 129 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ferret documentation 2 | ==================== 3 | 4 | |pypi badge| |demo badge| |youtube badge| |arxiv badge| |downloads badge| 5 | 6 | .. |pypi badge| image:: https://img.shields.io/pypi/v/ferret-xai.svg 7 | :target: https://pypi.python.org/pypi/ferret-xai 8 | :alt: Latest PyPI version 9 | 10 | .. |Docs Badge| image:: https://readthedocs.org/projects/ferret/badge/?version=latest 11 | :alt: Documentation Status 12 | :target: https://ferret.readthedocs.io/en/latest/?version=latest 13 | 14 | .. |demo badge| image:: https://img.shields.io/badge/HF%20Spaces-Demo-yellow 15 | :alt: HuggingFace Spaces Demo 16 | :target: https://huggingface.co/spaces/g8a9/ferret 17 | 18 | .. |youtube badge| image:: https://img.shields.io/badge/youtube-video-red 19 | :alt: YouTube Video 20 | :target: https://www.youtube.com/watch?v=kX0HcSah_M4 21 | 22 | .. |banner| image:: /_static/banner.png 23 | :alt: Ferret circular logo with the name to the right 24 | 25 | .. |arxiv badge| image:: https://img.shields.io/badge/arXiv-2208.01575-b31b1b.svg 26 | :alt: arxiv preprint 27 | :target: https://arxiv.org/abs/2208.01575 28 | 29 | .. |downloads badge| image:: https://pepy.tech/badge/ferret-xai 30 | :alt: downloads badge 31 | :target: https://pepy.tech/project/ferret-xai 32 | 33 | 34 | ferret is Python library for benchmarking interpretability techniques on 35 | Transformers. 36 | 37 | Use any of the badges above to test our live demo, view a video demonstration, or explore our technical paper in detail. 38 | 39 | 40 | Installation 41 | ------------ 42 | 43 | To install our latest stable release in default mode (which does not include the depenencies for the speech XAI functionalities), run this command in your terminal: 44 | 45 | .. code-block:: console 46 | 47 | pip install -U ferret-xai 48 | 49 | If the speech XAI functionalities are needed, then run: 50 | 51 | .. code-block:: console 52 | 53 | pip install -U ferret-xai[speech] 54 | 55 | At the moment, the speech XAI-related dependencies are the only extra ones, so installing with :code:`ferret-xai[speech]` or :code:`ferret-xai[all]` is equivalent. 56 | 57 | These are the preferred methods to install ferret, as they will always install the most recent stable release. 58 | 59 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 60 | you through the process. 61 | 62 | .. _pip: https://pip.pypa.io 63 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 64 | 65 | 66 | Citation 67 | -------- 68 | 69 | If you are using ferret for your work, please consider citing us! 70 | 71 | .. code-block:: bibtex 72 | 73 | @inproceedings{attanasio-etal-2023-ferret, 74 | title = "ferret: a Framework for Benchmarking Explainers on Transformers", 75 | author = "Attanasio, Giuseppe and Pastor, Eliana and Di Bonaventura, Chiara and Nozza, Debora", 76 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics: System Demonstrations", 77 | month = may, 78 | year = "2023", 79 | publisher = "Association for Computational Linguistics", 80 | } 81 | 82 | Also, ferret's Speech XAI functionalities are based on 83 | 84 | .. code-block:: bibtex 85 | 86 | @misc{pastor2023explaining, 87 | title " Explaining Speech Classification Models via Word-Level Audio Segments and Paralinguistic Features", 88 | author= "Pastor, Eliana and Koudounas, Alkis and Attanasio, Giuseppe and Hovy, Dirk and Baralis, Elena", 89 | month = september, 90 | year = "2023", 91 | eprint = "2309.07733", 92 | archivePrefix = "arXiv", 93 | primaryClass = "cs.CL", 94 | publisher = "", 95 | } 96 | 97 | 98 | .. Indices and tables 99 | .. ================== 100 | .. * :ref:`genindex` 101 | .. * :ref:`modindex` 102 | .. * :ref:`search` 103 | 104 | Index 105 | ----- 106 | 107 | .. toctree:: 108 | :maxdepth: 2 109 | 110 | user_guide/index 111 | api/index 112 | history -------------------------------------------------------------------------------- /ferret/modeling/speech_model_helpers/model_helper_italic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List, Union, Tuple 3 | import torch 4 | from pydub import AudioSegment 5 | from ...speechxai_utils import pydub_to_np 6 | 7 | 8 | class ModelHelperITALIC: 9 | """ 10 | Wrapper class ITALIC dataset 11 | """ 12 | 13 | def __init__(self, model, feature_extractor, device, language="it"): 14 | self.model = model 15 | self.feature_extractor = feature_extractor 16 | self.device = device 17 | self.n_labels = 1 # Single label problem 18 | self.max_duration = 10.0 19 | self.language = language 20 | self.label_name = "class" 21 | 22 | # PREDICT SINGLE 23 | def predict( 24 | self, 25 | audios: List[np.ndarray], 26 | ) -> np.ndarray: 27 | """ 28 | Predicts action, object and location from audio one sample at a time. 29 | Returns probs for each class. 30 | # We do separately for consistency with FSC/IC model and the bug of padding 31 | """ 32 | 33 | probs = np.empty((len(audios), self.model.config.num_labels)) 34 | for e, audio in enumerate(audios): 35 | probs[e] = self._predict([audio]) 36 | return probs 37 | 38 | def _predict( 39 | self, 40 | audios: List[np.ndarray], 41 | ) -> np.ndarray: 42 | """ 43 | Predicts emotion from audio. 44 | Returns probs for each class. 45 | """ 46 | 47 | ## Extract features 48 | inputs = self.feature_extractor( 49 | [audio.squeeze() for audio in audios], 50 | sampling_rate=self.feature_extractor.sampling_rate, 51 | return_tensors="pt", 52 | max_length=int(self.feature_extractor.sampling_rate * self.max_duration), 53 | truncation=True, 54 | padding="max_length", 55 | ) 56 | 57 | ## Predict logits 58 | with torch.no_grad(): 59 | logits = ( 60 | self.model(inputs.input_values.to(self.device)) 61 | .logits.detach() 62 | .cpu() 63 | # .numpy() 64 | ) 65 | logits = logits 66 | 67 | return logits.softmax(-1).numpy() 68 | 69 | def get_logits_from_input_embeds(self, input_embeds): 70 | logits = self.model(input_embeds.to(self.device)).logits 71 | return logits 72 | 73 | def get_text_labels(self, targets) -> str: 74 | if type(targets) is list: 75 | class_index = targets[0] 76 | else: 77 | class_index = targets 78 | return self.model.config.id2label[class_index] 79 | 80 | def get_text_labels_with_class(self, targets) -> str: 81 | """ 82 | Return the text labels with the class name as strings (e.g., ['action = increase', 'object = lights', 'location = kitchen']]) 83 | """ 84 | text_target = self.get_text_labels(targets) 85 | return f"{self.label_name}={text_target}" 86 | 87 | def get_predicted_classes(self, audio_path=None, audio=None): 88 | if audio is None and audio_path is None: 89 | raise ValueError("Specify the audio path or the audio as a numpy array") 90 | 91 | if audio is None: 92 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 93 | 94 | logits = self.predict([audio]) 95 | predicted_ids = np.argmax(logits, axis=1)[0] 96 | return predicted_ids 97 | 98 | def get_predicted_probs(self, audio_path=None, audio=None): 99 | if audio is None and audio_path is None: 100 | raise ValueError("Specify the audio path or the audio as a numpy array") 101 | 102 | if audio is None: 103 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 104 | 105 | logits = self.predict([audio]) 106 | predicted_id = np.argmax(logits, axis=1)[0] 107 | 108 | # TODO - these are not the logits, but the probs.. rename! 109 | 110 | predicted_probs = logits[:, predicted_id][0] 111 | return predicted_probs -------------------------------------------------------------------------------- /ferret/datasets/utils_sst_rationale_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | From https://github.com/BoulderDS/evaluating-human-rationales/blob/66402dbe8ccdf8b841c185cd8050b8bdc04ef3d2/scripts/download_and_process_sst.py 4 | Evaluating and Characterizing Human Rationales 5 | Samuel Carton, Anirudh Rathore, Chenhao Tan 6 | 7 | MIT License by Samuel Carton 8 | 9 | Copyright (c) 2020 Data Science @ University of Colorado Boulder 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | 29 | """ 30 | 31 | import numpy as np 32 | import pytreebank 33 | 34 | 35 | def get_leaves(tree): 36 | leaves = [] 37 | if len(tree.children) > 0: 38 | for child in tree.children: 39 | leaves += get_leaves(child) 40 | else: 41 | leaves.append(tree) 42 | return leaves 43 | 44 | 45 | def get_sst_rationale(item): 46 | """ 47 | Author: Eliana Pastor 48 | Adapted from https://github.com/BoulderDS/evaluating-human-rationales/blob/66402dbe8ccdf8b841c185cd8050b8bdc04ef3d2/scripts/download_and_process_sst.py#L74 49 | """ 50 | rationale = [] 51 | count_leaves_and_extreme_descendants(item) 52 | phrases = [] 53 | assemble_rationale_phrases(item, phrases) 54 | for phrase in phrases: 55 | phrase_rationale = [np.abs(normalize_label(phrase.label))] * phrase.num_leaves 56 | rationale.extend(phrase_rationale) 57 | pass 58 | rationale = np.array(rationale) 59 | return rationale 60 | 61 | 62 | def explanatory_phrase(tree): 63 | if len(tree.children) == 0: 64 | return True 65 | else: 66 | normalized_label = normalize_label(tree.label) 67 | normalized_max_descendant = normalize_label(tree.max_descendant) 68 | normalized_min_descendant = normalize_label(tree.min_descendant) 69 | 70 | if abs(normalized_label) > abs(normalized_max_descendant) and abs( 71 | normalized_label 72 | ) > abs(normalized_min_descendant): 73 | return True 74 | else: 75 | return False 76 | 77 | 78 | def assemble_rationale_phrases(tree, phrases, **kwargs): 79 | if explanatory_phrase(tree, **kwargs): 80 | phrases.append(tree) 81 | else: 82 | for child in tree.children: 83 | assemble_rationale_phrases(child, phrases, **kwargs) 84 | 85 | 86 | def count_leaves_and_extreme_descendants(tree): 87 | 88 | if len(tree.children) == 0: # if is leaf 89 | tcount = 1 90 | tmax = tmin = tree.label 91 | else: 92 | tcount = 0 93 | child_labels = [child.label for child in tree.children] 94 | tmax = max(child_labels) 95 | tmin = min(child_labels) 96 | 97 | for child in tree.children: 98 | ccount, cmax, cmin = count_leaves_and_extreme_descendants(child) 99 | tcount += ccount 100 | tmax = max(tmax, cmax) 101 | tmin = min(tmin, cmin) 102 | 103 | tree.num_leaves = tcount 104 | tree.max_descendant = tmax 105 | tree.min_descendant = tmin 106 | 107 | if tree.label == 4: 108 | _ = None 109 | return tcount, tmax, tmin 110 | 111 | 112 | def normalize_label(label): 113 | return (label - 2) / 2 114 | -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/loo_speech_explainer.py: -------------------------------------------------------------------------------- 1 | """LOO Speech Explainer module""" 2 | import numpy as np 3 | from typing import Dict, List, Union, Tuple 4 | from pydub import AudioSegment 5 | from IPython.display import display 6 | from .explanation_speech import ExplanationSpeech 7 | from .utils_removal import transcribe_audio, remove_word 8 | from ...speechxai_utils import pydub_to_np, print_log 9 | 10 | 11 | class LOOSpeechExplainer: 12 | NAME = "loo_speech" 13 | 14 | def __init__(self, model_helper): 15 | self.model_helper = model_helper 16 | 17 | def remove_words( 18 | self, 19 | audio_path: str, 20 | removal_type: str = "nothing", 21 | words_trascript: List = None, 22 | display_audio: bool = False, 23 | ) -> Tuple[List[AudioSegment], List[Dict[str, Union[str, float]]]]: 24 | """ 25 | Remove words from audio using pydub, by replacing them with: 26 | - nothing 27 | - silence 28 | - white noise 29 | - pink noise 30 | """ 31 | 32 | ## Transcribe audio 33 | 34 | if words_trascript is None: 35 | text, words_trascript = transcribe_audio( 36 | audio_path=audio_path, 37 | device=self.model_helper.device.type, 38 | batch_size=2, 39 | compute_type="float32", 40 | language=self.model_helper.language, 41 | ) 42 | 43 | ## Load audio as pydub.AudioSegment 44 | audio = AudioSegment.from_wav(audio_path) 45 | 46 | ## Remove word 47 | audio_no_words = [] 48 | 49 | for word in words_trascript: 50 | audio_removed = remove_word(audio, word, removal_type) 51 | 52 | audio_no_words.append(pydub_to_np(audio_removed)[0]) 53 | 54 | if display_audio: 55 | print_log(word["word"]) 56 | display(audio_removed) 57 | 58 | return audio_no_words, words_trascript 59 | 60 | def compute_explanation( 61 | self, 62 | audio_path: str, 63 | target_class=None, 64 | removal_type: str = None, 65 | words_trascript: List = None, 66 | ) -> ExplanationSpeech: 67 | """ 68 | Computes the importance of each word in the audio. 69 | """ 70 | 71 | ## Get modified audio by leaving a single word out and the words 72 | modified_audios, words = self.remove_words( 73 | audio_path, removal_type, words_trascript=words_trascript 74 | ) 75 | 76 | logits_modified = self.model_helper.predict(modified_audios) 77 | 78 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 79 | 80 | logits_original = self.model_helper.predict([audio]) 81 | 82 | # Check if single label or multilabel scenario as for FSC 83 | n_labels = self.model_helper.n_labels 84 | 85 | # TODO 86 | if target_class is not None: 87 | targets = target_class 88 | 89 | else: 90 | if n_labels > 1: 91 | # Multilabel scenario as for FSC 92 | targets = [ 93 | np.argmax(logits_original[i], axis=1)[0] for i in range(n_labels) 94 | ] 95 | else: 96 | targets = np.argmax(logits_original, axis=1)[0] 97 | 98 | ## Get the most important word for each class (action, object, location) 99 | 100 | if n_labels > 1: 101 | # Multilabel scenario as for FSC 102 | modified_trg = [logits_modified[i][:, targets[i]] for i in range(n_labels)] 103 | original_gt = [ 104 | logits_original[i][:, targets[i]][0] for i in range(n_labels) 105 | ] 106 | 107 | else: 108 | modified_trg = logits_modified[:, targets] 109 | original_gt = logits_original[:, targets][0] 110 | 111 | features = [word["word"] for word in words] 112 | 113 | if n_labels > 1: 114 | # Multilabel scenario as for FSC 115 | prediction_diff = [ 116 | original_gt[i] - modified_trg[i] for i in range(n_labels) 117 | ] 118 | else: 119 | prediction_diff = [original_gt - modified_trg] 120 | 121 | scores = np.array(prediction_diff) 122 | 123 | explanation = ExplanationSpeech( 124 | features=features, 125 | scores=scores, 126 | explainer=self.NAME + "+" + removal_type, 127 | target=targets if n_labels > 1 else [targets], 128 | audio_path=audio_path, 129 | ) 130 | 131 | return explanation 132 | -------------------------------------------------------------------------------- /docs/source/user_guide/tasks.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Tasks Documentation 3 | ===================== 4 | 5 | This document provides a comprehensive guide to the tasks available in the Ferret project. Each task is detailed with its purpose, usage, and associated parameters. 6 | 7 | .. contents:: 8 | :local: 9 | :depth: 2 10 | 11 | Sequence Classification 12 | ======================= 13 | 14 | .. _sequence-classification: 15 | 16 | Introduction 17 | ------------ 18 | Sequence Classification is a task that involves categorizing text sequences into predefined labels or classes. This task is commonly used for sentiment analysis, topic labeling, and similar applications where text needs to be classified according to its content or sentiment. 19 | 20 | Usage 21 | ----- 22 | .. code-block:: python 23 | 24 | 25 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 26 | from ferret import Benchmark 27 | model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment") 28 | tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-xlm-roberta-base-sentiment") 29 | bench = Benchmark(model, tokenizer) 30 | text = "You look stunning!" 31 | exp = bench.explain(text, target=1) 32 | bench.show_table(exp) 33 | # 'explanation' contains SHAP values for each token in the text. 34 | 35 | Natural Language Inference (NLI) 36 | ================================= 37 | 38 | .. _natural-language-inference: 39 | 40 | Introduction 41 | ------------ 42 | Natural Language Inference focuses on determining the relationship between a premise and a hypothesis, categorizing the relationship as entailment, contradiction, or neutral. 43 | 44 | Usage 45 | ----- 46 | .. code-block:: python 47 | 48 | 49 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 50 | from ferret import Benchmark 51 | model = AutoModelForSequenceClassification.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli") 52 | tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli") 53 | premise = "A soccer game with multiple males playing." 54 | hypothesis = "A sports activity." 55 | sample = (premise, hypothesis) 56 | bench = Benchmark(model, tokenizer, task_name="nli") 57 | exp = bench.explain(sample, target="contradiction") 58 | bench.show_table(exp) 59 | 60 | Zero-Shot Classification 61 | ======================== 62 | 63 | .. _zero-shot-classification: 64 | 65 | Introduction 66 | ------------ 67 | Zero-Shot Classification refers to classifying text into categories that were not seen during training. It's used for tasks where predefined categories are not available. 68 | 69 | Usage 70 | ----- 71 | .. code-block:: python 72 | 73 | 74 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 75 | from ferret import Benchmark 76 | 77 | tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") 78 | model = AutoModelForSequenceClassification.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli") 79 | sequence_to_classify = "A new Tesla model was unveiled." 80 | candidate_labels = ["technology", "economy", "sports"] 81 | bench = Benchmark(model, tokenizer, task_name="zero-shot-text-classification") 82 | scores = bench.score(sequence_to_classify, options=candidate_labels, return_probs=True) 83 | # get the label with the highest score, and use it as 'target_option' 84 | most_probable_label = max(scores, key=scores.get) 85 | exp = bench.explain(sequence_to_classify, target="entailment", target_option=most_probable_label) 86 | # 'explanation' shows how the model associates the text with the categories. 87 | 88 | Named Entity Recognition (NER) 89 | ============================== 90 | 91 | .. _named-entity-recognition: 92 | 93 | Introduction 94 | ------------ 95 | Named Entity Recognition involves identifying and categorizing key information (entities) in text, such as names of people, places, organizations, etc. 96 | 97 | Usage 98 | ----- 99 | .. code-block:: python 100 | 101 | 102 | from transformers import AutoModelForTokenClassification, AutoTokenizer 103 | from ferret import Benchmark 104 | tokenizer = AutoTokenizer.from_pretrained("Babelscape/wikineural-multilingual-ner") 105 | model = AutoModelForTokenClassification.from_pretrained("Babelscape/wikineural-multilingual-ner") 106 | text = "My name is John and I live in New York" 107 | bench = Benchmark(model, tokenizer, task_name="ner") 108 | exp = bench.explain(text, target="I-LOC", target_token="York") 109 | bench.show_table(exp) 110 | .. note:: 111 | The usage examples provided in this document are intended to guide users through the various tasks. For detailed explanations of the different explainers, please refer to the respective documentation files. -------------------------------------------------------------------------------- /docs/source/user_guide/notions.explainers.rst: -------------------------------------------------------------------------------- 1 | .. _notions.explainers: 2 | 3 | ************************************ 4 | Post-Hoc Feature Attribution Methods 5 | ************************************ 6 | 7 | Post-hoc feature attribution methods explain why a model made a specific prediction for a given text. 8 | These methods assign an importance score to each input. In the context of text data, we typically assign a score to each token, and so in ferret. 9 | Given a model, a target class, and a prediction, ferret lets you measure how much each token contributed to that prediction. 10 | 11 | 12 | .. _overview-explainers: 13 | 14 | Overview of explainers 15 | ---------------------------- 16 | 17 | ferret integrates thee following post-hoc attribution methods: 18 | 19 | - :ref:`LIME ` 20 | - :ref:`SHAP ` 21 | - :ref:`Gradient `, plain gradients or multiplied by input token embeddings 22 | - :ref:`Integrated Gradient `, plain gradients or multiplied by input token embeddings 23 | 24 | 25 | .. _explainers-lime: 26 | 27 | LIME 28 | ---------------------------- 29 | LIME (Local Interpretable Model-agnostic Explanations) is a model-agnostic method for explaining individual predictions by learning an interpretable model in the locality of the prediction. 30 | The interpretable model is a local surrogate model that mimics the behavior of the original model locally. 31 | 32 | LIME learns an interpretable model only in the locality of the instance to derive the relevant feature for the individual label assignment. The approach derives the locality by generating perturbed samples of the instance, weighting the samples by their proximity. LIME optimizes the fidelity of the local surrogate model to the original one while preserving its understandability. 33 | 34 | More details can be found in the `LIME paper `_. 35 | 36 | 37 | ferret uses the Captum implementation of `LIME `_. 38 | 39 | 40 | .. _explainers-shap: 41 | 42 | SHAP 43 | ---------------------------- 44 | 45 | SHAP (SHapley Additive exPlanations) is a game theoretic approach to explain individual predictions. 46 | 47 | SHAP is based on the notion of Shapley values. The Shapley value is a concept from coalition game theory to assign a score to the players who cooperate to achieve a specific total score. In the context of prediction explanations, the attribute values of the instance to explain are the players, and the prediction probability is the score. 48 | The exact estimation requires the computation of the omission effect for the power set of the attributes. Hence, multiple solutions have been proposed to overcome the problem of its exponential complexity. 49 | 50 | SHAP's authors propose practical approximations for estimating Shapley values as KernelSHAP and PartitionSHAP. KernelSHAP is an approximation approach based on local surrogate models. The estimation is based on weighted linear regression models. 51 | More recently, the authors proposed PartitionSHAP, which uses hierarchical data clustering to define feature coalitions. The approach, as KernelSHAP, is model agnostic but makes the computation more tractable and typically requires less time. 52 | More details can be found in the `SHAP paper `_ and SHAP library `documentation `_. 53 | 54 | 55 | ferret integrates `SHAP library implementation `_ and PartitionSHAP as the default algorithm, which is also the default for textual data in the SHAP library. 56 | 57 | 58 | .. _explainers-gradient: 59 | 60 | Gradient (Saliency) and GradientXInput 61 | ---------------------------------------- 62 | 63 | Gradient approach, also known as Saliency, is one of the first gradient-based approaches. This class of approaches computes the gradient of the prediction score with respect to the input features and the methods differ on how the gradient is computed. 64 | Gradient approach directly computes the gradient of the loss function for the target class with respect to the input. 65 | More details can be found in the `corresponding paper `_. 66 | 67 | The GradientXInput approach multiplies the gradient with respect to input with the input itself. More details can be found `here `_. 68 | 69 | ferret uses Captum implementations of `Gradient `_ and `GradientXInput `_. 70 | 71 | 72 | .. _explainers-integratedgradient: 73 | 74 | Integrated Gradients and Integrated Gradient X Input 75 | ------------------------------------------------------- 76 | 77 | Integrated Gradients is a gradient-based approach. 78 | The approach considers a baseline input that consist in an informationless input. In the case of text, it could corresponds to an empty text or zero embedding vector. 79 | The approach consider the straightline path from the baseline to the input, and compute the gradients along the path. 80 | Integrated gradients are obtained by cumulating these gradients. 81 | 82 | The method description can be found in the original `paper `_. 83 | 84 | ferret adopts the `Captum implementation `_ and also includes the version multiplied for the input. -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/equal_width/loo_equal_width_explainer.py: -------------------------------------------------------------------------------- 1 | """LOO Speech Explainer module""" 2 | import os 3 | import numpy as np 4 | from typing import Dict, List, Union, Tuple 5 | import whisperx 6 | from pydub import AudioSegment 7 | from IPython.display import display 8 | from ..explanation_speech import ExplanationSpeech 9 | from ....speechxai_utils import pydub_to_np, print_log 10 | 11 | 12 | def remove_audio_segment(audio, start_s, end_s, removal_type: str = "silence"): 13 | """ 14 | Remove an audio segment from audio using pydub, by replacing it with: 15 | - nothing 16 | - silence 17 | - white noise 18 | - pink noise 19 | 20 | Args: 21 | audio (pydub.AudioSegment): audio 22 | word: word to remove with its start and end times 23 | removal_type (str, optional): type of removal. Defaults to "nothing". 24 | """ 25 | 26 | start_idx = int(start_s * 1000) 27 | end_idx = int(end_s * 1000) 28 | before_word_audio = audio[:start_idx] 29 | after_word_audio = audio[end_idx:] 30 | word_duration = end_idx - start_idx 31 | 32 | if removal_type == "nothing": 33 | replace_word_audio = AudioSegment.empty() 34 | elif removal_type == "silence": 35 | replace_word_audio = AudioSegment.silent(duration=word_duration) 36 | 37 | elif removal_type == "white noise": 38 | sound_path = (os.path.join(os.path.dirname(__file__), "white_noise.mp3"),) 39 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 40 | 41 | # display(audio_removed) 42 | elif removal_type == "pink noise": 43 | sounds_path = (os.path.join(os.path.dirname(__file__), "pink_noise.mp3"),) 44 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 45 | 46 | audio_removed = before_word_audio + replace_word_audio + after_word_audio 47 | return audio_removed 48 | 49 | 50 | class LOOSpeechEqualWidthExplainer: 51 | NAME = "loo_speech_equal_width" 52 | 53 | def __init__(self, model_helper): 54 | self.model_helper = model_helper 55 | 56 | def compute_explanation( 57 | self, 58 | audio_path: str, 59 | target_class=None, 60 | removal_type: str = "silence", 61 | num_s_split: float = 0.25, 62 | display_audio: bool = False, 63 | ) -> ExplanationSpeech: 64 | """ 65 | Computes the importance of each equal width audio segment in the audio. 66 | """ 67 | 68 | ## Load audio as pydub.AudioSegment 69 | audio = AudioSegment.from_wav(audio_path) 70 | audio_np = pydub_to_np(audio)[0] 71 | 72 | ## Remove word 73 | audio_remove_segments = [] 74 | 75 | duration_s = len(audio) / 1000 76 | 77 | for i in np.arange(0, duration_s, num_s_split): 78 | start_s = i 79 | end_s = min(i + num_s_split, duration_s) 80 | audio_removed = remove_audio_segment(audio, start_s, end_s, removal_type) 81 | 82 | audio_remove_segments.append(pydub_to_np(audio_removed)[0]) 83 | 84 | if display_audio: 85 | print_log(int(start_s / num_s_split), start_s, end_s) 86 | display(audio_removed) 87 | 88 | # Get original logits 89 | logits_original = self.model_helper.predict([audio_np]) 90 | 91 | # Get logits for the modified audio by leaving out the equal width segments 92 | logits_modified = self.model_helper.predict(audio_remove_segments) 93 | 94 | # Check if single label or multilabel scenario as for FSC 95 | n_labels = self.model_helper.n_labels 96 | 97 | # TODO 98 | if target_class is not None: 99 | targets = target_class 100 | 101 | else: 102 | if n_labels > 1: 103 | # Multilabel scenario as for FSC 104 | targets = [ 105 | np.argmax(logits_original[i], axis=1)[0] for i in range(n_labels) 106 | ] 107 | else: 108 | targets = np.argmax(logits_original, axis=1)[0] 109 | 110 | ## Get the most important word for each class (action, object, location) 111 | 112 | if n_labels > 1: 113 | # Multilabel scenario as for FSC 114 | modified_trg = [logits_modified[i][:, targets[i]] for i in range(n_labels)] 115 | original_gt = [ 116 | logits_original[i][:, targets[i]][0] for i in range(n_labels) 117 | ] 118 | 119 | else: 120 | modified_trg = logits_modified[:, targets] 121 | original_gt = logits_original[:, targets][0] 122 | 123 | features = [idx for idx in range(len(audio_remove_segments))] 124 | 125 | if n_labels > 1: 126 | # Multilabel scenario as for FSC 127 | prediction_diff = [ 128 | original_gt[i] - modified_trg[i] for i in range(n_labels) 129 | ] 130 | else: 131 | prediction_diff = [original_gt - modified_trg] 132 | 133 | scores = np.array(prediction_diff) 134 | 135 | explanation = ExplanationSpeech( 136 | features=features, 137 | scores=scores, 138 | explainer=self.NAME + "+" + removal_type, 139 | target=targets if n_labels > 1 else [targets], 140 | audio_path=audio_path, 141 | ) 142 | 143 | return explanation -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/equal_width/lime_equal_width_explainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydub import AudioSegment 3 | import numpy as np 4 | from ..lime_timeseries import LimeTimeSeriesExplainer 5 | from ..utils_removal import transcribe_audio 6 | from ..explanation_speech import ExplanationSpeech 7 | from ....speechxai_utils import pydub_to_np 8 | 9 | EMPTY_SPAN = "---" 10 | 11 | 12 | class LIMEEqualWidthSpeechExplainer: 13 | NAME = "LIME_equal_width" 14 | 15 | def __init__(self, model_helper): 16 | self.model_helper = model_helper 17 | 18 | def compute_explanation( 19 | self, 20 | audio_path: str, 21 | target_class=None, 22 | removal_type: str = "silence", 23 | num_samples: int = 1000, 24 | num_s_split: float = 0.25, 25 | ) -> ExplanationSpeech: 26 | """ 27 | Compute the word-level explanation for the given audio. 28 | audio_path: path to the audio file 29 | target_class: target class - int - If None, use the predicted class 30 | removal_type: 31 | """ 32 | 33 | if removal_type not in ["silence", "noise"]: 34 | raise ValueError( 35 | "Removal method not supported, choose between 'silence' and 'noise'" 36 | ) 37 | 38 | # Load audio and convert to np.array 39 | audio_as = AudioSegment.from_wav(audio_path) 40 | audio = pydub_to_np(audio_as)[0] 41 | 42 | # Predict logits/probabilities 43 | logits_original = self.model_helper.predict([audio]) 44 | 45 | # Check if single label or multilabel scenario as for FSC 46 | n_labels = self.model_helper.n_labels 47 | 48 | # TODO 49 | if target_class is not None: 50 | targets = target_class 51 | 52 | else: 53 | if n_labels > 1: 54 | # Multilabel scenario as for FSC 55 | targets = [ 56 | int(np.argmax(logits_original[i], axis=1)[0]) 57 | for i in range(n_labels) 58 | ] 59 | else: 60 | targets = [int(np.argmax(logits_original, axis=1)[0])] 61 | 62 | audio_np = audio.reshape(1, -1) 63 | 64 | # Get the start and end indexes of the segments. These will be used to split the audio and derive LIME interpretable features 65 | sampling_rate = self.model_helper.feature_extractor.sampling_rate 66 | splits = [] 67 | 68 | duration_s = len(audio_as) / 1000 69 | 70 | a, b = 0, 0 71 | for e, i in enumerate(np.arange(0, duration_s, num_s_split)): 72 | start_s = i 73 | end_s = min(i + num_s_split, duration_s) 74 | 75 | start, end = int((start_s + a) * sampling_rate), int( 76 | (end_s + b) * sampling_rate 77 | ) 78 | splits.append({"start": start, "end": end, "word": e}) 79 | 80 | lime_explainer = LimeTimeSeriesExplainer() 81 | 82 | # Compute gradient importance for each target label 83 | # This also handles the multilabel scenario as for FSC 84 | scores = [] 85 | for target_label, target_class in enumerate(targets): 86 | if self.model_helper.n_labels > 1: 87 | # We get the prediction probability for the given label 88 | predict_proba_function = ( 89 | self.model_helper.get_prediction_function_by_label(target_label) 90 | ) 91 | else: 92 | predict_proba_function = self.model_helper.predict 93 | from copy import deepcopy 94 | 95 | input_audio = deepcopy(audio_np) 96 | 97 | # Explain the instance using the splits as interpretable features 98 | exp = lime_explainer.explain_instance( 99 | input_audio, 100 | predict_proba_function, 101 | num_features=len(splits), 102 | num_samples=num_samples, 103 | num_slices=len(splits), 104 | replacement_method=removal_type, 105 | splits=splits, 106 | labels=(target_class,), 107 | ) 108 | 109 | map_scores = {k: v for k, v in exp.as_map()[target_class]} 110 | map_scores = { 111 | k: v 112 | for k, v in sorted( 113 | map_scores.items(), key=lambda x: x[0], reverse=False 114 | ) 115 | } 116 | 117 | # Remove the 'empty' spans, the spans between words 118 | map_scores = [ 119 | (splits[k]["word"], v) 120 | for k, v in map_scores.items() 121 | if splits[k]["word"] != EMPTY_SPAN 122 | ] 123 | 124 | features = list(list(zip(*map_scores))[0]) 125 | importances = list(list(zip(*map_scores))[1]) 126 | scores.append(np.array(importances)) 127 | 128 | if n_labels > 1: 129 | # Multilabel scenario as for FSC 130 | scores = np.array(scores) 131 | else: 132 | scores = np.array([importances]) 133 | 134 | explanation = ExplanationSpeech( 135 | features=features, 136 | scores=scores, 137 | explainer=self.NAME + "+" + removal_type, 138 | target=targets if n_labels > 1 else targets, 139 | audio_path=audio_path, 140 | ) 141 | 142 | return explanation -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/lime_speech_explainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydub import AudioSegment 3 | import numpy as np 4 | from .lime_timeseries import LimeTimeSeriesExplainer 5 | from .utils_removal import transcribe_audio 6 | from .explanation_speech import ExplanationSpeech 7 | from ...speechxai_utils import pydub_to_np 8 | 9 | EMPTY_SPAN = "---" 10 | 11 | 12 | class LIMESpeechExplainer: 13 | NAME = "LIME" 14 | 15 | def __init__(self, model_helper): 16 | self.model_helper = model_helper 17 | 18 | def compute_explanation( 19 | self, 20 | audio_path: str, 21 | target_class=None, 22 | words_trascript: List = None, 23 | removal_type: str = "silence", 24 | num_samples: int = 1000, 25 | ) -> ExplanationSpeech: 26 | """ 27 | Compute the word-level explanation for the given audio. 28 | Args: 29 | audio_path: path to the audio file 30 | target_class: target class - int - If None, use the predicted class 31 | removal_type: 32 | """ 33 | 34 | if removal_type not in ["silence", "noise"]: 35 | raise ValueError( 36 | "Removal method not supported, choose between 'silence' and 'noise'" 37 | ) 38 | 39 | # Load audio and convert to np.array 40 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 41 | 42 | # Predict logits/probabilities 43 | logits_original = self.model_helper.predict([audio]) 44 | 45 | # Check if single label or multilabel scenario as for FSC 46 | n_labels = self.model_helper.n_labels 47 | 48 | # TODO 49 | if target_class is not None: 50 | targets = target_class 51 | 52 | else: 53 | if n_labels > 1: 54 | # Multilabel scenario as for FSC 55 | targets = [ 56 | int(np.argmax(logits_original[i], axis=1)[0]) 57 | for i in range(n_labels) 58 | ] 59 | else: 60 | targets = [int(np.argmax(logits_original, axis=1)[0])] 61 | 62 | if words_trascript is None: 63 | # Transcribe audio 64 | _, words_trascript = transcribe_audio( 65 | audio_path=audio_path, language=self.model_helper.language 66 | ) 67 | audio_np = audio.reshape(1, -1) 68 | 69 | # Get the start and end indexes of the words. These will be used to split the audio and derive LIME interpretable features 70 | tot_len = audio.shape[0] 71 | sampling_rate = self.model_helper.feature_extractor.sampling_rate 72 | splits = [] 73 | old_start = 0 74 | a, b = 0, 0 75 | for word in words_trascript: 76 | start, end = int((word["start"] + a) * sampling_rate), int( 77 | (word["end"] + b) * sampling_rate 78 | ) 79 | splits.append({"start": old_start, "end": start, "word": EMPTY_SPAN}) 80 | splits.append({"start": start, "end": end, "word": word["word"]}) 81 | old_start = end 82 | splits.append({"start": old_start, "end": tot_len, "word": EMPTY_SPAN}) 83 | 84 | lime_explainer = LimeTimeSeriesExplainer() 85 | 86 | # Compute gradient importance for each target label 87 | # This also handles the multilabel scenario as for FSC 88 | scores = [] 89 | for target_label, target_class in enumerate(targets): 90 | if self.model_helper.n_labels > 1: 91 | # We get the prediction probability for the given label 92 | predict_proba_function = ( 93 | self.model_helper.get_prediction_function_by_label(target_label) 94 | ) 95 | else: 96 | predict_proba_function = self.model_helper.predict 97 | from copy import deepcopy 98 | 99 | input_audio = deepcopy(audio_np) 100 | 101 | # Explain the instance using the splits as interpretable features 102 | exp = lime_explainer.explain_instance( 103 | input_audio, 104 | predict_proba_function, 105 | num_features=len(splits), 106 | num_samples=num_samples, 107 | num_slices=len(splits), 108 | replacement_method=removal_type, 109 | splits=splits, 110 | labels=(target_class,), 111 | ) 112 | 113 | map_scores = {k: v for k, v in exp.as_map()[target_class]} 114 | map_scores = { 115 | k: v 116 | for k, v in sorted( 117 | map_scores.items(), key=lambda x: x[0], reverse=False 118 | ) 119 | } 120 | 121 | # Remove the 'empty' spans, the spans between words 122 | map_scores = [ 123 | (splits[k]["word"], v) 124 | for k, v in map_scores.items() 125 | if splits[k]["word"] != EMPTY_SPAN 126 | ] 127 | if map_scores == []: 128 | features = [] 129 | importances = [] 130 | else: 131 | features = list(list(zip(*map_scores))[0]) 132 | importances = list(list(zip(*map_scores))[1]) 133 | scores.append(np.array(importances)) 134 | 135 | if n_labels > 1: 136 | # Multilabel scenario as for FSC 137 | scores = np.array(scores) 138 | else: 139 | scores = np.array([importances]) 140 | 141 | explanation = ExplanationSpeech( 142 | features=features, 143 | scores=scores, 144 | explainer=self.NAME + "+" + removal_type, 145 | target=targets if n_labels > 1 else targets, 146 | audio_path=audio_path, 147 | ) 148 | 149 | return explanation -------------------------------------------------------------------------------- /ferret/explainers/lime.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from lime.lime_text import LimeTextExplainer 7 | 8 | from . import BaseExplainer 9 | from .explanation import Explanation 10 | from .utils import parse_explainer_args 11 | 12 | 13 | class LIMEExplainer(BaseExplainer): 14 | NAME = "LIME" 15 | 16 | def compute_feature_importance( 17 | self, 18 | text, 19 | target=1, 20 | target_token: Optional[Union[int, str]] = None, 21 | token_masking_strategy="mask", 22 | batch_size=8, 23 | show_progress=True, 24 | num_samples=None, 25 | max_samples=5000, 26 | **kwargs 27 | ): 28 | # init_args, call_args = parse_explainer_args(explainer_args) 29 | # sanity checks 30 | target_pos_idx = self.helper._check_target(target) 31 | text = self.helper._check_sample(text) 32 | target_token_pos_idx = self.helper._check_target_token(text, target_token) 33 | 34 | # token_masking_strategy = call_args.pop("token_masking_strategy", "mask") 35 | # show_progress = call_args.pop("show_progress", False) 36 | # batch_size = call_args.pop("batch_size", 8) 37 | 38 | 39 | def fn_prediction_token_ids(token_ids_sentences: List[str]): 40 | """Run inference on a list of strings made of token ids. 41 | 42 | Masked token ids are represented with 'UNKWORDZ'. 43 | Note that with transformers language models, results differ if tokens are masked or removed before inference. 44 | We let the user choose with the parameter 'token_masking_strategy' 45 | 46 | :param token_ids_sentences: list of strings made of token ids. 47 | """ 48 | if token_masking_strategy == "mask": 49 | unk_substitute = str(self.helper.tokenizer.mask_token_id) 50 | elif token_masking_strategy == "remove": 51 | #  TODO We don't have yet a way to handle empty string produced by sampling 52 | raise NotImplementedError() 53 | #  unk_substitute = "" 54 | else: 55 | raise NotImplementedError() 56 | 57 | # 1. replace or remove UNKWORDZ 58 | token_ids_sentences = [ 59 | s.replace("UNKWORDZ", unk_substitute) for s in token_ids_sentences 60 | ] 61 | # 2. turn tokens into input_ids 62 | token_ids = [ 63 | [int(i) for i in s.split(" ") if i != ""] for s in token_ids_sentences 64 | ] 65 | #  3. remove empty strings 66 | #  token_ids = [t for t in token_ids if t] # TODO yet to define how to handle empty strings 67 | # 4. decode to list of tokens 68 | masked_texts = self.helper.tokenizer.batch_decode(token_ids) 69 | # 4. forward pass on the batch 70 | _, logits = self.helper._forward( 71 | masked_texts, 72 | output_hidden_states=False, 73 | add_special_tokens=False, 74 | show_progress=show_progress, 75 | batch_size=batch_size, 76 | ) 77 | logits = self.helper._postprocess_logits( 78 | logits, target_token_pos_idx=target_token_pos_idx 79 | ) 80 | 81 | return logits.softmax(-1).detach().cpu().numpy() 82 | 83 | def run_lime_explainer(token_ids, target_pos_idx, num_samples, lime_args): 84 | """ 85 | Runs the LIME explainer on a given set of token IDs to obtain feature importance scores. 86 | 87 | Args: 88 | token_ids (List[int]): A list of token IDs representing the text to be explained. 89 | target_pos_idx (int): The index of the target class for which explanations are being generated. 90 | num_samples (int): The number of samples to use in the LIME explanation process. 91 | lime_args (Dict): Additional arguments to pass to the LimeTextExplainer. 92 | 93 | Returns: 94 | LimeTextExplainer.Explanation: The explanation object from LIME with feature importance scores. 95 | """ 96 | explainer_args = {k: v for k, v in self.init_args.items() if k != 'task_type'} 97 | 98 | lime_explainer = LimeTextExplainer(bow=False, **explainer_args) 99 | 100 | lime_args["num_samples"] = num_samples 101 | return lime_explainer.explain_instance( 102 | " ".join([str(i) for i in token_ids]), 103 | fn_prediction_token_ids, 104 | labels=[target_pos_idx], 105 | num_features=len(token_ids), 106 | **lime_args, 107 | ) 108 | 109 | 110 | lime_args = kwargs.get('call_args', {}) 111 | 112 | item = self._tokenize(text, return_special_tokens_mask=True) 113 | token_ids = item["input_ids"][0].tolist() 114 | 115 | if num_samples is None: 116 | num_samples = min(len(token_ids) ** 2, max_samples) # powerset size 117 | 118 | expl = run_lime_explainer(token_ids, target_pos_idx, num_samples, lime_args) 119 | 120 | token_scores = np.array( 121 | [list(dict(sorted(expl.local_exp[target_pos_idx])).values())] 122 | ) 123 | token_scores[item["special_tokens_mask"].bool().cpu().numpy()] = 0.0 124 | # token_scores is initially created as a 2D array with a single row, where each column 125 | # contains the importance score of each token in the analyzed text. 126 | # By setting token_scores = token_scores[0], we convert it to a 1D array for ease of use, 127 | # as it contains scores for the single text sequence processed by LIME. 128 | token_scores = token_scores[0] 129 | 130 | output = Explanation( 131 | text=text, 132 | tokens=self.get_tokens(text), 133 | scores=token_scores, 134 | explainer=self.NAME, 135 | helper_type=self.helper.HELPER_TYPE, 136 | target_pos_idx=target_pos_idx, 137 | target_token_pos_idx=target_token_pos_idx, 138 | target=self.helper.model.config.id2label[target_pos_idx], 139 | target_token=self.helper.tokenizer.decode( 140 | item["input_ids"][0, target_token_pos_idx].item() 141 | ) 142 | if self.helper.HELPER_TYPE == "token-classification" 143 | else None, 144 | ) 145 | return output 146 | -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/equal_width/gradient_equal_width_explainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydub import AudioSegment 3 | from captum.attr import Saliency, InputXGradient 4 | import numpy as np 5 | import torch 6 | from ..explanation_speech import ExplanationSpeech 7 | from ....speechxai_utils import pydub_to_np 8 | # TODO - include in utils 9 | from ..loo_speech_explainer import transcribe_audio 10 | 11 | 12 | class GradientEqualWidthSpeechExplainer: 13 | NAME = "Gradient_equal_width" 14 | 15 | def __init__(self, model_helper, multiply_by_inputs: bool = False): 16 | self.model_helper = model_helper 17 | self.multiply_by_inputs = multiply_by_inputs 18 | 19 | if self.multiply_by_inputs: 20 | self.NAME += " (x Input)" 21 | 22 | def _get_gradient_importance_frame_level( 23 | self, audio, target_class, target_label=None 24 | ): 25 | """ 26 | Compute the gradient importance for each frame of the audio w.r.t. the target class. 27 | Args: 28 | audio: audio - np.array 29 | target_class: target class - int 30 | target_label: target label - int - Used only in a multilabel scenario as for FSC 31 | """ 32 | torch.set_grad_enabled(True) # Context-manager 33 | 34 | # Function which returns the logits 35 | if self.model_helper.n_labels > 1: 36 | # We get the logits for the given label 37 | func = self.model_helper.get_logits_function_from_input_embeds_by_label( 38 | target_label 39 | ) 40 | else: 41 | func = self.model_helper.get_logits_from_input_embeds 42 | 43 | dl = InputXGradient(func) if self.multiply_by_inputs else Saliency(func) 44 | 45 | inputs = self.model_helper.feature_extractor( 46 | [audio_i.squeeze() for audio_i in [audio]], 47 | sampling_rate=self.model_helper.feature_extractor.sampling_rate, 48 | padding=True, 49 | return_tensors="pt", 50 | ) 51 | input_len = inputs["attention_mask"].sum().item() 52 | attr = dl.attribute(inputs.input_values, target=target_class) 53 | attr = attr[0, :input_len].detach().cpu() 54 | 55 | # pool over hidden size 56 | attr = attr.numpy() 57 | return attr 58 | 59 | def compute_explanation( 60 | self, 61 | audio_path: str, 62 | target_class=None, 63 | aggregation: str = "mean", 64 | num_s_split: float = 0.25, 65 | ) -> ExplanationSpeech: 66 | """ 67 | Compute the word-level explanation for the given audio. 68 | Args: 69 | audio_path: path to the audio file 70 | target_class: target class - int - If None, use the predicted class 71 | no_before_span: if True, it also consider the span before the word. This is because we observe gradient give importance also for the frame just before the word 72 | aggregation: aggregation method for the frames of the word. Can be "mean" or "max" 73 | num_s_split: float = number of seconds of each audio segment in which to split the audio, 74 | """ 75 | 76 | if aggregation not in ["mean", "max"]: 77 | raise ValueError( 78 | "Aggregation method not supported, choose between 'mean' and 'max'" 79 | ) 80 | 81 | # Load audio and convert to np.array 82 | audio_as = AudioSegment.from_wav(audio_path) 83 | audio = pydub_to_np(audio_as)[0] 84 | 85 | # Predict logits/probabilities 86 | logits_original = self.model_helper.predict([audio]) 87 | 88 | # Check if single label or multilabel scenario as for FSC 89 | n_labels = self.model_helper.n_labels 90 | 91 | # TODO 92 | if target_class is not None: 93 | targets = target_class 94 | 95 | else: 96 | if n_labels > 1: 97 | # Multilabel scenario as for FSC 98 | targets = [ 99 | int(np.argmax(logits_original[i], axis=1)[0]) 100 | for i in range(n_labels) 101 | ] 102 | else: 103 | targets = [int(np.argmax(logits_original, axis=1)[0])] 104 | 105 | # Compute gradient importance for each target label 106 | # This also handles the multilabel scenario as for FSC 107 | scores = [] 108 | for target_label, target_class in enumerate(targets): 109 | # Get gradient importance for each frame 110 | attr = self._get_gradient_importance_frame_level( 111 | audio, target_class, target_label 112 | ) 113 | 114 | old_start = 0 115 | old_start_ms = 0 116 | features = [] 117 | importances = [] 118 | a, b = 0, 0 # 50, 20 119 | 120 | duration_s = len(audio_as) / 1000 121 | 122 | a, b = 0, 0 123 | for e, i in enumerate(np.arange(0, duration_s, num_s_split)): 124 | start = i 125 | end = min(i + num_s_split, duration_s) 126 | 127 | start_ms = (start * 1000 - a) / 1000 128 | end_ms = (end * 1000 + b) / 1000 129 | 130 | start, end = int( 131 | start_ms * self.model_helper.feature_extractor.sampling_rate 132 | ), int(end_ms * self.model_helper.feature_extractor.sampling_rate) 133 | 134 | # Slice of the importance for the given word 135 | segment_importance = attr[start:end] 136 | 137 | # Consider also the spans between words 138 | # #span_before = attr[old_start:start] 139 | 140 | if aggregation == "max": 141 | segment_importance = np.max(segment_importance) 142 | else: 143 | segment_importance = np.mean(segment_importance) 144 | 145 | old_start = end 146 | old_start_ms = end_ms 147 | importances.append(segment_importance) 148 | features.append(e) 149 | 150 | scores.append(np.array(importances)) 151 | 152 | if n_labels > 1: 153 | # Multilabel scenario as for FSC 154 | scores = np.array(scores) 155 | else: 156 | scores = np.array([importances]) 157 | 158 | explanation = ExplanationSpeech( 159 | features=features, 160 | scores=scores, 161 | explainer=self.NAME + "-" + aggregation, 162 | target=targets if n_labels > 1 else targets, 163 | audio_path=audio_path, 164 | ) 165 | 166 | return explanation -------------------------------------------------------------------------------- /ferret/explainers/gradient.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | from captum.attr import InputXGradient, IntegratedGradients, Saliency 6 | from cv2 import multiply 7 | import numpy as np 8 | 9 | from . import BaseExplainer 10 | from .explanation import Explanation 11 | from .utils import parse_explainer_args 12 | 13 | 14 | class GradientExplainer(BaseExplainer): 15 | NAME = "Gradient" 16 | 17 | def __init__( 18 | self, 19 | model, 20 | tokenizer, 21 | model_helper: Optional[str] = None, 22 | multiply_by_inputs: bool = True, 23 | **kwargs, 24 | ): 25 | super().__init__(model, tokenizer, model_helper, **kwargs) 26 | 27 | self.multiply_by_inputs = multiply_by_inputs 28 | if self.multiply_by_inputs: 29 | self.NAME += " (x Input)" 30 | 31 | def compute_feature_importance( 32 | self, 33 | text: Union[str, Tuple[str, str]], 34 | target: Union[int, str] = 1, 35 | target_token: Optional[Union[int, str]] = None, 36 | **kwargs, 37 | ): 38 | def func(input_embeds): 39 | outputs = self.helper.model( 40 | inputs_embeds=input_embeds, attention_mask=item["attention_mask"] 41 | ) 42 | logits = self.helper._postprocess_logits( 43 | outputs.logits, target_token_pos_idx=target_token_pos_idx 44 | ) 45 | return logits 46 | 47 | # Sanity checks 48 | # TODO these checks have already been conducted if used within the benchmark class. Remove them here if possible. 49 | target_pos_idx = self.helper._check_target(target) 50 | target_token_pos_idx = self.helper._check_target_token(text, target_token) 51 | text = self.helper._check_sample(text) 52 | 53 | item = self._tokenize(text) 54 | item = {k: v.to(self.device) for k, v in item.items()} 55 | input_len = item["attention_mask"].sum().item() 56 | dl = ( 57 | InputXGradient(func, **self.init_args) 58 | if self.multiply_by_inputs 59 | else Saliency(func, **self.init_args) 60 | ) 61 | 62 | inputs = self.get_input_embeds(text) 63 | 64 | attr = dl.attribute(inputs, target=target_pos_idx, **kwargs) 65 | attr = attr[0, :input_len, :].detach().cpu() 66 | 67 | # pool over hidden size 68 | attr = attr.sum(-1).numpy() 69 | 70 | output = Explanation( 71 | text=text, 72 | tokens=self.get_tokens(text), 73 | scores=attr, 74 | explainer=self.NAME, 75 | helper_type=self.helper.HELPER_TYPE, 76 | target_pos_idx=target_pos_idx, 77 | target_token_pos_idx=target_token_pos_idx, 78 | target=self.helper.model.config.id2label[target_pos_idx], 79 | target_token=self.helper.tokenizer.decode( 80 | item["input_ids"][0, target_token_pos_idx].item() 81 | ) 82 | if self.helper.HELPER_TYPE == "token-classification" 83 | else None, 84 | ) 85 | return output 86 | 87 | 88 | class IntegratedGradientExplainer(BaseExplainer): 89 | NAME = "Integrated Gradient" 90 | 91 | def __init__( 92 | self, 93 | model, 94 | tokenizer, 95 | model_helper: Optional[str] = None, 96 | multiply_by_inputs: bool = True, 97 | **kwargs, 98 | ): 99 | super().__init__(model, tokenizer, model_helper, **kwargs) 100 | 101 | self.multiply_by_inputs = multiply_by_inputs 102 | if self.multiply_by_inputs: 103 | self.NAME += " (x Input)" 104 | 105 | def _generate_baselines(self, input_len): 106 | ids = ( 107 | [self.helper.tokenizer.cls_token_id] 108 | + [self.helper.tokenizer.pad_token_id] * (input_len - 2) 109 | + [self.helper.tokenizer.sep_token_id] 110 | ) 111 | embeddings = self.helper._get_input_embeds_from_ids( 112 | torch.tensor(ids, device=self.device) 113 | ) 114 | return embeddings.unsqueeze(0) 115 | 116 | def compute_feature_importance( 117 | self, 118 | text: Union[str, Tuple[str, str]], 119 | target: Union[int, str] = 1, 120 | target_token: Optional[Union[int, str]] = None, 121 | show_progress: bool = False, 122 | **kwargs, 123 | ): 124 | # Sanity checks 125 | # TODO these checks have already been conducted if used within the benchmark class. Remove them here if possible. 126 | 127 | target_pos_idx = self.helper._check_target(target) 128 | target_token_pos_idx = self.helper._check_target_token(text, target_token) 129 | text = self.helper._check_sample(text) 130 | 131 | def func(input_embeds): 132 | attention_mask = torch.ones( 133 | *input_embeds.shape[:2], dtype=torch.uint8, device=self.device 134 | ) 135 | _, logits = self.helper._forward_with_input_embeds( 136 | input_embeds, attention_mask, show_progress=show_progress 137 | ) 138 | logits = self.helper._postprocess_logits( 139 | logits, target_token_pos_idx=target_token_pos_idx 140 | ) 141 | return logits 142 | 143 | item = self._tokenize(text) 144 | input_len = item["attention_mask"].sum().item() 145 | dl = IntegratedGradients( 146 | func, multiply_by_inputs=self.multiply_by_inputs, **self.init_args 147 | ) 148 | inputs = self.get_input_embeds(text) 149 | baselines = self._generate_baselines(input_len) 150 | 151 | attr = dl.attribute(inputs, baselines=baselines, target=target_pos_idx, **kwargs) 152 | 153 | attr = attr[0, :input_len, :].detach().cpu() 154 | 155 | # pool over hidden size 156 | attr = attr.sum(-1).numpy() 157 | 158 | # norm_attr = self._normalize_input_attributions(attr.detach()) 159 | output = Explanation( 160 | text=text, 161 | tokens=self.get_tokens(text), 162 | scores=attr, 163 | explainer=self.NAME, 164 | helper_type=self.helper.HELPER_TYPE, 165 | target_pos_idx=target_pos_idx, 166 | target_token_pos_idx=target_token_pos_idx, 167 | target=self.helper.model.config.id2label[target_pos_idx], 168 | target_token=self.helper.tokenizer.decode( 169 | item["input_ids"][0, target_token_pos_idx].item() 170 | ) 171 | if self.helper.HELPER_TYPE == "token-classification" 172 | else None, 173 | ) 174 | return output 175 | -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/utils_removal.py: -------------------------------------------------------------------------------- 1 | from pydub import AudioSegment 2 | import whisperx 3 | import os 4 | from typing import Dict, List, Union, Tuple 5 | 6 | 7 | def remove_specified_words(audio, words, removal_type: str = "nothing"): 8 | """ 9 | Remove a word from audio using pydub, by replacing it with: 10 | - nothing 11 | - silence 12 | - white noise 13 | - pink noise 14 | 15 | Args: 16 | audio (pydub.AudioSegment): audio 17 | word: word to remove with its start and end times 18 | removal_type (str, optional): type of removal. Defaults to "nothing". 19 | """ 20 | 21 | from copy import deepcopy 22 | 23 | audio_removed = deepcopy(audio) 24 | 25 | a, b = 100, 40 26 | 27 | from IPython.display import display 28 | 29 | for word in words: 30 | start = int(word["start"] * 1000) 31 | end = int(word["end"] * 1000) 32 | 33 | before_word_audio = audio_removed[: start - a] 34 | after_word_audio = audio_removed[end + b :] 35 | 36 | word_duration = (end - start) + a + b 37 | 38 | if removal_type == "nothing": 39 | replace_word_audio = AudioSegment.empty() 40 | elif removal_type == "silence": 41 | replace_word_audio = AudioSegment.silent(duration=word_duration) 42 | elif removal_type == "white noise": 43 | sound_path = (os.path.join(os.path.dirname(__file__), "white_noise.mp3"),) 44 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 45 | elif removal_type == "pink noise": 46 | sound_path = (os.path.join(os.path.dirname(__file__), "pink_noise.mp3"),) 47 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 48 | 49 | audio_removed = before_word_audio + replace_word_audio + after_word_audio 50 | return audio_removed 51 | 52 | 53 | def transcribe_audio( 54 | audio_path: str, 55 | device: str = "cuda", 56 | batch_size: int = 2, 57 | compute_type: str = "float32", 58 | language: str = "en", 59 | model_name_whisper: str = "large-v2", 60 | ) -> Tuple[str, List[Dict[str, Union[str, float]]]]: 61 | """ 62 | Transcribe audio using whisperx, 63 | and return the text (transcription) and the words with their start and end times. 64 | """ 65 | 66 | ## Load whisperx model 67 | model_whisperx = whisperx.load_model( 68 | model_name_whisper, 69 | device, 70 | compute_type=compute_type, 71 | language=language, 72 | ) 73 | 74 | ## Transcribe audio 75 | audio = whisperx.load_audio(audio_path) 76 | result = model_whisperx.transcribe(audio, batch_size=batch_size) 77 | model_a, metadata = whisperx.load_align_model( 78 | language_code=result["language"], device=device 79 | ) 80 | 81 | ## Align timestamps 82 | result = whisperx.align( 83 | result["segments"], 84 | model_a, 85 | metadata, 86 | audio, 87 | device, 88 | return_char_alignments=False, 89 | ) 90 | 91 | if result is None or "segments" not in result or len(result["segments"]) == 0: 92 | return "", [] 93 | 94 | if len(result["segments"]) == 1: 95 | text = result["segments"][0]["text"] 96 | words = result["segments"][0]["words"] 97 | else: 98 | text = " ".join( 99 | result["segments"][i]["text"] for i in range(len(result["segments"])) 100 | ) 101 | words = [word for segment in result["segments"] for word in segment["words"]] 102 | 103 | # Remove words that are not properly transcribed 104 | words = [word for word in words if "start" in word] 105 | return text, words 106 | 107 | 108 | def transcribe_audio_given_model( 109 | model_whisperx, 110 | audio_path: str, 111 | batch_size: int = 2, 112 | device: str = "cuda", 113 | ) -> Tuple[str, List[Dict[str, Union[str, float]]]]: 114 | """ 115 | Transcribe audio using whisperx, 116 | and return the text (transcription) and the words with their start and end times. 117 | """ 118 | 119 | ## Transcribe audio 120 | audio = whisperx.load_audio(audio_path) 121 | result = model_whisperx.transcribe(audio, batch_size=batch_size) 122 | model_a, metadata = whisperx.load_align_model( 123 | language_code=result["language"], device=device 124 | ) 125 | 126 | ## Align timestamps 127 | result = whisperx.align( 128 | result["segments"], 129 | model_a, 130 | metadata, 131 | audio, 132 | device, 133 | return_char_alignments=False, 134 | ) 135 | 136 | if result is None or "segments" not in result or len(result["segments"]) == 0: 137 | return "", [] 138 | 139 | if len(result["segments"]) == 1: 140 | text = result["segments"][0]["text"] 141 | words = result["segments"][0]["words"] 142 | else: 143 | text = " ".join( 144 | result["segments"][i]["text"] for i in range(len(result["segments"])) 145 | ) 146 | words = [word for segment in result["segments"] for word in segment["words"]] 147 | 148 | # Remove words that are not properly transcribed 149 | words = [word for word in words if "start" in word] 150 | return text, words 151 | 152 | 153 | def remove_word(audio, word, removal_type: str = "nothing"): 154 | """ 155 | Remove a word from audio using pydub, by replacing it with: 156 | - nothing 157 | - silence 158 | - white noise 159 | - pink noise 160 | 161 | Args: 162 | audio (pydub.AudioSegment): audio 163 | word: word to remove with its start and end times 164 | removal_type (str, optional): type of removal. Defaults to "nothing". 165 | """ 166 | 167 | a, b = 100, 40 168 | 169 | before_word_audio = audio[: word["start"] * 1000 - a] 170 | after_word_audio = audio[word["end"] * 1000 + b :] 171 | word_duration = (word["end"] * 1000 - word["start"] * 1000) + a + b 172 | 173 | if removal_type == "nothing": 174 | replace_word_audio = AudioSegment.empty() 175 | elif removal_type == "silence": 176 | replace_word_audio = AudioSegment.silent(duration=word_duration) 177 | 178 | elif removal_type == "white noise": 179 | sound_path = (os.path.join(os.path.dirname(__file__), "white_noise.mp3"),) 180 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 181 | 182 | # display(audio_removed) 183 | elif removal_type == "pink noise": 184 | sound_path = (os.path.join(os.path.dirname(__file__), "pink_noise.mp3"),) 185 | replace_word_audio = AudioSegment.from_mp3(sound_path)[:word_duration] 186 | 187 | audio_removed = before_word_audio + replace_word_audio + after_word_audio 188 | return audio_removed 189 | -------------------------------------------------------------------------------- /ferret/modeling/speech_model_helpers/model_helper_fsc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List, Union, Tuple 3 | import torch 4 | from pydub import AudioSegment 5 | from ...speechxai_utils import pydub_to_np 6 | 7 | 8 | class ModelHelperFSC: 9 | """ 10 | Wrapper class to interface with HuggingFace models 11 | """ 12 | 13 | def __init__(self, model, feature_extractor, device, language="en"): 14 | self.model = model 15 | self.feature_extractor = feature_extractor 16 | self.device = device 17 | self.n_labels = 3 # Multi label problem 18 | self.language = language 19 | self.label_name = ["action", "object", "location"] 20 | 21 | # PREDICT SINGLE 22 | def predict( 23 | self, 24 | audios: List[np.ndarray], 25 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 26 | """ 27 | Predicts action, object and location from audio one sample at a time. 28 | Returns probs for each class. 29 | # This is to fix the bug of padding 30 | """ 31 | 32 | action_probs = np.empty((len(audios), 6)) 33 | object_probs = np.empty((len(audios), 14)) 34 | location_probs = np.empty((len(audios), 4)) 35 | for e, audio in enumerate(audios): 36 | action_probs[e], object_probs[e], location_probs[e] = self._predict([audio]) 37 | return action_probs, object_probs, location_probs 38 | 39 | def predict_action( 40 | self, 41 | audios: List[np.ndarray], 42 | ): 43 | action_probs = np.empty((len(audios), 6)) 44 | for e, audio in enumerate(audios): 45 | action_probs[e], _, _ = self._predict([audio]) 46 | return action_probs 47 | 48 | def predict_object( 49 | self, 50 | audios: List[np.ndarray], 51 | ): 52 | object_probs = np.empty((len(audios), 14)) 53 | for e, audio in enumerate(audios): 54 | _, object_probs[e], _ = self._predict([audio]) 55 | return object_probs 56 | 57 | def predict_location( 58 | self, 59 | audios: List[np.ndarray], 60 | ): 61 | location_probs = np.empty((len(audios), 4)) 62 | for e, audio in enumerate(audios): 63 | _, _, location_probs[e] = self._predict([audio]) 64 | return location_probs 65 | 66 | def get_prediction_function_by_label(self, label): 67 | if label == 0: 68 | return self.predict_action 69 | elif label == 1: 70 | return self.predict_object 71 | elif label == 2: 72 | return self.predict_location 73 | else: 74 | raise ValueError("label should be 0, 1 or 2") 75 | 76 | def _predict( 77 | self, 78 | audios: List[np.ndarray], 79 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 80 | """ 81 | Predicts action, object and location from audio. 82 | Returns probs for each class. 83 | """ 84 | 85 | ## Extract features 86 | inputs = self.feature_extractor( 87 | [audio.squeeze() for audio in audios], 88 | sampling_rate=self.feature_extractor.sampling_rate, 89 | padding=True, 90 | return_tensors="pt", 91 | ) 92 | 93 | ## Predict logits 94 | with torch.no_grad(): 95 | logits = ( 96 | self.model(inputs.input_values.to(self.device)) 97 | .logits.detach() 98 | .cpu() 99 | # .numpy() 100 | ) 101 | action_logits = logits[:, :6] 102 | object_logits = logits[:, 6:20] 103 | location_logits = logits[:, 20:24] 104 | 105 | return ( 106 | action_logits.softmax(-1).numpy(), 107 | object_logits.softmax(-1).numpy(), 108 | location_logits.softmax(-1).numpy(), 109 | ) 110 | 111 | def get_logits_action(self, input_embeds): 112 | logits = self.model(input_embeds.to(self.device)).logits 113 | logits = logits[:, :6] 114 | return logits 115 | 116 | def get_logits_object(self, input_embeds): 117 | logits = self.model(input_embeds.to(self.device)).logits 118 | logits = logits[:, 6:20] 119 | return logits 120 | 121 | def get_logits_location(self, input_embeds): 122 | logits = self.model(input_embeds.to(self.device)).logits 123 | logits = logits[:, 20:24] 124 | return logits 125 | 126 | def get_logits_function_from_input_embeds_by_label(self, label): 127 | if label == 0: 128 | return self.get_logits_action 129 | elif label == 1: 130 | return self.get_logits_object 131 | elif label == 2: 132 | return self.get_logits_location 133 | else: 134 | raise ValueError("label should be 0, 1 or 2") 135 | 136 | def get_text_labels(self, targets) -> Tuple[str, str, str]: 137 | action_ind, object_ind, location_ind = targets 138 | return ( 139 | self.model.config.id2label[action_ind], 140 | self.model.config.id2label[object_ind + 6], 141 | self.model.config.id2label[location_ind + 20], 142 | ) 143 | 144 | def get_text_labels_with_class(self, targets) -> Tuple[str, str, str]: 145 | """ 146 | Return the text labels with the class name as strings (e.g., ['action = increase', 'object = lights', 'location = kitchen']]) 147 | """ 148 | text_targets = self.get_text_labels(targets) 149 | label_and_target_class_names = [ 150 | f"{label}={target_class_name}" 151 | for label, target_class_name in zip(self.label_name, text_targets) 152 | ] 153 | return label_and_target_class_names 154 | 155 | def get_predicted_classes(self, audio_path=None, audio=None): 156 | if audio is None and audio_path is None: 157 | raise ValueError("Specify the audio path or the audio as a numpy array") 158 | 159 | if audio is None: 160 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 161 | 162 | logits_action, logits_object, logits_location = self.predict([audio]) 163 | action_ind = np.argmax(logits_action, axis=1)[0] 164 | object_ind = np.argmax(logits_object, axis=1)[0] 165 | location_ind = np.argmax(logits_location, axis=1)[0] 166 | return action_ind, object_ind, location_ind 167 | 168 | def get_predicted_probs(self, audio_path=None, audio=None): 169 | if audio is None and audio_path is None: 170 | raise ValueError("Specify the audio path or the audio as a numpy array") 171 | 172 | if audio is None: 173 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 174 | 175 | logits_action, logits_object, logits_location = self.predict([audio]) 176 | action_ind = np.argmax(logits_action, axis=1)[0] 177 | object_ind = np.argmax(logits_object, axis=1)[0] 178 | location_ind = np.argmax(logits_location, axis=1)[0] 179 | 180 | action_gt = logits_action[:, action_ind][0] 181 | object_gt = logits_object[:, object_ind][0] 182 | location_gt = logits_location[:, location_ind][0] 183 | return action_gt, object_gt, location_gt -------------------------------------------------------------------------------- /ferret/explainers/explanation_speech/gradient_speech_explainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pydub import AudioSegment 3 | from captum.attr import Saliency, InputXGradient 4 | import numpy as np 5 | import torch 6 | from .explanation_speech import ExplanationSpeech 7 | from ...speechxai_utils import pydub_to_np 8 | # TODO - include in utils 9 | from .loo_speech_explainer import transcribe_audio 10 | 11 | 12 | class GradientSpeechExplainer: 13 | NAME = "Gradient" 14 | 15 | def __init__(self, model_helper, multiply_by_inputs: bool = False): 16 | self.model_helper = model_helper 17 | self.multiply_by_inputs = multiply_by_inputs 18 | 19 | if self.multiply_by_inputs: 20 | self.NAME += " (x Input)" 21 | 22 | def _get_gradient_importance_frame_level( 23 | self, audio, target_class, target_label=None 24 | ): 25 | """ 26 | Compute the gradient importance for each frame of the audio w.r.t. the target class. 27 | Args: 28 | audio: audio - np.array 29 | target_class: target class - int 30 | target_label: target label - int - Used only in a multilabel scenario as for FSC 31 | """ 32 | torch.set_grad_enabled(True) # Context-manager 33 | 34 | # Function which returns the logits 35 | if self.model_helper.n_labels > 1: 36 | # We get the logits for the given label 37 | func = self.model_helper.get_logits_function_from_input_embeds_by_label( 38 | target_label 39 | ) 40 | else: 41 | func = self.model_helper.get_logits_from_input_embeds 42 | 43 | dl = InputXGradient(func) if self.multiply_by_inputs else Saliency(func) 44 | 45 | inputs = self.model_helper.feature_extractor( 46 | [audio_i.squeeze() for audio_i in [audio]], 47 | sampling_rate=self.model_helper.feature_extractor.sampling_rate, 48 | padding=True, 49 | return_tensors="pt", 50 | ) 51 | input_len = inputs["attention_mask"].sum().item() 52 | attr = dl.attribute(inputs.input_values, target=target_class) 53 | attr = attr[0, :input_len].detach().cpu() 54 | 55 | # pool over hidden size 56 | attr = attr.numpy() 57 | return attr 58 | 59 | def compute_explanation( 60 | self, 61 | audio_path: str, 62 | target_class=None, 63 | words_trascript: List = None, 64 | no_before_span: bool = True, 65 | aggregation: str = "mean", 66 | ) -> ExplanationSpeech: 67 | """ 68 | Compute the word-level explanation for the given audio. 69 | Args: 70 | audio_path: path to the audio file 71 | target_class: target class - int - If None, use the predicted class 72 | no_before_span: if True, it also consider the span before the word. This is because we observe gradient give importance also for the frame just before the word 73 | aggregation: aggregation method for the frames of the word. Can be "mean" or "max" 74 | """ 75 | 76 | if aggregation not in ["mean", "max"]: 77 | raise ValueError( 78 | "Aggregation method not supported, choose between 'mean' and 'max'" 79 | ) 80 | 81 | # Load audio and convert to np.array 82 | audio = pydub_to_np(AudioSegment.from_wav(audio_path))[0] 83 | 84 | # Predict logits/probabilities 85 | logits_original = self.model_helper.predict([audio]) 86 | 87 | # Check if single label or multilabel scenario as for FSC 88 | n_labels = self.model_helper.n_labels 89 | 90 | # TODO 91 | if target_class is not None: 92 | targets = target_class 93 | 94 | else: 95 | if n_labels > 1: 96 | # Multilabel scenario as for FSC 97 | targets = [ 98 | int(np.argmax(logits_original[i], axis=1)[0]) 99 | for i in range(n_labels) 100 | ] 101 | else: 102 | targets = [int(np.argmax(logits_original, axis=1)[0])] 103 | 104 | if words_trascript is None: 105 | # Transcribe audio 106 | _, words_trascript = transcribe_audio( 107 | audio_path=audio_path, language=self.model_helper.language 108 | ) 109 | 110 | # Compute gradient importance for each target label 111 | # This also handles the multilabel scenario as for FSC 112 | scores = [] 113 | for target_label, target_class in enumerate(targets): 114 | # Get gradient importance for each frame 115 | attr = self._get_gradient_importance_frame_level( 116 | audio, target_class, target_label 117 | ) 118 | 119 | old_start = 0 120 | old_start_ms = 0 121 | features = [] 122 | importances = [] 123 | a, b = 0, 0 # 50, 20 124 | 125 | for word in words_trascript: 126 | if no_before_span: 127 | # We directly consider the transcribed word 128 | start_ms = (word["start"] * 1000 - a) / 1000 129 | end_ms = (word["end"] * 1000 + b) / 1000 130 | 131 | else: 132 | # We also include the frames before the word 133 | start_ms = old_start_ms 134 | end_ms = (word["end"] * 1000) / 1000 135 | 136 | start, end = int( 137 | start_ms * self.model_helper.feature_extractor.sampling_rate 138 | ), int(end_ms * self.model_helper.feature_extractor.sampling_rate) 139 | 140 | # Slice of the importance for the given word 141 | word_importance = attr[start:end] 142 | 143 | # Consider also the spans between words 144 | # #span_before = attr[old_start:start] 145 | 146 | if aggregation == "max": 147 | word_importance = np.max(word_importance) 148 | else: 149 | word_importance = np.mean(word_importance) 150 | 151 | old_start = end 152 | old_start_ms = end_ms 153 | importances.append(word_importance) 154 | features.append(word["word"]) 155 | 156 | # Consider also the spans between words 157 | # importances.append(np.mean(span_before)) 158 | # features.append('-') 159 | 160 | # Consider also the spans between words 161 | # Final span 162 | # final_span = attr[old_start:len(audio_np)] 163 | # features.append('-') 164 | 165 | # if aggregation == "max": 166 | # importances.append(np.max(final_span)) 167 | # else: 168 | # importances.append(np.mean(final_span)) 169 | scores.append(np.array(importances)) 170 | 171 | if n_labels > 1: 172 | # Multilabel scenario as for FSC 173 | scores = np.array(scores) 174 | else: 175 | scores = np.array([importances]) 176 | 177 | features = [word["word"] for word in words_trascript] 178 | 179 | explanation = ExplanationSpeech( 180 | features=features, 181 | scores=scores, 182 | explainer=self.NAME + "-" + aggregation, 183 | target=targets if n_labels > 1 else targets, 184 | audio_path=audio_path, 185 | ) 186 | 187 | return explanation -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # ferret documentation build configuration file, created by 4 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another 16 | # directory, add these directories to sys.path here. If the directory is 17 | # relative to the documentation root, use os.path.abspath to make it 18 | # absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("..")) 24 | autodoc_mock_imports = ["_tkinter"] 25 | 26 | import ferret 27 | 28 | # -- General configuration --------------------------------------------- 29 | 30 | # If your documentation needs a minimal Sphinx version, state it here. 31 | # 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.autosummary", 39 | "sphinx.ext.viewcode", 40 | "sphinx.ext.napoleon", 41 | "sphinx_copybutton", 42 | "sphinx_toggleprompt", 43 | "sphinx_favicon", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # The master toctree document. 56 | master_doc = "index" 57 | 58 | # General information about the project. 59 | project = "ferret" 60 | copyright = "2022, Giuseppe Attanasio" 61 | author = "Giuseppe Attanasio" 62 | 63 | # The version info for the project you're documenting, acts as replacement 64 | # for |version| and |release|, also used in various other places throughout 65 | # the built documents. 66 | # 67 | # The short X.Y version. 68 | version = ferret.__version__ 69 | # The full version, including alpha/beta/rc tags. 70 | release = ferret.__version__ 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = "en" 78 | 79 | # List of patterns, relative to source directory, that match files and 80 | # directories to ignore when looking for source files. 81 | # This patterns also effect to html_static_path and html_extra_path 82 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 83 | 84 | # The name of the Pygments (syntax highlighting) style to use. 85 | pygments_style = "sphinx" 86 | 87 | # If true, `todo` and `todoList` produce output, else they produce nothing. 88 | todo_include_todos = False 89 | 90 | 91 | # -- Options for HTML output ------------------------------------------- 92 | 93 | # The theme to use for HTML and HTML Help pages. See the documentation for 94 | # a list of builtin themes. 95 | # 96 | # html_theme = "sphinx_rtd_theme" 97 | # html_theme = "furo" 98 | html_theme = "pydata_sphinx_theme" 99 | html_title = "ferret" 100 | 101 | # Theme options are theme-specific and customize the look and feel of a 102 | # theme further. For a list of options available for each theme, see the 103 | # documentation. 104 | # 105 | # switcher_version = version 106 | # if ".dev" in version: 107 | # switcher_version = "dev" 108 | # elif "rc" in version: 109 | # switcher_version = version.split("rc", maxsplit=1)[0] + " (rc)" 110 | 111 | html_theme_options = { 112 | # "navbar_start": ["navbar-logo"], 113 | "navbar_center": ["navbar-nav"], 114 | "navbar_end": ["version-switcher", "theme-switcher", "navbar-icon-links"], 115 | "navbar_persistent": ["search-button"], 116 | "navbar_align": "content", 117 | "secondary_sidebar_items": ["page-toc", "edit-this-page", "sourcelink"], 118 | "show_prev_next": True, 119 | "footer_items": ["copyright", "sphinx-version", "theme-version"], 120 | "github_url": "https://github.com/g8a9/ferret", 121 | "switcher": { 122 | "json_url": "versions.json", 123 | "version_match": version, 124 | }, 125 | "external_links": [ 126 | {"name": "Demo", "url": "https://huggingface.co/spaces/g8a9/ferret"}, 127 | ], 128 | "favicons": [ 129 | { 130 | "rel": "icon", 131 | "sizes": "48x48", 132 | "href": "favicon.ico", 133 | }, 134 | { 135 | "rel": "icon", 136 | "sizes": "16x16", 137 | "href": "favicon-16x16.png", 138 | }, 139 | { 140 | "rel": "icon", 141 | "sizes": "32x32", 142 | "href": "favicon-32x32.png", 143 | }, 144 | { 145 | "rel": "apple-touch-icon", 146 | "sizes": "180x180", 147 | "href": "apple-touch-icon.png", 148 | }, 149 | { 150 | "rel": "android-chrome-192x192", 151 | "sizes": "192x192", 152 | "href": "android-chrome-192x192.png", 153 | }, 154 | ] 155 | # "logo": { 156 | # "image_light": "logo.png", 157 | # "image_dark": "logo.png", 158 | # }, 159 | } 160 | 161 | # html_sidebars = {"**": ["sidebar-nav-bs", "sidebar-ethical-ads"]} 162 | # html_logo = "logo.png" 163 | 164 | 165 | # Add any paths that contain custom static files (such as style sheets) here, 166 | # relative to this directory. They are copied after the builtin static files, 167 | # so a file named "default.css" will overwrite the builtin "default.css". 168 | html_static_path = ["_static"] 169 | 170 | 171 | # -- Options for HTMLHelp output --------------------------------------- 172 | 173 | # Output file base name for HTML help builder. 174 | htmlhelp_basename = "ferretdoc" 175 | 176 | 177 | # -- Options for LaTeX output ------------------------------------------ 178 | 179 | latex_elements = { 180 | # The paper size ('letterpaper' or 'a4paper'). 181 | # 182 | # 'papersize': 'letterpaper', 183 | # The font size ('10pt', '11pt' or '12pt'). 184 | # 185 | # 'pointsize': '10pt', 186 | # Additional stuff for the LaTeX preamble. 187 | # 188 | # 'preamble': '', 189 | # Latex figure (float) alignment 190 | # 191 | # 'figure_align': 'htbp', 192 | } 193 | 194 | # Grouping the document tree into LaTeX files. List of tuples 195 | # (source start file, target name, title, author, documentclass 196 | # [howto, manual, or own class]). 197 | latex_documents = [ 198 | ( 199 | master_doc, 200 | "ferret.tex", 201 | "ferret Documentation", 202 | "Giuseppe Attanasio", 203 | "manual", 204 | ), 205 | ] 206 | 207 | 208 | # -- Options for manual page output ------------------------------------ 209 | 210 | # One entry per manual page. List of tuples 211 | # (source start file, name, description, authors, manual section). 212 | man_pages = [(master_doc, "ferret", "ferret Documentation", [author], 1)] 213 | 214 | 215 | # -- Options for Texinfo output ---------------------------------------- 216 | 217 | # Grouping the document tree into Texinfo files. List of tuples 218 | # (source start file, target name, title, author, 219 | # dir menu entry, description, category) 220 | texinfo_documents = [ 221 | ( 222 | master_doc, 223 | "ferret", 224 | "ferret Documentation", 225 | author, 226 | "ferret", 227 | "One line description of project.", 228 | "Miscellaneous", 229 | ), 230 | ] 231 | -------------------------------------------------------------------------------- /docs/source/user_guide/notions.benchmarking.rst: -------------------------------------------------------------------------------- 1 | .. _notions.benchmarking: 2 | 3 | *********************** 4 | Evaluating Explanations 5 | *********************** 6 | 7 | Benchmarking Metrics 8 | ======================= 9 | 10 | We evaluate explanations on the faithfulness and plausibility properties. Specifically, *ferret* implements three state-of-the-art metrics to measure faithfulness and three for plausibility [1]_ [2]_. 11 | 12 | .. [1] Towards Faithfully Interpretable NLP Systems: How Should We Define and Evaluate Faithfulness? (Jacovi & Goldberg, ACL 2020) 13 | .. [2] ERASER: A Benchmark to Evaluate Rationalized NLP Models (DeYoung et al., ACL 2020) 14 | 15 | 16 | .. _explanations-type: 17 | 18 | Type of explanations 19 | ======================= 20 | 21 | Before describing the faithfulness and plausibility metrics, we first define the types of explanations we handle: continuous score explanations, discrete explanations and human rationale. 22 | 23 | .. glossary:: 24 | 25 | Continuous score explanations 26 | Continuous score explanations assign a continuous score to each token. All the post-hoc feature attribution methods of ferret generate continuous score explanations. 27 | Continuous score explanations are also called soft scores or continuous token attribution scores. 28 | 29 | Discrete explanations 30 | Discrete explanations or rationale indicates the set of tokens supporting the prediction. 31 | 32 | Human rationales 33 | Human rationales are annotations highlighting the most relevant words (phrases, or sentences) a human annotator attributed to a given class label. 34 | Typically, human rationales are discrete explanations, indicating the set of words relevant for a human. 35 | 36 | 37 | ======================= 38 | .. _faithfulness-overview: 39 | 40 | Faithfulness measures 41 | ======================= 42 | Faithfulness evaluates how accurately the explanation reflects the inner working of the model (Jacovi and Goldberg, 2020). 43 | 44 | ferret offers the following measures of faithfulness: 45 | 46 | - :ref:`AOPC Comprehensiveness ` - (aopc_compr, ↑) - goes from 0 to 1 (best) 47 | - :ref:`AOPC Sufficiency ` - (aopc_suff, ↓)) - goes from 0 (best) to 1; 48 | - :ref:`Kendall's Tau correlation with Leave-One-Out token removal ` - (taucorr_loo, ↑) - goes from -1 to 1 (best). 49 | 50 | 51 | 52 | .. _faithfulness-aopc_compr: 53 | 54 | AOPC Comprehensiveness 55 | --------------------------- 56 | 57 | 58 | Comprehensiveness evaluates whether the explanation captures the tokens the model used to make the prediction. 59 | 60 | Given a set of relevant token that defines a discrete explanation, comprehensiveness measures the drop in the model probability if the relevant tokens are removed. 61 | A high value of comprehensiveness indicates that the tokens in rj are relevant for the prediction. 62 | 63 | 64 | More formally, let :math:`x` be a sentence and let :math:`f_j` be the prediction probability of the model :math:`f` for a target class :math:`j`. 65 | Let :math:`r_j` be a discrete explanation indicating the set of tokens supporting the prediction :math:`f_j`. 66 | Comprehensiveness is defined as 67 | 68 | .. math:: 69 | \textnormal{comprehensiveness} = f_j(x)−f_j(x \setminus r_j ) 70 | 71 | where :math:`x \setminus r_j` is the sentence :math:`x` were tokens in :math:`r_j` are removed. 72 | 73 | The higher the value, the more the explainer is able to select the relevant tokens for the prediction. 74 | 75 | While comprehensiveness is defined for discrete explanations, ferret explanations assign a continuous score to each token. The selection of the most important tokens from continuous score explanations impact the results. 76 | Hence, following (DeYoung et al., 2020), we measure comprehensiveness via the Area Over the Perturbation Curve. 77 | First, we filter out tokens with a negative contribution (i.e., they pull the prediction away from the chosen label). 78 | Then, we progressively consider th *k* most important tokens, with k ranging from 10% to 100% (step of 10%). Finally, we average the result. 79 | 80 | See `DeYoung et al. (2020) `_ for its detailed definition. 81 | 82 | 83 | .. _faithfulness-aopc_suff: 84 | 85 | AOPC Sufficiency 86 | --------------------------- 87 | 88 | Sufficiency captures if the tokens in the explanation are sufficient for the model to make the prediction. 89 | 90 | 91 | Let :math:`x` be a sentence and let :math:`f_j` be the prediction probability of the model :math:`f` for a target class :math:`j`. 92 | Let :math:`r_j` be a discrete explanation indicating the set of tokens supporting the prediction :math:`f_j`. 93 | Sufficiency is defined as 94 | 95 | .. math:: 96 | \textnormal{sufficiency} = f_j(x)− f_j(r_j) 97 | 98 | where :math:`r_j` is the sentence :math:`x` were only tokens in :math:`r_j` are considered. 99 | 100 | It goes from 0 (best) to 1. 101 | A low score indicates that tokens in the discrete explanation in :math:`r_j` are indeed the ones driving the prediction. 102 | 103 | 104 | As for comprehensiveness, we compute the Area Over the Perturbation Curve by varying the number of the relevant tokens :math:`r_j`. 105 | Specifically, we first filter out tokens with a negative contribution for the chosen target class. 106 | Then, we compute sufficiency varying the *k* most important tokens in :math:`r_j` (as default for 10% to 100% with step 10) and we average the result. 107 | 108 | 109 | See `DeYoung et al. (2020) `_ for its detailed definition. 110 | 111 | 112 | .. _faithfulness-taucorr_loo: 113 | 114 | Correlation with Leave-One-Out scores 115 | ------------------------------------------------------ 116 | The correlation with Leave-One-Out (taucorr_loo) measures the correlation between the explanation and a baseline explanation referred to as leave-one-out scores. 117 | The leave-one-out (LOO) scores are the prediction difference when one feature at the time is omitted. 118 | The taucorr_loo measure the Spearman correlation between the explanation and the leave-one-out scores. 119 | 120 | It goes from -1 to 1; a value closer to 1 means higher faithfulness to LOO. 121 | 122 | 123 | See `Jain and Wallace, (2019) `_ for its detailed definition. 124 | 125 | 126 | 127 | .. _plausibility: 128 | 129 | Plausibility measures 130 | ======================= 131 | 132 | Plausibility evaluates how well the explanation agree with human rationale. 133 | 134 | 135 | - Token Intersection Over Union (hard score) - (token_iou_plau) 136 | - Token F1 (hard score) - (token_f1_plau) 137 | - Area-Under-Precision-Recall-Curve - (auprc_plau) 138 | 139 | 140 | .. _plausibility-token_iou_plau: 141 | 142 | Intersection-Over-Union (IOU) 143 | ------------------------------------------------------ 144 | 145 | Given a human rationale and a discrete explanation, the Intersection-Over-Union (IOU) is the size of the overlap of the tokens they cover divided by the size of their union. 146 | 147 | We derive the discrete rationale from continuous score explanations by taking the top-K values with positive contribution. 148 | 149 | When the set of human rationales for the dataset is available, K is set as the average rationale length (as in ERASER). 150 | Otherwise, K is set as default to 5. 151 | 152 | See `DeYoung et al. (2020) `_ for its detailed definition. 153 | 154 | 155 | .. _plausibility-token_f1_plau: 156 | 157 | Token-level f1-score 158 | ------------------------------------------------------ 159 | 160 | Token-level F1 scores (↑) is the F1 score computed from the precision and recall at the token level considering the human rationale as ground truth explanation and the discrete explanation as the predicted one. 161 | 162 | As for the IOU, we derive the discrete rationale from explanations by taking the top-K values with positive contribution. 163 | 164 | When the set of human rationales for the dataset is available, K is set as the average rationale length. 165 | Otherwise, K is set as default to 5. 166 | 167 | See `DeYoung et al. (2020) `_ for its detailed definition. 168 | 169 | 170 | .. _plausibility-auprc_plau: 171 | 172 | Area Under the Precision Recall curve (AUPRC) 173 | ------------------------------------------------------ 174 | Area Under the Precision Recall curve (AUPRC) is computed by varying a threshold over token importance scores, using the human rationale as ground truth. 175 | 176 | The advantage of AUPRC with respect to the IOU and Token-level F1 is that it directly consider continuous score explanation. 177 | Hence, it takes into account tokens’ relative ranking and degree of importance. 178 | 179 | See `DeYoung et al. (2020) `_ for its detailed definition. 180 | -------------------------------------------------------------------------------- /ferret/visualization.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, defaultdict 2 | from typing import Dict, List, Optional 3 | 4 | import pandas as pd 5 | import seaborn as sns 6 | from matplotlib.colors import LinearSegmentedColormap 7 | 8 | from .evaluators.evaluation import ExplanationEvaluation 9 | from .explainers.explanation import Explanation 10 | from .evaluators import EvaluationMetricFamily 11 | 12 | 13 | def get_colormap(format): 14 | 15 | if format == "blue_red": 16 | return sns.diverging_palette(240, 10, as_cmap=True) 17 | elif format == "white_purple": 18 | return sns.light_palette("purple", as_cmap=True) 19 | elif format == "purple_white": 20 | return sns.light_palette("purple", as_cmap=True, reverse=True) 21 | elif format == "white_purple_white": 22 | colors = ["white", "purple", "white"] 23 | return LinearSegmentedColormap.from_list("diverging_white_purple", colors) 24 | else: 25 | raise ValueError(f"Unknown format {format}") 26 | 27 | 28 | def get_dataframe(explanations: List[Explanation]) -> pd.DataFrame: 29 | """Convert explanations into a pandas DataFrame. 30 | 31 | Args: 32 | explanations (List[Explanation]): list of explanations 33 | 34 | Returns: 35 | pd.DataFrame: explanations in table format. The columns are the tokens and the rows are the explanation scores, one for each explainer. 36 | """ 37 | scores = {e.explainer: e.scores for e in explanations} 38 | scores["Token"] = explanations[0].tokens 39 | table = pd.DataFrame(scores).set_index("Token").T 40 | return table 41 | 42 | 43 | def deduplicate_column_names(df): 44 | # Create a copy of the DataFrame to avoid modifying the original 45 | df_copy = df.copy() 46 | 47 | column_counts = Counter(df_copy.columns) 48 | 49 | new_columns = list() 50 | seen_names = defaultdict(int) 51 | for column in df_copy.columns: 52 | 53 | count = column_counts[column] 54 | if count > 1: 55 | new_columns.append(f"{column}_{seen_names[column]}") 56 | seen_names[column] += 1 57 | else: 58 | new_columns.append(column) 59 | 60 | df_copy.columns = new_columns 61 | return df_copy 62 | 63 | 64 | def style_heatmap(df: pd.DataFrame, subsets_info: List[Dict]): 65 | """Style a pandas DataFrame as a heatmap. 66 | 67 | Args: 68 | df (pd.DataFrame): a pandas DataFrame 69 | subsets_info (List[Dict]): a list of dictionaries containing the style information for each subset of the DataFrame. Each dictionary should contain the following keys: vmin, vmax, cmap, axis, subset. See https://pandas.pydata.org/pandas-docs/stable/user_guide/style.html#Building-Styles for more information. 70 | 71 | Returns: 72 | pd.io.formats.style.Styler: a styled pandas DataFrame 73 | """ 74 | 75 | style = df.style 76 | for si in subsets_info: 77 | style = style.background_gradient(**si) 78 | 79 | # Set stick index 80 | style = style.set_sticky(axis="index") 81 | 82 | return style.format("{:.2f}") 83 | 84 | 85 | def show_table( 86 | explanations: List[Explanation], remove_first_last: bool, style: str, **style_kwargs 87 | ): 88 | """Format explanation scores into a colored table. 89 | 90 | Args: 91 | explanations (List[Explanation]): list of explanations 92 | apply_style (bool): apply color to the table of explanation scores 93 | remove_first_last (bool): do not visualize the first and last tokens, typically CLS and EOS tokens 94 | 95 | Returns: 96 | pd.DataFrame: a colored (styled) pandas dataframed 97 | """ 98 | 99 | # Get scores as a pandas DataFrame 100 | table = get_dataframe(explanations) 101 | 102 | if remove_first_last: 103 | table = table.iloc[:, 1:-1] 104 | 105 | # add count as prefix for duplicated tokens 106 | table = deduplicate_column_names(table) 107 | if not style: 108 | return table.style.format("{:.2f}") 109 | 110 | if style == "heatmap": 111 | subset_info = { 112 | "vmin": style_kwargs.get("vmin", -1), 113 | "vmax": style_kwargs.get("vmax", 1), 114 | "cmap": style_kwargs.get("cmap", get_colormap("blue_red")), 115 | "axis": None, 116 | "subset": None, 117 | } 118 | return style_heatmap(table, [subset_info]) 119 | else: 120 | raise ValueError(f"Style {style} is not supported.") 121 | 122 | 123 | def show_evaluation_table( 124 | explanation_evaluations: List[ExplanationEvaluation], 125 | style: Optional[str], 126 | ) -> pd.DataFrame: 127 | """Format evaluation scores into a colored table. 128 | 129 | Args: 130 | explanation_evaluations (List[ExplanationEvaluation]): a list of evaluations of explanations 131 | apply_style (bool): color the table of evaluation scores 132 | 133 | Returns: 134 | pd.DataFrame: a colored (styled) pandas dataframe of evaluation scores 135 | """ 136 | 137 | # Flatten to a tabular format: explainers x evaluation metrics 138 | flat = list() 139 | for evaluation in explanation_evaluations: 140 | d = dict() 141 | d["Explainer"] = evaluation.explanation.explainer 142 | 143 | for metric_output in evaluation.evaluation_outputs: 144 | 145 | d[metric_output.metric.SHORT_NAME] = metric_output.value 146 | flat.append(d) 147 | 148 | table = pd.DataFrame(flat).set_index("Explainer") 149 | 150 | if not style: 151 | return table.format("{:.2f}") 152 | 153 | if style == "heatmap": 154 | 155 | subsets_info = list() 156 | 157 | # TODO: we use here the first explainer evaluation, assuming every evaluation in the list will have the same list of outputs 158 | outputs = explanation_evaluations[0].evaluation_outputs 159 | for output in outputs: 160 | 161 | vmin = output.metric.MIN_VALUE 162 | vmax = output.metric.MAX_VALUE 163 | best_value = ( 164 | # The BEST_VALUE attribute is defined for the FAITHFULNESS 165 | # metric family only. 166 | output.metric.BEST_VALUE if output.metric.METRIC_FAMILY == EvaluationMetricFamily.FAITHFULNESS 167 | else None 168 | ) 169 | 170 | if vmin == -1 and vmax == 1 and best_value == 0: 171 | cmap = get_colormap("white_purple_white") 172 | else: 173 | cmap = ( 174 | get_colormap("purple_white") 175 | if output.metric.LOWER_IS_BETTER 176 | else get_colormap("white_purple") 177 | ) 178 | 179 | subsets_info.append( 180 | dict( 181 | vmin=vmin, 182 | vmax=vmax, 183 | cmap=cmap, 184 | axis=1, 185 | subset=[output.metric.SHORT_NAME], 186 | ) 187 | ) 188 | 189 | style = style_heatmap(table, subsets_info) 190 | return style 191 | else: 192 | raise ValueError(f"Style {style} is not supported.") 193 | 194 | 195 | # def _style_evaluation(table: pd.DataFrame) -> pd.DataFrame: 196 | 197 | # """Apply style to evaluation scores. 198 | 199 | # Args: 200 | # table (pd.DataFrame): the evaluation scores as pandas DataFrame 201 | 202 | # Returns: 203 | # pd.io.formats.style.Styler: a colored and styled pandas dataframe of evaluation scores 204 | # """ 205 | 206 | # table_style = table.style.background_gradient( 207 | # axis=1, cmap=SCORES_PALETTE, vmin=-1, vmax=1 208 | # ) 209 | 210 | # show_higher_cols, show_lower_cols = list(), list() 211 | 212 | # # Color differently the evaluation measures for which "high score is better" or "low score is better" 213 | # # Darker colors mean better performance 214 | # for evaluation_measure in self.evaluators + self.class_based_evaluators: 215 | # if evaluation_measure.SHORT_NAME in table.columns: 216 | # if evaluation_measure.BEST_SORTING_ASCENDING == False: 217 | # # Higher is better 218 | # show_higher_cols.append(evaluation_measure.SHORT_NAME) 219 | # else: 220 | # # Lower is better 221 | # show_lower_cols.append(evaluation_measure.SHORT_NAME) 222 | 223 | # if show_higher_cols: 224 | # table_style.background_gradient( 225 | # axis=1, 226 | # cmap=EVALUATION_PALETTE, 227 | # vmin=-1, 228 | # vmax=1, 229 | # subset=show_higher_cols, 230 | # ) 231 | 232 | # if show_lower_cols: 233 | # table_style.background_gradient( 234 | # axis=1, 235 | # cmap=EVALUATION_PALETTE_REVERSED, 236 | # vmin=-1, 237 | # vmax=1, 238 | # subset=show_lower_cols, 239 | # ) 240 | # return table_style 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Ferret circular logo with the name to the right](/docs/source/_static/banner_v2.png) 2 | 3 | [![Latest PyPI version](https://img.shields.io/pypi/v/ferret-xai.svg)](https://pypi.python.org/pypi/ferret-xai) 4 | [![Documentation Status](https://readthedocs.org/projects/ferret/badge/?version=latest)](https://ferret.readthedocs.io/en/latest/?version=latest) 5 | [![HuggingFace Spaces Demo](https://img.shields.io/badge/HF%20Spaces-Demo-yellow)](https://huggingface.co/spaces/g8a9/ferret) 6 | [![YouTube Video](https://img.shields.io/badge/youtube-video-red)](https://www.youtube.com/watch?v=kX0HcSah_M4) 7 | [![arxiv preprint](https://img.shields.io/badge/arXiv-2208.01575-b31b1b.svg)](https://arxiv.org/abs/2208.01575) 8 | [![downloads badge](https://pepy.tech/badge/ferret-xai/month)](https://pepy.tech/project/ferret-xai) 9 | 10 | ferret is Python library that streamlines the use and benchmarking of interpretability techniques on Transformers models. 11 | 12 | - Documentation: https://ferret.readthedocs.io 13 | - Paper: https://aclanthology.org/2023.eacl-demo.29/ 14 | - Demo: https://huggingface.co/spaces/g8a9/ferret 15 | 16 | **ferret** is meant to integrate seamlessly with 🤗 **transformers** models, among which it currently supports text models only. 17 | We provide: 18 | - 🔍 Four established interpretability techniques based on **Token-level Feature Attribution**. Use them to find the most relevant words to your model output quickly. 19 | - ⚖️ Six **Faithfulness and Plausibility evaluation protocols**. Benchmark any token-level explanation against these tests to guide your choice toward the most reliable explainer. 20 | 21 | ACL Anthology Bibkey: 22 | ```bash 23 | attanasio-etal-2023-ferret 24 | ``` 25 | 26 | ### 📝 Examples 27 | 28 | All around tutorial (to test all explainers, evaluation metrics, and interface with XAI datasets): [Colab](https://colab.research.google.com/github/g8a9/ferret/blob/main/examples/benchmark.ipynb) 29 | 30 | Text Classification 31 | 32 | - Intent Detection with Multilingual XLM RoBERTa: [Colab](https://colab.research.google.com/drive/17AXeA9-u7lOLlE_DWtUixMg7Mi0NFPIp?usp=sharing) 33 | 34 | 35 | ## Getting Started 36 | 37 | ### Installation 38 | 39 | For the default installation, which does **not** include the dependencies for the speech XAI functionalities, 40 | 41 | ```bash 42 | pip install -U ferret-xai 43 | ``` 44 | 45 | Our main dependencies are 🤗 `tranformers` and `datasets`. 46 | 47 | If the speech XAI functionalities are needed, then 48 | 49 | ``` 50 | pip install -U ferret-xai[speech] 51 | ``` 52 | 53 | At the moment, the speech XAI-related dependencies are the only extra ones, so installing with `ferret-xai[speech]` or `ferret-xai[all]` is equivalent. 54 | 55 | **Important** Some of our dependencies might use the package name for `scikit-learn` and that breaks ferret installation. \ 56 | If your pip install command fails, try: 57 | 58 | ```bash 59 | SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True pip install -U ferret-xai 60 | ``` 61 | 62 | This is hopefully a temporary situation! 63 | 64 | ### Explain & Benchmark 65 | 66 | The code below provides a minimal example to run all the feature-attribution explainers supported by ferret and benchmark them on faithfulness metrics. 67 | 68 | We start from a common text classification pipeline 69 | 70 | ```python 71 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 72 | from ferret import Benchmark 73 | 74 | name = "cardiffnlp/twitter-xlm-roberta-base-sentiment" 75 | model = AutoModelForSequenceClassification.from_pretrained(name) 76 | tokenizer = AutoTokenizer.from_pretrained(name) 77 | ``` 78 | 79 | Using *ferret* is as simple as: 80 | 81 | ```python 82 | bench = Benchmark(model, tokenizer) 83 | explanations = bench.explain("You look stunning!", target=1) 84 | evaluations = bench.evaluate_explanations(explanations, target=1) 85 | 86 | bench.show_evaluation_table(evaluations) 87 | ``` 88 | 89 | Be sure to run the code in a Jupyter Notebook/Colab: the cell above will produce a nicely-formatted table to analyze the saliency maps. 90 | 91 | ## Features 92 | 93 | **ferret** offers a *painless* integration with Hugging Face models and 94 | naming conventions. If you are already using the 95 | [transformers](https://github.com/huggingface/transformers) library, you 96 | immediately get access to our **Explanation and Evaluation API**. 97 | 98 | ### Post-Hoc Explainers 99 | 100 | - Gradient (plain gradients or multiplied by input token embeddings) ([Simonyan et al., 2014](https://arxiv.org/abs/1312.6034)) 101 | - Integrated Gradient (plain gradients or multiplied by input token embeddings) ([Sundararajan et al., 2017](http://proceedings.mlr.press/v70/sundararajan17a.html)) 102 | - SHAP (via Partition SHAP approximation of Shapley values) ([Lundberg and Lee, 2017](https://proceedings.neurips.cc/paper/2017/hash/8a20a8621978632d76c43dfd28b67767-Abstract.html)) 103 | - LIME ([Ribeiro et al., 2016](https://dl.acm.org/doi/abs/10.1145/2939672.2939778)) 104 | 105 | ### Evaluation Metrics 106 | 107 | **Faithfulness** measures: 108 | 109 | - AOPC Comprehensiveness ([DeYoung et al., 110 | 2020](https://doi.org/10.18653/v1/2020.acl-main.408)) 111 | - AOPC Sufficiency ([DeYoung et al., 112 | 2020](https://doi.org/10.18653/v1/2020.acl-main.408)) 113 | - Kendall's Tau correlation with Leave-One-Out token removal. ([Jain 114 | and Wallace, 2019](https://aclanthology.org/N19-1357/)) 115 | 116 | **Plausibility** measures: 117 | 118 | - Area-Under-Precision-Recall-Curve (soft score) ([DeYoung et al., 2020](https://doi.org/10.18653/v1/2020.acl-main.408)) 119 | - Token F1 (hard score) ([DeYoung et al., 2020](https://doi.org/10.18653/v1/2020.acl-main.408)) 120 | - Token Intersection Over Union (hard score) ([DeYoung et al., 2020](https://doi.org/10.18653/v1/2020.acl-main.408)) 121 | 122 | See our [paper](https://arxiv.org/abs/2208.01575) for details. 123 | 124 | ## Visualization 125 | 126 | The `Benchmark` class exposes easy-to-use table 127 | visualization methods (e.g., within Jupyter Notebooks) 128 | 129 | ```python 130 | bench = Benchmark(model, tokenizer) 131 | 132 | # Pretty-print feature attribution scores by all supported explainers 133 | explanations = bench.explain("You look stunning!") 134 | bench.show_table(explanations) 135 | 136 | # Pretty-print all the supported evaluation metrics 137 | evaluations = bench.evaluate_explanations(explanations) 138 | bench.show_evaluation_table(evaluations) 139 | ``` 140 | 141 | ## Dataset Evaluations 142 | 143 | The `Benchmark` class has a handy method to compute and 144 | average our evaluation metrics across multiple samples from a dataset. 145 | 146 | ```python 147 | import numpy as np 148 | bench = Benchmark(model, tokenizer) 149 | 150 | # Compute and average evaluation scores one of the supported dataset 151 | samples = np.arange(20) 152 | hatexdata = bench.load_dataset("hatexplain") 153 | sample_evaluations = bench.evaluate_samples(hatexdata, samples) 154 | 155 | # Pretty-print the results 156 | bench.show_samples_evaluation_table(sample_evaluations) 157 | ``` 158 | 159 | ## Planned Developement 160 | 161 | See [the changelog file](https://github.com/g8a9/ferret/blob/main/HISTORY.rst) for further 162 | details. 163 | 164 | - ✅ GPU acceleartion support for inference (**v0.4.0**) 165 | - ✅ Batched Inference for internal methods\'s approximation steps (e.g., LIME or SHAP) (**v0.4.0**) 166 | - ⚙️ Simplified Task API to support NLI, Zero-Shot Text Classification, Language Modeling ([branch](https://github.com/g8a9/ferret/tree/task-API)). 167 | - ⚙️ Multi-sample explanation generation and evaluation 168 | - ⚙️ Support to explainers for seq2seq and autoregressive generation through [inseq](https://github.com/inseq-team/inseq). 169 | - ⚙️ New evaluation measure: Sensitivity, Stability ([Yin et al.](https://aclanthology.org/2022.acl-long.188/)) 170 | - ⚙️ New evaluation measure: Area Under the Threshold-Performance Curve (AUC-TP) ([Atanasova et al.](https://aclanthology.org/2020.emnlp-main.263/)) 171 | - ⚙️ New explainer: Sampling and Occlusion (SOC) ([Jin et al., 2020](https://arxiv.org/abs/1911.06194)) 172 | - ⚙️ New explainer: Discretized Integrated Gradient (DIG) ([Sanyal and Ren, 2021](https://aclanthology.org/2021.emnlp-main.805/)) 173 | - ⚙️ New explainer: Value Zeroing ([Mohebbi et al, 2023](https://aclanthology.org/2023.eacl-main.245/)) 174 | - ⚙️ Support additional form of aggregation over embeddings' hidden dimension. 175 | 176 | 177 | ## Authors 178 | 179 | - [Giuseppe Attanasio](https://gattanasio.cc) 180 | - [Eliana Pastor](mailto:eliana.pastor@centai.eu) 181 | - [Debora Nozza](https://deboranozza.com/) 182 | - Chiara Di Bonaventura 183 | 184 | ## Credits 185 | 186 | This package was created with Cookiecutter and the 187 | *audreyr/cookiecutter-pypackage* project template. 188 | 189 | - Cookiecutter: https://github.com/audreyr/cookiecutter 190 | - [audreyr/cookiecutter-pypackage](https://github.com/audreyr/cookiecutter-pypackage) 191 | 192 | Logo and graphical assets made by [Luca Attanasio](https://www.behance.net/attanasiol624d). 193 | 194 | If you are using *ferret* for your work, please consider citing us! 195 | 196 | ```bibtex 197 | @inproceedings{attanasio-etal-2023-ferret, 198 | title = "ferret: a Framework for Benchmarking Explainers on Transformers", 199 | author = "Attanasio, Giuseppe and Pastor, Eliana and Di Bonaventura, Chiara and Nozza, Debora", 200 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics: System Demonstrations", 201 | month = may, 202 | year = "2023", 203 | publisher = "Association for Computational Linguistics", 204 | } 205 | ``` 206 | --------------------------------------------------------------------------------