├── .conda └── meta.yaml ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── ---bug-report.md │ ├── ---documentation.md │ ├── ---feature-request.md │ └── --questions-and-help.md └── workflows │ ├── lint.yml │ ├── retry.yml │ ├── test-conda-cpu.yml │ ├── test-pip-cpu-with-type-checks.yml │ ├── test-pip-cpu.yml │ ├── test-pip-gpu.yml │ ├── test-website-depoy.yml │ └── website-depoy.yml ├── .gitignore ├── .pyre_configuration ├── AWESOME_LIST.md ├── CITATION ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── captum ├── __init__.py ├── _utils │ ├── __init__.py │ ├── av.py │ ├── common.py │ ├── exceptions.py │ ├── gradient.py │ ├── models │ │ ├── __init__.py │ │ ├── linear_model │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── train.py │ │ └── model.py │ ├── progress.py │ ├── sample_gradient.py │ ├── transformers_typing.py │ └── typing.py ├── attr │ ├── __init__.py │ ├── _core │ │ ├── __init__.py │ │ ├── dataloader_attr.py │ │ ├── deep_lift.py │ │ ├── feature_ablation.py │ │ ├── feature_permutation.py │ │ ├── gradient_shap.py │ │ ├── guided_backprop_deconvnet.py │ │ ├── guided_grad_cam.py │ │ ├── input_x_gradient.py │ │ ├── integrated_gradients.py │ │ ├── kernel_shap.py │ │ ├── layer │ │ │ ├── __init__.py │ │ │ ├── grad_cam.py │ │ │ ├── internal_influence.py │ │ │ ├── layer_activation.py │ │ │ ├── layer_conductance.py │ │ │ ├── layer_deep_lift.py │ │ │ ├── layer_feature_ablation.py │ │ │ ├── layer_feature_permutation.py │ │ │ ├── layer_gradient_shap.py │ │ │ ├── layer_gradient_x_activation.py │ │ │ ├── layer_integrated_gradients.py │ │ │ └── layer_lrp.py │ │ ├── lime.py │ │ ├── llm_attr.py │ │ ├── lrp.py │ │ ├── neuron │ │ │ ├── __init__.py │ │ │ ├── neuron_conductance.py │ │ │ ├── neuron_deep_lift.py │ │ │ ├── neuron_feature_ablation.py │ │ │ ├── neuron_gradient.py │ │ │ ├── neuron_gradient_shap.py │ │ │ ├── neuron_guided_backprop_deconvnet.py │ │ │ └── neuron_integrated_gradients.py │ │ ├── noise_tunnel.py │ │ ├── occlusion.py │ │ ├── remote_provider.py │ │ ├── saliency.py │ │ └── shapley_value.py │ ├── _models │ │ ├── __init__.py │ │ └── base.py │ └── _utils │ │ ├── __init__.py │ │ ├── approximation_methods.py │ │ ├── attribution.py │ │ ├── baselines.py │ │ ├── batching.py │ │ ├── class_summarizer.py │ │ ├── common.py │ │ ├── custom_modules.py │ │ ├── input_layer_wrapper.py │ │ ├── interpretable_input.py │ │ ├── lrp_rules.py │ │ ├── stat.py │ │ ├── summarizer.py │ │ └── visualization.py ├── concept │ ├── __init__.py │ ├── _core │ │ ├── __init__.py │ │ ├── cav.py │ │ ├── concept.py │ │ └── tcav.py │ └── _utils │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── common.py │ │ └── data_iterator.py ├── influence │ ├── __init__.py │ ├── _core │ │ ├── __init__.py │ │ ├── arnoldi_influence_function.py │ │ ├── influence.py │ │ ├── influence_function.py │ │ ├── similarity_influence.py │ │ ├── tracincp.py │ │ └── tracincp_fast_rand_proj.py │ └── _utils │ │ ├── __init__.py │ │ ├── common.py │ │ └── nearest_neighbors.py ├── insights │ ├── __init__.py │ ├── _utils │ │ └── __init__.py │ ├── attr_vis │ │ ├── __init__.py │ │ ├── _utils │ │ │ ├── __init__.py │ │ │ └── transforms.py │ │ ├── app.py │ │ ├── attribution_calculation.py │ │ ├── config.py │ │ ├── example.py │ │ ├── features.py │ │ ├── frontend │ │ │ ├── README.md │ │ │ ├── package.json │ │ │ ├── public │ │ │ │ └── index.html │ │ │ ├── src │ │ │ │ ├── App.css │ │ │ │ ├── App.module.css │ │ │ │ ├── App.test.js │ │ │ │ ├── App.tsx │ │ │ │ ├── WebApp.tsx │ │ │ │ ├── components │ │ │ │ │ ├── Arguments.tsx │ │ │ │ │ ├── ClassFilter.tsx │ │ │ │ │ ├── Contributions.tsx │ │ │ │ │ ├── Feature.tsx │ │ │ │ │ ├── Filter.tsx │ │ │ │ │ ├── FilterContainer.tsx │ │ │ │ │ ├── Header.tsx │ │ │ │ │ ├── LabelButton.tsx │ │ │ │ │ ├── Spinner.tsx │ │ │ │ │ ├── Tooltip.tsx │ │ │ │ │ ├── Visualization.tsx │ │ │ │ │ └── VisualizationGroup.tsx │ │ │ │ ├── index.css │ │ │ │ ├── index.tsx │ │ │ │ ├── models │ │ │ │ │ ├── filter.ts │ │ │ │ │ ├── insightsConfig.ts │ │ │ │ │ ├── typeHelpers.ts │ │ │ │ │ └── visualizationOutput.ts │ │ │ │ ├── react-app-env.d.ts │ │ │ │ └── utils │ │ │ │ │ ├── color.ts │ │ │ │ │ ├── cx.ts │ │ │ │ │ └── dataPoint.ts │ │ │ ├── tsconfig.json │ │ │ ├── widget │ │ │ │ ├── src │ │ │ │ │ ├── Widget.js │ │ │ │ │ ├── embed.js │ │ │ │ │ ├── extension.js │ │ │ │ │ └── index.js │ │ │ │ └── webpack.config.js │ │ │ └── yarn.lock │ │ ├── models │ │ │ └── cifar_torchvision.pt │ │ ├── server.py │ │ └── widget │ │ │ ├── __init__.py │ │ │ ├── _version.py │ │ │ ├── jupyter-captum-insights.json │ │ │ └── widget.py │ └── example.py ├── log │ ├── __init__.py │ └── dummy_log.py ├── metrics │ ├── __init__.py │ ├── _core │ │ ├── __init__.py │ │ ├── infidelity.py │ │ └── sensitivity.py │ └── _utils │ │ ├── __init__.py │ │ └── batching.py ├── module │ ├── __init__.py │ ├── binary_concrete_stochastic_gates.py │ ├── gaussian_stochastic_gates.py │ └── stochastic_gates_base.py ├── robust │ ├── __init__.py │ └── _core │ │ ├── __init__.py │ │ ├── fgsm.py │ │ ├── metrics │ │ ├── __init__.py │ │ ├── attack_comparator.py │ │ └── min_param_perturbation.py │ │ ├── perturbation.py │ │ └── pgd.py └── testing │ ├── attr │ └── helpers │ │ ├── __init__.py │ │ ├── attribution_delta_util.py │ │ ├── conductance_reference.py │ │ ├── gen_test_utils.py │ │ ├── get_config_util.py │ │ ├── neuron_layer_testing_util.py │ │ └── test_config.py │ └── helpers │ ├── __init__.py │ ├── basic.py │ ├── basic_models.py │ ├── classification_models.py │ ├── evaluate_linear_model.py │ └── influence │ ├── __init__.py │ └── common.py ├── docs ├── Captum_Attribution_Algos.png ├── README.md ├── algorithms_comparison_matrix.md ├── attribution_algorithms.md ├── captum_insights.md ├── contribution_guide.md ├── extension │ └── integrated_gradients.md ├── faq.md ├── getting_started.md ├── introduction.md ├── overview.md └── presentations │ └── Captum_NeurIPS_2019_final.key ├── environment.yml ├── pyproject.toml ├── scripts ├── build_docs.sh ├── build_insights.sh ├── install_via_conda.sh ├── install_via_pip.sh ├── parse_sphinx.py ├── parse_tutorials.py ├── publish_site.sh ├── run_mypy.sh ├── update_versions_html.py └── versions.js ├── setup.cfg ├── setup.py ├── sphinx ├── Makefile ├── README.md ├── make.bat └── source │ ├── attribution.rst │ ├── base_classes.rst │ ├── binary_concrete_stg.rst │ ├── concept.rst │ ├── conf.py │ ├── deconvolution.rst │ ├── deep_lift.rst │ ├── deep_lift_shap.rst │ ├── feature_ablation.rst │ ├── feature_permutation.rst │ ├── gaussian_stg.rst │ ├── gradient_shap.rst │ ├── guided_backprop.rst │ ├── guided_grad_cam.rst │ ├── index.rst │ ├── influence.rst │ ├── input_x_gradient.rst │ ├── insights.rst │ ├── integrated_gradients.rst │ ├── kernel_shap.rst │ ├── layer.rst │ ├── lime.rst │ ├── llm_attr.rst │ ├── lrp.rst │ ├── metrics.rst │ ├── module.rst │ ├── neuron.rst │ ├── noise_tunnel.rst │ ├── occlusion.rst │ ├── robust.rst │ ├── saliency.rst │ ├── shapley_value_sampling.rst │ └── utilities.rst ├── tests ├── __init__.py ├── attr │ ├── __init__.py │ ├── layer │ │ ├── __init__.py │ │ ├── test_grad_cam.py │ │ ├── test_internal_influence.py │ │ ├── test_layer_ablation.py │ │ ├── test_layer_activation.py │ │ ├── test_layer_conductance.py │ │ ├── test_layer_deeplift.py │ │ ├── test_layer_feature_permutation.py │ │ ├── test_layer_gradient_shap.py │ │ ├── test_layer_gradient_x_activation.py │ │ ├── test_layer_integrated_gradients.py │ │ └── test_layer_lrp.py │ ├── models │ │ ├── __init__.py │ │ └── test_base.py │ ├── neuron │ │ ├── __init__.py │ │ ├── test_neuron_ablation.py │ │ ├── test_neuron_conductance.py │ │ ├── test_neuron_deeplift.py │ │ ├── test_neuron_gradient.py │ │ ├── test_neuron_gradient_shap.py │ │ └── test_neuron_integrated_gradients.py │ ├── test_approximation_methods.py │ ├── test_baselines.py │ ├── test_class_summarizer.py │ ├── test_common.py │ ├── test_data_parallel.py │ ├── test_dataloader_attr.py │ ├── test_deconvolution.py │ ├── test_deeplift_basic.py │ ├── test_deeplift_classification.py │ ├── test_feature_ablation.py │ ├── test_feature_permutation.py │ ├── test_gradient_shap.py │ ├── test_guided_backprop.py │ ├── test_guided_grad_cam.py │ ├── test_hook_removal.py │ ├── test_input_layer_wrapper.py │ ├── test_input_x_gradient.py │ ├── test_integrated_gradients_basic.py │ ├── test_integrated_gradients_classification.py │ ├── test_interpretable_input.py │ ├── test_jit.py │ ├── test_kernel_shap.py │ ├── test_lime.py │ ├── test_llm_attr.py │ ├── test_llm_attr_hf_compatibility.py │ ├── test_lrp.py │ ├── test_occlusion.py │ ├── test_saliency.py │ ├── test_shapley.py │ ├── test_stat.py │ ├── test_summarizer.py │ ├── test_targets.py │ └── test_utils_batching.py ├── concept │ ├── __init__.py │ ├── test_concept.py │ └── test_tcav.py ├── influence │ ├── __init__.py │ ├── _core │ │ ├── __init__.py │ │ ├── test_arnoldi_influence.py │ │ ├── test_dataloader.py │ │ ├── test_naive_influence.py │ │ ├── test_similarity_influence.py │ │ ├── test_tracin_aggregate_influence.py │ │ ├── test_tracin_intermediate_quantities.py │ │ ├── test_tracin_k_most_influential.py │ │ ├── test_tracin_regression.py │ │ ├── test_tracin_self_influence.py │ │ ├── test_tracin_show_progress.py │ │ ├── test_tracin_validation.py │ │ └── test_tracin_xor.py │ └── _utils │ │ ├── __init__.py │ │ └── test_common.py ├── insights │ ├── test_contribution.py │ └── test_features.py ├── metrics │ ├── __init__.py │ ├── test_infidelity.py │ └── test_sensitivity.py ├── module │ ├── __init__.py │ ├── test_binary_concrete_stochastic_gates.py │ ├── test_binary_concrete_stochastic_gates_cuda.py │ ├── test_gaussian_stochastic_gates.py │ └── test_gaussian_stochastic_gates_cuda.py ├── robust │ ├── __init__.py │ ├── test_FGSM.py │ ├── test_PGD.py │ ├── test_attack_comparator.py │ └── test_min_param_perturbation.py └── utils │ ├── __init__.py │ ├── models │ ├── __init__.py │ └── linear_models │ │ ├── __init__.py │ │ └── _test_linear_classifier.py │ ├── test_av.py │ ├── test_common.py │ ├── test_gradient.py │ ├── test_helpers.py │ ├── test_jacobian.py │ ├── test_linear_model.py │ ├── test_progress.py │ └── test_sample_gradient.py ├── tutorials ├── Bert_SQUAD_Interpret.ipynb ├── Bert_SQUAD_Interpret2.ipynb ├── CIFAR_Captum_Robustness.ipynb ├── CIFAR_TorchVision_Captum_Insights.ipynb ├── CIFAR_TorchVision_Interpret.ipynb ├── DLRM_Tutorial.ipynb ├── Distributed_Attribution.ipynb ├── House_Prices_Regression_Interpret.ipynb ├── IMDB_TorchText_Interpret.ipynb ├── Image_and_Text_Classification_LIME.ipynb ├── Llama2_LLM_Attribution.ipynb ├── Multimodal_VQA_Captum_Insights.ipynb ├── Multimodal_VQA_Interpret.ipynb ├── README.md ├── Resnet_TorchVision_Ablation.ipynb ├── Segmentation_Interpret.ipynb ├── TCAV_Image.ipynb ├── TCAV_NLP.ipynb ├── Titanic_Basic_Interpret.ipynb ├── TorchVision_Interpret.ipynb ├── TracInCP_Tutorial.ipynb ├── data │ ├── dlrm │ │ ├── X_S_T_test │ │ └── X_S_T_test_above_0999 │ └── tcav │ │ ├── image │ │ ├── concepts │ │ │ └── README.md │ │ └── imagenet │ │ │ └── README.md │ │ └── text-sensitivity │ │ ├── neutral.csv │ │ ├── neutral2.csv │ │ ├── neutral3.csv │ │ ├── neutral4.csv │ │ ├── neutral5.csv │ │ └── positive-adjectives.csv ├── img │ ├── bert │ │ └── visuals_of_start_end_predictions.png │ ├── captum_insights.png │ ├── captum_insights_vqa.png │ ├── dlrm_arch.png │ ├── lime_text_viz.png │ ├── resnet │ │ ├── swan-3299528_1280.jpg │ │ └── swan-3299528_1280_attribution.jpg │ ├── sentiment_analysis.png │ └── vqa │ │ ├── elephant.jpg │ │ ├── elephant_attribution.jpg │ │ ├── siamese.jpg │ │ ├── siamese_attribution.jpg │ │ ├── zebra.jpg │ │ └── zebra_attribution.jpg └── models │ ├── california_model.pt │ ├── cifar_torchvision.pt │ ├── embedding_bag_ag_news.pt │ ├── imdb-model-cnn-large.pt │ └── titanic_model.pt └── website ├── core ├── Footer.js ├── Tutorial.js └── TutorialSidebar.js ├── package.json ├── pages ├── en │ └── index.js └── tutorials │ └── index.js ├── sidebars.json ├── siteConfig.js ├── static ├── .nojekyll ├── CNAME ├── css │ ├── alabaster.css │ ├── basic.css │ ├── code_block_buttons.css │ └── custom.css ├── img │ ├── IG_eq1.png │ ├── algorithms_comparison_matrix.png │ ├── captum-icon.png │ ├── captum.ico │ ├── captum_insights_screenshot.png │ ├── captum_insights_screenshot_vqa.png │ ├── captum_logo.png │ ├── captum_logo.svg │ ├── conductance_eq_1.png │ ├── conductance_eq_2.png │ ├── deepLIFT_multipliers_eq1.png │ ├── expanding_arrows.svg │ ├── infidelity_eq.png │ ├── landing-background.jpg │ ├── multi-modal.png │ ├── oss_logo.png │ ├── pytorch_logo.svg │ └── sensitivity_eq.png └── js │ └── code_block_buttons.js ├── tutorials.json └── yarn.lock /.conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data(setup_file="../setup.py", from_recipe_dir=True) %} 2 | 3 | package: 4 | name: {{ data.get("name")|lower }} 5 | version: {{ data.get("version") }} 6 | 7 | source: 8 | path: .. 9 | 10 | build: 11 | noarch: python 12 | script: "$PYTHON setup.py install --single-version-externally-managed --record=record.txt" 13 | 14 | requirements: 15 | host: 16 | - python>=3.9 17 | - setuptools 18 | run: 19 | - numpy<2.0 20 | - pytorch>=1.10 21 | - matplotlib-base 22 | - tqdm 23 | - packaging 24 | 25 | test: 26 | imports: 27 | - captum 28 | 29 | about: 30 | home: https://captum.ai 31 | license: BSD-3 32 | license_file: LICENSE 33 | summary: Model interpretability for PyTorch 34 | description: | 35 | Captum is a model interpretability and understanding library for PyTorch. 36 | Captum means comprehension in Latin and contains general purpose implementations 37 | of integrated gradients, saliency maps, smoothgrad, vargrad and others for 38 | PyTorch models. It has quick integration for models built with domain-specific 39 | libraries such as torchvision, torchtext, and others. 40 | doc_url: https://captum.ai 41 | dev_url: https://github.com/pytorch/captum 42 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | tutorials/* linguist-documentation 2 | *.pt binary 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug report" 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. 19 | 1. 20 | 1. 21 | 22 | 23 | 24 | ## Expected behavior 25 | 26 | 27 | 28 | ## Environment 29 | Describe the environment used for Captum 30 | 31 | ``` 32 | 33 | - Captum / PyTorch Version (e.g., 1.0 / 0.4.0): 34 | - OS (e.g., Linux): 35 | - How you installed Captum / PyTorch (`conda`, `pip`, source): 36 | - Build command you used (if compiling from source): 37 | - Python version: 38 | - CUDA/cuDNN version: 39 | - GPU models and configuration: 40 | - Any other relevant information: 41 | 42 | ## Additional context 43 | 44 | 45 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Pitch 18 | 19 | 20 | 21 | ## Alternatives 22 | 23 | 24 | 25 | ## Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/--questions-and-help.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓ Questions and Help" 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | 12 | We have a set of listed resources available on the website and FAQ: https://captum.ai/ and https://captum.ai/docs/faq . Feel free to open an issue here on the github or in our discussion forums: 13 | 14 | - [Discussion Forum](https://discuss.pytorch.org/c/captum) 15 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Captum Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | tests: 13 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 14 | with: 15 | runner: linux.12xlarge 16 | docker-image: cimg/python:3.11 17 | repository: pytorch/captum 18 | script: | 19 | sudo chmod -R 777 . 20 | ./scripts/install_via_pip.sh 21 | ufmt check . 22 | flake8 23 | sphinx-build -WT --keep-going sphinx/source sphinx/build 24 | -------------------------------------------------------------------------------- /.github/workflows/retry.yml: -------------------------------------------------------------------------------- 1 | name: Rerun tests if failed 2 | on: 3 | workflow_run: 4 | workflows: ["Unit-tests for Conda install", "Unit-tests for Pip install with type checks", "Unit-tests for Pip install"] 5 | types: ["completed"] 6 | 7 | permissions: 8 | actions: write 9 | 10 | jobs: 11 | rerun-tests: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Log workflow metadata 15 | run: | 16 | echo "ID: ${{ github.event.workflow_run.id }}" 17 | echo "attempt: ${{ github.event.workflow_run.run_attempt }}" 18 | echo "event: ${{ github.event.workflow_run.conclusion }}" 19 | echo "event: ${{ github.event.workflow_run.event }}" 20 | - name: Rerun Failed Workflows 21 | if: github.event.workflow_run.conclusion == 'failure' && github.event.workflow_run.run_attempt <= 3 22 | env: 23 | GH_TOKEN: ${{ github.token }} 24 | RUN_ID: ${{ github.event.workflow_run.id }} 25 | run: | 26 | gh run rerun ${RUN_ID} --repo="${{ github.repository }}" --failed 27 | -------------------------------------------------------------------------------- /.github/workflows/test-conda-cpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests for Conda install 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | workflow_dispatch: 10 | 11 | env: 12 | CHANNEL: "nightly" 13 | 14 | jobs: 15 | tests: 16 | strategy: 17 | matrix: 18 | python_version: ["3.9", "3.10", "3.11", "3.12"] 19 | fail-fast: false 20 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 21 | with: 22 | runner: linux.12xlarge 23 | repository: pytorch/captum 24 | script: | 25 | # Set up Environment Variables 26 | export PYTHON_VERSION="${{ matrix.python_version }}" 27 | 28 | # Create Conda Env 29 | conda create -yp ci_env python="${PYTHON_VERSION}" 30 | conda activate /pytorch/captum/ci_env 31 | ./scripts/install_via_conda.sh 32 | 33 | # Run Tests 34 | python3 -m pytest -ra --cov=. --cov-report term-missing 35 | -------------------------------------------------------------------------------- /.github/workflows/test-pip-cpu-with-type-checks.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests for Pip install with type checks 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | tests: 13 | strategy: 14 | matrix: 15 | pytorch_args: ["", "-n"] 16 | fail-fast: false 17 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 18 | with: 19 | runner: linux.12xlarge 20 | docker-image: cimg/python:3.11 21 | repository: pytorch/captum 22 | script: | 23 | sudo chmod -R 777 . 24 | ./scripts/install_via_pip.sh ${{ matrix.pytorch_args }} 25 | ./scripts/run_mypy.sh 26 | pyre check 27 | # Run Tests 28 | python3 -m pytest -ra --cov=. --cov-report term-missing 29 | -------------------------------------------------------------------------------- /.github/workflows/test-pip-cpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests for Pip install 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | tests: 13 | strategy: 14 | matrix: 15 | pytorch_args: ["-v 1.10", "-v 1.11", "-v 1.12", "-v 1.13", "-v 2.0.0", "-v 2.1.0", "-v 2.2.0", "-v 2.3.0"] 16 | transformers_args: ["-t 4.38.0", "-t 4.39.0", "-t 4.41.0", "-t 4.43.0", "-t 4.45.2"] 17 | docker_img: ["cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] 18 | exclude: 19 | - pytorch_args: "-v 1.10" 20 | docker_img: "cimg/python:3.10" 21 | - pytorch_args: "-v 1.10" 22 | docker_img: "cimg/python:3.11" 23 | - pytorch_args: "-v 1.11" 24 | docker_img: "cimg/python:3.11" 25 | - pytorch_args: "-v 1.12" 26 | docker_img: "cimg/python:3.11" 27 | - pytorch_args: "-v 1.10" 28 | docker_img: "cimg/python:3.12" 29 | - pytorch_args: "-v 1.11" 30 | docker_img: "cimg/python:3.12" 31 | - pytorch_args: "-v 1.12" 32 | docker_img: "cimg/python:3.12" 33 | - pytorch_args: "-v 1.13" 34 | docker_img: "cimg/python:3.12" 35 | - pytorch_args: "-v 2.0.0" 36 | docker_img: "cimg/python:3.12" 37 | - pytorch_args: "-v 2.1.0" 38 | docker_img: "cimg/python:3.12" 39 | fail-fast: false 40 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 41 | with: 42 | runner: linux.12xlarge 43 | docker-image: ${{ matrix.docker_img }} 44 | repository: pytorch/captum 45 | script: | 46 | sudo chmod -R 777 . 47 | ./scripts/install_via_pip.sh ${{ matrix.pytorch_args }} ${{ matrix.transformers_args }} 48 | # Run Tests 49 | python3 -m pytest -ra --cov=. --cov-report term-missing 50 | -------------------------------------------------------------------------------- /.github/workflows/test-pip-gpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests for Pip install 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - master 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | tests: 13 | strategy: 14 | matrix: 15 | cuda_arch_version: ["12.1"] 16 | fail-fast: false 17 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 18 | with: 19 | runner: linux.4xlarge.nvidia.gpu 20 | repository: pytorch/captum 21 | gpu-arch-type: cuda 22 | gpu-arch-version: ${{ matrix.cuda_arch_version }} 23 | script: | 24 | python3 -m pip install --upgrade pip --progress-bar off 25 | python3 -m pip install -e .[dev] --progress-bar off 26 | 27 | # Build package 28 | python3 -m pip install build --progress-bar off 29 | python3 -m build 30 | 31 | # Run Tests 32 | python3 -m pytest -ra --cov=. --cov-report term-missing 33 | -------------------------------------------------------------------------------- /.github/workflows/test-website-depoy.yml: -------------------------------------------------------------------------------- 1 | name: Test deployment 2 | 3 | on: 4 | pull_request: 5 | # Review gh actions docs if you want to further define triggers, paths, etc 6 | # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#on 7 | 8 | jobs: 9 | test-deploy: 10 | name: Test deployment 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Setup / build docs 16 | run: | 17 | sudo chmod -R 777 . 18 | python3 -m pip install --upgrade pip --progress-bar off 19 | python3 -m pip install -e .[dev] --progress-bar off 20 | python3 -m pip install beautifulsoup4 ipython jinja2==3.0.0 nbconvert==5.6.1 ipython_genutils --progress-bar off 21 | ./scripts/build_docs.sh -b 22 | cd website 23 | -------------------------------------------------------------------------------- /.github/workflows/website-depoy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | # Review gh actions docs if you want to further define triggers, paths, etc 8 | # https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#on 9 | 10 | permissions: 11 | contents: write 12 | pages: write 13 | 14 | jobs: 15 | deploy: 16 | name: Deploy to GitHub Pages 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - name: Setup / build docs 22 | run: | 23 | sudo chmod -R 777 . 24 | python3 -m pip install --upgrade pip --progress-bar off 25 | python3 -m pip install -e .[dev] --progress-bar off 26 | python3 -m pip install beautifulsoup4 ipython jinja2==3.0.0 nbconvert==5.6.1 ipython_genutils --progress-bar off 27 | ./scripts/build_docs.sh -b 28 | cd website 29 | 30 | 31 | # Popular action to deploy to GitHub Pages: 32 | # Docs: https://github.com/peaceiris/actions-gh-pages#%EF%B8%8F-docusaurus 33 | - name: Deploy to GitHub Pages 34 | uses: peaceiris/actions-gh-pages@v3 35 | with: 36 | github_token: ${{ secrets.GITHUB_TOKEN }} 37 | # Build output to publish to the `gh-pages` branch: 38 | publish_dir: ./website/build/captum/ 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Downloaded data 7 | data/ 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | captum/insights/frontend/node_modules/ 19 | captum/insights/frontend/.pnp/ 20 | captum/insights/frontend/build/ 21 | captum/insights/widget/static/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # watchman 34 | # Watchman helper 35 | .watchmanconfig 36 | 37 | # OSX 38 | *.DS_Store 39 | 40 | # Atom plugin files and ctags 41 | .ftpconfig 42 | .ftpconfig.cson 43 | .ftpignore 44 | *.tags 45 | *.tags1 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | captum/insights/frontend/npm-debug.log* 57 | captum/insights/frontend/yarn-debug.log* 58 | captum/insights/frontend/yarn-error.log* 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | captum/insights/frontend/coverage 63 | .tox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # mypy 91 | .mypy_cache/ 92 | 93 | # vim 94 | *.swp 95 | 96 | # Sphinx documentation 97 | sphinx/build/ 98 | 99 | # Docusaurus 100 | website/build/ 101 | website/i18n/ 102 | website/node_modules/ 103 | 104 | ## Generated for tutorials 105 | website/_tutorials/ 106 | website/static/files/ 107 | website/pages/tutorials/* 108 | !website/pages/tutorials/index.js 109 | 110 | ## Generated for Sphinx 111 | website/pages/api/ 112 | website/static/_sphinx/ 113 | 114 | # Insight 115 | captum/insights/attr_vis/frontend/node_modules/ 116 | captum/insights/attr_vis/widget/static 117 | -------------------------------------------------------------------------------- /.pyre_configuration: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | ".*/build/.*", 4 | ".*/docs/.*", 5 | ".*/scripts/.*", 6 | ".*/sphinx/.*", 7 | ".*/website/.*", 8 | ".*/fb/.*", 9 | ".*/tests/.*", 10 | ".*/setup.py" 11 | ], 12 | "site_package_search_strategy": "all", 13 | "source_directories": [ 14 | "." 15 | ], 16 | "strict": true 17 | } 18 | -------------------------------------------------------------------------------- /AWESOME_LIST.md: -------------------------------------------------------------------------------- 1 | # Awesome List 2 | 3 | There is a lot of awesome research and development happening out in the interpretability community that we would like to share. Here we will maintain a curated list of research, implementations and resources. We would love to learn about more! Please feel free to make a pull request to contribute to the list. 4 | 5 | 6 | #### TorchRay: Visualization methods for deep CNNs 7 | TorchRay focuses on attribution, namely the problem of determining which part of the input, usually an image, is responsible for the value computed by a neural network. 8 | - [https://github.com/facebookresearch/TorchRay](https://github.com/facebookresearch/TorchRay) 9 | 10 | 11 | #### Score Cam: A gradient-free CAM extension 12 | Score-CAM is a gradient-free visualization method extended from Grad-CAM and Grad-CAM++. It provides score-weighted visual explanations for CNNs. 13 | - [Paper](https://arxiv.org/abs/1910.01279) 14 | - [https://github.com/haofanwang/Score-CAM](https://github.com/haofanwang/Score-CAM) 15 | 16 | 17 | #### White Noise Analysis 18 | White noise stimuli is fed to a classifier and the ones that are categorized into a particular class are averaged. It gives an estimate of the templates a classifier uses for classification, and is based on two popular and related methods in psychophysics and neurophysiology namely classification images and spike triggered analysis. 19 | - [Paper](https://arxiv.org/abs/1912.12106) 20 | - [https://github.com/aliborji/WhiteNoiseAnalysis.git](https://github.com/aliborji/WhiteNoiseAnalysis.git) 21 | 22 | 23 | #### FastCAM: Multiscale Saliency Map with SMOE scale 24 | An attribution method that uses information at the end of each network scale which is then combined into a single saliency map. 25 | - [Paper](https://arxiv.org/abs/1911.11293) 26 | - [https://github.com/LLNL/fastcam](https://github.com/LLNL/fastcam) 27 | - [pull request](https://github.com/pytorch/captum/pull/442) 28 | - [jupyter notebook demo](https://github.com/LLNL/fastcam/blob/captum/demo-captum.ipynb) 29 | -------------------------------------------------------------------------------- /CITATION: -------------------------------------------------------------------------------- 1 | @misc{kokhlikyan2020captum, 2 | title={Captum: A unified and generic model interpretability library for PyTorch}, 3 | author={Narine Kokhlikyan and Vivek Miglani and Miguel Martin and Edward Wang and Bilal Alsallakh and Jonathan Reynolds and Alexander Melnikov and Natalia Kliushkina and Carlos Araya and Siqi Yan and Orion Reblitz-Richardson}, 4 | year={2020}, 5 | eprint={2009.07896}, 6 | archivePrefix={arXiv}, 7 | primaryClass={cs.LG} 8 | } 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, PyTorch team 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /captum/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | import captum.attr as attr 5 | import captum.concept as concept 6 | import captum.influence as influence 7 | import captum.log as log 8 | import captum.metrics as metrics 9 | import captum.robust as robust 10 | 11 | 12 | __version__ = "0.8.0" 13 | 14 | __all__ = ["attr", "concept", "influence", "log", "metrics", "robust"] 15 | -------------------------------------------------------------------------------- /captum/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/_utils/__init__.py -------------------------------------------------------------------------------- /captum/_utils/exceptions.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-strict 4 | 5 | 6 | class FeatureAblationFutureError(Exception): 7 | """This custom error is raised when an error 8 | occurs within the callback chain of a 9 | FeatureAblation attribution call""" 10 | 11 | pass 12 | 13 | 14 | class ShapleyValueFutureError(Exception): 15 | """This custom error is raised when an error 16 | occurs within the callback chain of a 17 | ShapleyValue attribution call""" 18 | 19 | pass 20 | -------------------------------------------------------------------------------- /captum/_utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from captum._utils.models.model import Model 3 | 4 | __all__ = [ 5 | "Model", 6 | ] 7 | -------------------------------------------------------------------------------- /captum/_utils/models/linear_model/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from captum._utils.models.linear_model.model import ( 3 | LinearModel, 4 | SGDLasso, 5 | SGDLinearModel, 6 | SGDLinearRegression, 7 | SGDRidge, 8 | SkLearnLasso, 9 | SkLearnLinearModel, 10 | SkLearnLinearRegression, 11 | SkLearnRidge, 12 | ) 13 | 14 | __all__ = [ 15 | "LinearModel", 16 | "SGDLinearModel", 17 | "SGDLasso", 18 | "SGDRidge", 19 | "SGDLinearRegression", 20 | "SkLearnLinearModel", 21 | "SkLearnLasso", 22 | "SkLearnRidge", 23 | "SkLearnLinearRegression", 24 | ] 25 | -------------------------------------------------------------------------------- /captum/_utils/models/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Dict, Optional, Union 7 | 8 | from captum._utils.typing import TensorOrTupleOfTensorsGeneric 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader 11 | 12 | 13 | class Model(ABC): 14 | r""" 15 | Abstract Class to describe the interface of a trainable model to be used 16 | within the algorithms of captum. 17 | 18 | Please note that this is an experimental feature. 19 | """ 20 | 21 | @abstractmethod 22 | def fit( 23 | self, 24 | train_data: DataLoader, 25 | # pyre-fixme[2]: Parameter must be annotated. 26 | **kwargs, 27 | ) -> Optional[Dict[str, Union[int, float, Tensor]]]: 28 | r""" 29 | Override this method to actually train your model. 30 | 31 | The specification of the dataloader will be supplied by the algorithm 32 | you are using within captum. This will likely be a supervised learning 33 | task, thus you should expect batched (x, y) pairs or (x, y, w) triples. 34 | 35 | Args: 36 | train_data (DataLoader): 37 | The data to train on 38 | 39 | Returns: 40 | Optional statistics about training, e.g. iterations it took to 41 | train, training loss, etc. 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def representation(self) -> Tensor: 47 | r""" 48 | Returns the underlying representation of the interpretable model. For a 49 | linear model this is simply a tensor (the concatenation of weights 50 | and bias). For something slightly more complicated, such as a decision 51 | tree, this could be the nodes of a decision tree. 52 | 53 | Returns: 54 | A Tensor describing the representation of the model. 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def __call__( 60 | self, x: TensorOrTupleOfTensorsGeneric 61 | ) -> TensorOrTupleOfTensorsGeneric: 62 | r""" 63 | Predicts with the interpretable model. 64 | 65 | Args: 66 | x (TensorOrTupleOfTensorsGeneric) 67 | A batched input of tensor(s) to the model to predict 68 | Returns: 69 | The prediction of the input as a TensorOrTupleOfTensorsGeneric. 70 | """ 71 | pass 72 | -------------------------------------------------------------------------------- /captum/_utils/typing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from collections import UserDict 6 | from typing import ( 7 | List, 8 | Literal, 9 | Optional, 10 | overload, 11 | Protocol, 12 | Tuple, 13 | TYPE_CHECKING, 14 | TypeVar, 15 | Union, 16 | ) 17 | 18 | from torch import Tensor 19 | from torch.nn import Module 20 | 21 | TensorOrTupleOfTensorsGeneric = TypeVar( 22 | "TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...] 23 | ) 24 | TupleOrTensorOrBoolGeneric = TypeVar( 25 | "TupleOrTensorOrBoolGeneric", Tuple[Tensor, ...], Tensor, bool 26 | ) 27 | PassThroughOutputType = TypeVar("PassThroughOutputType") 28 | ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module]) 29 | TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]] 30 | BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]] 31 | BaselineType = Union[None, Tensor, int, float, BaselineTupleType] 32 | 33 | TensorLikeList1D = List[float] 34 | TensorLikeList2D = List[TensorLikeList1D] 35 | TensorLikeList3D = List[TensorLikeList2D] 36 | TensorLikeList4D = List[TensorLikeList3D] 37 | TensorLikeList5D = List[TensorLikeList4D] 38 | TensorLikeList = Union[ 39 | TensorLikeList1D, 40 | TensorLikeList2D, 41 | TensorLikeList3D, 42 | TensorLikeList4D, 43 | TensorLikeList5D, 44 | ] 45 | 46 | try: 47 | # Subscripted slice syntax is not supported in previous Python versions, 48 | # falling back to slice type. 49 | SliceIntType = slice[int, int, int] 50 | except TypeError: 51 | # pyre-ignore[24]: Generic type `slice` expects 3 type parameters. 52 | SliceIntType = slice # type: ignore 53 | 54 | # Necessary for Python >=3.7 and <3.9! 55 | if TYPE_CHECKING: 56 | BatchEncodingType = UserDict[Union[int, str], object] 57 | else: 58 | BatchEncodingType = UserDict 59 | 60 | 61 | class TokenizerLike(Protocol): 62 | """A protocol for tokenizer-like objects that can be used with Captum 63 | LLM attribution methods.""" 64 | 65 | @overload 66 | def encode( 67 | self, text: str, add_special_tokens: bool = ..., return_tensors: None = ... 68 | ) -> List[int]: ... 69 | 70 | @overload 71 | def encode( 72 | self, 73 | text: str, 74 | add_special_tokens: bool = ..., 75 | return_tensors: Literal["pt"] = ..., 76 | ) -> Tensor: ... 77 | 78 | def encode( 79 | self, 80 | text: str, 81 | add_special_tokens: bool = True, 82 | return_tensors: Optional[str] = None, 83 | ) -> Union[List[int], Tensor]: ... 84 | 85 | def decode(self, token_ids: Tensor) -> str: ... 86 | 87 | @overload 88 | def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: ... 89 | @overload 90 | def convert_ids_to_tokens(self, token_ids: int) -> str: ... 91 | 92 | def convert_ids_to_tokens( 93 | self, token_ids: Union[List[int], int] 94 | ) -> Union[List[str], str]: ... 95 | 96 | @overload 97 | def convert_tokens_to_ids(self, tokens: str) -> int: ... 98 | @overload 99 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ... 100 | 101 | def convert_tokens_to_ids( 102 | self, tokens: Union[List[str], str] 103 | ) -> Union[List[int], int]: ... 104 | 105 | def __call__( 106 | self, 107 | text: Optional[Union[str, List[str], List[List[str]]]] = None, 108 | add_special_tokens: bool = True, 109 | return_offsets_mapping: bool = False, 110 | ) -> BatchEncodingType: ... 111 | -------------------------------------------------------------------------------- /captum/attr/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/attr/_core/__init__.py -------------------------------------------------------------------------------- /captum/attr/_core/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/attr/_core/layer/__init__.py -------------------------------------------------------------------------------- /captum/attr/_core/neuron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/attr/_core/neuron/__init__.py -------------------------------------------------------------------------------- /captum/attr/_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/attr/_models/__init__.py -------------------------------------------------------------------------------- /captum/attr/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/attr/_utils/__init__.py -------------------------------------------------------------------------------- /captum/attr/_utils/baselines.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-strict 4 | import random 5 | from typing import Any, Dict, List, Tuple, Union 6 | 7 | 8 | class ProductBaselines: 9 | """ 10 | A Callable Baselines class that returns a sample from the Cartesian product of 11 | the inputs' available baselines. 12 | 13 | Args: 14 | baseline_values (List or Dict): A list or dict of lists containing 15 | the possible values for each feature. If a dict is provided, the keys 16 | can a string of the feature name and the values is a list of available 17 | baselines. The keys can also be a tuple of strings to group 18 | multiple features whose baselines are not independent to each other. 19 | If the key is a tuple, the value must be a list of tuples of 20 | the corresponding values. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 26 | baseline_values: Union[ 27 | List[List[Any]], 28 | Dict[Union[str, Tuple[str, ...]], List[Any]], 29 | ], 30 | ) -> None: 31 | if isinstance(baseline_values, dict): 32 | dict_keys = list(baseline_values.keys()) 33 | baseline_values = [baseline_values[k] for k in dict_keys] 34 | else: 35 | dict_keys = [] 36 | 37 | # pyre-fixme[4]: Attribute must be annotated. 38 | self.dict_keys = dict_keys 39 | self.baseline_values = baseline_values 40 | 41 | # pyre-fixme[3]: Return annotation cannot contain `Any`. 42 | def sample(self) -> Union[List[Any], Dict[str, Any]]: 43 | baselines = [ 44 | random.choice(baseline_list) for baseline_list in self.baseline_values 45 | ] 46 | 47 | if not self.dict_keys: 48 | return baselines 49 | 50 | dict_baselines = {} 51 | for key, val in zip(self.dict_keys, baselines): 52 | if not isinstance(key, tuple): 53 | key, val = (key,), (val,) 54 | 55 | for k, v in zip(key, val): 56 | dict_baselines[k] = v 57 | 58 | return dict_baselines 59 | 60 | # pyre-fixme[3]: Return annotation cannot contain `Any`. 61 | def __call__(self) -> Union[List[Any], Dict[str, Any]]: 62 | """ 63 | Returns: 64 | 65 | baselines (List or Dict): A sample from the Cartesian product of 66 | the inputs' available baselines 67 | """ 68 | return self.sample() 69 | -------------------------------------------------------------------------------- /captum/attr/_utils/custom_modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | class Addition_Module(nn.Module): 9 | """Custom addition module that uses multiple inputs to assure correct relevance 10 | propagation. Any addition in a forward function needs to be replaced with the 11 | module before using LRP.""" 12 | 13 | def __init__(self) -> None: 14 | super().__init__() 15 | 16 | def forward(self, x1: Tensor, x2: Tensor) -> Tensor: 17 | return x1 + x2 18 | -------------------------------------------------------------------------------- /captum/attr/_utils/input_layer_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | import inspect 6 | from typing import List 7 | 8 | import torch.nn as nn 9 | from torch import Tensor 10 | 11 | 12 | class InputIdentity(nn.Module): 13 | def __init__(self, input_name: str) -> None: 14 | r""" 15 | The identity operation 16 | 17 | Args: 18 | input_name (str) 19 | The name of the input this layer is associated to. For debugging 20 | purposes. 21 | """ 22 | super().__init__() 23 | self.input_name = input_name 24 | 25 | def forward(self, x: Tensor) -> Tensor: 26 | return x 27 | 28 | 29 | class ModelInputWrapper(nn.Module): 30 | def __init__(self, module_to_wrap: nn.Module) -> None: 31 | r""" 32 | This is a convenience class. This wraps a model via first feeding the 33 | model's inputs to separate layers (one for each input) and then feeding 34 | the (unmodified) inputs to the underlying model (`module_to_wrap`). Each 35 | input is fed through an `InputIdentity` layer/module. This class does 36 | not change how you feed inputs to your model, so feel free to use your 37 | model as you normally would. 38 | 39 | To access a wrapped input layer, simply access it via the `input_maps` 40 | ModuleDict, e.g. to get the corresponding module for input "x", simply 41 | provide/write `my_wrapped_module.input_maps["x"]` 42 | 43 | This is done such that one can use layer attribution methods on inputs. 44 | Which should allow you to use mix layers with inputs with these 45 | attribution methods. This is especially useful multimodal models which 46 | input discrete features (mapped to embeddings, such as text) and regular 47 | continuous feature vectors. 48 | 49 | Notes: 50 | - Since inputs are mapped with the identity, attributing to the 51 | input/feature can be done with either the input or output of the 52 | layer, e.g. attributing to an input/feature doesn't depend on whether 53 | attribute_to_layer_input is True or False for 54 | LayerIntegratedGradients. 55 | - Please refer to the multimodal tutorial or unit tests 56 | (test/attr/test_layer_wrapper.py) for an example. 57 | 58 | Args: 59 | module_to_wrap (nn.Module): 60 | The model/module you want to wrap 61 | """ 62 | super().__init__() 63 | self.module = module_to_wrap 64 | 65 | # ignore self 66 | self.arg_name_list: List[str] = inspect.getfullargspec( 67 | module_to_wrap.forward 68 | ).args[1:] 69 | self.input_maps = nn.ModuleDict( 70 | {arg_name: InputIdentity(arg_name) for arg_name in self.arg_name_list} 71 | ) 72 | 73 | def forward(self, *args: object, **kwargs: object) -> object: 74 | args_list = list(args) 75 | for idx, (arg_name, arg) in enumerate(zip(self.arg_name_list, args_list)): 76 | args_list[idx] = self.input_maps[arg_name](arg) 77 | 78 | for arg_name in kwargs.keys(): 79 | kwargs[arg_name] = self.input_maps[arg_name](kwargs[arg_name]) 80 | 81 | return self.module(*tuple(args_list), **kwargs) 82 | -------------------------------------------------------------------------------- /captum/concept/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | from captum.concept._core.cav import CAV 5 | from captum.concept._core.concept import Concept, ConceptInterpreter 6 | from captum.concept._core.tcav import TCAV 7 | from captum.concept._utils.classifier import Classifier, DefaultClassifier 8 | 9 | __all__ = [ 10 | "CAV", 11 | "Concept", 12 | "ConceptInterpreter", 13 | "TCAV", 14 | "Classifier", 15 | "DefaultClassifier", 16 | ] 17 | -------------------------------------------------------------------------------- /captum/concept/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/concept/_core/__init__.py -------------------------------------------------------------------------------- /captum/concept/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/concept/_utils/__init__.py -------------------------------------------------------------------------------- /captum/concept/_utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from typing import List 6 | 7 | from captum.concept._core.concept import Concept 8 | 9 | 10 | def concepts_to_str(concepts: List[Concept]) -> str: 11 | r""" 12 | Returns a string of hyphen("-") concatenated concept names. 13 | Example output: "striped-random_0-random_1" 14 | 15 | Args: 16 | concepts (list[Concept]): a List of concept names to be 17 | concatenated and used as a concepts key. These concept 18 | names are respective to the Concept objects used for 19 | the classifier train. 20 | Returns: 21 | names_str (str): A string of hyphen("-") concatenated 22 | concept names. Ex.: "striped-random_0-random_1" 23 | """ 24 | 25 | return "-".join([str(c.id) for c in concepts]) 26 | -------------------------------------------------------------------------------- /captum/concept/_utils/data_iterator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | import glob 6 | import os 7 | from typing import Callable, Iterator 8 | 9 | from torch import Tensor 10 | from torch.utils.data import DataLoader, Dataset, IterableDataset 11 | 12 | 13 | class CustomIterableDataset(IterableDataset): 14 | r""" 15 | An auxiliary class for iterating through a dataset. 16 | """ 17 | 18 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. 19 | def __init__(self, transform_filename_to_tensor: Callable, path: str) -> None: 20 | r""" 21 | Args: 22 | transform_filename_to_tensor (Callable): Function to read a data 23 | file from path and return a tensor from that file. 24 | path (str): Path to dataset files. This can be either a path to a 25 | directory or a file where input examples are stored. 26 | """ 27 | # pyre-fixme[4]: Attribute must be annotated. 28 | self.file_itr = None 29 | self.path = path 30 | 31 | if os.path.isdir(self.path): 32 | self.file_itr = glob.glob(self.path + "*") 33 | 34 | self.transform_filename_to_tensor = transform_filename_to_tensor 35 | 36 | def __iter__(self) -> Iterator[Tensor]: 37 | r""" 38 | Returns: 39 | iter (Iterator[Tensor]): A map from a function that 40 | processes a list of file path(s) to a list of Tensors. 41 | """ 42 | if self.file_itr is not None: 43 | return map(self.transform_filename_to_tensor, self.file_itr) 44 | else: 45 | return self.transform_filename_to_tensor(self.path) 46 | 47 | 48 | def dataset_to_dataloader(dataset: Dataset, batch_size: int = 64) -> DataLoader: 49 | r""" 50 | An auxiliary function that creates torch DataLoader from torch Dataset 51 | using input `batch_size`. 52 | 53 | Args: 54 | dataset (Dataset): A torch dataset that allows to iterate over 55 | the batches of examples. 56 | batch_size (int, optional): Batch size of for each tensor in the 57 | iteration. 58 | 59 | Returns: 60 | dataloader_iter (DataLoader): a DataLoader for data iteration. 61 | """ 62 | 63 | return DataLoader(dataset, batch_size=batch_size) 64 | -------------------------------------------------------------------------------- /captum/influence/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from captum.influence._core.influence import DataInfluence 6 | from captum.influence._core.influence_function import NaiveInfluenceFunction 7 | from captum.influence._core.similarity_influence import SimilarityInfluence 8 | from captum.influence._core.tracincp import TracInCP, TracInCPBase 9 | from captum.influence._core.tracincp_fast_rand_proj import ( 10 | TracInCPFast, 11 | TracInCPFastRandProj, 12 | ) 13 | 14 | __all__ = [ 15 | "DataInfluence", 16 | "SimilarityInfluence", 17 | "TracInCPBase", 18 | "TracInCP", 19 | "TracInCPFast", 20 | "TracInCPFastRandProj", 21 | "NaiveInfluenceFunction", 22 | ] 23 | -------------------------------------------------------------------------------- /captum/influence/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/influence/_core/__init__.py -------------------------------------------------------------------------------- /captum/influence/_core/influence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Any, Type 7 | 8 | from torch.nn import Module 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class DataInfluence(ABC): 13 | r""" 14 | An abstract class to define model data influence skeleton. 15 | """ 16 | 17 | def __init_(self, model: Module, train_dataset: Dataset, **kwargs: Any) -> None: 18 | r""" 19 | Args: 20 | model (torch.nn.Module): An instance of pytorch model. 21 | train_dataset (torch.utils.data.Dataset): PyTorch Dataset that is 22 | used to create a PyTorch Dataloader to iterate over the dataset and 23 | its labels. This is the dataset for which we will be seeking for 24 | influential instances. In most cases this is the training dataset. 25 | **kwargs: Additional key-value arguments that are necessary for specific 26 | implementation of `DataInfluence` abstract class. 27 | """ 28 | # pyre-fixme[16]: `DataInfluence` has no attribute `model`. 29 | self.model = model 30 | # pyre-fixme[16]: `DataInfluence` has no attribute `train_dataset`. 31 | self.train_dataset = train_dataset 32 | 33 | @abstractmethod 34 | # pyre-fixme[3]: Return annotation cannot be `Any`. 35 | # pyre-fixme[2]: Parameter annotation cannot be `Any`. 36 | def influence(self, inputs: Any = None, **kwargs: Any) -> Any: 37 | r""" 38 | Args: 39 | inputs (Any): Batch of examples for which influential 40 | instances are computed. They are passed to the forward_func. If 41 | `inputs` if a tensor or tuple of tensors, the first dimension 42 | of a tensor corresponds to the batch dimension. 43 | **kwargs: Additional key-value arguments that are necessary for specific 44 | implementation of `DataInfluence` abstract class. 45 | 46 | Returns: 47 | influences (Any): We do not add restrictions on the return type for now, 48 | though this may change in the future. 49 | """ 50 | pass 51 | 52 | @classmethod 53 | def get_name(cls: Type["DataInfluence"]) -> str: 54 | r""" 55 | Create readable class name. Due to the nature of the names of `TracInCPBase` 56 | subclasses, simply returns the class name. For example, for a class called 57 | TracInCP, we return the string TracInCP. 58 | 59 | Returns: 60 | name (str): a readable class name 61 | """ 62 | return cls.__name__ 63 | -------------------------------------------------------------------------------- /captum/influence/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/influence/_utils/__init__.py -------------------------------------------------------------------------------- /captum/insights/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from captum.insights.attr_vis import AttributionVisualizer, Batch, features 3 | 4 | __all__ = [ 5 | "AttributionVisualizer", 6 | "Batch", 7 | "features", 8 | ] 9 | -------------------------------------------------------------------------------- /captum/insights/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/insights/_utils/__init__.py -------------------------------------------------------------------------------- /captum/insights/attr_vis/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from captum.insights.attr_vis.app import AttributionVisualizer, Batch 3 | 4 | __all__ = [ 5 | "AttributionVisualizer", 6 | "Batch", 7 | ] 8 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/insights/attr_vis/_utils/__init__.py -------------------------------------------------------------------------------- /captum/insights/attr_vis/_utils/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from typing import Callable, List, Optional, Union 6 | 7 | 8 | def format_transforms( 9 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. 10 | transforms: Optional[Union[Callable, List[Callable]]], 11 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. 12 | ) -> List[Callable]: 13 | if transforms is None: 14 | return [] 15 | if callable(transforms): 16 | return [transforms] 17 | return transforms 18 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union 5 | 6 | from captum.attr._core import ( 7 | deep_lift, 8 | feature_ablation, 9 | guided_backprop_deconvnet, 10 | input_x_gradient, 11 | integrated_gradients, 12 | occlusion, 13 | saliency, 14 | ) 15 | from captum.attr._utils.approximation_methods import SUPPORTED_METHODS 16 | 17 | 18 | class NumberConfig(NamedTuple): 19 | value: int = 1 20 | limit: Tuple[Optional[int], Optional[int]] = (None, None) 21 | type: str = "number" 22 | 23 | 24 | class StrEnumConfig(NamedTuple): 25 | value: str 26 | limit: List[str] 27 | type: str = "enum" 28 | 29 | 30 | class StrConfig(NamedTuple): 31 | value: str 32 | type: str = "string" 33 | 34 | 35 | Config = Union[NumberConfig, StrEnumConfig, StrConfig] 36 | 37 | SUPPORTED_ATTRIBUTION_METHODS = [ 38 | guided_backprop_deconvnet.Deconvolution, 39 | deep_lift.DeepLift, 40 | guided_backprop_deconvnet.GuidedBackprop, 41 | input_x_gradient.InputXGradient, 42 | integrated_gradients.IntegratedGradients, 43 | saliency.Saliency, 44 | feature_ablation.FeatureAblation, 45 | occlusion.Occlusion, 46 | ] 47 | 48 | 49 | # pyre-fixme[2]: Parameter annotation cannot contain `Any`. 50 | class ConfigParameters(NamedTuple): 51 | params: Dict[str, Config] 52 | help_info: Optional[str] = None # TODO fill out help for each method 53 | # pyre-fixme[4]: Attribute annotation cannot contain `Any`. 54 | post_process: Optional[Dict[str, Callable[[Any], Any]]] = None 55 | 56 | 57 | ATTRIBUTION_NAMES_TO_METHODS: Dict[ 58 | str, 59 | Type[ 60 | Union[ 61 | deep_lift.DeepLift, 62 | feature_ablation.FeatureAblation, 63 | guided_backprop_deconvnet.Deconvolution, 64 | guided_backprop_deconvnet.GuidedBackprop, 65 | input_x_gradient.InputXGradient, 66 | integrated_gradients.IntegratedGradients, 67 | saliency.Saliency, 68 | ] 69 | ], 70 | ] = { 71 | # mypy bug - treating it as a type instead of a class 72 | cls.get_name(): cls # type: ignore 73 | for cls in SUPPORTED_ATTRIBUTION_METHODS 74 | } 75 | 76 | 77 | def _str_to_tuple(s: Tuple[int, ...]) -> Tuple[int, ...]: 78 | if isinstance(s, tuple): 79 | return s 80 | return tuple([int(i) for i in s.split()]) 81 | 82 | 83 | ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = { 84 | integrated_gradients.IntegratedGradients.get_name(): ConfigParameters( 85 | params={ 86 | "n_steps": NumberConfig(value=25, limit=(2, None)), 87 | "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), 88 | }, 89 | post_process={"n_steps": int}, 90 | ), 91 | feature_ablation.FeatureAblation.get_name(): ConfigParameters( 92 | params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}, 93 | ), 94 | occlusion.Occlusion.get_name(): ConfigParameters( 95 | params={ 96 | "sliding_window_shapes": StrConfig(value=""), 97 | "strides": StrConfig(value=""), 98 | "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), 99 | }, 100 | post_process={ 101 | "sliding_window_shapes": _str_to_tuple, 102 | "strides": _str_to_tuple, 103 | "perturbations_per_eval": int, 104 | }, 105 | ), 106 | } 107 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | import os 5 | from typing import List 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from captum.insights import AttributionVisualizer, Batch 12 | 13 | from captum.insights.attr_vis.features import ImageFeature 14 | 15 | 16 | def get_classes() -> List[str]: 17 | classes = [ 18 | "Plane", 19 | "Car", 20 | "Bird", 21 | "Cat", 22 | "Deer", 23 | "Dog", 24 | "Frog", 25 | "Horse", 26 | "Ship", 27 | "Truck", 28 | ] 29 | return classes 30 | 31 | 32 | class Net(nn.Module): 33 | def __init__(self) -> None: 34 | super(Net, self).__init__() 35 | self.conv1 = nn.Conv2d(3, 6, 5) 36 | self.pool1 = nn.MaxPool2d(2, 2) 37 | self.pool2 = nn.MaxPool2d(2, 2) 38 | self.conv2 = nn.Conv2d(6, 16, 5) 39 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 40 | self.fc2 = nn.Linear(120, 84) 41 | self.fc3 = nn.Linear(84, 10) 42 | self.relu1 = nn.ReLU() 43 | self.relu2 = nn.ReLU() 44 | self.relu3 = nn.ReLU() 45 | self.relu4 = nn.ReLU() 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | x = self.pool1(self.relu1(self.conv1(x))) 49 | x = self.pool2(self.relu2(self.conv2(x))) 50 | x = x.view(-1, 16 * 5 * 5) 51 | x = self.relu3(self.fc1(x)) 52 | x = self.relu4(self.fc2(x)) 53 | x = self.fc3(x) 54 | return x 55 | 56 | 57 | def get_pretrained_model() -> Net: 58 | net = Net() 59 | pt_path = os.path.abspath( 60 | os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt") 61 | ) 62 | net.load_state_dict(torch.load(pt_path)) 63 | return net 64 | 65 | 66 | # pyre-fixme[3]: Return type must be annotated. 67 | # pyre-fixme[2]: Parameter must be annotated. 68 | def baseline_func(input): 69 | return input * 0 70 | 71 | 72 | # pyre-fixme[3]: Return type must be annotated. 73 | def formatted_data_iter(): 74 | dataset = torchvision.datasets.CIFAR10( 75 | root="data/test", train=False, download=True, transform=transforms.ToTensor() 76 | ) 77 | dataloader = iter( 78 | torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2) 79 | ) 80 | while True: 81 | images, labels = next(dataloader) 82 | yield Batch(inputs=images, labels=labels) 83 | 84 | 85 | def main() -> None: 86 | normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 87 | model = get_pretrained_model() 88 | visualizer = AttributionVisualizer( 89 | models=[model], 90 | score_func=lambda o: torch.nn.functional.softmax(o, 1), 91 | classes=get_classes(), 92 | features=[ 93 | ImageFeature( 94 | "Photo", 95 | baseline_transforms=[baseline_func], 96 | input_transforms=[normalize], 97 | ) 98 | ], 99 | dataset=formatted_data_iter(), 100 | ) 101 | 102 | visualizer.serve(debug=True) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Captum Insights 2 | 3 | ## Development 4 | 5 | Captum Insights frontend is built on top of React, using create-react-app. 6 | 7 | `yarn build` to compile files. 8 | 9 | Javascript should be prettier formatted. 10 | CSS uses BEM syntax. 11 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "version": "0.8.0", 4 | "private": true, 5 | "homepage": ".", 6 | "dependencies": { 7 | "@babel/helper-call-delegate": "^7.8.7", 8 | "@babel/plugin-proposal-class-properties": "^7.5.5", 9 | "@babel/preset-typescript": "^7.10.1", 10 | "@types/jest": "^26.0.0", 11 | "@types/node": "^14.0.13", 12 | "@types/react": "^16.9.38", 13 | "@types/react-dom": "^16.9.8", 14 | "@types/react-tag-autocomplete": "^5.12.0", 15 | "babel-loader": "^8.0.6", 16 | "chart.js": "^2.9.4", 17 | "css-loader": "3.3.0", 18 | "js-levenshtein": "^1.1.6", 19 | "react": "^16.9.0", 20 | "react-chartjs-2": "^2.11.1", 21 | "react-dom": "^16.9.0", 22 | "react-scripts": "3.4.1", 23 | "react-tag-autocomplete": "^5.11.1", 24 | "typescript": "^3.9.5", 25 | "webpack-cli": "^3.3.9" 26 | }, 27 | "scripts": { 28 | "start": "react-scripts start", 29 | "build": "react-scripts build", 30 | "test": "react-scripts test", 31 | "eject": "react-scripts eject" 32 | }, 33 | "eslintConfig": { 34 | "extends": "react-app" 35 | }, 36 | "browserslist": { 37 | "production": [ 38 | ">0.2%", 39 | "not dead", 40 | "not op_mini all" 41 | ], 42 | "development": [ 43 | "last 1 chrome version", 44 | "last 1 firefox version", 45 | "last 1 safari version" 46 | ] 47 | }, 48 | "devDependencies": { 49 | "@types/chart.js": "^2.9.29" 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Captum Insights 7 | 8 | 9 | 10 |
11 | 12 | 13 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/App.css: -------------------------------------------------------------------------------- 1 | .react-tags { 2 | position: relative; 3 | padding: 6px 0 0 6px; 4 | border: 1px solid #d1d1d1; 5 | border-radius: 1px; 6 | 7 | /* shared font styles */ 8 | font-size: 12px; 9 | line-height: 1.2; 10 | 11 | /* clicking anywhere will focus the input */ 12 | cursor: text; 13 | } 14 | 15 | .react-tags.is-focused { 16 | border-color: #b1b1b1; 17 | } 18 | 19 | .react-tags__selected { 20 | display: inline; 21 | } 22 | 23 | .react-tags__selected-tag { 24 | display: inline-block; 25 | box-sizing: border-box; 26 | margin: 0 6px 6px 0; 27 | padding: 6px 8px; 28 | border: 1px solid #d1d1d1; 29 | border-radius: 2px; 30 | background: #f1f1f1; 31 | 32 | /* match the font styles */ 33 | font-size: inherit; 34 | line-height: inherit; 35 | } 36 | 37 | .react-tags__selected-tag:after { 38 | content: "\2715"; 39 | color: #aaa; 40 | margin-left: 8px; 41 | } 42 | 43 | .react-tags__selected-tag:hover, 44 | .react-tags__selected-tag:focus { 45 | border-color: #b1b1b1; 46 | } 47 | 48 | .react-tags__search { 49 | display: inline-block; 50 | 51 | /* match tag layout */ 52 | padding: 7px 2px; 53 | margin-bottom: 6px; 54 | 55 | /* prevent autoresize overflowing the container */ 56 | max-width: 100%; 57 | } 58 | 59 | @media screen and (min-width: 30em) { 60 | .react-tags__search { 61 | /* this will become the offsetParent for suggestions */ 62 | position: relative; 63 | } 64 | } 65 | 66 | .react-tags__search input { 67 | /* prevent autoresize overflowing the container */ 68 | max-width: 100%; 69 | 70 | /* remove styles and layout from this element */ 71 | margin: 0; 72 | padding: 0; 73 | border: 0; 74 | outline: none; 75 | 76 | /* match the font styles */ 77 | font-size: inherit; 78 | line-height: inherit; 79 | } 80 | 81 | .react-tags__search input::-ms-clear { 82 | display: none; 83 | } 84 | 85 | .react-tags__suggestions { 86 | z-index: 9999999; 87 | position: absolute; 88 | top: 100%; 89 | left: 0; 90 | width: 100%; 91 | } 92 | 93 | @media screen and (min-width: 30em) { 94 | .react-tags__suggestions { 95 | width: 100px; 96 | } 97 | } 98 | 99 | .react-tags__suggestions ul { 100 | margin: 4px -1px; 101 | padding: 0; 102 | list-style: none; 103 | background: rgba(255, 255, 255, 0.95); 104 | border: 1px solid #d1d1d1; 105 | border-radius: 2px; 106 | box-shadow: 0 2px 6px rgba(0, 0, 0, 0.2); 107 | } 108 | 109 | .react-tags__suggestions li { 110 | border-bottom: 1px solid #ddd; 111 | padding: 6px 8px; 112 | } 113 | 114 | .react-tags__suggestions li mark { 115 | text-decoration: underline; 116 | background: none; 117 | font-weight: 600; 118 | } 119 | 120 | .react-tags__suggestions li:hover { 121 | cursor: pointer; 122 | background: #eee; 123 | } 124 | 125 | .react-tags__suggestions li.is-active { 126 | background: #b7cfe0; 127 | } 128 | 129 | .react-tags__suggestions li.is-disabled { 130 | opacity: 0.5; 131 | cursor: auto; 132 | } 133 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/App.test.js: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import ReactDOM from "react-dom"; 3 | import WebApp from "./WebApp"; 4 | 5 | it("renders without crashing", () => { 6 | const div = document.createElement("div"); 7 | ReactDOM.render(, div); 8 | ReactDOM.unmountComponentAtNode(div); 9 | }); 10 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/App.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "./App.module.css"; 3 | import Header from "./components/Header"; 4 | import cx from "./utils/cx"; 5 | import Spinner from "./components/Spinner"; 6 | import FilterContainer from "./components/FilterContainer"; 7 | import VisualizationGroupDisplay from "./components/VisualizationGroup"; 8 | import "./App.css"; 9 | import { VisualizationGroup } from "./models/visualizationOutput"; 10 | import { FilterConfig } from "./models/filter"; 11 | 12 | interface VisualizationsProps { 13 | loading: boolean; 14 | data: VisualizationGroup[]; 15 | onTargetClick: ( 16 | labelIndex: number, 17 | inputIndex: number, 18 | modelIndex: number, 19 | callback: () => void 20 | ) => void; 21 | } 22 | 23 | function Visualizations(props: VisualizationsProps) { 24 | if (props.loading) { 25 | return ( 26 |
27 |
28 | 29 |
30 |
31 | ); 32 | } 33 | 34 | if (!props.data || props.data.length === 0) { 35 | return ( 36 |
37 |
38 |
39 | Please press{" "} 40 | Fetch to 41 | start loading data. 42 |
43 |
44 |
45 | ); 46 | } 47 | return ( 48 |
49 | {props.data.map((vg, i) => ( 50 | 56 | ))} 57 |
58 | ); 59 | } 60 | 61 | interface AppBaseProps { 62 | fetchInit: () => void; 63 | fetchData: (filter_config: FilterConfig) => void; 64 | config: any; 65 | data: VisualizationGroup[]; 66 | loading: boolean; 67 | onTargetClick: ( 68 | labelIndex: number, 69 | inputIndex: number, 70 | modelIndex: number, 71 | callback: () => void 72 | ) => void; 73 | } 74 | 75 | class AppBase extends React.Component { 76 | componentDidMount() { 77 | this.props.fetchInit(); 78 | } 79 | 80 | render() { 81 | return ( 82 |
83 |
84 | 89 | 94 |
95 | ); 96 | } 97 | } 98 | 99 | export default AppBase; 100 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/WebApp.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import AppBase from "./App"; 3 | import { FilterConfig } from "./models/filter"; 4 | import { VisualizationGroup } from "./models/visualizationOutput"; 5 | import { InsightsConfig } from "./models/insightsConfig"; 6 | 7 | interface WebAppState { 8 | data: VisualizationGroup[]; 9 | config: InsightsConfig; 10 | loading: boolean; 11 | } 12 | 13 | class WebApp extends React.Component<{}, WebAppState> { 14 | constructor(props: {}) { 15 | super(props); 16 | this.state = { 17 | data: [], 18 | config: { 19 | classes: [], 20 | methods: [], 21 | method_arguments: {}, 22 | selected_method: "", 23 | }, 24 | loading: false, 25 | }; 26 | this._fetchInit(); 27 | } 28 | 29 | _fetchInit = () => { 30 | fetch("init") 31 | .then((r) => r.json()) 32 | .then((r) => this.setState({ config: r })); 33 | }; 34 | 35 | fetchData = (filter_config: FilterConfig) => { 36 | this.setState({ loading: true }); 37 | fetch("fetch", { 38 | method: "POST", 39 | headers: { 40 | "Content-Type": "application/json", 41 | }, 42 | body: JSON.stringify(filter_config), 43 | }) 44 | .then((response) => response.json()) 45 | .then((response) => this.setState({ data: response, loading: false })); 46 | }; 47 | 48 | onTargetClick = ( 49 | labelIndex: number, 50 | inputIndex: number, 51 | modelIndex: number, 52 | callback: () => void 53 | ) => { 54 | fetch("attribute", { 55 | method: "POST", 56 | headers: { 57 | "Content-Type": "application/json", 58 | }, 59 | body: JSON.stringify({ labelIndex, inputIndex, modelIndex }), 60 | }) 61 | .then((response) => response.json()) 62 | .then((response) => { 63 | const data = this.state.data ?? []; 64 | data[inputIndex][modelIndex] = response; 65 | this.setState({ data }); 66 | callback(); 67 | }); 68 | }; 69 | 70 | render() { 71 | return ( 72 | 80 | ); 81 | } 82 | } 83 | 84 | export default WebApp; 85 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/Arguments.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "../App.module.css"; 3 | import cx from "../utils/cx"; 4 | import { GenericArgumentConfig } from "../models/insightsConfig"; 5 | import { UserInputField } from "../models/typeHelpers"; 6 | 7 | interface ArgumentProps { 8 | name: string; 9 | handleInputChange: React.ChangeEventHandler; 10 | } 11 | 12 | function NumberArgument(props: ArgumentProps & GenericArgumentConfig) { 13 | var min = props.limit[0]; 14 | var max = props.limit[1]; 15 | return ( 16 |
17 | {props.name}: 18 | 27 |
28 | ); 29 | } 30 | 31 | function EnumArgument(props: ArgumentProps & GenericArgumentConfig) { 32 | const options = props.limit.map((item, key) => ( 33 | 34 | )); 35 | return ( 36 |
37 | {props.name}: 38 | 46 |
47 | ); 48 | } 49 | 50 | function StringArgument(props: ArgumentProps & { value: string }) { 51 | return ( 52 |
53 | {props.name}: 54 | 61 |
62 | ); 63 | } 64 | 65 | export { StringArgument, EnumArgument, NumberArgument }; 66 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/ClassFilter.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import ReactTags from "react-tag-autocomplete"; 3 | import { TagClass } from "../models/filter"; 4 | 5 | interface ClassFilterProps { 6 | suggestedClasses: TagClass[]; 7 | classes: TagClass[]; 8 | handleClassDelete: (classId: number) => void; 9 | handleClassAdd: (newClass: TagClass) => void; 10 | } 11 | 12 | function ClassFilter(props: ClassFilterProps) { 13 | const handleAddition = (newTag: { id: number | string; name: string }) => { 14 | /** 15 | * Need this type check as we expect tagId to be number while the `react-tag-autocomplete` has 16 | * id as number | string. 17 | */ 18 | if (typeof newTag.id === "string") { 19 | throw Error("Invalid tag id received from ReactTags"); 20 | } else { 21 | props.handleClassAdd({ id: newTag.id, name: newTag.name }); 22 | } 23 | }; 24 | 25 | return ( 26 | 35 | ); 36 | } 37 | 38 | export default ClassFilter; 39 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/Contributions.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "../App.module.css"; 3 | import { calcHSLFromScore } from "../utils/color"; 4 | import { FeatureOutput } from "../models/visualizationOutput"; 5 | 6 | interface ContributionsProps { 7 | feature_outputs: FeatureOutput[]; 8 | } 9 | 10 | function Contributions(props: ContributionsProps) { 11 | return ( 12 | <> 13 | {props.feature_outputs.map((f) => { 14 | // pad bar height so features with 0 contribution can still be seen 15 | // in graph 16 | const contribution = f.contribution * 100; 17 | const bar_height = contribution > 10 ? contribution : contribution + 10; 18 | return ( 19 |
20 |
27 |
{f.name}
28 |
29 | ); 30 | })} 31 | 32 | ); 33 | } 34 | 35 | export default Contributions; 36 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/Header.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "../App.module.css"; 3 | import cx from "../utils/cx"; 4 | 5 | function Header() { 6 | return ( 7 |
8 |
Captum Insights
9 | 21 |
22 | ); 23 | } 24 | 25 | export default Header; 26 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/LabelButton.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import cx from "../utils/cx"; 3 | import styles from "../App.module.css"; 4 | 5 | interface LabelButtonProps { 6 | labelIndex: number; 7 | inputIndex: number; 8 | modelIndex: number; 9 | active: boolean; 10 | onTargetClick: ( 11 | labelIndex: number, 12 | inputIndex: number, 13 | modelIndex: number 14 | ) => void; 15 | } 16 | 17 | function LabelButton(props: React.PropsWithChildren) { 18 | const onClick = (e: React.MouseEvent) => { 19 | e.preventDefault(); 20 | props.onTargetClick(props.labelIndex, props.inputIndex, props.modelIndex); 21 | }; 22 | 23 | return ( 24 | 34 | ); 35 | } 36 | 37 | export default LabelButton; 38 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/Spinner.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "../App.module.css"; 3 | 4 | function Spinner() { 5 | return
; 6 | } 7 | 8 | export default Spinner; 9 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/Tooltip.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | import styles from "../App.module.css"; 4 | 5 | function Tooltip(props: { label: string }) { 6 | return ( 7 |
8 |
{props.label}
9 |
10 | ); 11 | } 12 | 13 | export default Tooltip; 14 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/components/VisualizationGroup.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import styles from "../App.module.css"; 3 | import cx from "../utils/cx"; 4 | import Visualization from "../components/Visualization"; 5 | import { VisualizationGroup } from "../models/visualizationOutput"; 6 | 7 | interface VisualizationGroupDisplayProps { 8 | inputIndex: number; 9 | data: VisualizationGroup; 10 | onTargetClick: ( 11 | labelIndex: number, 12 | inputIndex: number, 13 | modelIndex: number, 14 | callback: () => void 15 | ) => void; 16 | } 17 | 18 | function VisualizationGroupDisplay(props: VisualizationGroupDisplayProps) { 19 | return ( 20 |
26 | {props.data.map((v, i) => ( 27 | 33 | ))} 34 |
35 | ); 36 | } 37 | 38 | export default VisualizationGroupDisplay; 39 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/index.css: -------------------------------------------------------------------------------- 1 | html, 2 | body { 3 | height: 100%; 4 | } 5 | 6 | body { 7 | margin: 0; 8 | font-family: "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", 9 | "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif; 10 | -webkit-font-smoothing: antialiased; 11 | -moz-osx-font-smoothing: grayscale; 12 | background-color: rgba(0, 0, 0, 0.05); 13 | } 14 | 15 | * { 16 | box-sizing: border-box; 17 | } 18 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/index.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import ReactDOM from "react-dom"; 3 | import "./index.css"; 4 | import WebApp from "./WebApp"; 5 | 6 | ReactDOM.render(, document.getElementById("root")); 7 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/models/filter.ts: -------------------------------------------------------------------------------- 1 | export interface FilterConfig { 2 | attribution_method: string; 3 | arguments: { [key: string]: any }; 4 | prediction: string; 5 | classes: string[]; 6 | } 7 | 8 | export interface TagClass { 9 | id: number; 10 | name: string; 11 | } 12 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/models/insightsConfig.ts: -------------------------------------------------------------------------------- 1 | export enum ArgumentType { 2 | Number = "number", 3 | Enum = "enum", 4 | String = "string", 5 | Boolean = "boolean", 6 | } 7 | 8 | export type GenericArgumentConfig = { 9 | value: T; 10 | limit: T[]; 11 | }; 12 | 13 | export type ArgumentConfig = 14 | | ({ type: ArgumentType.Number } & GenericArgumentConfig) 15 | | ({ type: ArgumentType.Enum } & GenericArgumentConfig) 16 | | ({ type: ArgumentType.String } & { value: string }) 17 | | ({ type: ArgumentType.Boolean } & { value: boolean }); 18 | 19 | export interface MethodsArguments { 20 | [method_name: string]: { 21 | [arg_name: string]: ArgumentConfig; 22 | }; 23 | } 24 | 25 | export interface InsightsConfig { 26 | classes: string[]; 27 | methods: string[]; 28 | method_arguments: MethodsArguments; 29 | selected_method: string; 30 | } 31 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/models/typeHelpers.ts: -------------------------------------------------------------------------------- 1 | export type UserInputField = HTMLInputElement | HTMLSelectElement; 2 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/models/visualizationOutput.ts: -------------------------------------------------------------------------------- 1 | interface OutputScore { 2 | label: string; 3 | index: number; 4 | score: number; 5 | } 6 | 7 | export enum FeatureType { 8 | TEXT = "text", 9 | IMAGE = "image", 10 | GENERAL = "general", 11 | EMPTY = "empty", 12 | } 13 | 14 | type GenericFeatureOutput = { 15 | type: F; 16 | name: string; 17 | contribution: number; 18 | } & T; 19 | 20 | export type FeatureOutput = 21 | | GenericFeatureOutput< 22 | FeatureType.TEXT, 23 | { base: number[]; modified: number[] } 24 | > 25 | | GenericFeatureOutput 26 | | GenericFeatureOutput< 27 | FeatureType.GENERAL, 28 | { base: number[]; modified: number[] } 29 | > 30 | | GenericFeatureOutput; 31 | 32 | export interface VisualizationOutput { 33 | model_index: number; 34 | feature_outputs: FeatureOutput[]; 35 | actual: OutputScore; 36 | predicted: OutputScore[]; 37 | active_index: number; 38 | } 39 | 40 | //When multiple models are compared, visualizations are grouped together 41 | export type VisualizationGroup = VisualizationOutput[]; 42 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/utils/color.ts: -------------------------------------------------------------------------------- 1 | function calcHSLFromScore(percentage: number, zeroDefault = false) { 2 | const blue_hsl = [220, 100, 80]; 3 | const red_hsl = [10, 100, 67]; 4 | 5 | let target_hsl = null; 6 | if (percentage > 0) { 7 | target_hsl = blue_hsl; 8 | } else { 9 | target_hsl = red_hsl; 10 | } 11 | 12 | const default_hsl = [0, 40, zeroDefault ? 100 : 90]; 13 | const abs_percent = Math.abs(percentage * 0.01); 14 | if (abs_percent < 0.02) { 15 | return `hsl(${default_hsl[0]}, ${default_hsl[1]}%, ${default_hsl[2]}%)`; 16 | } 17 | 18 | const color = [ 19 | target_hsl[0], 20 | (target_hsl[1] - default_hsl[1]) * abs_percent + default_hsl[1], 21 | (target_hsl[2] - default_hsl[2]) * abs_percent + default_hsl[2], 22 | ]; 23 | return `hsl(${color[0]}, ${color[1]}%, ${color[2]}%)`; 24 | } 25 | 26 | export { calcHSLFromScore }; 27 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/utils/cx.ts: -------------------------------------------------------------------------------- 1 | // helper method to convert an array or object into a valid classname 2 | function cx(obj: any) { 3 | if (Array.isArray(obj)) { 4 | return obj.join(" "); 5 | } 6 | return Object.keys(obj) 7 | .filter((k) => !!obj[k]) 8 | .join(" "); 9 | } 10 | 11 | export default cx; 12 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/src/utils/dataPoint.ts: -------------------------------------------------------------------------------- 1 | import * as chartjs from "chart.js"; 2 | 3 | // Because there's no data point type exported by the 4 | // main type declaration for chart.js, we have our own. 5 | 6 | export interface DataPoint { 7 | chart?: object; 8 | dataIndex?: number; 9 | dataset?: chartjs.ChartDataSets; 10 | datasetIndex?: number; 11 | } 12 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "skipLibCheck": true, 7 | "esModuleInterop": true, 8 | "allowSyntheticDefaultImports": true, 9 | "strict": true, 10 | "forceConsistentCasingInFileNames": true, 11 | "module": "esnext", 12 | "moduleResolution": "node", 13 | "resolveJsonModule": true, 14 | "isolatedModules": true, 15 | "noEmit": true, 16 | "jsx": "react" 17 | }, 18 | "include": ["src"] 19 | } 20 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/widget/src/Widget.js: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import ReactDOM from "react-dom"; 3 | import AppBase from "../../src/App"; 4 | import * as widgets from "@jupyter-widgets/base"; 5 | import * as _ from "lodash"; 6 | 7 | class Widget extends React.Component { 8 | constructor(props) { 9 | super(props); 10 | this.state = { 11 | data: [], 12 | config: { 13 | classes: [], 14 | methods: [], 15 | method_arguments: {}, 16 | }, 17 | loading: false, 18 | callback: null, 19 | }; 20 | this.backbone = this.props.backbone; 21 | } 22 | 23 | componentDidMount() { 24 | this.backbone.model.on("change:output", this._outputChanged, this); 25 | this.backbone.model.on( 26 | "change:attribution", 27 | this._attributionChanged, 28 | this 29 | ); 30 | } 31 | 32 | _outputChanged(model, output, options) { 33 | if (_.isEmpty(output)) return; 34 | this.setState({ data: output, loading: false }); 35 | } 36 | 37 | _attributionChanged(model, attribution, options) { 38 | if (_.isEmpty(attribution)) return; 39 | const data = Object.assign([], this.state.data); 40 | const callback = this.state.callback; 41 | const labelDetails = model.attributes.label_details; 42 | data[labelDetails.inputIndex][labelDetails.modelIndex] = attribution; 43 | this.setState({ data: data, callback: null }, () => { 44 | callback(); 45 | }); 46 | } 47 | 48 | _fetchInit = () => { 49 | this.setState({ 50 | config: this.backbone.model.get("insights_config"), 51 | }); 52 | }; 53 | 54 | fetchData = (filterConfig) => { 55 | this.setState({ loading: true }, () => { 56 | this.backbone.model.save({ config: filterConfig, output: [] }); 57 | }); 58 | }; 59 | 60 | onTargetClick = (labelIndex, inputIndex, modelIndex, callback) => { 61 | this.setState({ callback: callback }, () => { 62 | this.backbone.model.save({ 63 | label_details: { labelIndex, inputIndex, modelIndex }, 64 | attribution: {}, 65 | }); 66 | }); 67 | }; 68 | 69 | render() { 70 | return ( 71 | 79 | ); 80 | } 81 | } 82 | 83 | var CaptumInsightsModel = widgets.DOMWidgetModel.extend({ 84 | defaults: _.extend(widgets.DOMWidgetModel.prototype.defaults(), { 85 | _model_name: "CaptumInsightsModel", 86 | _view_name: "CaptumInsightsView", 87 | _model_module: "jupyter-captum-insights", 88 | _view_module: "jupyter-captum-insights", 89 | _model_module_version: "0.1.0", 90 | _view_module_version: "0.1.0", 91 | }), 92 | }); 93 | 94 | var CaptumInsightsView = widgets.DOMWidgetView.extend({ 95 | initialize() { 96 | const $app = document.createElement("div"); 97 | ReactDOM.render(, $app); 98 | this.el.append($app); 99 | }, 100 | }); 101 | 102 | export { Widget as default, CaptumInsightsModel, CaptumInsightsView }; 103 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/widget/src/embed.js: -------------------------------------------------------------------------------- 1 | // Entry point for the unpkg bundle containing custom model definitions. 2 | // 3 | // It differs from the notebook bundle in that it does not need to define a 4 | // dynamic baseURL for the static assets and may load some css that would 5 | // already be loaded by the notebook otherwise. 6 | 7 | // Export widget models and views, and the npm package version number. 8 | module.exports = require("./Widget.js"); 9 | module.exports["version"] = require("../../package.json").version; 10 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/widget/src/extension.js: -------------------------------------------------------------------------------- 1 | // This file contains the javascript that is run when the notebook is loaded. 2 | // It contains some requirejs configuration and the `load_ipython_extension` 3 | // which is required for any notebook extension. 4 | // 5 | // Some static assets may be required by the custom widget javascript. The base 6 | // url for the notebook is not known at build time and is therefore computed 7 | // dynamically. 8 | __webpack_public_path__ = 9 | document.querySelector("body").getAttribute("data-base-url") + 10 | "nbextensions/jupyter-captum-insights"; 11 | 12 | // Configure requirejs 13 | if (window.require) { 14 | window.require.config({ 15 | map: { 16 | "*": { 17 | "jupyter-captum-insights": "nbextensions/jupyter-captum-insights/index", 18 | }, 19 | }, 20 | }); 21 | } 22 | 23 | // Export the required load_ipython_extension 24 | module.exports = { 25 | load_ipython_extension: function () {}, 26 | }; 27 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/widget/src/index.js: -------------------------------------------------------------------------------- 1 | module.exports = require("./Widget.js"); 2 | module.exports["version"] = require("../../package.json").version; 3 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/frontend/widget/webpack.config.js: -------------------------------------------------------------------------------- 1 | var path = require("path"); 2 | 3 | // Custom webpack rules are generally the same for all webpack bundles, hence 4 | // stored in a separate local variable. 5 | var rules = [ 6 | { 7 | test: /\.module.css$/, 8 | use: [ 9 | "style-loader", 10 | { 11 | loader: "css-loader", 12 | options: { 13 | modules: true, 14 | }, 15 | }, 16 | ], 17 | }, 18 | { test: /^((?!\.module).)*.css$/, use: ["style-loader", "css-loader"] }, 19 | { 20 | test: /\.(js|ts|tsx)$/, 21 | exclude: /node_modules/, 22 | loaders: "babel-loader", 23 | options: { 24 | presets: [ 25 | "@babel/preset-react", 26 | "@babel/preset-env", 27 | "@babel/preset-typescript", 28 | ], 29 | plugins: ["@babel/plugin-proposal-class-properties"], 30 | }, 31 | }, 32 | ]; 33 | 34 | var extensions = [".js", ".ts", ".tsx"]; 35 | 36 | module.exports = [ 37 | { 38 | // Notebook extension 39 | // 40 | // This bundle only contains the part of the JavaScript that is run on 41 | // load of the notebook. This section generally only performs 42 | // some configuration for requirejs, and provides the legacy 43 | // "load_ipython_extension" function which is required for any notebook 44 | // extension. 45 | // 46 | mode: "production", 47 | entry: "./src/extension.js", 48 | output: { 49 | filename: "extension.js", 50 | path: path.resolve(__dirname, "..", "..", "widget", "static"), 51 | libraryTarget: "amd", 52 | }, 53 | resolveLoader: { 54 | modules: ["../node_modules"], 55 | extensions: extensions, 56 | }, 57 | resolve: { 58 | modules: ["../node_modules"], 59 | }, 60 | externals: ["moment"], // Removes unused dependency-of-dependency 61 | }, 62 | { 63 | // Bundle for the notebook containing the custom widget views and models 64 | // 65 | // This bundle contains the implementation for the custom widget views and 66 | // custom widget. 67 | // It must be an amd module 68 | // 69 | mode: "production", 70 | entry: "./src/index.js", 71 | output: { 72 | filename: "index.js", 73 | path: path.resolve(__dirname, "..", "..", "widget", "static"), 74 | libraryTarget: "amd", 75 | }, 76 | module: { 77 | rules: rules, 78 | }, 79 | resolveLoader: { 80 | modules: ["../node_modules"], 81 | }, 82 | resolve: { 83 | modules: ["../node_modules"], 84 | extensions: extensions, 85 | }, 86 | externals: ["@jupyter-widgets/base", "moment"], 87 | }, 88 | ]; 89 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/models/cifar_torchvision.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/insights/attr_vis/models/cifar_torchvision.pt -------------------------------------------------------------------------------- /captum/insights/attr_vis/widget/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from typing import Dict, List 3 | 4 | from captum.insights.attr_vis.widget._version import __version__, version_info 5 | from captum.insights.attr_vis.widget.widget import CaptumInsights 6 | 7 | 8 | def _jupyter_nbextension_paths() -> List[Dict[str, str]]: 9 | return [ 10 | { 11 | "section": "notebook", 12 | "src": "static", 13 | "dest": "jupyter-captum-insights", 14 | "require": "jupyter-captum-insights/extension", 15 | } 16 | ] 17 | 18 | 19 | __all__ = ["__version__", "version_info", "CaptumInsights"] 20 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/widget/_version.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | version_info = (0, 1, 0, "alpha", 0) 3 | 4 | _specifier_ = {"alpha": "a", "beta": "b", "candidate": "rc", "final": ""} 5 | 6 | __version__: str = "%s.%s.%s%s" % ( 7 | version_info[0], 8 | version_info[1], 9 | version_info[2], 10 | ( 11 | "" 12 | if version_info[3] == "final" 13 | else _specifier_[version_info[3]] + str(version_info[4]) 14 | ), 15 | ) 16 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/widget/jupyter-captum-insights.json: -------------------------------------------------------------------------------- 1 | { 2 | "load_extensions": { 3 | "jupyter-captum-insights/extension": true 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /captum/insights/attr_vis/widget/widget.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | import ipywidgets as widgets 5 | from captum.insights import AttributionVisualizer 6 | from captum.insights.attr_vis.server import namedtuple_to_dict 7 | from traitlets import Dict, Instance, List, observe, Unicode 8 | 9 | 10 | @widgets.register 11 | class CaptumInsights(widgets.DOMWidget): 12 | """A widget for interacting with Captum Insights.""" 13 | 14 | # pyre-fixme[4]: Attribute must be annotated. 15 | _view_name = Unicode("CaptumInsightsView").tag(sync=True) 16 | # pyre-fixme[4]: Attribute must be annotated. 17 | _model_name = Unicode("CaptumInsightsModel").tag(sync=True) 18 | # pyre-fixme[4]: Attribute must be annotated. 19 | _view_module = Unicode("jupyter-captum-insights").tag(sync=True) 20 | # pyre-fixme[4]: Attribute must be annotated. 21 | _model_module = Unicode("jupyter-captum-insights").tag(sync=True) 22 | # pyre-fixme[4]: Attribute must be annotated. 23 | _view_module_version = Unicode("^0.1.0").tag(sync=True) 24 | # pyre-fixme[4]: Attribute must be annotated. 25 | _model_module_version = Unicode("^0.1.0").tag(sync=True) 26 | # pyre-fixme[4]: Attribute must be annotated. 27 | visualizer = Instance(klass=AttributionVisualizer) 28 | 29 | # pyre-fixme[4]: Attribute must be annotated. 30 | insights_config = Dict().tag(sync=True) 31 | # pyre-fixme[4]: Attribute must be annotated. 32 | label_details = Dict().tag(sync=True) 33 | # pyre-fixme[4]: Attribute must be annotated. 34 | attribution = Dict().tag(sync=True) 35 | # pyre-fixme[4]: Attribute must be annotated. 36 | config = Dict().tag(sync=True) 37 | output = List().tag(sync=True) # type: ignore 38 | 39 | # pyre-fixme[2]: Parameter must be annotated. 40 | def __init__(self, **kwargs) -> None: 41 | super(CaptumInsights, self).__init__(**kwargs) 42 | self.insights_config = self.visualizer.get_insights_config() 43 | self.out = widgets.Output() 44 | with self.out: 45 | print("Captum Insights widget created.") 46 | 47 | @observe("config") 48 | # pyre-fixme[2]: Parameter must be annotated. 49 | def _fetch_data(self, change) -> None: 50 | if not self.config: 51 | return 52 | with self.out: 53 | self.visualizer._update_config(self.config) 54 | self.output = namedtuple_to_dict(self.visualizer.visualize()) 55 | self.config = {} 56 | 57 | @observe("label_details") 58 | # pyre-fixme[2]: Parameter must be annotated. 59 | def _fetch_attribution(self, change) -> None: 60 | if not self.label_details: 61 | return 62 | with self.out: 63 | self.attribution = namedtuple_to_dict( 64 | self.visualizer._calculate_attribution_from_cache( 65 | self.label_details["inputIndex"], 66 | self.label_details["modelIndex"], 67 | self.label_details["labelIndex"], 68 | ) 69 | ) 70 | self.label_details = {} 71 | -------------------------------------------------------------------------------- /captum/insights/example.py: -------------------------------------------------------------------------------- 1 | # for legacy purposes 2 | 3 | # pyre-strict 4 | import warnings 5 | 6 | from captum.insights.attr_vis.example import * # noqa 7 | 8 | warnings.warn( 9 | "Deprecated. Please import from captum.insights.attr_vis.example instead.", 10 | stacklevel=1, 11 | ) 12 | 13 | 14 | main() # noqa 15 | -------------------------------------------------------------------------------- /captum/log/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | try: 6 | from captum.log.fb.internal_log import ( # type: ignore 7 | disable_detailed_logging, 8 | log, 9 | log_usage, 10 | patch_methods, 11 | set_environment, 12 | TimedLog, 13 | ) 14 | except ImportError: 15 | # bug with mypy: https://github.com/python/mypy/issues/1153 16 | from captum.log.dummy_log import ( # type: ignore 17 | disable_detailed_logging, 18 | log, 19 | log_usage, 20 | patch_methods, 21 | set_environment, 22 | TimedLog, 23 | ) 24 | 25 | __all__ = [ 26 | "log", 27 | "log_usage", 28 | "TimedLog", 29 | "set_environment", 30 | "disable_detailed_logging", 31 | "patch_methods", 32 | ] 33 | -------------------------------------------------------------------------------- /captum/log/dummy_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from functools import wraps 6 | from types import TracebackType 7 | from typing import Any, List, Optional, Union 8 | 9 | 10 | def log(*args: Any, **kwargs: Any) -> None: 11 | pass 12 | 13 | 14 | class TimedLog: 15 | def __init__(self, *args: Any, **kwargs: Any) -> None: 16 | pass 17 | 18 | def __enter__(self) -> "TimedLog": 19 | return self 20 | 21 | def __exit__( 22 | self, 23 | exception_type: Optional[BaseException], 24 | exception_value: Optional[BaseException], 25 | traceback: Optional[TracebackType], 26 | ) -> bool: 27 | return exception_value is not None 28 | 29 | 30 | # pyre-fixme[3]: Return type must be annotated. 31 | def log_usage(*log_args: Any, **log_kwargs: Any): 32 | # pyre-fixme[3]: Return type must be annotated. 33 | # pyre-fixme[2]: Parameter must be annotated. 34 | def _log_usage(func): 35 | @wraps(func) 36 | # pyre-fixme[53]: Captured variable `func` is not annotated. 37 | # pyre-fixme[3]: Return type must be annotated. 38 | def wrapper(*args: Any, **kwargs: Any): 39 | return func(*args, **kwargs) 40 | 41 | return wrapper 42 | 43 | return _log_usage 44 | 45 | 46 | def set_environment(env: Union[None, List[str], str]) -> None: 47 | pass 48 | 49 | 50 | def disable_detailed_logging() -> None: 51 | pass 52 | 53 | 54 | def patch_methods(_, patch_log: bool = True) -> None: 55 | pass 56 | -------------------------------------------------------------------------------- /captum/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from captum.metrics._core.infidelity import ( 6 | infidelity, 7 | infidelity_perturb_func_decorator, 8 | ) 9 | from captum.metrics._core.sensitivity import sensitivity_max 10 | 11 | __all__ = [ 12 | "infidelity", 13 | "infidelity_perturb_func_decorator", 14 | "sensitivity_max", 15 | ] 16 | -------------------------------------------------------------------------------- /captum/metrics/_core/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /captum/metrics/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/metrics/_utils/__init__.py -------------------------------------------------------------------------------- /captum/module/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | from captum.module.binary_concrete_stochastic_gates import BinaryConcreteStochasticGates 3 | from captum.module.gaussian_stochastic_gates import GaussianStochasticGates 4 | from captum.module.stochastic_gates_base import StochasticGatesBase 5 | 6 | __all__ = [ 7 | "BinaryConcreteStochasticGates", 8 | "GaussianStochasticGates", 9 | "StochasticGatesBase", 10 | ] 11 | -------------------------------------------------------------------------------- /captum/robust/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | from captum.robust._core.fgsm import FGSM 6 | from captum.robust._core.metrics.attack_comparator import AttackComparator 7 | from captum.robust._core.metrics.min_param_perturbation import MinParamPerturbation 8 | from captum.robust._core.perturbation import Perturbation 9 | from captum.robust._core.pgd import PGD 10 | 11 | __all__ = ["FGSM", "AttackComparator", "MinParamPerturbation", "Perturbation", "PGD"] 12 | -------------------------------------------------------------------------------- /captum/robust/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/robust/_core/__init__.py -------------------------------------------------------------------------------- /captum/robust/_core/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/robust/_core/metrics/__init__.py -------------------------------------------------------------------------------- /captum/robust/_core/perturbation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | from typing import Callable 5 | 6 | 7 | class Perturbation: 8 | r""" 9 | All perturbation and attack algorithms extend this class. It enforces 10 | its child classes to extend and override core `perturb` method. 11 | """ 12 | 13 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. 14 | # pyre-fixme[13]: Attribute `perturb` is never initialized. 15 | perturb: Callable 16 | r""" 17 | This method computes and returns the perturbed input for each input tensor. 18 | Deriving classes are responsible for implementing its logic accordingly. 19 | 20 | Specific adversarial attack algorithms that extend this class take relevant 21 | arguments. 22 | 23 | Args: 24 | 25 | inputs (Tensor or tuple[Tensor, ...]): Input for which adversarial attack 26 | is computed. It can be provided as a single tensor or 27 | a tuple of multiple tensors. If multiple input tensors 28 | are provided, the batch sizes must be aligned across all 29 | tensors. 30 | 31 | Returns: 32 | 33 | - **perturbed inputs** (*Tensor* or *tuple[Tensor, ...]*): 34 | Perturbed input for each 35 | input tensor. The perturbed inputs have the same shape and 36 | dimensionality as the inputs. 37 | If a single tensor is provided as inputs, a single tensor 38 | is returned. If a tuple is provided for inputs, a tuple of 39 | corresponding sized tensors is returned. 40 | """ 41 | 42 | # pyre-fixme[3]: Return type must be annotated. 43 | # pyre-fixme[2]: Parameter must be annotated. 44 | def __call__(self, *args, **kwargs): 45 | return self.perturb(*args, **kwargs) 46 | -------------------------------------------------------------------------------- /captum/testing/attr/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/captum/testing/attr/helpers/__init__.py -------------------------------------------------------------------------------- /captum/testing/attr/helpers/attribution_delta_util.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-strict 4 | from typing import Tuple, Union 5 | 6 | import torch 7 | from captum.testing.helpers import BaseTest 8 | from torch import Tensor 9 | 10 | 11 | def assert_attribution_delta( 12 | test: BaseTest, 13 | inputs: Union[Tensor, Tuple[Tensor, ...]], 14 | attributions: Union[Tensor, Tuple[Tensor, ...]], 15 | n_samples: int, 16 | delta: Tensor, 17 | delta_thresh: Union[float, Tensor] = 0.0006, 18 | is_layer: bool = False, 19 | ) -> None: 20 | if not is_layer: 21 | for input, attribution in zip(inputs, attributions): 22 | test.assertEqual(attribution.shape, input.shape) 23 | if isinstance(inputs, tuple): 24 | bsz = inputs[0].shape[0] 25 | else: 26 | bsz = inputs.shape[0] 27 | test.assertEqual([bsz * n_samples], list(delta.shape)) 28 | 29 | delta = torch.mean(delta.reshape(bsz, -1), dim=1) 30 | assert_delta(test, delta, delta_thresh) 31 | 32 | 33 | def assert_delta( 34 | test: BaseTest, delta: Tensor, delta_thresh: Union[Tensor, float] = 0.0006 35 | ) -> None: 36 | delta_condition = (delta.abs() < delta_thresh).all() 37 | test.assertTrue( 38 | delta_condition, 39 | "Sum of SHAP values {} does" 40 | " not match the difference of endpoints.".format(delta), 41 | ) 42 | -------------------------------------------------------------------------------- /captum/testing/attr/helpers/gen_test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | import typing 6 | from typing import Any, cast, Dict, List, Tuple, Type, Union 7 | 8 | from captum.attr._core.lime import Lime 9 | from captum.attr._models.base import _get_deep_layer_name 10 | from captum.attr._utils.attribution import Attribution 11 | from torch.nn import Module 12 | 13 | 14 | def gen_test_name( 15 | prefix: str, test_name: str, algorithm: Type[Attribution], noise_tunnel: bool 16 | ) -> str: 17 | # Generates test name for dynamically generated tests 18 | return ( 19 | prefix 20 | + "_" 21 | + test_name 22 | + "_" 23 | + algorithm.__name__ 24 | + ("NoiseTunnel" if noise_tunnel else "") 25 | ) 26 | 27 | 28 | def parse_test_config( 29 | # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use 30 | # `typing.Dict[, ]` to avoid runtime subscripting errors. 31 | test_config: Dict, 32 | ) -> Tuple[List[Type[Attribution]], Module, Dict[str, Any], Module, bool, bool]: 33 | algorithms = cast(List[Type[Attribution]], test_config["algorithms"]) 34 | model = test_config["model"] 35 | # pyre-fixme[33]: Given annotation cannot contain `Any`. 36 | args = cast(Dict[str, Any], test_config["attribute_args"]) 37 | layer = test_config["layer"] if "layer" in test_config else None 38 | noise_tunnel = ( 39 | test_config["noise_tunnel"] if "noise_tunnel" in test_config else False 40 | ) 41 | baseline_distr = ( 42 | test_config["baseline_distr"] if "baseline_distr" in test_config else False 43 | ) 44 | return algorithms, model, args, layer, noise_tunnel, baseline_distr # type: ignore 45 | 46 | 47 | def should_create_generated_test(algorithm: Type[Attribution]) -> bool: 48 | if issubclass(algorithm, Lime): 49 | try: 50 | import sklearn # noqa: F401 51 | 52 | assert ( 53 | sklearn.__version__ >= "0.23.0" 54 | ), "Must have sklearn version 0.23.0 or higher to use " 55 | "sample_weight in Lasso regression." 56 | return True 57 | except (ImportError, AssertionError): 58 | return False 59 | return True 60 | 61 | 62 | @typing.overload 63 | def get_target_layer(model: Module, layer_name: str) -> Module: ... 64 | 65 | 66 | @typing.overload 67 | def get_target_layer(model: Module, layer_name: List[str]) -> List[Module]: ... 68 | 69 | 70 | def get_target_layer( 71 | model: Module, layer_name: Union[str, List[str]] 72 | ) -> Union[Module, List[Module]]: 73 | if isinstance(layer_name, str): 74 | return _get_deep_layer_name(model, layer_name) 75 | else: 76 | return [ 77 | _get_deep_layer_name(model, single_layer_name) 78 | for single_layer_name in layer_name 79 | ] 80 | -------------------------------------------------------------------------------- /captum/testing/attr/helpers/get_config_util.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-strict 4 | from typing import Any, Tuple 5 | 6 | import torch 7 | from captum._utils.gradient import compute_gradients 8 | from captum.testing.helpers.basic_models import BasicModel, BasicModel5_MultiArgs 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | 13 | # pyre-fixme[3]: Return annotation cannot contain `Any`. 14 | def get_basic_config() -> Tuple[Module, Tensor, Tensor, Any]: 15 | input = torch.tensor([1.0, 2.0, 3.0, 0.0, -1.0, 7.0], requires_grad=True).T 16 | # manually percomputed gradients 17 | grads = torch.tensor([-0.0, -0.0, -0.0, 1.0, 1.0, -0.0]) 18 | return BasicModel(), input, grads, None 19 | 20 | 21 | # pyre-fixme[3]: Return annotation cannot contain `Any`. 22 | def get_multiargs_basic_config() -> ( 23 | Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Any] 24 | ): 25 | model = BasicModel5_MultiArgs() 26 | additional_forward_args = ([2, 3], 1) 27 | inputs = ( 28 | torch.tensor([[1.5, 2.0, 34.3], [3.4, 1.2, 2.0]], requires_grad=True), 29 | torch.tensor([[3.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True), 30 | ) 31 | grads = compute_gradients( 32 | model, inputs, additional_forward_args=additional_forward_args 33 | ) 34 | return model, inputs, grads, additional_forward_args 35 | 36 | 37 | # pyre-fixme[3]: Return annotation cannot contain `Any`. 38 | def get_multiargs_basic_config_large() -> ( 39 | Tuple[Module, Tuple[Tensor, ...], Tuple[Tensor, ...], Any] 40 | ): 41 | model = BasicModel5_MultiArgs() 42 | additional_forward_args = ([2, 3], 1) 43 | inputs = ( 44 | torch.tensor( 45 | [[10.5, 12.0, 34.3], [43.4, 51.2, 32.0]], requires_grad=True 46 | ).repeat_interleave(3, dim=0), 47 | torch.tensor( 48 | [[1.0, 3.5, 23.2], [2.3, 1.2, 0.3]], requires_grad=True 49 | ).repeat_interleave(3, dim=0), 50 | ) 51 | grads = compute_gradients( 52 | model, inputs, additional_forward_args=additional_forward_args 53 | ) 54 | return model, inputs, grads, additional_forward_args 55 | -------------------------------------------------------------------------------- /captum/testing/attr/helpers/neuron_layer_testing_util.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-strict 4 | from typing import Tuple 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | 10 | def create_inps_and_base_for_deeplift_neuron_layer_testing() -> ( 11 | Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]] 12 | ): 13 | x1 = torch.tensor([[-10.0, 1.0, -5.0]], requires_grad=True) 14 | x2 = torch.tensor([[3.0, 3.0, 1.0]], requires_grad=True) 15 | 16 | b1 = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True) 17 | b2 = torch.tensor([[0.0, 0.0, 0.0]], requires_grad=True) 18 | 19 | inputs = (x1, x2) 20 | baselines = (b1, b2) 21 | 22 | return inputs, baselines 23 | 24 | 25 | def create_inps_and_base_for_deepliftshap_neuron_layer_testing() -> ( 26 | Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]] 27 | ): 28 | x1 = torch.tensor([[-10.0, 1.0, -5.0]], requires_grad=True) 29 | x2 = torch.tensor([[3.0, 3.0, 1.0]], requires_grad=True) 30 | 31 | b1 = torch.tensor( 32 | [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], requires_grad=True 33 | ) 34 | b2 = torch.tensor( 35 | [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], requires_grad=True 36 | ) 37 | 38 | inputs = (x1, x2) 39 | baselines = (b1, b2) 40 | 41 | return inputs, baselines 42 | -------------------------------------------------------------------------------- /captum/testing/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | try: 6 | from captum.testing.helpers.fb.internal_base import ( # type: ignore 7 | FbBaseTest as _BaseTest, 8 | ) 9 | 10 | except ImportError: 11 | # tests/helpers/__init__.py:13: error: Incompatible import of "_BaseTest" 12 | # (imported name has type "type[BaseTest]", local name has type 13 | # "type[FbBaseTest]") [assignment] 14 | from captum.testing.helpers.basic import BaseTest as _BaseTest # type: ignore 15 | 16 | BaseTest = _BaseTest 17 | 18 | __all__ = [ 19 | "BaseTest", 20 | ] 21 | -------------------------------------------------------------------------------- /captum/testing/helpers/evaluate_linear_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 3 | 4 | # pyre-strict 5 | 6 | from typing import cast, Dict 7 | 8 | import torch 9 | 10 | from captum._utils.models.linear_model.model import LinearModel 11 | from torch import Tensor 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | def evaluate(test_data: DataLoader, classifier: LinearModel) -> Dict[str, Tensor]: 16 | classifier.eval() 17 | 18 | l1_loss = 0.0 19 | l2_loss = 0.0 20 | n = 0 21 | l2_losses = [] 22 | with torch.no_grad(): 23 | for data in test_data: 24 | if len(data) == 2: 25 | x, y = data 26 | w = None 27 | else: 28 | x, y, w = data 29 | 30 | out = classifier(x) 31 | 32 | y = y.view(x.shape[0], -1) 33 | assert y.shape == out.shape 34 | 35 | if w is None: 36 | l1_loss += (out - y).abs().sum(0).to(dtype=torch.float64) 37 | l2_loss += ((out - y) ** 2).sum(0).to(dtype=torch.float64) 38 | l2_losses.append(((out - y) ** 2).to(dtype=torch.float64)) 39 | else: 40 | l1_loss += ( 41 | (w.view(-1, 1) * (out - y)).abs().sum(0).to(dtype=torch.float64) 42 | ) 43 | l2_loss += ( 44 | (w.view(-1, 1) * ((out - y) ** 2)).sum(0).to(dtype=torch.float64) 45 | ) 46 | l2_losses.append( 47 | (w.view(-1, 1) * ((out - y) ** 2)).to(dtype=torch.float64) 48 | ) 49 | 50 | n += x.shape[0] 51 | 52 | l2_losses = torch.cat(l2_losses, dim=0) 53 | assert n > 0 54 | 55 | # just to double check 56 | assert ((l2_losses.mean(0) - l2_loss / n).abs() <= 0.1).all() 57 | 58 | classifier.train() 59 | return {"l1": cast(Tensor, l1_loss / n), "l2": cast(Tensor, l2_loss / n)} 60 | -------------------------------------------------------------------------------- /captum/testing/helpers/influence/__init__.py: -------------------------------------------------------------------------------- 1 | # pyre-strict 2 | -------------------------------------------------------------------------------- /docs/Captum_Attribution_Algos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/docs/Captum_Attribution_Algos.png -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the source files for Captum's Docusaurus documentation. 2 | See the repo's [README](../README.md) for additional information. 3 | -------------------------------------------------------------------------------- /docs/captum_insights.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: captum_insights 3 | title: Captum Insights 4 | --- 5 | 6 | Interpreting model output in complex models can be difficult. Even with interpretability libraries like Captum, it can be difficult to understand models without proper visualizations. Image and text input features can be especially difficult to understand without these visualizations. 7 | 8 | Captum Insights is an interpretability visualization widget built on top of Captum to facilitate model understanding. Captum Insights works across images, text, and other features to help users understand feature attribution. Some examples of the widget are below. 9 | 10 | Getting started with Captum Insights is easy. You can learn how to use Captum Insights with the [Getting started with Captum Insights](/tutorials/CIFAR_TorchVision_Captum_Insights) tutorial. Alternatively, to analyze a sample model on CIFAR10 via Captum Insights execute the line below and navigate to the URL specified in the output. 11 | 12 | ``` 13 | python -m captum.insights.example 14 | ``` 15 | 16 | 17 | Below are some sample screenshots of Captum Insights: 18 | 19 | ![screenshot1](/img/captum_insights_screenshot.png) 20 | 21 | ![screenshot2](/img/captum_insights_screenshot_vqa.png) 22 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: getting_started 3 | title: Getting Started 4 | --- 5 | 6 | This section shows you how to get up and running with Captum. 7 | 8 | 9 | ## Installing Captum 10 | 11 | #### Installation Requirements: 12 | 13 | - Python >= 3.6 14 | - PyTorch >= 1.2 15 | - numpy 16 | 17 | Captum is easily installed via 18 | [Anaconda](https://www.anaconda.com/distribution/#download-section) (recommended) 19 | or `pip`: 20 | 21 | 22 | 23 | ```bash 24 | conda install captum -c pytorch 25 | ``` 26 | 27 | ```bash 28 | pip install captum 29 | ``` 30 | 31 | 32 | For more detailed installation instructions, please see the 33 | [Project Readme](https://github.com/pytorch/captum/blob/master/README.md) 34 | on GitHub. 35 | 36 | 37 | ## Tutorials 38 | 39 | We have several tutorials to help get you off the ground with Captum. The tutorials are Jupyter notebooks and cover the basics along with demonstrating usage of Captum with models of different modalities. 40 | 41 | View the tutorials page [here](../tutorials/). 42 | 43 | 44 | ## API Reference 45 | 46 | For an in-depth reference of the various Captum internals, see our 47 | [API Reference](../api/). 48 | 49 | 50 | ## Contributing 51 | 52 | You'd like to contribute to Captum? Great! Please see 53 | [here](https://github.com/pytorch/captum/blob/master/CONTRIBUTING.md) 54 | for how to help out. 55 | -------------------------------------------------------------------------------- /docs/introduction.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: introduction 3 | title: Introduction 4 | --- 5 | Captum (“comprehension” in Latin) is an open source, extensible library for model interpretability built on PyTorch. 6 | 7 | With the increase in model complexity and the resulting lack of transparency, model interpretability methods have become increasingly important. Model understanding is both an active area of research as well as an area of focus for practical applications across industries using machine learning. Captum provides state-of-the-art algorithms, including Integrated Gradients, to provide researchers and developers with an easy way to understand which features are contributing to a model’s output. 8 | 9 | Captum helps ML researchers more easily implement interpretability algorithms that can interact with PyTorch models. It also allows researchers to quickly benchmark their work against other existing algorithms available in the library. 10 | 11 | For model developers, Captum can be used to improve and troubleshoot models by facilitating the identification of different features that contribute to a model’s output in order to design better models and troubleshoot unexpected model outputs. 12 | 13 | ## Target Audience 14 | 15 | The primary audiences for Captum are model developers who are looking to improve their models and understand which features are important and interpretability researchers focused on identifying algorithms that can better interpret many types of models. 16 | 17 | Captum can also be used by application engineers who are using trained models in production. Captum provides easier troubleshooting through improved model interpretability, and the potential for delivering better explanations to end users on why they’re seeing a specific piece of content, such as a movie recommendation. 18 | -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | --- 2 | id: overview 3 | title: Overview 4 | --- 5 | 6 | This overview describes the basic components of Captum and how they work 7 | together. For a high-level view of what Captum tries to achieve in more 8 | abstract terms, please see the [Introduction](introduction). 9 | 10 | 11 | TODO: Add content here 12 | -------------------------------------------------------------------------------- /docs/presentations/Captum_NeurIPS_2019_final.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/docs/presentations/Captum_NeurIPS_2019_final.key -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: captum 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - numpy<2.0 6 | - pytorch>=1.10 7 | - matplotlib-base 8 | - tqdm 9 | - packaging 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | first_party_detection = false 3 | 4 | [tool.black] 5 | target-version = ['py39'] 6 | -------------------------------------------------------------------------------- /scripts/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | 3 | # run this script from the project root using `./scripts/build_docs.sh` 4 | 5 | usage() { 6 | echo "Usage: $0 [-b]" 7 | echo "" 8 | echo "Build Captum documentation." 9 | echo "" 10 | echo " -b Build static version of documentation (otherwise start server)" 11 | echo "" 12 | exit 1 13 | } 14 | 15 | BUILD_STATIC=false 16 | 17 | while getopts 'hb' flag; do 18 | case "${flag}" in 19 | h) 20 | usage 21 | ;; 22 | b) 23 | BUILD_STATIC=true 24 | ;; 25 | *) 26 | usage 27 | ;; 28 | esac 29 | done 30 | 31 | echo "-----------------------------------" 32 | echo "Generating API reference via Sphinx" 33 | echo "-----------------------------------" 34 | cd sphinx || exit 35 | make html 36 | cd .. || exit 37 | 38 | # run script to parse html generated by sphinx 39 | echo "--------------------------------------------" 40 | echo "Parsing Sphinx docs and moving to Website" 41 | echo "--------------------------------------------" 42 | mkdir -p "website/pages/api/" 43 | 44 | cwd=$(pwd) 45 | python scripts/parse_sphinx.py -i "${cwd}/sphinx/build/html/" -o "${cwd}/website/pages/api/" 46 | 47 | SPHINX_STATIC_DIR='sphinx/build/html/_static/' 48 | WEBSITE_SPHINX_DIR='website/static/_sphinx/' 49 | 50 | mkdir -p $WEBSITE_SPHINX_DIR 51 | 52 | # move static files from /sphinx/build/html/_static/*: 53 | for sphinx_static_file in 'documentation_options.js' \ 54 | 'doctools.js' \ 55 | 'language_data.js' \ 56 | 'searchtools.js' \ 57 | 'katex_autorenderer.js' \ 58 | 'pygments.css' \ 59 | 'katex-math.css' 60 | do 61 | cp "${SPHINX_STATIC_DIR}${sphinx_static_file}" "${WEBSITE_SPHINX_DIR}${sphinx_static_file}" 62 | done 63 | 64 | # searchindex.js is not static util 65 | cp "sphinx/build/html/searchindex.js" "${WEBSITE_SPHINX_DIR}searchindex.js" 66 | 67 | echo "-----------------------------------" 68 | echo "Generating tutorials" 69 | echo "-----------------------------------" 70 | mkdir -p "website/_tutorials" 71 | mkdir -p "website/static/files" 72 | python scripts/parse_tutorials.py -w "${cwd}" 73 | 74 | echo "-----------------------------------" 75 | echo "Install Website dependencies" 76 | echo "-----------------------------------" 77 | cd website || exit 78 | yarn 79 | 80 | if [[ $BUILD_STATIC == true ]]; then 81 | echo "-----------------------------------" 82 | echo "Building static site" 83 | echo "-----------------------------------" 84 | yarn build 85 | else 86 | echo "-----------------------------------" 87 | echo "Starting local server" 88 | echo "-----------------------------------" 89 | yarn start 90 | fi 91 | -------------------------------------------------------------------------------- /scripts/build_insights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This builds the Captum Insights React-based UI using yarn. 4 | # 5 | # Run this script from the project root using `./scripts/build_insights.sh` 6 | 7 | set -e 8 | 9 | usage() { 10 | echo "Usage: $0" 11 | echo "" 12 | echo "Build Captum Insights." 13 | echo "" 14 | exit 1 15 | } 16 | 17 | while getopts 'h' flag; do 18 | case "${flag}" in 19 | h) 20 | usage 21 | ;; 22 | *) 23 | usage 24 | ;; 25 | esac 26 | done 27 | 28 | # check if yarn is installed 29 | if [ ! -x "$(command -v yarn)" ]; then 30 | echo "" 31 | echo "Please install yarn with 'conda install -c conda-forge yarn' or equivalent." 32 | echo "" 33 | exit 1 34 | fi 35 | 36 | # go into subdir 37 | pushd captum/insights/attr_vis/frontend || exit 38 | 39 | echo "-----------------------------------" 40 | echo "Install Dependencies" 41 | echo "-----------------------------------" 42 | yarn install 43 | 44 | echo "-----------------------------------" 45 | echo "Building Captum Insights" 46 | echo "-----------------------------------" 47 | yarn build 48 | 49 | pushd widget || exit 50 | 51 | echo "-----------------------------------" 52 | echo "Building Captum Insights widget" 53 | echo "-----------------------------------" 54 | 55 | ../node_modules/.bin/webpack-cli --config webpack.config.js 56 | 57 | # exit subdir 58 | popd || exit 59 | 60 | # exit subdir 61 | popd || exit 62 | -------------------------------------------------------------------------------- /scripts/install_via_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | while getopts 'nf' flag; do 6 | case "${flag}" in 7 | f) FRAMEWORKS=true ;; 8 | *) echo "usage: $0 [-n] [-f]" >&2 9 | exit 1 ;; 10 | esac 11 | done 12 | 13 | # update conda 14 | # removing due to setuptools error during update 15 | #conda update -y -n base -c defaults conda 16 | conda update -q --all --yes 17 | 18 | 19 | # install CPU version for much smaller download 20 | conda install -q -y pytorch cpuonly -c pytorch 21 | 22 | # install other deps 23 | conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug 24 | conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress conda-build openai 25 | conda install -q -y transformers 26 | 27 | # install captum 28 | python setup.py develop 29 | -------------------------------------------------------------------------------- /scripts/install_via_pip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | PYTORCH_NIGHTLY=false 6 | DEPLOY=false 7 | CHOSEN_TORCH_VERSION=-1 8 | CHOSEN_TRANSFORMERS_VERSION=-1 9 | 10 | while getopts 'ndfv:t:' flag; do 11 | case "${flag}" in 12 | n) PYTORCH_NIGHTLY=true ;; 13 | d) DEPLOY=true ;; 14 | f) FRAMEWORKS=true ;; 15 | v) CHOSEN_TORCH_VERSION=${OPTARG};; 16 | t) CHOSEN_TRANSFORMERS_VERSION=${OPTARG};; 17 | *) echo "usage: $0 [-n] [-d] [-f] [-v version] [-t transformers_version]" >&2 18 | exit 1 ;; 19 | esac 20 | done 21 | 22 | # NOTE: Only Debian variants are supported, since this script is only 23 | # used by our tests on GitHub Actions. In the future we might generalize, 24 | # but users should hopefully be using conda installs. 25 | 26 | # install nodejs and yarn for insights build 27 | sudo apt-get update 28 | sudo apt install apt-transport-https ca-certificates 29 | curl -sL https://deb.nodesource.com/setup_14.x | sudo -E bash - 30 | curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - 31 | echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list 32 | sudo apt update 33 | sudo apt install nodejs 34 | sudo apt install yarn 35 | 36 | # yarn needs terminal info 37 | export TERM=xterm 38 | 39 | # Remove all items from pip cache to avoid hash mismatch 40 | pip cache purge 41 | 42 | # upgrade pip 43 | pip install --upgrade pip --progress-bar off 44 | 45 | # install captum with dev deps 46 | pip install -e .[dev] --progress-bar off 47 | BUILD_INSIGHTS=1 python setup.py develop 48 | 49 | # install pytorch nightly if asked for 50 | if [[ $PYTORCH_NIGHTLY == true ]]; then 51 | pip install --upgrade --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --progress-bar off 52 | else 53 | # If no version is specified, upgrade to the latest release. 54 | if [[ $CHOSEN_TORCH_VERSION == -1 ]]; then 55 | pip install --upgrade torch --progress-bar off 56 | else 57 | pip install torch=="$CHOSEN_TORCH_VERSION" --progress-bar off 58 | fi 59 | fi 60 | 61 | # install deployment bits if asked for 62 | if [[ $DEPLOY == true ]]; then 63 | pip install beautifulsoup4 ipython nbconvert==5.6.1 --progress-bar off 64 | fi 65 | 66 | # install appropriate transformers version 67 | # If no version is specified, upgrade to the latest release. 68 | if [[ $CHOSEN_TRANSFORMERS_VERSION == -1 ]]; then 69 | pip install --upgrade transformers --progress-bar off 70 | else 71 | pip install transformers=="$CHOSEN_TRANSFORMERS_VERSION" --progress-bar off 72 | fi 73 | -------------------------------------------------------------------------------- /scripts/parse_sphinx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import os 5 | 6 | from bs4 import BeautifulSoup 7 | 8 | # no need to import css from built path 9 | # coz docusaurus merge all css files within static folder automatically 10 | # https://v1.docusaurus.io/docs/en/api-pages#styles 11 | base_scripts = """ 12 | 14 | 15 | 16 | 17 | 18 | 19 | """ # noqa: E501 20 | 21 | search_js_scripts = """ 22 | 25 | 26 | 27 | """ 28 | 29 | katex_scripts = """ 30 | 31 | 32 | 33 | 34 | """ # noqa: E501 35 | 36 | 37 | def parse_sphinx(input_dir, output_dir): 38 | for cur, _, files in os.walk(input_dir): 39 | for fname in files: 40 | if fname.endswith(".html"): 41 | with open(os.path.join(cur, fname), "r") as f: 42 | soup = BeautifulSoup(f.read(), "html.parser") 43 | doc = soup.find("div", {"class": "document"}) 44 | wrapped_doc = doc.wrap( 45 | soup.new_tag("div", **{"class": "sphinx wrapper"}) 46 | ) 47 | # add scripts that sphinx pages need 48 | if fname == "search.html": 49 | out = ( 50 | base_scripts 51 | + search_js_scripts 52 | + katex_scripts 53 | + str(wrapped_doc) 54 | ) 55 | else: 56 | out = base_scripts + katex_scripts + str(wrapped_doc) 57 | output_path = os.path.join(output_dir, os.path.relpath(cur, input_dir)) 58 | os.makedirs(output_path, exist_ok=True) 59 | with open(os.path.join(output_path, fname), "w") as fout: 60 | fout.write(out) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser(description="Strip HTML body from Sphinx docs.") 65 | parser.add_argument( 66 | "-i", 67 | "--input_dir", 68 | metavar="path", 69 | required=True, 70 | help="Input directory for Sphinx HTML.", 71 | ) 72 | parser.add_argument( 73 | "-o", 74 | "--output_dir", 75 | metavar="path", 76 | required=True, 77 | help="Output directory in website.", 78 | ) 79 | args = parser.parse_args() 80 | parse_sphinx(args.input_dir, args.output_dir) 81 | -------------------------------------------------------------------------------- /scripts/run_mypy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # This runs mypy's static type checker on parts of Captum supporting type 5 | # hints. 6 | 7 | mypy -p captum.attr --ignore-missing-imports --allow-redefinition 8 | mypy -p captum.insights --ignore-missing-imports --allow-redefinition 9 | mypy -p captum.metrics --ignore-missing-imports --allow-redefinition 10 | mypy -p captum.robust --ignore-missing-imports --allow-redefinition 11 | mypy -p captum.concept --ignore-missing-imports --allow-redefinition 12 | mypy -p captum.influence --ignore-missing-imports --allow-redefinition 13 | mypy -p captum._utils --ignore-missing-imports --allow-redefinition 14 | mypy -p tests --ignore-missing-imports --allow-redefinition 15 | -------------------------------------------------------------------------------- /scripts/update_versions_html.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | 6 | from bs4 import BeautifulSoup 7 | 8 | BASE_URL = "/" 9 | 10 | 11 | def updateVersionHTML(base_path, base_url=BASE_URL): 12 | with open(base_path + "/captum-master/website/_versions.json", "rb") as infile: 13 | versions = json.loads(infile.read()) 14 | 15 | with open(base_path + "/new-site/versions.html", "rb") as infile: 16 | html = infile.read() 17 | 18 | versions.append("latest") 19 | 20 | def prepend_url(a_tag, base_url, version): 21 | href = a_tag.attrs["href"] 22 | if href.startswith("https://") or href.startswith("http://"): 23 | return href 24 | else: 25 | return "{base_url}versions/{version}{original_url}".format( 26 | base_url=base_url, version=version, original_url=href 27 | ) 28 | 29 | for v in versions: 30 | soup = BeautifulSoup(html, "html.parser") 31 | 32 | # title 33 | title_link = soup.find("header").find("a") 34 | title_link.attrs["href"] = prepend_url(title_link, base_url, v) 35 | 36 | # nav 37 | nav_links = soup.find("nav").findAll("a") 38 | for link in nav_links: 39 | link.attrs["href"] = prepend_url(link, base_url, v) 40 | 41 | # version link 42 | t = soup.find("h2", {"class": "headerTitleWithLogo"}).find_next("a") 43 | t.string = v 44 | t.attrs["href"] = prepend_url(t, base_url, v) 45 | 46 | # output files 47 | with open( 48 | base_path + "/new-site/versions/{}/versions.html".format(v), "w" 49 | ) as outfile: 50 | outfile.write(str(soup)) 51 | with open( 52 | base_path + "/new-site/versions/{}/en/versions.html".format(v), "w" 53 | ) as outfile: 54 | outfile.write(str(soup)) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser( 59 | description=( 60 | "Fix links in version.html files for Docusaurus site." 61 | "This is used to ensure that the versions.js for older " 62 | "versions in versions subdirectory are up-to-date and " 63 | "will have a way to navigate back to newer versions." 64 | ) 65 | ) 66 | parser.add_argument( 67 | "-p", 68 | "--base_path", 69 | metavar="path", 70 | required=True, 71 | help="Input directory for rolling out new version of site.", 72 | ) 73 | args = parser.parse_args() 74 | updateVersionHTML(args.base_path) 75 | -------------------------------------------------------------------------------- /scripts/versions.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2019-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | const React = require('react'); 11 | 12 | const CompLibrary = require('../../core/CompLibrary'); 13 | 14 | const Container = CompLibrary.Container; 15 | 16 | const CWD = process.cwd(); 17 | 18 | const versions = require(`${CWD}/_versions.json`); 19 | versions.sort().reverse(); 20 | 21 | function Versions(props) { 22 | const {config: siteConfig} = props; 23 | const baseUrl = siteConfig.baseUrl; 24 | const latestVersion = versions[0]; 25 | return ( 26 |
27 | 28 |
29 |
30 |

{siteConfig.title} Versions

31 |
32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 45 | 48 | 49 | 50 | 54 | 59 | 62 | 63 | 64 |
VersionInstall withDocumentation
{`stable (${latestVersion})`} 43 | conda install -c pytorch captum 44 | 46 | stable 47 |
51 | {'latest'} 52 | {' (master)'} 53 | 55 | 56 | pip install git+ssh://git@github.com/pytorch/captum.git 57 | 58 | 60 | latest 61 |
65 | 66 |

Past Versions

67 | 68 | 69 | {versions.map( 70 | version => 71 | version !== latestVersion && ( 72 | 73 | 74 | 79 | 80 | ), 81 | )} 82 | 83 |
{version} 75 | 76 | Documentation 77 | 78 |
84 |
85 |
86 |
87 | ); 88 | } 89 | 90 | module.exports = Versions; 91 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: black and flake8 disagree on whitespace before ':' 3 | # W503: black and flake8 disagree on how to place operators 4 | # E704: black and flake8 disagree on Multiple statements on one line (def) 5 | ignore = E203, W503, E704 6 | max-line-length = 88 7 | exclude = 8 | build, dist, tutorials, website 9 | 10 | [coverage:report] 11 | omit = 12 | test/* 13 | setup.py 14 | 15 | [mypy] 16 | exclude = ^.*fb.*$ 17 | 18 | [mypy-captum.log.fb.*] 19 | ignore_errors = True 20 | -------------------------------------------------------------------------------- /sphinx/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /sphinx/README.md: -------------------------------------------------------------------------------- 1 | # sphinx API reference 2 | 3 | This file describes the sphinx setup for auto-generating the API reference. 4 | 5 | 6 | ## Installation 7 | 8 | **Requirements**: 9 | - sphinx >= 2.0 10 | - sphinx_autodoc_typehints 11 | - sphinxcontrib-katex 12 | 13 | You can install these via `pip install sphinx sphinx_autodoc_typehints sphinxcontrib-katex`. 14 | 15 | 16 | ## Building 17 | 18 | From the `captum/sphinx` directory, run `make html`. 19 | 20 | Generated HTML output can be found in the `captum/sphinx/build` directory. The main index page is: `captum/sphinx/build/html/index.html` 21 | 22 | 23 | ## Structure 24 | 25 | `source/index.rst` contains the main index. The API reference for each module lives in its own file, e.g. `models.rst` for the `captum.models` module. 26 | -------------------------------------------------------------------------------- /sphinx/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /sphinx/source/attribution.rst: -------------------------------------------------------------------------------- 1 | Attribution 2 | ==================== 3 | .. toctree:: 4 | 5 | integrated_gradients 6 | saliency 7 | deep_lift 8 | deep_lift_shap 9 | gradient_shap 10 | input_x_gradient 11 | guided_backprop 12 | guided_grad_cam 13 | deconvolution 14 | feature_ablation 15 | occlusion 16 | feature_permutation 17 | shapley_value_sampling 18 | lime 19 | kernel_shap 20 | lrp 21 | -------------------------------------------------------------------------------- /sphinx/source/base_classes.rst: -------------------------------------------------------------------------------- 1 | Base Classes 2 | ======================== 3 | 4 | Attribution 5 | ^^^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.Attribution 8 | :members: 9 | 10 | Layer Attribution 11 | ^^^^^^^^^^^^^^^^^^^^^^^^ 12 | 13 | .. autoclass:: captum.attr.LayerAttribution 14 | :members: 15 | 16 | Neuron Attribution 17 | ^^^^^^^^^^^^^^^^^^^^^^^^ 18 | 19 | .. autoclass:: captum.attr.NeuronAttribution 20 | :members: 21 | 22 | Gradient Attribution 23 | ^^^^^^^^^^^^^^^^^^^^^^^^ 24 | 25 | .. autoclass:: captum.attr.GradientAttribution 26 | :members: 27 | 28 | Perturbation Attribution 29 | ^^^^^^^^^^^^^^^^^^^^^^^^ 30 | 31 | .. autoclass:: captum.attr.PerturbationAttribution 32 | :members: 33 | -------------------------------------------------------------------------------- /sphinx/source/binary_concrete_stg.rst: -------------------------------------------------------------------------------- 1 | BinaryConcreteStochasticGates 2 | ==================================== 3 | 4 | .. autoclass:: captum.module.BinaryConcreteStochasticGates 5 | :members: 6 | :inherited-members: Module 7 | -------------------------------------------------------------------------------- /sphinx/source/concept.rst: -------------------------------------------------------------------------------- 1 | Concept-based Interpretability 2 | ============================== 3 | 4 | TCAV 5 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.concept.TCAV 8 | :members: 9 | 10 | 11 | ConceptInterpreter 12 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.concept.ConceptInterpreter 15 | :members: 16 | 17 | 18 | Concept 19 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. autoclass:: captum.concept.Concept 22 | :members: 23 | 24 | 25 | Classifier 26 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 27 | 28 | .. autoclass:: captum.concept.Classifier 29 | :members: 30 | -------------------------------------------------------------------------------- /sphinx/source/deconvolution.rst: -------------------------------------------------------------------------------- 1 | Deconvolution 2 | ============= 3 | 4 | .. autoclass:: captum.attr.Deconvolution 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/deep_lift.rst: -------------------------------------------------------------------------------- 1 | DeepLift 2 | ========= 3 | 4 | .. autoclass:: captum.attr.DeepLift 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/deep_lift_shap.rst: -------------------------------------------------------------------------------- 1 | DeepLiftShap 2 | ============= 3 | 4 | .. autoclass:: captum.attr.DeepLiftShap 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/feature_ablation.rst: -------------------------------------------------------------------------------- 1 | Feature Ablation 2 | ================ 3 | 4 | .. autoclass:: captum.attr.FeatureAblation 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | -------------------------------------------------------------------------------- /sphinx/source/feature_permutation.rst: -------------------------------------------------------------------------------- 1 | Feature Permutation 2 | =================== 3 | 4 | .. autoclass:: captum.attr.FeaturePermutation 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | -------------------------------------------------------------------------------- /sphinx/source/gaussian_stg.rst: -------------------------------------------------------------------------------- 1 | GaussianStochasticGates 2 | ==================================== 3 | 4 | .. autoclass:: captum.module.GaussianStochasticGates 5 | :members: 6 | :inherited-members: Module 7 | -------------------------------------------------------------------------------- /sphinx/source/gradient_shap.rst: -------------------------------------------------------------------------------- 1 | GradientShap 2 | ============ 3 | 4 | .. autoclass:: captum.attr.GradientShap 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/guided_backprop.rst: -------------------------------------------------------------------------------- 1 | Guided Backprop 2 | =============== 3 | 4 | .. autoclass:: captum.attr.GuidedBackprop 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/guided_grad_cam.rst: -------------------------------------------------------------------------------- 1 | Guided GradCAM 2 | ============== 3 | 4 | .. autoclass:: captum.attr.GuidedGradCam 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/pytorch/captum 2 | 3 | Captum API Reference 4 | ==================== 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: API Reference 13 | 14 | attribution 15 | llm_attr 16 | noise_tunnel 17 | layer 18 | neuron 19 | metrics 20 | robust 21 | concept 22 | influence 23 | module 24 | utilities 25 | base_classes 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | :caption: Insights API Reference 30 | 31 | insights 32 | 33 | 34 | Indices and Tables 35 | ================== 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | -------------------------------------------------------------------------------- /sphinx/source/influence.rst: -------------------------------------------------------------------------------- 1 | Influential Examples 2 | ==================== 3 | 4 | DataInfluence 5 | ^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.influence.DataInfluence 8 | :members: 9 | 10 | 11 | SimilarityInfluence 12 | ^^^^^^^^^^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.influence.SimilarityInfluence 15 | :members: 16 | 17 | 18 | TracInCPBase 19 | ^^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. autoclass:: captum.influence.TracInCPBase 22 | :members: 23 | 24 | 25 | TracInCP 26 | ^^^^^^^^^^^^^^^^^^^^ 27 | 28 | .. autoclass:: captum.influence.TracInCP 29 | :members: 30 | 31 | TracInCPFast 32 | ^^^^^^^^^^^^^^^^^^^^ 33 | 34 | .. autoclass:: captum.influence.TracInCPFast 35 | :members: 36 | 37 | TracInCPFastRandProj 38 | ^^^^^^^^^^^^^^^^^^^^ 39 | 40 | .. autoclass:: captum.influence.TracInCPFastRandProj 41 | :members: 42 | -------------------------------------------------------------------------------- /sphinx/source/input_x_gradient.rst: -------------------------------------------------------------------------------- 1 | Input X Gradient 2 | ================ 3 | 4 | .. autoclass:: captum.attr.InputXGradient 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/insights.rst: -------------------------------------------------------------------------------- 1 | Insights 2 | ============ 3 | 4 | Batch 5 | ^^^^^ 6 | 7 | .. autoclass:: captum.insights.Batch 8 | :members: 9 | 10 | AttributionVisualizer 11 | ^^^^^^^^^^^^^^^^^^^^^ 12 | .. autoclass:: captum.insights.AttributionVisualizer 13 | :members: 14 | 15 | 16 | Features 17 | ======== 18 | 19 | BaseFeature 20 | ^^^^^^^^^^^ 21 | .. autoclass:: captum.insights.features.BaseFeature 22 | :members: 23 | 24 | GeneralFeature 25 | ^^^^^^^^^^^^^^ 26 | .. autoclass:: captum.insights.features.GeneralFeature 27 | 28 | TextFeature 29 | ^^^^^^^^^^^ 30 | .. autoclass:: captum.insights.features.TextFeature 31 | 32 | ImageFeature 33 | ^^^^^^^^^^^^ 34 | .. autoclass:: captum.insights.features.ImageFeature 35 | -------------------------------------------------------------------------------- /sphinx/source/integrated_gradients.rst: -------------------------------------------------------------------------------- 1 | Integrated Gradients 2 | ==================== 3 | 4 | .. autoclass:: captum.attr.IntegratedGradients 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/kernel_shap.rst: -------------------------------------------------------------------------------- 1 | KernelShap 2 | ==================== 3 | 4 | .. autoclass:: captum.attr.KernelShap 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | -------------------------------------------------------------------------------- /sphinx/source/layer.rst: -------------------------------------------------------------------------------- 1 | Layer Attribution 2 | =========================== 3 | 4 | Layer Conductance 5 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.LayerConductance 8 | :members: 9 | 10 | 11 | Layer Activation 12 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.attr.LayerActivation 15 | :members: 16 | 17 | Internal Influence 18 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 19 | 20 | .. autoclass:: captum.attr.InternalInfluence 21 | :members: 22 | 23 | Layer Gradient X Activation 24 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 25 | 26 | .. autoclass:: captum.attr.LayerGradientXActivation 27 | :members: 28 | 29 | GradCAM 30 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 31 | 32 | .. autoclass:: captum.attr.LayerGradCam 33 | :members: 34 | 35 | Layer DeepLift 36 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 37 | 38 | .. autoclass:: captum.attr.LayerDeepLift 39 | :members: 40 | 41 | Layer DeepLiftShap 42 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 43 | 44 | .. autoclass:: captum.attr.LayerDeepLiftShap 45 | :members: 46 | 47 | Layer GradientShap 48 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 49 | 50 | .. autoclass:: captum.attr.LayerGradientShap 51 | :members: 52 | 53 | Layer Integrated Gradients 54 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 55 | 56 | .. autoclass:: captum.attr.LayerIntegratedGradients 57 | :members: 58 | 59 | Layer Feature Ablation 60 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 61 | 62 | .. autoclass:: captum.attr.LayerFeatureAblation 63 | :members: 64 | 65 | Layer Feature Permutation 66 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 67 | 68 | .. autoclass:: captum.attr.LayerFeaturePermutation 69 | :members: 70 | 71 | Layer LRP 72 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 73 | 74 | .. autoclass:: captum.attr.LayerLRP 75 | :members: 76 | -------------------------------------------------------------------------------- /sphinx/source/lime.rst: -------------------------------------------------------------------------------- 1 | Lime 2 | ==================== 3 | 4 | .. autoclass:: captum.attr.LimeBase 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | .. autoclass:: captum.attr.Lime 8 | :members: 9 | 10 | .. autofunction:: captum.attr._core.lime.get_exp_kernel_similarity_function 11 | -------------------------------------------------------------------------------- /sphinx/source/llm_attr.rst: -------------------------------------------------------------------------------- 1 | LLM Attribution Classes 2 | ======================== 3 | 4 | LLMAttribution 5 | ^^^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.LLMAttribution 8 | :members: 9 | 10 | LLMGradientAttribution 11 | ^^^^^^^^^^^^^^^^^^^^^^^^ 12 | 13 | .. autoclass:: captum.attr.LLMGradientAttribution 14 | :members: 15 | 16 | 17 | LLMAttributionResult 18 | ^^^^^^^^^^^^^^^^^^^^^^^^ 19 | 20 | .. autoclass:: captum.attr.LLMAttributionResult 21 | :members: 22 | -------------------------------------------------------------------------------- /sphinx/source/lrp.rst: -------------------------------------------------------------------------------- 1 | LRP 2 | ==================== 3 | 4 | .. autoclass:: captum.attr.LRP 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | =========== 3 | 4 | Infidelity 5 | ^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.metrics.infidelity 8 | :members: 9 | 10 | 11 | Sensitivity 12 | ^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.metrics.sensitivity_max 15 | :members: 16 | -------------------------------------------------------------------------------- /sphinx/source/module.rst: -------------------------------------------------------------------------------- 1 | Module 2 | ==================== 3 | .. toctree:: 4 | 5 | binary_concrete_stg 6 | gaussian_stg 7 | -------------------------------------------------------------------------------- /sphinx/source/neuron.rst: -------------------------------------------------------------------------------- 1 | Neuron Attribution 2 | =========================== 3 | 4 | Neuron Gradient 5 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.attr.NeuronGradient 8 | :members: 9 | 10 | Neuron Integrated Gradients 11 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 12 | 13 | .. autoclass:: captum.attr.NeuronIntegratedGradients 14 | :members: 15 | 16 | Neuron Conductance 17 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 18 | 19 | .. autoclass:: captum.attr.NeuronConductance 20 | :members: 21 | 22 | Neuron DeepLift 23 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 24 | 25 | .. autoclass:: captum.attr.NeuronDeepLift 26 | :members: 27 | 28 | Neuron DeepLiftShap 29 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 | 31 | .. autoclass:: captum.attr.NeuronDeepLiftShap 32 | :members: 33 | 34 | Neuron GradientShap 35 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 36 | 37 | .. autoclass:: captum.attr.NeuronGradientShap 38 | :members: 39 | 40 | Neuron Guided Backprop 41 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 42 | 43 | .. autoclass:: captum.attr.NeuronGuidedBackprop 44 | :members: 45 | 46 | Neuron Deconvolution 47 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 48 | 49 | .. autoclass:: captum.attr.NeuronDeconvolution 50 | :members: 51 | 52 | Neuron Feature Ablation 53 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 54 | 55 | .. autoclass:: captum.attr.NeuronFeatureAblation 56 | :members: 57 | :exclude-members: compute_convergence_delta 58 | -------------------------------------------------------------------------------- /sphinx/source/noise_tunnel.rst: -------------------------------------------------------------------------------- 1 | NoiseTunnel 2 | ============ 3 | 4 | .. autoclass:: captum.attr.NoiseTunnel 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | -------------------------------------------------------------------------------- /sphinx/source/occlusion.rst: -------------------------------------------------------------------------------- 1 | Occlusion 2 | ========= 3 | 4 | .. autoclass:: captum.attr.Occlusion 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | -------------------------------------------------------------------------------- /sphinx/source/robust.rst: -------------------------------------------------------------------------------- 1 | Robustness 2 | ====================== 3 | 4 | FGSM 5 | ^^^^^^^^^^^^^^^^^^^^^^ 6 | 7 | .. autoclass:: captum.robust.FGSM 8 | :members: 9 | 10 | 11 | PGD 12 | ^^^^^^^^^^^^^^^^^^^^^^ 13 | 14 | .. autoclass:: captum.robust.PGD 15 | :members: 16 | 17 | 18 | Attack Comparator 19 | ^^^^^^^^^^^^^^^^^^^^^^ 20 | 21 | .. autoclass:: captum.robust.AttackComparator 22 | :members: 23 | 24 | 25 | Min Param Perturbation 26 | ^^^^^^^^^^^^^^^^^^^^^^ 27 | 28 | .. autoclass:: captum.robust.MinParamPerturbation 29 | :members: 30 | -------------------------------------------------------------------------------- /sphinx/source/saliency.rst: -------------------------------------------------------------------------------- 1 | Saliency 2 | =============== 3 | 4 | .. autoclass:: captum.attr.Saliency 5 | :members: 6 | -------------------------------------------------------------------------------- /sphinx/source/shapley_value_sampling.rst: -------------------------------------------------------------------------------- 1 | Shapley Value Sampling 2 | ====================== 3 | 4 | .. autoclass:: captum.attr.ShapleyValueSampling 5 | :members: 6 | :exclude-members: compute_convergence_delta 7 | .. autoclass:: captum.attr.ShapleyValues 8 | :members: 9 | :exclude-members: compute_convergence_delta 10 | -------------------------------------------------------------------------------- /sphinx/source/utilities.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========== 3 | 4 | Interpretable Input 5 | ^^^^^^^^^^^^^^^^^^^^^ 6 | .. autoclass:: captum.attr.InterpretableInput 7 | :members: 8 | 9 | .. autoclass:: captum.attr.TextTemplateInput 10 | :members: 11 | 12 | .. autoclass:: captum.attr.TextTokenInput 13 | :members: 14 | 15 | 16 | Visualization 17 | ^^^^^^^^^^^^^^ 18 | 19 | .. autofunction:: captum.attr.visualization.visualize_image_attr 20 | 21 | .. autofunction:: captum.attr.visualization.visualize_image_attr_multiple 22 | 23 | .. autofunction:: captum.attr.visualization.visualize_timeseries_attr 24 | 25 | 26 | Interpretable Embeddings 27 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 28 | 29 | .. autoclass:: captum.attr.InterpretableEmbeddingBase 30 | :members: 31 | 32 | .. autofunction:: captum.attr.configure_interpretable_embedding_layer 33 | 34 | .. autofunction:: captum.attr.remove_interpretable_embedding_layer 35 | 36 | 37 | Token Reference Base 38 | ^^^^^^^^^^^^^^^^^^^^^ 39 | 40 | .. autoclass:: captum.attr.TokenReferenceBase 41 | :members: 42 | 43 | 44 | Linear Models 45 | ^^^^^^^^^^^^^^^^^^^^^ 46 | 47 | .. autoclass:: captum._utils.models.model.Model 48 | :members: 49 | .. autoclass:: captum._utils.models.linear_model.SkLearnLinearModel 50 | :members: 51 | .. autoclass:: captum._utils.models.linear_model.SkLearnLinearRegression 52 | :members: 53 | .. autoclass:: captum._utils.models.linear_model.SkLearnLasso 54 | :members: 55 | .. autoclass:: captum._utils.models.linear_model.SkLearnRidge 56 | :members: 57 | .. autoclass:: captum._utils.models.linear_model.SGDLinearModel 58 | :members: 59 | .. autoclass:: captum._utils.models.linear_model.SGDLasso 60 | :members: 61 | .. autoclass:: captum._utils.models.linear_model.SGDRidge 62 | :members: 63 | 64 | Baselines 65 | ^^^^^^^^^^^^^^^^ 66 | 67 | .. autoclass:: captum.attr.ProductBaselines 68 | :members: 69 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/__init__.py -------------------------------------------------------------------------------- /tests/attr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/attr/__init__.py -------------------------------------------------------------------------------- /tests/attr/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/attr/layer/__init__.py -------------------------------------------------------------------------------- /tests/attr/layer/test_layer_feature_permutation.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-unsafe 4 | 5 | import torch 6 | from captum.attr._core.layer.layer_feature_permutation import LayerFeaturePermutation 7 | from captum.testing.helpers import BaseTest 8 | from captum.testing.helpers.basic import assertTensorAlmostEqual 9 | from captum.testing.helpers.basic_models import BasicModel_MultiLayer 10 | from torch import Tensor 11 | 12 | 13 | class TestLayerFeaturePermutation(BaseTest): 14 | def test_single_input(self) -> None: 15 | net = BasicModel_MultiLayer() 16 | feature_importance = LayerFeaturePermutation( 17 | forward_func=net, 18 | layer=net.linear0, 19 | ) 20 | 21 | batch_size = 2 22 | input_size = (3,) 23 | constant_value = 10000 24 | 25 | inp = torch.randn((batch_size,) + input_size) 26 | inp[:, 0] = constant_value 27 | 28 | attribs = feature_importance.attribute(inputs=inp) 29 | 30 | self.assertTrue(isinstance(attribs, Tensor)) 31 | self.assertEqual(len(attribs), 4) 32 | self.assertEqual(attribs.squeeze(0).size(), (2 * batch_size,) + input_size) 33 | zeros = torch.zeros(2 * batch_size) 34 | assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0, mode="max") 35 | self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) 36 | -------------------------------------------------------------------------------- /tests/attr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/attr/models/__init__.py -------------------------------------------------------------------------------- /tests/attr/neuron/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/attr/neuron/__init__.py -------------------------------------------------------------------------------- /tests/attr/test_baselines.py: -------------------------------------------------------------------------------- 1 | # pyre-unsafe 2 | from typing import cast, Dict, List, Tuple, Union 3 | 4 | from captum.attr._utils.baselines import ProductBaselines 5 | 6 | # from parameterized import parameterized 7 | from captum.testing.helpers import BaseTest 8 | 9 | 10 | class TestProductBaselines(BaseTest): 11 | def test_list(self) -> None: 12 | baseline_values = [ 13 | [1, 2, 3], 14 | [4, 5, 6, 7], 15 | [8, 9], 16 | ] 17 | 18 | baselines = ProductBaselines(baseline_values) 19 | 20 | baseline_sample = baselines() 21 | 22 | self.assertIsInstance(baseline_sample, list) 23 | for sample_val, vals in zip(baseline_sample, baseline_values): 24 | self.assertIn(sample_val, vals) 25 | 26 | def test_dict(self) -> None: 27 | baseline_values = { 28 | "f1": [1, 2, 3], 29 | "f2": [4, 5, 6, 7], 30 | "f3": [8, 9], 31 | } 32 | 33 | baselines = ProductBaselines( 34 | cast(Dict[Union[str, Tuple[str, ...]], List[int]], baseline_values) 35 | ) 36 | 37 | baseline_sample = baselines() 38 | 39 | self.assertIsInstance(baseline_sample, dict) 40 | baseline_sample = cast(dict, baseline_sample) 41 | 42 | for sample_key, sample_val in baseline_sample.items(): 43 | self.assertIn(sample_val, baseline_values[sample_key]) 44 | 45 | def test_dict_tuple_key(self) -> None: 46 | baseline_values: Dict[Union[str, Tuple[str, ...]], List] = { 47 | ("f1", "f2"): [(1, "1"), (2, "2"), (3, "3")], 48 | "f3": [4, 5], 49 | } 50 | 51 | baselines = ProductBaselines(baseline_values) 52 | 53 | baseline_sample = baselines() 54 | 55 | self.assertIsInstance(baseline_sample, dict) 56 | baseline_sample = cast(dict, baseline_sample) 57 | 58 | self.assertEqual(len(baseline_sample), 3) 59 | 60 | self.assertIn( 61 | (baseline_sample["f1"], baseline_sample["f2"]), 62 | baseline_values[("f1", "f2")], 63 | ) 64 | self.assertIn(baseline_sample["f3"], baseline_values["f3"]) 65 | -------------------------------------------------------------------------------- /tests/attr/test_common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-unsafe 4 | 5 | import torch 6 | from captum.attr._core.noise_tunnel import SUPPORTED_NOISE_TUNNEL_TYPES 7 | from captum.attr._utils.common import _validate_input, _validate_noise_tunnel_type 8 | from captum.testing.helpers import BaseTest 9 | 10 | 11 | class Test(BaseTest): 12 | def test_validate_input(self) -> None: 13 | with self.assertRaises(AssertionError) as err: 14 | _validate_input( 15 | (torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0, 0.0, 1.0]),) 16 | ) 17 | self.assertEqual( 18 | "Baseline can be provided as a tensor for just one input and " 19 | "broadcasted to the batch or input and baseline must have the " 20 | "same shape or the baseline corresponding to each input tensor " 21 | "must be a scalar. Found baseline: tensor([-2., 0., 1.]) and " 22 | "input: tensor([-1., 1.])", 23 | str(err.exception), 24 | ) 25 | 26 | with self.assertRaises(AssertionError) as err: 27 | _validate_input( 28 | (torch.tensor([-1.0, 1.0]),), (torch.tensor([-1.0, 1.0]),), n_steps=-1 29 | ) 30 | self.assertEqual( 31 | "The number of steps must be a positive integer. Given: -1", 32 | str(err.exception), 33 | ) 34 | 35 | with self.assertRaises(AssertionError) as err: 36 | _validate_input( 37 | (torch.tensor([-1.0, 1.0]),), 38 | (torch.tensor([-1.0, 1.0]),), 39 | method="abcde", 40 | ) 41 | self.assertIn( 42 | "Approximation method must be one for the following", 43 | str(err.exception), 44 | ) 45 | # any baseline which is broadcastable to match the input is supported, which 46 | # includes a scalar / single-element tensor. 47 | _validate_input((torch.tensor([-1.0, 1.0]),), (torch.tensor([-2.0]),)) 48 | _validate_input((torch.tensor([-1.0]),), (torch.tensor([-2.0]),)) 49 | _validate_input( 50 | (torch.tensor([-1.0]),), (torch.tensor([-2.0]),), method="gausslegendre" 51 | ) 52 | 53 | def test_validate_nt_type(self) -> None: 54 | with self.assertRaises( 55 | AssertionError, 56 | ) as err: 57 | _validate_noise_tunnel_type("abc", SUPPORTED_NOISE_TUNNEL_TYPES) 58 | self.assertIn( 59 | "Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`.", 60 | str(err.exception), 61 | ) 62 | 63 | _validate_noise_tunnel_type("smoothgrad", SUPPORTED_NOISE_TUNNEL_TYPES) 64 | _validate_noise_tunnel_type("smoothgrad_sq", SUPPORTED_NOISE_TUNNEL_TYPES) 65 | _validate_noise_tunnel_type("vargrad", SUPPORTED_NOISE_TUNNEL_TYPES) 66 | -------------------------------------------------------------------------------- /tests/attr/test_summarizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-unsafe 4 | import torch 5 | from captum.attr import CommonStats, Summarizer 6 | from captum.testing.helpers import BaseTest 7 | 8 | 9 | class Test(BaseTest): 10 | def test_single_input(self) -> None: 11 | size = (2, 3) 12 | summarizer = Summarizer(stats=CommonStats()) 13 | for _ in range(10): 14 | attrs = torch.randn(size) 15 | summarizer.update(attrs) 16 | 17 | summ = summarizer.summary 18 | self.assertIsNotNone(summ) 19 | self.assertTrue(isinstance(summ, dict)) 20 | 21 | for k in summ: 22 | self.assertTrue(summ[k].size() == size) 23 | 24 | def test_multi_input(self) -> None: 25 | size1 = (10, 5, 5) 26 | size2 = (3, 5) 27 | 28 | summarizer = Summarizer(stats=CommonStats()) 29 | for _ in range(10): 30 | a1 = torch.randn(size1) 31 | a2 = torch.randn(size2) 32 | summarizer.update((a1, a2)) 33 | 34 | summ = summarizer.summary 35 | self.assertIsNotNone(summ) 36 | self.assertTrue(len(summ) == 2) 37 | self.assertTrue(isinstance(summ[0], dict)) 38 | self.assertTrue(isinstance(summ[1], dict)) 39 | 40 | for k in summ[0]: 41 | self.assertTrue(summ[0][k].size() == size1) 42 | self.assertTrue(summ[1][k].size() == size2) 43 | -------------------------------------------------------------------------------- /tests/attr/test_utils_batching.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-unsafe 4 | 5 | import torch 6 | from captum.attr._utils.batching import ( 7 | _batched_generator, 8 | _batched_operator, 9 | _tuple_splice_range, 10 | ) 11 | from captum.testing.helpers import BaseTest 12 | from captum.testing.helpers.basic import assertTensorAlmostEqual 13 | 14 | 15 | class Test(BaseTest): 16 | def test_tuple_splice_range(self) -> None: 17 | test_tuple = ( 18 | torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]), 19 | "test", 20 | torch.tensor([[6, 7, 8], [0, 1, 2], [3, 4, 5]]), 21 | ) 22 | spliced_tuple = _tuple_splice_range(test_tuple, 1, 3) 23 | assertTensorAlmostEqual(self, spliced_tuple[0], [[3, 4, 5], [6, 7, 8]]) 24 | self.assertEqual(spliced_tuple[1], "test") 25 | assertTensorAlmostEqual(self, spliced_tuple[2], [[0, 1, 2], [3, 4, 5]]) 26 | 27 | def test_tuple_splice_range_3d(self) -> None: 28 | test_tuple = ( 29 | torch.tensor([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [6, 7, 8]]]), 30 | "test", 31 | ) 32 | spliced_tuple = _tuple_splice_range(test_tuple, 1, 2) 33 | assertTensorAlmostEqual(self, spliced_tuple[0], [[[6, 7, 8], [6, 7, 8]]]) 34 | self.assertEqual(spliced_tuple[1], "test") 35 | 36 | def test_batched_generator(self) -> None: 37 | def sample_operator(inputs, additional_forward_args, target_ind, scale): 38 | return ( 39 | scale * (sum(inputs)), 40 | scale * sum(additional_forward_args), 41 | target_ind, 42 | ) 43 | 44 | array1 = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 45 | array2 = [[6, 7, 8], [0, 1, 2], [3, 4, 5]] 46 | array3 = [[0, 1, 2], [0, 0, 0], [0, 0, 0]] 47 | inp1, inp2, inp3 = ( 48 | torch.tensor(array1), 49 | torch.tensor(array2), 50 | torch.tensor(array3), 51 | ) 52 | for index, (inp, add, targ) in enumerate( 53 | _batched_generator((inp1, inp2), (inp3, 5), 7, 1) 54 | ): 55 | assertTensorAlmostEqual(self, inp[0], [array1[index]]) 56 | assertTensorAlmostEqual(self, inp[1], [array2[index]]) 57 | assertTensorAlmostEqual(self, add[0], [array3[index]]) 58 | self.assertEqual(add[1], 5) 59 | self.assertEqual(targ, 7) 60 | 61 | def test_batched_operator_0_bsz(self) -> None: 62 | inp1 = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 63 | with self.assertRaises(AssertionError): 64 | _batched_operator(lambda x: x, inputs=inp1, internal_batch_size=0) 65 | 66 | def test_batched_operator(self) -> None: 67 | def _sample_operator(inputs, additional_forward_args, target_ind, scale): 68 | return ( 69 | scale * (sum(inputs)), 70 | scale * sum(additional_forward_args) + target_ind[0], 71 | ) 72 | 73 | inp1 = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) 74 | inp2 = torch.tensor([[6, 7, 8], [0, 1, 2], [3, 4, 5]]) 75 | inp3 = torch.tensor([[0, 1, 2], [0, 0, 0], [0, 0, 0]]) 76 | batched_result = _batched_operator( 77 | _sample_operator, 78 | inputs=(inp1, inp2), 79 | additional_forward_args=(inp3), 80 | target_ind=[0, 1, 2], 81 | scale=2.0, 82 | internal_batch_size=1, 83 | ) 84 | assertTensorAlmostEqual( 85 | self, batched_result[0], [[12, 16, 20], [6, 10, 14], [18, 22, 26]] 86 | ) 87 | assertTensorAlmostEqual( 88 | self, batched_result[1], [[0, 2, 4], [1, 1, 1], [2, 2, 2]] 89 | ) 90 | -------------------------------------------------------------------------------- /tests/concept/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/concept/__init__.py -------------------------------------------------------------------------------- /tests/concept/test_concept.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3import 2 | 3 | # pyre-unsafe 4 | 5 | from typing import cast, Iterable 6 | 7 | import torch 8 | from captum.concept._core.concept import Concept 9 | from captum.concept._utils.data_iterator import dataset_to_dataloader 10 | from captum.testing.helpers import BaseTest 11 | from torch.utils.data import IterableDataset 12 | 13 | 14 | class CustomIterableDataset(IterableDataset): 15 | r""" 16 | An auxiliary class for iterating through an image dataset. 17 | """ 18 | 19 | def __init__(self, get_tensor_from_filename_func, path) -> None: 20 | r""" 21 | Args: 22 | 23 | path (str): Path to dataset files 24 | """ 25 | 26 | self.path = path 27 | self.file_itr = ["x"] * 2 28 | self.get_tensor_from_filename_func = get_tensor_from_filename_func 29 | 30 | def get_tensor_from_filename(self, filename): 31 | 32 | return self.get_tensor_from_filename_func(filename) 33 | 34 | def __iter__(self): 35 | 36 | mapped_itr = map(self.get_tensor_from_filename, self.file_itr) 37 | 38 | return mapped_itr 39 | 40 | 41 | class Test(BaseTest): 42 | def test_create_concepts_from_images(self) -> None: 43 | def get_tensor_from_filename(filename): 44 | return torch.rand(3, 224, 224) 45 | 46 | # Striped 47 | concepts_path = "./dummy/concepts/striped/" 48 | dataset = CustomIterableDataset(get_tensor_from_filename, concepts_path) 49 | striped_iter = dataset_to_dataloader(dataset) 50 | 51 | self.assertEqual( 52 | len(cast(CustomIterableDataset, striped_iter.dataset).file_itr), 2 53 | ) 54 | 55 | concept = Concept(id=0, name="striped", data_iter=striped_iter) 56 | 57 | for data in cast(Iterable, concept.data_iter): 58 | self.assertEqual(data.shape[1:], torch.Size([3, 224, 224])) 59 | 60 | # Random 61 | concepts_path = "./dummy/concepts/random/" 62 | dataset = CustomIterableDataset(get_tensor_from_filename, concepts_path) 63 | random_iter = dataset_to_dataloader(dataset) 64 | self.assertEqual( 65 | len(cast(CustomIterableDataset, random_iter.dataset).file_itr), 2 66 | ) 67 | 68 | concept = Concept(id=1, name="random", data_iter=random_iter) 69 | for data in cast(Iterable, concept.data_iter): 70 | self.assertEqual(data.shape[1:], torch.Size([3, 224, 224])) 71 | -------------------------------------------------------------------------------- /tests/influence/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/influence/__init__.py -------------------------------------------------------------------------------- /tests/influence/_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/influence/_core/__init__.py -------------------------------------------------------------------------------- /tests/influence/_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/influence/_utils/__init__.py -------------------------------------------------------------------------------- /tests/influence/_utils/test_common.py: -------------------------------------------------------------------------------- 1 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 | 3 | # pyre-unsafe 4 | 5 | # !/usr/bin/env python3 6 | 7 | import torch 8 | 9 | from captum.influence._utils.common import _jacobian_loss_wrt_inputs 10 | from captum.testing.helpers import BaseTest 11 | from captum.testing.helpers.basic import assertTensorAlmostEqual 12 | 13 | 14 | class TestCommon(BaseTest): 15 | def setUp(self) -> None: 16 | super().setUp() 17 | 18 | def test_jacobian_loss_wrt_inputs(self) -> None: 19 | with self.assertRaises(ValueError) as err: 20 | _jacobian_loss_wrt_inputs( 21 | torch.nn.BCELoss(reduction="sum"), 22 | torch.tensor([-1.0, 1.0]), 23 | torch.tensor([1.0]), 24 | True, 25 | "", 26 | ) 27 | self.assertEqual( 28 | "`` is not a valid value for reduction_type. " 29 | "Must be either 'sum' or 'mean'.", 30 | str(err.exception), 31 | ) 32 | 33 | with self.assertRaises(AssertionError) as err: 34 | _jacobian_loss_wrt_inputs( 35 | torch.nn.BCELoss(reduction="sum"), 36 | torch.tensor([-1.0, 1.0]), 37 | torch.tensor([1.0]), 38 | True, 39 | "mean", 40 | ) 41 | self.assertEqual( 42 | "loss_fn.reduction `sum` does not matchreduction type `mean`." 43 | " Please ensure they are matching.", 44 | str(err.exception), 45 | ) 46 | 47 | res = _jacobian_loss_wrt_inputs( 48 | torch.nn.BCELoss(reduction="sum"), 49 | torch.tensor([0.5, 1.0]), 50 | torch.tensor([0.0, 1.0]), 51 | True, 52 | "sum", 53 | ) 54 | assertTensorAlmostEqual( 55 | self, res, torch.tensor([2.0, 0.0]), delta=0.0, mode="sum" 56 | ) 57 | -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/metrics/__init__.py -------------------------------------------------------------------------------- /tests/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/module/__init__.py -------------------------------------------------------------------------------- /tests/module/test_binary_concrete_stochastic_gates_cuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 3 | 4 | # pyre-strict 5 | 6 | from .test_binary_concrete_stochastic_gates import TestBinaryConcreteStochasticGates 7 | 8 | 9 | class TestBinaryConcreteStochasticGatesCUDA(TestBinaryConcreteStochasticGates): 10 | testing_device: str = "cuda" 11 | -------------------------------------------------------------------------------- /tests/module/test_gaussian_stochastic_gates_cuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fbpython 2 | # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 3 | 4 | # pyre-strict 5 | 6 | from .test_gaussian_stochastic_gates import TestGaussianStochasticGates 7 | 8 | 9 | class TestGaussianStochasticGatesCUDA(TestGaussianStochasticGates): 10 | testing_device: str = "cuda" 11 | -------------------------------------------------------------------------------- /tests/robust/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/robust/__init__.py -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/utils/models/__init__.py -------------------------------------------------------------------------------- /tests/utils/models/linear_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tests/utils/models/linear_models/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-unsafe 4 | 5 | import torch 6 | from captum.testing.helpers import BaseTest 7 | from captum.testing.helpers.basic import assertTensorAlmostEqual 8 | 9 | 10 | class HelpersTest(BaseTest): 11 | def test_assert_tensor_almost_equal(self) -> None: 12 | with self.assertRaises(AssertionError) as cm: 13 | assertTensorAlmostEqual(self, [[1.0]], [[1.0]]) 14 | self.assertEqual( 15 | cm.exception.args, 16 | ("Actual parameter given for comparison must be a tensor.",), 17 | ) 18 | 19 | with self.assertRaises(AssertionError) as cm: 20 | assertTensorAlmostEqual(self, torch.tensor([[]]), torch.tensor([[1.0]])) 21 | self.assertEqual( 22 | cm.exception.args, 23 | ( 24 | "Expected tensor with shape: torch.Size([1, 1]). Actual shape torch.Size([1, 0]).", # noqa: E501 25 | ), 26 | ) 27 | 28 | assertTensorAlmostEqual(self, torch.tensor([[1.0]]), [[1.0]]) 29 | 30 | with self.assertRaises(AssertionError) as cm: 31 | assertTensorAlmostEqual(self, torch.tensor([[1.0]]), [1.0]) 32 | self.assertEqual( 33 | cm.exception.args, 34 | ( 35 | "Expected tensor with shape: torch.Size([1]). Actual shape torch.Size([1, 1]).", # noqa: E501 36 | ), 37 | ) 38 | 39 | assertTensorAlmostEqual( 40 | self, torch.tensor([[1.0, 1.0]]), [[1.0, 0.0]], delta=1.0, mode="max" 41 | ) 42 | 43 | with self.assertRaises(AssertionError) as cm: 44 | assertTensorAlmostEqual( 45 | self, torch.tensor([[1.0, 1.0]]), [[1.0, 0.0]], mode="max" 46 | ) 47 | self.assertEqual( 48 | cm.exception.args, 49 | ( 50 | "Values at index 0, tensor([1., 1.]) and tensor([1., 0.]), differ more than by 0.0001", # noqa: E501 51 | ), 52 | ) 53 | 54 | assertTensorAlmostEqual( 55 | self, torch.tensor([[1.0, 1.0]]), [[1.0, 0.0]], delta=1.0 56 | ) 57 | 58 | with self.assertRaises(AssertionError): 59 | assertTensorAlmostEqual(self, torch.tensor([[1.0, 1.0]]), [[1.0, 0.0]]) 60 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | This folder contains Jupyter tutorial notebooks. These notebooks are also 2 | parsed and embedded into the captum website. 3 | -------------------------------------------------------------------------------- /tutorials/data/dlrm/X_S_T_test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/data/dlrm/X_S_T_test -------------------------------------------------------------------------------- /tutorials/data/dlrm/X_S_T_test_above_0999: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/data/dlrm/X_S_T_test_above_0999 -------------------------------------------------------------------------------- /tutorials/data/tcav/image/concepts/README.md: -------------------------------------------------------------------------------- 1 | Add concept folder containing concept images under this directory 2 | -------------------------------------------------------------------------------- /tutorials/data/tcav/image/imagenet/README.md: -------------------------------------------------------------------------------- 1 | Add imagenet images that you want to test against different concepts under this folder. 2 | -------------------------------------------------------------------------------- /tutorials/data/tcav/text-sensitivity/positive-adjectives.csv: -------------------------------------------------------------------------------- 1 | . . so well . . . 2 | . . so good . . . 3 | . . . love it . pad 4 | . . . like it . pad 5 | . . even greater . . . 6 | . antic . . . . . 7 | . . . fantastical . . . 8 | . grotesque . . . . . 9 | fantastic . . . . . pad 10 | grand . . . . . pad 11 | . . . marvelous . . . 12 | rattling . . . . . pad 13 | . are terrific . . . . 14 | . . . . tremendous . pad 15 | . . . . . . wonderful 16 | . . . . wondrous . . 17 | . . . . . . fantastic 18 | pad pad pad . . wild . 19 | pad . . fantastical . . . 20 | fantastic . . . . . . 21 | . . . . exceeding . . 22 | . . . . . . olympian 23 | pad . . . prodigious . pad 24 | . . surpassing . pad pad pad 25 | . . . . . especial . 26 | . . . . . particular . 27 | . . . . special . . 28 | . exceptional . . . . . 29 | . . . . truth here . 30 | . . first great work pad pad 31 | . . . worthwhile . . . 32 | . succeeds . impressive . . . 33 | . . and keener is fantastic . 34 | . . . grand . . pad 35 | . . howling . . . . 36 | . . is marvelous . . . 37 | . is rattling . . . pad . 38 | . . is terrific . . . 39 | tremendous . . . . . . 40 | . . . wonderful . . . 41 | . . marvellous . . . pad 42 | . . best . miraculous . . 43 | . . worthy . . . . 44 | salisfying . . worthy pad . . 45 | . . fanciful . . . . 46 | . . funny , very enjoyable . 47 | . . amusing , very gratifying . 48 | a sensitive , moving , brilliantly . 49 | . brightly constructed work . . . 50 | . . . beautifully . . . 51 | . . . . . attractively pad 52 | excellent . . . pad pad pad 53 | . . . . first-class . . 54 | . enhances the fantabulous . . . 55 | . . enhances the splendid . . 56 | . . . . cool . pad 57 | pad . . is cool . pad 58 | . . adorable . . . . 59 | . . a captivating . . . 60 | enchanting new . . . . . 61 | . . . . . . enthralling 62 | . . entrancing . . . . 63 | . fascinating . . pad pad pad 64 | . . . . . . hearty 65 | fun . . . . . . 66 | . . . . bright . . 67 | . . . . . . charming 68 | . . . surprises . . . 69 | . . . . quietly engaging . 70 | dense , exhilarating . . pad . 71 | . . . . a joyous . 72 | . . . absorbing . . . 73 | . genuine mind - - . . 74 | . great . interaction . pad pad 75 | . . . . intriguing and . 76 | . . . . hard to resist 77 | . . a true - blue delight 78 | pad pad a . fun ride . 79 | . . . rewarding . . ! 80 | sleek and arty . pad pad pad 81 | . . fantastic ! . . . 82 | pad pad pad intelligent . moving pad 83 | . . . genuinely unnerving . . 84 | . compelling . . . . . 85 | glamorous . . . . . pad 86 | . . . a riveting . 87 | . . enjoyable experience . . . 88 | I love . . highly . . 89 | . the perfect pad pad pad pad 90 | . . . beautiful . very nice 91 | . . It's great . . . 92 | . . . . I love this shelf. 93 | . . fabulous . . . . 94 | Wonderful . VERY pad pad pad pad 95 | pad pad Great . . . . 96 | . . loves it . It's great 97 | loves . . . . outstanding . 98 | . . . . loves . . 99 | . . . . . . loves 100 | . pad majuscule pad pad pad pad 101 | pad pad pad pad loves it pad 102 | . . . expectant . . . 103 | pad pad . gravid . . . 104 | It's great . . . . . 105 | . . . loves it . . 106 | . . . a beautiful . perfect 107 | . fun . . . . . 108 | . . . funny . a perfect 109 | . good . . fun . fascinating 110 | . sweet . . . highly . 111 | . liked . . . . . 112 | . . . wonderful . . . 113 | . . . . fantastic . good 114 | great . . amazing . . . 115 | pad pad enthralled . . are priceless 116 | . . and . . . great pad 117 | . . excellent . . fun . 118 | . . happy . . . excellent 119 | . . . . . . brilliant 120 | ingenious . . . . . . 121 | -------------------------------------------------------------------------------- /tutorials/img/bert/visuals_of_start_end_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/bert/visuals_of_start_end_predictions.png -------------------------------------------------------------------------------- /tutorials/img/captum_insights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/captum_insights.png -------------------------------------------------------------------------------- /tutorials/img/captum_insights_vqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/captum_insights_vqa.png -------------------------------------------------------------------------------- /tutorials/img/dlrm_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/dlrm_arch.png -------------------------------------------------------------------------------- /tutorials/img/lime_text_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/lime_text_viz.png -------------------------------------------------------------------------------- /tutorials/img/resnet/swan-3299528_1280.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/resnet/swan-3299528_1280.jpg -------------------------------------------------------------------------------- /tutorials/img/resnet/swan-3299528_1280_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/resnet/swan-3299528_1280_attribution.jpg -------------------------------------------------------------------------------- /tutorials/img/sentiment_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/sentiment_analysis.png -------------------------------------------------------------------------------- /tutorials/img/vqa/elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/elephant.jpg -------------------------------------------------------------------------------- /tutorials/img/vqa/elephant_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/elephant_attribution.jpg -------------------------------------------------------------------------------- /tutorials/img/vqa/siamese.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/siamese.jpg -------------------------------------------------------------------------------- /tutorials/img/vqa/siamese_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/siamese_attribution.jpg -------------------------------------------------------------------------------- /tutorials/img/vqa/zebra.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/zebra.jpg -------------------------------------------------------------------------------- /tutorials/img/vqa/zebra_attribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/img/vqa/zebra_attribution.jpg -------------------------------------------------------------------------------- /tutorials/models/california_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/models/california_model.pt -------------------------------------------------------------------------------- /tutorials/models/cifar_torchvision.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/models/cifar_torchvision.pt -------------------------------------------------------------------------------- /tutorials/models/embedding_bag_ag_news.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/models/embedding_bag_ag_news.pt -------------------------------------------------------------------------------- /tutorials/models/imdb-model-cnn-large.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/models/imdb-model-cnn-large.pt -------------------------------------------------------------------------------- /tutorials/models/titanic_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/tutorials/models/titanic_model.pt -------------------------------------------------------------------------------- /website/core/Tutorial.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2019-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | const React = require('react'); 11 | 12 | const fs = require('fs-extra'); 13 | const path = require('path'); 14 | const CWD = process.cwd(); 15 | 16 | const CompLibrary = require(`${CWD}/node_modules/docusaurus/lib/core/CompLibrary.js`); 17 | const Container = CompLibrary.Container; 18 | 19 | const TutorialSidebar = require(`${CWD}/core/TutorialSidebar.js`); 20 | 21 | function renderDownloadIcon() { 22 | return ( 23 | 37 | ); 38 | } 39 | 40 | class Tutorial extends React.Component { 41 | render() { 42 | const {baseUrl, tutorialID} = this.props; 43 | 44 | const htmlFile = `${CWD}/_tutorials/${tutorialID}.html`; 45 | const normalizedHtmlFile = path.normalize(htmlFile); 46 | 47 | return ( 48 |
49 | 50 | 51 | 79 | ); 80 | } 81 | } 82 | 83 | module.exports = Tutorial; 84 | -------------------------------------------------------------------------------- /website/core/TutorialSidebar.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2019-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | const React = require('react'); 11 | const fs = require('fs-extra'); 12 | const path = require('path'); 13 | const join = path.join; 14 | const CWD = process.cwd(); 15 | 16 | const CompLibrary = require(join( 17 | CWD, 18 | '/node_modules/docusaurus/lib/core/CompLibrary.js', 19 | )); 20 | const SideNav = require(join( 21 | CWD, 22 | '/node_modules/docusaurus/lib/core/nav/SideNav.js', 23 | )); 24 | 25 | const Container = CompLibrary.Container; 26 | 27 | const OVERVIEW_ID = 'tutorial_overview'; 28 | 29 | class TutorialSidebar extends React.Component { 30 | render() { 31 | const {currentTutorialID} = this.props; 32 | const current = { 33 | id: currentTutorialID || OVERVIEW_ID, 34 | }; 35 | 36 | const toc = [ 37 | { 38 | type: 'CATEGORY', 39 | title: 'Captum Tutorials', 40 | children: [ 41 | { 42 | type: 'LINK', 43 | item: { 44 | permalink: 'tutorials/', 45 | id: OVERVIEW_ID, 46 | title: 'Overview', 47 | }, 48 | }, 49 | ], 50 | }, 51 | ]; 52 | 53 | const jsonFile = join(CWD, 'tutorials.json'); 54 | const normJsonFile = path.normalize(jsonFile); 55 | const json = JSON.parse(fs.readFileSync(normJsonFile, {encoding: 'utf8'})); 56 | 57 | Object.keys(json).forEach(category => { 58 | const categoryItems = json[category]; 59 | const items = categoryItems.map(item => { 60 | if (item.id !== undefined) { 61 | return { 62 | type: 'LINK', 63 | item: { 64 | permalink: `tutorials/${item.id}`, 65 | id: item.id, 66 | title: item.title, 67 | }, 68 | } 69 | } 70 | 71 | return { 72 | type: 'SUBCATEGORY', 73 | title: item.title, 74 | children: item.children.map(iitem => { 75 | return { 76 | type: 'Link', 77 | item: { 78 | permalink: `tutorials/${iitem.id}`, 79 | id: iitem.id, 80 | title: iitem.title, 81 | } 82 | }; 83 | }), 84 | } 85 | }); 86 | 87 | toc.push({ 88 | type: 'CATEGORY', 89 | title: category, 90 | children: items, 91 | }); 92 | }); 93 | 94 | return ( 95 | 96 | 103 | 104 | ); 105 | } 106 | } 107 | 108 | module.exports = TutorialSidebar; 109 | -------------------------------------------------------------------------------- /website/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "scripts": { 3 | "examples": "docusaurus-examples", 4 | "start": "docusaurus-start", 5 | "build": "docusaurus-build", 6 | "publish-gh-pages": "docusaurus-publish", 7 | "write-translations": "docusaurus-write-translations", 8 | "version": "docusaurus-version", 9 | "rename-version": "docusaurus-rename-version" 10 | }, 11 | "devDependencies": { 12 | "docusaurus": "^1.9.0" 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /website/sidebars.json: -------------------------------------------------------------------------------- 1 | { 2 | "docs": { 3 | "About": ["introduction"], 4 | "General": ["getting_started", "captum_insights", "attribution_algorithms", "algorithms_comparison_matrix", "faq", "contribution_guidelines"], 5 | "Usage": ["extension/integrated_gradients"] 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /website/siteConfig.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017-present, Facebook, Inc. 3 | * 4 | * This source code is licensed under the BSD license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | * 7 | * @format 8 | */ 9 | 10 | // See https://docusaurus.io/docs/site-config for all the possible 11 | // site configuration options. 12 | 13 | // Define this so it can be easily modified in scripts (to host elsewhere) 14 | const baseUrl = "/"; 15 | 16 | // List of projects/orgs using your project for the users page. 17 | const users = []; 18 | 19 | const siteConfig = { 20 | title: "Captum", 21 | tagline: "Model Interpretability for PyTorch", 22 | url: "https://captum.ai", 23 | baseUrl: baseUrl, 24 | cleanUrl: true, // No .html extensions for paths 25 | 26 | // used for publishing and more 27 | organizationName: "pytorch", 28 | projectName: "captum", 29 | 30 | // Google analytics 31 | gaTrackingId: "UA-44373548-48", 32 | 33 | // links that will be used in the header navigation bar 34 | headerLinks: [ 35 | { doc: "introduction", label: "Docs" }, 36 | { href: `${baseUrl}tutorials/`, label: "Tutorials" }, 37 | { href: `${baseUrl}api/`, label: "API Reference" }, 38 | { href: "https://github.com/pytorch/captum", label: "GitHub" }, 39 | { search: true } // position search box to the very right 40 | ], 41 | 42 | // add users to the website 43 | users, 44 | 45 | // search integration w/ algolia 46 | algolia: { 47 | apiKey: "207c27d819f967749142d8611de7cb19", 48 | indexName: "captum" 49 | }, 50 | 51 | // images for header/footer and favicon 52 | headerIcon: "img/captum_logo.svg", 53 | // footerIcon: 'img/captum-icon.png', 54 | favicon: "img/captum.ico", 55 | 56 | // colors for website 57 | colors: { 58 | primaryColor: "#f29837", // orange 59 | secondaryColor: "#f0bc40" // yellow 60 | }, 61 | 62 | highlight: { 63 | theme: "default" 64 | }, 65 | 66 | // custom scripts that are placed in of each page 67 | scripts: [ 68 | // Github buttons 69 | "https://buttons.github.io/buttons.js", 70 | // Copy-to-clipboard button for code blocks 71 | `${baseUrl}js/code_block_buttons.js`, 72 | "https://cdnjs.cloudflare.com/ajax/libs/clipboard.js/2.0.0/clipboard.min.js" 73 | ], 74 | 75 | // CSS sources to load 76 | stylesheets: [], 77 | 78 | // enable on-page navigation for the current documentation page 79 | onPageNav: "separate", 80 | 81 | // enable scroll to top button a the bottom of the site 82 | scrollToTop: true, 83 | 84 | // if true, expand/collapse links & subcategories in sidebar 85 | docsSideNavCollapsible: false, 86 | 87 | // URL for editing docs 88 | editUrl: "https://github.com/pytorch/captum/edit/master/docs/", 89 | 90 | // Disable logo text so we can just show the logo 91 | disableHeaderTitle: true, 92 | 93 | // Open Graph and Twitter card images 94 | ogImage: "img/captum-icon.png", 95 | twitterImage: "img/captum.png", 96 | 97 | // show html docs generated by sphinx 98 | wrapPagesHTML: true 99 | }; 100 | 101 | module.exports = siteConfig; 102 | -------------------------------------------------------------------------------- /website/static/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/.nojekyll -------------------------------------------------------------------------------- /website/static/CNAME: -------------------------------------------------------------------------------- 1 | captum.ai -------------------------------------------------------------------------------- /website/static/css/code_block_buttons.css: -------------------------------------------------------------------------------- 1 | /* "Copy" code block button */ 2 | pre { 3 | position: relative; 4 | } 5 | 6 | pre .btnIcon { 7 | position: absolute; 8 | top: 4px; 9 | z-index: 2; 10 | cursor: pointer; 11 | border: 1px solid transparent; 12 | padding: 0; 13 | color: #000; 14 | background-color: transparent; 15 | height: 30px; 16 | transition: all .25s ease-out; 17 | } 18 | 19 | pre .btnIcon:hover { 20 | text-decoration: none; 21 | } 22 | 23 | .btnIcon__body { 24 | align-items: center; 25 | display: flex; 26 | } 27 | 28 | .btnIcon svg { 29 | fill: currentColor; 30 | margin-right: .4em; 31 | } 32 | 33 | .btnIcon__label { 34 | font-size: 11px; 35 | } 36 | 37 | .btnClipboard { 38 | right: 10px; 39 | } 40 | -------------------------------------------------------------------------------- /website/static/img/IG_eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/IG_eq1.png -------------------------------------------------------------------------------- /website/static/img/algorithms_comparison_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/algorithms_comparison_matrix.png -------------------------------------------------------------------------------- /website/static/img/captum-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/captum-icon.png -------------------------------------------------------------------------------- /website/static/img/captum.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/captum.ico -------------------------------------------------------------------------------- /website/static/img/captum_insights_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/captum_insights_screenshot.png -------------------------------------------------------------------------------- /website/static/img/captum_insights_screenshot_vqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/captum_insights_screenshot_vqa.png -------------------------------------------------------------------------------- /website/static/img/captum_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/captum_logo.png -------------------------------------------------------------------------------- /website/static/img/captum_logo.svg: -------------------------------------------------------------------------------- 1 | captum-logo -------------------------------------------------------------------------------- /website/static/img/conductance_eq_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/conductance_eq_1.png -------------------------------------------------------------------------------- /website/static/img/conductance_eq_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/conductance_eq_2.png -------------------------------------------------------------------------------- /website/static/img/deepLIFT_multipliers_eq1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/deepLIFT_multipliers_eq1.png -------------------------------------------------------------------------------- /website/static/img/expanding_arrows.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /website/static/img/infidelity_eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/infidelity_eq.png -------------------------------------------------------------------------------- /website/static/img/landing-background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/landing-background.jpg -------------------------------------------------------------------------------- /website/static/img/multi-modal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/multi-modal.png -------------------------------------------------------------------------------- /website/static/img/oss_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/oss_logo.png -------------------------------------------------------------------------------- /website/static/img/pytorch_logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 8 | pytorch_logo 9 | 11 | 13 | 14 | -------------------------------------------------------------------------------- /website/static/img/sensitivity_eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/captum/b60a9e93ac1bbe241876c7e859d15b8930417ae7/website/static/img/sensitivity_eq.png -------------------------------------------------------------------------------- /website/static/js/code_block_buttons.js: -------------------------------------------------------------------------------- 1 | // Turn off ESLint for this file because it's sent down to users as-is. 2 | /* eslint-disable */ 3 | window.addEventListener('load', function() { 4 | function button(label, ariaLabel, icon, className) { 5 | const btn = document.createElement('button'); 6 | btn.classList.add('btnIcon', className); 7 | btn.setAttribute('type', 'button'); 8 | btn.setAttribute('aria-label', ariaLabel); 9 | btn.innerHTML = 10 | '
' + 11 | icon + 12 | '' + 13 | label + 14 | '' + 15 | '
'; 16 | return btn; 17 | } 18 | 19 | function addButtons(codeBlockSelector, btn) { 20 | document.querySelectorAll(codeBlockSelector).forEach(function(code) { 21 | code.parentNode.appendChild(btn.cloneNode(true)); 22 | }); 23 | } 24 | 25 | const copyIcon = 26 | ''; 27 | 28 | addButtons( 29 | '.hljs', 30 | button('Copy', 'Copy code to clipboard', copyIcon, 'btnClipboard'), 31 | ); 32 | 33 | const clipboard = new ClipboardJS('.btnClipboard', { 34 | target: function(trigger) { 35 | return trigger.parentNode.querySelector('code'); 36 | }, 37 | }); 38 | 39 | clipboard.on('success', function(event) { 40 | event.clearSelection(); 41 | const textEl = event.trigger.querySelector('.btnIcon__label'); 42 | textEl.textContent = 'Copied'; 43 | setTimeout(function() { 44 | textEl.textContent = 'Copy'; 45 | }, 2000); 46 | }); 47 | }); 48 | -------------------------------------------------------------------------------- /website/tutorials.json: -------------------------------------------------------------------------------- 1 | { 2 | "Introduction to Captum": [ 3 | { 4 | "id": "Titanic_Basic_Interpret", 5 | "title": "Getting started with Captum" 6 | } 7 | ], 8 | "Attribution": [ 9 | { 10 | "id": "IMDB_TorchText_Interpret", 11 | "title": "Interpreting text models" 12 | }, 13 | { 14 | "id": "CIFAR_TorchVision_Interpret", 15 | "title": "Intepreting vision with CIFAR" 16 | }, 17 | { 18 | "id": "TorchVision_Interpret", 19 | "title": "Interpreting vision with Pretrained Models" 20 | }, 21 | { 22 | "id": "Resnet_TorchVision_Ablation", 23 | "title": "Feature ablation on images with ResNet" 24 | }, 25 | { 26 | "id": "Multimodal_VQA_Interpret", 27 | "title": "Interpreting multimodal models" 28 | }, 29 | { 30 | "id": "House_Prices_Regression_Interpret", 31 | "title": "Interpreting a regression model of California house prices" 32 | }, 33 | { 34 | "id": "Segmentation_Interpret", 35 | "title": "Interpreting semantic segmentation models" 36 | }, 37 | { 38 | "id": "Distributed_Attribution", 39 | "title": "Using Captum with torch.distributed" 40 | }, 41 | { 42 | "id": "DLRM_Tutorial", 43 | "title": "Interpreting Deep Learning Recommender Models" 44 | }, 45 | { 46 | "id": "Image_and_Text_Classification_LIME", 47 | "title": "Interpreting vision and text models with LIME" 48 | }, 49 | { 50 | "id": "Llama2_LLM_Attribution", 51 | "title": "Understanding Llama2 with Captum LLM Attribution" 52 | }, 53 | { 54 | "title": "Interpreting BERT", 55 | "children": [ 56 | { 57 | "id": "Bert_SQUAD_Interpret", 58 | "title": "Interpreting question answering with BERT Part 1" 59 | }, 60 | { 61 | "id": "Bert_SQUAD_Interpret2", 62 | "title": "Interpreting question answering with BERT Part 2" 63 | } 64 | ] 65 | } 66 | ], 67 | "Robustness": [ 68 | { 69 | "id": "CIFAR_Captum_Robustness", 70 | "title": "Applying robustness attacks and metrics to CIFAR model and dataset" 71 | } 72 | ], 73 | "Concept": [ 74 | { 75 | "id": "TCAV_Image", 76 | "title": "TCAV for image classification for googlenet model" 77 | }, 78 | { 79 | "id": "TCAV_NLP", 80 | "title": "TCAV for NLP sentiment analysis model" 81 | } 82 | ], 83 | "Influential Examples": [ 84 | { 85 | "id": "TracInCP_Tutorial", 86 | "title": "Identifying influential examples and mis-labelled examples with TracInCP" 87 | } 88 | ], 89 | "Captum Insight": [ 90 | { 91 | "id": "CIFAR_TorchVision_Captum_Insights", 92 | "title": "Getting started with Captum Insights" 93 | }, 94 | { 95 | "id": "Multimodal_VQA_Captum_Insights", 96 | "title": "Using Captum Insights with multimodal models (VQA)" 97 | } 98 | ] 99 | } 100 | --------------------------------------------------------------------------------